From 2cb3f0d3305e400ae0d0f45bfad107b16eaf58c6 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 23 Apr 2026 15:22:46 +0200 Subject: [PATCH] use upstream normal forms Remove the local copy of the NormalForm dialect and rely on the upstream MLIR interfaces from the transform dialect instead. Additional functionally is introduced to walk attributes and types since upstream doesn't currently provide native support for that. It keeps checking the form despite seeing silenceable failures with the idea that these may be silenced later and we don't want to miss a definite failure later in the IR, similarly to the upstream design. This functionality may later be moved upstream as well. Remove the test scaffolding that was needed to verify the NormalForm dialect. Requires LLVM bump, available as commit 9deb1c631b11230787f0fb56583b17f060b194a0 in https://github.com/ftynse/llvm-project Signed-off-by: Alex Zinenko --- include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h | 3 +- .../aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td | 43 +- include/aster/Dialect/AMDGCN/IR/AMDGCNOps.td | 8 +- include/aster/Dialect/CMakeLists.txt | 1 - .../aster/Dialect/NormalForm/CMakeLists.txt | 2 - .../Dialect/NormalForm/IR/CMakeLists.txt | 23 - .../Dialect/NormalForm/IR/NormalFormDialect.h | 18 - .../NormalForm/IR/NormalFormDialect.td | 28 -- .../NormalForm/IR/NormalFormInterfaces.h | 48 -- .../NormalForm/IR/NormalFormInterfaces.td | 92 ---- .../Dialect/NormalForm/IR/NormalFormOps.h | 25 -- .../Dialect/NormalForm/IR/NormalFormOps.td | 124 ------ .../NormalForm/Transforms/CMakeLists.txt | 3 - .../Dialect/NormalForm/Transforms/Passes.h | 27 -- .../Dialect/NormalForm/Transforms/Passes.td | 26 -- lib/CMakeLists.txt | 2 - lib/Dialect/AMDGCN/IR/AMDGCN.cpp | 41 +- lib/Dialect/AMDGCN/IR/AMDGCNAttrs.cpp | 409 ++++++++++++------ lib/Dialect/AMDGCN/IR/CMakeLists.txt | 3 +- lib/Dialect/AMDGCN/Transforms/CMakeLists.txt | 3 +- .../AMDGCN/Transforms/LowLevelScheduler.cpp | 6 +- .../AMDGCN/Transforms/SetNormalForms.cpp | 10 +- lib/Dialect/CMakeLists.txt | 1 - lib/Dialect/NormalForm/CMakeLists.txt | 2 - lib/Dialect/NormalForm/IR/CMakeLists.txt | 18 - .../NormalForm/IR/NormalFormDialect.cpp | 26 -- .../NormalForm/IR/NormalFormInterfaces.cpp | 16 - lib/Dialect/NormalForm/IR/NormalFormOps.cpp | 250 ----------- .../NormalForm/Transforms/CMakeLists.txt | 17 - .../Transforms/LowerNormalFormModule.cpp | 102 ----- lib/Init.cpp | 4 - llvm/LLVM_COMMIT | 2 +- .../AMDGCN/IR/normal-forms-kernel.mlir | 13 + .../amdgcn-no-value-semantic-registers.mlir | 52 --- ...n-to-register-semantics-postcondition.mlir | 25 -- .../lower-normalform-module-invalid.mlir | 19 - .../NormalForm/lower-normalform-module.mlir | 51 --- test/Dialect/NormalForm/ops-invalid.mlir | 134 ------ test/Dialect/NormalForm/ops.mlir | 89 ---- test/lib/CMakeLists.txt | 1 - test/lib/Dialect/AsterTestDialect.cpp | 88 ---- test/lib/Dialect/AsterTestDialect.h | 25 -- test/lib/Dialect/AsterTestDialect.td | 24 - test/lib/Dialect/CMakeLists.txt | 30 -- test/lib/Dialect/TestNormalFormAttr.td | 54 --- tools/aster-opt/CMakeLists.txt | 1 - tools/aster-opt/aster-opt.cpp | 5 - 47 files changed, 366 insertions(+), 1628 deletions(-) delete mode 100644 include/aster/Dialect/NormalForm/CMakeLists.txt delete mode 100644 include/aster/Dialect/NormalForm/IR/CMakeLists.txt delete mode 100644 include/aster/Dialect/NormalForm/IR/NormalFormDialect.h delete mode 100644 include/aster/Dialect/NormalForm/IR/NormalFormDialect.td delete mode 100644 include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.h delete mode 100644 include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.td delete mode 100644 include/aster/Dialect/NormalForm/IR/NormalFormOps.h delete mode 100644 include/aster/Dialect/NormalForm/IR/NormalFormOps.td delete mode 100644 include/aster/Dialect/NormalForm/Transforms/CMakeLists.txt delete mode 100644 include/aster/Dialect/NormalForm/Transforms/Passes.h delete mode 100644 include/aster/Dialect/NormalForm/Transforms/Passes.td delete mode 100644 lib/Dialect/NormalForm/CMakeLists.txt delete mode 100644 lib/Dialect/NormalForm/IR/CMakeLists.txt delete mode 100644 lib/Dialect/NormalForm/IR/NormalFormDialect.cpp delete mode 100644 lib/Dialect/NormalForm/IR/NormalFormInterfaces.cpp delete mode 100644 lib/Dialect/NormalForm/IR/NormalFormOps.cpp delete mode 100644 lib/Dialect/NormalForm/Transforms/CMakeLists.txt delete mode 100644 lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp delete mode 100644 test/Dialect/NormalForm/amdgcn-no-value-semantic-registers.mlir delete mode 100644 test/Dialect/NormalForm/amdgcn-to-register-semantics-postcondition.mlir delete mode 100644 test/Dialect/NormalForm/lower-normalform-module-invalid.mlir delete mode 100644 test/Dialect/NormalForm/lower-normalform-module.mlir delete mode 100644 test/Dialect/NormalForm/ops-invalid.mlir delete mode 100644 test/Dialect/NormalForm/ops.mlir delete mode 100644 test/lib/Dialect/AsterTestDialect.cpp delete mode 100644 test/lib/Dialect/AsterTestDialect.h delete mode 100644 test/lib/Dialect/AsterTestDialect.td delete mode 100644 test/lib/Dialect/CMakeLists.txt delete mode 100644 test/lib/Dialect/TestNormalFormAttr.td diff --git a/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h b/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h index 0667572ce..b4133cc13 100644 --- a/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h +++ b/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h @@ -21,10 +21,11 @@ #include "aster/Dialect/AMDGCN/IR/Hazards.h" #include "aster/Dialect/AMDGCN/IR/Interfaces/KernelArgInterface.h" #include "aster/Dialect/AMDGCN/IR/Sched.h" -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" #include "aster/Interfaces/MemorySpaceConstraints.h" #include "aster/Interfaces/SchedInterfaces.h" #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" #include "mlir/IR/Attributes.h" namespace mlir { diff --git a/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td b/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td index c8219f0a8..93bf3239d 100644 --- a/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td +++ b/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td @@ -20,10 +20,27 @@ include "aster/Dialect/AMDGCN/IR/AMDGCNEnums.td" include "aster/Dialect/AMDGCN/IR/Interfaces/KernelArgInterface.td" include "aster/Dialect/AMDGCN/IR/Interfaces/HazardAttrInterface.td" include "aster/Dialect/AMDGCN/IR/Hazards.td" -include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.td" include "aster/Interfaces/MemorySpaceConstraints.td" include "aster/Interfaces/SchedInterfaces.td" include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/CommonAttrConstraints.td" + +//===----------------------------------------------------------------------===// +// Normal form attribute array constraints. +//===----------------------------------------------------------------------===// + +/// An array of attributes implementing `mlir::transform::NormalFormAttrInterface`. +def NormalFormAttrArray : TypedArrayAttrBase; + +/// Optional normal form array: omitted when empty, prints in attr-dict. +def OptionalNormalFormAttrArray + : DefaultValuedAttr()"> { + let constBuilderCall = "$_builder.getArrayAttr($0)"; + let isOptional = true; +} //===----------------------------------------------------------------------===// // AddressSpaceAttr @@ -372,7 +389,7 @@ def KernelArgumentsAttr def NoValueSemanticRegistersAttr : AMDGCN_Attr< "NoValueSemanticRegisters", "no_value_semantic_registers", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no register types with value semantics"; let description = [{ Verifies that no register types in the IR have value semantics. @@ -386,7 +403,7 @@ def NoValueSemanticRegistersAttr : AMDGCN_Attr< def AllRegistersAllocatedAttr : AMDGCN_Attr< "AllRegistersAllocated", "all_registers_allocated", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: all register types have allocated semantics"; let description = [{ Verifies that every register type in the IR has allocated semantics @@ -400,7 +417,7 @@ def AllRegistersAllocatedAttr : AMDGCN_Attr< def NoRegCastOpsAttr : AMDGCN_Attr< "NoRegCastOps", "no_reg_cast_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no lsir.reg_cast operations"; let description = [{ Verifies that no `lsir.reg_cast` operations exist in the IR. @@ -414,7 +431,7 @@ def NoRegCastOpsAttr : AMDGCN_Attr< def NoCfBranchesAttr : AMDGCN_Attr< "NoCfBranches", "no_cf_branches", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no CF dialect branch operations"; let description = [{ Verifies that no `cf.br` or `cf.cond_br` operations exist in the IR. @@ -426,7 +443,7 @@ def NoCfBranchesAttr : AMDGCN_Attr< def NoScfOpsAttr : AMDGCN_Attr< "NoScfOps", "no_scf_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no SCF dialect operations"; let description = [{ Verifies that no SCF dialect operations (`scf.for`, `scf.if`, `scf.while`, @@ -438,7 +455,7 @@ def NoScfOpsAttr : AMDGCN_Attr< def NoLsirOpsAttr : AMDGCN_Attr< "NoLsirOps", "no_lsir_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no LSIR dialect operations"; let description = [{ Verifies that no LSIR dialect operations exist in the IR. @@ -450,7 +467,7 @@ def NoLsirOpsAttr : AMDGCN_Attr< def NoLsirComputeOpsAttr : AMDGCN_Attr< "NoLsirComputeOps", "no_lsir_compute_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no LSIR compute/memory operations"; let description = [{ Verifies that no LSIR arithmetic, memory, or utility operations exist in @@ -469,7 +486,7 @@ def NoLsirComputeOpsAttr : AMDGCN_Attr< def NoLsirControlOpsAttr : AMDGCN_Attr< "NoLsirControlOps", "no_lsir_control_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no LSIR control-flow operations"; let description = [{ Verifies that no LSIR control-flow-related operations (`lsir.cmpi`, @@ -485,7 +502,7 @@ def NoLsirControlOpsAttr : AMDGCN_Attr< def NoRegisterBlockArgsAttr : AMDGCN_Attr< "NoRegisterBlockArgs", "no_register_block_args", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no block arguments with register types"; let description = [{ Verifies that no block arguments have register types @@ -497,7 +514,7 @@ def NoRegisterBlockArgsAttr : AMDGCN_Attr< def NoAffineOpsAttr : AMDGCN_Attr< "NoAffineOps", "no_affine_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no affine dialect operations"; let description = [{ Verifies that no affine dialect operations (`affine.apply`, `affine.for`, @@ -509,7 +526,7 @@ def NoAffineOpsAttr : AMDGCN_Attr< def NoMetadataOpsAttr : AMDGCN_Attr< "NoMetadataOps", "no_metadata_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no AMDGCN metadata operations"; let description = [{ Verifies that no AMDGCN metadata operations remain in the IR. @@ -523,7 +540,7 @@ def NoMetadataOpsAttr : AMDGCN_Attr< def AllInlinedAttr : AMDGCN_Attr< "AllInlined", "all_inlined", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: all function calls inlined"; let description = [{ Verifies that no `func.call` operations exist in the IR. diff --git a/include/aster/Dialect/AMDGCN/IR/AMDGCNOps.td b/include/aster/Dialect/AMDGCN/IR/AMDGCNOps.td index b769d5de3..03acf008b 100644 --- a/include/aster/Dialect/AMDGCN/IR/AMDGCNOps.td +++ b/include/aster/Dialect/AMDGCN/IR/AMDGCNOps.td @@ -445,11 +445,11 @@ def AMDGCN_KernelOp : AMDGCN_Op<"kernel", [ /// Adds normal form attributes. Returns true if the op was changed. bool addNormalForms( - ::llvm::ArrayRef<::normalform::NormalFormAttrInterface> normalForms); + ::llvm::ArrayRef<::mlir::transform::NormalFormAttrInterface> normalForms); /// Removes normal form attributes. Returns true if the op was changed. bool removeNormalForms( - ::llvm::ArrayRef<::normalform::NormalFormAttrInterface> normalForms); + ::llvm::ArrayRef<::mlir::transform::NormalFormAttrInterface> normalForms); }]; } @@ -572,11 +572,11 @@ def AMDGCN_ModuleOp : AMDGCN_Op<"module", [ /// Adds normal form attributes. Returns true if the op was changed. bool addNormalForms( - ::llvm::ArrayRef<::normalform::NormalFormAttrInterface> normalForms); + ::llvm::ArrayRef<::mlir::transform::NormalFormAttrInterface> normalForms); /// Removes normal form attributes. Returns true if the op was changed. bool removeNormalForms( - ::llvm::ArrayRef<::normalform::NormalFormAttrInterface> normalForms); + ::llvm::ArrayRef<::mlir::transform::NormalFormAttrInterface> normalForms); }]; } diff --git a/include/aster/Dialect/CMakeLists.txt b/include/aster/Dialect/CMakeLists.txt index 68bb9f07d..407936f00 100644 --- a/include/aster/Dialect/CMakeLists.txt +++ b/include/aster/Dialect/CMakeLists.txt @@ -2,4 +2,3 @@ add_subdirectory(AMDGCN) add_subdirectory(AsterUtils) add_subdirectory(Layout) add_subdirectory(LSIR) -add_subdirectory(NormalForm) diff --git a/include/aster/Dialect/NormalForm/CMakeLists.txt b/include/aster/Dialect/NormalForm/CMakeLists.txt deleted file mode 100644 index 9f57627c3..000000000 --- a/include/aster/Dialect/NormalForm/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/include/aster/Dialect/NormalForm/IR/CMakeLists.txt b/include/aster/Dialect/NormalForm/IR/CMakeLists.txt deleted file mode 100644 index 8d55f3851..000000000 --- a/include/aster/Dialect/NormalForm/IR/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -add_custom_target(MLIRNormalFormIncGen) - -set(LLVM_TARGET_DEFINITIONS NormalFormDialect.td) -mlir_tablegen(NormalFormDialect.h.inc -gen-dialect-decls -dialect=normalform) -mlir_tablegen(NormalFormDialect.cpp.inc -gen-dialect-defs -dialect=normalform) -add_public_tablegen_target(MLIRNormalFormDialectIncGen) - -set(LLVM_TARGET_DEFINITIONS NormalFormOps.td) -mlir_tablegen(NormalFormOps.h.inc -gen-op-decls) -mlir_tablegen(NormalFormOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRNormalFormOpsIncGen) -add_dependencies(MLIRNormalFormIncGen MLIRNormalFormOpsIncGen) - -set(LLVM_TARGET_DEFINITIONS NormalFormInterfaces.td) -mlir_tablegen(NormalFormAttrInterfaces.h.inc -gen-attr-interface-decls) -mlir_tablegen(NormalFormAttrInterfaces.cpp.inc -gen-attr-interface-defs) -add_public_tablegen_target(MLIRNormalFormInterfacesIncGen) -add_dependencies(MLIRNormalFormIncGen MLIRNormalFormInterfacesIncGen) -add_mlir_doc(NormalFormInterfaces NormalFormAttrInterfaces Dialects/ -gen-attr-interface-docs) - -add_dependencies(mlir-headers MLIRNormalFormIncGen) - -add_mlir_doc(NormalForm NormalForm Dialects/ -gen-dialect-doc -dialect normalform) diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormDialect.h b/include/aster/Dialect/NormalForm/IR/NormalFormDialect.h deleted file mode 100644 index 7b0841bc3..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormDialect.h +++ /dev/null @@ -1,18 +0,0 @@ -//===- NormalFormDialect.h ------------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_IR_NORMALFORMDIALECT_H -#define ASTER_DIALECT_NORMALFORM_IR_NORMALFORMDIALECT_H - -#include "mlir/IR/Dialect.h" - -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.h.inc" - -#endif // ASTER_DIALECT_NORMALFORM_IR_NORMALFORMDIALECT_H diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormDialect.td b/include/aster/Dialect/NormalForm/IR/NormalFormDialect.td deleted file mode 100644 index 588bd8bb0..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormDialect.td +++ /dev/null @@ -1,28 +0,0 @@ -//===- NormalFormDialect.td -----------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_NORMALFORMDIALECT -#define ASTER_DIALECT_NORMALFORM_NORMALFORMDIALECT - -include "mlir/IR/DialectBase.td" - -def NormalFormDialect : Dialect { - let name = "normalform"; - let summary = "Dialect for normal form IR verification"; - - let description = [{ - The NormalForm dialect provides infrastructure for defining and verifying - normal forms of IR. It uses the `normalform.module` operation as a - container that specifies the expected normal forms through an array of - attributes implementing the NormalFormAttrInterface. - }]; -} - -#endif // ASTER_DIALECT_NORMALFORM_NORMALFORMDIALECT diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.h b/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.h deleted file mode 100644 index 2ef4d5214..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.h +++ /dev/null @@ -1,48 +0,0 @@ -//===- NormalFormInterfaces.h ---------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_IR_NORMALFORMINTERFACES_H -#define ASTER_DIALECT_NORMALFORM_IR_NORMALFORMINTERFACES_H - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseSet.h" - -#include "aster/Dialect/NormalForm/IR/NormalFormAttrInterfaces.h.inc" - -namespace normalform { - -/// Verify that all IR nested under `root` satisfies the given normal form. -/// If `emitDiagnostics` is true, errors are reported; otherwise the check -/// is silent (useful for inferNormalForms). -/// `excludeAttrNames` optionally specifies named attributes whose nested -/// types should be skipped during verification (e.g., kernel argument -/// attributes that contain ABI metadata types, not computational types). -::llvm::LogicalResult verifyNormalForm( - ::mlir::Operation *root, NormalFormAttrInterface normalForm, - bool emitDiagnostics, - const ::llvm::DenseSet<::mlir::StringAttr> *excludeAttrNames = nullptr); - -} // namespace normalform - -namespace llvm { -template <> -struct PointerLikeTypeTraits - : public PointerLikeTypeTraits { - static inline normalform::NormalFormAttrInterface - getFromVoidPointer(void *p) { - return normalform::NormalFormAttrInterface( - mlir::Attribute::getFromOpaquePointer(p)); - } -}; -} // namespace llvm - -#endif // ASTER_DIALECT_NORMALFORM_IR_NORMALFORMINTERFACES_H diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.td b/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.td deleted file mode 100644 index 9ed21e7a9..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.td +++ /dev/null @@ -1,92 +0,0 @@ -//===- NormalFormInterfaces.td --------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_NORMALFORMINTERFACES -#define ASTER_DIALECT_NORMALFORM_NORMALFORMINTERFACES - -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/OpBase.td" -include "mlir/IR/CommonAttrConstraints.td" - -//===----------------------------------------------------------------------===// -// Normal form attribute interface. -//===----------------------------------------------------------------------===// -def NormalFormAttrInterface : AttrInterface<"NormalFormAttrInterface"> { - let description = [{ - Interface for attributes that define normal form constraints on IR. - - A normal form attribute specifies invariants that types, attributes, and - operations must satisfy. When attached to a `normalform.module`, all - contained IR is verified against the constraints defined by the attribute. - - Implementers should override one or more of `verifyType`, `verifyAttribute`, - or `verifyOperation` to define custom verification logic. The default - implementations return success. - }]; - let cppNamespace = "::normalform"; - let methods = [ - InterfaceMethod< - /*desc=*/[{ - Verify the type using this normal form attribute. - }], - /*returnType=*/"::llvm::LogicalResult", - /*name=*/"verifyType", - /*arguments=*/(ins "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError, "::mlir::Type":$type), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return llvm::success(); - }] - >, - - InterfaceMethod< - /*desc=*/[{ - Verify the attribute using this normal form attribute. - }], - /*returnType=*/"::llvm::LogicalResult", - /*name=*/"verifyAttribute", - /*arguments=*/(ins "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError, "::mlir::Attribute":$attr), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return llvm::success(); - }] - >, - - InterfaceMethod< - /*desc=*/[{ - Verify the operation using this normal form attribute. - }], - /*returnType=*/"::llvm::LogicalResult", - /*name=*/"verifyOperation", - /*arguments=*/(ins "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError, "::mlir::Operation*":$op), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return llvm::success(); - }] - >, - ]; -} - -//===----------------------------------------------------------------------===// -// Normal form attribute array constraints. -//===----------------------------------------------------------------------===// - -/// An array of attributes implementing NormalFormAttrInterface. -def NormalFormAttrArray : TypedArrayAttrBase; - -/// Optional normal form array: omitted when empty, prints in attr-dict. -def OptionalNormalFormAttrArray - : DefaultValuedAttr()"> { - let constBuilderCall = "$_builder.getArrayAttr($0)"; - let isOptional = true; -} - - -#endif // ASTER_DIALECT_NORMALFORM_NORMALFORMINTERFACES diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormOps.h b/include/aster/Dialect/NormalForm/IR/NormalFormOps.h deleted file mode 100644 index 24d5dde27..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormOps.h +++ /dev/null @@ -1,25 +0,0 @@ -//===- NormalFormOps.h ----------------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_IR_NORMALFORMOPS_H -#define ASTER_DIALECT_NORMALFORM_IR_NORMALFORMOPS_H - -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" - -#include "mlir/Bytecode/BytecodeOpInterface.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/RegionKindInterface.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" - -#define GET_OP_CLASSES -#include "aster/Dialect/NormalForm/IR/NormalFormOps.h.inc" - -#endif // ASTER_DIALECT_NORMALFORM_IR_NORMALFORMOPS_H diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormOps.td b/include/aster/Dialect/NormalForm/IR/NormalFormOps.td deleted file mode 100644 index d1f1c71a2..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormOps.td +++ /dev/null @@ -1,124 +0,0 @@ -//===- NormalFormOps.td ---------------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -include "aster/Dialect/NormalForm/IR/NormalFormDialect.td" -include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.td" -include "mlir/IR/OpBase.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/IR/CommonAttrConstraints.td" -include "mlir/IR/CommonTypeConstraints.td" -include "mlir/IR/BuiltinAttributeInterfaces.td" -include "mlir/IR/RegionKindInterface.td" -include "mlir/IR/SymbolInterfaces.td" - - - -#ifndef ASTER_DIALECT_NORMALFORM_NORMALFORMOPS -#define ASTER_DIALECT_NORMALFORM_NORMALFORMOPS - -//----------------------------------------------------------------------------- -// Base class for all NormalForm operations. -//----------------------------------------------------------------------------- - -class NormalFormOp traits = []> : - Op; - -//----------------------------------------------------------------------------- -// Structure Ops -//----------------------------------------------------------------------------- - -def ModuleOp : NormalFormOp<"module", [ - IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, - ] # GraphRegionNoTerminator.traits> { - let summary = "A container that enforces IR invariants specified by normal form attributes."; - let description = [{ - The `normalform.module` operation defines a scoped region whose contents must - satisfy the constraints specified by the attached normal form attributes. Each - attribute must implement `NormalFormAttrInterface`, which provides hooks for - verifying types, attributes, and operations within the module. - - During region verification, the module walks all nested operations and - invokes the interface methods (`verifyType`, `verifyAttribute`, - `verifyOperation`) on each element for all normal form attributes. - Normal forms are verified in the order given by the normal form array attribute, - walking over each operation in pre-order while also visiting its types, followed by - its attributes. - - Example: - - ```mlir - // Enforce that all tensor types are fully specified. - normalform.module @my_kernel [#wave.normal_form] { - func.func @compute(%arg: !wave.tensor<[64, 128] of f32>) { - return - } - } - - // Multiple normal form attributes from different dialects. - normalform.module @validated [#wave.normal_form, #other.normal_form] { - func.func @compute(%arg: !wave.tensor<[64, 128] of f32>) { - return - } - } - ``` - }]; - let arguments = (ins NormalFormAttrArray:$normal_forms, OptionalAttr:$sym_name); - let regions = (region SizedRegion<1>:$bodyRegion); - let assemblyFormat = [{ - ($sym_name^)? $normal_forms attr-dict-with-keyword $bodyRegion - }]; - - let builders = [ - OpBuilder<(ins - CArg<"::llvm::ArrayRef">:$normal_forms, - CArg<"std::optional<::llvm::StringRef>", "{}">:$name)> - ]; - - let extraClassDeclaration = [{ - /// Construct a module from the given location with an optional name. - static ModuleOp create(::mlir::Location loc, - ::llvm::ArrayRef normalForms, - std::optional<::llvm::StringRef> name = {}); - - /// Return the name of this module if present. - std::optional<::llvm::StringRef> getName() { return getSymName(); } - - /// Checks whether the normal forms passed in `normalForms` apply to this - /// module and attaches them to the module if true. Returns true if the module was changed. - bool inferNormalForms(::llvm::ArrayRef normalForms); - - /// Checks wheter a given normal form applies to this module. - ::llvm::LogicalResult - verifyNormalForm(NormalFormAttrInterface normalForm, bool emitDiagnostics); - - /// Adds normal form attributes to the module. Returns true if the module was changed. - bool addNormalForms(::llvm::ArrayRef normalForms); - - /// Removes normal form attributes from the module. Returns true if the module was changed. - bool removeNormalForms(::llvm::ArrayRef normalForms); - - //===------------------------------------------------------------------===// - // SymbolOpInterface Methods - //===------------------------------------------------------------------===// - - /// A ModuleOp may optionally define a symbol. - bool isOptionalSymbol() const { return true; } - }]; - - let hasVerifier = 1; - let hasRegionVerifier = 1; - - // We need to ensure the block inside the region is properly terminated; - // the auto-generated builders do not guarantee that. - let skipDefaultBuilders = 1; -} - - -#endif // ASTER_DIALECT_NORMALFORM_NORMALFORMOPS diff --git a/include/aster/Dialect/NormalForm/Transforms/CMakeLists.txt b/include/aster/Dialect/NormalForm/Transforms/CMakeLists.txt deleted file mode 100644 index 2594d3ce1..000000000 --- a/include/aster/Dialect/NormalForm/Transforms/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name NormalForm) -add_public_tablegen_target(MLIRAsterNormalFormPassesIncGen) diff --git a/include/aster/Dialect/NormalForm/Transforms/Passes.h b/include/aster/Dialect/NormalForm/Transforms/Passes.h deleted file mode 100644 index 8f3eb68d7..000000000 --- a/include/aster/Dialect/NormalForm/Transforms/Passes.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- Passes.h - Pass entrypoints ----------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES_H -#define ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES_H - -#include "mlir/Pass/Pass.h" -#include - -namespace normalform { - -#define GEN_PASS_DECL -#include "aster/Dialect/NormalForm/Transforms/Passes.h.inc" - -#define GEN_PASS_REGISTRATION -#include "aster/Dialect/NormalForm/Transforms/Passes.h.inc" - -} // namespace normalform - -#endif // ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES_H diff --git a/include/aster/Dialect/NormalForm/Transforms/Passes.td b/include/aster/Dialect/NormalForm/Transforms/Passes.td deleted file mode 100644 index af07a9fff..000000000 --- a/include/aster/Dialect/NormalForm/Transforms/Passes.td +++ /dev/null @@ -1,26 +0,0 @@ -//===- Passes.td - Pass definitions ---------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES -#define ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES - -include "mlir/Pass/PassBase.td" - -def LowerNormalFormModulePass : Pass<"lower-normalform-module"> { - let summary = "Lower normalform.module to builtin.module"; - let description = [{ - This pass converts `normalform.module` operations to `builtin.module` - operations. The normal form attributes are not retained on the resulting - module. - }]; - let dependentDialects = []; -} - -#endif // ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 03b0d0f0d..07d75b9e6 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -22,8 +22,6 @@ add_mlir_library(ASTERInit LayoutDialect LayoutTransforms LSIRDialect - MLIRNormalFormDialect - MLIRNormalFormTransforms MLIRAffineDialect MLIRAffineTransforms MLIRArithDialect diff --git a/lib/Dialect/AMDGCN/IR/AMDGCN.cpp b/lib/Dialect/AMDGCN/IR/AMDGCN.cpp index 02014d8a4..05aa8c85f 100644 --- a/lib/Dialect/AMDGCN/IR/AMDGCN.cpp +++ b/lib/Dialect/AMDGCN/IR/AMDGCN.cpp @@ -16,7 +16,6 @@ #include "aster/Dialect/AMDGCN/IR/AMDGCNVerifiers.h" #include "aster/Dialect/AMDGCN/IR/Interfaces/AMDGCNInterfaces.h" #include "aster/Dialect/AMDGCN/IR/Utils.h" -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" #include "aster/IR/ParsePrintUtils.h" #include "aster/Interfaces/RegisterType.h" #include "mlir/IR/BuiltinTypes.h" @@ -921,17 +920,17 @@ LogicalResult KernelOp::verify() { //===----------------------------------------------------------------------===// /// Verify all normal forms attached to an operation via its normal_forms attr. -/// `excludeAttrNames` optionally specifies named attributes whose nested -/// types should be skipped during verification. -static LogicalResult verifyNormalFormsRegions( - Operation *op, ArrayAttr normalFormsAttr, - const DenseSet *excludeAttrNames = nullptr) { +/// Each normal form's `checkOperation` method is responsible for walking the +/// IR (and excluding op-attribute payloads that should not participate in +/// normal-form type checking; see `getNormalFormTypeWalkExcludes` in +/// AMDGCNAttrs.cpp for the kernel-arguments exclusion). +static LogicalResult verifyNormalFormsRegions(Operation *op, + ArrayAttr normalFormsAttr) { if (!normalFormsAttr || normalFormsAttr.empty()) return success(); for (Attribute attr : normalFormsAttr) { - auto nf = cast(attr); - if (failed(normalform::verifyNormalForm(op, nf, /*emitDiagnostics=*/true, - excludeAttrNames))) + auto nf = cast(attr); + if (failed(nf.checkOperation(op).checkAndReport())) return failure(); } return success(); @@ -941,7 +940,7 @@ static LogicalResult verifyNormalFormsRegions( /// semantics. Returns true if the attribute was changed. static bool addNormalFormsImpl(Operation *op, StringRef attrName, ArrayAttr currentAttr, - ArrayRef nfs, + ArrayRef nfs, function_ref setter) { if (nfs.empty()) return false; @@ -951,7 +950,7 @@ addNormalFormsImpl(Operation *op, StringRef attrName, ArrayAttr currentAttr, nfSet.insert_range(currentAttr.getValue()); bool changed = false; - for (normalform::NormalFormAttrInterface nf : nfs) + for (mlir::transform::NormalFormAttrInterface nf : nfs) changed |= nfSet.insert(nf); if (!changed) @@ -966,7 +965,7 @@ addNormalFormsImpl(Operation *op, StringRef attrName, ArrayAttr currentAttr, /// Returns true if the attribute was changed. static bool removeNormalFormsImpl(Operation *op, ArrayAttr currentAttr, - ArrayRef nfs, + ArrayRef nfs, function_ref setter) { if (nfs.empty() || !currentAttr || currentAttr.empty()) return false; @@ -975,7 +974,7 @@ removeNormalFormsImpl(Operation *op, ArrayAttr currentAttr, nfSet.insert_range(currentAttr.getValue()); bool changed = false; - for (normalform::NormalFormAttrInterface nf : nfs) + for (mlir::transform::NormalFormAttrInterface nf : nfs) changed |= nfSet.remove(nf); if (!changed) @@ -995,14 +994,14 @@ LogicalResult amdgcn::ModuleOp::verifyRegions() { } bool amdgcn::ModuleOp::addNormalForms( - ArrayRef normalForms) { + ArrayRef normalForms) { return addNormalFormsImpl(getOperation(), getNormalFormsAttrName(), getNormalFormsAttr(), normalForms, [&](ArrayAttr attr) { setNormalFormsAttr(attr); }); } bool amdgcn::ModuleOp::removeNormalForms( - ArrayRef normalForms) { + ArrayRef normalForms) { return removeNormalFormsImpl( getOperation(), getNormalFormsAttr(), normalForms, [&](ArrayAttr attr) { setNormalFormsAttr(attr); }); @@ -1013,24 +1012,18 @@ bool amdgcn::ModuleOp::removeNormalForms( //===----------------------------------------------------------------------===// LogicalResult KernelOp::verifyRegions() { - // Exclude 'arguments' attribute from normal form type walking: kernel - // argument attrs (by_val_arg, buffer_arg) contain ABI metadata types, - // not computational register types in the kernel body. - DenseSet excludeAttrs; - excludeAttrs.insert(getArgumentsAttrName()); - return verifyNormalFormsRegions(getOperation(), getNormalFormsAttr(), - &excludeAttrs); + return verifyNormalFormsRegions(getOperation(), getNormalFormsAttr()); } bool KernelOp::addNormalForms( - ArrayRef normalForms) { + ArrayRef normalForms) { return addNormalFormsImpl(getOperation(), getNormalFormsAttrName(), getNormalFormsAttr(), normalForms, [&](ArrayAttr attr) { setNormalFormsAttr(attr); }); } bool KernelOp::removeNormalForms( - ArrayRef normalForms) { + ArrayRef normalForms) { return removeNormalFormsImpl( getOperation(), getNormalFormsAttr(), normalForms, [&](ArrayAttr attr) { setNormalFormsAttr(attr); }); diff --git a/lib/Dialect/AMDGCN/IR/AMDGCNAttrs.cpp b/lib/Dialect/AMDGCN/IR/AMDGCNAttrs.cpp index 05f1da1d2..47d953fb7 100644 --- a/lib/Dialect/AMDGCN/IR/AMDGCNAttrs.cpp +++ b/lib/Dialect/AMDGCN/IR/AMDGCNAttrs.cpp @@ -17,7 +17,13 @@ #include "aster/Interfaces/RegisterType.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" +#include "mlir/IR/AttrTypeSubElements.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/Support/WalkResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/ErrorHandling.h" @@ -225,198 +231,357 @@ int32_t GenericSchedLabelerAttr::getLabel(Operation *op, int32_t, } //===----------------------------------------------------------------------===// -// NoValueSemanticRegistersAttr +// Normal-form helpers //===----------------------------------------------------------------------===// -LogicalResult NoValueSemanticRegistersAttr::verifyType( - function_ref emitError, Type type) const { - auto regType = dyn_cast(type); - if (!regType) - return success(); - - if (regType.hasValueSemantics()) - return emitError() << "normal form violation: register types with value " - "semantics are disallowed but found: " - << type; +namespace { +/// Filter callback used by `walkTypes` to decide whether to descend into a +/// given `NamedAttribute` of an operation. Returning `false` skips the +/// attribute payload entirely. +using NamedAttrFilter = llvm::function_ref; + +/// Skips named attributes that carry ABI metadata whose register types are +/// not subject to normal-form invariants enforced on the kernel body. For +/// `amdgcn.kernel`, this excludes the `arguments` attribute (e.g. +/// `by_val_arg` parameter types). +bool skipKernelAbiMetadata(Operation *op, NamedAttribute attr) { + if (auto kernel = dyn_cast(op)) + return attr.getName() != kernel.getArgumentsAttrName(); + return true; +} - return success(); +/// Aggregates `DiagnosedSilenceableFailure` results across multiple visits: +/// records the first silenceable failure (silencing later ones) and short- +/// circuits on definite failures. +struct AttrTypeAggregator { + DiagnosedSilenceableFailure overall = DiagnosedSilenceableFailure::success(); + bool stop = false; + + void merge(DiagnosedSilenceableFailure &&result) { + if (result.isDefiniteFailure()) { + overall = std::move(result); + stop = true; + return; + } + if (result.isSilenceableFailure()) { + if (overall.succeeded()) + overall = std::move(result); + else + (void)result.silence(); + } + } +}; + +/// Walks all distinct types reachable from operations under `root`: operation +/// result types, block argument types, and types nested inside operation +/// attributes. Invokes `visitor` on each distinct type with a `Location` near +/// its discovery point. When `filter` is non-null, it is consulted for each +/// named attribute of every visited operation; returning `false` skips the +/// attribute payload (e.g. to exclude `amdgcn.kernel`'s `arguments` ABI +/// metadata). +DiagnosedSilenceableFailure walkTypes( + Operation *root, + llvm::function_ref visitor, + NamedAttrFilter filter = nullptr) { + AttrTypeAggregator agg; + llvm::SmallPtrSet seenTypes; + llvm::SmallPtrSet seenAttrs; + Location currentLoc = root->getLoc(); + AttrTypeWalker walker; + + walker.addWalk([&](Type type) { + auto [it, inserted] = seenTypes.insert(type); + if (!inserted) + return WalkResult::skip(); + agg.merge(visitor(type, currentLoc)); + if (agg.stop) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + walker.addWalk([&](Attribute attr) { + auto [it, inserted] = seenAttrs.insert(attr); + if (!inserted) + return WalkResult::skip(); + return WalkResult::advance(); + }); + + root->walk([&](Operation *op) { + currentLoc = op->getLoc(); + for (OpResult result : op->getResults()) { + currentLoc = result.getLoc(); + if (walker.walk(result.getType()).wasInterrupted()) + return WalkResult::interrupt(); + } + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + currentLoc = arg.getLoc(); + if (walker.walk(arg.getType()).wasInterrupted()) + return WalkResult::interrupt(); + } + } + } + currentLoc = op->getLoc(); + for (NamedAttribute attr : op->getAttrs()) { + if (filter && !filter(op, attr)) + continue; + if (walker.walk(attr.getValue()).wasInterrupted()) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return std::move(agg.overall); } +} // namespace + //===----------------------------------------------------------------------===// -// AllRegistersAllocatedAttr +// NoValueSemanticRegistersAttr //===----------------------------------------------------------------------===// -LogicalResult AllRegistersAllocatedAttr::verifyType( - function_ref emitError, Type type) const { - auto regType = dyn_cast(type); - if (!regType) - return success(); +DiagnosedSilenceableFailure +NoValueSemanticRegistersAttr::checkOperation(Operation *op) const { + return walkTypes( + op, + [](Type type, Location loc) -> DiagnosedSilenceableFailure { + auto regType = dyn_cast(type); + if (!regType || !regType.hasValueSemantics()) + return DiagnosedSilenceableFailure::success(); + return emitSilenceableFailure(loc) + << "normal form violation: register types with value " + "semantics are disallowed but found: " + << type; + }, + skipKernelAbiMetadata); +} - if (!regType.hasAllocatedSemantics()) - return emitError() << "normal form violation: all registers must have " - "allocated semantics but found: " - << type; +//===----------------------------------------------------------------------===// +// AllRegistersAllocatedAttr +//===----------------------------------------------------------------------===// - return success(); +DiagnosedSilenceableFailure +AllRegistersAllocatedAttr::checkOperation(Operation *op) const { + return walkTypes( + op, + [](Type type, Location loc) -> DiagnosedSilenceableFailure { + auto regType = dyn_cast(type); + if (!regType || regType.hasAllocatedSemantics()) + return DiagnosedSilenceableFailure::success(); + return emitSilenceableFailure(loc) + << "normal form violation: all registers must have " + "allocated semantics but found: " + << type; + }, + skipKernelAbiMetadata); } //===----------------------------------------------------------------------===// // NoRegCastOpsAttr //===----------------------------------------------------------------------===// -LogicalResult -NoRegCastOpsAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (isa(op)) - return emitError() << "normal form violation: lsir.reg_cast should not " - "survive past aster-to-amdgcn; this indicates an " - "incorrect lsir.to_reg or lsir.from_reg surviving " - "from high-level (hand-authored ?) IR"; - return success(); +DiagnosedSilenceableFailure +NoRegCastOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!isa(innerOp)) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: lsir.reg_cast should not " + "survive past aster-to-amdgcn; this indicates an " + "incorrect lsir.to_reg or lsir.from_reg surviving " + "from high-level (hand-authored ?) IR"); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoLsirOpsAttr //===----------------------------------------------------------------------===// -LogicalResult -NoLsirOpsAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (op->getDialect() && op->getDialect()->getNamespace() == "lsir") - return emitError() << "normal form violation: LSIR dialect operations " - "are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure NoLsirOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!innerOp->getDialect() || + innerOp->getDialect()->getNamespace() != "lsir") + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: LSIR dialect operations " + "are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoLsirComputeOpsAttr //===----------------------------------------------------------------------===// -LogicalResult NoLsirComputeOpsAttr::verifyOperation( - function_ref emitError, Operation *op) const { - if (!op->getDialect() || op->getDialect()->getNamespace() != "lsir") - return success(); - - // Allow control-flow ops (lowered by LegalizeCF) and copy (regalloc - // primitive). - if (isa(op)) - return success(); - - return emitError() << "normal form violation: LSIR compute/memory " - "operations are disallowed but found: " - << op->getName(); +DiagnosedSilenceableFailure +NoLsirComputeOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!innerOp->getDialect() || + innerOp->getDialect()->getNamespace() != "lsir") + return WalkResult::advance(); + // Allow control-flow ops (lowered by LegalizeCF) and copy (regalloc + // primitive). + if (isa(innerOp)) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: LSIR compute/memory " + "operations are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoLsirControlOpsAttr //===----------------------------------------------------------------------===// -LogicalResult NoLsirControlOpsAttr::verifyOperation( - function_ref emitError, Operation *op) const { - if (isa(op)) - return emitError() << "normal form violation: LSIR control-flow " - "operations are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure +NoLsirControlOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!isa(innerOp)) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: LSIR control-flow " + "operations are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoScfOpsAttr //===----------------------------------------------------------------------===// -LogicalResult -NoScfOpsAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (op->getDialect() && op->getDialect()->getNamespace() == "scf") - return emitError() << "normal form violation: SCF dialect operations " - "are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure NoScfOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!innerOp->getDialect() || + innerOp->getDialect()->getNamespace() != "scf") + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: SCF dialect operations " + "are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoCfBranchesAttr //===----------------------------------------------------------------------===// -LogicalResult -NoCfBranchesAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (isa(op)) - return emitError() << "normal form violation: cf.br/cf.cond_br operations " - "are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure +NoCfBranchesAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!isa(innerOp)) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: cf.br/cf.cond_br operations " + "are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoRegisterBlockArgsAttr //===----------------------------------------------------------------------===// -LogicalResult NoRegisterBlockArgsAttr::verifyOperation( - function_ref emitError, Operation *op) const { - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - for (BlockArgument arg : block.getArguments()) { - if (isa(arg.getType())) - return emitError() - << "normal form violation: block arguments with register " - "types are disallowed but found: " - << arg.getType(); +DiagnosedSilenceableFailure +NoRegisterBlockArgsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + for (Region ®ion : innerOp->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + if (!isa(arg.getType())) + continue; + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: block arguments with " + "register types are disallowed but found: " + << arg.getType()); + if (agg.stop) + return WalkResult::interrupt(); + // Only report once per op to mirror the original behavior. + return WalkResult::advance(); + } } } - } - return success(); + return WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoAffineOpsAttr //===----------------------------------------------------------------------===// -LogicalResult -NoAffineOpsAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (op->getDialect() && op->getDialect()->getNamespace() == "affine") - return emitError() << "normal form violation: affine dialect operations " - "are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure +NoAffineOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!innerOp->getDialect() || + innerOp->getDialect()->getNamespace() != "affine") + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: affine dialect operations " + "are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoMetadataOpsAttr //===----------------------------------------------------------------------===// -LogicalResult -NoMetadataOpsAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (isa(op)) - return emitError() << "normal form violation: AMDGCN metadata operations " - "are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure +NoMetadataOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!isa(innerOp)) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: AMDGCN metadata operations " + "are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // AllInlinedAttr //===----------------------------------------------------------------------===// -LogicalResult -AllInlinedAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (isa(op)) - return emitError() << "normal form violation: func.call operations " - "are disallowed (all functions should be inlined) " - "but found call to '" - << cast(op).getCallee() << "'"; - - return success(); +DiagnosedSilenceableFailure +AllInlinedAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + auto callOp = dyn_cast(innerOp); + if (!callOp) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: func.call operations " + "are disallowed (all functions should be inlined) " + "but found call to '" + << callOp.getCallee() << "'"); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } diff --git a/lib/Dialect/AMDGCN/IR/CMakeLists.txt b/lib/Dialect/AMDGCN/IR/CMakeLists.txt index 4443a22e4..8cf2c79e8 100644 --- a/lib/Dialect/AMDGCN/IR/CMakeLists.txt +++ b/lib/Dialect/AMDGCN/IR/CMakeLists.txt @@ -24,7 +24,6 @@ add_mlir_dialect_library(AMDGCNDialect ASTERIR AsterInterfaces LSIRDialect - MLIRNormalFormDialect MLIRArithDialect MLIRControlFlowDialect MLIRControlFlowInterfaces @@ -33,4 +32,6 @@ add_mlir_dialect_library(AMDGCNDialect MLIRIR MLIRInferTypeOpInterface MLIRPtrDialect + MLIRTransformDialectInterfaces + MLIRTransformDialectUtils ) diff --git a/lib/Dialect/AMDGCN/Transforms/CMakeLists.txt b/lib/Dialect/AMDGCN/Transforms/CMakeLists.txt index 4e39717c3..7c1fd264c 100644 --- a/lib/Dialect/AMDGCN/Transforms/CMakeLists.txt +++ b/lib/Dialect/AMDGCN/Transforms/CMakeLists.txt @@ -38,7 +38,6 @@ add_mlir_library(AMDGCNTransforms AMDGCNAnalysis AMDGCNDialect AsterUtilsDialect - MLIRNormalFormDialect ASTERAnalysis AsterTransforms AsterInterfaces @@ -49,6 +48,8 @@ add_mlir_library(AMDGCNTransforms MLIRPass MLIRSCFDialect MLIRSCFUtils + MLIRTransformDialectInterfaces + MLIRTransformDialectUtils MLIRTransforms MLIRUBDialect ) diff --git a/lib/Dialect/AMDGCN/Transforms/LowLevelScheduler.cpp b/lib/Dialect/AMDGCN/Transforms/LowLevelScheduler.cpp index 7ff314f2d..a68e21d87 100644 --- a/lib/Dialect/AMDGCN/Transforms/LowLevelScheduler.cpp +++ b/lib/Dialect/AMDGCN/Transforms/LowLevelScheduler.cpp @@ -18,7 +18,8 @@ #include "aster/Dialect/AMDGCN/Transforms/Passes.h" #include "aster/Dialect/AsterUtils/IR/AsterUtilsAttrs.h" #include "aster/Dialect/AsterUtils/IR/AsterUtilsDialect.h" -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" namespace mlir::aster { namespace amdgcn { @@ -43,8 +44,7 @@ struct LowLevelSchedulerPass MLIRContext *ctx = kernel.getContext(); auto allInlined = AllInlinedAttr::get(ctx); - if (failed(normalform::verifyNormalForm(kernel, allInlined, - /*emitDiagnostics=*/true))) + if (failed(allInlined.checkOperation(kernel).checkAndReport())) return signalPassFailure(); GenericSchedulerAttr compositeAttr = GenericSchedulerAttr::get( diff --git a/lib/Dialect/AMDGCN/Transforms/SetNormalForms.cpp b/lib/Dialect/AMDGCN/Transforms/SetNormalForms.cpp index 1a768dd31..032c5d9be 100644 --- a/lib/Dialect/AMDGCN/Transforms/SetNormalForms.cpp +++ b/lib/Dialect/AMDGCN/Transforms/SetNormalForms.cpp @@ -11,8 +11,8 @@ #include "aster/Dialect/AMDGCN/IR/AMDGCNOps.h" #include "aster/Dialect/AMDGCN/Transforms/Passes.h" -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" #include "mlir/AsmParser/AsmParser.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" namespace mlir::aster { namespace amdgcn { @@ -35,15 +35,15 @@ struct SetNormalForms private: /// Parse a list of mnemonic strings into NormalFormAttrInterface attributes. /// Returns failure if any mnemonic is invalid. - FailureOr> + FailureOr> parseFormMnemonics(ArrayRef mnemonics); }; } // namespace -FailureOr> +FailureOr> SetNormalForms::parseFormMnemonics(ArrayRef mnemonics) { MLIRContext *ctx = &getContext(); - SmallVector attrs; + SmallVector attrs; for (const std::string &mnemonic : mnemonics) { std::string attrStr = "#amdgcn." + mnemonic; Attribute attr = mlir::parseAttribute(attrStr, ctx); @@ -53,7 +53,7 @@ SetNormalForms::parseFormMnemonics(ArrayRef mnemonics) { << "(tried to parse as '" << attrStr << "')"; return failure(); } - auto nfAttr = dyn_cast(attr); + auto nfAttr = dyn_cast(attr); if (!nfAttr) { getOperation()->emitError() << "attribute '" << attrStr diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 68bb9f07d..407936f00 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -2,4 +2,3 @@ add_subdirectory(AMDGCN) add_subdirectory(AsterUtils) add_subdirectory(Layout) add_subdirectory(LSIR) -add_subdirectory(NormalForm) diff --git a/lib/Dialect/NormalForm/CMakeLists.txt b/lib/Dialect/NormalForm/CMakeLists.txt deleted file mode 100644 index 9f57627c3..000000000 --- a/lib/Dialect/NormalForm/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/lib/Dialect/NormalForm/IR/CMakeLists.txt b/lib/Dialect/NormalForm/IR/CMakeLists.txt deleted file mode 100644 index 7ed5ba119..000000000 --- a/lib/Dialect/NormalForm/IR/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -add_mlir_dialect_library(MLIRNormalFormDialect - NormalFormDialect.cpp - NormalFormOps.cpp - NormalFormInterfaces.cpp - - DEPENDS - MLIRNormalFormIncGen - - LINK_LIBS PUBLIC - MLIRIR -) - -# Install the dialect library so Python can find it at runtime. -install(TARGETS MLIRNormalFormDialect - LIBRARY DESTINATION lib - ARCHIVE DESTINATION lib - RUNTIME DESTINATION bin -) diff --git a/lib/Dialect/NormalForm/IR/NormalFormDialect.cpp b/lib/Dialect/NormalForm/IR/NormalFormDialect.cpp deleted file mode 100644 index a8fe40fb5..000000000 --- a/lib/Dialect/NormalForm/IR/NormalFormDialect.cpp +++ /dev/null @@ -1,26 +0,0 @@ -//===- NormalFormDialect.cpp - dialect definition -------------------------===// -// -// Copyright 2025 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.h" -#include "aster/Dialect/NormalForm/IR/NormalFormOps.h" - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" - -using namespace mlir; - -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.cpp.inc" - -void normalform::NormalFormDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "aster/Dialect/NormalForm/IR/NormalFormOps.cpp.inc" - >(); -} diff --git a/lib/Dialect/NormalForm/IR/NormalFormInterfaces.cpp b/lib/Dialect/NormalForm/IR/NormalFormInterfaces.cpp deleted file mode 100644 index 1cd00493e..000000000 --- a/lib/Dialect/NormalForm/IR/NormalFormInterfaces.cpp +++ /dev/null @@ -1,16 +0,0 @@ -//===- NormalFormInterfaces.cpp -------------------------------------------===// -// -// Copyright 2025 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.h" - -using namespace mlir; - -#include "aster/Dialect/NormalForm/IR/NormalFormAttrInterfaces.cpp.inc" diff --git a/lib/Dialect/NormalForm/IR/NormalFormOps.cpp b/lib/Dialect/NormalForm/IR/NormalFormOps.cpp deleted file mode 100644 index 522044412..000000000 --- a/lib/Dialect/NormalForm/IR/NormalFormOps.cpp +++ /dev/null @@ -1,250 +0,0 @@ -//===- NormalFormOps.cpp --------------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "aster/Dialect/NormalForm/IR/NormalFormOps.h" -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/WalkResult.h" - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Diagnostics.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/LogicalResult.h" - -using namespace mlir; -using namespace normalform; - -#define GET_OP_CLASSES -#include "aster/Dialect/NormalForm/IR/NormalFormOps.cpp.inc" - -//----------------------------------------------------------------------------- -// ModuleOp -//----------------------------------------------------------------------------- - -void normalform::ModuleOp::build(OpBuilder &builder, OperationState &state, - ArrayRef normalForms, - std::optional name) { - state.addRegion()->emplaceBlock(); - ArrayRef attributeArray = - ArrayRef(normalForms.begin(), normalForms.end()); - ArrayAttr normalFormsArray = builder.getArrayAttr(attributeArray); - // Use the tablegen-generated attribute name "normal_forms". - state.addAttribute(getNormalFormsAttrName(state.name), normalFormsArray); - if (name) { - state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), - builder.getStringAttr(*name)); - } -} - -/// Construct a module from the given context. -normalform::ModuleOp -normalform::ModuleOp::create(Location loc, - ArrayRef normalForms, - std::optional name) { - OpBuilder builder(loc->getContext()); - return ModuleOp::create(builder, loc, normalForms, name); -} - -LogicalResult normalform::verifyNormalForm( - Operation *root, NormalFormAttrInterface normalForm, bool emitDiagnostics, - const DenseSet *excludeAttrNames) { - SmallPtrSet seenTypes; - SmallPtrSet seenAttrs; - Location loc = root->getLoc(); - AttrTypeWalker walker; - - auto emitLocError = [&]() { - InFlightDiagnostic diag = mlir::emitError(loc); - if (!emitDiagnostics) - diag.abandon(); - - return diag; - }; - - auto visitType = [&](Type type) { - auto [it, inserted] = seenTypes.insert(type); - if (!inserted) - return WalkResult::skip(); - - if (llvm::failed(normalForm.verifyType(emitLocError, type))) - return WalkResult::interrupt(); - - return WalkResult::advance(); - }; - - auto visitAttr = [&](Attribute attr) { - auto [it, inserted] = seenAttrs.insert(attr); - if (!inserted) - return WalkResult::skip(); - - if (llvm::failed(normalForm.verifyAttribute(emitLocError, attr))) - return WalkResult::interrupt(); - - return WalkResult::advance(); - }; - - walker.addWalk(visitType); - walker.addWalk(visitAttr); - - auto visitOp = [&](Operation *op) { - loc = op->getLoc(); - - // TODO: skip when we reach another normalform.module which has normalform - // in its attributes - if (llvm::failed(normalForm.verifyOperation(emitLocError, op))) - return WalkResult::interrupt(); - - for (OpResult result : op->getResults()) { - loc = result.getLoc(); - WalkResult walkResult = walker.walk(result.getType()); - if (walkResult.wasInterrupted()) - return WalkResult::interrupt(); - } - - for (mlir::Region ®ion : op->getRegions()) { - for (mlir::Block &block : region) { - for (mlir::BlockArgument arg : block.getArguments()) { - loc = arg.getLoc(); - WalkResult walkResult = walker.walk(arg.getType()); - if (walkResult.wasInterrupted()) - return WalkResult::interrupt(); - } - } - } - - for (NamedAttribute attr : op->getAttrs()) { - if (excludeAttrNames && excludeAttrNames->contains(attr.getName())) - continue; - WalkResult walkResult = walker.walk(attr.getValue()); - if (walkResult.wasInterrupted()) - return WalkResult::interrupt(); - } - - return WalkResult::advance(); - }; - - WalkResult walkResult = root->walk(visitOp); - - return llvm::failure(walkResult.wasInterrupted()); -} - -LogicalResult -normalform::ModuleOp::verifyNormalForm(NormalFormAttrInterface normalForm, - bool emitDiagnostics) { - return normalform::verifyNormalForm(getOperation(), normalForm, - emitDiagnostics); -} - -bool normalform::ModuleOp::inferNormalForms( - ArrayRef normalForms) { - ArrayRef currentNormalForms = getNormalFormsAttr().getValue(); - SetVector normalFormSet; - normalFormSet.insert_range(currentNormalForms); - - bool changed = false; - for (NormalFormAttrInterface nf : normalForms) { - if (normalFormSet.contains(nf)) - continue; - - if (llvm::succeeded(verifyNormalForm(nf, /*emitDiagnostics*/ false))) { - normalFormSet.insert(nf); - changed = true; - } - } - - if (!changed) - return false; - - OpBuilder builder(getContext()); - ArrayAttr newNormalFormsAttr = - builder.getArrayAttr(normalFormSet.getArrayRef()); - setNormalFormsAttr(newNormalFormsAttr); - return true; -} - -bool normalform::ModuleOp::addNormalForms( - ArrayRef normalForms) { - if (normalForms.empty()) - return false; - - ArrayRef currentNormalForms = getNormalFormsAttr().getValue(); - SetVector normalFormSet; - normalFormSet.insert_range(currentNormalForms); - - bool changed = false; - for (NormalFormAttrInterface nf : normalForms) - changed |= normalFormSet.insert(nf); - - if (!changed) - return false; - - OpBuilder builder(getContext()); - ArrayAttr newNormalFormsAttr = - builder.getArrayAttr(normalFormSet.getArrayRef()); - setNormalFormsAttr(newNormalFormsAttr); - return true; -} - -bool normalform::ModuleOp::removeNormalForms( - ArrayRef normalForms) { - if (normalForms.empty()) - return false; - - ArrayRef currentNormalForms = getNormalFormsAttr().getValue(); - SetVector normalFormSet; - normalFormSet.insert_range(currentNormalForms); - - bool changed = false; - for (NormalFormAttrInterface nf : normalForms) - changed |= normalFormSet.remove(nf); - - if (!changed) - return false; - - OpBuilder builder(getContext()); - ArrayAttr newNormalFormsAttr = - builder.getArrayAttr(normalFormSet.getArrayRef()); - setNormalFormsAttr(newNormalFormsAttr); - return true; -} - -LogicalResult normalform::ModuleOp::verify() { - // Verify that normal form attributes are unique (set semantics). - ArrayAttr normalForms = getNormalFormsAttr(); - SmallPtrSet seenForms; - for (Attribute attr : normalForms) { - auto [it, inserted] = seenForms.insert(attr); - if (!inserted) - return emitOpError() << "contains duplicate normal form attribute: " - << attr; - } - return llvm::success(); -} - -LogicalResult normalform::ModuleOp::verifyRegions() { - ArrayRef normalFormsAttrs = getNormalForms().getValue(); - auto normalFormRange = - llvm::map_range(normalFormsAttrs, llvm::CastTo); - - for (NormalFormAttrInterface normalForm : normalFormRange) { - if (llvm::failed(verifyNormalForm(normalForm, /*emitDiagnostics*/ true))) { - return llvm::failure(); - } - } - return llvm::success(); -} diff --git a/lib/Dialect/NormalForm/Transforms/CMakeLists.txt b/lib/Dialect/NormalForm/Transforms/CMakeLists.txt deleted file mode 100644 index 663f5a8ff..000000000 --- a/lib/Dialect/NormalForm/Transforms/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -add_mlir_dialect_library(MLIRNormalFormTransforms - LowerNormalFormModule.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/aster/ - - DEPENDS - MLIRAsterNormalFormPassesIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRNormalFormDialect - MLIRParser - MLIRPass - MLIRRewrite - MLIRTransformUtils -) diff --git a/lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp b/lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp deleted file mode 100644 index 01311689b..000000000 --- a/lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp +++ /dev/null @@ -1,102 +0,0 @@ -//===- LowerNormalFormModule.cpp ------------------------------------------===// -// -// Copyright 2025 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.h" -#include "aster/Dialect/NormalForm/IR/NormalFormOps.h" -#include "aster/Dialect/NormalForm/Transforms/Passes.h" -#include "mlir/Transforms/WalkPatternRewriteDriver.h" - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" -#include - -#define GEN_PASS_DEF_LOWERNORMALFORMMODULEPASS -#include "aster/Dialect/NormalForm/Transforms/Passes.h.inc" - -using namespace mlir; - -namespace { - -//===----------------------------------------------------------------------===// -// LowerNormalFormModulePattern -//===----------------------------------------------------------------------===// - -/// Lower `normalform.module` to `builtin.module`, discarding normal form -/// attributes. -class LowerNormalFormModulePattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(normalform::ModuleOp nfModule, - PatternRewriter &rewriter) const override { - // Check if parent is a builtin module - if so, inline contents into parent. - if (auto parentModule = dyn_cast(nfModule->getParentOp())) { - rewriter.setInsertionPoint(nfModule); - for (Operation &op : llvm::make_early_inc_range(*nfModule.getBody())) - rewriter.moveOpBefore(&op, nfModule); - rewriter.eraseOp(nfModule); - return success(); - } - - // Otherwise, create a new builtin module. - ModuleOp builtinModule = - ModuleOp::create(rewriter, nfModule.getLoc(), nfModule.getName()); - - // Move all blocks from the normalform module to the builtin module. - rewriter.inlineRegionBefore(nfModule.getRegion(), builtinModule.getBody()); - - // Remove the empty terminator block that was automatically added by the - // builder. - rewriter.eraseBlock(&builtinModule.getBodyRegion().back()); - - rewriter.eraseOp(nfModule); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// LowerNormalFormModulePass -//===----------------------------------------------------------------------===// - -struct LowerNormalFormModulePass - : public ::impl::LowerNormalFormModulePassBase { - using LowerNormalFormModulePassBase::LowerNormalFormModulePassBase; - - void runOnOperation() override { - Operation *root = getOperation(); - - if (auto rootModule = dyn_cast(root)) { - int64_t count = llvm::count_if(rootModule.getBody()->getOperations(), - llvm::IsaPred); - - if (count > 1) { - rootModule.emitError() - << "expected at most one top-level " - << normalform::ModuleOp::getOperationName() << ", found " << count; - return signalPassFailure(); - } - } - - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - - walkAndApplyPatterns(getOperation(), std::move(patterns)); - } -}; - -} // namespace - -std::unique_ptr normalform::createLowerNormalFormModulePass() { - return std::make_unique(); -} diff --git a/lib/Init.cpp b/lib/Init.cpp index 2f3cc9d72..189108899 100644 --- a/lib/Init.cpp +++ b/lib/Init.cpp @@ -17,8 +17,6 @@ #include "aster/Dialect/LSIR/IR/LSIRDialect.h" #include "aster/Dialect/Layout/IR/LayoutDialect.h" #include "aster/Dialect/Layout/Transforms/Passes.h" -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.h" -#include "aster/Dialect/NormalForm/Transforms/Passes.h" #include "aster/Interfaces/UpstreamExternalModels.h" #include "aster/Transforms/Passes.h" #include "mlir/CAPI/IR.h" @@ -411,7 +409,6 @@ void mlir::aster::initDialects(DialectRegistry ®istry) { registry.insert(); registry.insert(); registry.insert(); - registry.insert(); registerUpstreamExternalModels(registry); if (contribRegisterFn) contribRegisterFn(registry); @@ -424,7 +421,6 @@ void mlir::aster::registerPasses() { aster::registerAsterPasses(); aster::registerCodeGenPasses(); layout::registerLayoutPasses(); - normalform::registerNormalFormPasses(); } /// diff --git a/llvm/LLVM_COMMIT b/llvm/LLVM_COMMIT index c4e2952b0..f3e340346 100644 --- a/llvm/LLVM_COMMIT +++ b/llvm/LLVM_COMMIT @@ -1 +1 @@ -3bbb7cc6de0100c101604a9e69aae9ef345368f2 +9deb1c631b11230787f0fb56583b17f060b194a0 diff --git a/test/Dialect/AMDGCN/IR/normal-forms-kernel.mlir b/test/Dialect/AMDGCN/IR/normal-forms-kernel.mlir index 4257bbe84..745246eea 100644 --- a/test/Dialect/AMDGCN/IR/normal-forms-kernel.mlir +++ b/test/Dialect/AMDGCN/IR/normal-forms-kernel.mlir @@ -19,4 +19,17 @@ amdgcn.module @test target = #amdgcn.target isa = #amdgcn.isa { ^bb0: amdgcn.end_kernel } + + // The kernel's `arguments` attribute carries ABI metadata whose register + // types are not subject to the no_value_semantic_registers check. + // CHECK: kernel @by_val_arg_metadata + // CHECK-SAME: attributes {normal_forms = [#amdgcn.no_value_semantic_registers] + amdgcn.kernel @by_val_arg_metadata arguments <[ + #amdgcn.by_val_arg + ]> attributes { + normal_forms = [#amdgcn.no_value_semantic_registers], + shared_memory_size = 0 : i32 + } { + amdgcn.end_kernel + } } diff --git a/test/Dialect/NormalForm/amdgcn-no-value-semantic-registers.mlir b/test/Dialect/NormalForm/amdgcn-no-value-semantic-registers.mlir deleted file mode 100644 index 60b6b4235..000000000 --- a/test/Dialect/NormalForm/amdgcn-no-value-semantic-registers.mlir +++ /dev/null @@ -1,52 +0,0 @@ -// RUN: aster-opt %s --split-input-file --verify-diagnostics - -normalform.module @value_vgpr [#amdgcn.no_value_semantic_registers] { - func.func @f() { - // expected-error @below {{normal form violation: register types with value semantics are disallowed but found}} - %0 = amdgcn.alloca : !amdgcn.vgpr - return - } -} - -// ----- - -normalform.module @value_sgpr [#amdgcn.no_value_semantic_registers] { - func.func @f() { - // expected-error @below {{normal form violation: register types with value semantics are disallowed but found}} - %0 = amdgcn.alloca : !amdgcn.sgpr - return - } -} - -// ----- - -normalform.module @value_in_func_arg [#amdgcn.no_value_semantic_registers] { - // expected-error @below {{normal form violation: register types with value semantics are disallowed but found}} - func.func @f(%arg: !amdgcn.vgpr) { - return - } -} - -// ----- - -normalform.module @value_in_func_result [#amdgcn.no_value_semantic_registers] { - // expected-error @below {{normal form violation: register types with value semantics are disallowed but found}} - func.func @f() -> !amdgcn.vgpr { - %0 = amdgcn.alloca : !amdgcn.vgpr - return %0 : !amdgcn.vgpr - } -} - -// ----- - -// by_val_arg type is ABI metadata, not a value-semantic register in the body -amdgcn.module @by_val_arg_metadata target = #amdgcn.target isa = #amdgcn.isa { - amdgcn.kernel @test arguments <[ - #amdgcn.by_val_arg - ]> attributes { - normal_forms = [#amdgcn.no_value_semantic_registers], - shared_memory_size = 0 : i32 - } { - amdgcn.end_kernel - } -} diff --git a/test/Dialect/NormalForm/amdgcn-to-register-semantics-postcondition.mlir b/test/Dialect/NormalForm/amdgcn-to-register-semantics-postcondition.mlir deleted file mode 100644 index 6f6233eba..000000000 --- a/test/Dialect/NormalForm/amdgcn-to-register-semantics-postcondition.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// Demonstrates that #amdgcn.no_value_semantic_registers is a post-condition -// of the amdgcn-to-register-semantics pass. -// -// Step 1: Run the pass on IR with value-semantic registers. -// Step 2: Wrap the output in a normalform.module and re-verify. -// -// RUN: aster-opt --amdgcn-to-register-semantics %s \ -// RUN: | sed '1s/^module/normalform.module [#amdgcn.no_value_semantic_registers]/' \ -// RUN: | aster-opt --verify-diagnostics - -func.func @value_to_unallocated() { - %0 = amdgcn.alloca : !amdgcn.vgpr - %1 = amdgcn.alloca : !amdgcn.vgpr - %2 = lsir.copy %1, %0 : !amdgcn.vgpr, !amdgcn.vgpr - %3 = amdgcn.test_inst outs %0 : (!amdgcn.vgpr) -> !amdgcn.vgpr - amdgcn.test_inst ins %3, %2 : (!amdgcn.vgpr, !amdgcn.vgpr) -> () - func.return -} - -func.func @mixed_types(%arg: i32) -> f32 { - %0 = amdgcn.alloca : !amdgcn.vgpr - %1 = amdgcn.alloca : !amdgcn.sgpr - %cst = arith.constant 0.0 : f32 - return %cst : f32 -} diff --git a/test/Dialect/NormalForm/lower-normalform-module-invalid.mlir b/test/Dialect/NormalForm/lower-normalform-module-invalid.mlir deleted file mode 100644 index 24d06367d..000000000 --- a/test/Dialect/NormalForm/lower-normalform-module-invalid.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: aster-opt %s -lower-normalform-module --split-input-file --verify-diagnostics - -//----------------------------------------------------------------------------- -// Test that multiple top-level normalform.module operations are rejected. -//----------------------------------------------------------------------------- - -// expected-error @below {{expected at most one top-level normalform.module, found 2}} -module { - normalform.module [] { - func.func @foo() { - return - } - } - normalform.module [] { - func.func @bar() { - return - } - } -} diff --git a/test/Dialect/NormalForm/lower-normalform-module.mlir b/test/Dialect/NormalForm/lower-normalform-module.mlir deleted file mode 100644 index f0b3c8bdd..000000000 --- a/test/Dialect/NormalForm/lower-normalform-module.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: aster-opt %s -lower-normalform-module --mlir-print-local-scope --split-input-file | FileCheck %s - -//----------------------------------------------------------------------------- -// Test lowering of normalform.module to builtin.module. -//----------------------------------------------------------------------------- - -// Test that a top-level normalform.module is inlined into the root module. -// CHECK: module { -// CHECK-NOT: normalform.module -// CHECK-NOT: module { -// CHECK: func.func @inlined_into_root() -// CHECK: } -normalform.module [] { - func.func @inlined_into_root() { - return - } -} - -// ----- - -// Test that a named normalform.module is inlined into the root module. -// CHECK: module { -// CHECK-NOT: normalform.module -// CHECK-NOT: module { -// CHECK: func.func @from_named_module() -// CHECK: } -normalform.module @named [] { - func.func @from_named_module() { - return - } -} - -// ----- - -// Test that multiple operations are preserved when inlining. -// CHECK: module { -// CHECK: func.func @first() -// CHECK: func.func @second() -// CHECK: func.func @third() -// CHECK: } -normalform.module [] { - func.func @first() { - return - } - func.func @second() { - return - } - func.func @third() { - return - } -} diff --git a/test/Dialect/NormalForm/ops-invalid.mlir b/test/Dialect/NormalForm/ops-invalid.mlir deleted file mode 100644 index 308f635c3..000000000 --- a/test/Dialect/NormalForm/ops-invalid.mlir +++ /dev/null @@ -1,134 +0,0 @@ -// RUN: aster-opt %s --split-input-file --verify-diagnostics - -//----------------------------------------------------------------------------- -// Test dialect normal form attribute tests (single attribute). -//----------------------------------------------------------------------------- - -// no_index_types: index function argument. -normalform.module @no_index_types_arg [#aster_test.no_index_types] { - // expected-error @below {{normal form prohibits index types}} - func.func @f(%arg: index) { - return - } -} - -// ----- - -// no_index_types: index result type. -normalform.module @no_index_types_result [#aster_test.no_index_types] { - // expected-error @below {{normal form prohibits index types}} - func.func @f() -> index { - %0 = arith.constant 0 : index - return %0 : index - } -} - -// ----- - -// no_index_types: index-typed block argument in nested region (scf.for iter_arg). -// This tests that block arguments in nested regions are verified for type -// constraints. The error is triggered by the arith.constant result type, but -// the scf.for iter_arg and induction variable block arguments are also checked. -normalform.module @no_index_types_nested_block_arg [#aster_test.no_index_types] { - func.func @f() { - // expected-error @below {{normal form prohibits index types}} - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c10 = arith.constant 10 : index - %result = scf.for %iv = %c0 to %c10 step %c1 iter_args(%iter = %c0) -> (index) { - scf.yield %iter : index - } - return - } -} - -// ----- - -// no_invalid_ops: division operation. -normalform.module @no_invalid_ops [#aster_test.no_invalid_ops] { - func.func @f(%a: f32, %b: f32) -> f32 { - // expected-error @below {{normal form prohibits division operations}} - %0 = arith.divf %a, %b : f32 - return %0 : f32 - } -} - -// ----- - -// no_invalid_attrs: string attribute with value "invalid". -normalform.module @no_invalid_attrs [#aster_test.no_invalid_attrs] { - // expected-error @below {{normal form prohibits 'invalid' string attribute values}} - func.func @f() attributes {foo = "invalid"} { - return - } -} - -// ----- - -//----------------------------------------------------------------------------- -// Test dialect normal form attribute tests (multiple attributes). -//----------------------------------------------------------------------------- - -// Multiple test attributes: violation on no_invalid_ops. -normalform.module @multi_attrs_invalid_op [#aster_test.no_index_types, #aster_test.no_invalid_ops] { - func.func @f(%a: i32, %b: i32) -> i32 { - // expected-error @below {{normal form prohibits division operations}} - %0 = arith.divsi %a, %b : i32 - return %0 : i32 - } -} - -// ----- - -// Multiple test attributes: only no_invalid_attrs violation present. -normalform.module @multi_attrs_invalid_attr [#aster_test.no_index_types, #aster_test.no_invalid_attrs] { - // expected-error @below {{normal form prohibits 'invalid' string attribute values}} - func.func @f() attributes {x = "invalid"} { - return - } -} - -// ----- - -// Multiple test attributes: only no_index_types violation present. -normalform.module @multi_attrs_index_type [#aster_test.no_invalid_ops, #aster_test.no_index_types] { - // expected-error @below {{normal form prohibits index types}} - func.func @f(%arg: index) { - return - } -} - -// ----- - -// All three attributes: violation on no_invalid_ops. -normalform.module @all_attrs_invalid_op [#aster_test.no_index_types, #aster_test.no_invalid_ops, #aster_test.no_invalid_attrs] { - func.func @f(%a: i32, %b: i32) -> i32 { - // expected-error @below {{normal form prohibits division operations}} - %0 = arith.divui %a, %b : i32 - return %0 : i32 - } -} - -// ----- - -// All three attributes: violation on no_invalid_attrs only. -normalform.module @all_attrs_invalid_attr [#aster_test.no_index_types, #aster_test.no_invalid_ops, #aster_test.no_invalid_attrs] { - // expected-error @below {{normal form prohibits 'invalid' string attribute values}} - func.func @f() attributes {x = "invalid"} { - return - } -} - -// ----- - -//----------------------------------------------------------------------------- -// Duplicate attribute rejection. -//----------------------------------------------------------------------------- - -// Duplicate normal form attributes are rejected. -// expected-error @below {{contains duplicate normal form attribute}} -normalform.module @duplicate [#aster_test.no_index_types, #aster_test.no_index_types] { - func.func @f() { - return - } -} diff --git a/test/Dialect/NormalForm/ops.mlir b/test/Dialect/NormalForm/ops.mlir deleted file mode 100644 index 421472d99..000000000 --- a/test/Dialect/NormalForm/ops.mlir +++ /dev/null @@ -1,89 +0,0 @@ -// RUN: aster-opt %s | aster-opt | FileCheck %s -// RUN: aster-opt %s --mlir-print-op-generic | aster-opt | FileCheck %s - -//----------------------------------------------------------------------------- -// Test dialect normal form attribute tests (single attribute). -//----------------------------------------------------------------------------- - -// no_index_types passes when no index types are used. -// CHECK-LABEL: normalform.module @no_index_types_valid -// CHECK-SAME: [#aster_test.no_index_types] -normalform.module @no_index_types_valid [#aster_test.no_index_types] { - func.func @f(%arg: i32) -> f32 { - %cst = arith.constant 0.0 : f32 - return %cst : f32 - } -} - -// no_invalid_ops passes when no division operations are present. -// CHECK-LABEL: normalform.module @no_invalid_ops_valid -// CHECK-SAME: [#aster_test.no_invalid_ops] -normalform.module @no_invalid_ops_valid [#aster_test.no_invalid_ops] { - func.func @f(%a: f32, %b: f32) -> f32 { - %0 = arith.mulf %a, %b : f32 - return %0 : f32 - } -} - -// no_invalid_attrs passes when no "invalid" string attributes are present. -// CHECK-LABEL: normalform.module @no_invalid_attrs_valid -// CHECK-SAME: [#aster_test.no_invalid_attrs] -normalform.module @no_invalid_attrs_valid [#aster_test.no_invalid_attrs] { - func.func @f() attributes {foo = "valid", bar = 42 : i32} { - return - } -} - -// no_invalid_ops passes with scf.for block arguments (iter_args and induction -// variable). This tests that block arguments in nested regions are properly -// walked and verified. -// CHECK-LABEL: normalform.module @no_invalid_ops_nested_block_args -// CHECK-SAME: [#aster_test.no_invalid_ops] -normalform.module @no_invalid_ops_nested_block_args [#aster_test.no_invalid_ops] { - func.func @f() -> f32 { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c10 = arith.constant 10 : index - %init = arith.constant 0.0 : f32 - %result = scf.for %iv = %c0 to %c10 step %c1 iter_args(%acc = %init) -> (f32) { - %one = arith.constant 1.0 : f32 - %sum = arith.addf %acc, %one : f32 - scf.yield %sum : f32 - } - return %result : f32 - } -} - -//----------------------------------------------------------------------------- -// Test dialect normal form attribute tests (multiple attributes). -//----------------------------------------------------------------------------- - -// Two attributes pass with valid IR. -// CHECK-LABEL: normalform.module @two_attrs_valid -// CHECK-SAME: [#aster_test.no_index_types, #aster_test.no_invalid_ops] -normalform.module @two_attrs_valid [#aster_test.no_index_types, #aster_test.no_invalid_ops] { - func.func @f(%arg: i32) { - return - } -} - -// All three attributes pass with valid IR. -// CHECK-LABEL: normalform.module @all_attrs_valid -// CHECK-SAME: [#aster_test.no_index_types, #aster_test.no_invalid_ops, #aster_test.no_invalid_attrs] -normalform.module @all_attrs_valid [#aster_test.no_index_types, #aster_test.no_invalid_ops, #aster_test.no_invalid_attrs] { - func.func @f(%arg: i32) attributes {foo = "valid"} { - return - } -} - -//----------------------------------------------------------------------------- -// Module without name. -//----------------------------------------------------------------------------- - -// Anonymous module with single attribute. -// CHECK-LABEL: normalform.module [#aster_test.no_invalid_ops] -normalform.module [#aster_test.no_invalid_ops] { - func.func @f() { - return - } -} diff --git a/test/lib/CMakeLists.txt b/test/lib/CMakeLists.txt index 63807d95d..396350b01 100644 --- a/test/lib/CMakeLists.txt +++ b/test/lib/CMakeLists.txt @@ -1,2 +1 @@ -add_subdirectory(Dialect) add_subdirectory(Pass) diff --git a/test/lib/Dialect/AsterTestDialect.cpp b/test/lib/Dialect/AsterTestDialect.cpp deleted file mode 100644 index 931221507..000000000 --- a/test/lib/Dialect/AsterTestDialect.cpp +++ /dev/null @@ -1,88 +0,0 @@ -//===- AsterTestDialect.cpp - test dialect --------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "AsterTestDialect.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpImplementation.h" - -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace mlir; - -#include "AsterTestDialect.cpp.inc" - -// Test normal form attribute. -#define GET_ATTRDEF_CLASSES -#include "TestNormalFormAttr.cpp.inc" - -void mlir::aster::test::AsterTestDialect::initialize() { - addAttributes< -#define GET_ATTRDEF_LIST -#include "TestNormalFormAttr.cpp.inc" - >(); -}; - -namespace mlir::aster::test { -void registerAsterTestDialect(DialectRegistry ®istry) { - registry.insert(); -} -} // namespace mlir::aster::test - -using namespace mlir::aster::test; - -//----------------------------------------------------------------------------- -// NoIndexTypesAttr interface implementations. -//----------------------------------------------------------------------------- - -llvm::LogicalResult -NoIndexTypesAttr::verifyType(llvm::function_ref emitError, - Type type) const { - if (!type) - return llvm::success(); - - if (type.isIndex()) - return emitError() << "normal form prohibits index types"; - - return llvm::success(); -} - -//----------------------------------------------------------------------------- -// NoInvalidOpsAttr interface implementations. -//----------------------------------------------------------------------------- - -llvm::LogicalResult NoInvalidOpsAttr::verifyOperation( - llvm::function_ref emitError, Operation *op) const { - if (isa(op)) - return emitError() << "normal form prohibits division operations"; - - return llvm::success(); -} - -//----------------------------------------------------------------------------- -// NoInvalidAttrsAttr interface implementations. -//----------------------------------------------------------------------------- - -llvm::LogicalResult NoInvalidAttrsAttr::verifyAttribute( - llvm::function_ref emitError, Attribute attr) const { - if (!attr) - return llvm::success(); - - if (auto strAttr = llvm::dyn_cast(attr)) { - if (strAttr.getValue() == "invalid") - return emitError() - << "normal form prohibits 'invalid' string attribute values"; - } - - return llvm::success(); -} diff --git a/test/lib/Dialect/AsterTestDialect.h b/test/lib/Dialect/AsterTestDialect.h deleted file mode 100644 index 94938ad8c..000000000 --- a/test/lib/Dialect/AsterTestDialect.h +++ /dev/null @@ -1,25 +0,0 @@ -//===- AsterTestDialect.h - test dialect ----------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT_H -#define ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT_H - -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" -#include "mlir/Bytecode/BytecodeOpInterface.h" -#include "mlir/IR/OpDefinition.h" - -#include "AsterTestDialect.h.inc" -#include "mlir/IR/Dialect.h" - -// Test normal form attribute. -#define GET_ATTRDEF_CLASSES -#include "TestNormalFormAttr.h.inc" - -#endif // ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT_H diff --git a/test/lib/Dialect/AsterTestDialect.td b/test/lib/Dialect/AsterTestDialect.td deleted file mode 100644 index 4a2899e31..000000000 --- a/test/lib/Dialect/AsterTestDialect.td +++ /dev/null @@ -1,24 +0,0 @@ -//===- AsterTestDialect.td - test dialect ---------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -include "mlir/IR/DialectBase.td" -include "mlir/IR/OpBase.td" - -#ifndef ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT -#define ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT - -def AsterTestDialect : Dialect { - let name = "aster_test"; - let summary = "Dialect for testing Aster"; - let cppNamespace = "::mlir::aster::test"; - let useDefaultAttributePrinterParser = 1; -} - -#endif // ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT diff --git a/test/lib/Dialect/CMakeLists.txt b/test/lib/Dialect/CMakeLists.txt deleted file mode 100644 index a5508b8e0..000000000 --- a/test/lib/Dialect/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS AsterTestDialect.td) -mlir_tablegen(AsterTestDialect.h.inc -gen-dialect-decls -dialect=aster_test) -mlir_tablegen(AsterTestDialect.cpp.inc -gen-dialect-defs -dialect=aster_test) -mlir_tablegen(AsterTestDialectOps.h.inc -gen-op-decls) -mlir_tablegen(AsterTestDialectOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRAsterTestDialectIncGen) - -set(LLVM_TARGET_DEFINITIONS TestNormalFormAttr.td) -mlir_tablegen(TestNormalFormAttr.h.inc -gen-attrdef-decls -attrdefs-dialect=aster_test) -mlir_tablegen(TestNormalFormAttr.cpp.inc -gen-attrdef-defs -attrdefs-dialect=aster_test) -add_public_tablegen_target(MLIRTestNormalFormAttrIncGen) - -add_mlir_dialect_library(MLIRAsterTestDialect - AsterTestDialect.cpp - - EXCLUDE_FROM_LIBMLIR - - DEPENDS - MLIRTestNormalFormAttrIncGen - - LINK_LIBS - MLIRArithDialect - MLIRIR - MLIRNormalFormDialect -) - -target_include_directories(MLIRAsterTestDialect - PRIVATE - ${PROJECT_BINARY_DIR}/test/lib/Dialect -) diff --git a/test/lib/Dialect/TestNormalFormAttr.td b/test/lib/Dialect/TestNormalFormAttr.td deleted file mode 100644 index 46700bc6f..000000000 --- a/test/lib/Dialect/TestNormalFormAttr.td +++ /dev/null @@ -1,54 +0,0 @@ -//===- TestNormalFormAttrs.td - definitions for NormalForm tests ----------===// -// -// Copyright 2025 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_TEST_LIB_DIALECT_TESTNORMALFORMATTR -#define ASTER_TEST_LIB_DIALECT_TESTNORMALFORMATTR - -include "mlir/IR/DialectBase.td" -include "mlir/IR/AttrTypeBase.td" -include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.td" -include "AsterTestDialect.td" - -//----------------------------------------------------------------------------- -// Test normal form attributes for testing NormalFormAttrInterface. -// Each attribute tests a single invariant. -//----------------------------------------------------------------------------- - -def NoIndexTypesAttr : AttrDef]> { - let mnemonic = "no_index_types"; - let summary = "Normal form attribute that prohibits index types"; - let description = [{ - A test attribute implementing NormalFormAttrInterface that verifies - no index types are used in operations, block arguments, or results. - }]; -} - -def NoInvalidOpsAttr : AttrDef]> { - let mnemonic = "no_invalid_ops"; - let summary = "Normal form attribute that prohibits division operations"; - let description = [{ - A test attribute implementing NormalFormAttrInterface that verifies - no division operations (arith.divf, arith.divsi, arith.divui) are present. - }]; -} - -def NoInvalidAttrsAttr : AttrDef]> { - let mnemonic = "no_invalid_attrs"; - let summary = "Normal form attribute that prohibits 'invalid' string values"; - let description = [{ - A test attribute implementing NormalFormAttrInterface that verifies - no string attributes have the value "invalid". - }]; -} - -#endif // ASTER_TEST_LIB_DIALECT_TESTNORMALFORMATTR diff --git a/tools/aster-opt/CMakeLists.txt b/tools/aster-opt/CMakeLists.txt index cc0c28d11..4bcc63c0a 100644 --- a/tools/aster-opt/CMakeLists.txt +++ b/tools/aster-opt/CMakeLists.txt @@ -2,7 +2,6 @@ set(LIBS AMDGCNTransforms ASTERInit ASTERTestPass - MLIRAsterTestDialect MLIROptLib ) diff --git a/tools/aster-opt/aster-opt.cpp b/tools/aster-opt/aster-opt.cpp index 43ab1fb7d..6e2bb75f0 100644 --- a/tools/aster-opt/aster-opt.cpp +++ b/tools/aster-opt/aster-opt.cpp @@ -22,10 +22,6 @@ namespace mlir::aster { void registerTestPasses(); } // namespace mlir::aster -namespace mlir::aster::test { -void registerAsterTestDialect(DialectRegistry ®istry); -} // namespace mlir::aster::test - using namespace llvm; using namespace mlir; @@ -41,7 +37,6 @@ int main(int argc, char **argv) { aster::initDialects(registry); aster::registerPasses(); aster::registerTestPasses(); - aster::test::registerAsterTestDialect(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "aster modular optimizer driver\n", registry)); }