Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/fuse_pointwise_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void fuse_pointwise_reduce::apply(module_pass_manager& mpm) const
if(not enabled(MIGRAPHX_DISABLE_MULTI_OUTPUT_FUSION{}))
{
mpm.run_pass(fuse_pointwise{.enable_multi_output = true});
mpm.run_pass(fuse_reduce{.enable_multi_output = true});
}
}

Expand Down
254 changes: 217 additions & 37 deletions src/fuse_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ struct fused_reduce
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
const auto* sm = mods.front();
if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported");
auto output_shapes = sm->get_output_shapes();
if(output_shapes.empty())
MIGRAPHX_THROW("submodule has no outputs");
if(not sm->bypass())
MIGRAPHX_THROW("fused_reduce: bypass flag is not set");
auto names = sm->get_parameter_names();
Expand All @@ -73,12 +74,29 @@ struct fused_reduce
}))
MIGRAPHX_THROW("Input dimension does not match the submodule.");

if(sm->get_output_shapes().front().dynamic())
return sm->get_output_shapes().front();
// If all outputs are dynamic, return them directly
if(std::all_of(output_shapes.begin(), output_shapes.end(), [](const shape& os) {
return os.dynamic();
}))
{
if(output_shapes.size() == 1)
return output_shapes.front();
return shape{output_shapes};
}

return shape::from_permutation(sm->get_output_shapes().front().type(),
sm->get_output_shapes().front().lens(),
find_permutation(inputs));
auto perm = find_permutation(inputs);
std::vector<shape> result_shapes;
std::transform(output_shapes.begin(),
output_shapes.end(),
std::back_inserter(result_shapes),
[&](const shape& os) {
if(os.dynamic())
return os;
return shape::from_permutation(os.type(), os.lens(), perm);
});
if(result_shapes.size() == 1)
return result_shapes.front();
return shape{result_shapes};
}

std::string name() const { return "fused_reduce"; }
Expand Down Expand Up @@ -247,9 +265,135 @@ static void finalize_reduce_module(module_ref m)
dead_code_elimination{}.apply(*m);
}

static instruction_ref
merge_reduces(module_pass_manager& mpm, instruction_ref input, instruction_ref output)
{
auto& m = mpm.get_module();
const auto* rm1 = input->module_inputs().front();
const auto* rm2 = output->module_inputs().front();
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
rm->set_bypass();

std::unordered_map<instruction_ref, instruction_ref> map_ins;
auto outs1 = insert_module_in_submodule(rm, input, &map_ins);
auto outs2 = insert_module_in_submodule(rm, output, &map_ins);

std::vector<instruction_ref> all_outs;
all_outs.insert(all_outs.end(), outs1.begin(), outs1.end());
all_outs.insert(all_outs.end(), outs2.begin(), outs2.end());
rm->replace_return(all_outs);
finalize_reduce_module(rm);

auto new_inputs = find_inputs(map_ins, &m, rm);
auto fins = m.insert_instruction(output, input->get_operator(), new_inputs, {rm});

// Replace the first instruction's usages with get_tuple_elem
if(input->get_shape().type() == shape::tuple_type)
{
auto input_outputs = input->outputs();
for(auto inp_out : input_outputs)
{
if(inp_out->name() != "get_tuple_elem")
continue;
auto v = inp_out->get_operator().to_value();
auto i = v.at("index").to<std::size_t>();
m.replace_instruction(inp_out, make_op("get_tuple_elem", {{"index", i}}), fins);
}
}
else
{
auto elem =
m.insert_instruction(std::next(fins), make_op("get_tuple_elem", {{"index", 0}}), fins);
m.replace_instruction(input, elem);
}

// Replace the second instruction's usages with get_tuple_elem
std::size_t start2 = outs1.size();
if(output->get_shape().type() == shape::tuple_type)
{
auto output_outputs = output->outputs();
for(auto out_out : output_outputs)
{
if(out_out->name() != "get_tuple_elem")
continue;
auto v = out_out->get_operator().to_value();
auto i = v.at("index").to<std::size_t>();
m.replace_instruction(
out_out, make_op("get_tuple_elem", {{"index", i + start2}}), fins);
}
}
else
{
auto elem = m.insert_instruction(
std::next(fins), make_op("get_tuple_elem", {{"index", start2}}), fins);
m.replace_instruction(output, elem);
}

return fins;
}

static void try_multi_output_merge(module_pass_manager& mpm, instruction_ref reduce)
{
if(reduce->outputs().empty())
return;
auto& m = mpm.get_module();

// Collect sibling fused_reduces sharing an input with the same operator
std::vector<instruction_ref> candidates;
candidates.push_back(reduce);
for(auto inp : reduce->inputs())
{
std::copy_if(inp->outputs().begin(),
inp->outputs().end(),
std::back_inserter(candidates),
[&](instruction_ref output) {
if(output == reduce)
return false;
if(output->name() != "fused_reduce")
return false;
if(not m.has_instruction(output))
return false;
if(output->get_operator() != reduce->get_operator())
return false;
if(output->outputs().empty())
return false;
return std::find(candidates.begin(), candidates.end(), output) ==
candidates.end();
});
}

if(candidates.size() < 2)
return;

// Sort by position in module
std::sort(candidates.begin(), candidates.end(), by(std::less<>{}, [&](auto x) {
return std::distance(m.begin(), x);
}));

// Filter to independent instructions (no reachability between them)
std::vector<instruction_ref> independent;
std::copy_if(
candidates.begin(), candidates.end(), std::back_inserter(independent), [&](auto c) {
return std::none_of(independent.begin(), independent.end(), [&](auto other) {
return reaches(other, c, &m);
});
});

if(independent.size() < 2)
return;

// Iteratively merge all independent reduces
(void)std::accumulate(independent.begin() + 1,
independent.end(),
independent.front(),
[&](auto prev, auto next) { return merge_reduces(mpm, prev, next); });
}

namespace {
struct find_pointwise_reduce
{
bool multi_output = false;

auto matcher() const
{
// fused_reduce instruction with pointwise inputs.
Expand Down Expand Up @@ -285,11 +429,14 @@ struct find_pointwise_reduce

auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
if(multi_output)
try_multi_output_merge(mpm, reduce);
}
};

struct find_reduce_pointwise
{
bool multi_output = false;

auto matcher() const
{
Expand Down Expand Up @@ -327,51 +474,82 @@ struct find_reduce_pointwise

auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
if(multi_output)
try_multi_output_merge(mpm, pw);
}
};

struct find_reduce_reduce
{
bool multi_output = false;

auto matcher() const
{
return match::name("fused_reduce")(match_broadcastable_input("fused_reduce", "reduce"));
return match::any_of(
match::name("fused_reduce")(match_broadcastable_input("fused_reduce", "reduce")),
match::name("fused_reduce"));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce1 = r.result;
auto reduce2 = r.instructions["reduce"];
auto input = r.instructions["input"];

if(reduce1->get_operator() != reduce2->get_operator())
return;
if(contains(r.instructions, "reduce"))
{
// Re-validate broadcast axes since any_of may leave stale bindings
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
auto axes = reduce1->get_operator().to_value().at("axes").to_vector<std::size_t>();
if(not is_valid_broadcast(broadcast, axes))
{
if(multi_output)
try_multi_output_merge(mpm, reduce1);
return;
}
}

const auto* rm1 = reduce1->module_inputs().front();
const auto* rm2 = reduce2->module_inputs().front();
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
rm->set_bypass();
// Chain fusion
auto reduce2 = r.instructions["reduce"];
auto input = r.instructions["input"];

std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy reduce1 instructions
insert_module_in_submodule(rm, reduce2, &map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
auto bout = rm->fuse({broadcast}, &map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = rm->get_returns().front();
}
if(reduce1->get_operator() != reduce2->get_operator())
{
if(multi_output)
try_multi_output_merge(mpm, reduce1);
return;
}

auto out = insert_module_in_submodule(rm, reduce1, &map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);
const auto* rm1 = reduce1->module_inputs().front();
const auto* rm2 = reduce2->module_inputs().front();
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
rm->set_bypass();

auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
std::unordered_map<instruction_ref, instruction_ref> map_ins;
insert_module_in_submodule(rm, reduce2, &map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
auto bout = rm->fuse({broadcast}, &map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = rm->get_returns().front();
}

auto out = insert_module_in_submodule(rm, reduce1, &map_ins);
rm->replace_return(out);
finalize_reduce_module(rm);

auto new_inputs = find_inputs(map_ins, &mpm.get_module(), rm);
mpm.get_module().replace_instruction(
reduce1, reduce1->get_operator(), new_inputs, {rm});
}

if(multi_output)
try_multi_output_merge(mpm, reduce1);
}
};

Expand Down Expand Up @@ -455,8 +633,10 @@ void fuse_reduce::apply(module_pass_manager& mpm) const
{
if(enable_rewrite_reshapes)
mpm.run_pass(rewrite_reshapes<reduce_reshape>{});
match::find_matches(
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
match::find_matches(mpm,
find_reduce_pointwise{.multi_output = enable_multi_output},
find_pointwise_reduce{.multi_output = enable_multi_output},
find_reduce_reduce{.multi_output = enable_multi_output});
mpm.run_pass(dead_code_elimination{});
}
}
Expand Down
1 change: 1 addition & 0 deletions src/include/migraphx/fuse_reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct MIGRAPHX_EXPORT fuse_reduce
void apply(module_pass_manager& mpm) const;

bool enable_rewrite_reshapes = true;
bool enable_multi_output = false;
};

} // namespace MIGRAPHX_INLINE_NS
Expand Down
Loading
Loading