Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/DXSA/IR/DXSAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,24 @@ def DXSA_DclTemps : DXSA_Op<"dcl_temps"> {
let assemblyFormat = [{ $count attr-dict }];
}

def DXSA_DclThreadGroup : DXSA_Op<"dcl_thread_group"> {
let summary = "declare thread group size";
let description = [{
The `dxsa.dcl_thread_group` operation declares the number of threads in a
thread group along the x, y, and z dimensions.

The x and y values must be in [1, 1024], z must be in [1, 64], and the
product x * y * z must not exceed 1024.

Example:

```mlir
dxsa.dcl_thread_group [8, 8, 1]
```
}];
let arguments = (ins ConfinedAttr<I64ArrayAttr, [ArrayCount<3>]>:$dims);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to have DenseI64ArrayAttr here? Then you could have direct access to the underlying array w/o need to cast to Attributes for individual entries below.

let assemblyFormat = [{ $dims attr-dict }];
let hasVerifier = 1;
}

#endif // DXSA_OPS
33 changes: 33 additions & 0 deletions mlir/lib/Dialect/DXSA/IR/DXSA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,39 @@ LogicalResult DclGlobalFlags::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// DclThreadGroup
//===----------------------------------------------------------------------===//

LogicalResult DclThreadGroup::verify() {
auto verifyDim = [&](StringRef name, int64_t value,
int64_t maxValue) -> LogicalResult {
if (value < 1 || value > maxValue)
return emitOpError("thread group ")
<< name << " dimension must be in [1, " << maxValue << "], got "
<< value;
return success();
};

ArrayAttr dims = getDims();

auto x = cast<IntegerAttr>(dims[0]).getInt();
auto y = cast<IntegerAttr>(dims[1]).getInt();
auto z = cast<IntegerAttr>(dims[2]).getInt();
// clang-format off
if (failed(verifyDim("x", x, 1024)) ||
failed(verifyDim("y", y, 1024)) ||
failed(verifyDim("z", z, 64)))
return failure();
// clang-format on

constexpr int64_t maxTotalThreads = 1024;
if (auto total = x * y * z; total > maxTotalThreads)
return emitOpError("thread group size x*y*z must be <= ")
<< maxTotalThreads << ", got " << total;
return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 38 additions & 0 deletions mlir/lib/Target/DXSA/BinaryParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,12 @@ class DXBuilder {
builder.getI32IntegerAttr(count));
}

Instruction buildDclThreadGroup(uint32_t x, uint32_t y, uint32_t z,
Location loc) {
return dxsa::DclThreadGroup::create(builder, loc,
builder.getI64ArrayAttr({x, y, z}));
}

private:
MLIRContext *context;
ModuleOp module;
Expand Down Expand Up @@ -831,6 +837,35 @@ class Parser {
return builder.buildDclTemps(count, loc);
}

FailureOr<Instruction> parseDclThreadGroup(Location loc) {
auto parseDimension = [&](StringRef dimensionName, uint32_t maxValue,
uint32_t &value) -> ParseResult {
auto token = parseToken();
if (failed(token))
return failure();
value = *token;
if (value == 0 || value > maxValue)
return emitError(getLocation(), "thread group ")
<< dimensionName << " dimension must be in [1, " << maxValue
<< "], got " << value;
return success();
};

uint32_t x, y, z;
// clang-format off
if (parseDimension("x", 1024, x) ||
parseDimension("y", 1024, y) ||
parseDimension("z", 64, z))
return failure();
// clang-format on

if (auto total = x * y * z; total > 1024)
return emitError(getLocation(),
"thread group size x*y*z must be <= 1024, got ")
<< total;
return builder.buildDclThreadGroup(x, y, z, loc);
}

OptionalParseResult parseDclInstruction(uint32_t opcodeToken, Location loc,
Instruction &out) {
FailureOr<Instruction> result;
Expand All @@ -841,6 +876,9 @@ class Parser {
case D3D10_SB_OPCODE_DCL_TEMPS:
result = parseDclTemps(loc);
break;
case D3D11_SB_OPCODE_DCL_THREAD_GROUP:
result = parseDclThreadGroup(loc);
break;
default:
return std::nullopt;
}
Expand Down
10 changes: 10 additions & 0 deletions mlir/test/Target/DXSA/dcl_thread_group.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// RUN: mlir-translate --import-dxsa-bin %S/inputs/dcl_thread_group.bin | FileCheck %s

// CHECK: module {
// CHECK-NEXT: dxsa.dcl_thread_group [1, 1, 1]
// CHECK-NEXT: dxsa.dcl_thread_group [8, 8, 1]
// CHECK-NEXT: dxsa.dcl_thread_group [32, 32, 1]
// CHECK-NEXT: dxsa.dcl_thread_group [16, 1, 64]
// CHECK-NEXT: dxsa.dcl_thread_group [1024, 1, 1]
// CHECK-NEXT: dxsa.dcl_thread_group [1, 1024, 1]
// CHECK-NEXT: }
45 changes: 45 additions & 0 deletions mlir/test/Target/DXSA/dcl_thread_group_invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics

// expected-error@+1 {{attribute 'dims' failed to satisfy constraint: 64-bit integer array attribute with exactly 3 elements}}
dxsa.dcl_thread_group [1, 1]

// -----

// expected-error@+1 {{attribute 'dims' failed to satisfy constraint: 64-bit integer array attribute with exactly 3 elements}}
dxsa.dcl_thread_group [1, 1, 1, 1]

// -----

// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group x dimension must be in [1, 1024], got 0}}
dxsa.dcl_thread_group [0, 1, 1]

// -----

// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group x dimension must be in [1, 1024], got 1025}}
dxsa.dcl_thread_group [1025, 1, 1]

// -----

// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group y dimension must be in [1, 1024], got 0}}
dxsa.dcl_thread_group [1, 0, 1]

// -----

// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group y dimension must be in [1, 1024], got 1025}}
dxsa.dcl_thread_group [1, 1025, 1]

// -----

// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group z dimension must be in [1, 64], got 0}}
dxsa.dcl_thread_group [1, 1, 0]

// -----

// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group z dimension must be in [1, 64], got 65}}
dxsa.dcl_thread_group [1, 1, 65]

// -----

// 64 * 8 * 4 == 2048
// expected-error@+1 {{'dxsa.dcl_thread_group' op thread group size x*y*z must be <= 1024, got 2048}}
dxsa.dcl_thread_group [64, 8, 4]
Binary file added mlir/test/Target/DXSA/inputs/dcl_thread_group.bin
Binary file not shown.