Refactor debug output to return full intermediate arrays instead of checksums#171
Open
oliverdutton wants to merge 34 commits intoclaude/refactor-sampling-kernels-IWkhwfrom
Open
Conversation
Changes: - Updated reference.py to return OrderedDict with exactly 5 fields: greedy_sampled, topk_logits_unsorted, topk_topp_unnorm_probs_i32_unsorted, random_unnorm_cdf_sampled, next_tokens - Modified arbitrary_k/kernel.py to return same OrderedDict structure with full intermediate arrays when debug=True - Added debug support to bounded_k/top_p_and_sample.py with OrderedDict containing k-slice versions of debug outputs - Updated tests to use new debug field names and verify all intermediates The bounded_k implementation returns debug values in the reduced k-slice format (batch, k) rather than full vocabulary size, matching the input dimensions. https://claude.ai/code/session_01NJso6RoxkCzL1d3hSTnU5R
Changed random_unnorm_cdf_sampled from single i64 value to tuple of two u32 values (high, low) in all debug outputs: - reference.py: Split i64 into high and low u32 components - arbitrary_k/kernel.py: Store U48 high/low as separate arrays - bounded_k/top_p_and_sample.py: Split target_cumsum into high/low - Updated test to verify both high and low components match https://claude.ai/code/session_01NJso6RoxkCzL1d3hSTnU5R
Renamed functions: - reference_topk_topp_mask_and_sample → reference_topk_topp_and_sample - topk_topp_mask_and_sample → arbitrary_topk_topp_and_sample - top_p_and_sample → bounded_topk_topp_and_sample Changes: - Added max_k parameter to bounded_topk_topp_and_sample - Moved random_u128_in_u32s generation outside sharding rules in bounded implementation for better control over RNG state - Updated all imports and exports in __init__.py files - Updated test imports to use new function names - Normalized sharding rules to pass random_u128_in_u32s explicitly https://claude.ai/code/session_01NJso6RoxkCzL1d3hSTnU5R
…om/oliverdutton/tallax into claude/refactor-debug-output-fs6HG
392b73f to
91d0acf
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR refactors the debug output mechanism in the top-k/top-p sampling kernels to return full intermediate arrays instead of checksums or partial values. This enables more comprehensive debugging and validation of the sampling pipeline.
Key Changes
Arbitrary-k Kernel (
tallax/vllm/arbitrary_k/kernel.py)debug_results_refparameter todebug_arrays_refto clarify that full arrays are returned_write_debug_results()helper function that computed checksums with direct array assignmentsgreedy_sampled: Full [batch, 1] array instead of [1, 1] checksumtopk_logits_unsorted: Full [batch, vocab_size] array (new)topk_topp_unnorm_probs_i32_unsorted: Full [batch, vocab_size] array instead of nonzero countrandom_unnorm_cdf_sampled: Full [batch, 1] array (new, extracted as int64)next_tokens: Full [batch, 1] array instead of [1, 1] checksumBlockSpecdefinitions to use proper memory layout for full arrays instead of SMEMBounded-k Kernel (
tallax/vllm/bounded_k/top_p_and_sample.py)debugparameter totop_p_and_sample_arrays()and related functionsgreedy_sampledcomputation before temperature scaling(tokens, debug_results)tuple whendebug=Truegreedy_sampled: [batch_size]topk_logits_unsorted: [batch_size, k]topk_topp_unnorm_probs_i32_unsorted: [batch_size, k]random_unnorm_cdf_sampled: [batch_size]next_tokens: [batch_size]Reference Implementation (
tallax/vllm/reference.py)OrderedDictfor consistent orderingtopp_unnorm_probs_i32→topk_topp_unnorm_probs_i32_unsortedtopp_nonzero_count→ removed (now returning full array)total_sum→ removed (now returningrandom_unnorm_cdf_sampled)Tests (
tests/vllm/arbitrary_k/kernel_test.py)topk_logits_unsorted: Validates full top-k logits matchtopk_topp_unnorm_probs_i32_unsorted: Validates full unnormalized probabilitiesrandom_unnorm_cdf_sampled: Validates the sampled CDF valueImplementation Details
OrderedDictto maintain consistent ordering of debug outputs across implementationshttps://claude.ai/code/session_01NJso6RoxkCzL1d3hSTnU5R