Skip to content

Reorganize sampling kernels into clear code paths with shared utils#170

Open
oliverdutton wants to merge 2 commits intomainfrom
claude/refactor-sampling-kernels-IWkhw
Open

Reorganize sampling kernels into clear code paths with shared utils#170
oliverdutton wants to merge 2 commits intomainfrom
claude/refactor-sampling-kernels-IWkhw

Conversation

@oliverdutton
Copy link
Copy Markdown
Owner

Restructure tallax/vllm/ from the fix-topp-integer branch into three
clear packages:

vllm/utils/ - Shared utilities used by both code paths:

  • binary_search.py: Monotonic f32<->u32 conversion, binary search
  • high_precision_uint.py: U48 48-bit arithmetic, modulo_u128_u64, RNG

vllm/fullvocab/ - Full-vocabulary binary search path (never sorts):

  • topk_mask.py: Binary search in f32 space + stable boundary detection
  • topp_mask.py: Binary search in i32 probability space with U48 sums
  • kernel.py: Combined topk+topp+sample Pallas kernel with debug_results

vllm/reducedk/ - Sort-based reduced-k path (bitonic sort to k, then operate):

  • top_p_and_sample.py: top_p_integer_mask on sorted subset, re-sort to
    original order, sample via modulo_u128_u64

vllm/reference.py - Pure JAX reference (no Pallas):

  • Uses jax.enable_x64 for exact i64 arithmetic
  • Python arbitrary-precision u128 % u64 via pure_callback
  • Ground truth for both kernel paths

Debug support: fullvocab kernel accepts debug=True to return a nested dict
of int32[1] SMEM values for verifying intermediate stages match reference.

Also:

  • Cherry-pick stable dynamic_topk from stable_topk branch (stable sort in
    divide-and-filter, >= swap for stable index ordering, reverse loop order)
  • Fix syntax errors in cherry-picked code (missing parens)
  • Add import operator for stable comparison
  • Add sparse_random_bits() extracted from sparse_random_uniform()
  • Add map_reduce() to tax/utils.py for parallel chunked reduction
  • Add comprehensive component tests (binary_search, U48, modulo, topk_mask,
    topp_mask, reference greedy, fullvocab vs reference end-to-end)

https://claude.ai/code/session_01Rd5Nd2h4ZSmMW5wAAkiuwt

Restructure tallax/vllm/ from the fix-topp-integer branch into three
clear packages:

**vllm/utils/** - Shared utilities used by both code paths:
  - binary_search.py: Monotonic f32<->u32 conversion, binary search
  - high_precision_uint.py: U48 48-bit arithmetic, modulo_u128_u64, RNG

**vllm/fullvocab/** - Full-vocabulary binary search path (never sorts):
  - topk_mask.py: Binary search in f32 space + stable boundary detection
  - topp_mask.py: Binary search in i32 probability space with U48 sums
  - kernel.py: Combined topk+topp+sample Pallas kernel with debug_results

**vllm/reducedk/** - Sort-based reduced-k path (bitonic sort to k, then operate):
  - top_p_and_sample.py: top_p_integer_mask on sorted subset, re-sort to
    original order, sample via modulo_u128_u64

**vllm/reference.py** - Pure JAX reference (no Pallas):
  - Uses jax.enable_x64 for exact i64 arithmetic
  - Python arbitrary-precision u128 % u64 via pure_callback
  - Ground truth for both kernel paths

Debug support: fullvocab kernel accepts debug=True to return a nested dict
of int32[1] SMEM values for verifying intermediate stages match reference.

Also:
- Cherry-pick stable dynamic_topk from stable_topk branch (stable sort in
  divide-and-filter, >= swap for stable index ordering, reverse loop order)
- Fix syntax errors in cherry-picked code (missing parens)
- Add import operator for stable comparison
- Add sparse_random_bits() extracted from sparse_random_uniform()
- Add map_reduce() to tax/utils.py for parallel chunked reduction
- Add comprehensive component tests (binary_search, U48, modulo, topk_mask,
  topp_mask, reference greedy, fullvocab vs reference end-to-end)

https://claude.ai/code/session_01Rd5Nd2h4ZSmMW5wAAkiuwt
…arse_random

- Rename fullvocab/ -> arbitrary_k/ (binary-search-based, works with any k)
- Rename reducedk/ -> bounded_k/ (sort-based, requires bounded k <= 128)
- Remove sparse_random.py and its tests (all sampling now uses i32 path)
- Remove dead top_p_and_sample.py (old code superseded by bounded_k)
- Remove random_u48 from high_precision_uint.py (unused, depended on sparse_random)
- Restructure tests to mirror source directory structure (tests/vllm/...)
- Simplify tests: one parameterized test per component, e2e test with debug intermediates

https://claude.ai/code/session_01Rd5Nd2h4ZSmMW5wAAkiuwt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants