dialects: (linalg) Add initial linalg tiling pass#5842
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #5842 +/- ##
==========================================
+ Coverage 86.78% 86.81% +0.03%
==========================================
Files 424 426 +2
Lines 63822 63960 +138
Branches 7319 7329 +10
==========================================
+ Hits 55385 55524 +139
+ Misses 6868 6866 -2
- Partials 1569 1570 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
53134e8 to
aaf8c70
Compare
superlopuh
left a comment
There was a problem hiding this comment.
Looking great! I think the file structure is perfect, let's now iterate on the helpers.
aaf8c70 to
aba7946
Compare
| def _analyze_operand_tile_info( | ||
| indexing_map: AffineMap, | ||
| source_type: MemRefType[Attribute], | ||
| tile_sizes: Sequence[int], | ||
| ) -> OperandTileInfo: | ||
| """ | ||
| Analyze how one operand should be sliced for each tile, and returned an `OperandTileInfo`. | ||
| """ | ||
|
|
||
| source_shape = source_type.get_shape() | ||
| loop_dims = tuple( | ||
| cast(AffineDimExpr, expr).position for expr in indexing_map.results | ||
| ) | ||
| result_shape = tuple( | ||
| tile_sizes[loop_dim] | ||
| if tile_sizes[loop_dim] != 0 | ||
| else source_shape[result_index] | ||
| for result_index, loop_dim in enumerate(loop_dims) | ||
| ) | ||
| return OperandTileInfo(source_type, loop_dims, result_shape) | ||
|
|
||
|
|
||
| def _analyze_generic_op( | ||
| op: linalg.ops.GenericOp, | ||
| tile_sizes: tuple[int, ...], | ||
| ) -> TilingPlan: |
There was a problem hiding this comment.
let's move these two methods to be public static methods on OperandTileInfo, add pytests for them specifically, and merge them in a separate PR?
There was a problem hiding this comment.
Sounds good, I’ll split that into a separate PR with pytests!
There was a problem hiding this comment.
Hi, for _analyze_generic_op, would you prefer it on OperandTileInfo as well, or would TilingPlan.analyze_generic_op be more appropriate since it returns a TilingPlan?
Sorry if I misunderstood anything.
There was a problem hiding this comment.
To me it feels natural to have the function that returns TilingPlan as static method on TilingPlan and the method that returns OperandTileInfo on OperandTileInfo.
There was a problem hiding this comment.
Done, thanks! analyze_generic_op is now on TilingPlan, and I added tests.
One more question,
I have this on a separate local branch for now. Since the current tiling PR has not merged yet, opening it as a PR would make it a stacked PR on my branch.
Does that sound like the right workflow, or would you prefer I wait until this PR lands and then open the helper PR from main?
Thank you.
There was a problem hiding this comment.
I would flip the order of the branches. So first merge your local branch into this one, then create a new branch from main where you take only these changes. Depending on how comfortable you are with git, I would either cherry-pick the changes you want or just copy/paste the code by hand, it shouldn't take more than 5 minutes. In either case the first step is to merge your changes directly to this branch first to contain the final destination you're working towards.
There was a problem hiding this comment.
Thanks for the clear explanation! I followed that workflow.
I merged the helper changes back into this branch, then opened a clean helper-only PR from main: #6125
27fd734 to
f0f3844
Compare
| raise PassFailedException(str(e)) from e | ||
|
|
||
| if not plan.tiled_dims: | ||
| return False |
There was a problem hiding this comment.
can you please add a test for this case?
|
|
||
| zero = arith.ConstantOp(IntegerAttr.from_index_int_value(0)) |
There was a problem hiding this comment.
This is kind of a nit but would be quite nice to have a single index type and reuse it both for IntegerAttr construction and the block argument type below
| zero = arith.ConstantOp(IntegerAttr.from_index_int_value(0)) | |
| index = IndexType() | |
| zero = arith.ConstantOp(IntegerAttr(0, index)) |
| strides, | ||
| ) | ||
| except ValueError as e: | ||
| raise PassFailedException(str(e)) from e |
superlopuh
left a comment
There was a problem hiding this comment.
Beautiful! I think it would be niec to do the index thing, then I think it's ready to merge
No description provided.