Fix FP16 overflow in GQA attention and concat_past_present buffer overflow#4677
Fix FP16 overflow in GQA attention and concat_past_present buffer overflow#4677aditya-dl wants to merge 2 commits intoROCm:developfrom
Conversation
|
Unit tests need to be added and the CI failure fixed. |
There was a problem hiding this comment.
Pull request overview
This PR fixes two root causes of garbage output in Qwen1.5-architecture models during FP16 inference: (1) FP16 overflow in the dot→softmax attention chain, and (2) a buffer overflow in concat_past_present during prompt processing when the sequence length exceeds the past cache size.
Changes:
- Extends
find_softmax_base_opsinrewrite_reduce.cppto walk backward from softmax throughmul/where/broadcast/convertto find a feedingdotinstruction, upcasting the entire range to FP32 (with bool inputs excluded). - Fixes
concat_past_presentbuffer sizing across the operator definition, GPU lowering, JIT compiler, and GPU kernel so that the output buffer is properly sized whensequence_length > past_cache_sequence_length.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
src/rewrite_reduce.cpp |
Adds backward walk from softmax to dot for FP32 upcast range extension; skips bool inputs in conversion |
src/include/migraphx/op/concat_past_present.hpp |
Updates compute_shape to return larger shape when needed; uses std::max for present_buffer_sequence_length |
src/targets/gpu/lowering.cpp |
Allocates properly-sized GPU buffer when output shape exceeds past cache shape |
src/targets/gpu/jit/concat_past_present.cpp |
Adjusts JIT compiler output shape to match larger buffer when needed |
src/targets/gpu/kernels/include/migraphx/kernels/concat_past_present.hpp |
GPU kernel uses max(past_seq, seq_len) for present buffer sequence length |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
|
I think you should put the softmax change in a separate PR as I dont think we will merge the concat_past_present change. |
6f03700 to
9da5d9f
Compare
|
@pfultz2 Updated this PR with only the softmax change. |
Motivation
Qwen1.5-architecture models produce garbage output when running FP16 inference through MIGraphX. Two root causes were identified:
Technical Details
FP16 overflow fix: Extends
find_softmax_base_opsto walk backwards through the attention chain (mul, where, broadcast, convert) to find the feeding dot instruction. The entiredot-to-softmax range is upcast to FP32, preventing overflow in attention score computation. Bool-type inputs (where conditions) are excluded from conversion.Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable