You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix error with rewrite_reshapes #4482 will get around the bug by preventing the shape transformation for the case seen. This workaround leaves performance on the table from a possible fusion of the multibroadcast into the fused_reduce.
Without the workaround the model does not compile and produces a shape mismatch error when trying to insert a modified pointwise instruction.
DOD (Definition of Done)
Investigate the base cause of the issue in shape_transform_descriptor
Create a bugfix that allows this case to be handled correctly
The following tests pass:
"test/fuse_reduce.cpp"
// TODO fix shape_transform_descriptor error
// See also: test/shape_transform_descriptor.cpp
TEST_CASE(reduce_to_scalar_and_broadcast)
{
// Taken from bug found when compiling Llama 3.2
// input_ids = @param:input_ids -> int64_type, {8, 1}, {1, 1}
// @34 = pointwise(input_ids), [main:pointwise0] -> float_type, {8, 1}, {1, 1}
// @35 = gather[axis=0](@32,@34) -> float_type, {8, 1, 2048}, {2048, 2048, 1}
// @36 = broadcast[axis=2,out_lens={8, 1, 64, 1, 32}](@5) -> float_type, {8, 1, 64, 1, 32}, {0, 0, 32, 32, 1}
// @37 = reshape[dims={8, 1, 64, 1, 32}](@35) -> float_type, {8, 1, 64, 1, 32}, {2048, 2048, 32, 32, 1}
// @38 = fused_reduce[axes={2}](@35), [main:pointwise1:main:reduce_sum0:main:pointwise3] -> float_type, {8, 1, 1}, {1, 1, 1}
// @39 = unsqueeze[axes={2, 4},steps={}](@38) -> float_type, {8, 1, 1, 1, 1}, {1, 1, 1, 1, 1}
// @40 = multibroadcast[out_lens={8, 1, 64, 1, 32},out_dyn_dims={}](@39) -> float_type, {8, 1, 64, 1, 32}, {1, 1, 0, 1, 0}
// @41 = pointwise(@37,@40,@36), [main:pointwise6] -> float_type, {8, 1, 64, 1, 32}, {2048, 2048, 32, 32, 1}
migraphx::shape shape_l1{migraphx::shape::float_type, {128256, 2048}};
migraphx::shape shape_x1{migraphx::shape::float_type, {8, 1}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto l1 = mm->add_parameter("l1", shape_l1);
auto x1 = mm->add_parameter("x1", shape_x1);
auto gather = mm->add_instruction(migraphx::make_op("gather"), l1, x1);
auto i1 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8, 1, 64, 1, 32}}}), gather);
auto reduce = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), gather);
auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}}), reduce);
auto i2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8, 1, 64, 1, 32}}}), unsqueeze);
auto pw = mm->add_instruction(migraphx::make_op("add"), i1, i2);
mm->add_return({pw});
}
// want the multibroadcast to be along reduction axes so it can be fused into the reduction
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto l1 = mm->add_parameter("l1", shape_l1);
auto x1 = mm->add_parameter("x1", shape_x1);
auto gather = mm->add_instruction(migraphx::make_op("gather"), l1, x1);
auto i1 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8, 1, 64, 1, 32}}}), gather);
auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8, 1, 64, 1, 32}}}), gather);
auto reduce = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 4}}}), reshape);
auto i2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8, 1, 64, 1, 32}}}), reduce);
auto pw = mm->add_instruction(migraphx::make_op("add"), i1, i2);
mm->add_return({pw});
}
}
DOR (Definition of Ready)
Ready as is to be worked on.
Description
There's a bug with the rewrite_reshapes pass for a case found when running Llama3.2.
Here's a code snippet that can reproduce the bug:
DOD (Definition of Done)