Skip to content

Fix bug with shape_transform_descriptor for Llama3.2 model #4491

@CharlieL7

Description

@CharlieL7

DOR (Definition of Ready)

Ready as is to be worked on.

Description

There's a bug with the rewrite_reshapes pass for a case found when running Llama3.2.

Here's a code snippet that can reproduce the bug:

input_ids = @param:input_ids -> int64_type, {8, 1}, {1, 1}
@34 = pointwise(input_ids), [main:pointwise0] -> float_type, {8, 1}, {1, 1}
@35 = gather[axis=0](@32,@34) -> float_type, {8, 1, 2048}, {2048, 2048, 1}
@36 = broadcast[axis=2,out_lens={8, 1, 64, 1, 32}](@5) -> float_type, {8, 1, 64, 1, 32}, {0, 0, 32, 32, 1}
@37 = reshape[dims={8, 1, 64, 1, 32}](@35) -> float_type, {8, 1, 64, 1, 32}, {2048, 2048, 32, 32, 1}
@38 = fused_reduce[axes={2}](@35), [main:pointwise1:main:reduce_sum0:main:pointwise3] -> float_type, {8, 1, 1}, {1, 1, 1}
@39 = unsqueeze[axes={2, 4},steps={}](@38) -> float_type, {8, 1, 1, 1, 1}, {1, 1, 1, 1, 1}
@40 = multibroadcast[out_lens={8, 1, 64, 1, 32},out_dyn_dims={}](@39) -> float_type, {8, 1, 64, 1, 32}, {1, 1, 0, 1, 0}
@41 = pointwise(@37,@40,@36), [main:pointwise6] -> float_type, {8, 1, 64, 1, 32}, {2048, 2048, 32, 32, 1}
  • Tests were made for it here: https://github.com/ROCm/AMDMIGraphX/pull/4490/files.
  • Fix error with rewrite_reshapes #4482 will get around the bug by preventing the shape transformation for the case seen. This workaround leaves performance on the table from a possible fusion of the multibroadcast into the fused_reduce.
  • Without the workaround the model does not compile and produces a shape mismatch error when trying to insert a modified pointwise instruction.

DOD (Definition of Done)

  • Investigate the base cause of the issue in shape_transform_descriptor
  • Create a bugfix that allows this case to be handled correctly
  • The following tests pass:
"test/fuse_reduce.cpp"
// TODO fix shape_transform_descriptor error
// See also: test/shape_transform_descriptor.cpp
TEST_CASE(reduce_to_scalar_and_broadcast)
{
    // Taken from bug found when compiling Llama 3.2
    // input_ids = @param:input_ids -> int64_type, {8, 1}, {1, 1}
    // @34 = pointwise(input_ids), [main:pointwise0] -> float_type, {8, 1}, {1, 1}
    // @35 = gather[axis=0](@32,@34) -> float_type, {8, 1, 2048}, {2048, 2048, 1}
    // @36 = broadcast[axis=2,out_lens={8, 1, 64, 1, 32}](@5) -> float_type, {8, 1, 64, 1, 32}, {0, 0, 32, 32, 1}
    // @37 = reshape[dims={8, 1, 64, 1, 32}](@35) -> float_type, {8, 1, 64, 1, 32}, {2048, 2048, 32, 32, 1}
    // @38 = fused_reduce[axes={2}](@35), [main:pointwise1:main:reduce_sum0:main:pointwise3] -> float_type, {8, 1, 1}, {1, 1, 1}
    // @39 = unsqueeze[axes={2, 4},steps={}](@38) -> float_type, {8, 1, 1, 1, 1}, {1, 1, 1, 1, 1}
    // @40 = multibroadcast[out_lens={8, 1, 64, 1, 32},out_dyn_dims={}](@39) -> float_type, {8, 1, 64, 1, 32}, {1, 1, 0, 1, 0}
    // @41 = pointwise(@37,@40,@36), [main:pointwise6] -> float_type, {8, 1, 64, 1, 32}, {2048, 2048, 32, 32, 1}
    
    migraphx::shape shape_l1{migraphx::shape::float_type, {128256, 2048}};
    migraphx::shape shape_x1{migraphx::shape::float_type, {8, 1}};

    migraphx::program p1;
    {
        auto* mm = p1.get_main_module();
        auto l1 = mm->add_parameter("l1", shape_l1);
        auto x1 = mm->add_parameter("x1", shape_x1);
        auto gather = mm->add_instruction(migraphx::make_op("gather"), l1, x1);
        auto i1 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8, 1, 64, 1, 32}}}), gather);
        auto reduce = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), gather);
        auto unsqueeze = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2, 4}}}), reduce);
        auto i2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8, 1, 64, 1, 32}}}), unsqueeze);
        auto pw = mm->add_instruction(migraphx::make_op("add"), i1, i2);
        mm->add_return({pw});
    }
    
    // want the multibroadcast to be along reduction axes so it can be fused into the reduction
    migraphx::program p2;
    {
        auto* mm = p2.get_main_module();
        auto l1 = mm->add_parameter("l1", shape_l1);
        auto x1 = mm->add_parameter("x1", shape_x1);
        auto gather = mm->add_instruction(migraphx::make_op("gather"), l1, x1);
        auto i1 = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8, 1, 64, 1, 32}}}), gather);
        auto reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8, 1, 64, 1, 32}}}), gather);
        auto reduce = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2, 4}}}), reshape);
        auto i2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {8, 1, 64, 1, 32}}}), reduce);
        auto pw = mm->add_instruction(migraphx::make_op("add"), i1, i2);
        mm->add_return({pw});
    }
}
"test/shape_transform_descriptor.cpp"
TEST_CASE(rebase_broadcasted_scalar_from_reduce)
{
    // Taken from bug found when compiling Llama 3.2
    // input_ids = @param:input_ids -> int64_type, {8, 1}, {1, 1}
    // @34 = pointwise(input_ids), [main:pointwise0] -> float_type, {8, 1}, {1, 1}
    // @35 = gather[axis=0](@32,@34) -> float_type, {8, 1, 2048}, {2048, 2048, 1}
    // @36 = broadcast[axis=2,out_lens={8, 1, 64, 1, 32}](@5) -> float_type, {8, 1, 64, 1, 32}, {0, 0, 32, 32, 1}
    // @37 = reshape[dims={8, 1, 64, 1, 32}](@35) -> float_type, {8, 1, 64, 1, 32}, {2048, 2048, 32, 32, 1}
    // @38 = fused_reduce[axes={2}](@35), [main:pointwise1:main:reduce_sum0:main:pointwise3] -> float_type, {8, 1, 1}, {1, 1, 1}
    // @39 = unsqueeze[axes={2, 4},steps={}](@38) -> float_type, {8, 1, 1, 1, 1}, {1, 1, 1, 1, 1}
    // @40 = multibroadcast[out_lens={8, 1, 64, 1, 32},out_dyn_dims={}](@39) -> float_type, {8, 1, 64, 1, 32}, {1, 1, 0, 1, 0}
    // @41 = pointwise(@37,@40,@36), [main:pointwise6] -> float_type, {8, 1, 64, 1, 32}, {2048, 2048, 32, 32, 1}

    auto base_desc =
        make_simple_descriptor({8, 1, 1},
                               make_op("unsqueeze", {{"axes", {2, 4}}}),
                               make_op("multibroadcast", {{"out_lens", {8, 1, 64, 1, 32}}}));

    {
        auto desc = base_desc.rebase({8, 1, 2048});
        EXPECT(not desc.empty());
        EXPECT(get_final_lens(desc) == final_lens{8, 1, 64, 1, 32});
        EXPECT(get_all_lens(desc) == all_lens{{8}, {1}, {64}, {1}, {32}});
        EXPECT(get_all_axes(desc) == all_axes{d_axes{{0}}, d_axes{{1}}, d_axes{{2, 0}}, d_axes{{2, 1}}, d_axes{{2, 2}}});
        auto generated = desc.generate();
        EXPECT(generated ==
               ops{
                   make_op("reshape", {{"out_lens", {8, 1, 64, 1, 32}}}),
               });
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions