From 2d1fac6ca469772898d039476ad8507384181c9b Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Thu, 21 May 2026 10:41:14 +0000 Subject: [PATCH 1/2] [Enh] Ensure type closure for primitive func --- kernels/layout_utils.py | 5 +- kernels/mfma_preshuffle_pipeline.py | 11 ++-- kernels/silu_and_mul_fq.py | 5 +- python/flydsl/expr/math.py | 27 +++++--- python/flydsl/expr/primitive.py | 99 ++++++++++++++++++++++++---- python/flydsl/expr/typing.py | 9 +-- tests/unit/test_layout_algebra.py | 22 ++----- tests/unit/test_static_vs_dynamic.py | 20 ++---- 8 files changed, 128 insertions(+), 70 deletions(-) diff --git a/kernels/layout_utils.py b/kernels/layout_utils.py index 976996c0..1439af18 100644 --- a/kernels/layout_utils.py +++ b/kernels/layout_utils.py @@ -156,9 +156,8 @@ def crd2idx(crd, layout): cv = raw crd_i32.append(cv) coord_val = fx.make_coord(*crd_i32) - result = fx.crd2idx(coord_val, layout) - scalar = fx.get_scalar(result) - if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): + scalar = fx.get_scalar(fx.crd2idx(coord_val, layout)).ir_value() + if not isinstance(scalar.type, ir.IndexType): scalar = arith.index_cast(T.index, scalar) return _wrap(scalar) diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index 118ba670..0a130995 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -20,12 +20,11 @@ def crd2idx(crd, layout): - """crd2idx returning an index-type scalar (unwraps fly.int_tuple).""" - result = fx.crd2idx(crd, layout) - scalar = fx.get_scalar(result) - if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): - scalar = _arith.IndexCastOp(T.index, scalar).result - return scalar + """crd2idx returning an index-typed ir.Value (unwraps fly.int_tuple).""" + scalar = fx.get_scalar(fx.crd2idx(crd, layout)).ir_value() + if isinstance(scalar.type, ir.IndexType): + return scalar + return _arith.IndexCastOp(T.index, scalar).result def swizzle_xor16(row, col, k_blocks16): diff --git a/kernels/silu_and_mul_fq.py b/kernels/silu_and_mul_fq.py index cbe00c0b..71a80c3d 100644 --- a/kernels/silu_and_mul_fq.py +++ b/kernels/silu_and_mul_fq.py @@ -75,10 +75,7 @@ def _make_scale_tiled_layout(scale_cols_val): def _scale_byte_offset(layout_scale, row, col32): """Compute byte offset for one E8M0 scale element via layout algebra.""" - result = fx.crd2idx(fx.make_coord(row, col32), layout_scale) - scalar = fx.get_scalar(result) - if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType): - scalar = arith.index_cast(T.index, scalar) + scalar = fx.get_scalar(fx.crd2idx(fx.make_coord(row, col32), layout_scale)) return ArithValue(scalar) diff --git a/python/flydsl/expr/math.py b/python/flydsl/expr/math.py index c4d128c2..4cc061fe 100644 --- a/python/flydsl/expr/math.py +++ b/python/flydsl/expr/math.py @@ -23,17 +23,21 @@ def _traced_math_op(fn): - """Like @traced_op, but re-wraps results to preserve Numeric class hierarchy. + """Like @traced_op, but re-wraps results to preserve DslType closure. - If the first positional arg is a Numeric (Float32, Int32, …), the MLIR - result is wrapped back into the appropriate Numeric subclass via - ``Numeric.from_ir_type``. Raw ir.Value inputs pass through unchanged. + If the first positional arg is a ``Numeric`` (Float32, Int32, …) or a + ``Vector``, the MLIR result is wrapped back into the matching DSL type so + callers stay at the DSL level instead of dropping to raw ir.Value / ArithValue. + Raw ir.Value inputs pass through unchanged. """ @wraps(fn) def wrapper(*args, **kwargs): + from .typing import Vector + first = args[0] if args else None - do_rewrap = isinstance(first, Numeric) + is_vector = isinstance(first, Vector) + is_numeric = isinstance(first, Numeric) loc = kwargs.pop("loc", None) if loc is None: @@ -42,12 +46,19 @@ def wrapper(*args, **kwargs): with loc: result = fn(*args, **kwargs) - if not do_rewrap: + if not (is_vector or is_numeric): return result + + def _wrap_arith_type(value): + if is_vector: + elem_dtype = Numeric.from_ir_type(ir.VectorType(value.type).element_type) + return Vector(value, first.shape, elem_dtype) + return Numeric.from_ir_type(value.type)(value) + if isinstance(result, ir.Value): - return Numeric.from_ir_type(result.type)(result) + return _wrap_arith_type(result) # Multi-result (e.g. sincos) - return tuple(Numeric.from_ir_type(r.type)(r) for r in result) + return tuple(_wrap_arith_type(r) for r in result) return wrapper diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index 83349dcb..5f75c4e5 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -247,6 +247,16 @@ def _check_profile(match_func, lhs, rhs): raise ValueError(f"profile mismatch: {match_func.__name__}({lhs.type}, {rhs.type}) is False") +def _wrap_numeric_type(value): + from .numeric import Numeric + + if not isinstance(value, ir.Value): + return value + if isinstance(value, Numeric): + return value + return Numeric.from_ir_type(value.type)(value) + + # ===----------------------------------------------------------------------=== # # Compile-time utility # ===----------------------------------------------------------------------=== # @@ -478,13 +488,54 @@ def make_fragment_like(tensor, dtype=None, loc=None, ip=None): @traced_op def get_scalar(int_tuple, loc=None, ip=None): - return fly.get_scalar(int_tuple, loc=loc, ip=ip) + """Unwrap a rank-1, single-element tuple back to a plain scalar value. + + Fails if the input has more than one leaf - use this only when you know + the tuple is a trivial wrapper. + + Examples: + get_scalar(make_coord(tid)) -> Int32(tid) + get_scalar(make_int_tuple(5)) -> 5 + """ + if not _is_int_tuple_value(int_tuple): + return int_tuple + if int_tuple.is_leaf and int_tuple.is_static: + return int_tuple.get_static_leaf_int + return _wrap_numeric_type(fly.get_scalar(int_tuple, loc=loc, ip=ip)) @traced_op def get_leaves(input, dynamic_only=False, loc=None, ip=None): - res_lists = fly.GetLeavesOp(input, dynamicOnly=dynamic_only, loc=loc, ip=ip) - return tuple(res_lists.results) + """Flatten an IntTuple into a flat sequence of leaf values. + + Set *dynamic_only=True* to keep only runtime values and drop static + constants - handy when you need the inputs that were passed at call time. + + Examples: + get_leaves(make_coord(tid, 0)) -> (Int32(tid), 0) + get_leaves(make_coord(tid, 0), dynamic_only=True) -> (Int32(tid),) # 0 is static, dropped + """ + if dynamic_only: + res_lists = fly.GetLeavesOp(input, dynamicOnly=True, loc=loc, ip=ip) + return tuple(_wrap_numeric_type(r) for r in res_lists.results) + + def _walk_int_tuple_leaves(ty): + if ty.is_leaf: + yield ty + return + for i in range(ty.rank): + yield from _walk_int_tuple_leaves(ty.at(i)) + + ty = IntTupleType(input.type) + res_lists = fly.GetLeavesOp(input, dynamicOnly=True, loc=loc, ip=ip) + dyn_iter = iter(res_lists.results) + out = [] + for leaf_ty in _walk_int_tuple_leaves(ty): + if leaf_ty.is_static: + out.append(leaf_ty.get_static_leaf_int) + else: + out.append(_wrap_numeric_type(next(dyn_iter))) + return tuple(out) @traced_op @@ -1041,11 +1092,17 @@ def ptrtoint(ptr, loc=None, ip=None): if is_generic_address_space(ptr.address_space, AddressSpace.Register): raise ValueError("ptrtoint is not supported for register address space") - return fly.ptrtoint(ptr, loc=loc, ip=ip) + return _wrap_numeric_type(fly.ptrtoint(ptr, loc=loc, ip=ip)) @traced_op def add_offset(ptr, offset, loc=None, ip=None): + """Shift *ptr* by *offset* elements + + Examples: + ptr2 = add_offset(ptr, 16) # move forward 16 elements + ptr2 = add_offset(ptr, tile_id * BM) # runtime offset + """ if not _is_int_tuple_value(offset): offset = make_int_tuple(offset, loc=loc, ip=ip) return fly.add_offset(ptr, offset, loc=loc, ip=ip) @@ -1058,13 +1115,27 @@ def apply_swizzle(ptr, swizzle, loc=None, ip=None): @traced_op def ptr_load(ptr, result_type=None, loc=None, ip=None): + """Load one value (scalar or vector) from *ptr*; dtype defaults to ptr's element type. + + Examples: + v = ptr_load(ptr) + """ if result_type is None: result_type = ptr.element_type - return fly.ptr_load(result_type.ir_type, ptr, loc=loc, ip=ip) + if not isinstance(result_type, ir.Type): + result_type = result_type.ir_type + return _wrap_numeric_type(fly.ptr_load(result_type, ptr, loc=loc, ip=ip)) @traced_op def ptr_store(value, ptr, loc=None, ip=None): + """Store *value* into *ptr*. Types must match the pointer's element type. + + Examples: + ptr_store(val, ptr) + """ + if not isinstance(value, ir.Value): + value = ptr.element_type(value).ir_value() return fly.ptr_store(value, ptr, loc=loc, ip=ip) @@ -1095,7 +1166,9 @@ def memref_alloca(memref_type, layout, loc=None, ip=None): @traced_op def memref_load_vec(memref, loc=None, ip=None): - return fly.memref_load_vec(memref, loc=loc, ip=ip) + from .typing import Vector + + return Vector(fly.memref_load_vec(memref, loc=loc, ip=ip), memref.shape.to_py_value(), memref.dtype) @traced_op @@ -1106,26 +1179,24 @@ def memref_store_vec(vector, memref, loc=None, ip=None): @traced_op def memref_load(memref, indices, loc=None, ip=None): if isinstance(indices, ir.Value): - if str(indices.type).startswith("!fly.int_tuple"): - return fly.memref_load(memref, indices, loc=loc, ip=ip) if str(indices.type) == "index": indices = _arith.IndexCastOp(T.i32(), indices) - indices = make_int_tuple(indices, loc=loc, ip=ip) - return fly.memref_load(memref, indices, loc=loc, ip=ip) + if not _is_int_tuple_value(indices): + indices = make_int_tuple(indices, loc=loc, ip=ip) + return _wrap_numeric_type(fly.memref_load(memref, indices, loc=loc, ip=ip)) indices = make_int_tuple(indices, loc=loc, ip=ip) _check_profile(is_profile_weakly_congruent, indices, memref) - return fly.memref_load(memref, indices, loc=loc, ip=ip) + return _wrap_numeric_type(fly.memref_load(memref, indices, loc=loc, ip=ip)) @traced_op def memref_store(value, memref, indices, loc=None, ip=None): if isinstance(indices, ir.Value): - if str(indices.type).startswith("!fly.int_tuple"): - return fly.memref_store(value, memref, indices, loc=loc, ip=ip) if str(indices.type) == "index": indices = _arith.IndexCastOp(T.i32(), indices) - indices = make_int_tuple(indices, loc=loc, ip=ip) + if not _is_int_tuple_value(indices): + indices = make_int_tuple(indices, loc=loc, ip=ip) return fly.memref_store(value, memref, indices, loc=loc, ip=ip) indices = make_int_tuple(indices, loc=loc, ip=ip) diff --git a/python/flydsl/expr/typing.py b/python/flydsl/expr/typing.py index 7ae5c6cc..d9b54bca 100644 --- a/python/flydsl/expr/typing.py +++ b/python/flydsl/expr/typing.py @@ -477,11 +477,8 @@ def _rebuild_py_value(self, leaf_iter): if self.is_leaf: if self.is_static: return self.get_static_leaf_int - val = next(leaf_iter) - width = ir.IntegerType(val.type).width - wrapper = Int64 if width == 64 else Int32 - return wrapper(val) - return tuple(IntTuple(get_(self, i))._rebuild_py_value(leaf_iter) for i in range(self.rank)) + return next(leaf_iter) + return tuple(get_(self, i)._rebuild_py_value(leaf_iter) for i in range(self.rank)) @traced_op def to_py_value(self, loc=None, ip=None): @@ -782,7 +779,7 @@ def __setitem__(self, coord, value, loc=None, ip=None): @traced_op def load(self, loc=None, ip=None): - return Vector(memref_load_vec(self, loc=loc, ip=ip), self.shape.to_py_value(), self.dtype) + return memref_load_vec(self, loc=loc, ip=ip) @traced_op def store(self, vector, loc=None, ip=None): diff --git a/tests/unit/test_layout_algebra.py b/tests/unit/test_layout_algebra.py index 4f7a55d6..610b3d85 100644 --- a/tests/unit/test_layout_algebra.py +++ b/tests/unit/test_layout_algebra.py @@ -32,13 +32,7 @@ FLY_PIPELINE = ( - "builtin.module(" - "fly-canonicalize," - "fly-layout-lowering," - "fly-canonicalize," - "convert-fly-to-rocdl," - "canonicalize," - "cse)" + "builtin.module(fly-canonicalize,fly-layout-lowering,fly-canonicalize,convert-fly-to-rocdl,canonicalize,cse)" ) @@ -218,9 +212,8 @@ def build_static(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("comp_dyn", FunctionType.get([i32] * 8, [idx])) + f = func.FuncOp("comp_dyn", FunctionType.get([i32] * 8, [i32])) entry = f.add_entry_block() with InsertionPoint(entry): args = list(entry.arguments) @@ -228,8 +221,8 @@ def build_static(): B = fx.make_layout(fx.make_shape(args[4], args[5]), fx.make_stride(args[6], args[7])) R = fx.composition(A, B) sz = fx.size(R) - sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + sc = fx.get_scalar(sz).ir_value() + func.ReturnOp([sc]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) assert module.operation.verify() @@ -317,9 +310,8 @@ def test_complement_rank_2_dynamic_stride_error(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("compl_dyn", FunctionType.get([i32], [idx])) + f = func.FuncOp("compl_dyn", FunctionType.get([i32], [i32])) entry = f.add_entry_block() with InsertionPoint(entry): runtime_stride = entry.arguments[0] @@ -328,8 +320,8 @@ def test_complement_rank_2_dynamic_stride_error(): tiler = fx.make_layout(shape, stride) comp = fx.complement(tiler, 12) sz = fx.size(comp) - sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + sc = fx.get_scalar(sz).ir_value() + func.ReturnOp([sc]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) diff --git a/tests/unit/test_static_vs_dynamic.py b/tests/unit/test_static_vs_dynamic.py index a3c42c32..e027320e 100644 --- a/tests/unit/test_static_vs_dynamic.py +++ b/tests/unit/test_static_vs_dynamic.py @@ -27,13 +27,7 @@ FLY_PIPELINE = ( - "builtin.module(" - "fly-canonicalize," - "fly-layout-lowering," - "fly-canonicalize," - "convert-fly-to-rocdl," - "canonicalize," - "cse)" + "builtin.module(fly-canonicalize,fly-layout-lowering,fly-canonicalize,convert-fly-to-rocdl,canonicalize,cse)" ) @@ -85,9 +79,8 @@ def test_layout_dynamic_types(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("dynamic_layout", FunctionType.get([i32] * 4, [idx])) + f = func.FuncOp("dynamic_layout", FunctionType.get([i32] * 4, [i32])) entry = f.add_entry_block() with InsertionPoint(entry): dim0, dim1, stride0, stride1 = entry.arguments @@ -98,7 +91,7 @@ def test_layout_dynamic_types(): layout = fx.make_layout(shape, stride) sz = fx.size(layout) sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + func.ReturnOp([sc.ir_value()]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) @@ -138,9 +131,8 @@ def test_mixed_static_dynamic(): with Location.unknown(ctx): module = Module.create() i32 = IntegerType.get_signless(32) - idx = IndexType.get() with InsertionPoint(module.body): - f = func.FuncOp("mixed_layout", FunctionType.get([i32, i32], [idx])) + f = func.FuncOp("mixed_layout", FunctionType.get([i32, i32], [i32])) entry = f.add_entry_block() with InsertionPoint(entry): runtime_extent, runtime_stride = entry.arguments @@ -152,8 +144,8 @@ def test_mixed_static_dynamic(): stride = fx.make_stride(c16, runtime_stride) layout = fx.make_layout(shape, stride) sz = fx.size(layout) - sc = fx.get_scalar(sz) - func.ReturnOp([arith.IndexCastOp(idx, sc).result]) + sc = fx.get_scalar(sz).ir_value() + func.ReturnOp([sc]) pm = PassManager.parse(FLY_PIPELINE, ctx) pm.run(module.operation) From 205ec384d4f668ffa599f03612b959ee35eb3734 Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Thu, 21 May 2026 10:41:14 +0000 Subject: [PATCH 2/2] fix --- kernels/silu_and_mul_fq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernels/silu_and_mul_fq.py b/kernels/silu_and_mul_fq.py index 71a80c3d..fd696a28 100644 --- a/kernels/silu_and_mul_fq.py +++ b/kernels/silu_and_mul_fq.py @@ -76,7 +76,7 @@ def _make_scale_tiled_layout(scale_cols_val): def _scale_byte_offset(layout_scale, row, col32): """Compute byte offset for one E8M0 scale element via layout algebra.""" scalar = fx.get_scalar(fx.crd2idx(fx.make_coord(row, col32), layout_scale)) - return ArithValue(scalar) + return ArithValue(scalar.ir_value()) def build_silu_and_mul_fq_module(