Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,57 @@ Tilus caches generated kernels. During development, set the cache directory via
- To inspect the CUDA kernel generated for a specific script, delete the cache directory, run the program, and check the newly generated `source.cu`.
- Use `debug_schedule=dict(...)` to pin a specific schedule, so only that single configuration is compiled.

## Compilation Pipeline

`drivers.py:build_program` orchestrates the full compilation:
1. **Verify** — `ir.tools.verify(prog)`
2. **Optimize (Tilus IR)** — `optimize_program` applies `get_default_passes()` from `transforms/__init__.py`
3. **Lower to Hidet IR** — `backends.codegen.generate_ir_module`
4. **Optimize (Hidet IR)** — `optimize_ir_module` applies Hidet-level passes
5. **Codegen** — Emit CUDA C source
6. **Compile** — Build `.so` via nvcc

### Tilus IR Structure

- **Program** (`ir/prog.py`) — `frozendict[str, Function]`
- **Function** (`ir/func.py`) — `name`, `params`, `body: Stmt`, `metadata: Metadata`
- **Stmt** (`ir/stmt.py`) — Frozen dataclasses: `SeqStmt`, `ForStmt`, `IfStmt`, `WhileStmt`, `LetStmt`, `InstStmt`, `ThreadGroupStmt`, etc.
- **InstStmt** wraps an `Instruction` — this is how instructions appear in the statement tree
- **Instruction** (`ir/inst.py`) — `output: Optional[Tensor]`, `inputs: tuple[Tensor, ...]`, plus type-specific attributes
- **Functional instructions** — pure computations on tensors (e.g., `AddInst`, `CastInst`, `SliceRegisterInst`). Whether an instruction is functional is determined by an explicit allowlist, NOT by checking `output is None`.
- **Side-effecting instructions** — memory ops, synchronization, etc. (e.g., `StoreGlobalInst`, `SyncThreadsInst`). Must never be eliminated.
- **Tensor** — `RegisterTensor`, `SharedTensor`, `GlobalTensor`, `TMemoryTensor` — identity-based (frozen dataclass with `eq=False`)

### Writing Passes

- Base class: `transforms/base.py:Pass` — override `process_function(func) -> Function`
- `IRRewriter` (`ir/functors/functor.py`) — visitor/rewriter pattern. Override `visit_*` methods. Uses identity-based memoization.
- `visit_Instruction` handles all instruction types generically (rewrites output/inputs/attributes)
- `visit_InstStmt` delegates to `visit(stmt.inst)` — if the instruction visitor returns `None`, the stmt becomes `SeqStmt(())`
- For instruction-specific handling, define `visit_<InstructionClassName>` methods
- `IRVisitor` — read-only traversal (same dispatch, returns None)
- Register pass in `transforms/__init__.py:get_default_passes()` and export from `__init__.py`
- Tensors use identity (`is`) for equality since `eq=False`. The memo dict in IRFunctor keys on object identity for IR nodes.

### IRVisitor/IRRewriter Pitfalls

- **`visit_Expr` is a no-op in `IRVisitor`** — it does NOT descend into Hidet sub-expressions. To collect Vars from Hidet expressions, use `hidet.ir.tools.collect(expr, Var)` explicitly.
- **Memo prevents re-visiting** — once a node is visited, `IRFunctor.visit()` returns the cached result. If you need to process the same Expr from multiple contexts, don't rely on `visit_Expr` being called again. Instead, collect data directly (e.g., iterate `inst.attributes.values()` and call `hidet_collect` yourself).
- **Instruction attributes contain Hidet Exprs** — many instructions store Vars (e.g., barrier addresses, offsets) in dataclass fields beyond `output`/`inputs`. Access these via `inst.attributes` (a dict of all non-output/inputs fields). When analyzing Var usage, always scan attributes too.
- **`TensorItemValueStmt`/`TensorItemPtrStmt`** bind Hidet `Var`s to tensor values/pointers. They bridge the Tilus tensor world and the Hidet scalar expression world. When checking liveness, both the tensor and the bound Var must be considered.

### Testing Passes

- Build test IR directly using `Function.create(...)`, `SeqStmt(...)`, `InstStmt(inst)`, and instruction `create()` methods. See `tests/transforms/test_dead_code_elimination.py` for examples.
- Use `ir/tools/instruction_collector.py:collect_instructions(func)` to count instructions by type after a pass.
- For end-to-end testing, use `InstantiatedScript._jit_instance_for(...)` to get a `JitInstance`, then access `ji.transpiled_programs[0]` for the `Program`.

## Debug Tips

### Dumping IR after each pass

Call `tilus.option.debug.dump_ir()` before running the kernel. The IR after each pass will be dumped into the cache directory under `ir/` (for Tilus IR passes) and `module/ir/` (for Hidet IR passes).

### Proxy fence required between `store_shared` and `tma.shared_to_global`

`self.store_shared(...)` writes to shared memory via the **generic proxy**, while `self.tma.shared_to_global(...)` reads from shared memory via the **async proxy**. A `fence.proxy.async.shared::cta` is required between them to ensure the generic proxy writes are visible to the async proxy. Without this fence, the TMA engine may read stale data.
Expand Down
2 changes: 2 additions & 0 deletions python/tilus/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from .base import Pass, PassContext, apply_transforms
from .bound_aware_simplify import bound_aware_simplify_pass
from .dead_code_elimination import dead_code_elimination_pass
from .declare_to_let import declare_to_let_pass
from .inject_print_instruction import inject_print_instruction_pass
from .layout_inference import layout_inference_pass
Expand All @@ -42,4 +43,5 @@ def get_default_passes() -> list[Pass]:
layout_inference_pass(),
bound_aware_simplify_pass(),
analyze_scalar_pass(),
dead_code_elimination_pass(),
]
244 changes: 244 additions & 0 deletions python/tilus/transforms/dead_code_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Dead code elimination pass for Tilus IR.

Removes functional instructions whose output tensors are never consumed by any other instruction.
"""

from typing import Type

from hidet.ir.expr import Expr, Var
from hidet.ir.tools import collect as hidet_collect

from tilus.ir.func import Function
from tilus.ir.functors import IRRewriter, IRVisitor
from tilus.ir.inst import Instruction
from tilus.ir.instructions.cuda.clc import ClusterLaunchControlQueryResponseInst
from tilus.ir.instructions.cuda.ldmatrix import LoadMatrixInst
from tilus.ir.instructions.cuda.mapa import MapSharedAddrInst
from tilus.ir.instructions.cuda.mbarrier import AllocBarrierInst
from tilus.ir.instructions.cuda.mma_dot import DotInst
from tilus.ir.instructions.cuda.simt_dot import SimtDotInst
from tilus.ir.instructions.cuda.tcgen05 import Tcgen05LoadInst, Tcgen05SliceInst, Tcgen05ViewInst
from tilus.ir.instructions.generic import (
AllocateRegisterInst,
CastInst,
ElementwiseBinaryBaseInst,
ElementwiseUnaryBaseInst,
GlobalViewInst,
LoadGlobalGenericInst,
LoadGlobalInst,
LoadSharedGenericInst,
LoadSharedInst,
PermuteSharedInst,
ReduceInst,
RepeatInst,
RepeatInterleaveInst,
ReshapeSharedInst,
SliceGlobalInst,
SliceRegisterInst,
SliceSharedInst,
SqueezeInst,
TransposeInst,
UnsqueezeInst,
ViewInst,
WhereInst,
)
from tilus.ir.stmt import SeqStmt, Stmt, TensorItemPtrStmt, TensorItemValueStmt
from tilus.ir.tensor import Tensor
from tilus.transforms.base import Pass

# Functional instruction types: pure computations safe to eliminate if output is unused.
FUNCTIONAL_INST_TYPES: tuple[Type[Instruction], ...] = (
# Register tensor operations
AllocateRegisterInst,
SliceRegisterInst,
CastInst,
ElementwiseUnaryBaseInst, # covers NegInst, AbsInst, ClipInst, ElementwiseUnaryInst
ElementwiseBinaryBaseInst, # covers AddInst, SubInst, MulInst, DivInst, ModInst, ElementwiseBinaryInst
WhereInst,
RepeatInst,
RepeatInterleaveInst,
ReduceInst,
ViewInst,
SqueezeInst,
UnsqueezeInst,
TransposeInst,
# Load instructions
LoadGlobalInst,
LoadSharedInst,
LoadGlobalGenericInst,
LoadSharedGenericInst,
LoadMatrixInst,
Tcgen05LoadInst,
# Shared/Global tensor views
SliceGlobalInst,
SliceSharedInst,
ReshapeSharedInst,
PermuteSharedInst,
GlobalViewInst,
# TMemory views
Tcgen05SliceInst,
Tcgen05ViewInst,
# Other pure ops
DotInst,
SimtDotInst,
MapSharedAddrInst,
ClusterLaunchControlQueryResponseInst,
AllocBarrierInst,
)


def _is_functional(inst: Instruction) -> bool:
return isinstance(inst, FUNCTIONAL_INST_TYPES)


class UsedTensorCollector(IRVisitor):
"""
Collects all tensors that are "used" (consumed by a live instruction or statement).

An instruction's input tensors are used if:
- The instruction is side-effecting (not functional), OR
- The instruction is functional and its output tensor is itself used.

A TensorItemValueStmt/TensorItemPtrStmt's tensor is used only if the Var it
binds is referenced in some expression elsewhere in the function.

We iterate to a fixed point since a functional instruction's liveness
depends on whether its output is consumed by another live instruction.
"""

def __init__(self) -> None:
super().__init__()
self.used_tensors: set[int] = set() # set of id(tensor)
self.functional_insts: list[Instruction] = []
# Deferred: TensorItem stmts whose liveness depends on Var usage
self.tensor_item_stmts: list[TensorItemValueStmt | TensorItemPtrStmt] = []
# All Vars referenced in expressions (collected after traversal).
# We skip visiting the defining Var in visit_TensorItemValueStmt/PtrStmt,
# so a Var only appears here if it's referenced elsewhere.
self.expr_vars: set[int] = set() # set of id(Var)

def visit_Instruction(self, inst: Instruction) -> None:
if _is_functional(inst):
self.functional_insts.append(inst)
else:
# Side-effecting: all inputs are unconditionally used
for tensor in inst.inputs:
self.used_tensors.add(id(tensor))
# Collect Vars from all Expr-typed attributes so that TensorItemValueStmt
# vars referenced in instruction attributes are tracked.
for value in inst.attributes.values():
self._collect_expr_vars(value)

def _collect_expr_vars(self, value: Expr | tuple | list | object) -> None:
"""Recursively collect Vars from Expr-typed values (including inside tuples/lists)."""
if isinstance(value, Expr):
for var in hidet_collect(value, Var):
self.expr_vars.add(id(var))
elif isinstance(value, (tuple, list)):
for item in value:
self._collect_expr_vars(item)

def visit_TensorItemValueStmt(self, stmt: TensorItemValueStmt) -> None:
# Defer: only mark tensor as used if stmt.var is actually referenced.
# We skip visiting stmt.var here so it won't self-register in expr_vars.
self.tensor_item_stmts.append(stmt)
self.visit(stmt.tensor)

def visit_TensorItemPtrStmt(self, stmt: TensorItemPtrStmt) -> None:
# Defer: only mark tensor as used if stmt.ptr_var is actually referenced.
self.tensor_item_stmts.append(stmt)
self.visit(stmt.tensor)

def visit_Expr(self, expr: Expr) -> None:
# Collect all Vars referenced in Hidet expressions.
for var in hidet_collect(expr, Var):
self.expr_vars.add(id(var))

def _mark_used(self, tensor: Tensor) -> bool:
"""Mark a tensor as used. Returns True if it was newly added."""
tid = id(tensor)
if tid not in self.used_tensors:
self.used_tensors.add(tid)
return True
return False

def propagate(self) -> None:
"""Fixed-point propagation of tensor liveness."""
# Mark tensors from TensorItem stmts whose bound Var is referenced
for stmt in self.tensor_item_stmts:
bound_var = stmt.var if isinstance(stmt, TensorItemValueStmt) else stmt.ptr_var
if id(bound_var) in self.expr_vars:
self.used_tensors.add(id(stmt.tensor))

# Propagate through functional instruction chains
changed = True
while changed:
changed = False
for inst in self.functional_insts:
if inst.output is not None and id(inst.output) in self.used_tensors:
for tensor in inst.inputs:
if self._mark_used(tensor):
changed = True


class DeadCodeEliminator(IRRewriter):
"""Eliminates dead functional instructions and dead TensorItem stmts."""

def __init__(self, used_tensors: set[int]) -> None:
super().__init__()
self.used_tensors = used_tensors

def visit_Instruction(self, inst: Instruction) -> Instruction | None:
if _is_functional(inst) and inst.output is not None and id(inst.output) not in self.used_tensors:
return None
return super().visit_Instruction(inst)

def visit_TensorItemValueStmt(self, stmt: TensorItemValueStmt) -> Stmt:
if id(stmt.tensor) not in self.used_tensors:
return SeqStmt(())
return super().visit_TensorItemValueStmt(stmt)

def visit_TensorItemPtrStmt(self, stmt: TensorItemPtrStmt) -> Stmt:
if id(stmt.tensor) not in self.used_tensors:
return SeqStmt(())
return super().visit_TensorItemPtrStmt(stmt)


class DeadCodeEliminationPass(Pass):
def process_function(self, function: Function) -> Function:
# Pass 1: collect used tensors
collector = UsedTensorCollector()
collector.visit(function)
collector.propagate()

# Check if there's anything to eliminate
has_dead = any(
inst.output is not None and id(inst.output) not in collector.used_tensors
for inst in collector.functional_insts
) or any(id(stmt.tensor) not in collector.used_tensors for stmt in collector.tensor_item_stmts)

if not has_dead:
return function

# Pass 2: eliminate dead instructions
eliminator = DeadCodeEliminator(collector.used_tensors)
return eliminator.visit(function)


def dead_code_elimination_pass() -> Pass:
return DeadCodeEliminationPass()
Loading
Loading