Skip to content
9 changes: 5 additions & 4 deletions src/fuse_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -790,23 +790,24 @@ struct find_kv_cache_attention
match::skip(match::name(skip_set))(match::name("transpose")(match::arg(0)(keys)));
auto queries = match::name("slice");
auto gemm1 = match::name("dot")(match::arg(0)(queries), match::arg(1)(k_transpose));
auto scale = match::name("mul")(match::any_arg(0, 1)(gemm1));
auto gemm1_maybe_cvt = match::skip(match::name("convert"))(gemm1);
auto scale = match::name("mul")(match::any_arg(0, 1)(gemm1_maybe_cvt));
auto broadcasted_const = match::name("multibroadcast")(match::arg(0)(match::is_constant()));
auto attn_scores = match::any_of(scale, gemm1);
auto attn_scores = match::any_of(scale, gemm1_maybe_cvt);
auto causal_mask =
match::name("where")(match::arg(0)(broadcasted_const), match::arg(2)(attn_scores));
auto conv_grtr = match::name("convert")(match::arg(0)(match::name("greater")));
auto local_window_comp = match::skip(match::name(skip_set))(conv_grtr);
auto local_window_mask =
match::name("where")(match::arg(0)(match::any_of(local_window_comp, broadcasted_const)),
match::arg(2)(match::any_of(causal_mask, scale, gemm1)));
match::arg(2)(match::any_of(causal_mask, scale, gemm1_maybe_cvt)));
auto greater = match::name("greater")(match::arg(1)(match::any().bind("total_sl")));
auto conv_greater =
match::skip(match::name("unsqueeze"))(match::name("convert")(match::arg(0)(greater)));
auto bc_greater = match::name("multibroadcast")(match::arg(0)(conv_greater));
auto mask = match::name("where")(
match::arg(0)(bc_greater),
match::arg(2)(match::any_of(local_window_mask, causal_mask, scale, gemm1)));
match::arg(2)(match::any_of(local_window_mask, causal_mask, scale, gemm1_maybe_cvt)));
auto attn_probabilities = match::skip(match::name("convert"))(
match::softmax_input(match::skip(match::name("convert"))(mask)));
auto values =
Expand Down
86 changes: 86 additions & 0 deletions src/rewrite_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <migraphx/common.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_convert.hpp>
#include <migraphx/unfold.hpp>
#include <migraphx/dead_code_elimination.hpp>

MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_FP32_SOFTMAX);
Expand Down Expand Up @@ -79,6 +80,90 @@ struct find_softmax
}
};

// Extend the FP32 upcast range from the dot output through mul/where to
// softmax. Prevents FP16 overflow in Q*K attention dot products for models
// with large k_proj.bias values (e.g. Qwen, DeepSeek).
//
// The dot stays as dot(f16,f16)->f16. A convert(f16->f32) is inserted on
// its output, and the intermediate ops (mul, where) are upcasted to f32.
// MFMA/WMMA accumulates in f32 internally; when fused into an attention
// kernel, rocMLIR's RemoveRedundantCasts pass preserves the f32 accumulator.
//
// Runs before find_softmax_base_ops so that the softmax internals
// (reduce_max through div) are still in f16 when find_softmax_base_ops
// processes them.
struct find_dot_softmax_fp32
{
auto matcher() const { return match::softmax(); }

// Walk backwards from the softmax input through the attention chain
// to find an upstream dot. At each step, follows the non-constant,
// non-bool input (the attention data path), skipping constants (scale,
// -inf literals) and bool inputs (where conditions/masks).
static std::optional<instruction_ref> find_upstream_dot(instruction_ref inp)
{
auto step = [](instruction_ref current) -> std::optional<instruction_ref> {
if(current->name() == "dot")
return std::nullopt;
if(current->inputs().size() == 1)
return current->inputs().front();
auto it = std::find_if(
current->inputs().begin(), current->inputs().end(), [](instruction_ref input) {
return not input->can_eval() and input->get_shape().type() != shape::bool_type;
});
if(it == current->inputs().end())
return std::nullopt;
return *it;
};
auto chain = unfold(inp, step);
auto it = std::find_if(
chain.begin(), chain.end(), [](instruction_ref ins) { return ins->name() == "dot"; });
if(it != chain.end())
return *it;
return std::nullopt;
}

void apply(module& m, const match::matcher_result& r) const
{
auto inp = r.instructions["x"];
auto inp_type = inp->get_shape().type();

if(contains({shape::float_type, shape::double_type}, inp_type))
return;

auto dot_opt = find_upstream_dot(inp);
if(not dot_opt.has_value())
return;

// Upcast ops between dot (exclusive) and inp (inclusive)
auto dot_ins = *dot_opt;
auto pre_inss = find_instructions_between(dot_ins, inp, &m);

for(const auto& ins : pre_inss)
{
if(ins == dot_ins)
continue;

std::vector<instruction_ref> ins_inputs_up;
std::transform(
ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(ins_inputs_up),
[&](auto i) {
if(i->get_shape().type() == shape::bool_type or
i->get_shape().type() == shape::float_type)
return i;
return m.insert_instruction(
ins, make_op("convert", {{"target_type", shape::float_type}}), i);
});

auto ins_up = m.insert_instruction(ins, ins->get_operator(), ins_inputs_up);
m.replace_instruction(
ins, make_op("convert", {{"target_type", ins->get_shape().type()}}), ins_up);
}
}
};

struct find_softmax_base_ops
{
bool full_precision;
Expand Down Expand Up @@ -218,6 +303,7 @@ void rewrite_reduce::apply(module& m) const

if(not enabled(MIGRAPHX_DISABLE_FP32_SOFTMAX{}))
{
match::find_matches(m, find_dot_softmax_fp32{});
match::find_matches(m, find_softmax_base_ops{});
Comment on lines +306 to 307
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new matcher should come after the base_ops transformation

Comment thread
aditya-dl marked this conversation as resolved.
migraphx::run_passes(m,
{migraphx::eliminate_convert{},
Expand Down
110 changes: 110 additions & 0 deletions test/fuse_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/fuse_attention.hpp>
#include <migraphx/rewrite_reduce.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/pass_manager.hpp>
Expand Down Expand Up @@ -1476,6 +1477,115 @@ TEST_CASE(kv_cache_attention)
EXPECT(p1.sort() == p2.sort());
}

// Verify that rewrite_reduce's FP32 upcast (which inserts convert(f16->f32)
// after the dot output) does not break kv_cache_attention fusion.
//
// In the real pipeline, rewrite_reduce runs first and extends the FP32 upcast
// range from the dot output through softmax. This inserts a convert(f16->f32)
// between the dot and mul. Then fuse_attention runs and must still recognize
// the attention pattern despite the convert. The matcher's skip(convert)
// allows it to see through the convert to the dot.
//
// This test runs both passes in sequence on a clean GQA graph and verifies
// that attention fusion still produces a group{tag="kv_cache_attention"}.
TEST_CASE(kv_cache_attention_with_fp32_softmax_upcast)
{
migraphx::shape s1{migraphx::shape::half_type, {1}};
migraphx::shape s2{migraphx::shape::int32_type, {4}};
migraphx::shape s3{migraphx::shape::half_type, {4, 1}};
migraphx::shape s4{migraphx::shape::int32_type, {2, 1}};
migraphx::shape s5{migraphx::shape::half_type, {2, 2, 4, 2}};
migraphx::shape s6{migraphx::shape::half_type, {2, 1, 12}};

// Build a clean GQA graph with softmax op (not decomposed).
// This is the graph as it comes from the ONNX parser, before any passes.
migraphx::program p;
{
auto* mm = p.get_main_module();
auto half = mm->add_literal(migraphx::literal{s1, {0.5}});
auto ninf =
mm->add_literal(migraphx::literal{s1, {-std::numeric_limits<float>::infinity()}});
auto range = mm->add_literal(migraphx::literal{s2, {1, 2, 3, 4}});
auto sin_cache = mm->add_parameter("sin_cache", s3);
auto cos_cache = mm->add_parameter("cos_cache", s3);
auto slk = mm->add_parameter("slk", s4);
auto v = mm->add_parameter("v", s5);
auto k = mm->add_parameter("k", s5);
auto query = mm->add_parameter("query", s6);
auto rsp_q =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 6, 2}}}), query);
auto tsp_q = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), rsp_q);
auto rope = mm->add_instruction(
migraphx::make_op("gqa_rotary_embedding",
{{"num_heads", 2}, {"kv_num_heads", 2}, {"interleaved", 0}}),
tsp_q,
slk,
cos_cache,
sin_cache);
auto slc_k = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {2}}, {"ends", {4}}}), rope);
auto slc_v = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {4}}, {"ends", {6}}}), rope);
auto cpp_k = mm->add_instruction(
migraphx::make_op("concat_past_present", {{"kv_num_heads", 2}}), slc_k, slk, k);
auto cpp_v = mm->add_instruction(
migraphx::make_op("concat_past_present", {{"kv_num_heads", 2}}), slc_v, slk, v);
auto slc_q = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {2}}}), rope);
auto tsp_k = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), cpp_k);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), slc_q, tsp_k);
auto bc_range = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {2, 4}}}), range);
auto bc_ninf = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 1, 4}}}), ninf);
auto bc_half = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 1, 4}}}), half);
auto scaled = mm->add_instruction(migraphx::make_op("mul"), gemm1, bc_half);
auto bc_slk =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {2, 4}}}), slk);
auto grtr = mm->add_instruction(migraphx::make_op("greater"), bc_range, bc_slk);
auto conv_grtr = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::bool_type}}), grtr);
auto unsq_grtr = mm->add_instruction(
migraphx::make_op("unsqueeze", {{"axes", {1, 2}}, {"steps", {}}}), conv_grtr);
auto bc_grtr = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 1, 4}}}), unsq_grtr);
auto mask = mm->add_instruction(migraphx::make_op("where"), bc_grtr, bc_ninf, scaled);
auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), mask);
auto gemm2 = mm->add_instruction(migraphx::make_op("dot"), softmax, cpp_v);
auto tsp_out = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm2);
auto rsp_out =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 1, 4}}}), tsp_out);
mm->add_return({rsp_out, cpp_k, cpp_v});
}

// Run rewrite_reduce first: decomposes softmax, inserts convert(f16->f32)
// after dot, extending the FP32 upcast range through the attention chain.
migraphx::run_passes(*p.get_main_module(),
{migraphx::rewrite_reduce{}, migraphx::dead_code_elimination{}});

// Run fuse_attention: must still match the kv_cache_attention pattern
// despite the convert between dot and mul.
run_pass(p, {.attn_enabled = true});

// Verify fusion happened: the output should contain a group instruction
// with tag "kv_cache_attention"
bool found_kv_cache_attention = false;
for(const auto& ins : *p.get_main_module())
{
if(ins.name() == "group")
{
auto tag = ins.get_operator().to_value()["tag"].to<std::string>();
if(tag == "kv_cache_attention")
found_kv_cache_attention = true;
}
}
EXPECT(found_kv_cache_attention);
}

// Verify that pointwise ops (add/mul from rotary embedding) that feed both
// the attention Q path and the K cache path are NOT fused into the attention
// group. Based on build/attn.py model structure.
Expand Down
Loading
Loading