diff --git a/include/fusilli/support/asm_emitter.h b/include/fusilli/support/asm_emitter.h index 05b06bc2..e350a57f 100644 --- a/include/fusilli/support/asm_emitter.h +++ b/include/fusilli/support/asm_emitter.h @@ -1853,73 +1853,57 @@ inline ErrorOr ReductionNode::emitNodePreAsm() const { std::string permuteY = getLayoutConversionOpsAsm(yT, "permute_Y", suffix, /*isInput=*/false); - switch (reductionAttr.getMode()) { - case ReductionAttr::Mode::SUM: { - constexpr std::string_view schema = R"( + constexpr std::string_view kKeepdimReductionSchema = R"( {0} {1} %keepdim_{2} = torch.constant.bool true - %dtype_{2} = torch.constant.none - {3}_{2}_perm = torch.aten.sum.dim_IntList {4}, %reduction_dims_{2}, %keepdim_{2}, %dtype_{2} : {5}, !torch.list, !torch.bool, !torch.none -> {6} + {3}_{2}_perm = {8} {4}, %reduction_dims_{2}, %keepdim_{2} : {5}, !torch.list, !torch.bool -> {6} {7} )"; - return std::format(schema, - permuteX, // {0} - dimListOss.str(), // {1} - suffix, // {2} - getResultNamesAsm(), // {3} - getOperandNamesAsm(), // {4} - getOperandTypesAsm(), // {5} - getResultTypesAsm(), // {6} - permuteY // {7} - ); - } - case ReductionAttr::Mode::MIN: { - constexpr std::string_view schema = R"( + constexpr std::string_view kKeepdimDtypeReductionSchema = R"( {0} {1} %keepdim_{2} = torch.constant.bool true - {3}_{2}_perm = torch.aten.amin {4}, %reduction_dims_{2}, %keepdim_{2} : {5}, !torch.list, !torch.bool -> {6} + %dtype_{2} = torch.constant.none + {3}_{2}_perm = {8} {4}, %reduction_dims_{2}, %keepdim_{2}, %dtype_{2} : {5}, !torch.list, !torch.bool, !torch.none -> {6} {7} )"; - return std::format(schema, - permuteX, // {0} - dimListOss.str(), // {1} - suffix, // {2} - getResultNamesAsm(), // {3} - getOperandNamesAsm(), // {4} - getOperandTypesAsm(), // {5} - getResultTypesAsm(), // {6} - permuteY // {7} - ); +#define FUSILLI_DECLARE_REDUCTION_EMITTER(MODE, SCHEMA, OPIR) \ + case ReductionAttr::Mode::MODE: { \ + return std::format(SCHEMA, permuteX, /* {0} */ \ + dimListOss.str(), /* {1} */ \ + suffix, /* {2} */ \ + getResultNamesAsm(), /* {3} */ \ + getOperandNamesAsm(), /* {4} */ \ + getOperandTypesAsm(), /* {5} */ \ + getResultTypesAsm(), /* {6} */ \ + permuteY, /* {7} */ \ + #OPIR /* {8} */ \ + ); \ } - case ReductionAttr::Mode::MAX: { - constexpr std::string_view schema = R"( - {0} - {1} - %keepdim_{2} = torch.constant.bool true - {3}_{2}_perm = torch.aten.amax {4}, %reduction_dims_{2}, %keepdim_{2} : {5}, !torch.list, !torch.bool -> {6} - {7} - )"; - return std::format(schema, - permuteX, // {0} - dimListOss.str(), // {1} - suffix, // {2} - getResultNamesAsm(), // {3} - getOperandNamesAsm(), // {4} - getOperandTypesAsm(), // {5} - getResultTypesAsm(), // {6} - permuteY // {7} - ); - } +#define FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER(MODE, OPIR) \ + FUSILLI_DECLARE_REDUCTION_EMITTER(MODE, kKeepdimReductionSchema, OPIR) + +#define FUSILLI_DECLARE_KEEPDIM_DTYPE_REDUCTION_EMITTER(MODE, OPIR) \ + FUSILLI_DECLARE_REDUCTION_EMITTER(MODE, kKeepdimDtypeReductionSchema, OPIR) + + switch (reductionAttr.getMode()) { + FUSILLI_DECLARE_KEEPDIM_DTYPE_REDUCTION_EMITTER(SUM, + torch.aten.sum.dim_IntList) + FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER(MIN, torch.aten.amin) + FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER(MAX, torch.aten.amax) default: return error(ErrorCode::InternalError, "Unsupported reduction mode"); } } +#undef FUSILLI_DECLARE_REDUCTION_EMITTER +#undef FUSILLI_DECLARE_KEEPDIM_REDUCTION_EMITTER +#undef FUSILLI_DECLARE_KEEPDIM_DTYPE_REDUCTION_EMITTER + //===----------------------------------------------------------------------===// // // CustomOpNode ASM Emitter Methods diff --git a/samples/reduction/reduction_ops.cpp b/samples/reduction/reduction_ops.cpp index 9597fcb7..19e3d9a0 100644 --- a/samples/reduction/reduction_ops.cpp +++ b/samples/reduction/reduction_ops.cpp @@ -44,8 +44,12 @@ TEST_CASE("Reduction ops", "[reduction][graph]") { const auto xDims = std::vector{2, 16, 8, 8}; const auto yDims = std::vector{2, 16, 1, 1}; - const auto mode = GENERATE(ReductionAttr::Mode::SUM, ReductionAttr::Mode::MIN, - ReductionAttr::Mode::MAX); + // clang-format off + const auto mode = GENERATE( + ReductionAttr::Mode::SUM, + ReductionAttr::Mode::MIN, + ReductionAttr::Mode::MAX); + // clang-format on auto execute = [&](Handle &handle, DataType dt, T initValue) { // Create graph.