Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions kernels/layout_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,8 @@ def crd2idx(crd, layout):
cv = raw
crd_i32.append(cv)
coord_val = fx.make_coord(*crd_i32)
result = fx.crd2idx(coord_val, layout)
scalar = fx.get_scalar(result)
if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType):
scalar = fx.get_scalar(fx.crd2idx(coord_val, layout)).ir_value()
if not isinstance(scalar.type, ir.IndexType):
scalar = arith.index_cast(T.index, scalar)
return _wrap(scalar)

Expand Down
11 changes: 5 additions & 6 deletions kernels/mfma_preshuffle_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@


def crd2idx(crd, layout):
"""crd2idx returning an index-type scalar (unwraps fly.int_tuple)."""
result = fx.crd2idx(crd, layout)
scalar = fx.get_scalar(result)
if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType):
scalar = _arith.IndexCastOp(T.index, scalar).result
return scalar
"""crd2idx returning an index-typed ir.Value (unwraps fly.int_tuple)."""
scalar = fx.get_scalar(fx.crd2idx(crd, layout)).ir_value()
if isinstance(scalar.type, ir.IndexType):
return scalar
return _arith.IndexCastOp(T.index, scalar).result


def swizzle_xor16(row, col, k_blocks16):
Expand Down
7 changes: 2 additions & 5 deletions kernels/silu_and_mul_fq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,8 @@ def _make_scale_tiled_layout(scale_cols_val):

def _scale_byte_offset(layout_scale, row, col32):
"""Compute byte offset for one E8M0 scale element via layout algebra."""
result = fx.crd2idx(fx.make_coord(row, col32), layout_scale)
scalar = fx.get_scalar(result)
if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType):
scalar = arith.index_cast(T.index, scalar)
return ArithValue(scalar)
scalar = fx.get_scalar(fx.crd2idx(fx.make_coord(row, col32), layout_scale))
return ArithValue(scalar.ir_value())


