Summary
FlyDSL currently does not have a direct high-level op/helper for loading scalar or vector values from a global tensor with a computed offset. Kernel code that needs a simple global memory load has to spell out lower-level plumbing, which is verbose and easy to get subtly wrong.
A first-class global_load op/helper would make kernels easier to write and would keep pointer/addressing details in one place.
Current Implementations
Today we have at least two ways to express this pattern.
1. Extract raw pointer, GEP, then LLVM load
This is the direct low-level approach:
def _extract_global_ptr(tensor):
from flydsl._mlir.dialects import fly as _fly
raw = tensor.ir_value() if hasattr(tensor, "ir_value") and not isinstance(tensor, ir.Value) else tensor
ptr_type = ir.Type.parse("!llvm.ptr<1>")
return _fly.extract_aligned_pointer_as_index(ptr_type, raw)
def _global_load_i64x2(global_ptr, byte_offset_i64):
ptr = buffer_ops.get_element_ptr(global_ptr, byte_offset=fx.Int64(byte_offset_i64), elem_type=T.i8)
return llvm.LoadOp(T.i64x2, ptr, alignment=16).result
def _global_load_i32(global_ptr, elem_offset_i32):
byte_offset_i64 = fx.Int64(elem_offset_i32) * fx.Int64(4)
ptr = buffer_ops.get_element_ptr(global_ptr, byte_offset=byte_offset_i64, elem_type=T.i8)
return llvm.LoadOp(T.i32, ptr, alignment=4).result
Pros:
- Emits exactly the intended pointer arithmetic and
llvm.load.
- Works for raw byte offsets and arbitrary result types.
Cons:
- Requires importing/using
llvm directly from kernel code.
- Requires manual pointer extraction and byte-offset handling.
- Bypasses the higher-level FlyDSL tensor/view vocabulary.
2. Recast tensor iterator, make a Fly view, then memref_load_vec
This avoids direct llvm.LoadOp in kernel code:
def _recast_tensor_iter(tensor, elem_type):
src_iter = fx.get_iter(tensor)
src_ptr_type = fx.PointerType(src_iter.type)
ptr_type = fx.PointerType.get(
elem_ty=elem_type,
address_space=src_ptr_type.address_space,
alignment=src_ptr_type.alignment,
)
return fx.recast_iter(ptr_type, src_iter)
def _global_load_i64x2(tensor, byte_offset_i64):
ptr = fx.add_offset(_recast_tensor_iter(tensor, T.i8), fx.make_int_tuple(fx.Int64(byte_offset_i64)))
view = fx.Tensor(fx.make_view(ptr, fx.make_layout((16,), (1,))))
raw = fx.memref_load_vec(view)
return vector.bitcast(T.i64x2, raw)
def _global_load_i32(tensor, elem_offset_i32):
ptr = fx.add_offset(_recast_tensor_iter(tensor, T.i32), fx.make_int_tuple(fx.Int32(elem_offset_i32)))
view = fx.Tensor(fx.make_view(ptr, fx.make_layout((1,), (1,))))
raw = fx.memref_load_vec(view)
return vector.extract(raw, static_position=[0], dynamic_position=[])
Pros:
- Stays within FlyDSL iterator/view/memref abstractions.
- Avoids direct
llvm.LoadOp in kernel code.
- Can be lowered through existing Fly memref/vector-load machinery.
Cons:
- Still very verbose for a simple global load.
- Requires callers to know when to recast to byte-addressed
i8 vs element-typed pointers.
- Requires manually constructing one-off layouts and extracting/bitcasting results.
vector.load cannot be used directly on kernel fx.Tensor arguments because they are !fly.memref, not standard MLIR memref.
Proposal
Add a first-class FlyDSL global load op/helper, for example one or both of:
fx.global_load(result_type, tensor, byte_offset, *, alignment=None)
fx.global_load_elem(result_type, tensor, elem_offset, *, alignment=None)
or a lower-level dialect op such as:
fly.global_load %tensor[%offset] : !fly.memref<...> -> vector<...>
The helper/op should support:
- Scalar loads, e.g.
i32, f32.
- Vector loads, e.g.
vector<2xi64>, vector<16xi8>.
- Byte offsets for packed/raw access patterns.
- Element offsets for typed access patterns.
- Optional alignment and cache modifier metadata if useful.
- Lowering to the appropriate LLVM/global memory load without requiring kernel authors to manually extract pointers or build temporary Fly views.
Motivation
Paged attention kernels currently need this pattern for K/V cache reads and metadata reads. The low-level spelling makes code harder to review and increases the chance of unit mistakes between byte offsets, dword offsets, and element offsets.
A dedicated op/helper would make the intent obvious:
k2 = fx.global_load(T.i64x2, key_cache_ptr, byte_offset=ka_dw * fx.Int64(4), alignment=16)
context_len = fx.global_load(T.i32, context_lengths_ptr, elem_offset=batch_idx, alignment=4)
This would also make it easier to consistently tune lowering behavior in one place later.
Summary
FlyDSL currently does not have a direct high-level op/helper for loading scalar or vector values from a global tensor with a computed offset. Kernel code that needs a simple global memory load has to spell out lower-level plumbing, which is verbose and easy to get subtly wrong.
A first-class
global_loadop/helper would make kernels easier to write and would keep pointer/addressing details in one place.Current Implementations
Today we have at least two ways to express this pattern.
1. Extract raw pointer, GEP, then LLVM load
This is the direct low-level approach:
Pros:
llvm.load.Cons:
llvmdirectly from kernel code.2. Recast tensor iterator, make a Fly view, then
memref_load_vecThis avoids direct
llvm.LoadOpin kernel code:Pros:
llvm.LoadOpin kernel code.Cons:
i8vs element-typed pointers.vector.loadcannot be used directly on kernelfx.Tensorarguments because they are!fly.memref, not standard MLIRmemref.Proposal
Add a first-class FlyDSL global load op/helper, for example one or both of:
or a lower-level dialect op such as:
The helper/op should support:
i32,f32.vector<2xi64>,vector<16xi8>.Motivation
Paged attention kernels currently need this pattern for K/V cache reads and metadata reads. The low-level spelling makes code harder to review and increases the chance of unit mistakes between byte offsets, dword offsets, and element offsets.
A dedicated op/helper would make the intent obvious:
This would also make it easier to consistently tune lowering behavior in one place later.