From b46617964659a17d0b7bbf24acf5258679366993 Mon Sep 17 00:00:00 2001 From: Ian Wood Date: Tue, 7 Apr 2026 12:31:45 -0700 Subject: [PATCH] Refactor reduction emitter to macro-based dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract schema templates and dispatch macros for the reduction ASM emitter, matching the pattern used by the pointwise emitter. This makes adding new reduction modes a one-line change. Also reformats the integration test GENERATE list to one mode per line with clang-format guards, matching the pointwise cleanup in https://github.com/iree-org/fusilli/pull/296. No behavior change — emitted MLIR is byte-identical. Signed-off-by: Ian Wood --- include/fusilli/support/asm_emitter.h | 80 +++++++++++---------------- samples/reduction/reduction_ops.cpp | 8 ++- 2 files changed, 38 insertions(+), 50 deletions(-) diff --git a/include/fusilli/support/asm_emitter.h b/include/fusilli/support/asm_emitter.h index 1fd9ea47..c4ba77a3 100644 --- a/include/fusilli/support/asm_emitter.h +++ b/include/fusilli/support/asm_emitter.h @@ -1849,73 +1849,57 @@ inline ErrorOr ReductionNode::emitNodePreAsm() const { std::string permuteY = getLayoutConversionOpsAsm(yT, "permute_Y", suffix, /*isInput=*/false); - switch (reductionAttr.getMode()) { - case ReductionAttr::Mode::SUM: { - constexpr std::string_view schema = R"( + constexpr std::string_view kKeepdimReductionSchema = R"( {0} {1} %keepdim_{2} = torch.constant.bool true - %dtype_{2} = torch.constant.none - {3}_{2}_perm = torch.aten.sum.dim_IntList {4}, %reduction_dims_{2}, %keepdim_{2}, %dtype_{2} : {5}, !torch.list, !torch.bool, !torch.none -> {6} + {3}_{2}_perm = {8} {4}, %reduction_dims_{2}, %keepdim_{2} : {5}, !torch.list, !torch.bool -> {6} {7} )"; - return std::format(schema, - permuteX, // {0} - dimListOss.str(), // {1} - suffix, // {2} - getResultNamesAsm(), // {3} - getOperandNamesAsm(), // {4} - getOperandTypesAsm(), // {5} - getResultTypesAsm(), // {6} - permuteY // {7} - ); - } - case ReductionAttr::Mode::MIN: { - constexpr std::string_view schema = R"( + constexpr std::string_view kKeepdimDtypeReductionSchema = R"( {0} {1} %keepdim_{2} = torch.constant.bool true - {3}_{2}_perm = torch.aten.amin {4}, %reduction_dims_{2}, %keepdim_{2} : {5}, !torch.list, !torch.bool -> {6} + %dtype_{2} = torch.constant.none + {3}_{2}_perm = {8} {4}, %reduction_dims_{2}, %keepdim_{2}, %dtype_{2} : {5}, !torch.list, !torch.bool, !torch.none -> {6} {7} )"; - return std::format(schema, - permuteX, // {0} - dimListOss.str(), // {1} - suffix, // {2} - getResultNamesAsm(), // {3} - getOperandNamesAsm(), // {4} - getOperandTypesAsm(), // {5} - getResultTypesAsm(), // {6} - permuteY // {7} - ); +#define FUSILLI_DECLARE_REDUCTION_EMITTER(MODE, SCHEMA, OPIR) \ + case ReductionAttr::Mode::MODE: { \ + return std::format(SCHEMA, permuteX, /* {0} */ \ + dimListOss.str(), /* {1} */ \ + suffix, /* {2} */ \ + getResultNamesAsm(), /* {3} */ \ + getOperandNamesAsm(), /* {4} */ \ + getOperandTypesAsm(), /* {5} */ \ + getResultTypesAsm(), /* {6} */ \ + permuteY, /* {7} */ \ + #OPIR /* {8} */ \ + ); \ } - case ReductionAttr::Mode::MAX: { - constexpr std::string_view schema = R"( - {0} - {1} - %keepdim_{2} = torch.constant.bool true - {3}_{2}_perm = torch.aten.amax {4}, %reduction_dims_{2}, %keepdim_{2} : {5}, !torch.list, !torch.bool -> {6} - {7} - )"; - return std::format(schema, - permuteX, // {0} - dimListOss.str(), // {1} - suffix, // {2} - getResultNamesAsm(), // {3} - getOperandNamesAsm(), // {4} - getOperandTypesAsm(), // {5} - getResultTypesAsm(), // {6} - permuteY // {7} - ); - } +#define FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER(MODE, OPIR) \ + FUSILLI_DECLARE_REDUCTION_EMITTER(MODE, kKeepdimReductionSchema, OPIR) + +#define FUSILLI_DECLARE_KEEPDIM_DTYPE_REDUCTION_EMITTER(MODE, OPIR) \ + FUSILLI_DECLARE_REDUCTION_EMITTER(MODE, kKeepdimDtypeReductionSchema, OPIR) + + switch (reductionAttr.getMode()) { + FUSILLI_DECLARE_KEEPDIM_DTYPE_REDUCTION_EMITTER(SUM, + torch.aten.sum.dim_IntList) + FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER(MIN, torch.aten.amin) + FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER(MAX, torch.aten.amax) default: return error(ErrorCode::InternalError, "Unsupported reduction mode"); } } +#undef FUSILLI_DECLARE_REDUCTION_EMITTER +#undef FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER +#undef FUSILLI_DECLARE_KEEPDIM_DTYPE_REDUCTION_EMITTER + //===----------------------------------------------------------------------===// // // CustomOpNode ASM Emitter Methods diff --git a/samples/reduction/reduction_ops.cpp b/samples/reduction/reduction_ops.cpp index 9597fcb7..19e3d9a0 100644 --- a/samples/reduction/reduction_ops.cpp +++ b/samples/reduction/reduction_ops.cpp @@ -44,8 +44,12 @@ TEST_CASE("Reduction ops", "[reduction][graph]") { const auto xDims = std::vector{2, 16, 8, 8}; const auto yDims = std::vector{2, 16, 1, 1}; - const auto mode = GENERATE(ReductionAttr::Mode::SUM, ReductionAttr::Mode::MIN, - ReductionAttr::Mode::MAX); + // clang-format off + const auto mode = GENERATE( + ReductionAttr::Mode::SUM, + ReductionAttr::Mode::MIN, + ReductionAttr::Mode::MAX); + // clang-format on auto execute = [&](Handle &handle, DataType dt, T initValue) { // Create graph.