mla: add fp8 qh32 seqlen=1 persistent kernel support on gfx950#3304
Open
alexioslyrakis-amd wants to merge 1 commit into
Open
mla: add fp8 qh32 seqlen=1 persistent kernel support on gfx950#3304alexioslyrakis-amd wants to merge 1 commit into
alexioslyrakis-amd wants to merge 1 commit into
Conversation
Add the mla_a8w8_qh32_qseqlen1_gqaratio32_ps kernel for gfx950 (MI350X). This covers the decode case with gqa_ratio=32, fp8 Q/KV, and seqlen_q=1. - asm_mla.cu: add seqlen=1 dispatch branch for gqa_ratio=32 fp8/fp8 (sub_Q=32); update error message to reflect supported seqlens 1/2/4 - v1_2_device.cuh: add seqlen=1 to natively_supported conditions for gfx950 nhead=32 fp8 - mla.py: add gfx950/nhead=32/fp8/seqlen=1 to the native-path selector - mla_asm.csv: register new .co entry for qh32 seqlen=1 persistent kernel - mla_a8w8_qh32_qseqlen1_gqaratio32_ps.co: compiled kernel binary - test_mla.py, test_mla_persistent.py: enable nhead=32 fp8 seqlen=1 test paths
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds the
mla_a8w8_qh32_qseqlen1_gqaratio32_pskernel for gfx950 (MI350X), covering the decode case withgqa_ratio=32, fp8 Q/KV, andseqlen_q=1.asm_mla.cu: addseqlen=1dispatch branch forgqa_ratio=32fp8/fp8 (sub_Q=32); update error message to list all supported seqlens (1/2/4)v1_2_device.cuh: addseqlen=1tonatively_supportedconditions for gfx950 nhead=32 fp8mla.py: add gfx950/nhead=32/fp8/seqlen=1 to the native-path selector inmla_decode_fwdmla_asm.csv: register new.coentry for the qh32 seqlen=1 persistent kernelmla_a8w8_qh32_qseqlen1_gqaratio32_ps.co: compiled kernel binarytest_mla.py,test_mla_persistent.py: enable nhead=32 fp8 seqlen=1 test pathsTest plan
python op_tests/test_mla_persistent.py --nhead 32,1 --dtype fp8 --kv_dtype fp8 --batchSize 512 --ctxLen 4096on MI350X