def build_silu_and_mul_fq_module(
Expand Down
27 changes: 19 additions & 8 deletions python/flydsl/expr/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,21 @@


def _traced_math_op(fn):
"""Like @traced_op, but re-wraps results to preserve Numeric class hierarchy.
"""Like @traced_op, but re-wraps results to preserve DslType closure.

If the first positional arg is a Numeric (Float32, Int32, …), the MLIR
result is wrapped back into the appropriate Numeric subclass via
``Numeric.from_ir_type``. Raw ir.Value inputs pass through unchanged.
If the first positional arg is a ``Numeric`` (Float32, Int32, …) or a
``Vector``, the MLIR result is wrapped back into the matching DSL type so
callers stay at the DSL level instead of dropping to raw ir.Value / ArithValue.
Raw ir.Value inputs pass through unchanged.
"""

@wraps(fn)
def wrapper(*args, **kwargs):
from .typing import Vector

first = args[0] if args else None
do_rewrap = isinstance(first, Numeric)
is_vector = isinstance(first, Vector)
is_numeric = isinstance(first, Numeric)

loc = kwargs.pop("loc", None)
if loc is None:
Expand All @@ -42,12 +46,19 @@ def wrapper(*args, **kwargs):
with loc:
result = fn(*args, **kwargs)

if not do_rewrap:
if not (is_vector or is_numeric):
return result

def _wrap_arith_type(value):
if is_vector:
elem_dtype = Numeric.from_ir_type(ir.VectorType(value.type).element_type)
return Vector(value, first.shape, elem_dtype)
return Numeric.from_ir_type(value.type)(value)

if isinstance(result, ir.Value):
return Numeric.from_ir_type(result.type)(result)
return _wrap_arith_type(result)
# Multi-result (e.g. sincos)
return tuple(Numeric.from_ir_type(r.type)(r) for r in result)
return tuple(_wrap_arith_type(r) for r in result)

return wrapper

Expand Down
99 changes: 85 additions & 14 deletions python/flydsl/expr/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ def _check_profile(match_func, lhs, rhs):
raise ValueError(f"profile mismatch: {match_func.__name__}({lhs.type}, {rhs.type}) is False")


def _wrap_numeric_type(value):
from .numeric import Numeric

if not isinstance(value, ir.Value):
return value
if isinstance(value, Numeric):
return value
Comment on lines +253 to +256
return Numeric.from_ir_type(value.type)(value)


# ===----------------------------------------------------------------------=== #
# Compile-time utility
# ===----------------------------------------------------------------------=== #
Expand Down Expand Up @@ -478,13 +488,54 @@ def make_fragment_like(tensor, dtype=None, loc=None, ip=None):

@traced_op
def get_scalar(int_tuple, loc=None, ip=None):
return fly.get_scalar(int_tuple, loc=loc, ip=ip)
"""Unwrap a rank-1, single-element tuple back to a plain scalar value.

Fails if the input has more than one leaf - use this only when you know
the tuple is a trivial wrapper.

Examples:
get_scalar(make_coord(tid)) -> Int32(tid)
get_scalar(make_int_tuple(5)) -> 5
"""
if not _is_int_tuple_value(int_tuple):
return int_tuple
if int_tuple.is_leaf and int_tuple.is_static:
return int_tuple.get_static_leaf_int
return _wrap_numeric_type(fly.get_scalar(int_tuple, loc=loc, ip=ip))


@traced_op
def get_leaves(input, dynamic_only=False, loc=None, ip=None):
res_lists = fly.GetLeavesOp(input, dynamicOnly=dynamic_only, loc=loc, ip=ip)
return tuple(res_lists.results)
"""Flatten an IntTuple into a flat sequence of leaf values.

Set *dynamic_only=True* to keep only runtime values and drop static
constants - handy when you need the inputs that were passed at call time.

Examples:
get_leaves(make_coord(tid, 0)) -> (Int32(tid), 0)
get_leaves(make_coord(tid, 0), dynamic_only=True) -> (Int32(tid),) # 0 is static, dropped
"""
if dynamic_only:
res_lists = fly.GetLeavesOp(input, dynamicOnly=True, loc=loc, ip=ip)
return tuple(_wrap_numeric_type(r) for r in res_lists.results)

def _walk_int_tuple_leaves(ty):
if ty.is_leaf:
yield ty
return
for i in range(ty.rank):
yield from _walk_int_tuple_leaves(ty.at(i))

ty = IntTupleType(input.type)
res_lists = fly.GetLeavesOp(input, dynamicOnly=True, loc=loc, ip=ip)
dyn_iter = iter(res_lists.results)
out = []
for leaf_ty in _walk_int_tuple_leaves(ty):
if leaf_ty.is_static:
out.append(leaf_ty.get_static_leaf_int)
else:
out.append(_wrap_numeric_type(next(dyn_iter)))
return tuple(out)


@traced_op
Expand Down Expand Up @@ -1041,11 +1092,17 @@ def ptrtoint(ptr, loc=None, ip=None):

if is_generic_address_space(ptr.address_space, AddressSpace.Register):
raise ValueError("ptrtoint is not supported for register address space")
return fly.ptrtoint(ptr, loc=loc, ip=ip)
return _wrap_numeric_type(fly.ptrtoint(ptr, loc=loc, ip=ip))


@traced_op
def add_offset(ptr, offset, loc=None, ip=None):
"""Shift *ptr* by *offset* elements

Examples:
ptr2 = add_offset(ptr, 16) # move forward 16 elements
ptr2 = add_offset(ptr, tile_id * BM) # runtime offset
"""
if not _is_int_tuple_value(offset):
offset = make_int_tuple(offset, loc=loc, ip=ip)
return fly.add_offset(ptr, offset, loc=loc, ip=ip)
Expand All @@ -1058,13 +1115,27 @@ def apply_swizzle(ptr, swizzle, loc=None, ip=None):

@traced_op
def ptr_load(ptr, result_type=None, loc=None, ip=None):
"""Load one value (scalar or vector) from *ptr*; dtype defaults to ptr's element type.

Examples:
v = ptr_load(ptr)
"""
if result_type is None:
result_type = ptr.element_type
return fly.ptr_load(result_type.ir_type, ptr, loc=loc, ip=ip)
if not isinstance(result_type, ir.Type):
result_type = result_type.ir_type
return _wrap_numeric_type(fly.ptr_load(result_type, ptr, loc=loc, ip=ip))


@traced_op
def ptr_store(value, ptr, loc=None, ip=None):
"""Store *value* into *ptr*. Types must match the pointer's element type.

Examples:
ptr_store(val, ptr)
"""
if not isinstance(value, ir.Value):
value = ptr.element_type(value).ir_value()
return fly.ptr_store(value, ptr, loc=loc, ip=ip)


Expand Down Expand Up @@ -1095,7 +1166,9 @@ def memref_alloca(memref_type, layout, loc=None, ip=None):

@traced_op
def memref_load_vec(memref, loc=None, ip=None):
return fly.memref_load_vec(memref, loc=loc, ip=ip)
from .typing import Vector

return Vector(fly.memref_load_vec(memref, loc=loc, ip=ip), memref.shape.to_py_value(), memref.dtype)


@traced_op
Expand All @@ -1106,26 +1179,24 @@ def memref_store_vec(vector, memref, loc=None, ip=None):
@traced_op
def memref_load(memref, indices, loc=None, ip=None):
if isinstance(indices, ir.Value):
if str(indices.type).startswith("!fly.int_tuple"):
return fly.memref_load(memref, indices, loc=loc, ip=ip)
if str(indices.type) == "index":
indices = _arith.IndexCastOp(T.i32(), indices)
indices = make_int_tuple(indices, loc=loc, ip=ip)
return fly.memref_load(memref, indices, loc=loc, ip=ip)
if not _is_int_tuple_value(indices):
indices = make_int_tuple(indices, loc=loc, ip=ip)
return _wrap_numeric_type(fly.memref_load(memref, indices, loc=loc, ip=ip))

indices = make_int_tuple(indices, loc=loc, ip=ip)
_check_profile(is_profile_weakly_congruent, indices, memref)
return fly.memref_load(memref, indices, loc=loc, ip=ip)
return _wrap_numeric_type(fly.memref_load(memref, indices, loc=loc, ip=ip))


@traced_op
def memref_store(value, memref, indices, loc=None, ip=None):
if isinstance(indices, ir.Value):
if str(indices.type).startswith("!fly.int_tuple"):
return fly.memref_store(value, memref, indices, loc=loc, ip=ip)
if str(indices.type) == "index":
indices = _arith.IndexCastOp(T.i32(), indices)
indices = make_int_tuple(indices, loc=loc, ip=ip)
if not _is_int_tuple_value(indices):
indices = make_int_tuple(indices, loc=loc, ip=ip)
return fly.memref_store(value, memref, indices, loc=loc, ip=ip)

indices = make_int_tuple(indices, loc=loc, ip=ip)
Expand Down
9 changes: 3 additions & 6 deletions python/flydsl/expr/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,11 +477,8 @@ def _rebuild_py_value(self, leaf_iter):
if self.is_leaf:
if self.is_static:
return self.get_static_leaf_int
val = next(leaf_iter)
width = ir.IntegerType(val.type).width
wrapper = Int64 if width == 64 else Int32
return wrapper(val)
return tuple(IntTuple(get_(self, i))._rebuild_py_value(leaf_iter) for i in range(self.rank))
return next(leaf_iter)
return tuple(get_(self, i)._rebuild_py_value(leaf_iter) for i in range(self.rank))

@traced_op
def to_py_value(self, loc=None, ip=None):
Expand Down Expand Up @@ -782,7 +779,7 @@ def __setitem__(self, coord, value, loc=None, ip=None):

@traced_op
def load(self, loc=None, ip=None):
return Vector(memref_load_vec(self, loc=loc, ip=ip), self.shape.to_py_value(), self.dtype)
return memref_load_vec(self, loc=loc, ip=ip)

@traced_op
def store(self, vector, loc=None, ip=None):
Expand Down
22 changes: 7 additions & 15 deletions tests/unit/test_layout_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,7 @@


FLY_PIPELINE = (
"builtin.module("
"fly-canonicalize,"
"fly-layout-lowering,"
"fly-canonicalize,"
"convert-fly-to-rocdl,"
"canonicalize,"
"cse)"
"builtin.module(fly-canonicalize,fly-layout-lowering,fly-canonicalize,convert-fly-to-rocdl,canonicalize,cse)"
)


Expand Down Expand Up @@ -218,18 +212,17 @@ def build_static():
with Location.unknown(ctx):
module = Module.create()
i32 = IntegerType.get_signless(32)
idx = IndexType.get()
with InsertionPoint(module.body):
f = func.FuncOp("comp_dyn", FunctionType.get([i32] * 8, [idx]))
f = func.FuncOp("comp_dyn", FunctionType.get([i32] * 8, [i32]))
entry = f.add_entry_block()
with InsertionPoint(entry):
args = list(entry.arguments)
A = fx.make_layout(fx.make_shape(args[0], args[1]), fx.make_stride(args[2], args[3]))
B = fx.make_layout(fx.make_shape(args[4], args[5]), fx.make_stride(args[6], args[7]))
R = fx.composition(A, B)
sz = fx.size(R)
sc = fx.get_scalar(sz)
func.ReturnOp([arith.IndexCastOp(idx, sc).result])
sc = fx.get_scalar(sz).ir_value()
func.ReturnOp([sc])
pm = PassManager.parse(FLY_PIPELINE, ctx)
pm.run(module.operation)
assert module.operation.verify()
Expand Down Expand Up @@ -317,9 +310,8 @@ def test_complement_rank_2_dynamic_stride_error():
with Location.unknown(ctx):
module = Module.create()
i32 = IntegerType.get_signless(32)
idx = IndexType.get()
with InsertionPoint(module.body):
f = func.FuncOp("compl_dyn", FunctionType.get([i32], [idx]))
f = func.FuncOp("compl_dyn", FunctionType.get([i32], [i32]))
entry = f.add_entry_block()
with InsertionPoint(entry):
runtime_stride = entry.arguments[0]
Expand All @@ -328,8 +320,8 @@ def test_complement_rank_2_dynamic_stride_error():
tiler = fx.make_layout(shape, stride)
comp = fx.complement(tiler, 12)
sz = fx.size(comp)
sc = fx.get_scalar(sz)
func.ReturnOp([arith.IndexCastOp(idx, sc).result])
sc = fx.get_scalar(sz).ir_value()
func.ReturnOp([sc])

pm = PassManager.parse(FLY_PIPELINE, ctx)
pm.run(module.operation)
Expand Down
20 changes: 6 additions & 14 deletions tests/unit/test_static_vs_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,7 @@


FLY_PIPELINE = (
"builtin.module("
"fly-canonicalize,"
"fly-layout-lowering,"
"fly-canonicalize,"
"convert-fly-to-rocdl,"
"canonicalize,"
"cse)"
"builtin.module(fly-canonicalize,fly-layout-lowering,fly-canonicalize,convert-fly-to-rocdl,canonicalize,cse)"
)


Expand Down Expand Up @@ -85,9 +79,8 @@ def test_layout_dynamic_types():
with Location.unknown(ctx):
module = Module.create()
i32 = IntegerType.get_signless(32)
idx = IndexType.get()
with InsertionPoint(module.body):
f = func.FuncOp("dynamic_layout", FunctionType.get([i32] * 4, [idx]))
f = func.FuncOp("dynamic_layout", FunctionType.get([i32] * 4, [i32]))
entry = f.add_entry_block()
with InsertionPoint(entry):
dim0, dim1, stride0, stride1 = entry.arguments
Expand All @@ -98,7 +91,7 @@ def test_layout_dynamic_types():
layout = fx.make_layout(shape, stride)
sz = fx.size(layout)
sc = fx.get_scalar(sz)
func.ReturnOp([arith.IndexCastOp(idx, sc).result])
func.ReturnOp([sc.ir_value()])

pm = PassManager.parse(FLY_PIPELINE, ctx)
pm.run(module.operation)
Expand Down Expand Up @@ -138,9 +131,8 @@ def test_mixed_static_dynamic():
with Location.unknown(ctx):
module = Module.create()
i32 = IntegerType.get_signless(32)
idx = IndexType.get()
with InsertionPoint(module.body):
f = func.FuncOp("mixed_layout", FunctionType.get([i32, i32], [idx]))
f = func.FuncOp("mixed_layout", FunctionType.get([i32, i32], [i32]))
entry = f.add_entry_block()
with InsertionPoint(entry):
runtime_extent, runtime_stride = entry.arguments
Expand All @@ -152,8 +144,8 @@ def test_mixed_static_dynamic():
stride = fx.make_stride(c16, runtime_stride)
layout = fx.make_layout(shape, stride)
sz = fx.size(layout)
sc = fx.get_scalar(sz)
func.ReturnOp([arith.IndexCastOp(idx, sc).result])
sc = fx.get_scalar(sz).ir_value()
func.ReturnOp([sc])

pm = PassManager.parse(FLY_PIPELINE, ctx)
pm.run(module.operation)
Expand Down
Loading