Skip to content

[Feat] Intranode Dispatch&Combine Kernel #522

Draft
yanboshao wants to merge 4 commits into
mainfrom
yanbo/dispatch_combine
Draft

[Feat] Intranode Dispatch&Combine Kernel #522
yanboshao wants to merge 4 commits into
mainfrom
yanbo/dispatch_combine

Conversation

@yanboshao
Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@yanboshao yanboshao marked this pull request as draft May 14, 2026 07:19
@yanboshao yanboshao changed the title feat(dispatch_combine): intranode dispatch/combine kernel [Feat]: intranode dispatch/combine kernel May 14, 2026
@yanboshao yanboshao changed the title [Feat]: intranode dispatch/combine kernel [Feat] Intranode Dispatch&Combine Kernel May 14, 2026
@yanboshao yanboshao force-pushed the yanbo/dispatch_combine branch from 7f53c40 to 1a8596c Compare May 14, 2026 07:53
xudoyuan
xudoyuan previously approved these changes May 14, 2026
Comment thread python/flydsl/compiler/ast_rewriter.py Outdated
yanboshao added 3 commits May 20, 2026 06:25
Trim the FlyDSL Python helper surface introduced by the dispatch/combine
kernel down to what is strictly necessary, by leaning on existing main-
branch idioms and pushing small kernel-only wrappers into the kernel
file itself.

FlyDSL helper modules
- python/flydsl/expr/arith.py: revert to origin/main. Drop the unused
  divui/remui/select_by_index extensions, and remove zext_i64 in favor
  of a kernel-local _to_i64 helper that wraps arith.extui(_lv_unwrap(...)).
- python/flydsl/expr/vector.py: revert to origin/main. Drop the
  bitcast_i32_to_v2bf16/bitcast_v2bf16_to_i32 helpers; the kernel now
  uses the standard vector.from_elements + vector.bitcast + vector.extract
  idiom (mirrors kernels/hgemm_splitk.py:578-585).
- python/flydsl/expr/rocdl/__init__.py: replace the bespoke ballot_i64 /
  readlane wrappers with generic ballot(res, pred, **kw) and
  readlane(res, src, lane, **kw) functions, aligned with the existing
  readfirstlane(res, src, **kw) style: capture the ODS-generated symbols
  as _ods_ballot / _ods_readlane up top, and use _to_ir coercion in the
  wrappers. Lets call sites pick the lane-mask width (i32 on wave32,
  i64 on wave64) explicitly.

Kernel
- kernels/dispatch_combine_intranode_kernel.py:
  - Add three file-local helpers: _to_i64, _i32_to_vec_bitcast,
    _vec_to_i32_bitcast (with docstrings pointing at the main-branch
    idioms they mirror).
  - Replace 31 arith.zext_i64(x) call sites with _to_i64(x); collapse
    two arith.zext_i64(arith.constant(rank)) sites into
    arith.constant(rank, type=T.i64()).
  - Update the 4 llvm_bitcast call sites to use the new
    _i32_to_vec_bitcast / _vec_to_i32_bitcast helpers.
  - Update ballot_i64(...) / readlane(...) call sites to the new generic
    APIs: ballot(T.i64(), pred), readlane(T.i32(), src, lane).

Net effect vs origin/main: arith.py and vector.py are now untouched;
rocdl/__init__.py keeps a +22 line delta (generic ballot/readlane
wrappers). All complexity that used to live in FlyDSL core has moved
into the kernel file where it belongs.

Verified
- torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py
  --mode verify              -> ALL PASS (diff=0 on dispatch + combine)
- torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py
  --mode verify --enable-std-moe -> ALL PASS (max_diff=0.015625 within
  StdMoE weighted tolerance)
Make the PR pass the `Check Python Code Style` CI step (.github/workflows/
pre-checks.yaml), which runs ``black --check --diff`` and ``ruff check``
on the set of Python files changed by the PR.

Auto-fixes (ruff --fix): I001 (5 unsorted-imports), F401 (3 unused-imports),
F811 (1 redefined-while-unused), W293 (1 blank-line-with-whitespace).

Manual fixes:
- F841 (7 unused-variable): drop dead assignments to ``tok_stride`` /
  ``inp_n_i32`` in dispatch_combine_intranode_kernel.py, and four
  ``hdim`` + one ``esz`` in dispatch_combine_intranode_op.py.
- E702 (23 multiple-statements-on-one-line): split ``a; b; c`` boilerplate
  in tests/kernels/test_profiler_dispatch_combine.py (mostly
  ``dist.all_reduce`` aggregation patterns).
- E402 (2 module-import-not-at-top): add ``# noqa: E402`` to the two
  imports that intentionally follow ``sys.path.insert(0, _p)`` in the
  test script.

Formatting: run ``black`` (line-length=120, per pyproject.toml) on the
four PR-modified Python files. ast_rewriter.py was already compliant.

CI parity locally: ``black --check`` + ``ruff check`` both clean on all
PR files.

Verified end-to-end (8x GPU, gfx942, bf16) after the style sweep:
- torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py
  --mode verify              -> ALL PASS (diff=0 on dispatch + combine).
- torchrun -np 8 tests/kernels/test_profiler_dispatch_combine.py
  --mode verify --enable-std-moe -> ALL PASS (max_diff=0.015625 within
  StdMoE weighted tolerance).
The dispatch/combine intranode test depends on mori shmem, which is
only installed on the 8-GPU multi-gpu CI runners.  Previously pytest
collection on single-GPU / Navi-2-GPU runners would crash because
``import mori`` raises ModuleNotFoundError at module load time.

* tests/kernels/test_profiler_dispatch_combine.py: when imported under
  pytest collection (detected via ``"pytest" in sys.modules``), call
  ``pytest.importorskip("mori")`` so single/dual-GPU jobs cleanly skip
  this file.  Direct ``torchrun``/``python`` invocations are unaffected
  and still surface a normal ImportError when mori is genuinely missing.

* .github/workflows/flydsl.yaml: add two explicit multi-GPU steps to the
  multi-gpu job that run the dispatch/combine verify torchrun script
  (default config + --enable-std-moe).  These only execute when the PR
  carries the ``multi-gpu`` label, providing real 8-GPU coverage for the
  new kernel.

* kernels/dispatch_combine_intranode_op.py: drop unused local
  ``_disp_wpb`` alias, use ``config.warp_num_per_block`` directly.
@yanboshao yanboshao force-pushed the yanbo/dispatch_combine branch from 3f863f7 to 81e859d Compare May 20, 2026 06:27
Align arith module ordering/ruff pragmas with mainline formatting so the Python style pre-check passes reliably in PR CI.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants