diff --git a/CMakeLists.txt b/CMakeLists.txt index 14debaad1..ba1937b54 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -144,5 +144,18 @@ add_subdirectory(test) option(ASTER_ENABLE_MLIR_AIR "Build mlir-air contrib (linalg->AMDGCN)" OFF) if(ASTER_ENABLE_MLIR_AIR) message(STATUS "mlir-air enabled") + # Upstream Xilinx/mlir-air: AIR dialect + conversion passes (no AIE, no GPU backend). + set(AIR_ENABLE_AIE OFF CACHE BOOL "" FORCE) + set(AIR_ENABLE_GPU OFF CACHE BOOL "" FORCE) + # Stub targets expected by upstream mlir-air CMakeLists. + add_custom_target(air-headers) + # Make upstream AIR .td files and generated headers visible to mlir-tblgen + # and C++ compilation. Must come before add_subdirectory so the directory + # property is set when mlir_tablegen() calls get_directory_property(). + include_directories(${CMAKE_SOURCE_DIR}/third_party/mlir-air/mlir/include) + include_directories(${CMAKE_BINARY_DIR}/third_party/mlir-air/mlir/include) + # Only build AIR lib+include, not its test suite (check-all not available here). + add_subdirectory(third_party/mlir-air/mlir/include) + add_subdirectory(third_party/mlir-air/mlir/lib) add_subdirectory(contrib/mlir-air) endif() diff --git a/contrib/kittens/library/lds_16x64_b.mlir b/contrib/kittens/library/lds_16x64_b.mlir index 8b4679ea0..6e0704bcd 100644 --- a/contrib/kittens/library/lds_16x64_b.mlir +++ b/contrib/kittens/library/lds_16x64_b.mlir @@ -138,4 +138,19 @@ amdgcn.library @kittens_lds_16x64_b isa = [#amdgcn.isa] { return %future : !future_lds_read } + // Zero-fill a 16x64_b LDS tile (1024 bytes). + // Each thread writes 16 bytes of zeros at its assigned positions. + func.func private @fill_lds_16x64_b(%tile_base: index) { + %zero = func.call @alloc_vgprx2() : () -> !vx2 + %addr_lo, %addr_hi = func.call @compute_lds_write_addrs_16x64_b(%tile_base) + : (index) -> (index, index) + %t0 = func.call @write_vx2_to_lds_at(%zero, %addr_lo) + : (!vx2, index) -> !future_lds_write + %t1 = func.call @write_vx2_to_lds_at(%zero, %addr_hi) + : (!vx2, index) -> !future_lds_write + amdgcn.wait deps %t0 : !future_lds_write + amdgcn.wait deps %t1 : !future_lds_write + return + } + } diff --git a/contrib/mlir-air/CMakeLists.txt b/contrib/mlir-air/CMakeLists.txt index b65fca9d4..234399296 100644 --- a/contrib/mlir-air/CMakeLists.txt +++ b/contrib/mlir-air/CMakeLists.txt @@ -1,9 +1,20 @@ add_mlir_library(MlirAirLib - lib/ConvertLinalgToAMDGCN.cpp + lib/AirToAMDGCN.cpp + lib/ConvertToAMDGCNLibraryCalls.cpp + lib/ConvertMemSpaceToAMDGCN.cpp lib/Init.cpp lib/Pipelines.cpp LINK_LIBS PUBLIC + AIRConversionPasses + AIRDialect + MLIRBufferizationDialect + MLIRBufferizationTransformOps + MLIRBufferizationTransforms + MLIRTensorInferTypeOpInterfaceImpl + MLIRTensorTransforms + AIRTransformOps + AIRTransformPasses ASTERInit AsterTransforms AsterCodeGen @@ -13,6 +24,7 @@ add_mlir_library(MlirAirLib MLIRAffineDialect MLIRAffineTransforms MLIRFuncDialect + MLIRGPUDialect MLIRLinalgDialect MLIRLinalgTransformOps MLIRLinalgTransforms diff --git a/contrib/mlir-air/lib/AirToAMDGCN.cpp b/contrib/mlir-air/lib/AirToAMDGCN.cpp new file mode 100644 index 000000000..0ef45d72d --- /dev/null +++ b/contrib/mlir-air/lib/AirToAMDGCN.cpp @@ -0,0 +1,324 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- AirToAMDGCN.cpp - Lower AIR hierarchy ops to AMDGCN IR ------------===// +// +// Lowers air.launch, air.segment, air.herd, air.execute, and air.wait_all +// to flat AMDGCN-compatible IR: +// +// air.launch IDs -> gpu.block_id (workgroup IDs) +// air.herd tile IDs -> gpu.thread_id / 64 (wavefront index within workgroup) +// air.segment -> inline body (transparent wrapper) +// air.execute -> inline body (strip async) +// air.wait_all -> erase +// +// air.channel.put/get are not expected after air-to-amdgcn (air-dma-to-channel not used). +//===----------------------------------------------------------------------===// + +#include "air/Dialect/AIR/AIRDialect.h" +#include "aster/Interfaces/ModuleOpInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Map gpu::Dimension enum from an integer index. +static gpu::Dimension dimFromIndex(unsigned i) { + switch (i) { + case 0: + return gpu::Dimension::x; + case 1: + return gpu::Dimension::y; + case 2: + return gpu::Dimension::z; + default: + llvm_unreachable("invalid dimension index"); + } +} + +/// Clone all ops from `src` into `builder`'s insertion point, applying +/// `mapping`. Ops whose type matches any of `SkipOps...` are skipped. +template +static void cloneBodyOps(OpBuilder &builder, Block &src, IRMapping &mapping) { + for (auto &op : src.getOperations()) { + if ((isa(op) || ...)) + continue; + builder.clone(op, mapping); + } +} + +/// Strip async_dependencies from a channel op and drop its async_token result. +static void stripAsyncFromChannelOp(Operation *op) { + if (auto asyncOp = dyn_cast(op)) { + // Remove all async dependency operands. + while (asyncOp.getAsyncDependencies().size() > 0) + asyncOp.eraseAsyncDependency(0); + // If the op has an async_token result, replace uses and drop it. + if (auto token = asyncOp.getAsyncToken()) { + token.replaceAllUsesWith(Value()); + // We can't actually remove a result from an existing op, but since all + // token users will be erased (wait_all, other async deps), replacing + // with null is sufficient — dead token uses are cleaned up below. + } + } +} + +// --------------------------------------------------------------------------- +// Pass +// --------------------------------------------------------------------------- + +struct AirToAMDGCN + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AirToAMDGCN) + StringRef getArgument() const override { return "air-to-amdgcn"; } + StringRef getDescription() const override { + return "Lower AIR hierarchy ops to AMDGCN-compatible IR"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + Operation *moduleOp = getOperation(); + OpBuilder builder(moduleOp->getContext()); + + // ----------------------------------------------------------------------- + // Phase 1: Strip async tokens. + // ----------------------------------------------------------------------- + + // Strip async from channel ops (they survive this pass). + moduleOp->walk([&](xilinx::air::ChannelPutOp op) { + stripAsyncFromChannelOp(op); + }); + moduleOp->walk([&](xilinx::air::ChannelGetOp op) { + stripAsyncFromChannelOp(op); + }); + + // Inline air.execute: splice body, replace results, erase. + SmallVector executes; + moduleOp->walk([&](xilinx::air::ExecuteOp op) { executes.push_back(op); }); + for (auto execOp : executes) { + Block &body = execOp.getBody(); + auto terminator = + cast(body.getTerminator()); + + // Replace non-token results (results 1..N) with terminator operands. + for (unsigned i = 0; i < terminator.getNumOperands(); ++i) + execOp.getResult(i + 1).replaceAllUsesWith(terminator.getOperand(i)); + + // Replace async token result (result 0) — users will be erased. + execOp.getResult(0).replaceAllUsesWith(Value()); + + // Splice body ops (except terminator) before the execute op. + auto &parentOps = execOp->getBlock()->getOperations(); + auto &bodyOps = body.getOperations(); + auto beforeTerminator = terminator->getIterator(); + parentOps.splice(execOp->getIterator(), bodyOps, bodyOps.begin(), + beforeTerminator); + + terminator->erase(); + execOp->erase(); + } + + // Erase air.wait_all ops. + SmallVector waitAlls; + moduleOp->walk( + [&](xilinx::air::WaitAllOp op) { waitAlls.push_back(op); }); + for (auto op : waitAlls) { + if (auto token = op.getAsyncToken()) + token.replaceAllUsesWith(Value()); + op->erase(); + } + + // ----------------------------------------------------------------------- + // Phase 2: Inline hierarchy ops (inside-out). + // ----------------------------------------------------------------------- + + // --- air.herd -> wavefront index (gpu.thread_id / wavefront_size) --- + // Each herd tile is a wavefront (64 threads cooperating collectively). + // Herd tile ID = wavefront index within the workgroup. + SmallVector herds; + moduleOp->walk([&](xilinx::air::HerdOp op) { herds.push_back(op); }); + for (auto herd : herds) { + builder.setInsertionPoint(herd); + Block &body = herd.getBody().front(); + unsigned numDims = herd.getNumDims(); + IRMapping mapping; + + // Map tile IDs to wavefront index = gpu.thread_id / 64. + auto ids = herd.getIds(); + Value wavefrontSize = arith::ConstantIndexOp::create( + builder, herd.getLoc(), 64); + for (unsigned i = 0; i < numDims; ++i) { + Value threadId = gpu::ThreadIdOp::create(builder, herd.getLoc(), + dimFromIndex(i)); + Value wavefrontId = arith::DivUIOp::create(builder, herd.getLoc(), + threadId, wavefrontSize); + mapping.map(ids[i], wavefrontId); + } + + // Map tile sizes to size operands. + auto sizeArgs = herd.getSize(); + auto sizeOperands = herd.getSizeOperands(); + for (unsigned i = 0; i < numDims; ++i) + mapping.map(sizeArgs[i], sizeOperands[i]); + + // Map kernel arguments to kernel operands. + auto kernelArgs = herd.getKernelArguments(); + auto kernelOperands = herd.getKernelOperands(); + for (unsigned i = 0; i < kernelArgs.size(); ++i) + mapping.map(kernelArgs[i], kernelOperands[i]); + + // Clone body. + cloneBodyOps(builder, body, mapping); + + // Replace async token if present. + if (auto token = herd.getAsyncToken()) + token.replaceAllUsesWith(Value()); + herd->erase(); + } + + // --- air.segment -> inline --- + SmallVector segments; + moduleOp->walk( + [&](xilinx::air::SegmentOp op) { segments.push_back(op); }); + for (auto segment : segments) { + builder.setInsertionPoint(segment); + Block &body = segment.getBody().front(); + unsigned numDims = segment.getNumDims(); + IRMapping mapping; + + // Map segment IDs (if any) to gpu.block_id ops. + auto ids = segment.getIds(); + auto sizeArgs = segment.getSize(); + auto sizeOperands = segment.getSizeOperands(); + for (unsigned i = 0; i < numDims; ++i) { + Value blockId = gpu::BlockIdOp::create(builder, segment.getLoc(), + dimFromIndex(i)); + mapping.map(ids[i], blockId); + mapping.map(sizeArgs[i], sizeOperands[i]); + } + + // Map kernel arguments to kernel operands. + auto kernelArgs = segment.getKernelArguments(); + auto kernelOperands = segment.getKernelOperands(); + for (unsigned i = 0; i < kernelArgs.size(); ++i) + mapping.map(kernelArgs[i], kernelOperands[i]); + + cloneBodyOps(builder, body, mapping); + + if (auto token = segment.getAsyncToken()) + token.replaceAllUsesWith(Value()); + segment->erase(); + } + + // --- air.launch -> gpu.block_id --- + SmallVector launches; + moduleOp->walk( + [&](xilinx::air::LaunchOp op) { launches.push_back(op); }); + for (auto launch : launches) { + builder.setInsertionPoint(launch); + Block &body = launch.getBody().front(); + unsigned numDims = launch.getNumDims(); + IRMapping mapping; + + // Map launch IDs to gpu.block_id ops. + auto ids = launch.getIds(); + auto sizeArgs = launch.getSize(); + auto sizeOperands = launch.getSizeOperands(); + for (unsigned i = 0; i < numDims; ++i) { + Value blockId = gpu::BlockIdOp::create(builder, launch.getLoc(), + dimFromIndex(i)); + mapping.map(ids[i], blockId); + mapping.map(sizeArgs[i], sizeOperands[i]); + } + + // Map kernel arguments to kernel operands. + auto kernelArgs = launch.getKernelArguments(); + auto kernelOperands = launch.getKernelOperands(); + for (unsigned i = 0; i < kernelArgs.size(); ++i) + mapping.map(kernelArgs[i], kernelOperands[i]); + + cloneBodyOps(builder, body, mapping); + + if (auto token = launch.getAsyncToken()) + token.replaceAllUsesWith(Value()); + launch->erase(); + } + + // --- scf.parallel -> wavefront ID inlining --- + // After air-dma-to-channel, hoisted channel.put/get ops are wrapped in + // scf.parallel loops that iterate over the herd tile space. These must + // be inlined: replace induction variables with gpu.thread_id / 64 + // (wavefront index), then splice body ops in place of the parallel. + SmallVector parallels; + moduleOp->walk( + [&](scf::ParallelOp op) { parallels.push_back(op); }); + for (auto parallel : parallels) { + builder.setInsertionPoint(parallel); + Location loc = parallel.getLoc(); + Block &body = parallel.getRegion().front(); + + // Compute wavefront ID for each induction variable. + // All dimensions are derived from thread_id_x (1D wavefront layout). + Value wavefrontSize = + arith::ConstantIndexOp::create(builder, loc, 64); + Value threadIdX = + gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x); + Value wavefrontId = + arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize); + + auto ivs = parallel.getInductionVars(); + if (ivs.size() == 1) { + // 1D parallel: iv = wavefrontId directly. + ivs[0].replaceAllUsesWith(wavefrontId); + } else if (ivs.size() == 2) { + // 2D parallel: decompose wavefrontId into (x, y). + auto ub = parallel.getUpperBound(); + Value ubX = ub[0]; + Value ivX = arith::RemUIOp::create(builder, loc, wavefrontId, ubX); + Value ivY = arith::DivUIOp::create(builder, loc, wavefrontId, ubX); + ivs[0].replaceAllUsesWith(ivX); + ivs[1].replaceAllUsesWith(ivY); + } + + // Replace init values / results. scf.parallel with reduce produces + // results from scf.reduce. Replace with init values (no actual + // reduction needed — each wavefront runs independently). + for (unsigned i = 0; i < parallel.getNumResults(); ++i) + parallel.getResult(i).replaceAllUsesWith(parallel.getInitVals()[i]); + + // Splice body ops (skip the yield/reduce terminator) before parallel. + auto &parentOps = parallel->getBlock()->getOperations(); + auto &bodyOps = body.getOperations(); + // Find the last non-terminator op. + auto termIt = body.getTerminator()->getIterator(); + parentOps.splice(parallel->getIterator(), bodyOps, bodyOps.begin(), + termIt); + + parallel->erase(); + } + } +}; + +} // namespace + +namespace mlir::aster::mlir_air { +std::unique_ptr createAirToAMDGCN() { + return std::make_unique(); +} +} // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp deleted file mode 100644 index 62b204e4b..000000000 --- a/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp +++ /dev/null @@ -1,271 +0,0 @@ -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -//===- ConvertLinalgToAMDGCN.cpp - linalg ops -> AMDGCN library calls -----===// - -#include "aster/Dialect/AMDGCN/IR/AMDGCNDialect.h" -#include "aster/Dialect/AMDGCN/IR/AMDGCNOps.h" -#include "aster/Dialect/AMDGCN/IR/AMDGCNTypes.h" -#include "aster/Dialect/LSIR/IR/LSIROps.h" -#include "aster/Interfaces/ModuleOpInterface.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Ptr/IR/PtrOps.h" -#include "mlir/Dialect/Ptr/IR/PtrTypes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Pass/Pass.h" - -using namespace mlir; -using namespace mlir::aster; -using namespace mlir::aster::amdgcn; - -namespace { - -static std::string buildFuncName(StringRef prefix, MemRefType ty) { - std::string name; - llvm::raw_string_ostream os(name); - os << prefix; - Type elt = ty.getElementType(); - if (elt.isF16()) - os << "_f16"; - else if (elt.isF32()) - os << "_f32"; - else if (elt.isBF16()) - os << "_bf16"; - else - os << "_unknown"; - auto shape = ty.getShape(); - for (size_t i = 0; i < shape.size(); ++i) - os << (i == 0 ? "_" : "x") << shape[i]; - return name; -} - -static void ensureDecl(OpBuilder &builder, Block &block, Location loc, - StringRef name, FunctionType funcTy) { - for (auto &op : block) - if (auto fn = dyn_cast(&op)) - if (fn.getName() == name) - return; - auto savedIP = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(&block); - auto decl = func::FuncOp::create(builder, loc, name, funcTy); - decl.setPrivate(); - builder.restoreInsertionPoint(savedIP); -} - -/// Check if a memref value comes from promote to shared memory -/// (memref.view(memref.alloca) with workgroup address space). -static bool isPromotedBuffer(Value v) { - auto viewOp = v.getDefiningOp(); - if (!viewOp) - return false; - auto allocaOp = viewOp.getSource().getDefiningOp(); - if (!allocaOp) - return false; - auto memSpace = allocaOp.getMemref().getType().getMemorySpace(); - return memSpace != nullptr; -} - -/// Emit amdgcn.alloc_lds + get_lds_offset for a promoted buffer. -/// Uses a cache so the same promoted buffer (same memref.view(memref.alloca)) -/// gets the same LDS region for both write (copy) and read (matmul). -static Value emitLDSOffset(OpBuilder &builder, Location loc, Value memrefVal, - DenseMap &ldsCache) { - auto it = ldsCache.find(memrefVal); - if (it != ldsCache.end()) - return it->second; - - auto viewOp = memrefVal.getDefiningOp(); - auto allocaOp = viewOp.getSource().getDefiningOp(); - int64_t sizeBytes = allocaOp.getMemref().getType().getNumElements(); - auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), - sizeBytes, /*alignment=*/16, - /*offset=*/IntegerAttr{}); - auto ldsOffset = - GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); - Value result = builder.create(loc, ldsOffset.getResult(), - viewOp.getByteShift()); - ldsCache[memrefVal] = result; - return result; -} - -/// Decompose a global memref into (!sx2, byte_stride: index). -/// Emits: extract_strided_metadata -> ptr.to_ptr -> lsir.to_reg -> ptr_add. -static std::pair -decomposeGlobalMemref(OpBuilder &builder, Location loc, Value memref) { - auto mrTy = cast(memref.getType()); - unsigned eltBytes = mrTy.getElementType().getIntOrFloatBitWidth() / 8; - - // extract_strided_metadata -> (base_memref, offset, sizes..., strides...) - auto metadata = - memref::ExtractStridedMetadataOp::create(builder, loc, memref); - Value baseBuffer = metadata.getBaseBuffer(); - Value offset = metadata.getOffset(); - // Leading stride is strides[0] (row stride in elements). - Value leadingStride = metadata.getStrides()[0]; - - // byte_stride = leading_stride * elt_bytes - Value eltSize = arith::ConstantIndexOp::create(builder, loc, eltBytes); - Value byteStride = - arith::MulIOp::create(builder, loc, leadingStride, eltSize); - - // byte_offset = offset * elt_bytes - Value byteOffset = arith::MulIOp::create(builder, loc, offset, eltSize); - - // ptr.to_ptr base_memref -> !ptr.ptr - auto addrSpace = cast(mrTy.getMemorySpace()); - auto ptrTy = ptr::PtrType::get(builder.getContext(), addrSpace); - Value ptrVal = ptr::ToPtrOp::create(builder, loc, ptrTy, baseBuffer); - - // lsir.to_reg ptr -> !sx2 - auto sx2Ty = amdgcn::SGPRType::get(builder.getContext(), Register(), - /*size=*/2, /*alignment=*/2); - Value rawPtr = lsir::ToRegOp::create(builder, loc, sx2Ty, ptrVal); - - // Add byte offset: from_reg -> ptr_add -> to_reg - Value ptrFromReg = lsir::FromRegOp::create(builder, loc, ptrTy, rawPtr); - Value adjusted = - ptr::PtrAddOp::create(builder, loc, ptrTy, ptrFromReg, byteOffset); - Value result = lsir::ToRegOp::create(builder, loc, sx2Ty, adjusted); - - return {result, byteStride}; -} - -/// Replace a linalg op with a library call. -/// Global memrefs -> decomposed (!sx2, byte_stride) args. -/// Promoted buffers -> index (LDS offset). -static void replaceWithCall(OpBuilder &builder, Block &declBlock, Operation *op, - StringRef namePrefix, - SmallVector &toErase, - DenseMap &ldsCache) { - auto indexTy = builder.getIndexType(); - SmallVector callArgs; - SmallVector argTypes; - - MemRefType namingType; - for (Value operand : op->getOperands()) - if (auto mrTy = dyn_cast(operand.getType())) - if (!namingType) - namingType = mrTy; - if (!namingType) - return; - std::string name = buildFuncName(namePrefix, namingType); - - builder.setInsertionPoint(op); - Location loc = op->getLoc(); - - auto sx2Ty = amdgcn::SGPRType::get(builder.getContext(), Register(), - /*size=*/2, /*alignment=*/2); - - for (Value operand : op->getOperands()) { - if (auto mrTy = dyn_cast(operand.getType())) { - if (isPromotedBuffer(operand)) { - callArgs.push_back(emitLDSOffset(builder, loc, operand, ldsCache)); - argTypes.push_back(indexTy); - } else { - // Decompose global memref into (!sx2, byte_stride) - auto [ptrVal, byteStride] = - decomposeGlobalMemref(builder, loc, operand); - callArgs.push_back(ptrVal); - argTypes.push_back(sx2Ty); - callArgs.push_back(byteStride); - argTypes.push_back(indexTy); - } - } else { - callArgs.push_back(operand); - argTypes.push_back(operand.getType()); - } - } - - auto funcTy = builder.getFunctionType(argTypes, {}); - ensureDecl(builder, declBlock, loc, name, funcTy); - func::CallOp::create(builder, loc, name, TypeRange{}, callArgs); - toErase.push_back(op); -} - -struct ConvertLinalgToAMDGCN - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertLinalgToAMDGCN) - StringRef getArgument() const override { return "convert-linalg-to-amdgcn"; } - StringRef getDescription() const override { - return "Convert tiled linalg ops to AMDGCN library calls"; - } - - void runOnOperation() override { - Operation *moduleOp = getOperation(); - MLIRContext *ctx = &getContext(); - - Operation *declParent = moduleOp; - if (isa(moduleOp)) - moduleOp->walk([&](amdgcn::ModuleOp m) { declParent = m; }); - auto &declBlock = declParent->getRegion(0).front(); - OpBuilder builder(ctx); - SmallVector toErase; - DenseMap ldsCache; - - moduleOp->walk([&](linalg::FillOp op) { - replaceWithCall(builder, declBlock, op, "fill", toErase, ldsCache); - }); - moduleOp->walk([&](linalg::CopyOp op) { - replaceWithCall(builder, declBlock, op, "copy", toErase, ldsCache); - }); - moduleOp->walk([&](linalg::MatmulOp op) { - replaceWithCall(builder, declBlock, op, "mfma_matmul", toErase, ldsCache); - }); - // Also handle linalg.generic with matmul-like semantics (e.g., - // matmul_transpose_b expressed as generic with (m,n,k)->(m,k),(n,k),(m,n)). - moduleOp->walk([&](linalg::GenericOp op) { - if (op.getNumDpsInputs() == 2 && op.getNumDpsInits() == 1 && - op.getNumReductionLoops() == 1) - replaceWithCall(builder, declBlock, op, "mfma_matmul", toErase, - ldsCache); - }); - - for (auto *op : toErase) - op->erase(); - - // DCE unused alloca/view/dealloc. - SmallVector deadOps; - moduleOp->walk([&](Operation *op) { - if (isa(op) || - (isa(op) && op->use_empty())) - deadOps.push_back(op); - }); - for (auto *op : deadOps) - op->erase(); - deadOps.clear(); - moduleOp->walk([&](memref::AllocaOp op) { - if (op->use_empty()) - deadOps.push_back(op); - }); - for (auto *op : deadOps) - op->erase(); - - // Erase transform dialect ops. - if (auto builtinMod = dyn_cast(moduleOp)) { - SmallVector transformOps; - for (auto &op : builtinMod.getBody()->getOperations()) - if (op.getDialect() && op.getDialect()->getNamespace() == "transform") - transformOps.push_back(&op); - for (auto *op : transformOps) - op->erase(); - builtinMod->removeAttr("transform.with_named_sequence"); - } - } -}; - -} // namespace - -namespace mlir::aster::mlir_air { -std::unique_ptr createConvertLinalgToAMDGCN() { - return std::make_unique(); -} -} // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp new file mode 100644 index 000000000..c20c75f3f --- /dev/null +++ b/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp @@ -0,0 +1,177 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- ConvertMemSpaceToAMDGCN.cpp - integer memspace -> amdgcn addr_space ===// +// +// Converts AIR-style integer memory spaces on memref types to AMDGCN address +// space attributes: +// +// null / 0 → #amdgcn.addr_space +// 2 (L1) → #amdgcn.addr_space +// +// This pass bridges the gap between the upstream AIR pipeline (which uses +// integer memory spaces) and aster's AMDGCN decomposition passes (which +// require #amdgcn.addr_space attributes for pointer generation). +// +// Run after air-to-amdgcn and before convert-to-amdgcn-library-calls. +//===----------------------------------------------------------------------===// + +#include "aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNDialect.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNEnums.h" +#include "aster/Interfaces/ModuleOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::aster::amdgcn; + +namespace { + +/// Map an integer memory space to an AMDGCN AddressSpaceAttr. +/// Returns nullptr if no mapping is needed (e.g., already an AMDGCN attr). +static AddressSpaceAttr mapMemorySpace(Attribute memSpace, + MLIRContext *ctx) { + // null or integer 0 → global + if (!memSpace) { + return AddressSpaceAttr::get(ctx, AddressSpaceKind::Global, + AccessKind::ReadWrite); + } + if (auto intAttr = dyn_cast(memSpace)) { + unsigned space = intAttr.getInt(); + switch (space) { + case 0: + return AddressSpaceAttr::get(ctx, AddressSpaceKind::Global, + AccessKind::ReadWrite); + case 2: + return AddressSpaceAttr::get(ctx, AddressSpaceKind::Local, + AccessKind::ReadWrite); + default: + // Unknown integer space — map to global as fallback. + return AddressSpaceAttr::get(ctx, AddressSpaceKind::Global, + AccessKind::ReadWrite); + } + } + // Already a non-integer attribute (e.g., #amdgcn.addr_space) — no change. + return {}; +} + +/// Convert a MemRefType's memory space if needed. +static MemRefType convertMemRefType(MemRefType ty, MLIRContext *ctx) { + auto newSpace = mapMemorySpace(ty.getMemorySpace(), ctx); + if (!newSpace) + return ty; + // Use the AffineMap overload since MemRefLayoutAttrInterface overload + // doesn't exist for some MLIR versions. + Attribute spaceAttr = newSpace; + return MemRefType::get(ty.getShape(), ty.getElementType(), + ty.getLayout().getAffineMap(), spaceAttr); +} + +struct ConvertMemSpaceToAMDGCN + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertMemSpaceToAMDGCN) + StringRef getArgument() const override { + return "convert-memspace-to-amdgcn"; + } + StringRef getDescription() const override { + return "Convert integer memory spaces to #amdgcn.addr_space attributes"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + Operation *moduleOp = getOperation(); + MLIRContext *ctx = &getContext(); + + // Collect all func.func ops to update. + SmallVector funcs; + moduleOp->walk([&](func::FuncOp f) { funcs.push_back(f); }); + + for (auto funcOp : funcs) { + // Update function signature. + auto funcTy = funcOp.getFunctionType(); + bool changed = false; + + SmallVector newInputs; + for (auto ty : funcTy.getInputs()) { + if (auto mrTy = dyn_cast(ty)) { + auto newTy = convertMemRefType(mrTy, ctx); + newInputs.push_back(newTy); + if (newTy != mrTy) + changed = true; + } else { + newInputs.push_back(ty); + } + } + + SmallVector newResults; + for (auto ty : funcTy.getResults()) { + if (auto mrTy = dyn_cast(ty)) { + auto newTy = convertMemRefType(mrTy, ctx); + newResults.push_back(newTy); + if (newTy != mrTy) + changed = true; + } else { + newResults.push_back(ty); + } + } + + if (changed) { + funcOp.setType(FunctionType::get(ctx, newInputs, newResults)); + // Update block argument types. + if (!funcOp.empty()) { + Block &entry = funcOp.front(); + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + if (auto mrTy = dyn_cast(entry.getArgument(i).getType())) { + auto newTy = convertMemRefType(mrTy, ctx); + if (newTy != mrTy) + entry.getArgument(i).setType(newTy); + } + } + } + } + + // Walk all ops inside the function and update memref types. + funcOp->walk([&](Operation *op) { + // Update operand types (handled transitively via result types). + // Update result types. + for (unsigned i = 0; i < op->getNumResults(); ++i) { + auto ty = op->getResult(i).getType(); + if (auto mrTy = dyn_cast(ty)) { + auto newTy = convertMemRefType(mrTy, ctx); + if (newTy != mrTy) + op->getResult(i).setType(newTy); + } + } + // Update block argument types in regions. + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto arg : block.getArguments()) { + if (auto mrTy = dyn_cast(arg.getType())) { + auto newTy = convertMemRefType(mrTy, ctx); + if (newTy != mrTy) + arg.setType(newTy); + } + } + } + } + }); + } + } +}; + +} // namespace + +namespace mlir::aster::mlir_air { +std::unique_ptr createConvertMemSpaceToAMDGCN() { + return std::make_unique(); +} +} // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp new file mode 100644 index 000000000..fdbdb2a5f --- /dev/null +++ b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp @@ -0,0 +1,605 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- ConvertToAMDGCNLibraryCalls.cpp - ops -> AMDGCN library calls ------===// + +#include "air/Dialect/AIR/AIRDialect.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNDialect.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNEnums.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNOps.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNTypes.h" +#include "aster/Dialect/LSIR/IR/LSIRDialect.h" +#include "aster/Dialect/LSIR/IR/LSIROps.h" +#include "aster/Interfaces/ModuleOpInterface.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/Dialect/Ptr/IR/PtrTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::aster; +using namespace mlir::aster::amdgcn; + +namespace { + +// --------------------------------------------------------------------------- +// Utilities +// --------------------------------------------------------------------------- + +static std::string buildFuncName(StringRef prefix, MemRefType ty) { + std::string name; + llvm::raw_string_ostream os(name); + os << prefix; + Type elt = ty.getElementType(); + if (elt.isF16()) + os << "_f16"; + else if (elt.isF32()) + os << "_f32"; + else if (elt.isBF16()) + os << "_bf16"; + else + os << "_unknown"; + auto shape = ty.getShape(); + for (size_t i = 0; i < shape.size(); ++i) + os << (i == 0 ? "_" : "x") << shape[i]; + return name; +} + +static void ensureDecl(OpBuilder &builder, Block &block, Location loc, + StringRef name, FunctionType funcTy) { + for (auto &op : block) + if (auto fn = dyn_cast(&op)) + if (fn.getName() == name) + return; + auto savedIP = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(&block); + auto decl = func::FuncOp::create(builder, loc, name, funcTy); + decl.setPrivate(); + builder.restoreInsertionPoint(savedIP); +} + +static bool isPromotedBuffer(Value v) { + if (auto viewOp = v.getDefiningOp()) { + if (auto allocaOp = viewOp.getSource().getDefiningOp()) + return allocaOp.getMemref().getType().getMemorySpace() != nullptr; + } + if (auto allocOp = v.getDefiningOp()) { + auto memSpace = allocOp.getMemref().getType().getMemorySpace(); + if (!memSpace) + return false; + if (auto intAttr = dyn_cast(memSpace)) + return intAttr.getInt() == 2; + if (auto addrSpace = dyn_cast(memSpace)) + return addrSpace.getSpace() == amdgcn::AddressSpaceKind::Local; + return false; + } + return false; +} + +// --------------------------------------------------------------------------- +// Shared context for patterns (populated by analysis, read by patterns). +// --------------------------------------------------------------------------- + +struct ConversionContext { + Block *declBlock = nullptr; + int64_t numWavefronts = 1; + DenseMap ldsCache; +}; + +/// Emit amdgcn.alloc_lds + get_lds_offset for a promoted buffer at function +/// entry. When numWavefronts > 1, allocates nWf * tileSizeBytes and adds a +/// per-wavefront offset (wavefrontId * tileSizeBytes). +/// Uses ldsCache so the same buffer gets the same LDS region regardless of +/// which pattern fires first. +static Value emitLDSOffset(OpBuilder &builder, Location loc, Value memrefVal, + ConversionContext &convCtx) { + auto it = convCtx.ldsCache.find(memrefVal); + if (it != convCtx.ldsCache.end()) + return it->second; + + int64_t sizeBytes = 0; + Value byteShift; + if (auto viewOp = memrefVal.getDefiningOp()) { + auto allocaOp = viewOp.getSource().getDefiningOp(); + sizeBytes = allocaOp.getMemref().getType().getNumElements(); + byteShift = viewOp.getByteShift(); + } else if (auto allocOp = memrefVal.getDefiningOp()) { + auto mrTy = allocOp.getMemref().getType(); + unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); + sizeBytes = mrTy.getNumElements() * eltBits / 8; + } + + // Insert at function entry so LDS allocation dominates all uses. + auto funcOp = memrefVal.getParentRegion()->getParentOfType(); + auto savedIP = builder.saveInsertionPoint(); + if (funcOp) + builder.setInsertionPointToStart(&funcOp.front()); + + int64_t nWf = convCtx.numWavefronts; + auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), + nWf * sizeBytes, /*alignment=*/16, + /*offset=*/IntegerAttr{}); + auto ldsOffset = + GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); + Value result = ldsOffset.getResult(); + + if (nWf > 1) { + Value wavefrontSize = arith::ConstantIndexOp::create(builder, loc, 64); + Value threadIdX = + gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x); + Value wavefrontId = + arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize); + Value tileSizeVal = + arith::ConstantIndexOp::create(builder, loc, sizeBytes); + Value wavefrontOffset = + arith::MulIOp::create(builder, loc, wavefrontId, tileSizeVal); + result = arith::AddIOp::create(builder, loc, result, wavefrontOffset); + } + + if (byteShift) + result = builder.create(loc, result, byteShift); + + builder.restoreInsertionPoint(savedIP); + convCtx.ldsCache[memrefVal] = result; + return result; +} + +static std::pair +decomposeGlobalMemref(OpBuilder &builder, Location loc, Value memref) { + auto mrTy = cast(memref.getType()); + unsigned eltBytes = mrTy.getElementType().getIntOrFloatBitWidth() / 8; + auto metadata = + memref::ExtractStridedMetadataOp::create(builder, loc, memref); + Value baseBuffer = metadata.getBaseBuffer(); + Value offset = metadata.getOffset(); + Value leadingStride = metadata.getStrides()[0]; + Value eltSize = arith::ConstantIndexOp::create(builder, loc, eltBytes); + Value byteStride = + arith::MulIOp::create(builder, loc, leadingStride, eltSize); + Value byteOffset = arith::MulIOp::create(builder, loc, offset, eltSize); + auto addrSpace = cast(mrTy.getMemorySpace()); + auto ptrTy = ptr::PtrType::get(builder.getContext(), addrSpace); + Value ptrVal = ptr::ToPtrOp::create(builder, loc, ptrTy, baseBuffer); + auto sx2Ty = amdgcn::SGPRType::get(builder.getContext(), Register(), + /*size=*/2, /*alignment=*/2); + Value rawPtr = lsir::ToRegOp::create(builder, loc, sx2Ty, ptrVal); + Value ptrFromReg = lsir::FromRegOp::create(builder, loc, ptrTy, rawPtr); + Value adjusted = + ptr::PtrAddOp::create(builder, loc, ptrTy, ptrFromReg, byteOffset); + Value result = lsir::ToRegOp::create(builder, loc, sx2Ty, adjusted); + return {result, byteStride}; +} + +// --------------------------------------------------------------------------- +// Patterns +// --------------------------------------------------------------------------- + +/// Convert air.dma_memcpy_nd to a library call. +/// Handles both directions: Global→LDS and LDS→Global. +struct DmaToLibraryCall + : public OpRewritePattern { + ConversionContext &convCtx; + + DmaToLibraryCall(MLIRContext *ctx, ConversionContext &convCtx) + : OpRewritePattern(ctx), convCtx(convCtx) {} + + LogicalResult matchAndRewrite(xilinx::air::DmaMemcpyNdOp dma, + PatternRewriter &rewriter) const override { + Value dst = dma.getDstMemref(); + Value src = dma.getSrcMemref(); + bool dstIsLDS = isPromotedBuffer(dst); + bool srcIsLDS = isPromotedBuffer(src); + if (!dstIsLDS && !srcIsLDS) + return failure(); + + Location loc = dma.getLoc(); + auto indexTy = rewriter.getIndexType(); + auto sx2Ty = amdgcn::SGPRType::get(rewriter.getContext(), Register(), + /*size=*/2, /*alignment=*/2); + SmallVector callArgs; + SmallVector argTypes; + + // Use the LDS memref type for function naming. + // Append direction suffix to distinguish global→LDS from LDS→global. + auto ldsTy = cast(dstIsLDS ? dst.getType() : src.getType()); + std::string name = buildFuncName( + srcIsLDS ? "store_global" : "copy", ldsTy); + + if (dstIsLDS && !srcIsLDS) { + // Detect boundary tile padding from either: + // (a) pad_after attribute (set by air-split-launch-for-padding), or + // (b) DMA src_sizes[0] < dst memref dim 0 (from transform tensor.pad). + auto padAfterAttr = dma->getAttrOfType("pad_after"); + bool hasPadding = false; + int32_t rowPad = 0; // pad rows appended after the valid region + if (padAfterAttr) { + auto padArr = padAfterAttr.asArrayRef(); + if (!padArr.empty() && padArr[0] > 0) { + hasPadding = true; + rowPad = padArr[0]; + } + } + // Also detect from DMA sizes: if src_sizes[0] is a constant < dst dim 0, + // this is a partial copy into a padded LDS buffer (from tensor.pad path). + if (!hasPadding) { + auto srcSizes = dma.getSrcSizes(); + if (!srcSizes.empty()) { + auto srcRowOpt = getConstantIntValue(srcSizes[0]); + int64_t dstDim0 = ldsTy.getDimSize(0); + if (srcRowOpt && dstDim0 > 0 && *srcRowOpt < dstDim0) { + hasPadding = true; + rowPad = dstDim0 - *srcRowOpt; + } + } + } + + Value ldsOffset = emitLDSOffset(rewriter, loc, dst, convCtx); + + if (hasPadding) { + // When padding is detected from pad_after attribute (air-split-launch), + // emit fill explicitly. When detected from DMA sizes (tensor.pad path), + // the linalg.fill from transform already handles the zero-fill and is + // converted separately by LinalgToLibraryCall. + if (padAfterAttr) { + std::string fillName = buildFuncName("fill", ldsTy); + Type f16Ty = rewriter.getF16Type(); + Value zeroF16 = arith::ConstantOp::create( + rewriter, loc, f16Ty, + rewriter.getF16FloatAttr(0.0f)); + auto fillTy = rewriter.getFunctionType({f16Ty, indexTy}, {}); + ensureDecl(rewriter, *convCtx.declBlock, loc, fillName, fillTy); + func::CallOp::create(rewriter, loc, fillName, TypeRange{}, + ValueRange{zeroF16, ldsOffset}); + } + + // Partial copy — copy only the valid rows. + // copy_f16_16x64_padded(global_ptr, stride, row, col, actual_rows, lds_dst) + // The library function reads actual_rows rows from global (not all 16). + auto [ptrVal, byteStride] = decomposeGlobalMemref(rewriter, loc, src); + auto srcSizes = dma.getSrcSizes(); + // actual_rows = src_sizes[0] (set by split-launch pass to actualLast). + Value actualRows; + if (!srcSizes.empty()) + actualRows = srcSizes[0]; + else + actualRows = arith::ConstantIndexOp::create( + rewriter, loc, ldsTy.getDimSize(0) - rowPad); + + std::string copyName = buildFuncName("copy", ldsTy) + "_padded"; + SmallVector copyArgs = {ptrVal, byteStride}; + SmallVector copyArgTypes = {sx2Ty, indexTy}; + auto srcOffsets = dma.getSrcOffsets(); + if (srcOffsets.size() >= 2) { + copyArgs.push_back(srcOffsets[0]); + copyArgs.push_back(srcOffsets[1]); + } else { + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + copyArgs.push_back(zero); + copyArgs.push_back(zero); + } + copyArgTypes.push_back(indexTy); + copyArgTypes.push_back(indexTy); + copyArgs.push_back(actualRows); + copyArgTypes.push_back(indexTy); + copyArgs.push_back(ldsOffset); + copyArgTypes.push_back(indexTy); + auto copyTy = rewriter.getFunctionType(copyArgTypes, {}); + ensureDecl(rewriter, *convCtx.declBlock, loc, copyName, copyTy); + func::CallOp::create(rewriter, loc, copyName, TypeRange{}, copyArgs); + rewriter.eraseOp(dma); + return success(); + } + + // Non-padded: Global→LDS: copy(global_ptr, stride, row, col, lds_dst) + auto [ptrVal, byteStride] = decomposeGlobalMemref(rewriter, loc, src); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); + argTypes.push_back(indexTy); + auto srcOffsets = dma.getSrcOffsets(); + if (srcOffsets.size() >= 2) { + callArgs.push_back(srcOffsets[0]); + callArgs.push_back(srcOffsets[1]); + } else { + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + callArgs.push_back(zero); + callArgs.push_back(zero); + } + argTypes.push_back(indexTy); + argTypes.push_back(indexTy); + callArgs.push_back(ldsOffset); + argTypes.push_back(indexTy); + } else if (srcIsLDS && !dstIsLDS) { + // LDS→Global: copy(lds_src, global_ptr, stride, row, col) + callArgs.push_back(emitLDSOffset(rewriter, loc, src, convCtx)); + argTypes.push_back(indexTy); + auto [ptrVal, byteStride] = decomposeGlobalMemref(rewriter, loc, dst); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); + argTypes.push_back(indexTy); + auto dstOffsets = dma.getDstOffsets(); + if (dstOffsets.size() >= 2) { + callArgs.push_back(dstOffsets[0]); + callArgs.push_back(dstOffsets[1]); + } else { + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + callArgs.push_back(zero); + callArgs.push_back(zero); + } + argTypes.push_back(indexTy); + argTypes.push_back(indexTy); + } else { + // LDS→LDS: not supported. + return failure(); + } + + auto funcTy = rewriter.getFunctionType(argTypes, {}); + ensureDecl(rewriter, *convCtx.declBlock, loc, name, funcTy); + func::CallOp::create(rewriter, loc, name, TypeRange{}, callArgs); + rewriter.eraseOp(dma); + return success(); + } +}; + +/// Convert a linalg op (or memref.copy) with at least one promoted (LDS) +/// operand to a library call. +template +struct LinalgToLibraryCall : public OpRewritePattern { + ConversionContext &convCtx; + StringRef namePrefix; + + LinalgToLibraryCall(MLIRContext *ctx, ConversionContext &convCtx, + StringRef namePrefix) + : OpRewritePattern(ctx), convCtx(convCtx), + namePrefix(namePrefix) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + bool hasPromoted = false; + for (Value operand : op->getOperands()) + if (isPromotedBuffer(operand)) + hasPromoted = true; + if (!hasPromoted) + return failure(); + + MemRefType namingType; + for (Value operand : op->getOperands()) + if (auto mrTy = dyn_cast(operand.getType())) + if (!namingType) + namingType = mrTy; + if (!namingType) + return failure(); + + std::string name = buildFuncName(namePrefix, namingType); + Location loc = op->getLoc(); + auto indexTy = rewriter.getIndexType(); + auto sx2Ty = amdgcn::SGPRType::get(rewriter.getContext(), Register(), + /*size=*/2, /*alignment=*/2); + SmallVector callArgs; + SmallVector argTypes; + + for (Value operand : op->getOperands()) { + if (auto mrTy = dyn_cast(operand.getType())) { + if (isPromotedBuffer(operand)) { + callArgs.push_back( + emitLDSOffset(rewriter, loc, operand, convCtx)); + argTypes.push_back(indexTy); + } else { + Value baseMemref = operand; + SmallVector tileOffsets; + if (auto svOp = operand.getDefiningOp()) { + baseMemref = svOp.getSource(); + for (auto off : svOp.getMixedOffsets()) { + if (auto val = dyn_cast(off)) + tileOffsets.push_back(val); + else + tileOffsets.push_back(arith::ConstantIndexOp::create( + rewriter, loc, + cast(off.get()).getInt())); + } + } + auto [ptrVal, byteStride] = + decomposeGlobalMemref(rewriter, loc, baseMemref); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); + argTypes.push_back(indexTy); + if (tileOffsets.empty()) { + for (int64_t i = 0; i < mrTy.getRank(); ++i) { + callArgs.push_back( + arith::ConstantIndexOp::create(rewriter, loc, 0)); + argTypes.push_back(indexTy); + } + } else { + for (auto off : tileOffsets) { + callArgs.push_back(off); + argTypes.push_back(indexTy); + } + } + } + } else { + callArgs.push_back(operand); + argTypes.push_back(operand.getType()); + } + } + + auto funcTy = rewriter.getFunctionType(argTypes, {}); + ensureDecl(rewriter, *convCtx.declBlock, loc, name, funcTy); + func::CallOp::create(rewriter, loc, name, TypeRange{}, callArgs); + rewriter.eraseOp(op); + return success(); + } +}; + +/// Match linalg.generic with matmul semantics (2 inputs, 1 output, 1 reduction). +struct GenericMatmulToLibraryCall : public OpRewritePattern { + ConversionContext &convCtx; + + GenericMatmulToLibraryCall(MLIRContext *ctx, ConversionContext &convCtx) + : OpRewritePattern(ctx), convCtx(convCtx) {} + + LogicalResult matchAndRewrite(linalg::GenericOp op, + PatternRewriter &rewriter) const override { + if (op.getNumDpsInputs() != 2 || op.getNumDpsInits() != 1 || + op.getNumReductionLoops() != 1) + return failure(); + // Use _lds_c suffix when the output operand is in LDS. + StringRef prefix = "mfma_matmul"; + bool outputIsLDS = false; + for (Value out : op.getDpsInits()) + if (isPromotedBuffer(out)) + outputIsLDS = true; + std::string prefixStr = + outputIsLDS ? "mfma_matmul_lds_c" : "mfma_matmul"; + LinalgToLibraryCall inner( + rewriter.getContext(), convCtx, prefixStr); + return inner.matchAndRewrite(op, rewriter); + } +}; + +/// Erase linalg.fill on global (non-LDS) buffers. +/// The library's zero_C handles accumulator init. +struct EraseGlobalFill : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::FillOp fill, + PatternRewriter &rewriter) const override { + for (Value out : fill.getDpsInits()) + if (isa(out.getType()) && !isPromotedBuffer(out)) { + rewriter.eraseOp(fill); + return success(); + } + return failure(); + } +}; + +// --------------------------------------------------------------------------- +// Analysis: detect numWavefronts from IR and pre-allocate LDS. +// --------------------------------------------------------------------------- + +static int64_t detectNumWavefronts(Operation *moduleOp) { + int64_t result = 1; + moduleOp->walk([&](arith::DivUIOp divOp) { + if (result > 1) + return; + auto threadId = divOp.getLhs().getDefiningOp(); + if (!threadId || threadId.getDimension() != gpu::Dimension::x) + return; + auto cst = divOp.getRhs().getDefiningOp(); + if (!cst || cst.value() != 64) + return; + for (Operation *user : divOp.getResult().getUsers()) { + auto applyOp = dyn_cast(user); + if (!applyOp || applyOp.getAffineMap().getNumResults() != 1) + continue; + unsigned wfPos = 0; + bool found = false; + for (auto [idx, operand] : llvm::enumerate(applyOp.getMapOperands())) { + if (operand == divOp.getResult()) { + wfPos = idx; + found = true; + break; + } + } + if (!found) + continue; + int64_t stride = 0; + AffineExpr expr = applyOp.getAffineMap().getResult(0); + expr.walk([&](AffineExpr e) { + auto mul = dyn_cast(e); + if (!mul || mul.getKind() != AffineExprKind::Mul) + return; + auto sym = dyn_cast(mul.getLHS()); + auto con = dyn_cast(mul.getRHS()); + if (!sym || !con) + return; + if (sym.getPosition() == wfPos) + stride = con.getValue(); + }); + if (stride <= 0) + continue; + moduleOp->walk([&](xilinx::air::DmaMemcpyNdOp dma2) { + if (result > 1) + return; + Value src = dma2.getSrcMemref(); + auto srcTy = dyn_cast(src.getType()); + if (!srcTy || srcTy.getRank() < 1 || srcTy.getDimSize(0) <= 0) + return; + result = srcTy.getDimSize(0) / stride; + }); + } + }); + return result; +} + +// --------------------------------------------------------------------------- +// Pass +// --------------------------------------------------------------------------- + +struct ConvertToAMDGCNLibraryCalls + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToAMDGCNLibraryCalls) + StringRef getArgument() const override { + return "convert-to-amdgcn-library-calls"; + } + StringRef getDescription() const override { + return "Convert linalg/AIR ops to AMDGCN library calls"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + Operation *moduleOp = getOperation(); + MLIRContext *ctx = &getContext(); + + // Find the declaration block (inside amdgcn.module if present). + Operation *declParent = moduleOp; + if (isa(moduleOp)) + moduleOp->walk([&](amdgcn::ModuleOp m) { declParent = m; }); + + ConversionContext convCtx; + convCtx.declBlock = &declParent->getRegion(0).front(); + convCtx.numWavefronts = detectNumWavefronts(moduleOp); + + // Apply conversion patterns. + RewritePatternSet patterns(ctx); + patterns.add(ctx, convCtx); + patterns.add>(ctx, convCtx, "fill"); + patterns.add>(ctx, convCtx, "copy"); + patterns.add>(ctx, convCtx, "copy"); + patterns.add>(ctx, convCtx, + "mfma_matmul"); + patterns.add(ctx, convCtx); + patterns.add(ctx); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +namespace mlir::aster::mlir_air { +std::unique_ptr createConvertToAMDGCNLibraryCalls() { + return std::make_unique(); +} +} // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/lib/Init.cpp b/contrib/mlir-air/lib/Init.cpp index b3d11c0fd..6ba97744a 100644 --- a/contrib/mlir-air/lib/Init.cpp +++ b/contrib/mlir-air/lib/Init.cpp @@ -6,8 +6,45 @@ //===- Init.cpp - mlir-air dialect and pass registration ------------------===// +#include "air/Conversion/ConvertToAIRPass.h" +#include "air/Dialect/AIR/AIRDialect.h" +#include "air/Dialect/AIR/AIRTransformOps.h" +#include "air/Transform/AIRDmaToChannel.h" +#include "air/Transform/AIRMiscPasses.h" +#include "air/Transform/AIRSplitLaunchForPadding.h" + +// Tablegen-generated per-pass registration for upstream AIR passes. +namespace air_conv_reg { +#define GEN_PASS_REGISTRATION_COPYTODMA +#define GEN_PASS_REGISTRATION_PARALLELTOHERD +#define GEN_PASS_REGISTRATION_PARALLELTOLAUNCH +#define GEN_PASS_REGISTRATION_AIRWRAPFUNCWITHPARALLELPASS +#include "air/Conversion/Passes.h.inc" +} // namespace air_conv_reg + +namespace air_xform_reg { +#define GEN_PASS_REGISTRATION_DMATOCHANNEL +#define GEN_PASS_REGISTRATION_AIROVERRIDEMEMREFMEMORYSPACE +#define GEN_PASS_REGISTRATION_AIRSPLITLAUNCHFORPADDING +#include "air/Transform/Passes.h.inc" +} // namespace air_xform_reg + #include "aster/Init.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" @@ -20,15 +57,35 @@ namespace mlir::aster::mlir_air { -std::unique_ptr createConvertLinalgToAMDGCN(); +std::unique_ptr createAirToAMDGCN(); +std::unique_ptr createConvertToAMDGCNLibraryCalls(); +std::unique_ptr createConvertMemSpaceToAMDGCN(); void registerPipelines(); void registerAll(DialectRegistry ®istry) { + // AIR dialect. + registry.insert(); + // Dialects needed for linalg tiling + transform dialect. + registry.insert(); registry.insert(); + registry.insert(); registry.insert(); + // Bufferization interface models. + arith::registerBufferizableOpInterfaceExternalModels(registry); + bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + linalg::registerBufferizableOpInterfaceExternalModels(registry); + linalg::registerSubsetOpInterfaceExternalModels(registry); + scf::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); + tensor::registerSubsetOpInterfaceExternalModels(registry); + tensor::registerValueBoundsOpInterfaceExternalModels(registry); + // Transform dialect extensions. + bufferization::registerTransformDialectExtension(registry); linalg::registerTransformDialectExtension(registry); scf::registerTransformDialectExtension(registry); @@ -36,11 +93,35 @@ void registerAll(DialectRegistry ®istry) { linalg::registerTilingInterfaceExternalModels(registry); // Upstream passes. + bufferization::registerBufferizationPasses(); registerLinalgPasses(); memref::registerMemRefPasses(); transform::registerInterpreterPass(); + transform::registerPreloadLibraryPass(); - registerPass([] { return createConvertLinalgToAMDGCN(); }); + // AIR transform ops extension (air.transform.*). + xilinx::air::registerTransformDialectExtension(registry); + + // Upstream doesn't declare airDialect as a dependent of the transform + // extension — add it so par_to_herd can create air.herd ops. + registry.addExtension( + +[](MLIRContext *ctx, transform::TransformDialect *dialect) { + ctx->getOrLoadDialect(); + }); + + // Upstream AIR passes (tablegen-generated registration). + air_conv_reg::registerCopyToDma(); // air-copy-to-dma + air_conv_reg::registerParallelToHerd(); // air-par-to-herd + air_conv_reg::registerParallelToLaunch(); // air-par-to-launch + air_conv_reg::registerAIRWrapFuncWithParallelPass(); // air-wrap-func-with-parallel + air_xform_reg::registerDmaToChannel(); // air-dma-to-channel + air_xform_reg::registerAIROverrideMemRefMemorySpace(); + air_xform_reg::registerAIRSplitLaunchForPadding(); // air-split-launch-for-padding (upstream, with single-launch GPU mode) + + // Aster-specific passes. + registerPass([] { return createAirToAMDGCN(); }); + registerPass([] { return createConvertToAMDGCNLibraryCalls(); }); + registerPass([] { return createConvertMemSpaceToAMDGCN(); }); // mlir-air pipelines. registerPipelines(); @@ -50,4 +131,5 @@ void registerAll(DialectRegistry ®istry) { // are available in aster-opt and the Python bindings when linked. static int _register = (mlir::aster::registerContribDialects(registerAll), 0); + } // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir new file mode 100644 index 000000000..c7970ef64 --- /dev/null +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir @@ -0,0 +1,94 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Transform sequence for padded matmul with device-side padding. +// A/B are 40x64 (actual), NOT padded. C is 48x48 (padded for safe stores). +// +// The matmul is 40x40x64. tile_sizes [16, 16, 0] creates a 3x3 grid: +// - Interior tiles (2x2): full 16x16 output, full 16x64 A/B loads +// - M-boundary (row 2): 8xN output, 8x64 A loads +// - N-boundary (col 2): Mx8 output, 8x64 B loads +// - Corner (2,2): 8x8 output, 8x64 A and B loads +// +// After air-split-launch-for-padding, boundary DMAs get pad_after attribute. +// ConvertToAMDGCNLibraryCalls emits fill_f16_16x64 + copy_f16_16x64_padded +// for boundary tiles. + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + + // Tile 40x40 matmul into 3x3 grid. Boundary tiles have 8 rows/cols. + %outer_tiled, %outer_forall = + transform.structured.tile_using_forall %matmul + tile_sizes [16, 16, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Pad A,B for LDS promotion (16x64 tiles; K untiled). + // nofold ensures ALL tiles (including full interior tiles) go through + // the pad→alloc→copy path, so all get air.dma_memcpy_nd ops. + %padded, %pad, %copy_back = transform.structured.pad %outer_tiled { + padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32], + padding_dimensions = [0, 1, 2], + pack_paddings = [1, 1, 0], + nofold_flags = [1, 1, 0] + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad + : (!transform.any_op) -> !transform.any_op + + // Promote padded A,B to LDS (memory_space = 2 = shared/LDS). + %padded_lhs = transform.get_producer_of_operand %padded[0] + : (!transform.any_op) -> (!transform.any_op) + %buf_a, %new_a = transform.structured.bufferize_to_allocation %padded_lhs + {memory_space = 2, bufferize_destination_only} + : !transform.any_op + %padded_rhs = transform.get_producer_of_operand %padded[1] + : (!transform.any_op) -> (!transform.any_op) + %buf_b, %new_b = transform.structured.bufferize_to_allocation %padded_rhs + {memory_space = 2, bufferize_destination_only} + : !transform.any_op + + // Canonicalize. + %func_0 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_0 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_0 : !transform.any_op + + // Bufferize. + %func_1 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %func_buf = transform.bufferization.one_shot_bufferize %func_1 { + allow_return_allocs_from_loops = true + } : (!transform.any_op) -> !transform.any_op + + // Cleanup. + %func_2 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_2 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_2 : !transform.any_op + %func_3 = transform.air.remove_uninitialized_copy %func_2 + : (!transform.any_op) -> !transform.any_op + + // forall → parallel → launch (M and N both in launch). + %forall_2 = transform.structured.match ops{["scf.forall"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %parallel = transform.loop.forall_to_parallel %forall_2 + : (!transform.any_op) -> !transform.any_op + %launch = transform.air.par_to_launch %parallel + : (!transform.any_op) -> !transform.any_op + + transform.yield + } +} diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir new file mode 100644 index 000000000..71ec3cac0 --- /dev/null +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir @@ -0,0 +1,236 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Padded matmul: actual M=40, N=40, K=64. +// Device-side padding via air-split-launch-for-padding. +// +// The kernel has air.launch (3x3 grid) written directly — no transform, +// no air-wrap-func-with-parallel. Each block computes one 16x16 output tile. +// +// A, B: actual 40x64, C: 48x48 (padded for safe boundary stores). +// Boundary tiles (last row/col) copy fewer rows from global, with LDS +// zero-filled first. split-launch adds pad_after on boundary DMAs. +// ConvertToAMDGCNLibraryCalls emits fill + copy_padded for those. +// +// Host extracts valid C[0:40, 0:40] after kernel. + +!sx2 = !amdgcn.sgpr<[? + 2]> +!vx2 = !amdgcn.vgpr<[? + 2]> +!ax4 = !amdgcn.agpr<[? + 4]> +!lds_write_token = !amdgcn.write_token +!future_lds_read = !aster_utils.struct> +!future_global_read = !aster_utils.struct> + +module { + amdgcn.library @linalg_lib isa = [#amdgcn.isa] { + func.func private @zero_C() -> !ax4 + func.func private @mfma_f32_16x16x16_f16(!vx2, !vx2, !ax4) -> !ax4 + func.func private @store_global_C_mfma_f32_16x16x16_f16( + !ax4, !aster_utils.any, index, index, index) + func.func private @prepare_ptr(!sx2) -> !aster_utils.any + func.func private @load_global_tile_16x64_b( + !aster_utils.any, index, index, index) -> !future_global_read + func.func private @store_global_tile_to_lds_16x64_b( + index, !future_global_read) -> (!lds_write_token, !lds_write_token) + func.func private @load_lds_A_swizzled( + index, index, index) -> !future_lds_read + func.func private @load_lds_B_swizzled( + index, index, index) -> !future_lds_read + func.func private @get_lds_read_value_vx2(!future_lds_read) -> !vx2 + func.func private @thread_tile_pos_16x64_b() -> (index, index) + func.func private @tiled_row_byte_off(index, index, index, index, index, index) -> index + func.func private @load_global_at_byte_off(!aster_utils.any, index) -> !future_global_read + func.func private @fill_lds_16x64_b(index) + + func.func private @copy_f16_16x64( + %src_ptr: !sx2, %src_stride: index, + %row_offset: index, %col_offset: index, + %lds_dst: index) { + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %col1 = arith.addi %col_offset, %c32 : index + %lds_dst1 = arith.addi %lds_dst, %c1024 : index + %ptr = func.call @prepare_ptr(%src_ptr) : (!sx2) -> !aster_utils.any + %gfut0 = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col_offset, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t0, %t1 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst, %gfut0) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + %gfut1 = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col1, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t2, %t3 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst1, %gfut1) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + amdgcn.wait deps %t0 : !lds_write_token + amdgcn.wait deps %t1 : !lds_write_token + amdgcn.wait deps %t2 : !lds_write_token + amdgcn.wait deps %t3 : !lds_write_token + return + } + + func.func private @copy_f16_16x64_padded( + %src_ptr: !sx2, %src_stride: index, + %row_offset: index, %col_offset: index, + %actual_rows: index, + %lds_dst: index) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %ptr = func.call @prepare_ptr(%src_ptr) : (!sx2) -> !aster_utils.any + %my_row, %my_col_byte = func.call @thread_tile_pos_16x64_b() + : () -> (index, index) + %max_row = arith.subi %actual_rows, %c1 : index + %clamped_row = arith.minui %my_row, %max_row : index + %byte_off0 = func.call @tiled_row_byte_off( + %row_offset, %clamped_row, %col_offset, %my_col_byte, %src_stride, %c2) + : (index, index, index, index, index, index) -> index + %gfut0 = func.call @load_global_at_byte_off(%ptr, %byte_off0) + : (!aster_utils.any, index) -> !future_global_read + %t0, %t1 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst, %gfut0) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + %col1 = arith.addi %col_offset, %c32 : index + %lds_dst1 = arith.addi %lds_dst, %c1024 : index + %byte_off1 = func.call @tiled_row_byte_off( + %row_offset, %clamped_row, %col1, %my_col_byte, %src_stride, %c2) + : (index, index, index, index, index, index) -> index + %gfut1 = func.call @load_global_at_byte_off(%ptr, %byte_off1) + : (!aster_utils.any, index) -> !future_global_read + %t2, %t3 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst1, %gfut1) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + amdgcn.wait deps %t0 : !lds_write_token + amdgcn.wait deps %t1 : !lds_write_token + amdgcn.wait deps %t2 : !lds_write_token + amdgcn.wait deps %t3 : !lds_write_token + return + } + + func.func private @mfma_matmul_f16_16x64( + %lds_A: index, %lds_B: index, + %C_ptr: !sx2, %C_stride: index, + %C_row_offset: index, %C_col_offset: index) { + %C_prepared = func.call @prepare_ptr(%C_ptr) : (!sx2) -> !aster_utils.any + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %lds_A2 = arith.addi %lds_A, %c1024 : index + %lds_B2 = arith.addi %lds_B, %c1024 : index + %acc = func.call @zero_C() : () -> !ax4 + %A0f = func.call @load_lds_A_swizzled(%lds_A, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A0 = func.call @get_lds_read_value_vx2(%A0f) : (!future_lds_read) -> !vx2 + %B0f = func.call @load_lds_B_swizzled(%lds_B, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B0 = func.call @get_lds_read_value_vx2(%B0f) : (!future_lds_read) -> !vx2 + %acc0 = func.call @mfma_f32_16x16x16_f16(%A0, %B0, %acc) + : (!vx2, !vx2, !ax4) -> !ax4 + %A1f = func.call @load_lds_A_swizzled(%lds_A, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A1 = func.call @get_lds_read_value_vx2(%A1f) : (!future_lds_read) -> !vx2 + %B1f = func.call @load_lds_B_swizzled(%lds_B, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B1 = func.call @get_lds_read_value_vx2(%B1f) : (!future_lds_read) -> !vx2 + %acc1 = func.call @mfma_f32_16x16x16_f16(%A1, %B1, %acc0) + : (!vx2, !vx2, !ax4) -> !ax4 + %A2f = func.call @load_lds_A_swizzled(%lds_A2, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A2 = func.call @get_lds_read_value_vx2(%A2f) : (!future_lds_read) -> !vx2 + %B2f = func.call @load_lds_B_swizzled(%lds_B2, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B2 = func.call @get_lds_read_value_vx2(%B2f) : (!future_lds_read) -> !vx2 + %acc2 = func.call @mfma_f32_16x16x16_f16(%A2, %B2, %acc1) + : (!vx2, !vx2, !ax4) -> !ax4 + %A3f = func.call @load_lds_A_swizzled(%lds_A2, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A3 = func.call @get_lds_read_value_vx2(%A3f) : (!future_lds_read) -> !vx2 + %B3f = func.call @load_lds_B_swizzled(%lds_B2, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B3 = func.call @get_lds_read_value_vx2(%B3f) : (!future_lds_read) -> !vx2 + %acc3 = func.call @mfma_f32_16x16x16_f16(%A3, %B3, %acc2) + : (!vx2, !vx2, !ax4) -> !ax4 + func.call @store_global_C_mfma_f32_16x16x16_f16( + %acc3, %C_prepared, %C_row_offset, %C_col_offset, %C_stride) + : (!ax4, !aster_utils.any, index, index, index) -> () + return + } + + func.func private @fill_f16_16x64(%val: f16, %lds_dst: index) { + func.call @fill_lds_16x64_b(%lds_dst) : (index) -> () + return + } + } + + amdgcn.module @matmul_mod target = #amdgcn.target isa = #amdgcn.isa { + func.func @matmul_f16_40x40( + %A: memref<40x64xf16>, %B: memref<40x64xf16>, %C: memref<48x48xf32>) + attributes {gpu.kernel} { + %c3 = arith.constant 3 : index + + air.launch (%m_id, %n_id) in (%m_sz=%c3, %n_sz=%c3) + args(%a=%A, %b=%B, %c=%C) + : memref<40x64xf16>, memref<40x64xf16>, memref<48x48xf32> + attributes {air.actual_sizes = array} { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c40 = arith.constant 40 : index + %cst_f16 = arith.constant 0.000000e+00 : f16 + + // Tile offsets from launch block indices. + %m_off = arith.muli %m_id, %c16 : index + %n_off = arith.muli %n_id, %c16 : index + + // A tile: copy min(16, 40-m_off) rows from global to LDS. + %m_rem = arith.subi %c40, %m_off : index + %m_size = arith.minui %c16, %m_rem : index + %lds_a = memref.alloc() : memref<16x64xf16, 2> + linalg.fill ins(%cst_f16 : f16) outs(%lds_a : memref<16x64xf16, 2>) + %a_sub = memref.subview %a[%m_off, 0] [%m_size, 64] [1, 1] + : memref<40x64xf16> to memref> + %lds_a_sub = memref.subview %lds_a[0, 0] [%m_size, 64] [1, 1] + : memref<16x64xf16, 2> to memref, 2> + memref.copy %a_sub, %lds_a_sub + : memref> + to memref, 2> + + // B tile: copy min(16, 40-n_off) rows from global to LDS. + %n_rem = arith.subi %c40, %n_off : index + %n_size = arith.minui %c16, %n_rem : index + %lds_b = memref.alloc() : memref<16x64xf16, 2> + linalg.fill ins(%cst_f16 : f16) outs(%lds_b : memref<16x64xf16, 2>) + %b_sub = memref.subview %b[%n_off, 0] [%n_size, 64] [1, 1] + : memref<40x64xf16> to memref> + %lds_b_sub = memref.subview %lds_b[0, 0] [%n_size, 64] [1, 1] + : memref<16x64xf16, 2> to memref, 2> + memref.copy %b_sub, %lds_b_sub + : memref> + to memref, 2> + + // Matmul on full 16x64 LDS tiles → 16x16 output written to C. + // C is 48x48 so writing 16x16 at any (m_off, n_off) is safe. + %c_sub = memref.subview %c[%m_off, %n_off] [16, 16] [1, 1] + : memref<48x48xf32> to memref<16x16xf32, strided<[48, 1], offset: ?>> + linalg.generic { + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%lds_a, %lds_b : memref<16x64xf16, 2>, memref<16x64xf16, 2>) + outs(%c_sub : memref<16x16xf32, strided<[48, 1], offset: ?>>) { + ^bb0(%av: f16, %bv: f16, %cv: f32): + %a_ext = arith.extf %av : f16 to f32 + %b_ext = arith.extf %bv : f16 to f32 + %prod = arith.mulf %a_ext, %b_ext : f32 + %sum = arith.addf %cv, %prod : f32 + linalg.yield %sum : f32 + } + } + return + } + } +} diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir new file mode 100644 index 000000000..01a7bd0c0 --- /dev/null +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir @@ -0,0 +1,94 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Transform sequence for 64x64 matmul: tile, pad, bufferize, map to AIR herd. +// Adapted from xrt/12 (tile-using-pad, no packing). + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + + // Outer tiling: 2x1 forall on M (becomes 2-wavefront air.herd). + // 64/32 = 2 iterations — non-trivial, survives canonicalization. + %outer_tiled, %outer_forall = + transform.structured.tile_using_forall %matmul + tile_sizes [32, 0, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Compute tiling inside forall: 16x16 M×N tiles, no K tiling. + // mfma_matmul_f16_16x32 calls zero_C internally so K must not be tiled + // across loop iterations (each call would reset the accumulator). + // The library internally handles K by reading two 16x16 panels from LDS. + %tiled, %lm, %ln = transform.structured.tile_using_for %outer_tiled + tile_sizes [16, 16, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op) + + // Pad A and B only (not C — the matmul stores C directly to global). + %padded, %pad, %copy_back = transform.structured.pad %tiled { + padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32], + padding_dimensions = [0, 1, 2], + pack_paddings = [1, 1, 0], + nofold_flags = [1, 1, 0], + copy_back_op = "linalg.copy" + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad + : (!transform.any_op) -> !transform.any_op + + // Promote A,B pads to L1 (memory_space=2) via bufferize_to_allocation. + %padded_lhs = transform.get_producer_of_operand %padded[0] + : (!transform.any_op) -> (!transform.any_op) + %buf_a, %new_a = transform.structured.bufferize_to_allocation %padded_lhs + {memory_space = 2, bufferize_destination_only} + : !transform.any_op + + %padded_rhs = transform.get_producer_of_operand %padded[1] + : (!transform.any_op) -> (!transform.any_op) + %buf_b, %new_b = transform.structured.bufferize_to_allocation %padded_rhs + {memory_space = 2, bufferize_destination_only} + : !transform.any_op + + // Canonicalize. + %func_0 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_0 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_0 : !transform.any_op + + // One-shot bufferize. + %func_1 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %func_buf = transform.bufferization.one_shot_bufferize %func_1 { + allow_return_allocs_from_loops = true + } : (!transform.any_op) -> !transform.any_op + + // Cleanup. + %func_2 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_2 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_2 : !transform.any_op + %func_3 = transform.air.remove_uninitialized_copy %func_2 + : (!transform.any_op) -> !transform.any_op + + // Convert outer forall → parallel → air.herd (now on memrefs). + %forall_2 = transform.structured.match ops{["scf.forall"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %parallel = transform.loop.forall_to_parallel %forall_2 + : (!transform.any_op) -> !transform.any_op + %herd = transform.air.par_to_herd %parallel + : (!transform.any_op) -> !transform.any_op + + transform.yield + } +} diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir new file mode 100644 index 000000000..40d51c5ce --- /dev/null +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir @@ -0,0 +1,233 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// RUN: mlir-air-opt %s \ +// RUN: --transform-preload-library="transform-library-paths=%p/air-to-amdgcn-matmul-transform.mlir" \ +// RUN: --transform-interpreter --canonicalize --cse \ +// RUN: --air-par-to-launch="has-air-segment=true" --canonicalize --cse \ +// RUN: --air-copy-to-dma \ +// RUN: --air-to-amdgcn --canonicalize \ +// RUN: --convert-memspace-to-amdgcn \ +// RUN: --convert-to-amdgcn-library-calls \ +// RUN: --amdgcn-preload-library="library-paths=%p/../../../mlir_kernels/library/common/register-init.mlir,%p/../../../mlir_kernels/library/common/indexing.mlir,%p/../../../mlir_kernels/library/common/indexing_ptr.mlir,%p/../../../mlir_kernels/library/common/futures.mlir,%p/../../../contrib/kittens/library/compute_16x16_f16.mlir,%p/../../../contrib/kittens/library/global_16x64_b.mlir,%p/../../../contrib/kittens/library/lds_16x64_b.mlir,%p/../../../contrib/kittens/library/lds_mfma_16x64_b.mlir" \ +// RUN: --inline --symbol-dce --canonicalize \ +// RUN: --mlir-air-to-asm \ +// RUN: | aster-translate --mlir-to-asm \ +// RUN: | FileCheck %s + +// CHECK-LABEL: matmul_f16_64x64: +// CHECK: global_load_dwordx4 +// CHECK: ds_write_b64 +// CHECK: ds_read_b64 +// CHECK: v_mfma_f32_16x16x16_f16 +// CHECK: global_store_dword +// CHECK: s_endpgm + +// Real AIR pipeline (adapted from xrt/12, tile-using-pad path): +// +// 1. linalg.generic on tensors (64x64 matmul, no AIR ops) +// 2. transform: tile_using_forall (2x1 herd) → tile_using_for (compute) +// → pad → bufferize_to_allocation (L1) → one_shot_bufferize +// → forall_to_parallel → par_to_herd +// 3. air-copy-to-dma → air-dma-to-channel +// 4. air-to-amdgcn (herd → wavefront index) +// 5. convert-memspace → convert-linalg → preload → asm + +!sx2 = !amdgcn.sgpr<[? + 2]> +!vx2 = !amdgcn.vgpr<[? + 2]> +!ax4 = !amdgcn.agpr<[? + 4]> +!lds_write_token = !amdgcn.write_token +!future_lds_read = !aster_utils.struct> +!future_global_read = !aster_utils.struct> + +module { + amdgcn.library @linalg_lib isa = [#amdgcn.isa] { + func.func private @zero_C() -> !ax4 + func.func private @mfma_f32_16x16x16_f16(!vx2, !vx2, !ax4) -> !ax4 + func.func private @store_global_C_mfma_f32_16x16x16_f16( + !ax4, !aster_utils.any, index, index, index) + func.func private @prepare_ptr(!sx2) -> !aster_utils.any + func.func private @load_global_tile_16x64_b( + !aster_utils.any, index, index, index) -> !future_global_read + func.func private @store_global_tile_to_lds_16x64_b( + index, !future_global_read) -> (!lds_write_token, !lds_write_token) + func.func private @load_lds_A_swizzled( + index, index, index) -> !future_lds_read + func.func private @load_lds_B_swizzled( + index, index, index) -> !future_lds_read + func.func private @get_lds_read_value_vx2(!future_lds_read) -> !vx2 + + func.func private @copy_f16_16x32( + %src_ptr: !sx2, %src_stride: index, + %row_offset: index, %col_offset: index, + %lds_dst: index) { + %ptr = func.call @prepare_ptr(%src_ptr) : (!sx2) -> !aster_utils.any + %gfut = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col_offset, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t0, %t1 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst, %gfut) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + amdgcn.wait deps %t0 : !lds_write_token + amdgcn.wait deps %t1 : !lds_write_token + return + } + + func.func private @mfma_matmul_f16_16x32( + %lds_A: index, %lds_B: index, + %C_ptr: !sx2, %C_stride: index, + %C_row_offset: index, %C_col_offset: index) { + %C_prepared = func.call @prepare_ptr(%C_ptr) : (!sx2) -> !aster_utils.any + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %acc = func.call @zero_C() : () -> !ax4 + %A0f = func.call @load_lds_A_swizzled(%lds_A, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A0 = func.call @get_lds_read_value_vx2(%A0f) : (!future_lds_read) -> !vx2 + %B0f = func.call @load_lds_B_swizzled(%lds_B, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B0 = func.call @get_lds_read_value_vx2(%B0f) : (!future_lds_read) -> !vx2 + %acc0 = func.call @mfma_f32_16x16x16_f16(%A0, %B0, %acc) + : (!vx2, !vx2, !ax4) -> !ax4 + %A1f = func.call @load_lds_A_swizzled(%lds_A, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A1 = func.call @get_lds_read_value_vx2(%A1f) : (!future_lds_read) -> !vx2 + %B1f = func.call @load_lds_B_swizzled(%lds_B, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B1 = func.call @get_lds_read_value_vx2(%B1f) : (!future_lds_read) -> !vx2 + %acc1 = func.call @mfma_f32_16x16x16_f16(%A1, %B1, %acc0) + : (!vx2, !vx2, !ax4) -> !ax4 + func.call @store_global_C_mfma_f32_16x16x16_f16( + %acc1, %C_prepared, %C_row_offset, %C_col_offset, %C_stride) + : (!ax4, !aster_utils.any, index, index, index) -> () + return + } + + func.func private @fill_f16_16x32(%val: f16, %lds_dst: index) { return } + + // Copy a 16x64 f16 tile from global memory to LDS. + // Loads two consecutive 16x32-element (= 16x64-byte) panels at col and col+32. + func.func private @copy_f16_16x64( + %src_ptr: !sx2, %src_stride: index, + %row_offset: index, %col_offset: index, + %lds_dst: index) { + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %col1 = arith.addi %col_offset, %c32 : index + %lds_dst1 = arith.addi %lds_dst, %c1024 : index + %ptr = func.call @prepare_ptr(%src_ptr) : (!sx2) -> !aster_utils.any + %gfut0 = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col_offset, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t0, %t1 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst, %gfut0) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + %gfut1 = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col1, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t2, %t3 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst1, %gfut1) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + amdgcn.wait deps %t0 : !lds_write_token + amdgcn.wait deps %t1 : !lds_write_token + amdgcn.wait deps %t2 : !lds_write_token + amdgcn.wait deps %t3 : !lds_write_token + return + } + + // 16x16 output matmul using two 16x32 A tiles and two 16x32 B tiles in LDS. + // lds_A points to the first 16x32 block; lds_A2 = lds_A + 1024 is the second. + // Similarly for lds_B. Reads 4 panels at k_byte_offsets 0,32 within each block. + func.func private @mfma_matmul_f16_16x64( + %lds_A: index, %lds_B: index, + %C_ptr: !sx2, %C_stride: index, + %C_row_offset: index, %C_col_offset: index) { + %C_prepared = func.call @prepare_ptr(%C_ptr) : (!sx2) -> !aster_utils.any + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + // Second half of K: LDS base + 1024 bytes. + %lds_A2 = arith.addi %lds_A, %c1024 : index + %lds_B2 = arith.addi %lds_B, %c1024 : index + %acc = func.call @zero_C() : () -> !ax4 + // K panel 0 (k=0..15): from first block, offset 0. + %A0f = func.call @load_lds_A_swizzled(%lds_A, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A0 = func.call @get_lds_read_value_vx2(%A0f) : (!future_lds_read) -> !vx2 + %B0f = func.call @load_lds_B_swizzled(%lds_B, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B0 = func.call @get_lds_read_value_vx2(%B0f) : (!future_lds_read) -> !vx2 + %acc0 = func.call @mfma_f32_16x16x16_f16(%A0, %B0, %acc) + : (!vx2, !vx2, !ax4) -> !ax4 + // K panel 1 (k=16..31): from first block, offset 32. + %A1f = func.call @load_lds_A_swizzled(%lds_A, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A1 = func.call @get_lds_read_value_vx2(%A1f) : (!future_lds_read) -> !vx2 + %B1f = func.call @load_lds_B_swizzled(%lds_B, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B1 = func.call @get_lds_read_value_vx2(%B1f) : (!future_lds_read) -> !vx2 + %acc1 = func.call @mfma_f32_16x16x16_f16(%A1, %B1, %acc0) + : (!vx2, !vx2, !ax4) -> !ax4 + // K panel 2 (k=32..47): from second block, offset 0. + %A2f = func.call @load_lds_A_swizzled(%lds_A2, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A2 = func.call @get_lds_read_value_vx2(%A2f) : (!future_lds_read) -> !vx2 + %B2f = func.call @load_lds_B_swizzled(%lds_B2, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B2 = func.call @get_lds_read_value_vx2(%B2f) : (!future_lds_read) -> !vx2 + %acc2 = func.call @mfma_f32_16x16x16_f16(%A2, %B2, %acc1) + : (!vx2, !vx2, !ax4) -> !ax4 + // K panel 3 (k=48..63): from second block, offset 32. + %A3f = func.call @load_lds_A_swizzled(%lds_A2, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A3 = func.call @get_lds_read_value_vx2(%A3f) : (!future_lds_read) -> !vx2 + %B3f = func.call @load_lds_B_swizzled(%lds_B2, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B3 = func.call @get_lds_read_value_vx2(%B3f) : (!future_lds_read) -> !vx2 + %acc3 = func.call @mfma_f32_16x16x16_f16(%A3, %B3, %acc2) + : (!vx2, !vx2, !ax4) -> !ax4 + func.call @store_global_C_mfma_f32_16x16x16_f16( + %acc3, %C_prepared, %C_row_offset, %C_col_offset, %C_stride) + : (!ax4, !aster_utils.any, index, index, index) -> () + return + } + + func.func private @fill_f16_16x64(%val: f16, %lds_dst: index) { return } + } + + amdgcn.module @matmul_mod target = #amdgcn.target isa = #amdgcn.isa { + // 64x64 tensor-based matmul. No AIR ops — transform generates hierarchy. + func.func @matmul_f16_64x64( + %A: memref<64x64xf16>, %B: memref<64x64xf16>, %C: memref<64x64xf32>) + attributes {gpu.kernel} { + %cst = arith.constant 0.000000e+00 : f32 + %a = bufferization.to_tensor %A restrict writable : memref<64x64xf16> to tensor<64x64xf16> + %b = bufferization.to_tensor %B restrict writable : memref<64x64xf16> to tensor<64x64xf16> + %c = bufferization.to_tensor %C restrict writable : memref<64x64xf32> to tensor<64x64xf32> + %fill = linalg.fill ins(%cst : f32) outs(%c : tensor<64x64xf32>) -> tensor<64x64xf32> + // matmul_transpose_b: C[m,n] += A[m,k] * B[n,k] + %result = linalg.generic { + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%a, %b : tensor<64x64xf16>, tensor<64x64xf16>) + outs(%fill : tensor<64x64xf32>) { + ^bb0(%av: f16, %bv: f16, %cv: f32): + %a_ext = arith.extf %av : f16 to f32 + %b_ext = arith.extf %bv : f16 to f32 + %prod = arith.mulf %a_ext, %b_ext : f32 + %sum = arith.addf %cv, %prod : f32 + linalg.yield %sum : f32 + } -> tensor<64x64xf32> + bufferization.materialize_in_destination %result in writable %C + : (tensor<64x64xf32>, memref<64x64xf32>) -> () + return + } + } + +} diff --git a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py new file mode 100644 index 000000000..5d8fdd016 --- /dev/null +++ b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py @@ -0,0 +1,223 @@ +# Copyright 2026 The ASTER Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""E2E matmul test exercising the real AIR lowering path. + +Pipeline: + mlir-air-opt (preprocess): + --transform-interpreter + --air-par-to-herd (forall → herd, each tile = 1 wavefront) + --one-shot-bufferize + --air-par-to-launch (outer parallel → launch) + --air-copy-to-dma (memref.copy → air.dma_memcpy_nd) + --air-to-amdgcn (flatten hierarchy, herd → wavefront index) + --convert-memspace-to-amdgcn (integer memspace → #amdgcn.addr_space) + --convert-to-amdgcn-library-calls (air.dma_memcpy_nd + linalg ops → library calls) + then aster pipeline: + --preload → inline → mlir-air-to-asm +""" + +import os +import shutil +import subprocess + +import numpy as np +import pytest + +from aster.execution.helpers import compile_and_run + +MCPU = "gfx942" + +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_MLIR_FILE = os.path.join(_THIS_DIR, "..", "air-to-amdgcn-matmul.mlir") +_TRANSFORM_FILE = os.path.join(_THIS_DIR, "..", "air-to-amdgcn-matmul-transform.mlir") +_PADDED_MLIR_FILE = os.path.join(_THIS_DIR, "..", "air-to-amdgcn-matmul-padded.mlir") +_PADDED_TRANSFORM_FILE = os.path.join( + _THIS_DIR, "..", "air-to-amdgcn-matmul-padded-transform.mlir" +) +_LIBRARY_DIR = os.path.join( + _THIS_DIR, "..", "..", "..", "..", "mlir_kernels", "library" +) +_KITTENS_DIR = os.path.join( + _THIS_DIR, "..", "..", "..", "..", "contrib", "kittens", "library" +) + +_LIBRARY_PATHS = [ + os.path.join(_LIBRARY_DIR, "common", f) + for f in [ + "register-init.mlir", + "indexing.mlir", + "indexing_ptr.mlir", + "futures.mlir", + ] +] + [ + os.path.join(_KITTENS_DIR, f) + for f in [ + "compute_16x16_f16.mlir", + "global_16x64_b.mlir", + "lds_16x64_b.mlir", + "lds_mfma_16x64_b.mlir", + ] +] + + +def _find_mlir_air_opt(): + """Find the mlir-air-opt binary.""" + build_path = os.path.join( + _THIS_DIR, "..", "..", "..", "..", "build", "bin", "mlir-air-opt" + ) + if os.path.isfile(build_path): + return os.path.abspath(build_path) + path = shutil.which("mlir-air-opt") + if path: + return path + pytest.skip("mlir-air-opt not found") + + +def _air_preprocess_with_files(mlir_text, transform_file): + """Run the full AIR lowering pipeline before handing to aster.""" + opt = _find_mlir_air_opt() + result = subprocess.run( + [ + opt, + f"--transform-preload-library=transform-library-paths={transform_file}", + "--transform-interpreter", + "--air-par-to-herd", + "--canonicalize", "--cse", + "--one-shot-bufferize", + "--canonicalize", "--cse", + "--air-par-to-launch=has-air-segment=true", + "--canonicalize", "--cse", + "--air-copy-to-dma", + "--air-to-amdgcn", + "--canonicalize", + "--convert-memspace-to-amdgcn", + "--convert-to-amdgcn-library-calls", + ], + input=mlir_text, + capture_output=True, text=True, + ) + if result.returncode != 0: + raise RuntimeError( + f"mlir-air-opt AIR preprocessing failed:\n{result.stderr}") + return result.stdout + + +def _air_preprocess(mlir_text): + return _air_preprocess_with_files(mlir_text, _TRANSFORM_FILE) + + +def _post_air_pipeline(library_paths): + libs = ",".join(library_paths) + return ( + "builtin.module(" + "canonicalize," + f"amdgcn-preload-library{{library-paths={libs}}}," + "inline, symbol-dce, canonicalize," + "mlir-air-to-asm)" + ) + + +class TestAirMatmulE2E: + + def test_matmul_64x64(self): + M, N, K = 64, 64, 64 + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np.float16) + B_KxN = (np.random.randn(K, N) * 0.1).astype(np.float16) + B_T = np.ascontiguousarray(B_KxN.T) + # C must be zero-initialized (fill is erased by convert-to-amdgcn-library-calls; + # the library's zero_C handles accumulator init per tile). + C = np.zeros(M * N, dtype=np.float32) + + compile_and_run( + file_name=_MLIR_FILE, + kernel_name="matmul_f16_64x64", + input_data=[A.flatten(), B_T.flatten()], + output_data=[C], + pass_pipeline=_post_air_pipeline(_LIBRARY_PATHS), + library_paths=[], + grid_dim=(1, 1, 1), + block_dim=(128, 1, 1), # 2 wavefronts (2x1 herd) + preprocess=_air_preprocess, + ) + + expected = (A.astype(np.float32) @ B_KxN.astype(np.float32)).flatten() + np.testing.assert_allclose(C, expected, rtol=1e-2, atol=1e-2) + + def test_matmul_padded_40x40(self): + """Matmul with non-tile-aligned dimensions: actual 40x40x64. + + Device-side padding via air-split-launch-for-padding: + - A, B are actual 40x64 (NO host-side padding) + - C is 48x48 (padded to next multiple of tile=16 for safe boundary stores) + - The split-launch pass produces a single 3x3 launch with scf.if + on block indices to select between interior and boundary bodies + - Boundary tiles use fill_f16_16x64 (zero-fill LDS) + + copy_f16_16x64_padded (clamped global loads via arith.minui) + - Valid result is in C[0:40, 0:40] + + Grid: 3x3 blocks (ceil(40/16)=3 per dim), one wavefront per block. + """ + M, N, K = 40, 40, 64 + M_pad, N_pad = 48, 48 # next multiple of tile size 16 (for C only) + + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np.float16) + B_T = (np.random.randn(N, K) * 0.1).astype(np.float16) + + # C is padded to 48x48 so boundary 16x16 stores don't go OOB. + C_pad = np.zeros(M_pad * N_pad, dtype=np.float32) + + def padded_preprocess(mlir_text): + opt = _find_mlir_air_opt() + dump_dir = "/mnt/m2m_nobackup/erweiw/aster/ir_dumps" + os.makedirs(dump_dir, exist_ok=True) + + result = subprocess.run( + [ + opt, + "--air-copy-to-dma", + "--air-split-launch-for-padding=split-mode=single-launch pad-location=source", + "--canonicalize", "--cse", + "--air-to-amdgcn", + "--canonicalize", + "--convert-memspace-to-amdgcn", + "--convert-to-amdgcn-library-calls", + ], + input=mlir_text, + capture_output=True, text=True, + ) + if result.returncode != 0: + raise RuntimeError( + f"mlir-air-opt padded preprocessing failed:\n{result.stderr}") + return result.stdout + + from aster.execution.core import InOutArray + + compile_and_run( + file_name=_PADDED_MLIR_FILE, + kernel_name="matmul_f16_40x40", + # Kernel args: (A: 40x64, B: 40x64, C: 48x48) + # No workspace buffers — single-tile kernel, no promoted allocs. + input_data=[ + A.flatten(), # arg0: A 40x64 (read-only) + B_T.flatten(), # arg1: B 40x64 (read-only) + InOutArray(C_pad), # arg2: C 48x48 (read-write) + ], + output_data=[], + pass_pipeline=_post_air_pipeline(_LIBRARY_PATHS), + library_paths=[], + grid_dim=(3, 3, 1), # ceil(40/16)=3 per dim, 9 blocks total + block_dim=(64, 1, 1), # 1 wavefront per block + preprocess=padded_preprocess, + ) + + # Extract valid 40x40 region from padded 48x48 output. + C_2d = C_pad.reshape(M_pad, N_pad) + C_valid = C_2d[:M, :N].flatten() + expected = (A.astype(np.float32) @ B_T.T.astype(np.float32)).flatten() + np.testing.assert_allclose(C_valid, expected, rtol=1e-2, atol=1e-2) diff --git a/contrib/mlir-air/test/integration/test_linalg_matmul_e2e.py b/contrib/mlir-air/test/integration/test_linalg_matmul_e2e.py index 8bd7c1fc6..f31c3bb41 100644 --- a/contrib/mlir-air/test/integration/test_linalg_matmul_e2e.py +++ b/contrib/mlir-air/test/integration/test_linalg_matmul_e2e.py @@ -40,7 +40,7 @@ def _mlir_air_pipeline(library_paths): return ( "builtin.module(" "transform-interpreter, canonicalize," - "convert-linalg-to-amdgcn," + "convert-to-amdgcn-library-calls," f"amdgcn-preload-library{{library-paths={libs}}}," "inline, symbol-dce, canonicalize," "mlir-air-to-asm)" diff --git a/contrib/mlir-air/test/linalg-to-amdgcn.mlir b/contrib/mlir-air/test/linalg-to-amdgcn.mlir index c1fe93e48..266b4238b 100644 --- a/contrib/mlir-air/test/linalg-to-amdgcn.mlir +++ b/contrib/mlir-air/test/linalg-to-amdgcn.mlir @@ -1,6 +1,6 @@ // RUN: mlir-air-opt %s \ // RUN: --transform-interpreter --canonicalize \ -// RUN: --convert-linalg-to-amdgcn \ +// RUN: --convert-to-amdgcn-library-calls \ // RUN: --amdgcn-preload-library="library-paths=%p/../../../mlir_kernels/library/common/register-init.mlir,%p/../../../mlir_kernels/library/common/indexing.mlir,%p/../../../mlir_kernels/library/common/indexing_ptr.mlir,%p/../../../mlir_kernels/library/common/futures.mlir,%p/../../../contrib/kittens/library/compute_16x16_f16.mlir,%p/../../../contrib/kittens/library/global_16x64_b.mlir,%p/../../../contrib/kittens/library/lds_16x64_b.mlir,%p/../../../contrib/kittens/library/lds_mfma_16x64_b.mlir" \ // RUN: --inline --symbol-dce --canonicalize \ // RUN: --mlir-air-to-asm \ diff --git a/contrib/mlir-air/tools/mlir-air-opt.cpp b/contrib/mlir-air/tools/mlir-air-opt.cpp index 874bdc2b5..6fd335091 100644 --- a/contrib/mlir-air/tools/mlir-air-opt.cpp +++ b/contrib/mlir-air/tools/mlir-air-opt.cpp @@ -8,11 +8,12 @@ #include "aster/Init.h" #include "mlir/IR/Dialect.h" +#include "mlir/Pass/PassRegistry.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" namespace mlir::aster::mlir_air { void registerAll(DialectRegistry ®istry); -} +} // namespace mlir::aster::mlir_air int main(int argc, char **argv) { mlir::DialectRegistry registry; @@ -23,7 +24,7 @@ int main(int argc, char **argv) { mlir::aster::registerUpstreamMLIRExternalModels(registry); mlir::aster::initDialects(registry); mlir::aster::registerPasses(); - // mlir-air additions. + // mlir-air additions (registers dialects, passes, pipelines). mlir::aster::mlir_air::registerAll(registry); return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "mlir-air optimizer driver\n", registry)); diff --git a/lib/Dialect/AMDGCN/Transforms/PreloadLibrary.cpp b/lib/Dialect/AMDGCN/Transforms/PreloadLibrary.cpp index b46f9111a..13982acc8 100644 --- a/lib/Dialect/AMDGCN/Transforms/PreloadLibrary.cpp +++ b/lib/Dialect/AMDGCN/Transforms/PreloadLibrary.cpp @@ -138,7 +138,10 @@ bool PreloadLibrary::processModule( for (auto funcOp : module.getOps()) { if (funcOp.isDeclaration()) { StringRef name = funcOp.getSymName(); - if (libraryFunctions.contains(name)) { + auto it = libraryFunctions.find(name); + // Only replace if the library has a definition (not another declaration). + if (it != libraryFunctions.end() && + !it->second->getRegions().front().empty()) { declarationsToReplace.push_back(funcOp); } } diff --git a/mlir_kernels/library/common/indexing_ptr.mlir b/mlir_kernels/library/common/indexing_ptr.mlir index e800b9043..1e620c1d9 100644 --- a/mlir_kernels/library/common/indexing_ptr.mlir +++ b/mlir_kernels/library/common/indexing_ptr.mlir @@ -69,4 +69,14 @@ amdgcn.library @common_indexing_ptr { : (!sx2, !s, i32) -> !sx4 return %rsrc : !sx4 } + + // Bounded variant: set num_records to the actual buffer size in bytes. + // OOB loads return zero, OOB stores are silently dropped. + func.func private @make_raw_buffer_rsrc_bounded(%base: !sx2, %num_bytes: !s) -> !sx4 { + %c0_stride = arith.constant 0 : i32 + %rsrc = amdgcn.make_buffer_rsrc %base, %num_bytes, %c0_stride, + cache_swizzle = false, swizzle_enable = false, flags = 131072 + : (!sx2, !s, i32) -> !sx4 + return %rsrc : !sx4 + } } diff --git a/third_party/mlir-air b/third_party/mlir-air new file mode 160000 index 000000000..39ee8fc6c --- /dev/null +++ b/third_party/mlir-air @@ -0,0 +1 @@ +Subproject commit 39ee8fc6c9fe45057e2ad86dc3dd078e00e8d132 diff --git a/tools/aster-shlib/CMakeLists.txt b/tools/aster-shlib/CMakeLists.txt index 1cd08a933..7ca64ef16 100644 --- a/tools/aster-shlib/CMakeLists.txt +++ b/tools/aster-shlib/CMakeLists.txt @@ -51,14 +51,9 @@ if(ASTER_ENABLE_MLIR_AIR) target_link_libraries(ASTER PRIVATE -Wl,--whole-archive $ -Wl,--no-whole-archive) endif() - target_link_libraries(ASTER PRIVATE - MLIRLinalgDialect - MLIRLinalgTransformOps - MLIRLinalgTransforms - MLIRTransformDialect - MLIRTransformDialectTransforms - MLIRSCFTransformOps - ) + # Normal link for CMake transitive dep resolution (--whole-archive + # bypasses CMake's LINK_LIBS propagation). + target_link_libraries(ASTER PRIVATE MlirAirLib) endif() # Place ASTER.dylib/.so alongside the Python extensions in the build tree so # that @loader_path RPATH on the extensions resolves at build time (stubgen