Skip to content

Refactor debug output to return full intermediate arrays instead of checksums#171

Open
oliverdutton wants to merge 34 commits intoclaude/refactor-sampling-kernels-IWkhwfrom
claude/refactor-debug-output-fs6HG
Open

Refactor debug output to return full intermediate arrays instead of checksums#171
oliverdutton wants to merge 34 commits intoclaude/refactor-sampling-kernels-IWkhwfrom
claude/refactor-debug-output-fs6HG

Conversation

@oliverdutton
Copy link
Copy Markdown
Owner

Summary

This PR refactors the debug output mechanism in the top-k/top-p sampling kernels to return full intermediate arrays instead of checksums or partial values. This enables more comprehensive debugging and validation of the sampling pipeline.

Key Changes

Arbitrary-k Kernel (tallax/vllm/arbitrary_k/kernel.py)

  • Renamed debug_results_ref parameter to debug_arrays_ref to clarify that full arrays are returned
  • Replaced the _write_debug_results() helper function that computed checksums with direct array assignments
  • Changed debug output from SMEM-based checksums to full arrays:
    • greedy_sampled: Full [batch, 1] array instead of [1, 1] checksum
    • topk_logits_unsorted: Full [batch, vocab_size] array (new)
    • topk_topp_unnorm_probs_i32_unsorted: Full [batch, vocab_size] array instead of nonzero count
    • random_unnorm_cdf_sampled: Full [batch, 1] array (new, extracted as int64)
    • next_tokens: Full [batch, 1] array instead of [1, 1] checksum
  • Updated BlockSpec definitions to use proper memory layout for full arrays instead of SMEM
  • Added trimming of padded outputs to match original batch size in debug results

Bounded-k Kernel (tallax/vllm/bounded_k/top_p_and_sample.py)

  • Added debug parameter to top_p_and_sample_arrays() and related functions
  • Refactored to store intermediate values before transposition for proper debug output
  • Added greedy_sampled computation before temperature scaling
  • Returns (tokens, debug_results) tuple when debug=True
  • Debug output includes full arrays matching the arbitrary-k format:
    • greedy_sampled: [batch_size]
    • topk_logits_unsorted: [batch_size, k]
    • topk_topp_unnorm_probs_i32_unsorted: [batch_size, k]
    • random_unnorm_cdf_sampled: [batch_size]
    • next_tokens: [batch_size]

Reference Implementation (tallax/vllm/reference.py)

  • Updated debug output to use OrderedDict for consistent ordering
  • Renamed debug keys to match kernel implementation:
    • topp_unnorm_probs_i32topk_topp_unnorm_probs_i32_unsorted
    • topp_nonzero_count → removed (now returning full array)
    • total_sum → removed (now returning random_unnorm_cdf_sampled)

Tests (tests/vllm/arbitrary_k/kernel_test.py)

  • Updated assertions to work with full arrays instead of checksums
  • Added new assertions for:
    • topk_logits_unsorted: Validates full top-k logits match
    • topk_topp_unnorm_probs_i32_unsorted: Validates full unnormalized probabilities
    • random_unnorm_cdf_sampled: Validates the sampled CDF value
  • Removed assertions for checksums that are no longer computed

Implementation Details

  • Uses OrderedDict to maintain consistent ordering of debug outputs across implementations
  • Full arrays are now stored in HBM instead of SMEM, enabling larger intermediate values
  • Debug output shapes match the full computation shapes for easier validation
  • Padding is trimmed from debug outputs to match the original batch size

https://claude.ai/code/session_01NJso6RoxkCzL1d3hSTnU5R

claude and others added 30 commits February 15, 2026 16:03
Changes:
- Updated reference.py to return OrderedDict with exactly 5 fields:
  greedy_sampled, topk_logits_unsorted, topk_topp_unnorm_probs_i32_unsorted,
  random_unnorm_cdf_sampled, next_tokens
- Modified arbitrary_k/kernel.py to return same OrderedDict structure
  with full intermediate arrays when debug=True
- Added debug support to bounded_k/top_p_and_sample.py with OrderedDict
  containing k-slice versions of debug outputs
- Updated tests to use new debug field names and verify all intermediates

The bounded_k implementation returns debug values in the reduced k-slice
format (batch, k) rather than full vocabulary size, matching the input
dimensions.

https://claude.ai/code/session_01NJso6RoxkCzL1d3hSTnU5R
Changed random_unnorm_cdf_sampled from single i64 value to tuple of
two u32 values (high, low) in all debug outputs:
- reference.py: Split i64 into high and low u32 components
- arbitrary_k/kernel.py: Store U48 high/low as separate arrays
- bounded_k/top_p_and_sample.py: Split target_cumsum into high/low
- Updated test to verify both high and low components match

https://claude.ai/code/session_01NJso6RoxkCzL1d3hSTnU5R
Renamed functions:
- reference_topk_topp_mask_and_sample → reference_topk_topp_and_sample
- topk_topp_mask_and_sample → arbitrary_topk_topp_and_sample
- top_p_and_sample → bounded_topk_topp_and_sample

Changes:
- Added max_k parameter to bounded_topk_topp_and_sample
- Moved random_u128_in_u32s generation outside sharding rules in
  bounded implementation for better control over RNG state
- Updated all imports and exports in __init__.py files
- Updated test imports to use new function names
- Normalized sharding rules to pass random_u128_in_u32s explicitly

https://claude.ai/code/session_01NJso6RoxkCzL1d3hSTnU5R
@oliverdutton oliverdutton force-pushed the claude/refactor-debug-output-fs6HG branch from 392b73f to 91d0acf Compare February 27, 2026 20:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants