Reorganize sampling kernels into clear code paths with shared utils#170
Open
oliverdutton wants to merge 2 commits intomainfrom
Open
Reorganize sampling kernels into clear code paths with shared utils#170oliverdutton wants to merge 2 commits intomainfrom
oliverdutton wants to merge 2 commits intomainfrom
Conversation
Restructure tallax/vllm/ from the fix-topp-integer branch into three
clear packages:
**vllm/utils/** - Shared utilities used by both code paths:
- binary_search.py: Monotonic f32<->u32 conversion, binary search
- high_precision_uint.py: U48 48-bit arithmetic, modulo_u128_u64, RNG
**vllm/fullvocab/** - Full-vocabulary binary search path (never sorts):
- topk_mask.py: Binary search in f32 space + stable boundary detection
- topp_mask.py: Binary search in i32 probability space with U48 sums
- kernel.py: Combined topk+topp+sample Pallas kernel with debug_results
**vllm/reducedk/** - Sort-based reduced-k path (bitonic sort to k, then operate):
- top_p_and_sample.py: top_p_integer_mask on sorted subset, re-sort to
original order, sample via modulo_u128_u64
**vllm/reference.py** - Pure JAX reference (no Pallas):
- Uses jax.enable_x64 for exact i64 arithmetic
- Python arbitrary-precision u128 % u64 via pure_callback
- Ground truth for both kernel paths
Debug support: fullvocab kernel accepts debug=True to return a nested dict
of int32[1] SMEM values for verifying intermediate stages match reference.
Also:
- Cherry-pick stable dynamic_topk from stable_topk branch (stable sort in
divide-and-filter, >= swap for stable index ordering, reverse loop order)
- Fix syntax errors in cherry-picked code (missing parens)
- Add import operator for stable comparison
- Add sparse_random_bits() extracted from sparse_random_uniform()
- Add map_reduce() to tax/utils.py for parallel chunked reduction
- Add comprehensive component tests (binary_search, U48, modulo, topk_mask,
topp_mask, reference greedy, fullvocab vs reference end-to-end)
https://claude.ai/code/session_01Rd5Nd2h4ZSmMW5wAAkiuwt
…arse_random - Rename fullvocab/ -> arbitrary_k/ (binary-search-based, works with any k) - Rename reducedk/ -> bounded_k/ (sort-based, requires bounded k <= 128) - Remove sparse_random.py and its tests (all sampling now uses i32 path) - Remove dead top_p_and_sample.py (old code superseded by bounded_k) - Remove random_u48 from high_precision_uint.py (unused, depended on sparse_random) - Restructure tests to mirror source directory structure (tests/vllm/...) - Simplify tests: one parameterized test per component, e2e test with debug intermediates https://claude.ai/code/session_01Rd5Nd2h4ZSmMW5wAAkiuwt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Restructure tallax/vllm/ from the fix-topp-integer branch into three
clear packages:
vllm/utils/ - Shared utilities used by both code paths:
vllm/fullvocab/ - Full-vocabulary binary search path (never sorts):
vllm/reducedk/ - Sort-based reduced-k path (bitonic sort to k, then operate):
original order, sample via modulo_u128_u64
vllm/reference.py - Pure JAX reference (no Pallas):
Debug support: fullvocab kernel accepts debug=True to return a nested dict
of int32[1] SMEM values for verifying intermediate stages match reference.
Also:
divide-and-filter, >= swap for stable index ordering, reverse loop order)
topp_mask, reference greedy, fullvocab vs reference end-to-end)
https://claude.ai/code/session_01Rd5Nd2h4ZSmMW5wAAkiuwt