diff --git a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td index 5039217ee07b..f4659c6e4b08 100644 --- a/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td +++ b/mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td @@ -158,4 +158,24 @@ def DXSA_DclTemps : DXSA_Op<"dcl_temps"> { let assemblyFormat = [{ $count attr-dict }]; } +def DXSA_DclThreadGroup : DXSA_Op<"dcl_thread_group"> { + let summary = "declare thread group size"; + let description = [{ + The `dxsa.dcl_thread_group` operation declares the number of threads in a + thread group along the x, y, and z dimensions. + + The x and y values must be in [1, 1024], z must be in [1, 64], and the + product x * y * z must not exceed 1024. + + Example: + + ```mlir + dxsa.dcl_thread_group [8, 8, 1] + ``` + }]; + let arguments = (ins ConfinedAttr]>:$dims); + let assemblyFormat = [{ $dims attr-dict }]; + let hasVerifier = 1; +} + #endif // DXSA_OPS diff --git a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp index 0281c642269a..cd66629e046b 100644 --- a/mlir/lib/Dialect/DXSA/IR/DXSA.cpp +++ b/mlir/lib/Dialect/DXSA/IR/DXSA.cpp @@ -44,6 +44,39 @@ LogicalResult DclGlobalFlags::verify() { return success(); } +//===----------------------------------------------------------------------===// +// DclThreadGroup +//===----------------------------------------------------------------------===// + +LogicalResult DclThreadGroup::verify() { + auto verifyDim = [&](StringRef name, int64_t value, + int64_t maxValue) -> LogicalResult { + if (value < 1 || value > maxValue) + return emitOpError("thread group ") + << name << " dimension must be in [1, " << maxValue << "], got " + << value; + return success(); + }; + + ArrayAttr dims = getDims(); + + auto x = cast(dims[0]).getInt(); + auto y = cast(dims[1]).getInt(); + auto z = cast(dims[2]).getInt(); + // clang-format off + if (failed(verifyDim("x", x, 1024)) || + failed(verifyDim("y", y, 1024)) || + failed(verifyDim("z", z, 64))) + return failure(); + // clang-format on + + constexpr int64_t maxTotalThreads = 1024; + if (auto total = x * y * z; total > maxTotalThreads) + return emitOpError("thread group size x*y*z must be <= ") + << maxTotalThreads << ", got " << total; + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/DXSA/BinaryParser.cpp b/mlir/lib/Target/DXSA/BinaryParser.cpp index 2ac6629e6d75..52f673cfb5bc 100644 --- a/mlir/lib/Target/DXSA/BinaryParser.cpp +++ b/mlir/lib/Target/DXSA/BinaryParser.cpp @@ -516,6 +516,12 @@ class DXBuilder { builder.getI32IntegerAttr(count)); } + Instruction buildDclThreadGroup(uint32_t x, uint32_t y, uint32_t z, + Location loc) { + return dxsa::DclThreadGroup::create(builder, loc, + builder.getI64ArrayAttr({x, y, z})); + } + private: MLIRContext *context; ModuleOp module; @@ -831,6 +837,35 @@ class Parser { return builder.buildDclTemps(count, loc); } + FailureOr parseDclThreadGroup(Location loc) { + auto parseDimension = [&](StringRef dimensionName, uint32_t maxValue, + uint32_t &value) -> ParseResult { + auto token = parseToken(); + if (failed(token)) + return failure(); + value = *token; + if (value == 0 || value > maxValue) + return emitError(getLocation(), "thread group ") + << dimensionName << " dimension must be in [1, " << maxValue + << "], got " << value; + return success(); + }; + + uint32_t x, y, z; + // clang-format off + if (parseDimension("x", 1024, x) || + parseDimension("y", 1024, y) || + parseDimension("z", 64, z)) + return failure(); + // clang-format on + + if (auto total = x * y * z; total > 1024) + return emitError(getLocation(), + "thread group size x*y*z must be <= 1024, got ") + << total; + return builder.buildDclThreadGroup(x, y, z, loc); + } + OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc, Instruction &out) { FailureOr result; @@ -841,6 +876,9 @@ class Parser { case D3D10_SB_OPCODE_DCL_TEMPS: result = parseDclTemps(loc); break; + case D3D11_SB_OPCODE_DCL_THREAD_GROUP: + result = parseDclThreadGroup(loc); + break; default: return std::nullopt; } diff --git a/mlir/test/Target/DXSA/dcl_thread_group.mlir b/mlir/test/Target/DXSA/dcl_thread_group.mlir new file mode 100644 index 000000000000..b42433fb06bd --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_thread_group.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_thread_group.bin | FileCheck %s + +// CHECK: module { +// CHECK-NEXT: dxsa.dcl_thread_group [1, 1, 1] +// CHECK-NEXT: dxsa.dcl_thread_group [8, 8, 1] +// CHECK-NEXT: dxsa.dcl_thread_group [32, 32, 1] +// CHECK-NEXT: dxsa.dcl_thread_group [16, 1, 64] +// CHECK-NEXT: dxsa.dcl_thread_group [1024, 1, 1] +// CHECK-NEXT: dxsa.dcl_thread_group [1, 1024, 1] +// CHECK-NEXT: } diff --git a/mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir b/mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir new file mode 100644 index 000000000000..55f8833c93c7 --- /dev/null +++ b/mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+1 {{attribute 'dims' failed to satisfy constraint: 64-bit integer array attribute with exactly 3 elements}} +dxsa.dcl_thread_group [1, 1] + +// ----- + +// expected-error@+1 {{attribute 'dims' failed to satisfy constraint: 64-bit integer array attribute with exactly 3 elements}} +dxsa.dcl_thread_group [1, 1, 1, 1] + +// ----- + +// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group x dimension must be in [1, 1024], got 0}} +dxsa.dcl_thread_group [0, 1, 1] + +// ----- + +// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group x dimension must be in [1, 1024], got 1025}} +dxsa.dcl_thread_group [1025, 1, 1] + +// ----- + +// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group y dimension must be in [1, 1024], got 0}} +dxsa.dcl_thread_group [1, 0, 1] + +// ----- + +// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group y dimension must be in [1, 1024], got 1025}} +dxsa.dcl_thread_group [1, 1025, 1] + +// ----- + +// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group z dimension must be in [1, 64], got 0}} +dxsa.dcl_thread_group [1, 1, 0] + +// ----- + +// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group z dimension must be in [1, 64], got 65}} +dxsa.dcl_thread_group [1, 1, 65] + +// ----- + +// 64 * 8 * 4 == 2048 +// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group size x*y*z must be <= 1024, got 2048}} +dxsa.dcl_thread_group [64, 8, 4] diff --git a/mlir/test/Target/DXSA/inputs/dcl_thread_group.bin b/mlir/test/Target/DXSA/inputs/dcl_thread_group.bin new file mode 100644 index 000000000000..47801af335e0 Binary files /dev/null and b/mlir/test/Target/DXSA/inputs/dcl_thread_group.bin differ