diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index c7f3748e8f225..67aa24b3641c8 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -1603,6 +1603,9 @@ def add_coroutine_setup_call(self, class_name: str, obj: Value) -> Value: ) ) + def load_builtin(self, name: str, line: int) -> Value | None: + return self.builder.load_builtin(name, line) + def gen_arg_defaults(builder: IRBuilder) -> None: """Generate blocks for arguments that have default values. diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 6ef59559c0112..e8d22a051cc4d 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -136,7 +136,6 @@ from mypyc.primitives.generic_ops import iter_op, name_op from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op -from mypyc.primitives.registry import builtin_names from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op from mypyc.primitives.str_ops import str_slice_op from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op @@ -157,9 +156,8 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value: ) return builder.none(expr.line) fullname = expr.node.fullname - if fullname in builtin_names: - typ, src = builtin_names[fullname] - return builder.add(LoadAddress(typ, src, expr.line)) + if builtin := builder.load_builtin(fullname, expr.line): + return builtin # special cases if fullname == "builtins.None": return builder.none(expr.line) diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 33a716624b178..95894dcedaba8 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -36,8 +36,8 @@ IntOp, LoadAddress, LoadErrorValue, + LoadGlobal, LoadLiteral, - LoadMem, MethodCall, RaiseStandardError, Register, @@ -63,7 +63,6 @@ is_tuple_rprimitive, object_pointer_rprimitive, object_rprimitive, - pointer_rprimitive, short_int_rprimitive, ) from mypyc.irbuild.builder import IRBuilder @@ -828,8 +827,9 @@ def gen_condition(self) -> None: line = self.line def except_match() -> Value: - addr = builder.add(LoadAddress(pointer_rprimitive, stop_async_iteration_op.src, line)) - return builder.add(LoadMem(stop_async_iteration_op.type, addr, borrow=True)) + return builder.add( + LoadGlobal(stop_async_iteration_op.type, stop_async_iteration_op.src, line) + ) def try_body() -> None: awaitable = builder.call_c(anext_op, [builder.read(self.iter_target, line)], line) diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index 0cf0edb32f86c..ef7450aeab24f 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -45,7 +45,6 @@ ComparisonOp, GetAttr, Integer, - LoadAddress, LoadLiteral, Register, Return, @@ -85,7 +84,6 @@ ) from mypyc.primitives.generic_ops import generic_getattr, generic_setattr, py_setattr_op from mypyc.primitives.misc_ops import register_function -from mypyc.primitives.registry import builtin_names from mypyc.sametype import is_same_method_signature, is_same_type # Top-level transform functions @@ -935,9 +933,8 @@ def load_type(builder: IRBuilder, typ: TypeInfo, unbounded_type: Type | None, li if typ in builder.mapper.type_to_ir: class_ir = builder.mapper.type_to_ir[typ] class_obj = builder.builder.get_native_type(class_ir) - elif typ.fullname in builtin_names: - builtin_addr_type, src = builtin_names[typ.fullname] - class_obj = builder.add(LoadAddress(builtin_addr_type, src, line)) + elif builtin := builder.load_builtin(typ.fullname, line): + class_obj = builtin elif isinstance(unbounded_type, UnboundType): path_parts = unbounded_type.name.split(".") class_obj = builder.load_global_str(path_parts[0], line) @@ -1013,8 +1010,8 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None: coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line) builder.add(Return(coerced)) - typ, src = builtin_names["builtins.int"] - int_type_obj = builder.add(LoadAddress(typ, src, line)) + int_type_obj = builder.load_builtin("builtins.int", line) + assert int_type_obj is_int = builder.builder.type_is_op(impl_to_use, int_type_obj, line) native_call, non_native_call = BasicBlock(), BasicBlock() diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 1a79fce59ab12..c19eded77464e 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -185,7 +185,9 @@ ERR_NEG_INT, CFunctionDescription, binary_ops, + builtin_names, function_ops, + global_names, method_call_ops, unary_ops, ) @@ -323,8 +325,18 @@ def set_mem(self, ptr: Value, value_type: RType, value: Value) -> None: def get_element(self, reg: Value, field: str) -> Value: return self.add(GetElement(reg, field)) - def load_address(self, name: str, rtype: RType) -> Value: - return self.add(LoadAddress(rtype, name)) + def load_address(self, name: str, rtype: RType, line: int = -1) -> Value: + return self.add(LoadAddress(rtype, name, line)) + + def load_global(self, name: str, rtype: RType, line: int) -> Value: + return self.add(LoadGlobal(rtype, name, line)) + + def load_builtin(self, name: str, line: int) -> Value | None: + if builtin := builtin_names.get(name): + return self.load_address(builtin[1], builtin[0], line) + if glob := global_names.get(name): + return self.load_global(glob[1], glob[0], line) + return None def load_struct_field( self, ptr: Value, struct: RStruct, field: str, *, borrow: bool = False diff --git a/mypyc/irbuild/vec.py b/mypyc/irbuild/vec.py index 128c0f62c0bf0..00e6f0adcdd8f 100644 --- a/mypyc/irbuild/vec.py +++ b/mypyc/irbuild/vec.py @@ -54,7 +54,6 @@ vec_api_by_item_type, vec_item_type_tags, ) -from mypyc.primitives.registry import builtin_names if TYPE_CHECKING: from mypyc.irbuild.ll_builder import LowLevelIRBuilder @@ -213,8 +212,7 @@ def vec_item_type_info( builder: LowLevelIRBuilder, typ: RType, line: int ) -> tuple[Value | None, bool, int]: if isinstance(typ, RPrimitive) and typ.is_refcounted: - typ, src = builtin_names[typ.name] - return builder.load_address(src, typ), False, 0 + return builder.load_builtin(typ.name, line), False, 0 elif isinstance(typ, RInstance): return builder.load_native_type_object(typ.name), False, 0 elif typ in vec_item_type_tags: diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 6be74baff3d0b..7b78b61f50e26 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -31,6 +31,7 @@ custom_primitive_op, function_op, load_address_op, + load_global_op, method_op, ) @@ -52,7 +53,7 @@ ) # Get the boxed StopAsyncIteration object -stop_async_iteration_op = load_address_op( +stop_async_iteration_op = load_global_op( name="builtins.StopAsyncIteration", type=object_rprimitive, src="PyExc_StopAsyncIteration" ) diff --git a/mypyc/primitives/registry.py b/mypyc/primitives/registry.py index 2376ddf9fbd0d..c04b4ff65a757 100644 --- a/mypyc/primitives/registry.py +++ b/mypyc/primitives/registry.py @@ -85,8 +85,12 @@ class LoadAddressDescription(NamedTuple): # Primitive ops for unary ops unary_ops: dict[str, list[PrimitiveDescription]] = {} +# Mapping of type name to (type, C value variable name). builtin_names: dict[str, tuple[RType, str]] = {} +# Mapping of type name to (type, C pointer variable name). +global_names: dict[str, tuple[RType, str]] = {} + def method_op( name: str, @@ -387,6 +391,12 @@ def load_address_op(name: str, type: RType, src: str) -> LoadAddressDescription: return LoadAddressDescription(name, type, src) +def load_global_op(name: str, type: RType, src: str) -> LoadAddressDescription: + assert name not in global_names, "already defined: %s" % name + global_names[name] = (type, src) + return LoadAddressDescription(name, type, src) + + # Import various modules that set up global state. import mypyc.primitives.bytearray_ops import mypyc.primitives.bytes_ops diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 60f06b12f2959..101c54ad7eff0 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -374,6 +374,9 @@ class ReferenceError(Exception): pass class StopIteration(Exception): value: Any +class StopAsyncIteration(Exception): + value: Any + class ArithmeticError(Exception): pass class ZeroDivisionError(ArithmeticError): pass class OverflowError(ArithmeticError): pass diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index 6127a0dd47a33..2733a31f3af2a 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -1874,3 +1874,69 @@ def test_nested_coroutine_calls_another_nested_function(): from typing import Any, Generator def run(x: object) -> object: ... + +[case testRaiseStopAsyncIteration] +from async_iter import async_iter +from testutil import assertRaises + +class AsyncIter: + def __init__(self, vals: list[str]) -> None: + self._iter = iter(vals) + + def __aiter__(self) -> AsyncIter: + return self + + async def __anext__(self) -> str: + try: + return next(self._iter) + except StopIteration: + raise StopAsyncIteration + +async def test_iterator() -> None: + new_list: list[int] = [] + async for v in async_iter([1, 2, 3]): + new_list.append(v) + assert new_list == [1, 2, 3] + + new_list = [] + iter = async_iter([1, 2, 3]) + while True: + try: + v = await iter.__anext__() + new_list.append(v) + except StopAsyncIteration: + new_list.append(4) + break + assert new_list == [1, 2, 3, 4] + + with assertRaises(StopAsyncIteration): + await async_iter([]).__anext__() + +async def test_wrapper() -> None: + new_list: list[str] = [] + async for v in AsyncIter(['a', 'b', 'c']): + new_list.append(v) + assert new_list == ['a', 'b', 'c'] + + new_list = [] + iter = AsyncIter(['a', 'b', 'c']) + while True: + try: + v = await iter.__anext__() + new_list.append(v) + except StopAsyncIteration: + new_list.append('d') + break + assert new_list == ['a', 'b', 'c', 'd'] + + with assertRaises(StopAsyncIteration): + await AsyncIter([]).__anext__() + +[file async_iter.py] +from typing import AsyncIterator + +async def async_iter(vals: list[int]) -> AsyncIterator[int]: + for v in vals: + yield v + +[typing fixtures/typing-full.pyi]