Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,44 @@ def NoAffineOpsAttr : AMDGCN_Attr<
}];
}

def NoIndexTypesAttr : AMDGCN_Attr<
"NoIndexTypes", "no_index_types",
[DeclareAttrInterfaceMethods<NormalFormAttrInterface, ["verifyType"]>]> {
let summary = "Normal form: no index types";
let description = [{
Verifies that no `index` types exist in the IR. This is the post-condition
of the `aster-to-int-arith` pass, which converts all `index` types to
concrete integer types (typically i32). An `index` type surviving into the
AMDGCN backend would crash register allocation or produce wrong code.
}];
}

def NoLdsBufferOpsAttr : AMDGCN_Attr<
"NoLdsBufferOps", "no_lds_buffer_ops",
[DeclareAttrInterfaceMethods<NormalFormAttrInterface, ["verifyOperation"]>]> {
let summary = "Normal form: no LDS buffer operations";
let description = [{
Verifies that no LDS buffer management operations remain in the IR.
This includes `amdgcn.alloc_lds`, `amdgcn.dealloc_lds`, and
`amdgcn.get_lds_offset`. These must be lowered to constants by the
`amdgcn-lds-alloc` + `amdgcn-convert-lds-buffers` passes before the
AMDGCN backend.
}];
}

def NoUnresolvedAnyTypesAttr : AMDGCN_Attr<
"NoUnresolvedAnyTypes", "no_unresolved_any_types",
[DeclareAttrInterfaceMethods<NormalFormAttrInterface, ["verifyType"]>]> {
let summary = "Normal form: no unresolved any types";
let description = [{
Verifies that no `!aster_utils.any` types exist in the IR. These are
type-erased wrappers used by the SCF pipeliner for iter_args. They must
be resolved to concrete types by `aster-resolve-any-iter-args` before
lowering to AMDGCN. An unresolved `any` type reaching register allocation
or codegen would crash.
}];
}

