Skip to content

perf(ep): dispatch intranode opt within small token#333

Open
kawhil-amd wants to merge 3 commits into
mainfrom
dev/dispatch_opt
Open

perf(ep): dispatch intranode opt within small token#333
kawhil-amd wants to merge 3 commits into
mainfrom
dev/dispatch_opt

Conversation

@kawhil-amd
Copy link
Copy Markdown
Contributor

No description provided.

kawhil-amd and others added 3 commits May 20, 2026 02:56
…rnel

- Explicitly use Unroll=2 for token data WarpCopy in dispatch kernel
  to reduce loop iterations from 14 to 7 for hiddenDim=7168
- Add __launch_bounds__(512, 1) to help compiler optimize register allocation

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
@kawhil-amd kawhil-amd requested review from TianDi101 and jhchouuu May 21, 2026 08:47
@isytwu
Copy link
Copy Markdown
Collaborator

isytwu commented May 21, 2026

Maybe you could record the data before and after optimization in this PR, as well as block/warp, for 128 and 4096 tokens, respectively?

@kawhil-amd
Copy link
Copy Markdown
Contributor Author

Maybe you could record the data before and after optimization in this PR, as well as block/warp, for 128 and 4096 tokens, respectively?

Sure, will do it later.

@kawhil-amd kawhil-amd self-assigned this May 21, 2026
FlatTokenIndex(config, myPe, srcTokId);

// Wait for previous round to be consumed (use atomicAdd for shared memory)
while (atomicAdd(&groupCounters[warpGroupIdInBlock], 0) != 0) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should move forward?


core::WarpCopy(args.intraNodeTokBufs.dispatchOut->template GetAs<T*>(destPe) + destTokOffset,
args.inpTokenBuf + srcTokOffset, hiddenDim);
if (myDimChunk > 0) {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove branch within ensure divisibility condition?

// Use atomicAdd for shared memory atomic load (workgroup scope)
uint64_t val;
do {
val = atomicAdd((unsigned long long*)&groupData[warpGroupIdInBlock], 0ULL);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lane 0 only?

@kawhil-amd kawhil-amd requested a review from isytwu May 22, 2026 03:39
@kawhil-amd kawhil-amd changed the title perf(ep): dispatch intranode opt perf(ep): dispatch intranode opt within small token May 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants