Skip to content

[Benchmark][Mamba2] Modernise SSD benchmark suite: parametrize, mamba baseline, model-scale params#815

Open
stelladuyx wants to merge 15 commits intotile-ai:mainfrom
stelladuyx:mamba
Open

[Benchmark][Mamba2] Modernise SSD benchmark suite: parametrize, mamba baseline, model-scale params#815
stelladuyx wants to merge 15 commits intotile-ai:mainfrom
stelladuyx:mamba

Conversation

@stelladuyx
Copy link
Copy Markdown
Collaborator

@stelladuyx stelladuyx commented Apr 7, 2026

Summary

  • Replace @Fixture decorator with @pytest.mark.parametrize + explicit typed -> None signatures across all 5 files, matching style of bench_gqa_decode_paged.py
  • Add mamba_ssm.ops.triton baseline comparison to bench_ssd_chunk_scan_fwd, bench_ssd_chunk_state_fwd, bench_ssd_state_passing_fwd (falls back to torch-ref when mamba-ssm is not installed); da_cumsum and ssd_decode have no mamba-ssm equivalent
  • Add _to_mamba_inputs() layout converter for chunk_scan (TileOPs chunked → mamba flat-sequence layout)
  • Add Mamba2 model-scale params to bench_ssd_chunk_scan_fwd and bench_ssd_decode: 130M/370M/780M/1.3B/2.7B with latency/serving/throughput/long-context cases
  • Fix BenchmarkReport.record(): replace string literal first arg with op; normalise tag= to "torch-ref" / "mamba"

Test plan

  • --collect-only → 55 tests collected, no errors
  • All pre-commit hooks pass
  • With mamba-ssm: tag="mamba" results appear; without: falls back to tag="torch-ref"

… baseline, model-scale params

- Replace @fixture decorator with @pytest.mark.parametrize + typed signatures
  in bench_da_cumsum_fwd, bench_ssd_chunk_scan_fwd, bench_ssd_chunk_state_fwd,
  bench_ssd_decode, bench_ssd_state_passing_fwd
- Add mamba_ssm.ops.triton baseline comparison to bench_ssd_chunk_scan_fwd,
  bench_ssd_chunk_state_fwd, bench_ssd_state_passing_fwd (with ImportError
  fallback to torch-ref); da_cumsum and decode have no mamba equivalent
- Add _to_mamba_inputs() layout converter for chunk_scan (TileOPs chunked
  layout -> mamba flat-sequence layout)
- Add Mamba2 model-scale benchmark params to bench_ssd_chunk_scan_fwd and
  bench_ssd_decode covering 130M/370M/780M/1.3B/2.7B with latency/serving/
  throughput/long-context cases
- Fix BenchmarkReport.record() calls: replace string literals with op object
  and normalise tag= to "torch-ref" / "mamba"

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@stelladuyx stelladuyx requested a review from a team April 7, 2026 05:40
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors several SSD (Mamba2) operator benchmarks by replacing custom fixtures with pytest.mark.parametrize and adding comprehensive model-scale parameters (ranging from 130M to 2.7B configurations). It also introduces baseline comparisons against mamba_ssm for the chunk scan, chunk state, and state passing operators. The review feedback identifies several critical issues where tensor layouts (permutations) and input mappings in the new benchmark code do not match the expectations of the mamba_ssm library, which would lead to incorrect baseline results.

I am having trouble creating individual review comments. Click here to see my feedback.

benchmarks/ops/bench_ssd_chunk_scan_fwd.py (19-31)

high

The _to_mamba_inputs layout converter has several critical issues that will lead to incorrect baseline results or runtime errors:

  1. Incorrect dt_m and dA_m mapping: The dt input is ignored, and dA_cumsum is used for both dt_m and dA_m. In mamba_ssm, dt is the element-wise discretized decay, while dA_cumsum is the prefix sum. They are distinct.
  2. Missing Permutations: mamba_ssm expects dt and dA_cumsum in (B, C, H, L) layout. In this benchmark, dt is (b, c, L, h) and dA_cumsum is (b, h, c, L). Both need permutations.
  3. Incorrect State Layout: prev_states is (b, c, h, n, p). The permutation (0, 1, 2, 4, 3) changes it to (b, c, h, p, n), but mamba_ssm expects (B, C, H, N, P).
def _to_mamba_inputs(x, cb, dA_cumsum, C, prev_states, dt):
    b, c, L, h, p = x.shape
    n = C.shape[-1]

    x_m = x.reshape(b, c * L, h, p).contiguous()
    cb_m = cb.contiguous()
    # TileOPs dt: (b, c, L, h) -> mamba: (b, c, h, L)
    dt_m = dt.permute(0, 1, 3, 2).contiguous()
    # TileOPs dA_cumsum: (b, h, c, L) -> mamba: (b, c, h, L)
    dA_m = dA_cumsum.permute(0, 2, 1, 3).contiguous()
    C_m = C.reshape(b, c * L, h, n).contiguous()
    # TileOPs prev_states: (b, c, h, n, p) -> mamba: (b, c, h, n, p)
    states_m = prev_states.contiguous()

    return x_m, cb_m, dt_m, dA_m, C_m, states_m

benchmarks/ops/bench_ssd_chunk_state_fwd.py (86-93)

high

The mamba_ssm baseline comparison will produce incorrect results because the dt and dA_cumsum tensors are passed in the wrong layout. tileops uses (b, h, c, L) for these tensors, while mamba_ssm expects (B, C, H, L). A permutation is required.

        def mamba_fwd():
            return _chunk_state_fwd(
                Bmat.contiguous(),
                x.contiguous(),
                # TileOPs (b, h, c, L) -> mamba (b, c, h, L)
                dt.permute(0, 2, 1, 3).contiguous(),
                dA_cumsum.permute(0, 2, 1, 3).contiguous(),
                seq_idx=seq_idx,
            )

benchmarks/ops/bench_ssd_state_passing_fwd.py (68-73)

high

The dA_chunk_cumsum tensor needs to be permuted before being passed to the mamba_ssm baseline. tileops provides it as (b, h, c), but mamba_ssm expects (B, C, H).

        def mamba_fwd():
            return _state_passing_fwd(
                states.contiguous(),
                # TileOPs (b, h, c) -> mamba (b, c, h)
                dA_chunk_cumsum.permute(0, 2, 1).contiguous(),
                initial_states=initial_states.contiguous(),
            )

…ine paths

bench_ssd_chunk_scan_fwd: _to_mamba_inputs had two layout bugs:
- dt was not permuted: TileOPs (b,c,L,h) -> mamba (b,h,c,L) requires permute(0,3,1,2)
- dA_cumsum was incorrectly used for both dt_m and dA_m; dA_cumsum (b,h,c,L)
  already matches mamba layout and needs no permutation
- prev_states permute(0,1,2,4,3) correctly maps (b,c,h,n,p) -> (b,c,h,p,n) (restored)

bench_ssd_chunk_state_fwd / bench_ssd_state_passing_fwd: dt, dA_cumsum, and
dA_chunk_cumsum are generated in mamba-compatible layouts by gen_inputs() and
require no permutation — no kernel-call changes needed.

All three files: remove stale imports from tests/ (ssd_chunk_scan_fwd_ref,
ssd_chunk_state_fwd_ref, ssd_state_passing_fwd_ref are now defined locally in
the bench files); import SsdChunkScanFwdTest/SsdChunkStateFwdTest/
SsdStatePassingFwdTest from workloads/ only. Update torch-ref baseline in
chunk_scan to call the locally-defined ssd_chunk_scan_fwd_torch.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@Ibuki-wind Ibuki-wind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

Model-scale params, mamba_ssm baseline integration, and the parametrize migration are all good ideas. But there are correctness and hygiene issues that need fixing before merge.

Blocking

  1. Dead tests.* imports in two benchmark filesbench_da_cumsum_fwd.py and bench_ssd_decode.py both add from tests.ops.* imports where every imported name is immediately shadowed (by local function definitions or the workloads.* import on the next line). These imports are 100% dead code, but they re-introduce the tests.* coupling that #787 eliminated. Remove them entirely.

  2. Missing tensor layout permutations in mamba_ssm baselinesbench_ssd_chunk_state_fwd.py and bench_ssd_state_passing_fwd.py pass dt/dA_cumsum/dA_chunk_cumsum to mamba_ssm with only .contiguous(), no layout permutation. But the local reference functions in the same files explicitly permute(0, 2, 1, 3) these tensors from TileOPs (b,h,c,L) to computation (b,c,h,L) layout. If mamba_ssm expects the same computation layout, the baselines produce silently wrong benchmark numbers. The chunk_scan file has a proper _to_mamba_inputs converter — the other two files need equivalent treatment.

Non-blocking

  1. Unused fixture importsDaCumsumFwdFixture and SsdDecodeFixture are still imported but no longer used after the @Fixture@pytest.mark.parametrize migration.

Action items

  1. Delete the from tests.ops.* import blocks in bench_da_cumsum_fwd.py and bench_ssd_decode.py.
  2. Clean up DaCumsumFwdFixture and SsdDecodeFixture from their respective workloads imports.
  3. Add layout permutations (or a shared _to_mamba_inputs-style converter) to the mamba baselines in bench_ssd_chunk_state_fwd.py and bench_ssd_state_passing_fwd.py. Document the expected mamba_ssm layout in a comment, like chunk_scan does.
  4. Verify the mamba baselines actually produce correct results by running them against the torch-ref baseline on at least one param set.

Comment thread benchmarks/ops/bench_da_cumsum_fwd.py Outdated
Comment thread benchmarks/ops/bench_da_cumsum_fwd.py Outdated
Comment thread benchmarks/ops/bench_ssd_decode.py Outdated
Comment thread benchmarks/ops/bench_ssd_decode.py Outdated
Comment thread benchmarks/ops/bench_ssd_chunk_state_fwd.py Outdated
Comment thread benchmarks/ops/bench_ssd_chunk_state_fwd.py Outdated
Comment thread benchmarks/ops/bench_ssd_state_passing_fwd.py Outdated
@stelladuyx stelladuyx requested a review from Ibuki-wind April 7, 2026 11:00
@stelladuyx stelladuyx self-assigned this Apr 8, 2026
stelladuyx and others added 9 commits April 10, 2026 11:13
…rface

- Kernel: seqlen-fused x/C/out [B,S,H,P]/[B,S,G,N], group-owned cb [B,C,G,L,L],
  dt/dA_cumsum [B,H,C,L], prev_states [B,C,H,P,N] (P before N)
- Op: add n_groups param, update docstrings
- Workload: gen_inputs uses official layouts, fixture includes n_groups
- Test: new reference function handling group-owned tensors and P-before-N states
- Benchmark: remove _to_mamba_inputs (inputs already in official layout),
  add n_groups throughout, fix mamba baseline call signature
…a_inputs, restore in-place state comment, add model-scale params to state_passing bench
…/test files

Upstream refactored individual bench_ssd_*.py and test_ssd_chunk_scan_fwd.py
into unified bench_mamba.py and test_mamba.py. Merge our PR tile-ai#815 changes into
those consolidated files:

- bench_mamba.py: add mamba_ssm Triton baselines (chunk_scan/chunk_state/
  state_passing), model-scale parametrize params for all ops, n_groups-aware
  memory calculation for chunk_scan, fix decode n_groups=1 (official default)
- test_mamba.py: replace ssd_chunk_scan_fwd_torch with official-interface
  ssd_chunk_scan_fwd_ref (n_groups-aware), pass n_groups through test fixture
- Remove individual bench/test files superseded by consolidated versions

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…e-ai#939)

Upstream tile-ai#939 consolidated all mamba workloads from workloads/ops/ssd_*.py
into a single workloads/mamba.py and moved other workloads out of
workloads/ops/ into a flat workloads/ layout.

- Apply n_groups-aware SsdChunkScanFwdFixture/Test to workloads/mamba.py:
  fixture params add n_groups, gen_inputs uses official [B,S,H,P] layout
  with group-separated cb [B,C,G,L,L] and C [B,S,G,N] tensors
- Remove workloads/ops/ssd_chunk_scan_fwd.py (accepted upstream deletion,
  content merged into workloads/mamba.py)
- bench_mamba.py and test_mamba.py imports auto-resolved to workloads.mamba

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…sing (tile-ai#940)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants