Skip to content

[Enh] Ensure type closure for primitive func#552

Open
sjfeng1999 wants to merge 2 commits into
mainfrom
pr/enh-type-closure
Open

[Enh] Ensure type closure for primitive func#552
sjfeng1999 wants to merge 2 commits into
mainfrom
pr/enh-type-closure

Conversation

@sjfeng1999
Copy link
Copy Markdown
Collaborator

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Copilot AI review requested due to automatic review settings May 21, 2026 10:43
@sjfeng1999
Copy link
Copy Markdown
Collaborator Author

sjfeng1999 commented May 21, 2026

may need update the same aiter kernels simultaneously.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates FlyDSL’s Python expression-layer wrappers to better preserve (“close over”) DSL types (e.g., Numeric, Vector) when calling primitive ops, and adjusts affected tests/kernels to use the new return-value behavior.

Changes:

  • Wrap primitive-op scalar results back into Numeric types (and propagate this through helpers like get_scalar, get_leaves, ptr_load, memref_load).
  • Centralize memref_load_vec to return a Vector with shape/dtype metadata, and simplify Tensor.load() accordingly.
  • Update unit tests and a few kernel utilities to align with the revised scalar/vector return types and pipeline string formatting.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/unit/test_static_vs_dynamic.py Adjusts dynamic-layout tests to return i32 scalars directly and simplifies pipeline string formatting.
tests/unit/test_layout_algebra.py Updates dynamic test functions to return i32 scalars from fx.get_scalar(...).ir_value() and simplifies pipeline string formatting.
python/flydsl/expr/typing.py Updates IntTuple reconstruction and makes Tensor.load() rely on the updated memref_load_vec wrapper.
python/flydsl/expr/primitive.py Introduces numeric re-wrapping helper and applies it across several primitive ops; moves vector wrapping into memref_load_vec.
python/flydsl/expr/math.py Extends traced math-op wrapping to preserve DSL closure for both Numeric and Vector inputs.
kernels/silu_and_mul_fq.py Simplifies scale-offset computation by relying on the updated fx.get_scalar behavior.
kernels/mfma_preshuffle_pipeline.py Updates crd2idx helper to unwrap int-tuples and cast to index type using the new scalar typing behavior.
kernels/layout_utils.py Updates dynamic-layout crd2idx fallback to unwrap/cast through fx.get_scalar(...).ir_value().

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

scalar = _arith.IndexCastOp(T.index, scalar).result
return scalar
"""crd2idx returning an index-typed ir.Value (unwraps fly.int_tuple)."""
scalar = fx.get_scalar(fx.crd2idx(crd, layout)).ir_value()
Comment on lines +253 to +256
if not isinstance(value, ir.Value):
return value
if isinstance(value, Numeric):
return value
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants