Skip to content

Commit b3c2aa6

Browse files
authored
PyAnextAwaitable (#6427)
1 parent db71554 commit b3c2aa6

File tree

12 files changed

+255
-37
lines changed

12 files changed

+255
-37
lines changed

.cspell.dict/python-more.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ aenter
44
aexit
55
aiter
66
anext
7+
anextawaitable
78
appendleft
89
argcount
910
arrayiterator

Lib/test/test_asyncgen.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,6 @@ async def async_gen_wrapper():
375375

376376
self.compare_generators(sync_gen_wrapper(), async_gen_wrapper())
377377

378-
# TODO: RUSTPYTHON
379-
@unittest.expectedFailure
380378
def test_async_gen_api_01(self):
381379
async def gen():
382380
yield 123
@@ -467,16 +465,12 @@ async def test_throw():
467465
result = self.loop.run_until_complete(test_throw())
468466
self.assertEqual(result, "completed")
469467

470-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
471-
@unittest.expectedFailure
472468
def test_async_generator_anext(self):
473469
async def agen():
474470
yield 1
475471
yield 2
476472
self.check_async_iterator_anext(agen)
477473

478-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
479-
@unittest.expectedFailure
480474
def test_python_async_iterator_anext(self):
481475
class MyAsyncIter:
482476
"""Asynchronously yield 1, then 2."""
@@ -492,8 +486,6 @@ async def __anext__(self):
492486
return self.yielded
493487
self.check_async_iterator_anext(MyAsyncIter)
494488

495-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
496-
@unittest.expectedFailure
497489
def test_python_async_iterator_types_coroutine_anext(self):
498490
import types
499491
class MyAsyncIterWithTypesCoro:
@@ -523,8 +515,6 @@ async def consume():
523515
res = self.loop.run_until_complete(consume())
524516
self.assertEqual(res, [1, 2])
525517

526-
# TODO: RUSTPYTHON, NameError: name 'aiter' is not defined
527-
@unittest.expectedFailure
528518
def test_async_gen_aiter_class(self):
529519
results = []
530520
class Gen:
@@ -549,8 +539,6 @@ async def gen():
549539
applied_twice = aiter(applied_once)
550540
self.assertIs(applied_once, applied_twice)
551541

552-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
553-
@unittest.expectedFailure
554542
def test_anext_bad_args(self):
555543
async def gen():
556544
yield 1
@@ -571,7 +559,7 @@ async def call_with_kwarg():
571559
with self.assertRaises(TypeError):
572560
self.loop.run_until_complete(call_with_kwarg())
573561

574-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
562+
# TODO: RUSTPYTHON, error message mismatch
575563
@unittest.expectedFailure
576564
def test_anext_bad_await(self):
577565
async def bad_awaitable():
@@ -642,7 +630,7 @@ async def do_test():
642630
result = self.loop.run_until_complete(do_test())
643631
self.assertEqual(result, "completed")
644632

645-
# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
633+
# TODO: RUSTPYTHON, anext coroutine iteration issue
646634
@unittest.expectedFailure
647635
def test_anext_iter(self):
648636
@types.coroutine

Lib/test/test_generators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def g3(): return (yield from f())
117117

118118
class GeneratorTest(unittest.TestCase):
119119

120-
@unittest.expectedFailure # TODO: RUSTPYTHON
121120
def test_name(self):
122121
def func():
123122
yield 1

Lib/test/test_types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,8 +2052,6 @@ async def corofunc():
20522052
else:
20532053
self.fail('StopIteration was expected')
20542054

2055-
# TODO: RUSTPYTHON
2056-
@unittest.expectedFailure
20572055
def test_gen(self):
20582056
def gen_func():
20592057
yield 1

crates/vm/src/builtins/asyncgenerator.rs

Lines changed: 155 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ impl PyAsyncGen {
3333
&self.inner
3434
}
3535

36-
pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
36+
pub fn new(frame: FrameRef, name: PyStrRef, qualname: PyStrRef) -> Self {
3737
Self {
38-
inner: Coro::new(frame, name),
38+
inner: Coro::new(frame, name, qualname),
3939
running_async: AtomicCell::new(false),
4040
}
4141
}
@@ -50,6 +50,16 @@ impl PyAsyncGen {
5050
self.inner.set_name(name)
5151
}
5252

53+
#[pygetset]
54+
fn __qualname__(&self) -> PyStrRef {
55+
self.inner.qualname()
56+
}
57+
58+
#[pygetset(setter)]
59+
fn set___qualname__(&self, qualname: PyStrRef) {
60+
self.inner.set_qualname(qualname)
61+
}
62+
5363
#[pygetset]
5464
fn ag_await(&self, _vm: &VirtualMachine) -> Option<PyObjectRef> {
5565
self.inner.frame().yield_from_target()
@@ -424,8 +434,151 @@ impl IterNext for PyAsyncGenAThrow {
424434
}
425435
}
426436

437+
/// Awaitable wrapper for anext() builtin with default value.
438+
/// When StopAsyncIteration is raised, it converts it to StopIteration(default).
439+
#[pyclass(module = false, name = "anext_awaitable")]
440+
#[derive(Debug)]
441+
pub struct PyAnextAwaitable {
442+
wrapped: PyObjectRef,
443+
default_value: PyObjectRef,
444+
}
445+
446+
impl PyPayload for PyAnextAwaitable {
447+
#[inline]
448+
fn class(ctx: &Context) -> &'static Py<PyType> {
449+
ctx.types.anext_awaitable
450+
}
451+
}
452+
453+
#[pyclass(with(IterNext, Iterable))]
454+
impl PyAnextAwaitable {
455+
pub fn new(wrapped: PyObjectRef, default_value: PyObjectRef) -> Self {
456+
Self {
457+
wrapped,
458+
default_value,
459+
}
460+
}
461+
462+
#[pymethod(name = "__await__")]
463+
fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
464+
zelf
465+
}
466+
467+
/// Get the awaitable iterator from wrapped object.
468+
// = anextawaitable_getiter.
469+
fn get_awaitable_iter(&self, vm: &VirtualMachine) -> PyResult {
470+
use crate::builtins::PyCoroutine;
471+
use crate::protocol::PyIter;
472+
473+
let wrapped = &self.wrapped;
474+
475+
// If wrapped is already an async_generator_asend, it's an iterator
476+
if wrapped.class().is(vm.ctx.types.async_generator_asend)
477+
|| wrapped.class().is(vm.ctx.types.async_generator_athrow)
478+
{
479+
return Ok(wrapped.clone());
480+
}
481+
482+
// _PyCoro_GetAwaitableIter equivalent
483+
let awaitable = if wrapped.class().is(vm.ctx.types.coroutine_type) {
484+
// Coroutine - get __await__ later
485+
wrapped.clone()
486+
} else {
487+
// Try to get __await__ method
488+
if let Some(await_method) = vm.get_method(wrapped.clone(), identifier!(vm, __await__)) {
489+
await_method?.call((), vm)?
490+
} else {
491+
return Err(vm.new_type_error(format!(
492+
"object {} can't be used in 'await' expression",
493+
wrapped.class().name()
494+
)));
495+
}
496+
};
497+
498+
// If awaitable is a coroutine, get its __await__
499+
if awaitable.class().is(vm.ctx.types.coroutine_type) {
500+
let coro_await = vm.call_method(&awaitable, "__await__", ())?;
501+
// Check that __await__ returned an iterator
502+
if !PyIter::check(&coro_await) {
503+
return Err(vm.new_type_error("__await__ returned a non-iterable"));
504+
}
505+
return Ok(coro_await);
506+
}
507+
508+
// Check the result is an iterator, not a coroutine
509+
if awaitable.downcast_ref::<PyCoroutine>().is_some() {
510+
return Err(vm.new_type_error("__await__() returned a coroutine"));
511+
}
512+
513+
// Check that the result is an iterator
514+
if !PyIter::check(&awaitable) {
515+
return Err(vm.new_type_error(format!(
516+
"__await__() returned non-iterator of type '{}'",
517+
awaitable.class().name()
518+
)));
519+
}
520+
521+
Ok(awaitable)
522+
}
523+
524+
#[pymethod]
525+
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
526+
let awaitable = self.get_awaitable_iter(vm)?;
527+
let result = vm.call_method(&awaitable, "send", (val,));
528+
self.handle_result(result, vm)
529+
}
530+
531+
#[pymethod]
532+
fn throw(
533+
&self,
534+
exc_type: PyObjectRef,
535+
exc_val: OptionalArg,
536+
exc_tb: OptionalArg,
537+
vm: &VirtualMachine,
538+
) -> PyResult {
539+
let awaitable = self.get_awaitable_iter(vm)?;
540+
let result = vm.call_method(
541+
&awaitable,
542+
"throw",
543+
(
544+
exc_type,
545+
exc_val.unwrap_or_none(vm),
546+
exc_tb.unwrap_or_none(vm),
547+
),
548+
);
549+
self.handle_result(result, vm)
550+
}
551+
552+
#[pymethod]
553+
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
554+
if let Ok(awaitable) = self.get_awaitable_iter(vm) {
555+
let _ = vm.call_method(&awaitable, "close", ());
556+
}
557+
Ok(())
558+
}
559+
560+
/// Convert StopAsyncIteration to StopIteration(default_value)
561+
fn handle_result(&self, result: PyResult, vm: &VirtualMachine) -> PyResult {
562+
match result {
563+
Ok(value) => Ok(value),
564+
Err(exc) if exc.fast_isinstance(vm.ctx.exceptions.stop_async_iteration) => {
565+
Err(vm.new_stop_iteration(Some(self.default_value.clone())))
566+
}
567+
Err(exc) => Err(exc),
568+
}
569+
}
570+
}
571+
572+
impl SelfIter for PyAnextAwaitable {}
573+
impl IterNext for PyAnextAwaitable {
574+
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
575+
PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
576+
}
577+
}
578+
427579
pub fn init(ctx: &Context) {
428580
PyAsyncGen::extend_class(ctx, ctx.types.async_generator);
429581
PyAsyncGenASend::extend_class(ctx, ctx.types.async_generator_asend);
430582
PyAsyncGenAThrow::extend_class(ctx, ctx.types.async_generator_athrow);
583+
PyAnextAwaitable::extend_class(ctx, ctx.types.anext_awaitable);
431584
}

crates/vm/src/builtins/coroutine.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ impl PyCoroutine {
2929
&self.inner
3030
}
3131

32-
pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
32+
pub fn new(frame: FrameRef, name: PyStrRef, qualname: PyStrRef) -> Self {
3333
Self {
34-
inner: Coro::new(frame, name),
34+
inner: Coro::new(frame, name, qualname),
3535
}
3636
}
3737

@@ -45,6 +45,16 @@ impl PyCoroutine {
4545
self.inner.set_name(name)
4646
}
4747

48+
#[pygetset]
49+
fn __qualname__(&self) -> PyStrRef {
50+
self.inner.qualname()
51+
}
52+
53+
#[pygetset(setter)]
54+
fn set___qualname__(&self, qualname: PyStrRef) {
55+
self.inner.set_qualname(qualname)
56+
}
57+
4858
#[pymethod(name = "__await__")]
4959
const fn r#await(zelf: PyRef<Self>) -> PyCoroutineWrapper {
5060
PyCoroutineWrapper { coro: zelf }
@@ -156,6 +166,11 @@ impl PyCoroutineWrapper {
156166
) -> PyResult<PyIterReturn> {
157167
self.coro.throw(exc_type, exc_val, exc_tb, vm)
158168
}
169+
170+
#[pymethod]
171+
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
172+
self.coro.close(vm)
173+
}
159174
}
160175

161176
impl SelfIter for PyCoroutineWrapper {}

crates/vm/src/builtins/function.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,9 +425,15 @@ impl Py<PyFunction> {
425425
let is_gen = code.flags.contains(bytecode::CodeFlags::IS_GENERATOR);
426426
let is_coro = code.flags.contains(bytecode::CodeFlags::IS_COROUTINE);
427427
match (is_gen, is_coro) {
428-
(true, false) => Ok(PyGenerator::new(frame, self.__name__()).into_pyobject(vm)),
429-
(false, true) => Ok(PyCoroutine::new(frame, self.__name__()).into_pyobject(vm)),
430-
(true, true) => Ok(PyAsyncGen::new(frame, self.__name__()).into_pyobject(vm)),
428+
(true, false) => {
429+
Ok(PyGenerator::new(frame, self.__name__(), self.__qualname__()).into_pyobject(vm))
430+
}
431+
(false, true) => {
432+
Ok(PyCoroutine::new(frame, self.__name__(), self.__qualname__()).into_pyobject(vm))
433+
}
434+
(true, true) => {
435+
Ok(PyAsyncGen::new(frame, self.__name__(), self.__qualname__()).into_pyobject(vm))
436+
}
431437
(false, false) => vm.run_frame(frame),
432438
}
433439
}

crates/vm/src/builtins/generator.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ impl PyGenerator {
3232
&self.inner
3333
}
3434

35-
pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
35+
pub fn new(frame: FrameRef, name: PyStrRef, qualname: PyStrRef) -> Self {
3636
Self {
37-
inner: Coro::new(frame, name),
37+
inner: Coro::new(frame, name, qualname),
3838
}
3939
}
4040

@@ -48,6 +48,16 @@ impl PyGenerator {
4848
self.inner.set_name(name)
4949
}
5050

51+
#[pygetset]
52+
fn __qualname__(&self) -> PyStrRef {
53+
self.inner.qualname()
54+
}
55+
56+
#[pygetset(setter)]
57+
fn set___qualname__(&self, qualname: PyStrRef) {
58+
self.inner.set_qualname(qualname)
59+
}
60+
5161
#[pygetset]
5262
fn gi_frame(&self, _vm: &VirtualMachine) -> FrameRef {
5363
self.inner.frame()

0 commit comments

Comments
 (0)