From a1665731a31dd30df9a59378cd8757d290084b74 Mon Sep 17 00:00:00 2001 From: Andrew Savonichev Date: Tue, 28 Apr 2026 20:34:59 +0900 Subject: [PATCH 1/6] [mlir][dxsa] Add AsmPrinter to translate from MLIR to DXSA text AsmPrinter traverses all instructions in an module and prints them in "canonical" assembly syntax. There is no description of its grammar, so the patch follows examples from "Direct3D 11.3 Functional Specification" and output and tests from existing DirectX compiler tools, such as DXC. Only standard instructions and some basic operand types are supported for now. --- mlir/include/mlir/InitAllTranslations.h | 2 + mlir/include/mlir/Target/DXSA/BinaryParser.h | 4 + mlir/lib/Target/DXSA/AsmPrinter.cpp | 362 ++++++++++++++++++ mlir/lib/Target/DXSA/CMakeLists.txt | 1 + .../lib/Target/DXSA/TranslateRegistration.cpp | 9 + mlir/test/Target/DXSA/mov-index.mlir | 19 +- mlir/test/Target/DXSA/mov.mlir | 15 +- mlir/test/Target/DXSA/ret.mlir | 8 +- mlir/test/Target/DXSA/udiv.mlir | 17 +- 9 files changed, 399 insertions(+), 38 deletions(-) create mode 100644 mlir/lib/Target/DXSA/AsmPrinter.cpp diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h index 956bf5db058b..2f4efc4be51f 100644 --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -24,6 +24,7 @@ void registerFromWasmTranslation(); void registerFromDxsaBinTranslation(); void registerToCppTranslation(); void registerToDxsaBinTranslation(); +void registerToDxsaTranslation(); void registerToLLVMIRTranslation(); void registerToSPIRVTranslation(); @@ -43,6 +44,7 @@ inline void registerAllTranslations() { registerFromDxsaBinTranslation(); registerToCppTranslation(); registerToDxsaBinTranslation(); + registerToDxsaTranslation(); registerToLLVMIRTranslation(); registerToSPIRVTranslation(); smt::registerExportSMTLIBTranslation(); diff --git a/mlir/include/mlir/Target/DXSA/BinaryParser.h b/mlir/include/mlir/Target/DXSA/BinaryParser.h index 3b81b8772a06..314c60ac381f 100644 --- a/mlir/include/mlir/Target/DXSA/BinaryParser.h +++ b/mlir/include/mlir/Target/DXSA/BinaryParser.h @@ -21,6 +21,10 @@ OwningOpRef importDxsaBinaryToModule(llvm::SourceMgr &source, MLIRContext *context); /// Encode \p source to DXSA binary. LogicalResult exportModuleToDxsaBinary(ModuleOp source, raw_ostream &output); + +/// Print \p source to DXSA text assembly. +LogicalResult exportModuleToDxsa(ModuleOp source, raw_ostream &output); + } // namespace mlir::dxsa #endif // MLIR_TARGET_DXSA_BINARYPARSER_H diff --git a/mlir/lib/Target/DXSA/AsmPrinter.cpp b/mlir/lib/Target/DXSA/AsmPrinter.cpp new file mode 100644 index 000000000000..94cd45ae7ed7 --- /dev/null +++ b/mlir/lib/Target/DXSA/AsmPrinter.cpp @@ -0,0 +1,362 @@ +#include "mlir/Dialect/DXSA/IR/DXSA.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/Target/DXSA/BinaryParser.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/EndianStream.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/NativeFormatting.h" + +#include "d3d12TokenizedProgramFormat.hpp" + +#define DEBUG_TYPE "export-dxsa-bin" + +using namespace mlir; +using namespace llvm; + +// FIXME: remove OpcodeClass from BinaryParser +enum OpcodeClass { + D3D10_SB_FLOAT_OP, + D3D10_SB_INT_OP, + D3D10_SB_UINT_OP, + D3D10_SB_BIT_OP, + D3D10_SB_FLOW_OP, + D3D10_SB_TEX_OP, + D3D10_SB_DCL_OP, + D3D11_SB_ATOMIC_OP, + D3D11_SB_MEM_OP, + D3D11_SB_DOUBLE_OP, + D3D11_SB_FLOAT_TO_DOUBLE_OP, + D3D11_SB_DOUBLE_TO_FLOAT_OP, + D3D11_SB_DEBUG_OP, +}; + +using OpClassMap = llvm::DenseMap; + +static void initOpClassMap(OpClassMap &opcodes) { +#define SET(OpCode, Name, NumOperands, PrecMask, OpClass) \ + opcodes[Name] = OpClass; +#include "InstrInfo.def" +#undef SET +} + +static void printComponent(raw_ostream &outs, uint32_t v) { + switch (v) { + case D3D10_SB_4_COMPONENT_X: + outs << 'x'; + break; + case D3D10_SB_4_COMPONENT_Y: + outs << 'y'; + break; + case D3D10_SB_4_COMPONENT_Z: + outs << 'z'; + break; + case D3D10_SB_4_COMPONENT_W: + outs << 'w'; + break; + } +} + +static void printComponentMask(raw_ostream &outs, uint32_t mask) { + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_X) { + outs << 'x'; + } + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Y) { + outs << 'y'; + } + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Z) { + outs << 'z'; + } + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_W) { + outs << 'w'; + } +} + +class Printer { +public: + Printer(raw_ostream &output) : outs(output) { initOpClassMap(opClass); } + + LogicalResult emitModule(ModuleOp source) { + Region ®ion = source.getRegion(); + if (!region.hasOneBlock()) { + return emitError(region.getLoc(), "region should contain only one block"); + } + + for (auto &op : region.front()) { + if (auto inst = dyn_cast(op)) { + if (failed(emitInstruction(inst))) { + return failure(); + } + } + } + return success(); + } + + // Emit an instruction and all its operands recursively. + // FIXME: add extended instructions + LogicalResult emitInstruction(dxsa::Instruction inst) { + outs << inst.getMnemonic(); + + if (inst.getOperands().empty()) { + outs << '\n'; + return success(); + } + + StringRef separator = "\t"; + for (Value value : inst.getOperands()) { + Operation *op = value.getDefiningOp(); + assert(op && "undefined operand"); + + outs << separator; + separator = ", "; + + auto result = + llvm::TypeSwitch(*op) + .Case([this](auto &op) { return emitOperand(op); }) + .Case([this, &inst](auto &op) { + return emitOperandImm(op, opClass[inst.getMnemonic()]); + }) + .Default([](auto &op) { + return emitError(op.getLoc(), "unexpected operand kind"); + }); + + if (failed(result)) { + return result; + } + } + outs << '\n'; + return success(); + } + + // Emit an operand and all its indices recursively. + LogicalResult emitOperand(dxsa::Operand op) { + // First index of register-like operands is printed without subscript + // syntax: `r0`, `o0`, cb0 + // + // Some operations must use subscript syntax for the first index: + // `icb[0]` + // + bool firstIndexIsSubscript = false; + + switch (op.getType()) { + case D3D10_SB_OPERAND_TYPE_TEMP: + outs << 'r'; + break; + case D3D10_SB_OPERAND_TYPE_INPUT: { + outs << 'v'; + break; + } + case D3D10_SB_OPERAND_TYPE_OUTPUT: { + outs << 'o'; + break; + } + case D3D10_SB_OPERAND_TYPE_INDEXABLE_TEMP: { + outs << 'x'; + break; + } + case D3D10_SB_OPERAND_TYPE_SAMPLER: { + outs << 's'; + break; + } + case D3D10_SB_OPERAND_TYPE_RESOURCE: { + outs << 't'; + break; + } + case D3D10_SB_OPERAND_TYPE_CONSTANT_BUFFER: { + outs << "cb"; + break; + } + case D3D10_SB_OPERAND_TYPE_IMMEDIATE_CONSTANT_BUFFER: { + outs << "icb"; + firstIndexIsSubscript = true; + break; + } + case D3D10_SB_OPERAND_TYPE_LABEL: + case D3D10_SB_OPERAND_TYPE_INPUT_PRIMITIVEID: + case D3D10_SB_OPERAND_TYPE_OUTPUT_DEPTH: + return emitError(op->getLoc(), "unsupported operand type"); + case D3D10_SB_OPERAND_TYPE_NULL: { + outs << "NULL"; + break; + } + case D3D10_SB_OPERAND_TYPE_RASTERIZER: + case D3D10_SB_OPERAND_TYPE_OUTPUT_COVERAGE_MASK: + case D3D11_SB_OPERAND_TYPE_STREAM: + case D3D11_SB_OPERAND_TYPE_FUNCTION_BODY: + case D3D11_SB_OPERAND_TYPE_FUNCTION_TABLE: + case D3D11_SB_OPERAND_TYPE_INTERFACE: + case D3D11_SB_OPERAND_TYPE_FUNCTION_INPUT: + case D3D11_SB_OPERAND_TYPE_FUNCTION_OUTPUT: + return emitError(op->getLoc(), "unsupported operand type"); + case D3D11_SB_OPERAND_TYPE_OUTPUT_CONTROL_POINT_ID: { + outs << "vOutputControlPointID"; + break; + } + case D3D11_SB_OPERAND_TYPE_INPUT_FORK_INSTANCE_ID: + case D3D11_SB_OPERAND_TYPE_INPUT_JOIN_INSTANCE_ID: + case D3D11_SB_OPERAND_TYPE_INPUT_CONTROL_POINT: + case D3D11_SB_OPERAND_TYPE_OUTPUT_CONTROL_POINT: + case D3D11_SB_OPERAND_TYPE_INPUT_PATCH_CONSTANT: + case D3D11_SB_OPERAND_TYPE_INPUT_DOMAIN_POINT: + case D3D11_SB_OPERAND_TYPE_THIS_POINTER: + case D3D11_SB_OPERAND_TYPE_UNORDERED_ACCESS_VIEW: + case D3D11_SB_OPERAND_TYPE_THREAD_GROUP_SHARED_MEMORY: + case D3D11_SB_OPERAND_TYPE_INPUT_THREAD_ID: + case D3D11_SB_OPERAND_TYPE_INPUT_THREAD_GROUP_ID: + case D3D11_SB_OPERAND_TYPE_INPUT_THREAD_ID_IN_GROUP: + case D3D11_SB_OPERAND_TYPE_INPUT_COVERAGE_MASK: + case D3D11_SB_OPERAND_TYPE_INPUT_THREAD_ID_IN_GROUP_FLATTENED: + case D3D11_SB_OPERAND_TYPE_INPUT_GS_INSTANCE_ID: + case D3D11_SB_OPERAND_TYPE_OUTPUT_DEPTH_GREATER_EQUAL: + case D3D11_SB_OPERAND_TYPE_OUTPUT_DEPTH_LESS_EQUAL: + case D3D11_SB_OPERAND_TYPE_CYCLE_COUNTER: + case D3D11_SB_OPERAND_TYPE_OUTPUT_STENCIL_REF: + case D3D11_SB_OPERAND_TYPE_INNER_COVERAGE: + return emitError(op->getLoc(), "unsupported operand type"); + } + + bool printSubscript = firstIndexIsSubscript; + for (Value value : op.getOperands()) { + Operation *index = value.getDefiningOp(); + assert(index && "undefined index"); + + // Non-immediate indices always use subscript syntax. + if (!isa(*index)) { + printSubscript = true; + } + + if (printSubscript) { + outs << '['; + } + if (failed(emitIndex(index))) { + return failure(); + } + if (printSubscript) { + outs << ']'; + } + + // First index may be a register number (immediate), but other + // indices are always subscripts. + printSubscript = true; + } + + if (auto swizzle = op.getSwizzle()) { + outs << '.'; + for (const APInt &v : *swizzle) { + printComponent(outs, v.getZExtValue()); + } + } else if (auto mask = op.getMask()) { + outs << '.'; + printComponentMask(outs, *mask); + } else if (auto one = op.getOne()) { + outs << '.'; + printComponent(outs, *one); + } + + return success(); + } + + // Emit an immediate operand. Its format (integer or floating point) depends + // on the instruction it is used for. + LogicalResult emitOperandImm(dxsa::OperandImm op, OpcodeClass kind) { + auto attr = cast(op.getImm()); + auto elementType = cast(attr.getType().getElementType()); + + if (elementType.getWidth() != 32) { + return emitError(op.getLoc(), "unsupported immediate operand type"); + } + + // FIXME: encode OperandImm with the correct type in MLIR + bool isInt = false; + switch (kind) { + case D3D10_SB_INT_OP: + case D3D10_SB_UINT_OP: + case D3D10_SB_BIT_OP: + isInt = true; + break; + default: + break; + }; + + bool printVec = attr.getNumElements() > 1 || !isInt; + + if (printVec) { + outs << "l("; + } + + StringRef separator = ""; + for (APInt v : attr) { + outs << separator; + separator = ","; + + uint32_t bits = v.getZExtValue(); + if (isInt) { + outs << bits; + } else { + write_double(outs, llvm::bit_cast(bits), FloatStyle::Fixed, 6); + } + } + if (printVec) { + outs << ")"; + } + + return success(); + } + + LogicalResult emitIndex(Operation *op) { + return llvm::TypeSwitch(*op) + .Case( + [this](auto index) { return emitIndexImm(index); }) + .Case( + [this](auto index) -> LogicalResult { return emitIndexRel(index); }) + .Case([this](auto index) -> LogicalResult { + return emitIndexRelImm(index); + }) + .Default([this](auto &op) { + return emitError(op.getLoc(), "invalid index kind,"); + }); + } + + // Emit an immediate index. + LogicalResult emitIndexImm(dxsa::IndexImm op) { + auto imm = cast(op.getImm()).getInt(); + outs << imm; + return success(); + } + + // Emit an operand used as an index. + LogicalResult emitIndexRel(dxsa::IndexRel index) { + auto operand = cast(index.getOperand().getDefiningOp()); + + // Recursively emit an operand, which may also have other indices. + return emitOperand(operand); + } + + // Emit an index as an operand + a 32 bit immediate offset. + LogicalResult emitIndexRelImm(dxsa::IndexRelImm index) { + auto operand = cast(index.getOperand().getDefiningOp()); + + if (failed(emitOperand(operand))) { + return failure(); + } + + outs << " + " << index.getImm(); + return success(); + } + +private: + raw_ostream &outs; + OpClassMap opClass; +}; + +namespace mlir::dxsa { +LogicalResult exportModuleToDxsa(ModuleOp source, raw_ostream &output) { + Printer printer(output); + return printer.emitModule(source); +} +} // namespace mlir::dxsa diff --git a/mlir/lib/Target/DXSA/CMakeLists.txt b/mlir/lib/Target/DXSA/CMakeLists.txt index 06d638fde18d..b0800c1c211d 100644 --- a/mlir/lib/Target/DXSA/CMakeLists.txt +++ b/mlir/lib/Target/DXSA/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_translation_library(MLIRTargetDXSA + AsmPrinter.cpp BinaryParser.cpp BinaryWriter.cpp TranslateRegistration.cpp diff --git a/mlir/lib/Target/DXSA/TranslateRegistration.cpp b/mlir/lib/Target/DXSA/TranslateRegistration.cpp index 5dcf9307668a..e58bddefe244 100644 --- a/mlir/lib/Target/DXSA/TranslateRegistration.cpp +++ b/mlir/lib/Target/DXSA/TranslateRegistration.cpp @@ -32,4 +32,13 @@ void registerToDxsaBinTranslation() { }, [](DialectRegistry ®istry) { registry.insert(); }}; } + +void registerToDxsaTranslation() { + TranslateFromMLIRRegistration registration{ + "export-dxsa", "Translate MLIR to DXSA", + [](ModuleOp source, raw_ostream &output) { + return dxsa::exportModuleToDxsa(source, output); + }, + [](DialectRegistry ®istry) { registry.insert(); }}; +} } // namespace mlir diff --git a/mlir/test/Target/DXSA/mov-index.mlir b/mlir/test/Target/DXSA/mov-index.mlir index eb2323ebdb53..6cc17957f49a 100644 --- a/mlir/test/Target/DXSA/mov-index.mlir +++ b/mlir/test/Target/DXSA/mov-index.mlir @@ -1,9 +1,11 @@ -// RUN: mlir-translate --import-dxsa-bin %S/inputs/mov-index.bin | FileCheck %s -// RUN: mlir-translate --export-dxsa-bin %s -o %t.bin +// RUN: mlir-translate --import-dxsa-bin %S/inputs/mov-index.bin -o %t.mlir +// RUN: FileCheck %s --input-file %t.mlir +// RUN: mlir-translate --export-dxsa-bin %t.mlir -o %t.bin // RUN: mlir-translate --import-dxsa-bin %t.bin | FileCheck %s // RUN: diff %t.bin %S/inputs/mov-index.bin +// RUN: mlir-translate --export-dxsa %t.mlir -o - | FileCheck %s --check-prefix ASM -// mov o0.xyzw, v[r0.x][0].xyzw +// ASM: mov o0.xyzw, v[r0.x][0].xyzw // CHECK: module { // CHECK-NEXT: %0 = dxsa.index.imm {imm = 0 : i32} @@ -15,14 +17,3 @@ // CHECK-NEXT: %6 = dxsa.operand %4, %5 {num_components = 4 : i32, swizzle = dense<[0, 1, 2, 3]> : vector<4xi32>, type = 1 : i32} // CHECK-NEXT: dxsa.instruction "mov" %1, %6 // CHECK-NEXT: } - -module { - %0 = dxsa.index.imm {imm = 0 : i32} - %1 = dxsa.operand %0 {mask = 240 : i32, num_components = 4 : i32, type = 2 : i32} - %2 = dxsa.index.imm {imm = 0 : i32} - %3 = dxsa.operand %2 {num_components = 4 : i32, one = 0 : i32, type = 0 : i32} - %4 = dxsa.index.rel %3 - %5 = dxsa.index.imm {imm = 0 : i32} - %6 = dxsa.operand %4, %5 {num_components = 4 : i32, swizzle = dense<[0, 1, 2, 3]> : vector<4xi32>, type = 1 : i32} - dxsa.instruction "mov" %1, %6 -} diff --git a/mlir/test/Target/DXSA/mov.mlir b/mlir/test/Target/DXSA/mov.mlir index de2854b3e0a4..8e6f851df6d0 100644 --- a/mlir/test/Target/DXSA/mov.mlir +++ b/mlir/test/Target/DXSA/mov.mlir @@ -1,9 +1,11 @@ -// RUN: mlir-translate --import-dxsa-bin %S/inputs/mov.bin | FileCheck %s -// RUN: mlir-translate --export-dxsa-bin %s -o %t.bin +// RUN: mlir-translate --import-dxsa-bin %S/inputs/mov.bin -o %t.mlir +// RUN: FileCheck %s --input-file %t.mlir +// RUN: mlir-translate --export-dxsa-bin %t.mlir -o %t.bin // RUN: mlir-translate --import-dxsa-bin %t.bin | FileCheck %s // RUN: diff %t.bin %S/inputs/mov.bin +// RUN: mlir-translate --export-dxsa %t.mlir -o - | FileCheck %s --check-prefix ASM -// mov r0.x, l(3.000000) +// ASM: mov r0.x, l(3.000000) // CHECK: module { // CHECK-NEXT: %0 = dxsa.index.imm {imm = 0 : i32} @@ -11,10 +13,3 @@ // CHECK-NEXT: %2 = dxsa.operand.imm {imm = dense<1077936128> : vector<1xi32>} // CHECK-NEXT: dxsa.instruction "mov" %1, %2 // CHECK-NEXT: } - -module { - %0 = dxsa.index.imm {imm = 0 : i32} - %1 = dxsa.operand %0 {mask = 16 : i32, num_components = 4 : i32, type = 0 : i32} - %2 = dxsa.operand.imm {imm = dense<1077936128> : vector<1xi32>} - dxsa.instruction "mov" %1, %2 -} diff --git a/mlir/test/Target/DXSA/ret.mlir b/mlir/test/Target/DXSA/ret.mlir index b9b352bcab74..905a57492549 100644 --- a/mlir/test/Target/DXSA/ret.mlir +++ b/mlir/test/Target/DXSA/ret.mlir @@ -1,7 +1,11 @@ -// RUN: mlir-translate --import-dxsa-bin %S/inputs/ret.bin | FileCheck %s -// RUN: mlir-translate --export-dxsa-bin %s -o %t.bin +// RUN: mlir-translate --import-dxsa-bin %S/inputs/ret.bin -o %t.mlir +// RUN: FileCheck %s --input-file %t.mlir +// RUN: mlir-translate --export-dxsa-bin %t.mlir -o %t.bin // RUN: mlir-translate --import-dxsa-bin %t.bin | FileCheck %s // RUN: diff %t.bin %S/inputs/ret.bin +// RUN: mlir-translate --export-dxsa %t.mlir -o - | FileCheck %s --check-prefix ASM + +// ASM: ret // CHECK: module { // CHECK-NEXT: dxsa.instruction "ret" diff --git a/mlir/test/Target/DXSA/udiv.mlir b/mlir/test/Target/DXSA/udiv.mlir index 69f13aa71488..68af092f7308 100644 --- a/mlir/test/Target/DXSA/udiv.mlir +++ b/mlir/test/Target/DXSA/udiv.mlir @@ -1,9 +1,11 @@ -// RUN: mlir-translate --import-dxsa-bin %S/inputs/udiv.bin | FileCheck %s -// RUN: mlir-translate --export-dxsa-bin %s -o %t.bin +// RUN: mlir-translate --import-dxsa-bin %S/inputs/udiv.bin -o %t.mlir +// RUN: FileCheck %s --input-file %t.mlir +// RUN: mlir-translate --export-dxsa-bin %t.mlir -o %t.bin // RUN: mlir-translate --import-dxsa-bin %t.bin | FileCheck %s // RUN: diff %t.bin %S/inputs/udiv.bin +// RUN: mlir-translate --export-dxsa %t.mlir -o - | FileCheck %s --check-prefix ASM -// udiv NULL, r0.x, vOutputControlPointID, 4 +// ASM: udiv NULL, r0.x, vOutputControlPointID, 4 // CHECK: module { // CHECK-NEXT: %0 = dxsa.operand {num_components = 0 : i32, type = 13 : i32} @@ -13,12 +15,3 @@ // CHECK-NEXT: %4 = dxsa.operand.imm {imm = dense<4> : vector<1xi32>} // CHECK-NEXT: dxsa.instruction "udiv" %0, %2, %3, %4 // CHECK-NEXT: } - -module { - %0 = dxsa.operand {num_components = 0 : i32, type = 13 : i32} - %1 = dxsa.index.imm {imm = 0 : i32} - %2 = dxsa.operand %1 {mask = 16 : i32, num_components = 4 : i32, type = 0 : i32} - %3 = dxsa.operand {num_components = 1 : i32, type = 22 : i32} - %4 = dxsa.operand.imm {imm = dense<4> : vector<1xi32>} - dxsa.instruction "udiv" %0, %2, %3, %4 -} From ae17f26a44217c8076f9e77be3c319873e99a4f7 Mon Sep 17 00:00:00 2001 From: Andrew Savonichev Date: Sat, 9 May 2026 22:56:08 +0900 Subject: [PATCH 2/6] Address code review comments --- mlir/lib/Target/DXSA/AsmPrinter.cpp | 57 +++++++++++------------------ 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/mlir/lib/Target/DXSA/AsmPrinter.cpp b/mlir/lib/Target/DXSA/AsmPrinter.cpp index 94cd45ae7ed7..79d41b11a9d3 100644 --- a/mlir/lib/Target/DXSA/AsmPrinter.cpp +++ b/mlir/lib/Target/DXSA/AsmPrinter.cpp @@ -63,18 +63,14 @@ static void printComponent(raw_ostream &outs, uint32_t v) { } static void printComponentMask(raw_ostream &outs, uint32_t mask) { - if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_X) { + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_X) outs << 'x'; - } - if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Y) { + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Y) outs << 'y'; - } - if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Z) { + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Z) outs << 'z'; - } - if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_W) { + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_W) outs << 'w'; - } } class Printer { @@ -83,15 +79,13 @@ class Printer { LogicalResult emitModule(ModuleOp source) { Region ®ion = source.getRegion(); - if (!region.hasOneBlock()) { + if (!region.hasOneBlock()) return emitError(region.getLoc(), "region should contain only one block"); - } for (auto &op : region.front()) { if (auto inst = dyn_cast(op)) { - if (failed(emitInstruction(inst))) { + if (failed(emitInstruction(inst))) return failure(); - } } } return success(); @@ -125,9 +119,8 @@ class Printer { return emitError(op.getLoc(), "unexpected operand kind"); }); - if (failed(result)) { + if (failed(result)) return result; - } } outs << '\n'; return success(); @@ -226,19 +219,17 @@ class Printer { assert(index && "undefined index"); // Non-immediate indices always use subscript syntax. - if (!isa(*index)) { + if (!isa(*index)) printSubscript = true; - } - if (printSubscript) { + if (printSubscript) outs << '['; - } - if (failed(emitIndex(index))) { + + if (failed(emitIndex(index))) return failure(); - } - if (printSubscript) { + + if (printSubscript) outs << ']'; - } // First index may be a register number (immediate), but other // indices are always subscripts. @@ -247,9 +238,8 @@ class Printer { if (auto swizzle = op.getSwizzle()) { outs << '.'; - for (const APInt &v : *swizzle) { + for (const APInt &v : *swizzle) printComponent(outs, v.getZExtValue()); - } } else if (auto mask = op.getMask()) { outs << '.'; printComponentMask(outs, *mask); @@ -267,9 +257,8 @@ class Printer { auto attr = cast(op.getImm()); auto elementType = cast(attr.getType().getElementType()); - if (elementType.getWidth() != 32) { + if (elementType.getWidth() != 32) return emitError(op.getLoc(), "unsupported immediate operand type"); - } // FIXME: encode OperandImm with the correct type in MLIR bool isInt = false; @@ -285,25 +274,22 @@ class Printer { bool printVec = attr.getNumElements() > 1 || !isInt; - if (printVec) { + if (printVec) outs << "l("; - } StringRef separator = ""; - for (APInt v : attr) { + for (const APInt &v : attr) { outs << separator; separator = ","; uint32_t bits = v.getZExtValue(); - if (isInt) { + if (isInt) outs << bits; - } else { + else write_double(outs, llvm::bit_cast(bits), FloatStyle::Fixed, 6); - } } - if (printVec) { + if (printVec) outs << ")"; - } return success(); } @@ -341,9 +327,8 @@ class Printer { LogicalResult emitIndexRelImm(dxsa::IndexRelImm index) { auto operand = cast(index.getOperand().getDefiningOp()); - if (failed(emitOperand(operand))) { + if (failed(emitOperand(operand))) return failure(); - } outs << " + " << index.getImm(); return success(); From e75dd076c739e2ebffcc1721bc9f4cc77f2d67f0 Mon Sep 17 00:00:00 2001 From: Andrew Savonichev Date: Tue, 12 May 2026 22:28:49 +0900 Subject: [PATCH 3/6] Add a FIXME for 64-bit immediate operands --- mlir/lib/Target/DXSA/AsmPrinter.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Target/DXSA/AsmPrinter.cpp b/mlir/lib/Target/DXSA/AsmPrinter.cpp index 79d41b11a9d3..2666dfae3b8e 100644 --- a/mlir/lib/Target/DXSA/AsmPrinter.cpp +++ b/mlir/lib/Target/DXSA/AsmPrinter.cpp @@ -257,6 +257,7 @@ class Printer { auto attr = cast(op.getImm()); auto elementType = cast(attr.getType().getElementType()); + // FIXME: how 64-bit immediates should be printed? if (elementType.getWidth() != 32) return emitError(op.getLoc(), "unsupported immediate operand type"); From d80b58bac48bf2d42089ff58846c7140682efd68 Mon Sep 17 00:00:00 2001 From: Andrew Savonichev Date: Wed, 13 May 2026 22:30:23 +0900 Subject: [PATCH 4/6] Use ModuleOp::getBody --- mlir/lib/Target/DXSA/AsmPrinter.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mlir/lib/Target/DXSA/AsmPrinter.cpp b/mlir/lib/Target/DXSA/AsmPrinter.cpp index 2666dfae3b8e..086a632a963d 100644 --- a/mlir/lib/Target/DXSA/AsmPrinter.cpp +++ b/mlir/lib/Target/DXSA/AsmPrinter.cpp @@ -78,11 +78,7 @@ class Printer { Printer(raw_ostream &output) : outs(output) { initOpClassMap(opClass); } LogicalResult emitModule(ModuleOp source) { - Region ®ion = source.getRegion(); - if (!region.hasOneBlock()) - return emitError(region.getLoc(), "region should contain only one block"); - - for (auto &op : region.front()) { + for (auto &op : *source.getBody()) { if (auto inst = dyn_cast(op)) { if (failed(emitInstruction(inst))) return failure(); From 3d064af8621c712028df6425c32f093d7976907c Mon Sep 17 00:00:00 2001 From: Andrew Savonichev Date: Wed, 13 May 2026 22:32:17 +0900 Subject: [PATCH 5/6] Fix typo --- mlir/lib/Target/DXSA/AsmPrinter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Target/DXSA/AsmPrinter.cpp b/mlir/lib/Target/DXSA/AsmPrinter.cpp index 086a632a963d..80863c38cbaf 100644 --- a/mlir/lib/Target/DXSA/AsmPrinter.cpp +++ b/mlir/lib/Target/DXSA/AsmPrinter.cpp @@ -301,7 +301,7 @@ class Printer { return emitIndexRelImm(index); }) .Default([this](auto &op) { - return emitError(op.getLoc(), "invalid index kind,"); + return emitError(op.getLoc(), "invalid index kind"); }); } From 1c09090a49a119b28ccc4ac3b41954457293bbff Mon Sep 17 00:00:00 2001 From: Andrew Savonichev Date: Fri, 15 May 2026 20:41:02 +0900 Subject: [PATCH 6/6] Remove explicit return types, use interleaveComma --- mlir/lib/Target/DXSA/AsmPrinter.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Target/DXSA/AsmPrinter.cpp b/mlir/lib/Target/DXSA/AsmPrinter.cpp index 80863c38cbaf..1ea598b20393 100644 --- a/mlir/lib/Target/DXSA/AsmPrinter.cpp +++ b/mlir/lib/Target/DXSA/AsmPrinter.cpp @@ -4,6 +4,7 @@ #include "mlir/Target/DXSA/BinaryParser.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -274,17 +275,14 @@ class Printer { if (printVec) outs << "l("; - StringRef separator = ""; - for (const APInt &v : attr) { - outs << separator; - separator = ","; - + interleaveComma(attr, outs, [isInt, this](const APInt &v) { uint32_t bits = v.getZExtValue(); if (isInt) outs << bits; else write_double(outs, llvm::bit_cast(bits), FloatStyle::Fixed, 6); - } + }); + if (printVec) outs << ")"; @@ -296,10 +294,9 @@ class Printer { .Case( [this](auto index) { return emitIndexImm(index); }) .Case( - [this](auto index) -> LogicalResult { return emitIndexRel(index); }) - .Case([this](auto index) -> LogicalResult { - return emitIndexRelImm(index); - }) + [this](auto index) { return emitIndexRel(index); }) + .Case( + [this](auto index) { return emitIndexRelImm(index); }) .Default([this](auto &op) { return emitError(op.getLoc(), "invalid index kind"); });