[Benchmark][Mamba2] Modernise SSD benchmark suite: parametrize, mamba baseline, model-scale params#815
[Benchmark][Mamba2] Modernise SSD benchmark suite: parametrize, mamba baseline, model-scale params#815stelladuyx wants to merge 15 commits intotile-ai:mainfrom
Conversation
… baseline, model-scale params - Replace @fixture decorator with @pytest.mark.parametrize + typed signatures in bench_da_cumsum_fwd, bench_ssd_chunk_scan_fwd, bench_ssd_chunk_state_fwd, bench_ssd_decode, bench_ssd_state_passing_fwd - Add mamba_ssm.ops.triton baseline comparison to bench_ssd_chunk_scan_fwd, bench_ssd_chunk_state_fwd, bench_ssd_state_passing_fwd (with ImportError fallback to torch-ref); da_cumsum and decode have no mamba equivalent - Add _to_mamba_inputs() layout converter for chunk_scan (TileOPs chunked layout -> mamba flat-sequence layout) - Add Mamba2 model-scale benchmark params to bench_ssd_chunk_scan_fwd and bench_ssd_decode covering 130M/370M/780M/1.3B/2.7B with latency/serving/ throughput/long-context cases - Fix BenchmarkReport.record() calls: replace string literals with op object and normalise tag= to "torch-ref" / "mamba" Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request refactors several SSD (Mamba2) operator benchmarks by replacing custom fixtures with pytest.mark.parametrize and adding comprehensive model-scale parameters (ranging from 130M to 2.7B configurations). It also introduces baseline comparisons against mamba_ssm for the chunk scan, chunk state, and state passing operators. The review feedback identifies several critical issues where tensor layouts (permutations) and input mappings in the new benchmark code do not match the expectations of the mamba_ssm library, which would lead to incorrect baseline results.
I am having trouble creating individual review comments. Click here to see my feedback.
benchmarks/ops/bench_ssd_chunk_scan_fwd.py (19-31)
The _to_mamba_inputs layout converter has several critical issues that will lead to incorrect baseline results or runtime errors:
- Incorrect
dt_manddA_mmapping: Thedtinput is ignored, anddA_cumsumis used for bothdt_manddA_m. Inmamba_ssm,dtis the element-wise discretized decay, whiledA_cumsumis the prefix sum. They are distinct. - Missing Permutations:
mamba_ssmexpectsdtanddA_cumsumin(B, C, H, L)layout. In this benchmark,dtis(b, c, L, h)anddA_cumsumis(b, h, c, L). Both need permutations. - Incorrect State Layout:
prev_statesis(b, c, h, n, p). The permutation(0, 1, 2, 4, 3)changes it to(b, c, h, p, n), butmamba_ssmexpects(B, C, H, N, P).
def _to_mamba_inputs(x, cb, dA_cumsum, C, prev_states, dt):
b, c, L, h, p = x.shape
n = C.shape[-1]
x_m = x.reshape(b, c * L, h, p).contiguous()
cb_m = cb.contiguous()
# TileOPs dt: (b, c, L, h) -> mamba: (b, c, h, L)
dt_m = dt.permute(0, 1, 3, 2).contiguous()
# TileOPs dA_cumsum: (b, h, c, L) -> mamba: (b, c, h, L)
dA_m = dA_cumsum.permute(0, 2, 1, 3).contiguous()
C_m = C.reshape(b, c * L, h, n).contiguous()
# TileOPs prev_states: (b, c, h, n, p) -> mamba: (b, c, h, n, p)
states_m = prev_states.contiguous()
return x_m, cb_m, dt_m, dA_m, C_m, states_m
benchmarks/ops/bench_ssd_chunk_state_fwd.py (86-93)
The mamba_ssm baseline comparison will produce incorrect results because the dt and dA_cumsum tensors are passed in the wrong layout. tileops uses (b, h, c, L) for these tensors, while mamba_ssm expects (B, C, H, L). A permutation is required.
def mamba_fwd():
return _chunk_state_fwd(
Bmat.contiguous(),
x.contiguous(),
# TileOPs (b, h, c, L) -> mamba (b, c, h, L)
dt.permute(0, 2, 1, 3).contiguous(),
dA_cumsum.permute(0, 2, 1, 3).contiguous(),
seq_idx=seq_idx,
)
benchmarks/ops/bench_ssd_state_passing_fwd.py (68-73)
The dA_chunk_cumsum tensor needs to be permuted before being passed to the mamba_ssm baseline. tileops provides it as (b, h, c), but mamba_ssm expects (B, C, H).
def mamba_fwd():
return _state_passing_fwd(
states.contiguous(),
# TileOPs (b, h, c) -> mamba (b, c, h)
dA_chunk_cumsum.permute(0, 2, 1).contiguous(),
initial_states=initial_states.contiguous(),
)
…ine paths bench_ssd_chunk_scan_fwd: _to_mamba_inputs had two layout bugs: - dt was not permuted: TileOPs (b,c,L,h) -> mamba (b,h,c,L) requires permute(0,3,1,2) - dA_cumsum was incorrectly used for both dt_m and dA_m; dA_cumsum (b,h,c,L) already matches mamba layout and needs no permutation - prev_states permute(0,1,2,4,3) correctly maps (b,c,h,n,p) -> (b,c,h,p,n) (restored) bench_ssd_chunk_state_fwd / bench_ssd_state_passing_fwd: dt, dA_cumsum, and dA_chunk_cumsum are generated in mamba-compatible layouts by gen_inputs() and require no permutation — no kernel-call changes needed. All three files: remove stale imports from tests/ (ssd_chunk_scan_fwd_ref, ssd_chunk_state_fwd_ref, ssd_state_passing_fwd_ref are now defined locally in the bench files); import SsdChunkScanFwdTest/SsdChunkStateFwdTest/ SsdStatePassingFwdTest from workloads/ only. Update torch-ref baseline in chunk_scan to call the locally-defined ssd_chunk_scan_fwd_torch. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Ibuki-wind
left a comment
There was a problem hiding this comment.
Summary
Model-scale params, mamba_ssm baseline integration, and the parametrize migration are all good ideas. But there are correctness and hygiene issues that need fixing before merge.
Blocking
-
Dead
tests.*imports in two benchmark files —bench_da_cumsum_fwd.pyandbench_ssd_decode.pyboth addfrom tests.ops.*imports where every imported name is immediately shadowed (by local function definitions or theworkloads.*import on the next line). These imports are 100% dead code, but they re-introduce thetests.*coupling that #787 eliminated. Remove them entirely. -
Missing tensor layout permutations in mamba_ssm baselines —
bench_ssd_chunk_state_fwd.pyandbench_ssd_state_passing_fwd.pypassdt/dA_cumsum/dA_chunk_cumsumto mamba_ssm with only.contiguous(), no layout permutation. But the local reference functions in the same files explicitlypermute(0, 2, 1, 3)these tensors from TileOPs(b,h,c,L)to computation(b,c,h,L)layout. If mamba_ssm expects the same computation layout, the baselines produce silently wrong benchmark numbers. The chunk_scan file has a proper_to_mamba_inputsconverter — the other two files need equivalent treatment.
Non-blocking
- Unused fixture imports —
DaCumsumFwdFixtureandSsdDecodeFixtureare still imported but no longer used after the@Fixture→@pytest.mark.parametrizemigration.
Action items
- Delete the
from tests.ops.*import blocks inbench_da_cumsum_fwd.pyandbench_ssd_decode.py. - Clean up
DaCumsumFwdFixtureandSsdDecodeFixturefrom their respective workloads imports. - Add layout permutations (or a shared
_to_mamba_inputs-style converter) to the mamba baselines inbench_ssd_chunk_state_fwd.pyandbench_ssd_state_passing_fwd.py. Document the expected mamba_ssm layout in a comment, like chunk_scan does. - Verify the mamba baselines actually produce correct results by running them against the torch-ref baseline on at least one param set.
…r layout permutations
…rface - Kernel: seqlen-fused x/C/out [B,S,H,P]/[B,S,G,N], group-owned cb [B,C,G,L,L], dt/dA_cumsum [B,H,C,L], prev_states [B,C,H,P,N] (P before N) - Op: add n_groups param, update docstrings - Workload: gen_inputs uses official layouts, fixture includes n_groups - Test: new reference function handling group-owned tensors and P-before-N states - Benchmark: remove _to_mamba_inputs (inputs already in official layout), add n_groups throughout, fix mamba baseline call signature
…a_inputs, restore in-place state comment, add model-scale params to state_passing bench
…/test files Upstream refactored individual bench_ssd_*.py and test_ssd_chunk_scan_fwd.py into unified bench_mamba.py and test_mamba.py. Merge our PR tile-ai#815 changes into those consolidated files: - bench_mamba.py: add mamba_ssm Triton baselines (chunk_scan/chunk_state/ state_passing), model-scale parametrize params for all ops, n_groups-aware memory calculation for chunk_scan, fix decode n_groups=1 (official default) - test_mamba.py: replace ssd_chunk_scan_fwd_torch with official-interface ssd_chunk_scan_fwd_ref (n_groups-aware), pass n_groups through test fixture - Remove individual bench/test files superseded by consolidated versions Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…e-ai#939) Upstream tile-ai#939 consolidated all mamba workloads from workloads/ops/ssd_*.py into a single workloads/mamba.py and moved other workloads out of workloads/ops/ into a flat workloads/ layout. - Apply n_groups-aware SsdChunkScanFwdFixture/Test to workloads/mamba.py: fixture params add n_groups, gen_inputs uses official [B,S,H,P] layout with group-separated cb [B,C,G,L,L] and C [B,S,G,N] tensors - Remove workloads/ops/ssd_chunk_scan_fwd.py (accepted upstream deletion, content merged into workloads/mamba.py) - bench_mamba.py and test_mamba.py imports auto-resolved to workloads.mamba Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…sing (tile-ai#940) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary
@Fixturedecorator with@pytest.mark.parametrize+ explicit typed-> Nonesignatures across all 5 files, matching style ofbench_gqa_decode_paged.pymamba_ssm.ops.tritonbaseline comparison tobench_ssd_chunk_scan_fwd,bench_ssd_chunk_state_fwd,bench_ssd_state_passing_fwd(falls back totorch-refwhen mamba-ssm is not installed);da_cumsumandssd_decodehave no mamba-ssm equivalent_to_mamba_inputs()layout converter for chunk_scan (TileOPs chunked → mamba flat-sequence layout)bench_ssd_chunk_scan_fwdandbench_ssd_decode: 130M/370M/780M/1.3B/2.7B with latency/serving/throughput/long-context casesBenchmarkReport.record(): replace string literal first arg withop; normalisetag=to"torch-ref"/"mamba"Test plan
--collect-only→ 55 tests collected, no errorstag="mamba"results appear; without: falls back totag="torch-ref"