diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h index 956bf5db058b..2f4efc4be51f 100644 --- a/mlir/include/mlir/InitAllTranslations.h +++ b/mlir/include/mlir/InitAllTranslations.h @@ -24,6 +24,7 @@ void registerFromWasmTranslation(); void registerFromDxsaBinTranslation(); void registerToCppTranslation(); void registerToDxsaBinTranslation(); +void registerToDxsaTranslation(); void registerToLLVMIRTranslation(); void registerToSPIRVTranslation(); @@ -43,6 +44,7 @@ inline void registerAllTranslations() { registerFromDxsaBinTranslation(); registerToCppTranslation(); registerToDxsaBinTranslation(); + registerToDxsaTranslation(); registerToLLVMIRTranslation(); registerToSPIRVTranslation(); smt::registerExportSMTLIBTranslation(); diff --git a/mlir/include/mlir/Target/DXSA/BinaryParser.h b/mlir/include/mlir/Target/DXSA/BinaryParser.h index 3b81b8772a06..314c60ac381f 100644 --- a/mlir/include/mlir/Target/DXSA/BinaryParser.h +++ b/mlir/include/mlir/Target/DXSA/BinaryParser.h @@ -21,6 +21,10 @@ OwningOpRef importDxsaBinaryToModule(llvm::SourceMgr &source, MLIRContext *context); /// Encode \p source to DXSA binary. LogicalResult exportModuleToDxsaBinary(ModuleOp source, raw_ostream &output); + +/// Print \p source to DXSA text assembly. +LogicalResult exportModuleToDxsa(ModuleOp source, raw_ostream &output); + } // namespace mlir::dxsa #endif // MLIR_TARGET_DXSA_BINARYPARSER_H diff --git a/mlir/lib/Target/DXSA/AsmPrinter.cpp b/mlir/lib/Target/DXSA/AsmPrinter.cpp new file mode 100644 index 000000000000..1ea598b20393 --- /dev/null +++ b/mlir/lib/Target/DXSA/AsmPrinter.cpp @@ -0,0 +1,341 @@ +#include "mlir/Dialect/DXSA/IR/DXSA.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Location.h" +#include "mlir/Target/DXSA/BinaryParser.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/EndianStream.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/NativeFormatting.h" + +#include "d3d12TokenizedProgramFormat.hpp" + +#define DEBUG_TYPE "export-dxsa-bin" + +using namespace mlir; +using namespace llvm; + +// FIXME: remove OpcodeClass from BinaryParser +enum OpcodeClass { + D3D10_SB_FLOAT_OP, + D3D10_SB_INT_OP, + D3D10_SB_UINT_OP, + D3D10_SB_BIT_OP, + D3D10_SB_FLOW_OP, + D3D10_SB_TEX_OP, + D3D10_SB_DCL_OP, + D3D11_SB_ATOMIC_OP, + D3D11_SB_MEM_OP, + D3D11_SB_DOUBLE_OP, + D3D11_SB_FLOAT_TO_DOUBLE_OP, + D3D11_SB_DOUBLE_TO_FLOAT_OP, + D3D11_SB_DEBUG_OP, +}; + +using OpClassMap = llvm::DenseMap; + +static void initOpClassMap(OpClassMap &opcodes) { +#define SET(OpCode, Name, NumOperands, PrecMask, OpClass) \ + opcodes[Name] = OpClass; +#include "InstrInfo.def" +#undef SET +} + +static void printComponent(raw_ostream &outs, uint32_t v) { + switch (v) { + case D3D10_SB_4_COMPONENT_X: + outs << 'x'; + break; + case D3D10_SB_4_COMPONENT_Y: + outs << 'y'; + break; + case D3D10_SB_4_COMPONENT_Z: + outs << 'z'; + break; + case D3D10_SB_4_COMPONENT_W: + outs << 'w'; + break; + } +} + +static void printComponentMask(raw_ostream &outs, uint32_t mask) { + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_X) + outs << 'x'; + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Y) + outs << 'y'; + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_Z) + outs << 'z'; + if (mask & D3D10_SB_OPERAND_4_COMPONENT_MASK_W) + outs << 'w'; +} + +class Printer { +public: + Printer(raw_ostream &output) : outs(output) { initOpClassMap(opClass); } + + LogicalResult emitModule(ModuleOp source) { + for (auto &op : *source.getBody()) { + if (auto inst = dyn_cast(op)) { + if (failed(emitInstruction(inst))) + return failure(); + } + } + return success(); + } + + // Emit an instruction and all its operands recursively. + // FIXME: add extended instructions + LogicalResult emitInstruction(dxsa::Instruction inst) { + outs << inst.getMnemonic(); + + if (inst.getOperands().empty()) { + outs << '\n'; + return success(); + } + + StringRef separator = "\t"; + for (Value value : inst.getOperands()) { + Operation *op = value.getDefiningOp(); + assert(op && "undefined operand"); + + outs << separator; + separator = ", "; + + auto result = + llvm::TypeSwitch(*op) + .Case([this](auto &op) { return emitOperand(op); }) + .Case([this, &inst](auto &op) { + return emitOperandImm(op, opClass[inst.getMnemonic()]); + }) + .Default([](auto &op) { + return emitError(op.getLoc(), "unexpected operand kind"); + }); + + if (failed(result)) + return result; + } + outs << '\n'; + return success(); + } + + // Emit an operand and all its indices recursively. + LogicalResult emitOperand(dxsa::Operand op) { + // First index of register-like operands is printed without subscript + // syntax: `r0`, `o0`, cb0 + // + // Some operations must use subscript syntax for the first index: + // `icb[0]` + // + bool firstIndexIsSubscript = false; + + switch (op.getType()) { + case D3D10_SB_OPERAND_TYPE_TEMP: + outs << 'r'; + break; + case D3D10_SB_OPERAND_TYPE_INPUT: { + outs << 'v'; + break; + } + case D3D10_SB_OPERAND_TYPE_OUTPUT: { + outs << 'o'; + break; + } + case D3D10_SB_OPERAND_TYPE_INDEXABLE_TEMP: { + outs << 'x'; + break; + } + case D3D10_SB_OPERAND_TYPE_SAMPLER: { + outs << 's'; + break; + } + case D3D10_SB_OPERAND_TYPE_RESOURCE: { + outs << 't'; + break; + } + case D3D10_SB_OPERAND_TYPE_CONSTANT_BUFFER: { + outs << "cb"; + break; + } + case D3D10_SB_OPERAND_TYPE_IMMEDIATE_CONSTANT_BUFFER: { + outs << "icb"; + firstIndexIsSubscript = true; + break; + } + case D3D10_SB_OPERAND_TYPE_LABEL: + case D3D10_SB_OPERAND_TYPE_INPUT_PRIMITIVEID: + case D3D10_SB_OPERAND_TYPE_OUTPUT_DEPTH: + return emitError(op->getLoc(), "unsupported operand type"); + case D3D10_SB_OPERAND_TYPE_NULL: { + outs << "NULL"; + break; + } + case D3D10_SB_OPERAND_TYPE_RASTERIZER: + case D3D10_SB_OPERAND_TYPE_OUTPUT_COVERAGE_MASK: + case D3D11_SB_OPERAND_TYPE_STREAM: + case D3D11_SB_OPERAND_TYPE_FUNCTION_BODY: + case D3D11_SB_OPERAND_TYPE_FUNCTION_TABLE: + case D3D11_SB_OPERAND_TYPE_INTERFACE: + case D3D11_SB_OPERAND_TYPE_FUNCTION_INPUT: + case D3D11_SB_OPERAND_TYPE_FUNCTION_OUTPUT: + return emitError(op->getLoc(), "unsupported operand type"); + case D3D11_SB_OPERAND_TYPE_OUTPUT_CONTROL_POINT_ID: { + outs << "vOutputControlPointID"; + break; + } + case D3D11_SB_OPERAND_TYPE_INPUT_FORK_INSTANCE_ID: + case D3D11_SB_OPERAND_TYPE_INPUT_JOIN_INSTANCE_ID: + case D3D11_SB_OPERAND_TYPE_INPUT_CONTROL_POINT: + case D3D11_SB_OPERAND_TYPE_OUTPUT_CONTROL_POINT: + case D3D11_SB_OPERAND_TYPE_INPUT_PATCH_CONSTANT: + case D3D11_SB_OPERAND_TYPE_INPUT_DOMAIN_POINT: + case D3D11_SB_OPERAND_TYPE_THIS_POINTER: + case D3D11_SB_OPERAND_TYPE_UNORDERED_ACCESS_VIEW: + case D3D11_SB_OPERAND_TYPE_THREAD_GROUP_SHARED_MEMORY: + case D3D11_SB_OPERAND_TYPE_INPUT_THREAD_ID: + case D3D11_SB_OPERAND_TYPE_INPUT_THREAD_GROUP_ID: + case D3D11_SB_OPERAND_TYPE_INPUT_THREAD_ID_IN_GROUP: + case D3D11_SB_OPERAND_TYPE_INPUT_COVERAGE_MASK: + case D3D11_SB_OPERAND_TYPE_INPUT_THREAD_ID_IN_GROUP_FLATTENED: + case D3D11_SB_OPERAND_TYPE_INPUT_GS_INSTANCE_ID: + case D3D11_SB_OPERAND_TYPE_OUTPUT_DEPTH_GREATER_EQUAL: + case D3D11_SB_OPERAND_TYPE_OUTPUT_DEPTH_LESS_EQUAL: + case D3D11_SB_OPERAND_TYPE_CYCLE_COUNTER: + case D3D11_SB_OPERAND_TYPE_OUTPUT_STENCIL_REF: + case D3D11_SB_OPERAND_TYPE_INNER_COVERAGE: + return emitError(op->getLoc(), "unsupported operand type"); + } + + bool printSubscript = firstIndexIsSubscript; + for (Value value : op.getOperands()) { + Operation *index = value.getDefiningOp(); + assert(index && "undefined index"); + + // Non-immediate indices always use subscript syntax. + if (!isa(*index)) + printSubscript = true; + + if (printSubscript) + outs << '['; + + if (failed(emitIndex(index))) + return failure(); + + if (printSubscript) + outs << ']'; + + // First index may be a register number (immediate), but other + // indices are always subscripts. + printSubscript = true; + } + + if (auto swizzle = op.getSwizzle()) { + outs << '.'; + for (const APInt &v : *swizzle) + printComponent(outs, v.getZExtValue()); + } else if (auto mask = op.getMask()) { + outs << '.'; + printComponentMask(outs, *mask); + } else if (auto one = op.getOne()) { + outs << '.'; + printComponent(outs, *one); + } + + return success(); + } + + // Emit an immediate operand. Its format (integer or floating point) depends + // on the instruction it is used for. + LogicalResult emitOperandImm(dxsa::OperandImm op, OpcodeClass kind) { + auto attr = cast(op.getImm()); + auto elementType = cast(attr.getType().getElementType()); + + // FIXME: how 64-bit immediates should be printed? + if (elementType.getWidth() != 32) + return emitError(op.getLoc(), "unsupported immediate operand type"); + + // FIXME: encode OperandImm with the correct type in MLIR + bool isInt = false; + switch (kind) { + case D3D10_SB_INT_OP: + case D3D10_SB_UINT_OP: + case D3D10_SB_BIT_OP: + isInt = true; + break; + default: + break; + }; + + bool printVec = attr.getNumElements() > 1 || !isInt; + + if (printVec) + outs << "l("; + + interleaveComma(attr, outs, [isInt, this](const APInt &v) { + uint32_t bits = v.getZExtValue(); + if (isInt) + outs << bits; + else + write_double(outs, llvm::bit_cast(bits), FloatStyle::Fixed, 6); + }); + + if (printVec) + outs << ")"; + + return success(); + } + + LogicalResult emitIndex(Operation *op) { + return llvm::TypeSwitch(*op) + .Case( + [this](auto index) { return emitIndexImm(index); }) + .Case( + [this](auto index) { return emitIndexRel(index); }) + .Case( + [this](auto index) { return emitIndexRelImm(index); }) + .Default([this](auto &op) { + return emitError(op.getLoc(), "invalid index kind"); + }); + } + + // Emit an immediate index. + LogicalResult emitIndexImm(dxsa::IndexImm op) { + auto imm = cast(op.getImm()).getInt(); + outs << imm; + return success(); + } + + // Emit an operand used as an index. + LogicalResult emitIndexRel(dxsa::IndexRel index) { + auto operand = cast(index.getOperand().getDefiningOp()); + + // Recursively emit an operand, which may also have other indices. + return emitOperand(operand); + } + + // Emit an index as an operand + a 32 bit immediate offset. + LogicalResult emitIndexRelImm(dxsa::IndexRelImm index) { + auto operand = cast(index.getOperand().getDefiningOp()); + + if (failed(emitOperand(operand))) + return failure(); + + outs << " + " << index.getImm(); + return success(); + } + +private: + raw_ostream &outs; + OpClassMap opClass; +}; + +namespace mlir::dxsa { +LogicalResult exportModuleToDxsa(ModuleOp source, raw_ostream &output) { + Printer printer(output); + return printer.emitModule(source); +} +} // namespace mlir::dxsa diff --git a/mlir/lib/Target/DXSA/CMakeLists.txt b/mlir/lib/Target/DXSA/CMakeLists.txt index 06d638fde18d..b0800c1c211d 100644 --- a/mlir/lib/Target/DXSA/CMakeLists.txt +++ b/mlir/lib/Target/DXSA/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_translation_library(MLIRTargetDXSA + AsmPrinter.cpp BinaryParser.cpp BinaryWriter.cpp TranslateRegistration.cpp diff --git a/mlir/lib/Target/DXSA/TranslateRegistration.cpp b/mlir/lib/Target/DXSA/TranslateRegistration.cpp index 5dcf9307668a..e58bddefe244 100644 --- a/mlir/lib/Target/DXSA/TranslateRegistration.cpp +++ b/mlir/lib/Target/DXSA/TranslateRegistration.cpp @@ -32,4 +32,13 @@ void registerToDxsaBinTranslation() { }, [](DialectRegistry ®istry) { registry.insert(); }}; } + +void registerToDxsaTranslation() { + TranslateFromMLIRRegistration registration{ + "export-dxsa", "Translate MLIR to DXSA", + [](ModuleOp source, raw_ostream &output) { + return dxsa::exportModuleToDxsa(source, output); + }, + [](DialectRegistry ®istry) { registry.insert(); }}; +} } // namespace mlir diff --git a/mlir/test/Target/DXSA/mov-index.mlir b/mlir/test/Target/DXSA/mov-index.mlir index eb2323ebdb53..6cc17957f49a 100644 --- a/mlir/test/Target/DXSA/mov-index.mlir +++ b/mlir/test/Target/DXSA/mov-index.mlir @@ -1,9 +1,11 @@ -// RUN: mlir-translate --import-dxsa-bin %S/inputs/mov-index.bin | FileCheck %s -// RUN: mlir-translate --export-dxsa-bin %s -o %t.bin +// RUN: mlir-translate --import-dxsa-bin %S/inputs/mov-index.bin -o %t.mlir +// RUN: FileCheck %s --input-file %t.mlir +// RUN: mlir-translate --export-dxsa-bin %t.mlir -o %t.bin // RUN: mlir-translate --import-dxsa-bin %t.bin | FileCheck %s // RUN: diff %t.bin %S/inputs/mov-index.bin +// RUN: mlir-translate --export-dxsa %t.mlir -o - | FileCheck %s --check-prefix ASM -// mov o0.xyzw, v[r0.x][0].xyzw +// ASM: mov o0.xyzw, v[r0.x][0].xyzw // CHECK: module { // CHECK-NEXT: %0 = dxsa.index.imm {imm = 0 : i32} @@ -15,14 +17,3 @@ // CHECK-NEXT: %6 = dxsa.operand %4, %5 {num_components = 4 : i32, swizzle = dense<[0, 1, 2, 3]> : vector<4xi32>, type = 1 : i32} // CHECK-NEXT: dxsa.instruction "mov" %1, %6 // CHECK-NEXT: } - -module { - %0 = dxsa.index.imm {imm = 0 : i32} - %1 = dxsa.operand %0 {mask = 240 : i32, num_components = 4 : i32, type = 2 : i32} - %2 = dxsa.index.imm {imm = 0 : i32} - %3 = dxsa.operand %2 {num_components = 4 : i32, one = 0 : i32, type = 0 : i32} - %4 = dxsa.index.rel %3 - %5 = dxsa.index.imm {imm = 0 : i32} - %6 = dxsa.operand %4, %5 {num_components = 4 : i32, swizzle = dense<[0, 1, 2, 3]> : vector<4xi32>, type = 1 : i32} - dxsa.instruction "mov" %1, %6 -} diff --git a/mlir/test/Target/DXSA/mov.mlir b/mlir/test/Target/DXSA/mov.mlir index de2854b3e0a4..8e6f851df6d0 100644 --- a/mlir/test/Target/DXSA/mov.mlir +++ b/mlir/test/Target/DXSA/mov.mlir @@ -1,9 +1,11 @@ -// RUN: mlir-translate --import-dxsa-bin %S/inputs/mov.bin | FileCheck %s -// RUN: mlir-translate --export-dxsa-bin %s -o %t.bin +// RUN: mlir-translate --import-dxsa-bin %S/inputs/mov.bin -o %t.mlir +// RUN: FileCheck %s --input-file %t.mlir +// RUN: mlir-translate --export-dxsa-bin %t.mlir -o %t.bin // RUN: mlir-translate --import-dxsa-bin %t.bin | FileCheck %s // RUN: diff %t.bin %S/inputs/mov.bin +// RUN: mlir-translate --export-dxsa %t.mlir -o - | FileCheck %s --check-prefix ASM -// mov r0.x, l(3.000000) +// ASM: mov r0.x, l(3.000000) // CHECK: module { // CHECK-NEXT: %0 = dxsa.index.imm {imm = 0 : i32} @@ -11,10 +13,3 @@ // CHECK-NEXT: %2 = dxsa.operand.imm {imm = dense<1077936128> : vector<1xi32>} // CHECK-NEXT: dxsa.instruction "mov" %1, %2 // CHECK-NEXT: } - -module { - %0 = dxsa.index.imm {imm = 0 : i32} - %1 = dxsa.operand %0 {mask = 16 : i32, num_components = 4 : i32, type = 0 : i32} - %2 = dxsa.operand.imm {imm = dense<1077936128> : vector<1xi32>} - dxsa.instruction "mov" %1, %2 -} diff --git a/mlir/test/Target/DXSA/ret.mlir b/mlir/test/Target/DXSA/ret.mlir index b9b352bcab74..905a57492549 100644 --- a/mlir/test/Target/DXSA/ret.mlir +++ b/mlir/test/Target/DXSA/ret.mlir @@ -1,7 +1,11 @@ -// RUN: mlir-translate --import-dxsa-bin %S/inputs/ret.bin | FileCheck %s -// RUN: mlir-translate --export-dxsa-bin %s -o %t.bin +// RUN: mlir-translate --import-dxsa-bin %S/inputs/ret.bin -o %t.mlir +// RUN: FileCheck %s --input-file %t.mlir +// RUN: mlir-translate --export-dxsa-bin %t.mlir -o %t.bin // RUN: mlir-translate --import-dxsa-bin %t.bin | FileCheck %s // RUN: diff %t.bin %S/inputs/ret.bin +// RUN: mlir-translate --export-dxsa %t.mlir -o - | FileCheck %s --check-prefix ASM + +// ASM: ret // CHECK: module { // CHECK-NEXT: dxsa.instruction "ret" diff --git a/mlir/test/Target/DXSA/udiv.mlir b/mlir/test/Target/DXSA/udiv.mlir index 69f13aa71488..68af092f7308 100644 --- a/mlir/test/Target/DXSA/udiv.mlir +++ b/mlir/test/Target/DXSA/udiv.mlir @@ -1,9 +1,11 @@ -// RUN: mlir-translate --import-dxsa-bin %S/inputs/udiv.bin | FileCheck %s -// RUN: mlir-translate --export-dxsa-bin %s -o %t.bin +// RUN: mlir-translate --import-dxsa-bin %S/inputs/udiv.bin -o %t.mlir +// RUN: FileCheck %s --input-file %t.mlir +// RUN: mlir-translate --export-dxsa-bin %t.mlir -o %t.bin // RUN: mlir-translate --import-dxsa-bin %t.bin | FileCheck %s // RUN: diff %t.bin %S/inputs/udiv.bin +// RUN: mlir-translate --export-dxsa %t.mlir -o - | FileCheck %s --check-prefix ASM -// udiv NULL, r0.x, vOutputControlPointID, 4 +// ASM: udiv NULL, r0.x, vOutputControlPointID, 4 // CHECK: module { // CHECK-NEXT: %0 = dxsa.operand {num_components = 0 : i32, type = 13 : i32} @@ -13,12 +15,3 @@ // CHECK-NEXT: %4 = dxsa.operand.imm {imm = dense<4> : vector<1xi32>} // CHECK-NEXT: dxsa.instruction "udiv" %0, %2, %3, %4 // CHECK-NEXT: } - -module { - %0 = dxsa.operand {num_components = 0 : i32, type = 13 : i32} - %1 = dxsa.index.imm {imm = 0 : i32} - %2 = dxsa.operand %1 {mask = 16 : i32, num_components = 4 : i32, type = 0 : i32} - %3 = dxsa.operand {num_components = 1 : i32, type = 22 : i32} - %4 = dxsa.operand.imm {imm = dense<4> : vector<1xi32>} - dxsa.instruction "udiv" %0, %2, %3, %4 -}