def NoMetadataOpsAttr : AMDGCN_Attr<
"NoMetadataOps", "no_metadata_ops",
[DeclareAttrInterfaceMethods<NormalFormAttrInterface, ["verifyOperation"]>]> {
Expand Down
42 changes: 42 additions & 0 deletions lib/Dialect/AMDGCN/IR/AMDGCNAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "aster/Dialect/AMDGCN/IR/AMDGCNInst.h"
#include "aster/Dialect/AMDGCN/IR/AMDGCNOps.h"
#include "aster/Dialect/AMDGCN/IR/Utils.h"
#include "aster/Dialect/AsterUtils/IR/AsterUtilsTypes.h"
#include "aster/Dialect/LSIR/IR/LSIROps.h"
#include "aster/Interfaces/RegisterType.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
Expand All @@ -22,6 +23,7 @@
using namespace mlir;
using namespace mlir::aster;
using namespace mlir::aster::amdgcn;
using namespace mlir::aster::aster_utils;

//===----------------------------------------------------------------------===//
// AMDGCN dialect
Expand Down Expand Up @@ -339,6 +341,46 @@ LogicalResult NoRegisterBlockArgsAttr::verifyOperation(
return success();
}

//===----------------------------------------------------------------------===//
// NoIndexTypesAttr
//===----------------------------------------------------------------------===//

LogicalResult
NoIndexTypesAttr::verifyType(function_ref<InFlightDiagnostic()> emitError,
Type type) const {
if (type.isIndex())
return emitError() << "normal form violation: index types are disallowed "
"but found: "
<< type;
return success();
}

//===----------------------------------------------------------------------===//
// NoLdsBufferOpsAttr
//===----------------------------------------------------------------------===//

LogicalResult NoLdsBufferOpsAttr::verifyOperation(
function_ref<InFlightDiagnostic()> emitError, Operation *op) const {
if (isa<AllocLDSOp, DeallocLDSOp, GetLDSOffsetOp>(op))
return emitError() << "normal form violation: LDS buffer operations "
"are disallowed but found: "
<< op->getName();
return success();
}

//===----------------------------------------------------------------------===//
// NoUnresolvedAnyTypesAttr
//===----------------------------------------------------------------------===//

LogicalResult NoUnresolvedAnyTypesAttr::verifyType(
function_ref<InFlightDiagnostic()> emitError, Type type) const {
if (isa<aster_utils::AnyTypeType>(type))
return emitError() << "normal form violation: unresolved any types are "
"disallowed but found: "
<< type;
return success();
}

//===----------------------------------------------------------------------===//
// NoAffineOpsAttr
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/AMDGCN/Transforms/Pipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ static void buildAMDGCNBackendPassPipeline(OpPassManager &pm) {
// Only lsir.cmpi/cmpf/select survive (lowered by LegalizeCF later).
{
SetNormalFormsOptions nfOpts;
nfOpts.moduleForms = {"no_lsir_compute_ops"};
nfOpts.moduleForms = {"no_lsir_compute_ops", "no_unresolved_any_types"};
pm.addPass(createSetNormalForms(nfOpts));
}
{
Expand Down
5 changes: 3 additions & 2 deletions lib/Transforms/ToIntArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,9 @@ void ToIntArith::runOnOperation() {
return signalPassFailure();
cse();

// Set post-condition: no affine ops remain.
// Set post-conditions: no affine ops and no index types remain.
if (auto amdgcnModule = dyn_cast<amdgcn::ModuleOp>(getOperation()))
amdgcnModule.addNormalForms(
{amdgcn::NoAffineOpsAttr::get(op->getContext())});
{amdgcn::NoAffineOpsAttr::get(op->getContext()),
amdgcn::NoIndexTypesAttr::get(op->getContext())});
}
11 changes: 7 additions & 4 deletions python/aster/pass_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def phase_scf_pipelining(lcm_unroll=False):
"amdgcn-lds-alloc",
"amdgcn-convert-lds-buffers",
"canonicalize", "cse",
# Assert: no alloc_lds/get_lds_offset/dealloc_lds remain.
'amdgcn-set-normal-forms{module-forms=no_lds_buffer_ops}',
)

# Lowering to LSIR and then AMDGCN
Expand All @@ -167,9 +169,8 @@ def phase_scf_pipelining(lcm_unroll=False):
"canonicalize", "cse",
"aster-resolve-any-iter-args",
"aster-amdgcn-set-abi", # "func.func(aster-amdgcn-set-abi)",
# Convert SCF control flow to AMDGCN control flow
# Note: control flow support is very limited atm, add NORMAL FORMS
# to harden invariants.
# Convert SCF control flow to AMDGCN control flow.
# Post-condition: #amdgcn.no_scf_ops (set by the pass itself).
"amdgcn-convert-scf-control-flow",
"canonicalize", "cse",
"aster-codegen",
Expand All @@ -180,7 +181,9 @@ def phase_scf_pipelining(lcm_unroll=False):

# Register allocation, and wait lowering.
# TODO: Move NOP insertion to backend.
# TODO: NORMAL FORMS for amdgcn-backend.
# Normal forms enforced by amdgcn-backend internally (see Pipelines.cpp):
# entry: #amdgcn.no_lsir_compute_ops
# exit: #amdgcn.no_lsir_ops, #amdgcn.no_lsir_control_ops
PHASE_AMDGCN_BACKEND = "amdgcn-backend"

# Note: needs to know about instructions and actual register number for WAW
Expand Down
23 changes: 23 additions & 0 deletions test/Dialect/AMDGCN/IR/normal-forms-no-index-types-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: aster-opt %s --split-input-file --verify-diagnostics

// Violation: index type in module with no_index_types.
amdgcn.module @has_index target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> attributes {normal_forms = [#amdgcn.no_index_types]} {
amdgcn.kernel @k {
^bb0:
// expected-error @below {{normal form violation: index types are disallowed but found}}
%0 = arith.constant 42 : index
amdgcn.end_kernel
}
}

// -----

// Violation: index type in kernel with no_index_types.
amdgcn.module @mod target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> {
amdgcn.kernel @k attributes {normal_forms = [#amdgcn.no_index_types]} {
^bb0:
// expected-error @below {{normal form violation: index types are disallowed but found}}
%0 = arith.constant 42 : index
amdgcn.end_kernel
}
}
14 changes: 14 additions & 0 deletions test/Dialect/AMDGCN/IR/normal-forms-no-index-types.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: aster-opt %s | aster-opt | FileCheck %s
// RUN: aster-opt %s --mlir-print-op-generic | aster-opt | FileCheck %s

// Roundtrip: #amdgcn.no_index_types on amdgcn.module.

// CHECK: amdgcn.module @with_nf target = <gfx942> isa = <cdna3>
// CHECK-SAME: attributes {normal_forms = [#amdgcn.no_index_types]}
amdgcn.module @with_nf target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> attributes {normal_forms = [#amdgcn.no_index_types]} {
amdgcn.kernel @k {
^bb0:
%0 = arith.constant 42 : i32
amdgcn.end_kernel
}
}
23 changes: 23 additions & 0 deletions test/Dialect/AMDGCN/IR/normal-forms-no-lds-buffer-ops-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: aster-opt %s --split-input-file --verify-diagnostics

// Violation: alloc_lds in module with no_lds_buffer_ops.
amdgcn.module @has_alloc_lds target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> attributes {normal_forms = [#amdgcn.no_lds_buffer_ops]} {
func.func @f(%arg0: index) {
// expected-error @below {{normal form violation: LDS buffer operations are disallowed but found}}
%0 = amdgcn.alloc_lds %arg0
return
}
}

// -----

// Violation: alloc_lds in kernel with no_lds_buffer_ops.
amdgcn.module @mod target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> {
amdgcn.kernel @k attributes {normal_forms = [#amdgcn.no_lds_buffer_ops]} {
^bb0:
%c = arith.constant 256 : index
// expected-error @below {{normal form violation: LDS buffer operations are disallowed but found}}
%0 = amdgcn.alloc_lds %c
amdgcn.end_kernel
}
}
14 changes: 14 additions & 0 deletions test/Dialect/AMDGCN/IR/normal-forms-no-lds-buffer-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: aster-opt %s | aster-opt | FileCheck %s
// RUN: aster-opt %s --mlir-print-op-generic | aster-opt | FileCheck %s

// Roundtrip: #amdgcn.no_lds_buffer_ops on amdgcn.module.

// CHECK: amdgcn.module @with_nf target = <gfx942> isa = <cdna3>
// CHECK-SAME: attributes {normal_forms = [#amdgcn.no_lds_buffer_ops]}
amdgcn.module @with_nf target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> attributes {normal_forms = [#amdgcn.no_lds_buffer_ops]} {
amdgcn.kernel @k {
^bb0:
%0 = arith.constant 42 : i32
amdgcn.end_kernel
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: aster-opt %s --split-input-file --verify-diagnostics

// Violation: !aster_utils.any type in module with no_unresolved_any_types.
amdgcn.module @has_any target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> attributes {normal_forms = [#amdgcn.no_unresolved_any_types]} {
// expected-error @below {{normal form violation: unresolved any types are disallowed but found}}
func.func @f(%arg0: !aster_utils.any) {
return
}
}
14 changes: 14 additions & 0 deletions test/Dialect/AMDGCN/IR/normal-forms-no-unresolved-any-types.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: aster-opt %s | aster-opt | FileCheck %s
// RUN: aster-opt %s --mlir-print-op-generic | aster-opt | FileCheck %s

// Roundtrip: #amdgcn.no_unresolved_any_types on amdgcn.module.

// CHECK: amdgcn.module @with_nf target = <gfx942> isa = <cdna3>
// CHECK-SAME: attributes {normal_forms = [#amdgcn.no_unresolved_any_types]}
amdgcn.module @with_nf target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> attributes {normal_forms = [#amdgcn.no_unresolved_any_types]} {
amdgcn.kernel @k {
^bb0:
%0 = arith.constant 42 : i32
amdgcn.end_kernel
}
}
4 changes: 4 additions & 0 deletions test/Dialect/AMDGCN/Transforms/set-normal-forms.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: aster-opt %s --amdgcn-set-normal-forms='module-forms=no_lsir_ops' | FileCheck %s --check-prefix=MODULE
// RUN: aster-opt %s --amdgcn-set-normal-forms='kernel-forms=no_scf_ops' | FileCheck %s --check-prefix=KERNEL
// RUN: aster-opt %s --amdgcn-set-normal-forms='module-forms=no_lsir_ops kernel-forms=no_scf_ops,no_cf_branches' | FileCheck %s --check-prefix=BOTH
// RUN: aster-opt %s --amdgcn-set-normal-forms='module-forms=no_lds_buffer_ops,no_index_types,no_unresolved_any_types' | FileCheck %s --check-prefix=NEW

// MODULE: amdgcn.module @test
// MODULE-SAME: attributes {normal_forms = [#amdgcn.no_lsir_ops]}
Expand All @@ -17,6 +18,9 @@
// BOTH: kernel @k
// BOTH-SAME: normal_forms = [#amdgcn.no_scf_ops, #amdgcn.no_cf_branches]

// NEW: amdgcn.module @test
// NEW-SAME: attributes {normal_forms = [#amdgcn.no_lds_buffer_ops, #amdgcn.no_index_types, #amdgcn.no_unresolved_any_types]}

amdgcn.module @test target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> {
amdgcn.kernel @k {
^bb0:
Expand Down
5 changes: 3 additions & 2 deletions test/Dialect/AMDGCN/Transforms/to-int-arith-nf.mlir
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// RUN: aster-opt --pass-pipeline='builtin.module(amdgcn.module(aster-to-int-arith))' %s \
// RUN: | FileCheck %s

// Verify that aster-to-int-arith sets the no_affine_ops post-condition on amdgcn.module.
// Verify that aster-to-int-arith sets the no_affine_ops and no_index_types
// post-conditions on amdgcn.module.

// CHECK: amdgcn.module @sets_postcondition
// CHECK-SAME: attributes {normal_forms = [#amdgcn.no_affine_ops]}
// CHECK-SAME: attributes {normal_forms = [#amdgcn.no_affine_ops, #amdgcn.no_index_types]}
amdgcn.module @sets_postcondition target = #amdgcn.target<gfx942> isa = #amdgcn.isa<cdna3> {
amdgcn.kernel @k {
^bb0:
Expand Down