Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,9 @@ def add_coroutine_setup_call(self, class_name: str, obj: Value) -> Value:
)
)

def load_builtin(self, name: str, line: int) -> Value | None:
return self.builder.load_builtin(name, line)


def gen_arg_defaults(builder: IRBuilder) -> None:
"""Generate blocks for arguments that have default values.
Expand Down
6 changes: 2 additions & 4 deletions mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@
from mypyc.primitives.generic_ops import iter_op, name_op
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
from mypyc.primitives.registry import builtin_names
from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op
from mypyc.primitives.str_ops import str_slice_op
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op
Expand All @@ -157,9 +156,8 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value:
)
return builder.none(expr.line)
fullname = expr.node.fullname
if fullname in builtin_names:
typ, src = builtin_names[fullname]
return builder.add(LoadAddress(typ, src, expr.line))
if builtin := builder.load_builtin(fullname, expr.line):
return builtin
# special cases
if fullname == "builtins.None":
return builder.none(expr.line)
Expand Down
8 changes: 4 additions & 4 deletions mypyc/irbuild/for_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
IntOp,
LoadAddress,
LoadErrorValue,
LoadGlobal,
LoadLiteral,
LoadMem,
MethodCall,
RaiseStandardError,
Register,
Expand All @@ -63,7 +63,6 @@
is_tuple_rprimitive,
object_pointer_rprimitive,
object_rprimitive,
pointer_rprimitive,
short_int_rprimitive,
)
from mypyc.irbuild.builder import IRBuilder
Expand Down Expand Up @@ -828,8 +827,9 @@ def gen_condition(self) -> None:
line = self.line

def except_match() -> Value:
addr = builder.add(LoadAddress(pointer_rprimitive, stop_async_iteration_op.src, line))
return builder.add(LoadMem(stop_async_iteration_op.type, addr, borrow=True))
return builder.add(
LoadGlobal(stop_async_iteration_op.type, stop_async_iteration_op.src, line)
)

def try_body() -> None:
awaitable = builder.call_c(anext_op, [builder.read(self.iter_target, line)], line)
Expand Down
11 changes: 4 additions & 7 deletions mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
ComparisonOp,
GetAttr,
Integer,
LoadAddress,
LoadLiteral,
Register,
Return,
Expand Down Expand Up @@ -85,7 +84,6 @@
)
from mypyc.primitives.generic_ops import generic_getattr, generic_setattr, py_setattr_op
from mypyc.primitives.misc_ops import register_function
from mypyc.primitives.registry import builtin_names
from mypyc.sametype import is_same_method_signature, is_same_type

# Top-level transform functions
Expand Down Expand Up @@ -935,9 +933,8 @@ def load_type(builder: IRBuilder, typ: TypeInfo, unbounded_type: Type | None, li
if typ in builder.mapper.type_to_ir:
class_ir = builder.mapper.type_to_ir[typ]
class_obj = builder.builder.get_native_type(class_ir)
elif typ.fullname in builtin_names:
builtin_addr_type, src = builtin_names[typ.fullname]
class_obj = builder.add(LoadAddress(builtin_addr_type, src, line))
elif builtin := builder.load_builtin(typ.fullname, line):
class_obj = builtin
elif isinstance(unbounded_type, UnboundType):
path_parts = unbounded_type.name.split(".")
class_obj = builder.load_global_str(path_parts[0], line)
Expand Down Expand Up @@ -1013,8 +1010,8 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None:
coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line)
builder.add(Return(coerced))

typ, src = builtin_names["builtins.int"]
int_type_obj = builder.add(LoadAddress(typ, src, line))
int_type_obj = builder.load_builtin("builtins.int", line)
assert int_type_obj
is_int = builder.builder.type_is_op(impl_to_use, int_type_obj, line)

native_call, non_native_call = BasicBlock(), BasicBlock()
Expand Down
16 changes: 14 additions & 2 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@
ERR_NEG_INT,
CFunctionDescription,
binary_ops,
builtin_names,
function_ops,
global_names,
method_call_ops,
unary_ops,
)
Expand Down Expand Up @@ -323,8 +325,18 @@ def set_mem(self, ptr: Value, value_type: RType, value: Value) -> None:
def get_element(self, reg: Value, field: str) -> Value:
return self.add(GetElement(reg, field))

def load_address(self, name: str, rtype: RType) -> Value:
return self.add(LoadAddress(rtype, name))
def load_address(self, name: str, rtype: RType, line: int = -1) -> Value:
return self.add(LoadAddress(rtype, name, line))

def load_global(self, name: str, rtype: RType, line: int) -> Value:
return self.add(LoadGlobal(rtype, name, line))

def load_builtin(self, name: str, line: int) -> Value | None:
if builtin := builtin_names.get(name):
return self.load_address(builtin[1], builtin[0], line)
if glob := global_names.get(name):
return self.load_global(glob[1], glob[0], line)
return None

def load_struct_field(
self, ptr: Value, struct: RStruct, field: str, *, borrow: bool = False
Expand Down
4 changes: 1 addition & 3 deletions mypyc/irbuild/vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
vec_api_by_item_type,
vec_item_type_tags,
)
from mypyc.primitives.registry import builtin_names

if TYPE_CHECKING:
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
Expand Down Expand Up @@ -213,8 +212,7 @@ def vec_item_type_info(
builder: LowLevelIRBuilder, typ: RType, line: int
) -> tuple[Value | None, bool, int]:
if isinstance(typ, RPrimitive) and typ.is_refcounted:
typ, src = builtin_names[typ.name]
return builder.load_address(src, typ), False, 0
return builder.load_builtin(typ.name, line), False, 0
elif isinstance(typ, RInstance):
return builder.load_native_type_object(typ.name), False, 0
elif typ in vec_item_type_tags:
Expand Down
3 changes: 2 additions & 1 deletion mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
custom_primitive_op,
function_op,
load_address_op,
load_global_op,
method_op,
)

Expand All @@ -52,7 +53,7 @@
)

# Get the boxed StopAsyncIteration object
stop_async_iteration_op = load_address_op(
stop_async_iteration_op = load_global_op(
name="builtins.StopAsyncIteration", type=object_rprimitive, src="PyExc_StopAsyncIteration"
)

Expand Down
10 changes: 10 additions & 0 deletions mypyc/primitives/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,12 @@ class LoadAddressDescription(NamedTuple):
# Primitive ops for unary ops
unary_ops: dict[str, list[PrimitiveDescription]] = {}

# Mapping of type name to (type, C value variable name).
builtin_names: dict[str, tuple[RType, str]] = {}

# Mapping of type name to (type, C pointer variable name).
global_names: dict[str, tuple[RType, str]] = {}


def method_op(
name: str,
Expand Down Expand Up @@ -387,6 +391,12 @@ def load_address_op(name: str, type: RType, src: str) -> LoadAddressDescription:
return LoadAddressDescription(name, type, src)


def load_global_op(name: str, type: RType, src: str) -> LoadAddressDescription:
assert name not in global_names, "already defined: %s" % name
global_names[name] = (type, src)
return LoadAddressDescription(name, type, src)


# Import various modules that set up global state.
import mypyc.primitives.bytearray_ops
import mypyc.primitives.bytes_ops
Expand Down
3 changes: 3 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ class ReferenceError(Exception): pass
class StopIteration(Exception):
value: Any

class StopAsyncIteration(Exception):
value: Any

class ArithmeticError(Exception): pass
class ZeroDivisionError(ArithmeticError): pass
class OverflowError(ArithmeticError): pass
Expand Down
66 changes: 66 additions & 0 deletions mypyc/test-data/run-async.test
Original file line number Diff line number Diff line change
Expand Up @@ -1874,3 +1874,69 @@ def test_nested_coroutine_calls_another_nested_function():
from typing import Any, Generator

def run(x: object) -> object: ...

[case testRaiseStopAsyncIteration]
from async_iter import async_iter
from testutil import assertRaises

class AsyncIter:
def __init__(self, vals: list[str]) -> None:
self._iter = iter(vals)

def __aiter__(self) -> AsyncIter:
return self

async def __anext__(self) -> str:
try:
return next(self._iter)
except StopIteration:
raise StopAsyncIteration

async def test_iterator() -> None:
new_list: list[int] = []
async for v in async_iter([1, 2, 3]):
new_list.append(v)
assert new_list == [1, 2, 3]

new_list = []
iter = async_iter([1, 2, 3])
while True:
try:
v = await iter.__anext__()
new_list.append(v)
except StopAsyncIteration:
new_list.append(4)
break
assert new_list == [1, 2, 3, 4]

with assertRaises(StopAsyncIteration):
await async_iter([]).__anext__()

async def test_wrapper() -> None:
new_list: list[str] = []
async for v in AsyncIter(['a', 'b', 'c']):
new_list.append(v)
assert new_list == ['a', 'b', 'c']

new_list = []
iter = AsyncIter(['a', 'b', 'c'])
while True:
try:
v = await iter.__anext__()
new_list.append(v)
except StopAsyncIteration:
new_list.append('d')
break
assert new_list == ['a', 'b', 'c', 'd']

with assertRaises(StopAsyncIteration):
await AsyncIter([]).__anext__()

[file async_iter.py]
from typing import AsyncIterator

async def async_iter(vals: list[int]) -> AsyncIterator[int]:
for v in vals:
yield v

[typing fixtures/typing-full.pyi]
Loading