diff --git a/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h b/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h index 0667572ce..b4133cc13 100644 --- a/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h +++ b/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h @@ -21,10 +21,11 @@ #include "aster/Dialect/AMDGCN/IR/Hazards.h" #include "aster/Dialect/AMDGCN/IR/Interfaces/KernelArgInterface.h" #include "aster/Dialect/AMDGCN/IR/Sched.h" -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" #include "aster/Interfaces/MemorySpaceConstraints.h" #include "aster/Interfaces/SchedInterfaces.h" #include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" #include "mlir/IR/Attributes.h" namespace mlir { diff --git a/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td b/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td index c8219f0a8..93bf3239d 100644 --- a/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td +++ b/include/aster/Dialect/AMDGCN/IR/AMDGCNAttrs.td @@ -20,10 +20,27 @@ include "aster/Dialect/AMDGCN/IR/AMDGCNEnums.td" include "aster/Dialect/AMDGCN/IR/Interfaces/KernelArgInterface.td" include "aster/Dialect/AMDGCN/IR/Interfaces/HazardAttrInterface.td" include "aster/Dialect/AMDGCN/IR/Hazards.td" -include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.td" include "aster/Interfaces/MemorySpaceConstraints.td" include "aster/Interfaces/SchedInterfaces.td" include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.td" +include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/CommonAttrConstraints.td" + +//===----------------------------------------------------------------------===// +// Normal form attribute array constraints. +//===----------------------------------------------------------------------===// + +/// An array of attributes implementing `mlir::transform::NormalFormAttrInterface`. +def NormalFormAttrArray : TypedArrayAttrBase; + +/// Optional normal form array: omitted when empty, prints in attr-dict. +def OptionalNormalFormAttrArray + : DefaultValuedAttr()"> { + let constBuilderCall = "$_builder.getArrayAttr($0)"; + let isOptional = true; +} //===----------------------------------------------------------------------===// // AddressSpaceAttr @@ -372,7 +389,7 @@ def KernelArgumentsAttr def NoValueSemanticRegistersAttr : AMDGCN_Attr< "NoValueSemanticRegisters", "no_value_semantic_registers", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no register types with value semantics"; let description = [{ Verifies that no register types in the IR have value semantics. @@ -386,7 +403,7 @@ def NoValueSemanticRegistersAttr : AMDGCN_Attr< def AllRegistersAllocatedAttr : AMDGCN_Attr< "AllRegistersAllocated", "all_registers_allocated", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: all register types have allocated semantics"; let description = [{ Verifies that every register type in the IR has allocated semantics @@ -400,7 +417,7 @@ def AllRegistersAllocatedAttr : AMDGCN_Attr< def NoRegCastOpsAttr : AMDGCN_Attr< "NoRegCastOps", "no_reg_cast_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no lsir.reg_cast operations"; let description = [{ Verifies that no `lsir.reg_cast` operations exist in the IR. @@ -414,7 +431,7 @@ def NoRegCastOpsAttr : AMDGCN_Attr< def NoCfBranchesAttr : AMDGCN_Attr< "NoCfBranches", "no_cf_branches", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no CF dialect branch operations"; let description = [{ Verifies that no `cf.br` or `cf.cond_br` operations exist in the IR. @@ -426,7 +443,7 @@ def NoCfBranchesAttr : AMDGCN_Attr< def NoScfOpsAttr : AMDGCN_Attr< "NoScfOps", "no_scf_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no SCF dialect operations"; let description = [{ Verifies that no SCF dialect operations (`scf.for`, `scf.if`, `scf.while`, @@ -438,7 +455,7 @@ def NoScfOpsAttr : AMDGCN_Attr< def NoLsirOpsAttr : AMDGCN_Attr< "NoLsirOps", "no_lsir_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no LSIR dialect operations"; let description = [{ Verifies that no LSIR dialect operations exist in the IR. @@ -450,7 +467,7 @@ def NoLsirOpsAttr : AMDGCN_Attr< def NoLsirComputeOpsAttr : AMDGCN_Attr< "NoLsirComputeOps", "no_lsir_compute_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no LSIR compute/memory operations"; let description = [{ Verifies that no LSIR arithmetic, memory, or utility operations exist in @@ -469,7 +486,7 @@ def NoLsirComputeOpsAttr : AMDGCN_Attr< def NoLsirControlOpsAttr : AMDGCN_Attr< "NoLsirControlOps", "no_lsir_control_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no LSIR control-flow operations"; let description = [{ Verifies that no LSIR control-flow-related operations (`lsir.cmpi`, @@ -485,7 +502,7 @@ def NoLsirControlOpsAttr : AMDGCN_Attr< def NoRegisterBlockArgsAttr : AMDGCN_Attr< "NoRegisterBlockArgs", "no_register_block_args", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no block arguments with register types"; let description = [{ Verifies that no block arguments have register types @@ -497,7 +514,7 @@ def NoRegisterBlockArgsAttr : AMDGCN_Attr< def NoAffineOpsAttr : AMDGCN_Attr< "NoAffineOps", "no_affine_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no affine dialect operations"; let description = [{ Verifies that no affine dialect operations (`affine.apply`, `affine.for`, @@ -509,7 +526,7 @@ def NoAffineOpsAttr : AMDGCN_Attr< def NoMetadataOpsAttr : AMDGCN_Attr< "NoMetadataOps", "no_metadata_ops", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: no AMDGCN metadata operations"; let description = [{ Verifies that no AMDGCN metadata operations remain in the IR. @@ -523,7 +540,7 @@ def NoMetadataOpsAttr : AMDGCN_Attr< def AllInlinedAttr : AMDGCN_Attr< "AllInlined", "all_inlined", - [DeclareAttrInterfaceMethods]> { + [DeclareAttrInterfaceMethods]> { let summary = "Normal form: all function calls inlined"; let description = [{ Verifies that no `func.call` operations exist in the IR. diff --git a/include/aster/Dialect/AMDGCN/IR/AMDGCNOps.td b/include/aster/Dialect/AMDGCN/IR/AMDGCNOps.td index b769d5de3..03acf008b 100644 --- a/include/aster/Dialect/AMDGCN/IR/AMDGCNOps.td +++ b/include/aster/Dialect/AMDGCN/IR/AMDGCNOps.td @@ -445,11 +445,11 @@ def AMDGCN_KernelOp : AMDGCN_Op<"kernel", [ /// Adds normal form attributes. Returns true if the op was changed. bool addNormalForms( - ::llvm::ArrayRef<::normalform::NormalFormAttrInterface> normalForms); + ::llvm::ArrayRef<::mlir::transform::NormalFormAttrInterface> normalForms); /// Removes normal form attributes. Returns true if the op was changed. bool removeNormalForms( - ::llvm::ArrayRef<::normalform::NormalFormAttrInterface> normalForms); + ::llvm::ArrayRef<::mlir::transform::NormalFormAttrInterface> normalForms); }]; } @@ -572,11 +572,11 @@ def AMDGCN_ModuleOp : AMDGCN_Op<"module", [ /// Adds normal form attributes. Returns true if the op was changed. bool addNormalForms( - ::llvm::ArrayRef<::normalform::NormalFormAttrInterface> normalForms); + ::llvm::ArrayRef<::mlir::transform::NormalFormAttrInterface> normalForms); /// Removes normal form attributes. Returns true if the op was changed. bool removeNormalForms( - ::llvm::ArrayRef<::normalform::NormalFormAttrInterface> normalForms); + ::llvm::ArrayRef<::mlir::transform::NormalFormAttrInterface> normalForms); }]; } diff --git a/include/aster/Dialect/CMakeLists.txt b/include/aster/Dialect/CMakeLists.txt index 68bb9f07d..407936f00 100644 --- a/include/aster/Dialect/CMakeLists.txt +++ b/include/aster/Dialect/CMakeLists.txt @@ -2,4 +2,3 @@ add_subdirectory(AMDGCN) add_subdirectory(AsterUtils) add_subdirectory(Layout) add_subdirectory(LSIR) -add_subdirectory(NormalForm) diff --git a/include/aster/Dialect/NormalForm/CMakeLists.txt b/include/aster/Dialect/NormalForm/CMakeLists.txt deleted file mode 100644 index 9f57627c3..000000000 --- a/include/aster/Dialect/NormalForm/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/include/aster/Dialect/NormalForm/IR/CMakeLists.txt b/include/aster/Dialect/NormalForm/IR/CMakeLists.txt deleted file mode 100644 index 8d55f3851..000000000 --- a/include/aster/Dialect/NormalForm/IR/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -add_custom_target(MLIRNormalFormIncGen) - -set(LLVM_TARGET_DEFINITIONS NormalFormDialect.td) -mlir_tablegen(NormalFormDialect.h.inc -gen-dialect-decls -dialect=normalform) -mlir_tablegen(NormalFormDialect.cpp.inc -gen-dialect-defs -dialect=normalform) -add_public_tablegen_target(MLIRNormalFormDialectIncGen) - -set(LLVM_TARGET_DEFINITIONS NormalFormOps.td) -mlir_tablegen(NormalFormOps.h.inc -gen-op-decls) -mlir_tablegen(NormalFormOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRNormalFormOpsIncGen) -add_dependencies(MLIRNormalFormIncGen MLIRNormalFormOpsIncGen) - -set(LLVM_TARGET_DEFINITIONS NormalFormInterfaces.td) -mlir_tablegen(NormalFormAttrInterfaces.h.inc -gen-attr-interface-decls) -mlir_tablegen(NormalFormAttrInterfaces.cpp.inc -gen-attr-interface-defs) -add_public_tablegen_target(MLIRNormalFormInterfacesIncGen) -add_dependencies(MLIRNormalFormIncGen MLIRNormalFormInterfacesIncGen) -add_mlir_doc(NormalFormInterfaces NormalFormAttrInterfaces Dialects/ -gen-attr-interface-docs) - -add_dependencies(mlir-headers MLIRNormalFormIncGen) - -add_mlir_doc(NormalForm NormalForm Dialects/ -gen-dialect-doc -dialect normalform) diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormDialect.h b/include/aster/Dialect/NormalForm/IR/NormalFormDialect.h deleted file mode 100644 index 7b0841bc3..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormDialect.h +++ /dev/null @@ -1,18 +0,0 @@ -//===- NormalFormDialect.h ------------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_IR_NORMALFORMDIALECT_H -#define ASTER_DIALECT_NORMALFORM_IR_NORMALFORMDIALECT_H - -#include "mlir/IR/Dialect.h" - -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.h.inc" - -#endif // ASTER_DIALECT_NORMALFORM_IR_NORMALFORMDIALECT_H diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormDialect.td b/include/aster/Dialect/NormalForm/IR/NormalFormDialect.td deleted file mode 100644 index 588bd8bb0..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormDialect.td +++ /dev/null @@ -1,28 +0,0 @@ -//===- NormalFormDialect.td -----------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_NORMALFORMDIALECT -#define ASTER_DIALECT_NORMALFORM_NORMALFORMDIALECT - -include "mlir/IR/DialectBase.td" - -def NormalFormDialect : Dialect { - let name = "normalform"; - let summary = "Dialect for normal form IR verification"; - - let description = [{ - The NormalForm dialect provides infrastructure for defining and verifying - normal forms of IR. It uses the `normalform.module` operation as a - container that specifies the expected normal forms through an array of - attributes implementing the NormalFormAttrInterface. - }]; -} - -#endif // ASTER_DIALECT_NORMALFORM_NORMALFORMDIALECT diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.h b/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.h deleted file mode 100644 index 2ef4d5214..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.h +++ /dev/null @@ -1,48 +0,0 @@ -//===- NormalFormInterfaces.h ---------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_IR_NORMALFORMINTERFACES_H -#define ASTER_DIALECT_NORMALFORM_IR_NORMALFORMINTERFACES_H - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/DenseSet.h" - -#include "aster/Dialect/NormalForm/IR/NormalFormAttrInterfaces.h.inc" - -namespace normalform { - -/// Verify that all IR nested under `root` satisfies the given normal form. -/// If `emitDiagnostics` is true, errors are reported; otherwise the check -/// is silent (useful for inferNormalForms). -/// `excludeAttrNames` optionally specifies named attributes whose nested -/// types should be skipped during verification (e.g., kernel argument -/// attributes that contain ABI metadata types, not computational types). -::llvm::LogicalResult verifyNormalForm( - ::mlir::Operation *root, NormalFormAttrInterface normalForm, - bool emitDiagnostics, - const ::llvm::DenseSet<::mlir::StringAttr> *excludeAttrNames = nullptr); - -} // namespace normalform - -namespace llvm { -template <> -struct PointerLikeTypeTraits - : public PointerLikeTypeTraits { - static inline normalform::NormalFormAttrInterface - getFromVoidPointer(void *p) { - return normalform::NormalFormAttrInterface( - mlir::Attribute::getFromOpaquePointer(p)); - } -}; -} // namespace llvm - -#endif // ASTER_DIALECT_NORMALFORM_IR_NORMALFORMINTERFACES_H diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.td b/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.td deleted file mode 100644 index 9ed21e7a9..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormInterfaces.td +++ /dev/null @@ -1,92 +0,0 @@ -//===- NormalFormInterfaces.td --------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_NORMALFORMINTERFACES -#define ASTER_DIALECT_NORMALFORM_NORMALFORMINTERFACES - -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/OpBase.td" -include "mlir/IR/CommonAttrConstraints.td" - -//===----------------------------------------------------------------------===// -// Normal form attribute interface. -//===----------------------------------------------------------------------===// -def NormalFormAttrInterface : AttrInterface<"NormalFormAttrInterface"> { - let description = [{ - Interface for attributes that define normal form constraints on IR. - - A normal form attribute specifies invariants that types, attributes, and - operations must satisfy. When attached to a `normalform.module`, all - contained IR is verified against the constraints defined by the attribute. - - Implementers should override one or more of `verifyType`, `verifyAttribute`, - or `verifyOperation` to define custom verification logic. The default - implementations return success. - }]; - let cppNamespace = "::normalform"; - let methods = [ - InterfaceMethod< - /*desc=*/[{ - Verify the type using this normal form attribute. - }], - /*returnType=*/"::llvm::LogicalResult", - /*name=*/"verifyType", - /*arguments=*/(ins "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError, "::mlir::Type":$type), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return llvm::success(); - }] - >, - - InterfaceMethod< - /*desc=*/[{ - Verify the attribute using this normal form attribute. - }], - /*returnType=*/"::llvm::LogicalResult", - /*name=*/"verifyAttribute", - /*arguments=*/(ins "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError, "::mlir::Attribute":$attr), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return llvm::success(); - }] - >, - - InterfaceMethod< - /*desc=*/[{ - Verify the operation using this normal form attribute. - }], - /*returnType=*/"::llvm::LogicalResult", - /*name=*/"verifyOperation", - /*arguments=*/(ins "::llvm::function_ref<::mlir::InFlightDiagnostic()>":$emitError, "::mlir::Operation*":$op), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return llvm::success(); - }] - >, - ]; -} - -//===----------------------------------------------------------------------===// -// Normal form attribute array constraints. -//===----------------------------------------------------------------------===// - -/// An array of attributes implementing NormalFormAttrInterface. -def NormalFormAttrArray : TypedArrayAttrBase; - -/// Optional normal form array: omitted when empty, prints in attr-dict. -def OptionalNormalFormAttrArray - : DefaultValuedAttr()"> { - let constBuilderCall = "$_builder.getArrayAttr($0)"; - let isOptional = true; -} - - -#endif // ASTER_DIALECT_NORMALFORM_NORMALFORMINTERFACES diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormOps.h b/include/aster/Dialect/NormalForm/IR/NormalFormOps.h deleted file mode 100644 index 24d5dde27..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormOps.h +++ /dev/null @@ -1,25 +0,0 @@ -//===- NormalFormOps.h ----------------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_IR_NORMALFORMOPS_H -#define ASTER_DIALECT_NORMALFORM_IR_NORMALFORMOPS_H - -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" - -#include "mlir/Bytecode/BytecodeOpInterface.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/RegionKindInterface.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Interfaces/ControlFlowInterfaces.h" - -#define GET_OP_CLASSES -#include "aster/Dialect/NormalForm/IR/NormalFormOps.h.inc" - -#endif // ASTER_DIALECT_NORMALFORM_IR_NORMALFORMOPS_H diff --git a/include/aster/Dialect/NormalForm/IR/NormalFormOps.td b/include/aster/Dialect/NormalForm/IR/NormalFormOps.td deleted file mode 100644 index d1f1c71a2..000000000 --- a/include/aster/Dialect/NormalForm/IR/NormalFormOps.td +++ /dev/null @@ -1,124 +0,0 @@ -//===- NormalFormOps.td ---------------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -include "aster/Dialect/NormalForm/IR/NormalFormDialect.td" -include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.td" -include "mlir/IR/OpBase.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/IR/CommonAttrConstraints.td" -include "mlir/IR/CommonTypeConstraints.td" -include "mlir/IR/BuiltinAttributeInterfaces.td" -include "mlir/IR/RegionKindInterface.td" -include "mlir/IR/SymbolInterfaces.td" - - - -#ifndef ASTER_DIALECT_NORMALFORM_NORMALFORMOPS -#define ASTER_DIALECT_NORMALFORM_NORMALFORMOPS - -//----------------------------------------------------------------------------- -// Base class for all NormalForm operations. -//----------------------------------------------------------------------------- - -class NormalFormOp traits = []> : - Op; - -//----------------------------------------------------------------------------- -// Structure Ops -//----------------------------------------------------------------------------- - -def ModuleOp : NormalFormOp<"module", [ - IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol, - ] # GraphRegionNoTerminator.traits> { - let summary = "A container that enforces IR invariants specified by normal form attributes."; - let description = [{ - The `normalform.module` operation defines a scoped region whose contents must - satisfy the constraints specified by the attached normal form attributes. Each - attribute must implement `NormalFormAttrInterface`, which provides hooks for - verifying types, attributes, and operations within the module. - - During region verification, the module walks all nested operations and - invokes the interface methods (`verifyType`, `verifyAttribute`, - `verifyOperation`) on each element for all normal form attributes. - Normal forms are verified in the order given by the normal form array attribute, - walking over each operation in pre-order while also visiting its types, followed by - its attributes. - - Example: - - ```mlir - // Enforce that all tensor types are fully specified. - normalform.module @my_kernel [#wave.normal_form] { - func.func @compute(%arg: !wave.tensor<[64, 128] of f32>) { - return - } - } - - // Multiple normal form attributes from different dialects. - normalform.module @validated [#wave.normal_form, #other.normal_form] { - func.func @compute(%arg: !wave.tensor<[64, 128] of f32>) { - return - } - } - ``` - }]; - let arguments = (ins NormalFormAttrArray:$normal_forms, OptionalAttr:$sym_name); - let regions = (region SizedRegion<1>:$bodyRegion); - let assemblyFormat = [{ - ($sym_name^)? $normal_forms attr-dict-with-keyword $bodyRegion - }]; - - let builders = [ - OpBuilder<(ins - CArg<"::llvm::ArrayRef">:$normal_forms, - CArg<"std::optional<::llvm::StringRef>", "{}">:$name)> - ]; - - let extraClassDeclaration = [{ - /// Construct a module from the given location with an optional name. - static ModuleOp create(::mlir::Location loc, - ::llvm::ArrayRef normalForms, - std::optional<::llvm::StringRef> name = {}); - - /// Return the name of this module if present. - std::optional<::llvm::StringRef> getName() { return getSymName(); } - - /// Checks whether the normal forms passed in `normalForms` apply to this - /// module and attaches them to the module if true. Returns true if the module was changed. - bool inferNormalForms(::llvm::ArrayRef normalForms); - - /// Checks wheter a given normal form applies to this module. - ::llvm::LogicalResult - verifyNormalForm(NormalFormAttrInterface normalForm, bool emitDiagnostics); - - /// Adds normal form attributes to the module. Returns true if the module was changed. - bool addNormalForms(::llvm::ArrayRef normalForms); - - /// Removes normal form attributes from the module. Returns true if the module was changed. - bool removeNormalForms(::llvm::ArrayRef normalForms); - - //===------------------------------------------------------------------===// - // SymbolOpInterface Methods - //===------------------------------------------------------------------===// - - /// A ModuleOp may optionally define a symbol. - bool isOptionalSymbol() const { return true; } - }]; - - let hasVerifier = 1; - let hasRegionVerifier = 1; - - // We need to ensure the block inside the region is properly terminated; - // the auto-generated builders do not guarantee that. - let skipDefaultBuilders = 1; -} - - -#endif // ASTER_DIALECT_NORMALFORM_NORMALFORMOPS diff --git a/include/aster/Dialect/NormalForm/Transforms/CMakeLists.txt b/include/aster/Dialect/NormalForm/Transforms/CMakeLists.txt deleted file mode 100644 index 2594d3ce1..000000000 --- a/include/aster/Dialect/NormalForm/Transforms/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name NormalForm) -add_public_tablegen_target(MLIRAsterNormalFormPassesIncGen) diff --git a/include/aster/Dialect/NormalForm/Transforms/Passes.h b/include/aster/Dialect/NormalForm/Transforms/Passes.h deleted file mode 100644 index 8f3eb68d7..000000000 --- a/include/aster/Dialect/NormalForm/Transforms/Passes.h +++ /dev/null @@ -1,27 +0,0 @@ -//===- Passes.h - Pass entrypoints ----------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES_H -#define ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES_H - -#include "mlir/Pass/Pass.h" -#include - -namespace normalform { - -#define GEN_PASS_DECL -#include "aster/Dialect/NormalForm/Transforms/Passes.h.inc" - -#define GEN_PASS_REGISTRATION -#include "aster/Dialect/NormalForm/Transforms/Passes.h.inc" - -} // namespace normalform - -#endif // ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES_H diff --git a/include/aster/Dialect/NormalForm/Transforms/Passes.td b/include/aster/Dialect/NormalForm/Transforms/Passes.td deleted file mode 100644 index af07a9fff..000000000 --- a/include/aster/Dialect/NormalForm/Transforms/Passes.td +++ /dev/null @@ -1,26 +0,0 @@ -//===- Passes.td - Pass definitions ---------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES -#define ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES - -include "mlir/Pass/PassBase.td" - -def LowerNormalFormModulePass : Pass<"lower-normalform-module"> { - let summary = "Lower normalform.module to builtin.module"; - let description = [{ - This pass converts `normalform.module` operations to `builtin.module` - operations. The normal form attributes are not retained on the resulting - module. - }]; - let dependentDialects = []; -} - -#endif // ASTER_DIALECT_NORMALFORM_TRANSFORMS_PASSES diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 03b0d0f0d..07d75b9e6 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -22,8 +22,6 @@ add_mlir_library(ASTERInit LayoutDialect LayoutTransforms LSIRDialect - MLIRNormalFormDialect - MLIRNormalFormTransforms MLIRAffineDialect MLIRAffineTransforms MLIRArithDialect diff --git a/lib/Dialect/AMDGCN/IR/AMDGCN.cpp b/lib/Dialect/AMDGCN/IR/AMDGCN.cpp index 02014d8a4..05aa8c85f 100644 --- a/lib/Dialect/AMDGCN/IR/AMDGCN.cpp +++ b/lib/Dialect/AMDGCN/IR/AMDGCN.cpp @@ -16,7 +16,6 @@ #include "aster/Dialect/AMDGCN/IR/AMDGCNVerifiers.h" #include "aster/Dialect/AMDGCN/IR/Interfaces/AMDGCNInterfaces.h" #include "aster/Dialect/AMDGCN/IR/Utils.h" -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" #include "aster/IR/ParsePrintUtils.h" #include "aster/Interfaces/RegisterType.h" #include "mlir/IR/BuiltinTypes.h" @@ -921,17 +920,17 @@ LogicalResult KernelOp::verify() { //===----------------------------------------------------------------------===// /// Verify all normal forms attached to an operation via its normal_forms attr. -/// `excludeAttrNames` optionally specifies named attributes whose nested -/// types should be skipped during verification. -static LogicalResult verifyNormalFormsRegions( - Operation *op, ArrayAttr normalFormsAttr, - const DenseSet *excludeAttrNames = nullptr) { +/// Each normal form's `checkOperation` method is responsible for walking the +/// IR (and excluding op-attribute payloads that should not participate in +/// normal-form type checking; see `getNormalFormTypeWalkExcludes` in +/// AMDGCNAttrs.cpp for the kernel-arguments exclusion). +static LogicalResult verifyNormalFormsRegions(Operation *op, + ArrayAttr normalFormsAttr) { if (!normalFormsAttr || normalFormsAttr.empty()) return success(); for (Attribute attr : normalFormsAttr) { - auto nf = cast(attr); - if (failed(normalform::verifyNormalForm(op, nf, /*emitDiagnostics=*/true, - excludeAttrNames))) + auto nf = cast(attr); + if (failed(nf.checkOperation(op).checkAndReport())) return failure(); } return success(); @@ -941,7 +940,7 @@ static LogicalResult verifyNormalFormsRegions( /// semantics. Returns true if the attribute was changed. static bool addNormalFormsImpl(Operation *op, StringRef attrName, ArrayAttr currentAttr, - ArrayRef nfs, + ArrayRef nfs, function_ref setter) { if (nfs.empty()) return false; @@ -951,7 +950,7 @@ addNormalFormsImpl(Operation *op, StringRef attrName, ArrayAttr currentAttr, nfSet.insert_range(currentAttr.getValue()); bool changed = false; - for (normalform::NormalFormAttrInterface nf : nfs) + for (mlir::transform::NormalFormAttrInterface nf : nfs) changed |= nfSet.insert(nf); if (!changed) @@ -966,7 +965,7 @@ addNormalFormsImpl(Operation *op, StringRef attrName, ArrayAttr currentAttr, /// Returns true if the attribute was changed. static bool removeNormalFormsImpl(Operation *op, ArrayAttr currentAttr, - ArrayRef nfs, + ArrayRef nfs, function_ref setter) { if (nfs.empty() || !currentAttr || currentAttr.empty()) return false; @@ -975,7 +974,7 @@ removeNormalFormsImpl(Operation *op, ArrayAttr currentAttr, nfSet.insert_range(currentAttr.getValue()); bool changed = false; - for (normalform::NormalFormAttrInterface nf : nfs) + for (mlir::transform::NormalFormAttrInterface nf : nfs) changed |= nfSet.remove(nf); if (!changed) @@ -995,14 +994,14 @@ LogicalResult amdgcn::ModuleOp::verifyRegions() { } bool amdgcn::ModuleOp::addNormalForms( - ArrayRef normalForms) { + ArrayRef normalForms) { return addNormalFormsImpl(getOperation(), getNormalFormsAttrName(), getNormalFormsAttr(), normalForms, [&](ArrayAttr attr) { setNormalFormsAttr(attr); }); } bool amdgcn::ModuleOp::removeNormalForms( - ArrayRef normalForms) { + ArrayRef normalForms) { return removeNormalFormsImpl( getOperation(), getNormalFormsAttr(), normalForms, [&](ArrayAttr attr) { setNormalFormsAttr(attr); }); @@ -1013,24 +1012,18 @@ bool amdgcn::ModuleOp::removeNormalForms( //===----------------------------------------------------------------------===// LogicalResult KernelOp::verifyRegions() { - // Exclude 'arguments' attribute from normal form type walking: kernel - // argument attrs (by_val_arg, buffer_arg) contain ABI metadata types, - // not computational register types in the kernel body. - DenseSet excludeAttrs; - excludeAttrs.insert(getArgumentsAttrName()); - return verifyNormalFormsRegions(getOperation(), getNormalFormsAttr(), - &excludeAttrs); + return verifyNormalFormsRegions(getOperation(), getNormalFormsAttr()); } bool KernelOp::addNormalForms( - ArrayRef normalForms) { + ArrayRef normalForms) { return addNormalFormsImpl(getOperation(), getNormalFormsAttrName(), getNormalFormsAttr(), normalForms, [&](ArrayAttr attr) { setNormalFormsAttr(attr); }); } bool KernelOp::removeNormalForms( - ArrayRef normalForms) { + ArrayRef normalForms) { return removeNormalFormsImpl( getOperation(), getNormalFormsAttr(), normalForms, [&](ArrayAttr attr) { setNormalFormsAttr(attr); }); diff --git a/lib/Dialect/AMDGCN/IR/AMDGCNAttrs.cpp b/lib/Dialect/AMDGCN/IR/AMDGCNAttrs.cpp index 05f1da1d2..47d953fb7 100644 --- a/lib/Dialect/AMDGCN/IR/AMDGCNAttrs.cpp +++ b/lib/Dialect/AMDGCN/IR/AMDGCNAttrs.cpp @@ -17,7 +17,13 @@ #include "aster/Interfaces/RegisterType.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" +#include "mlir/IR/AttrTypeSubElements.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/Support/WalkResult.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/ErrorHandling.h" @@ -225,198 +231,357 @@ int32_t GenericSchedLabelerAttr::getLabel(Operation *op, int32_t, } //===----------------------------------------------------------------------===// -// NoValueSemanticRegistersAttr +// Normal-form helpers //===----------------------------------------------------------------------===// -LogicalResult NoValueSemanticRegistersAttr::verifyType( - function_ref emitError, Type type) const { - auto regType = dyn_cast(type); - if (!regType) - return success(); - - if (regType.hasValueSemantics()) - return emitError() << "normal form violation: register types with value " - "semantics are disallowed but found: " - << type; +namespace { +/// Filter callback used by `walkTypes` to decide whether to descend into a +/// given `NamedAttribute` of an operation. Returning `false` skips the +/// attribute payload entirely. +using NamedAttrFilter = llvm::function_ref; + +/// Skips named attributes that carry ABI metadata whose register types are +/// not subject to normal-form invariants enforced on the kernel body. For +/// `amdgcn.kernel`, this excludes the `arguments` attribute (e.g. +/// `by_val_arg` parameter types). +bool skipKernelAbiMetadata(Operation *op, NamedAttribute attr) { + if (auto kernel = dyn_cast(op)) + return attr.getName() != kernel.getArgumentsAttrName(); + return true; +} - return success(); +/// Aggregates `DiagnosedSilenceableFailure` results across multiple visits: +/// records the first silenceable failure (silencing later ones) and short- +/// circuits on definite failures. +struct AttrTypeAggregator { + DiagnosedSilenceableFailure overall = DiagnosedSilenceableFailure::success(); + bool stop = false; + + void merge(DiagnosedSilenceableFailure &&result) { + if (result.isDefiniteFailure()) { + overall = std::move(result); + stop = true; + return; + } + if (result.isSilenceableFailure()) { + if (overall.succeeded()) + overall = std::move(result); + else + (void)result.silence(); + } + } +}; + +/// Walks all distinct types reachable from operations under `root`: operation +/// result types, block argument types, and types nested inside operation +/// attributes. Invokes `visitor` on each distinct type with a `Location` near +/// its discovery point. When `filter` is non-null, it is consulted for each +/// named attribute of every visited operation; returning `false` skips the +/// attribute payload (e.g. to exclude `amdgcn.kernel`'s `arguments` ABI +/// metadata). +DiagnosedSilenceableFailure walkTypes( + Operation *root, + llvm::function_ref visitor, + NamedAttrFilter filter = nullptr) { + AttrTypeAggregator agg; + llvm::SmallPtrSet seenTypes; + llvm::SmallPtrSet seenAttrs; + Location currentLoc = root->getLoc(); + AttrTypeWalker walker; + + walker.addWalk([&](Type type) { + auto [it, inserted] = seenTypes.insert(type); + if (!inserted) + return WalkResult::skip(); + agg.merge(visitor(type, currentLoc)); + if (agg.stop) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + walker.addWalk([&](Attribute attr) { + auto [it, inserted] = seenAttrs.insert(attr); + if (!inserted) + return WalkResult::skip(); + return WalkResult::advance(); + }); + + root->walk([&](Operation *op) { + currentLoc = op->getLoc(); + for (OpResult result : op->getResults()) { + currentLoc = result.getLoc(); + if (walker.walk(result.getType()).wasInterrupted()) + return WalkResult::interrupt(); + } + for (Region ®ion : op->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + currentLoc = arg.getLoc(); + if (walker.walk(arg.getType()).wasInterrupted()) + return WalkResult::interrupt(); + } + } + } + currentLoc = op->getLoc(); + for (NamedAttribute attr : op->getAttrs()) { + if (filter && !filter(op, attr)) + continue; + if (walker.walk(attr.getValue()).wasInterrupted()) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return std::move(agg.overall); } +} // namespace + //===----------------------------------------------------------------------===// -// AllRegistersAllocatedAttr +// NoValueSemanticRegistersAttr //===----------------------------------------------------------------------===// -LogicalResult AllRegistersAllocatedAttr::verifyType( - function_ref emitError, Type type) const { - auto regType = dyn_cast(type); - if (!regType) - return success(); +DiagnosedSilenceableFailure +NoValueSemanticRegistersAttr::checkOperation(Operation *op) const { + return walkTypes( + op, + [](Type type, Location loc) -> DiagnosedSilenceableFailure { + auto regType = dyn_cast(type); + if (!regType || !regType.hasValueSemantics()) + return DiagnosedSilenceableFailure::success(); + return emitSilenceableFailure(loc) + << "normal form violation: register types with value " + "semantics are disallowed but found: " + << type; + }, + skipKernelAbiMetadata); +} - if (!regType.hasAllocatedSemantics()) - return emitError() << "normal form violation: all registers must have " - "allocated semantics but found: " - << type; +//===----------------------------------------------------------------------===// +// AllRegistersAllocatedAttr +//===----------------------------------------------------------------------===// - return success(); +DiagnosedSilenceableFailure +AllRegistersAllocatedAttr::checkOperation(Operation *op) const { + return walkTypes( + op, + [](Type type, Location loc) -> DiagnosedSilenceableFailure { + auto regType = dyn_cast(type); + if (!regType || regType.hasAllocatedSemantics()) + return DiagnosedSilenceableFailure::success(); + return emitSilenceableFailure(loc) + << "normal form violation: all registers must have " + "allocated semantics but found: " + << type; + }, + skipKernelAbiMetadata); } //===----------------------------------------------------------------------===// // NoRegCastOpsAttr //===----------------------------------------------------------------------===// -LogicalResult -NoRegCastOpsAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (isa(op)) - return emitError() << "normal form violation: lsir.reg_cast should not " - "survive past aster-to-amdgcn; this indicates an " - "incorrect lsir.to_reg or lsir.from_reg surviving " - "from high-level (hand-authored ?) IR"; - return success(); +DiagnosedSilenceableFailure +NoRegCastOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!isa(innerOp)) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: lsir.reg_cast should not " + "survive past aster-to-amdgcn; this indicates an " + "incorrect lsir.to_reg or lsir.from_reg surviving " + "from high-level (hand-authored ?) IR"); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoLsirOpsAttr //===----------------------------------------------------------------------===// -LogicalResult -NoLsirOpsAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (op->getDialect() && op->getDialect()->getNamespace() == "lsir") - return emitError() << "normal form violation: LSIR dialect operations " - "are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure NoLsirOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!innerOp->getDialect() || + innerOp->getDialect()->getNamespace() != "lsir") + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: LSIR dialect operations " + "are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoLsirComputeOpsAttr //===----------------------------------------------------------------------===// -LogicalResult NoLsirComputeOpsAttr::verifyOperation( - function_ref emitError, Operation *op) const { - if (!op->getDialect() || op->getDialect()->getNamespace() != "lsir") - return success(); - - // Allow control-flow ops (lowered by LegalizeCF) and copy (regalloc - // primitive). - if (isa(op)) - return success(); - - return emitError() << "normal form violation: LSIR compute/memory " - "operations are disallowed but found: " - << op->getName(); +DiagnosedSilenceableFailure +NoLsirComputeOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!innerOp->getDialect() || + innerOp->getDialect()->getNamespace() != "lsir") + return WalkResult::advance(); + // Allow control-flow ops (lowered by LegalizeCF) and copy (regalloc + // primitive). + if (isa(innerOp)) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: LSIR compute/memory " + "operations are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoLsirControlOpsAttr //===----------------------------------------------------------------------===// -LogicalResult NoLsirControlOpsAttr::verifyOperation( - function_ref emitError, Operation *op) const { - if (isa(op)) - return emitError() << "normal form violation: LSIR control-flow " - "operations are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure +NoLsirControlOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!isa(innerOp)) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: LSIR control-flow " + "operations are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoScfOpsAttr //===----------------------------------------------------------------------===// -LogicalResult -NoScfOpsAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (op->getDialect() && op->getDialect()->getNamespace() == "scf") - return emitError() << "normal form violation: SCF dialect operations " - "are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure NoScfOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!innerOp->getDialect() || + innerOp->getDialect()->getNamespace() != "scf") + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: SCF dialect operations " + "are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoCfBranchesAttr //===----------------------------------------------------------------------===// -LogicalResult -NoCfBranchesAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (isa(op)) - return emitError() << "normal form violation: cf.br/cf.cond_br operations " - "are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure +NoCfBranchesAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!isa(innerOp)) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: cf.br/cf.cond_br operations " + "are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoRegisterBlockArgsAttr //===----------------------------------------------------------------------===// -LogicalResult NoRegisterBlockArgsAttr::verifyOperation( - function_ref emitError, Operation *op) const { - for (Region ®ion : op->getRegions()) { - for (Block &block : region) { - for (BlockArgument arg : block.getArguments()) { - if (isa(arg.getType())) - return emitError() - << "normal form violation: block arguments with register " - "types are disallowed but found: " - << arg.getType(); +DiagnosedSilenceableFailure +NoRegisterBlockArgsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + for (Region ®ion : innerOp->getRegions()) { + for (Block &block : region) { + for (BlockArgument arg : block.getArguments()) { + if (!isa(arg.getType())) + continue; + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: block arguments with " + "register types are disallowed but found: " + << arg.getType()); + if (agg.stop) + return WalkResult::interrupt(); + // Only report once per op to mirror the original behavior. + return WalkResult::advance(); + } } } - } - return success(); + return WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoAffineOpsAttr //===----------------------------------------------------------------------===// -LogicalResult -NoAffineOpsAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (op->getDialect() && op->getDialect()->getNamespace() == "affine") - return emitError() << "normal form violation: affine dialect operations " - "are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure +NoAffineOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!innerOp->getDialect() || + innerOp->getDialect()->getNamespace() != "affine") + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: affine dialect operations " + "are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // NoMetadataOpsAttr //===----------------------------------------------------------------------===// -LogicalResult -NoMetadataOpsAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (isa(op)) - return emitError() << "normal form violation: AMDGCN metadata operations " - "are disallowed but found: " - << op->getName(); - - return success(); +DiagnosedSilenceableFailure +NoMetadataOpsAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + if (!isa(innerOp)) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: AMDGCN metadata operations " + "are disallowed but found: " + << innerOp->getName()); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } //===----------------------------------------------------------------------===// // AllInlinedAttr //===----------------------------------------------------------------------===// -LogicalResult -AllInlinedAttr::verifyOperation(function_ref emitError, - Operation *op) const { - if (isa(op)) - return emitError() << "normal form violation: func.call operations " - "are disallowed (all functions should be inlined) " - "but found call to '" - << cast(op).getCallee() << "'"; - - return success(); +DiagnosedSilenceableFailure +AllInlinedAttr::checkOperation(Operation *op) const { + AttrTypeAggregator agg; + op->walk([&](Operation *innerOp) { + auto callOp = dyn_cast(innerOp); + if (!callOp) + return WalkResult::advance(); + agg.merge(emitSilenceableFailure(innerOp) + << "normal form violation: func.call operations " + "are disallowed (all functions should be inlined) " + "but found call to '" + << callOp.getCallee() << "'"); + return agg.stop ? WalkResult::interrupt() : WalkResult::advance(); + }); + return std::move(agg.overall); } diff --git a/lib/Dialect/AMDGCN/IR/CMakeLists.txt b/lib/Dialect/AMDGCN/IR/CMakeLists.txt index 4443a22e4..8cf2c79e8 100644 --- a/lib/Dialect/AMDGCN/IR/CMakeLists.txt +++ b/lib/Dialect/AMDGCN/IR/CMakeLists.txt @@ -24,7 +24,6 @@ add_mlir_dialect_library(AMDGCNDialect ASTERIR AsterInterfaces LSIRDialect - MLIRNormalFormDialect MLIRArithDialect MLIRControlFlowDialect MLIRControlFlowInterfaces @@ -33,4 +32,6 @@ add_mlir_dialect_library(AMDGCNDialect MLIRIR MLIRInferTypeOpInterface MLIRPtrDialect + MLIRTransformDialectInterfaces + MLIRTransformDialectUtils ) diff --git a/lib/Dialect/AMDGCN/Transforms/CMakeLists.txt b/lib/Dialect/AMDGCN/Transforms/CMakeLists.txt index 4e39717c3..7c1fd264c 100644 --- a/lib/Dialect/AMDGCN/Transforms/CMakeLists.txt +++ b/lib/Dialect/AMDGCN/Transforms/CMakeLists.txt @@ -38,7 +38,6 @@ add_mlir_library(AMDGCNTransforms AMDGCNAnalysis AMDGCNDialect AsterUtilsDialect - MLIRNormalFormDialect ASTERAnalysis AsterTransforms AsterInterfaces @@ -49,6 +48,8 @@ add_mlir_library(AMDGCNTransforms MLIRPass MLIRSCFDialect MLIRSCFUtils + MLIRTransformDialectInterfaces + MLIRTransformDialectUtils MLIRTransforms MLIRUBDialect ) diff --git a/lib/Dialect/AMDGCN/Transforms/LowLevelScheduler.cpp b/lib/Dialect/AMDGCN/Transforms/LowLevelScheduler.cpp index 7ff314f2d..a68e21d87 100644 --- a/lib/Dialect/AMDGCN/Transforms/LowLevelScheduler.cpp +++ b/lib/Dialect/AMDGCN/Transforms/LowLevelScheduler.cpp @@ -18,7 +18,8 @@ #include "aster/Dialect/AMDGCN/Transforms/Passes.h" #include "aster/Dialect/AsterUtils/IR/AsterUtilsAttrs.h" #include "aster/Dialect/AsterUtils/IR/AsterUtilsDialect.h" -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Utils/DiagnosedSilenceableFailure.h" namespace mlir::aster { namespace amdgcn { @@ -43,8 +44,7 @@ struct LowLevelSchedulerPass MLIRContext *ctx = kernel.getContext(); auto allInlined = AllInlinedAttr::get(ctx); - if (failed(normalform::verifyNormalForm(kernel, allInlined, - /*emitDiagnostics=*/true))) + if (failed(allInlined.checkOperation(kernel).checkAndReport())) return signalPassFailure(); GenericSchedulerAttr compositeAttr = GenericSchedulerAttr::get( diff --git a/lib/Dialect/AMDGCN/Transforms/SetNormalForms.cpp b/lib/Dialect/AMDGCN/Transforms/SetNormalForms.cpp index 1a768dd31..032c5d9be 100644 --- a/lib/Dialect/AMDGCN/Transforms/SetNormalForms.cpp +++ b/lib/Dialect/AMDGCN/Transforms/SetNormalForms.cpp @@ -11,8 +11,8 @@ #include "aster/Dialect/AMDGCN/IR/AMDGCNOps.h" #include "aster/Dialect/AMDGCN/Transforms/Passes.h" -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" #include "mlir/AsmParser/AsmParser.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" namespace mlir::aster { namespace amdgcn { @@ -35,15 +35,15 @@ struct SetNormalForms private: /// Parse a list of mnemonic strings into NormalFormAttrInterface attributes. /// Returns failure if any mnemonic is invalid. - FailureOr> + FailureOr> parseFormMnemonics(ArrayRef mnemonics); }; } // namespace -FailureOr> +FailureOr> SetNormalForms::parseFormMnemonics(ArrayRef mnemonics) { MLIRContext *ctx = &getContext(); - SmallVector attrs; + SmallVector attrs; for (const std::string &mnemonic : mnemonics) { std::string attrStr = "#amdgcn." + mnemonic; Attribute attr = mlir::parseAttribute(attrStr, ctx); @@ -53,7 +53,7 @@ SetNormalForms::parseFormMnemonics(ArrayRef mnemonics) { << "(tried to parse as '" << attrStr << "')"; return failure(); } - auto nfAttr = dyn_cast(attr); + auto nfAttr = dyn_cast(attr); if (!nfAttr) { getOperation()->emitError() << "attribute '" << attrStr diff --git a/lib/Dialect/CMakeLists.txt b/lib/Dialect/CMakeLists.txt index 68bb9f07d..407936f00 100644 --- a/lib/Dialect/CMakeLists.txt +++ b/lib/Dialect/CMakeLists.txt @@ -2,4 +2,3 @@ add_subdirectory(AMDGCN) add_subdirectory(AsterUtils) add_subdirectory(Layout) add_subdirectory(LSIR) -add_subdirectory(NormalForm) diff --git a/lib/Dialect/NormalForm/CMakeLists.txt b/lib/Dialect/NormalForm/CMakeLists.txt deleted file mode 100644 index 9f57627c3..000000000 --- a/lib/Dialect/NormalForm/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_subdirectory(IR) -add_subdirectory(Transforms) diff --git a/lib/Dialect/NormalForm/IR/CMakeLists.txt b/lib/Dialect/NormalForm/IR/CMakeLists.txt deleted file mode 100644 index 7ed5ba119..000000000 --- a/lib/Dialect/NormalForm/IR/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -add_mlir_dialect_library(MLIRNormalFormDialect - NormalFormDialect.cpp - NormalFormOps.cpp - NormalFormInterfaces.cpp - - DEPENDS - MLIRNormalFormIncGen - - LINK_LIBS PUBLIC - MLIRIR -) - -# Install the dialect library so Python can find it at runtime. -install(TARGETS MLIRNormalFormDialect - LIBRARY DESTINATION lib - ARCHIVE DESTINATION lib - RUNTIME DESTINATION bin -) diff --git a/lib/Dialect/NormalForm/IR/NormalFormDialect.cpp b/lib/Dialect/NormalForm/IR/NormalFormDialect.cpp deleted file mode 100644 index a8fe40fb5..000000000 --- a/lib/Dialect/NormalForm/IR/NormalFormDialect.cpp +++ /dev/null @@ -1,26 +0,0 @@ -//===- NormalFormDialect.cpp - dialect definition -------------------------===// -// -// Copyright 2025 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.h" -#include "aster/Dialect/NormalForm/IR/NormalFormOps.h" - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Dialect.h" - -using namespace mlir; - -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.cpp.inc" - -void normalform::NormalFormDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "aster/Dialect/NormalForm/IR/NormalFormOps.cpp.inc" - >(); -} diff --git a/lib/Dialect/NormalForm/IR/NormalFormInterfaces.cpp b/lib/Dialect/NormalForm/IR/NormalFormInterfaces.cpp deleted file mode 100644 index 1cd00493e..000000000 --- a/lib/Dialect/NormalForm/IR/NormalFormInterfaces.cpp +++ /dev/null @@ -1,16 +0,0 @@ -//===- NormalFormInterfaces.cpp -------------------------------------------===// -// -// Copyright 2025 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.h" - -using namespace mlir; - -#include "aster/Dialect/NormalForm/IR/NormalFormAttrInterfaces.cpp.inc" diff --git a/lib/Dialect/NormalForm/IR/NormalFormOps.cpp b/lib/Dialect/NormalForm/IR/NormalFormOps.cpp deleted file mode 100644 index 522044412..000000000 --- a/lib/Dialect/NormalForm/IR/NormalFormOps.cpp +++ /dev/null @@ -1,250 +0,0 @@ -//===- NormalFormOps.cpp --------------------------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "aster/Dialect/NormalForm/IR/NormalFormOps.h" -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/Value.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/WalkResult.h" - -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Diagnostics.h" - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Casting.h" -#include "llvm/Support/LogicalResult.h" - -using namespace mlir; -using namespace normalform; - -#define GET_OP_CLASSES -#include "aster/Dialect/NormalForm/IR/NormalFormOps.cpp.inc" - -//----------------------------------------------------------------------------- -// ModuleOp -//----------------------------------------------------------------------------- - -void normalform::ModuleOp::build(OpBuilder &builder, OperationState &state, - ArrayRef normalForms, - std::optional name) { - state.addRegion()->emplaceBlock(); - ArrayRef attributeArray = - ArrayRef(normalForms.begin(), normalForms.end()); - ArrayAttr normalFormsArray = builder.getArrayAttr(attributeArray); - // Use the tablegen-generated attribute name "normal_forms". - state.addAttribute(getNormalFormsAttrName(state.name), normalFormsArray); - if (name) { - state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), - builder.getStringAttr(*name)); - } -} - -/// Construct a module from the given context. -normalform::ModuleOp -normalform::ModuleOp::create(Location loc, - ArrayRef normalForms, - std::optional name) { - OpBuilder builder(loc->getContext()); - return ModuleOp::create(builder, loc, normalForms, name); -} - -LogicalResult normalform::verifyNormalForm( - Operation *root, NormalFormAttrInterface normalForm, bool emitDiagnostics, - const DenseSet *excludeAttrNames) { - SmallPtrSet seenTypes; - SmallPtrSet seenAttrs; - Location loc = root->getLoc(); - AttrTypeWalker walker; - - auto emitLocError = [&]() { - InFlightDiagnostic diag = mlir::emitError(loc); - if (!emitDiagnostics) - diag.abandon(); - - return diag; - }; - - auto visitType = [&](Type type) { - auto [it, inserted] = seenTypes.insert(type); - if (!inserted) - return WalkResult::skip(); - - if (llvm::failed(normalForm.verifyType(emitLocError, type))) - return WalkResult::interrupt(); - - return WalkResult::advance(); - }; - - auto visitAttr = [&](Attribute attr) { - auto [it, inserted] = seenAttrs.insert(attr); - if (!inserted) - return WalkResult::skip(); - - if (llvm::failed(normalForm.verifyAttribute(emitLocError, attr))) - return WalkResult::interrupt(); - - return WalkResult::advance(); - }; - - walker.addWalk(visitType); - walker.addWalk(visitAttr); - - auto visitOp = [&](Operation *op) { - loc = op->getLoc(); - - // TODO: skip when we reach another normalform.module which has normalform - // in its attributes - if (llvm::failed(normalForm.verifyOperation(emitLocError, op))) - return WalkResult::interrupt(); - - for (OpResult result : op->getResults()) { - loc = result.getLoc(); - WalkResult walkResult = walker.walk(result.getType()); - if (walkResult.wasInterrupted()) - return WalkResult::interrupt(); - } - - for (mlir::Region ®ion : op->getRegions()) { - for (mlir::Block &block : region) { - for (mlir::BlockArgument arg : block.getArguments()) { - loc = arg.getLoc(); - WalkResult walkResult = walker.walk(arg.getType()); - if (walkResult.wasInterrupted()) - return WalkResult::interrupt(); - } - } - } - - for (NamedAttribute attr : op->getAttrs()) { - if (excludeAttrNames && excludeAttrNames->contains(attr.getName())) - continue; - WalkResult walkResult = walker.walk(attr.getValue()); - if (walkResult.wasInterrupted()) - return WalkResult::interrupt(); - } - - return WalkResult::advance(); - }; - - WalkResult walkResult = root->walk(visitOp); - - return llvm::failure(walkResult.wasInterrupted()); -} - -LogicalResult -normalform::ModuleOp::verifyNormalForm(NormalFormAttrInterface normalForm, - bool emitDiagnostics) { - return normalform::verifyNormalForm(getOperation(), normalForm, - emitDiagnostics); -} - -bool normalform::ModuleOp::inferNormalForms( - ArrayRef normalForms) { - ArrayRef currentNormalForms = getNormalFormsAttr().getValue(); - SetVector normalFormSet; - normalFormSet.insert_range(currentNormalForms); - - bool changed = false; - for (NormalFormAttrInterface nf : normalForms) { - if (normalFormSet.contains(nf)) - continue; - - if (llvm::succeeded(verifyNormalForm(nf, /*emitDiagnostics*/ false))) { - normalFormSet.insert(nf); - changed = true; - } - } - - if (!changed) - return false; - - OpBuilder builder(getContext()); - ArrayAttr newNormalFormsAttr = - builder.getArrayAttr(normalFormSet.getArrayRef()); - setNormalFormsAttr(newNormalFormsAttr); - return true; -} - -bool normalform::ModuleOp::addNormalForms( - ArrayRef normalForms) { - if (normalForms.empty()) - return false; - - ArrayRef currentNormalForms = getNormalFormsAttr().getValue(); - SetVector normalFormSet; - normalFormSet.insert_range(currentNormalForms); - - bool changed = false; - for (NormalFormAttrInterface nf : normalForms) - changed |= normalFormSet.insert(nf); - - if (!changed) - return false; - - OpBuilder builder(getContext()); - ArrayAttr newNormalFormsAttr = - builder.getArrayAttr(normalFormSet.getArrayRef()); - setNormalFormsAttr(newNormalFormsAttr); - return true; -} - -bool normalform::ModuleOp::removeNormalForms( - ArrayRef normalForms) { - if (normalForms.empty()) - return false; - - ArrayRef currentNormalForms = getNormalFormsAttr().getValue(); - SetVector normalFormSet; - normalFormSet.insert_range(currentNormalForms); - - bool changed = false; - for (NormalFormAttrInterface nf : normalForms) - changed |= normalFormSet.remove(nf); - - if (!changed) - return false; - - OpBuilder builder(getContext()); - ArrayAttr newNormalFormsAttr = - builder.getArrayAttr(normalFormSet.getArrayRef()); - setNormalFormsAttr(newNormalFormsAttr); - return true; -} - -LogicalResult normalform::ModuleOp::verify() { - // Verify that normal form attributes are unique (set semantics). - ArrayAttr normalForms = getNormalFormsAttr(); - SmallPtrSet seenForms; - for (Attribute attr : normalForms) { - auto [it, inserted] = seenForms.insert(attr); - if (!inserted) - return emitOpError() << "contains duplicate normal form attribute: " - << attr; - } - return llvm::success(); -} - -LogicalResult normalform::ModuleOp::verifyRegions() { - ArrayRef normalFormsAttrs = getNormalForms().getValue(); - auto normalFormRange = - llvm::map_range(normalFormsAttrs, llvm::CastTo); - - for (NormalFormAttrInterface normalForm : normalFormRange) { - if (llvm::failed(verifyNormalForm(normalForm, /*emitDiagnostics*/ true))) { - return llvm::failure(); - } - } - return llvm::success(); -} diff --git a/lib/Dialect/NormalForm/Transforms/CMakeLists.txt b/lib/Dialect/NormalForm/Transforms/CMakeLists.txt deleted file mode 100644 index 663f5a8ff..000000000 --- a/lib/Dialect/NormalForm/Transforms/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -add_mlir_dialect_library(MLIRNormalFormTransforms - LowerNormalFormModule.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/aster/ - - DEPENDS - MLIRAsterNormalFormPassesIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRNormalFormDialect - MLIRParser - MLIRPass - MLIRRewrite - MLIRTransformUtils -) diff --git a/lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp b/lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp deleted file mode 100644 index 01311689b..000000000 --- a/lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp +++ /dev/null @@ -1,102 +0,0 @@ -//===- LowerNormalFormModule.cpp ------------------------------------------===// -// -// Copyright 2025 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.h" -#include "aster/Dialect/NormalForm/IR/NormalFormOps.h" -#include "aster/Dialect/NormalForm/Transforms/Passes.h" -#include "mlir/Transforms/WalkPatternRewriteDriver.h" - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/STLExtras.h" -#include - -#define GEN_PASS_DEF_LOWERNORMALFORMMODULEPASS -#include "aster/Dialect/NormalForm/Transforms/Passes.h.inc" - -using namespace mlir; - -namespace { - -//===----------------------------------------------------------------------===// -// LowerNormalFormModulePattern -//===----------------------------------------------------------------------===// - -/// Lower `normalform.module` to `builtin.module`, discarding normal form -/// attributes. -class LowerNormalFormModulePattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(normalform::ModuleOp nfModule, - PatternRewriter &rewriter) const override { - // Check if parent is a builtin module - if so, inline contents into parent. - if (auto parentModule = dyn_cast(nfModule->getParentOp())) { - rewriter.setInsertionPoint(nfModule); - for (Operation &op : llvm::make_early_inc_range(*nfModule.getBody())) - rewriter.moveOpBefore(&op, nfModule); - rewriter.eraseOp(nfModule); - return success(); - } - - // Otherwise, create a new builtin module. - ModuleOp builtinModule = - ModuleOp::create(rewriter, nfModule.getLoc(), nfModule.getName()); - - // Move all blocks from the normalform module to the builtin module. - rewriter.inlineRegionBefore(nfModule.getRegion(), builtinModule.getBody()); - - // Remove the empty terminator block that was automatically added by the - // builder. - rewriter.eraseBlock(&builtinModule.getBodyRegion().back()); - - rewriter.eraseOp(nfModule); - return success(); - } -}; - -//===----------------------------------------------------------------------===// -// LowerNormalFormModulePass -//===----------------------------------------------------------------------===// - -struct LowerNormalFormModulePass - : public ::impl::LowerNormalFormModulePassBase { - using LowerNormalFormModulePassBase::LowerNormalFormModulePassBase; - - void runOnOperation() override { - Operation *root = getOperation(); - - if (auto rootModule = dyn_cast(root)) { - int64_t count = llvm::count_if(rootModule.getBody()->getOperations(), - llvm::IsaPred); - - if (count > 1) { - rootModule.emitError() - << "expected at most one top-level " - << normalform::ModuleOp::getOperationName() << ", found " << count; - return signalPassFailure(); - } - } - - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - - walkAndApplyPatterns(getOperation(), std::move(patterns)); - } -}; - -} // namespace - -std::unique_ptr normalform::createLowerNormalFormModulePass() { - return std::make_unique(); -} diff --git a/lib/Init.cpp b/lib/Init.cpp index 2f3cc9d72..189108899 100644 --- a/lib/Init.cpp +++ b/lib/Init.cpp @@ -17,8 +17,6 @@ #include "aster/Dialect/LSIR/IR/LSIRDialect.h" #include "aster/Dialect/Layout/IR/LayoutDialect.h" #include "aster/Dialect/Layout/Transforms/Passes.h" -#include "aster/Dialect/NormalForm/IR/NormalFormDialect.h" -#include "aster/Dialect/NormalForm/Transforms/Passes.h" #include "aster/Interfaces/UpstreamExternalModels.h" #include "aster/Transforms/Passes.h" #include "mlir/CAPI/IR.h" @@ -411,7 +409,6 @@ void mlir::aster::initDialects(DialectRegistry ®istry) { registry.insert(); registry.insert(); registry.insert(); - registry.insert(); registerUpstreamExternalModels(registry); if (contribRegisterFn) contribRegisterFn(registry); @@ -424,7 +421,6 @@ void mlir::aster::registerPasses() { aster::registerAsterPasses(); aster::registerCodeGenPasses(); layout::registerLayoutPasses(); - normalform::registerNormalFormPasses(); } /// diff --git a/llvm/LLVM_COMMIT b/llvm/LLVM_COMMIT index c4e2952b0..f3e340346 100644 --- a/llvm/LLVM_COMMIT +++ b/llvm/LLVM_COMMIT @@ -1 +1 @@ -3bbb7cc6de0100c101604a9e69aae9ef345368f2 +9deb1c631b11230787f0fb56583b17f060b194a0 diff --git a/test/Dialect/AMDGCN/IR/normal-forms-kernel.mlir b/test/Dialect/AMDGCN/IR/normal-forms-kernel.mlir index 4257bbe84..745246eea 100644 --- a/test/Dialect/AMDGCN/IR/normal-forms-kernel.mlir +++ b/test/Dialect/AMDGCN/IR/normal-forms-kernel.mlir @@ -19,4 +19,17 @@ amdgcn.module @test target = #amdgcn.target isa = #amdgcn.isa { ^bb0: amdgcn.end_kernel } + + // The kernel's `arguments` attribute carries ABI metadata whose register + // types are not subject to the no_value_semantic_registers check. + // CHECK: kernel @by_val_arg_metadata + // CHECK-SAME: attributes {normal_forms = [#amdgcn.no_value_semantic_registers] + amdgcn.kernel @by_val_arg_metadata arguments <[ + #amdgcn.by_val_arg + ]> attributes { + normal_forms = [#amdgcn.no_value_semantic_registers], + shared_memory_size = 0 : i32 + } { + amdgcn.end_kernel + } } diff --git a/test/Dialect/NormalForm/amdgcn-no-value-semantic-registers.mlir b/test/Dialect/NormalForm/amdgcn-no-value-semantic-registers.mlir deleted file mode 100644 index 60b6b4235..000000000 --- a/test/Dialect/NormalForm/amdgcn-no-value-semantic-registers.mlir +++ /dev/null @@ -1,52 +0,0 @@ -// RUN: aster-opt %s --split-input-file --verify-diagnostics - -normalform.module @value_vgpr [#amdgcn.no_value_semantic_registers] { - func.func @f() { - // expected-error @below {{normal form violation: register types with value semantics are disallowed but found}} - %0 = amdgcn.alloca : !amdgcn.vgpr - return - } -} - -// ----- - -normalform.module @value_sgpr [#amdgcn.no_value_semantic_registers] { - func.func @f() { - // expected-error @below {{normal form violation: register types with value semantics are disallowed but found}} - %0 = amdgcn.alloca : !amdgcn.sgpr - return - } -} - -// ----- - -normalform.module @value_in_func_arg [#amdgcn.no_value_semantic_registers] { - // expected-error @below {{normal form violation: register types with value semantics are disallowed but found}} - func.func @f(%arg: !amdgcn.vgpr) { - return - } -} - -// ----- - -normalform.module @value_in_func_result [#amdgcn.no_value_semantic_registers] { - // expected-error @below {{normal form violation: register types with value semantics are disallowed but found}} - func.func @f() -> !amdgcn.vgpr { - %0 = amdgcn.alloca : !amdgcn.vgpr - return %0 : !amdgcn.vgpr - } -} - -// ----- - -// by_val_arg type is ABI metadata, not a value-semantic register in the body -amdgcn.module @by_val_arg_metadata target = #amdgcn.target isa = #amdgcn.isa { - amdgcn.kernel @test arguments <[ - #amdgcn.by_val_arg - ]> attributes { - normal_forms = [#amdgcn.no_value_semantic_registers], - shared_memory_size = 0 : i32 - } { - amdgcn.end_kernel - } -} diff --git a/test/Dialect/NormalForm/amdgcn-to-register-semantics-postcondition.mlir b/test/Dialect/NormalForm/amdgcn-to-register-semantics-postcondition.mlir deleted file mode 100644 index 6f6233eba..000000000 --- a/test/Dialect/NormalForm/amdgcn-to-register-semantics-postcondition.mlir +++ /dev/null @@ -1,25 +0,0 @@ -// Demonstrates that #amdgcn.no_value_semantic_registers is a post-condition -// of the amdgcn-to-register-semantics pass. -// -// Step 1: Run the pass on IR with value-semantic registers. -// Step 2: Wrap the output in a normalform.module and re-verify. -// -// RUN: aster-opt --amdgcn-to-register-semantics %s \ -// RUN: | sed '1s/^module/normalform.module [#amdgcn.no_value_semantic_registers]/' \ -// RUN: | aster-opt --verify-diagnostics - -func.func @value_to_unallocated() { - %0 = amdgcn.alloca : !amdgcn.vgpr - %1 = amdgcn.alloca : !amdgcn.vgpr - %2 = lsir.copy %1, %0 : !amdgcn.vgpr, !amdgcn.vgpr - %3 = amdgcn.test_inst outs %0 : (!amdgcn.vgpr) -> !amdgcn.vgpr - amdgcn.test_inst ins %3, %2 : (!amdgcn.vgpr, !amdgcn.vgpr) -> () - func.return -} - -func.func @mixed_types(%arg: i32) -> f32 { - %0 = amdgcn.alloca : !amdgcn.vgpr - %1 = amdgcn.alloca : !amdgcn.sgpr - %cst = arith.constant 0.0 : f32 - return %cst : f32 -} diff --git a/test/Dialect/NormalForm/lower-normalform-module-invalid.mlir b/test/Dialect/NormalForm/lower-normalform-module-invalid.mlir deleted file mode 100644 index 24d06367d..000000000 --- a/test/Dialect/NormalForm/lower-normalform-module-invalid.mlir +++ /dev/null @@ -1,19 +0,0 @@ -// RUN: aster-opt %s -lower-normalform-module --split-input-file --verify-diagnostics - -//----------------------------------------------------------------------------- -// Test that multiple top-level normalform.module operations are rejected. -//----------------------------------------------------------------------------- - -// expected-error @below {{expected at most one top-level normalform.module, found 2}} -module { - normalform.module [] { - func.func @foo() { - return - } - } - normalform.module [] { - func.func @bar() { - return - } - } -} diff --git a/test/Dialect/NormalForm/lower-normalform-module.mlir b/test/Dialect/NormalForm/lower-normalform-module.mlir deleted file mode 100644 index f0b3c8bdd..000000000 --- a/test/Dialect/NormalForm/lower-normalform-module.mlir +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: aster-opt %s -lower-normalform-module --mlir-print-local-scope --split-input-file | FileCheck %s - -//----------------------------------------------------------------------------- -// Test lowering of normalform.module to builtin.module. -//----------------------------------------------------------------------------- - -// Test that a top-level normalform.module is inlined into the root module. -// CHECK: module { -// CHECK-NOT: normalform.module -// CHECK-NOT: module { -// CHECK: func.func @inlined_into_root() -// CHECK: } -normalform.module [] { - func.func @inlined_into_root() { - return - } -} - -// ----- - -// Test that a named normalform.module is inlined into the root module. -// CHECK: module { -// CHECK-NOT: normalform.module -// CHECK-NOT: module { -// CHECK: func.func @from_named_module() -// CHECK: } -normalform.module @named [] { - func.func @from_named_module() { - return - } -} - -// ----- - -// Test that multiple operations are preserved when inlining. -// CHECK: module { -// CHECK: func.func @first() -// CHECK: func.func @second() -// CHECK: func.func @third() -// CHECK: } -normalform.module [] { - func.func @first() { - return - } - func.func @second() { - return - } - func.func @third() { - return - } -} diff --git a/test/Dialect/NormalForm/ops-invalid.mlir b/test/Dialect/NormalForm/ops-invalid.mlir deleted file mode 100644 index 308f635c3..000000000 --- a/test/Dialect/NormalForm/ops-invalid.mlir +++ /dev/null @@ -1,134 +0,0 @@ -// RUN: aster-opt %s --split-input-file --verify-diagnostics - -//----------------------------------------------------------------------------- -// Test dialect normal form attribute tests (single attribute). -//----------------------------------------------------------------------------- - -// no_index_types: index function argument. -normalform.module @no_index_types_arg [#aster_test.no_index_types] { - // expected-error @below {{normal form prohibits index types}} - func.func @f(%arg: index) { - return - } -} - -// ----- - -// no_index_types: index result type. -normalform.module @no_index_types_result [#aster_test.no_index_types] { - // expected-error @below {{normal form prohibits index types}} - func.func @f() -> index { - %0 = arith.constant 0 : index - return %0 : index - } -} - -// ----- - -// no_index_types: index-typed block argument in nested region (scf.for iter_arg). -// This tests that block arguments in nested regions are verified for type -// constraints. The error is triggered by the arith.constant result type, but -// the scf.for iter_arg and induction variable block arguments are also checked. -normalform.module @no_index_types_nested_block_arg [#aster_test.no_index_types] { - func.func @f() { - // expected-error @below {{normal form prohibits index types}} - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c10 = arith.constant 10 : index - %result = scf.for %iv = %c0 to %c10 step %c1 iter_args(%iter = %c0) -> (index) { - scf.yield %iter : index - } - return - } -} - -// ----- - -// no_invalid_ops: division operation. -normalform.module @no_invalid_ops [#aster_test.no_invalid_ops] { - func.func @f(%a: f32, %b: f32) -> f32 { - // expected-error @below {{normal form prohibits division operations}} - %0 = arith.divf %a, %b : f32 - return %0 : f32 - } -} - -// ----- - -// no_invalid_attrs: string attribute with value "invalid". -normalform.module @no_invalid_attrs [#aster_test.no_invalid_attrs] { - // expected-error @below {{normal form prohibits 'invalid' string attribute values}} - func.func @f() attributes {foo = "invalid"} { - return - } -} - -// ----- - -//----------------------------------------------------------------------------- -// Test dialect normal form attribute tests (multiple attributes). -//----------------------------------------------------------------------------- - -// Multiple test attributes: violation on no_invalid_ops. -normalform.module @multi_attrs_invalid_op [#aster_test.no_index_types, #aster_test.no_invalid_ops] { - func.func @f(%a: i32, %b: i32) -> i32 { - // expected-error @below {{normal form prohibits division operations}} - %0 = arith.divsi %a, %b : i32 - return %0 : i32 - } -} - -// ----- - -// Multiple test attributes: only no_invalid_attrs violation present. -normalform.module @multi_attrs_invalid_attr [#aster_test.no_index_types, #aster_test.no_invalid_attrs] { - // expected-error @below {{normal form prohibits 'invalid' string attribute values}} - func.func @f() attributes {x = "invalid"} { - return - } -} - -// ----- - -// Multiple test attributes: only no_index_types violation present. -normalform.module @multi_attrs_index_type [#aster_test.no_invalid_ops, #aster_test.no_index_types] { - // expected-error @below {{normal form prohibits index types}} - func.func @f(%arg: index) { - return - } -} - -// ----- - -// All three attributes: violation on no_invalid_ops. -normalform.module @all_attrs_invalid_op [#aster_test.no_index_types, #aster_test.no_invalid_ops, #aster_test.no_invalid_attrs] { - func.func @f(%a: i32, %b: i32) -> i32 { - // expected-error @below {{normal form prohibits division operations}} - %0 = arith.divui %a, %b : i32 - return %0 : i32 - } -} - -// ----- - -// All three attributes: violation on no_invalid_attrs only. -normalform.module @all_attrs_invalid_attr [#aster_test.no_index_types, #aster_test.no_invalid_ops, #aster_test.no_invalid_attrs] { - // expected-error @below {{normal form prohibits 'invalid' string attribute values}} - func.func @f() attributes {x = "invalid"} { - return - } -} - -// ----- - -//----------------------------------------------------------------------------- -// Duplicate attribute rejection. -//----------------------------------------------------------------------------- - -// Duplicate normal form attributes are rejected. -// expected-error @below {{contains duplicate normal form attribute}} -normalform.module @duplicate [#aster_test.no_index_types, #aster_test.no_index_types] { - func.func @f() { - return - } -} diff --git a/test/Dialect/NormalForm/ops.mlir b/test/Dialect/NormalForm/ops.mlir deleted file mode 100644 index 421472d99..000000000 --- a/test/Dialect/NormalForm/ops.mlir +++ /dev/null @@ -1,89 +0,0 @@ -// RUN: aster-opt %s | aster-opt | FileCheck %s -// RUN: aster-opt %s --mlir-print-op-generic | aster-opt | FileCheck %s - -//----------------------------------------------------------------------------- -// Test dialect normal form attribute tests (single attribute). -//----------------------------------------------------------------------------- - -// no_index_types passes when no index types are used. -// CHECK-LABEL: normalform.module @no_index_types_valid -// CHECK-SAME: [#aster_test.no_index_types] -normalform.module @no_index_types_valid [#aster_test.no_index_types] { - func.func @f(%arg: i32) -> f32 { - %cst = arith.constant 0.0 : f32 - return %cst : f32 - } -} - -// no_invalid_ops passes when no division operations are present. -// CHECK-LABEL: normalform.module @no_invalid_ops_valid -// CHECK-SAME: [#aster_test.no_invalid_ops] -normalform.module @no_invalid_ops_valid [#aster_test.no_invalid_ops] { - func.func @f(%a: f32, %b: f32) -> f32 { - %0 = arith.mulf %a, %b : f32 - return %0 : f32 - } -} - -// no_invalid_attrs passes when no "invalid" string attributes are present. -// CHECK-LABEL: normalform.module @no_invalid_attrs_valid -// CHECK-SAME: [#aster_test.no_invalid_attrs] -normalform.module @no_invalid_attrs_valid [#aster_test.no_invalid_attrs] { - func.func @f() attributes {foo = "valid", bar = 42 : i32} { - return - } -} - -// no_invalid_ops passes with scf.for block arguments (iter_args and induction -// variable). This tests that block arguments in nested regions are properly -// walked and verified. -// CHECK-LABEL: normalform.module @no_invalid_ops_nested_block_args -// CHECK-SAME: [#aster_test.no_invalid_ops] -normalform.module @no_invalid_ops_nested_block_args [#aster_test.no_invalid_ops] { - func.func @f() -> f32 { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c10 = arith.constant 10 : index - %init = arith.constant 0.0 : f32 - %result = scf.for %iv = %c0 to %c10 step %c1 iter_args(%acc = %init) -> (f32) { - %one = arith.constant 1.0 : f32 - %sum = arith.addf %acc, %one : f32 - scf.yield %sum : f32 - } - return %result : f32 - } -} - -//----------------------------------------------------------------------------- -// Test dialect normal form attribute tests (multiple attributes). -//----------------------------------------------------------------------------- - -// Two attributes pass with valid IR. -// CHECK-LABEL: normalform.module @two_attrs_valid -// CHECK-SAME: [#aster_test.no_index_types, #aster_test.no_invalid_ops] -normalform.module @two_attrs_valid [#aster_test.no_index_types, #aster_test.no_invalid_ops] { - func.func @f(%arg: i32) { - return - } -} - -// All three attributes pass with valid IR. -// CHECK-LABEL: normalform.module @all_attrs_valid -// CHECK-SAME: [#aster_test.no_index_types, #aster_test.no_invalid_ops, #aster_test.no_invalid_attrs] -normalform.module @all_attrs_valid [#aster_test.no_index_types, #aster_test.no_invalid_ops, #aster_test.no_invalid_attrs] { - func.func @f(%arg: i32) attributes {foo = "valid"} { - return - } -} - -//----------------------------------------------------------------------------- -// Module without name. -//----------------------------------------------------------------------------- - -// Anonymous module with single attribute. -// CHECK-LABEL: normalform.module [#aster_test.no_invalid_ops] -normalform.module [#aster_test.no_invalid_ops] { - func.func @f() { - return - } -} diff --git a/test/lib/CMakeLists.txt b/test/lib/CMakeLists.txt index 63807d95d..396350b01 100644 --- a/test/lib/CMakeLists.txt +++ b/test/lib/CMakeLists.txt @@ -1,2 +1 @@ -add_subdirectory(Dialect) add_subdirectory(Pass) diff --git a/test/lib/Dialect/AsterTestDialect.cpp b/test/lib/Dialect/AsterTestDialect.cpp deleted file mode 100644 index 931221507..000000000 --- a/test/lib/Dialect/AsterTestDialect.cpp +++ /dev/null @@ -1,88 +0,0 @@ -//===- AsterTestDialect.cpp - test dialect --------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "AsterTestDialect.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpImplementation.h" - -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace mlir; - -#include "AsterTestDialect.cpp.inc" - -// Test normal form attribute. -#define GET_ATTRDEF_CLASSES -#include "TestNormalFormAttr.cpp.inc" - -void mlir::aster::test::AsterTestDialect::initialize() { - addAttributes< -#define GET_ATTRDEF_LIST -#include "TestNormalFormAttr.cpp.inc" - >(); -}; - -namespace mlir::aster::test { -void registerAsterTestDialect(DialectRegistry ®istry) { - registry.insert(); -} -} // namespace mlir::aster::test - -using namespace mlir::aster::test; - -//----------------------------------------------------------------------------- -// NoIndexTypesAttr interface implementations. -//----------------------------------------------------------------------------- - -llvm::LogicalResult -NoIndexTypesAttr::verifyType(llvm::function_ref emitError, - Type type) const { - if (!type) - return llvm::success(); - - if (type.isIndex()) - return emitError() << "normal form prohibits index types"; - - return llvm::success(); -} - -//----------------------------------------------------------------------------- -// NoInvalidOpsAttr interface implementations. -//----------------------------------------------------------------------------- - -llvm::LogicalResult NoInvalidOpsAttr::verifyOperation( - llvm::function_ref emitError, Operation *op) const { - if (isa(op)) - return emitError() << "normal form prohibits division operations"; - - return llvm::success(); -} - -//----------------------------------------------------------------------------- -// NoInvalidAttrsAttr interface implementations. -//----------------------------------------------------------------------------- - -llvm::LogicalResult NoInvalidAttrsAttr::verifyAttribute( - llvm::function_ref emitError, Attribute attr) const { - if (!attr) - return llvm::success(); - - if (auto strAttr = llvm::dyn_cast(attr)) { - if (strAttr.getValue() == "invalid") - return emitError() - << "normal form prohibits 'invalid' string attribute values"; - } - - return llvm::success(); -} diff --git a/test/lib/Dialect/AsterTestDialect.h b/test/lib/Dialect/AsterTestDialect.h deleted file mode 100644 index 94938ad8c..000000000 --- a/test/lib/Dialect/AsterTestDialect.h +++ /dev/null @@ -1,25 +0,0 @@ -//===- AsterTestDialect.h - test dialect ----------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT_H -#define ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT_H - -#include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.h" -#include "mlir/Bytecode/BytecodeOpInterface.h" -#include "mlir/IR/OpDefinition.h" - -#include "AsterTestDialect.h.inc" -#include "mlir/IR/Dialect.h" - -// Test normal form attribute. -#define GET_ATTRDEF_CLASSES -#include "TestNormalFormAttr.h.inc" - -#endif // ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT_H diff --git a/test/lib/Dialect/AsterTestDialect.td b/test/lib/Dialect/AsterTestDialect.td deleted file mode 100644 index 4a2899e31..000000000 --- a/test/lib/Dialect/AsterTestDialect.td +++ /dev/null @@ -1,24 +0,0 @@ -//===- AsterTestDialect.td - test dialect ---------------------------------===// -// -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -include "mlir/IR/DialectBase.td" -include "mlir/IR/OpBase.td" - -#ifndef ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT -#define ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT - -def AsterTestDialect : Dialect { - let name = "aster_test"; - let summary = "Dialect for testing Aster"; - let cppNamespace = "::mlir::aster::test"; - let useDefaultAttributePrinterParser = 1; -} - -#endif // ASTER_TEST_LIB_DIALECT_WATERTESTDIALECT diff --git a/test/lib/Dialect/CMakeLists.txt b/test/lib/Dialect/CMakeLists.txt deleted file mode 100644 index a5508b8e0..000000000 --- a/test/lib/Dialect/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS AsterTestDialect.td) -mlir_tablegen(AsterTestDialect.h.inc -gen-dialect-decls -dialect=aster_test) -mlir_tablegen(AsterTestDialect.cpp.inc -gen-dialect-defs -dialect=aster_test) -mlir_tablegen(AsterTestDialectOps.h.inc -gen-op-decls) -mlir_tablegen(AsterTestDialectOps.cpp.inc -gen-op-defs) -add_public_tablegen_target(MLIRAsterTestDialectIncGen) - -set(LLVM_TARGET_DEFINITIONS TestNormalFormAttr.td) -mlir_tablegen(TestNormalFormAttr.h.inc -gen-attrdef-decls -attrdefs-dialect=aster_test) -mlir_tablegen(TestNormalFormAttr.cpp.inc -gen-attrdef-defs -attrdefs-dialect=aster_test) -add_public_tablegen_target(MLIRTestNormalFormAttrIncGen) - -add_mlir_dialect_library(MLIRAsterTestDialect - AsterTestDialect.cpp - - EXCLUDE_FROM_LIBMLIR - - DEPENDS - MLIRTestNormalFormAttrIncGen - - LINK_LIBS - MLIRArithDialect - MLIRIR - MLIRNormalFormDialect -) - -target_include_directories(MLIRAsterTestDialect - PRIVATE - ${PROJECT_BINARY_DIR}/test/lib/Dialect -) diff --git a/test/lib/Dialect/TestNormalFormAttr.td b/test/lib/Dialect/TestNormalFormAttr.td deleted file mode 100644 index 46700bc6f..000000000 --- a/test/lib/Dialect/TestNormalFormAttr.td +++ /dev/null @@ -1,54 +0,0 @@ -//===- TestNormalFormAttrs.td - definitions for NormalForm tests ----------===// -// -// Copyright 2025 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef ASTER_TEST_LIB_DIALECT_TESTNORMALFORMATTR -#define ASTER_TEST_LIB_DIALECT_TESTNORMALFORMATTR - -include "mlir/IR/DialectBase.td" -include "mlir/IR/AttrTypeBase.td" -include "aster/Dialect/NormalForm/IR/NormalFormInterfaces.td" -include "AsterTestDialect.td" - -//----------------------------------------------------------------------------- -// Test normal form attributes for testing NormalFormAttrInterface. -// Each attribute tests a single invariant. -//----------------------------------------------------------------------------- - -def NoIndexTypesAttr : AttrDef]> { - let mnemonic = "no_index_types"; - let summary = "Normal form attribute that prohibits index types"; - let description = [{ - A test attribute implementing NormalFormAttrInterface that verifies - no index types are used in operations, block arguments, or results. - }]; -} - -def NoInvalidOpsAttr : AttrDef]> { - let mnemonic = "no_invalid_ops"; - let summary = "Normal form attribute that prohibits division operations"; - let description = [{ - A test attribute implementing NormalFormAttrInterface that verifies - no division operations (arith.divf, arith.divsi, arith.divui) are present. - }]; -} - -def NoInvalidAttrsAttr : AttrDef]> { - let mnemonic = "no_invalid_attrs"; - let summary = "Normal form attribute that prohibits 'invalid' string values"; - let description = [{ - A test attribute implementing NormalFormAttrInterface that verifies - no string attributes have the value "invalid". - }]; -} - -#endif // ASTER_TEST_LIB_DIALECT_TESTNORMALFORMATTR diff --git a/tools/aster-opt/CMakeLists.txt b/tools/aster-opt/CMakeLists.txt index cc0c28d11..4bcc63c0a 100644 --- a/tools/aster-opt/CMakeLists.txt +++ b/tools/aster-opt/CMakeLists.txt @@ -2,7 +2,6 @@ set(LIBS AMDGCNTransforms ASTERInit ASTERTestPass - MLIRAsterTestDialect MLIROptLib ) diff --git a/tools/aster-opt/aster-opt.cpp b/tools/aster-opt/aster-opt.cpp index 43ab1fb7d..6e2bb75f0 100644 --- a/tools/aster-opt/aster-opt.cpp +++ b/tools/aster-opt/aster-opt.cpp @@ -22,10 +22,6 @@ namespace mlir::aster { void registerTestPasses(); } // namespace mlir::aster -namespace mlir::aster::test { -void registerAsterTestDialect(DialectRegistry ®istry); -} // namespace mlir::aster::test - using namespace llvm; using namespace mlir; @@ -41,7 +37,6 @@ int main(int argc, char **argv) { aster::initDialects(registry); aster::registerPasses(); aster::registerTestPasses(); - aster::test::registerAsterTestDialect(registry); return mlir::asMainReturnCode(mlir::MlirOptMain( argc, argv, "aster modular optimizer driver\n", registry)); }