From a7f454a5bb6fd3d2b4c85f7196ec25552d319cfc Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:11 +0000 Subject: [PATCH 01/16] Add AIR-to-AMDGCN lowering passes for mlir-air contrib MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Full AIR pipeline from tensor linalg to AMDGCN assembly. Pipeline: transform (tile/pad/promote/bufferize) → air-par-to-herd → one-shot-bufferize → air-par-to-launch → air-copy-to-dma → air-dma-to-channel → air-to-amdgcn → convert-memspace-to-amdgcn → convert-linalg-to-amdgcn → preload + inline → assembly. Design: - air.herd = wavefront (thread_id/64), air.launch = workgroup (block_id) - Base ptr stays sgpr, tile offsets passed separately (kittens pattern) - Per-wavefront LDS: alloc numWavefronts * size, offset by wave_id * size - scf.parallel from channel hoisting inlined with wavefront IDs - Global memref.copy/fill on non-LDS buffers eliminated Assembly generates correctly. GPU E2E test has numerical error: root cause identified as LDS cache key mismatch — air-to-amdgcn clones herd body ops (creating new SSA values), but channel gets reference original (pre-clone) alloc values. The ldsCache maps original allocs but matmul uses cloned allocs, causing separate LDS allocations that don't share data with the channel copies. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- CMakeLists.txt | 13 + contrib/mlir-air/CMakeLists.txt | 13 + contrib/mlir-air/lib/AirToAMDGCN.cpp | 324 +++++++++++++++ .../lib/ConvertAirChannelToAMDGCN.cpp | 314 +++++++++++++++ .../mlir-air/lib/ConvertLinalgToAMDGCN.cpp | 373 +++++++++++++++++- .../mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp | 177 +++++++++ contrib/mlir-air/lib/Init.cpp | 72 ++++ .../mlir-air/test/air-to-amdgcn-matmul.mlir | 220 +++++++++++ .../test/integration/test_air_matmul_e2e.py | 137 +++++++ contrib/mlir-air/tools/mlir-air-opt.cpp | 5 +- 10 files changed, 1626 insertions(+), 22 deletions(-) create mode 100644 contrib/mlir-air/lib/AirToAMDGCN.cpp create mode 100644 contrib/mlir-air/lib/ConvertAirChannelToAMDGCN.cpp create mode 100644 contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp create mode 100644 contrib/mlir-air/test/air-to-amdgcn-matmul.mlir create mode 100644 contrib/mlir-air/test/integration/test_air_matmul_e2e.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 14debaad1..ba1937b54 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -144,5 +144,18 @@ add_subdirectory(test) option(ASTER_ENABLE_MLIR_AIR "Build mlir-air contrib (linalg->AMDGCN)" OFF) if(ASTER_ENABLE_MLIR_AIR) message(STATUS "mlir-air enabled") + # Upstream Xilinx/mlir-air: AIR dialect + conversion passes (no AIE, no GPU backend). + set(AIR_ENABLE_AIE OFF CACHE BOOL "" FORCE) + set(AIR_ENABLE_GPU OFF CACHE BOOL "" FORCE) + # Stub targets expected by upstream mlir-air CMakeLists. + add_custom_target(air-headers) + # Make upstream AIR .td files and generated headers visible to mlir-tblgen + # and C++ compilation. Must come before add_subdirectory so the directory + # property is set when mlir_tablegen() calls get_directory_property(). + include_directories(${CMAKE_SOURCE_DIR}/third_party/mlir-air/mlir/include) + include_directories(${CMAKE_BINARY_DIR}/third_party/mlir-air/mlir/include) + # Only build AIR lib+include, not its test suite (check-all not available here). + add_subdirectory(third_party/mlir-air/mlir/include) + add_subdirectory(third_party/mlir-air/mlir/lib) add_subdirectory(contrib/mlir-air) endif() diff --git a/contrib/mlir-air/CMakeLists.txt b/contrib/mlir-air/CMakeLists.txt index b65fca9d4..33a30c8f6 100644 --- a/contrib/mlir-air/CMakeLists.txt +++ b/contrib/mlir-air/CMakeLists.txt @@ -1,9 +1,21 @@ add_mlir_library(MlirAirLib + lib/AirToAMDGCN.cpp + lib/ConvertAirChannelToAMDGCN.cpp lib/ConvertLinalgToAMDGCN.cpp + lib/ConvertMemSpaceToAMDGCN.cpp lib/Init.cpp lib/Pipelines.cpp LINK_LIBS PUBLIC + AIRConversionPasses + AIRDialect + MLIRBufferizationDialect + MLIRBufferizationTransformOps + MLIRBufferizationTransforms + MLIRTensorInferTypeOpInterfaceImpl + MLIRTensorTransforms + AIRTransformOps + AIRTransformPasses ASTERInit AsterTransforms AsterCodeGen @@ -13,6 +25,7 @@ add_mlir_library(MlirAirLib MLIRAffineDialect MLIRAffineTransforms MLIRFuncDialect + MLIRGPUDialect MLIRLinalgDialect MLIRLinalgTransformOps MLIRLinalgTransforms diff --git a/contrib/mlir-air/lib/AirToAMDGCN.cpp b/contrib/mlir-air/lib/AirToAMDGCN.cpp new file mode 100644 index 000000000..cd135716d --- /dev/null +++ b/contrib/mlir-air/lib/AirToAMDGCN.cpp @@ -0,0 +1,324 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- AirToAMDGCN.cpp - Lower AIR hierarchy ops to AMDGCN IR ------------===// +// +// Lowers air.launch, air.segment, air.herd, air.execute, and air.wait_all +// to flat AMDGCN-compatible IR: +// +// air.launch IDs -> gpu.block_id (workgroup IDs) +// air.herd tile IDs -> gpu.thread_id / 64 (wavefront index within workgroup) +// air.segment -> inline body (transparent wrapper) +// air.execute -> inline body (strip async) +// air.wait_all -> erase +// +// air.channel.put/get are preserved for the convert-air-channel-to-amdgcn pass. +//===----------------------------------------------------------------------===// + +#include "air/Dialect/AIR/AIRDialect.h" +#include "aster/Interfaces/ModuleOpInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Map gpu::Dimension enum from an integer index. +static gpu::Dimension dimFromIndex(unsigned i) { + switch (i) { + case 0: + return gpu::Dimension::x; + case 1: + return gpu::Dimension::y; + case 2: + return gpu::Dimension::z; + default: + llvm_unreachable("invalid dimension index"); + } +} + +/// Clone all ops from `src` into `builder`'s insertion point, applying +/// `mapping`. Ops whose type matches any of `SkipOps...` are skipped. +template +static void cloneBodyOps(OpBuilder &builder, Block &src, IRMapping &mapping) { + for (auto &op : src.getOperations()) { + if ((isa(op) || ...)) + continue; + builder.clone(op, mapping); + } +} + +/// Strip async_dependencies from a channel op and drop its async_token result. +static void stripAsyncFromChannelOp(Operation *op) { + if (auto asyncOp = dyn_cast(op)) { + // Remove all async dependency operands. + while (asyncOp.getAsyncDependencies().size() > 0) + asyncOp.eraseAsyncDependency(0); + // If the op has an async_token result, replace uses and drop it. + if (auto token = asyncOp.getAsyncToken()) { + token.replaceAllUsesWith(Value()); + // We can't actually remove a result from an existing op, but since all + // token users will be erased (wait_all, other async deps), replacing + // with null is sufficient — dead token uses are cleaned up below. + } + } +} + +// --------------------------------------------------------------------------- +// Pass +// --------------------------------------------------------------------------- + +struct AirToAMDGCN + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AirToAMDGCN) + StringRef getArgument() const override { return "air-to-amdgcn"; } + StringRef getDescription() const override { + return "Lower AIR hierarchy ops to AMDGCN-compatible IR"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + Operation *moduleOp = getOperation(); + OpBuilder builder(moduleOp->getContext()); + + // ----------------------------------------------------------------------- + // Phase 1: Strip async tokens. + // ----------------------------------------------------------------------- + + // Strip async from channel ops (they survive this pass). + moduleOp->walk([&](xilinx::air::ChannelPutOp op) { + stripAsyncFromChannelOp(op); + }); + moduleOp->walk([&](xilinx::air::ChannelGetOp op) { + stripAsyncFromChannelOp(op); + }); + + // Inline air.execute: splice body, replace results, erase. + SmallVector executes; + moduleOp->walk([&](xilinx::air::ExecuteOp op) { executes.push_back(op); }); + for (auto execOp : executes) { + Block &body = execOp.getBody(); + auto terminator = + cast(body.getTerminator()); + + // Replace non-token results (results 1..N) with terminator operands. + for (unsigned i = 0; i < terminator.getNumOperands(); ++i) + execOp.getResult(i + 1).replaceAllUsesWith(terminator.getOperand(i)); + + // Replace async token result (result 0) — users will be erased. + execOp.getResult(0).replaceAllUsesWith(Value()); + + // Splice body ops (except terminator) before the execute op. + auto &parentOps = execOp->getBlock()->getOperations(); + auto &bodyOps = body.getOperations(); + auto beforeTerminator = terminator->getIterator(); + parentOps.splice(execOp->getIterator(), bodyOps, bodyOps.begin(), + beforeTerminator); + + terminator->erase(); + execOp->erase(); + } + + // Erase air.wait_all ops. + SmallVector waitAlls; + moduleOp->walk( + [&](xilinx::air::WaitAllOp op) { waitAlls.push_back(op); }); + for (auto op : waitAlls) { + if (auto token = op.getAsyncToken()) + token.replaceAllUsesWith(Value()); + op->erase(); + } + + // ----------------------------------------------------------------------- + // Phase 2: Inline hierarchy ops (inside-out). + // ----------------------------------------------------------------------- + + // --- air.herd -> wavefront index (gpu.thread_id / wavefront_size) --- + // Each herd tile is a wavefront (64 threads cooperating collectively). + // Herd tile ID = wavefront index within the workgroup. + SmallVector herds; + moduleOp->walk([&](xilinx::air::HerdOp op) { herds.push_back(op); }); + for (auto herd : herds) { + builder.setInsertionPoint(herd); + Block &body = herd.getBody().front(); + unsigned numDims = herd.getNumDims(); + IRMapping mapping; + + // Map tile IDs to wavefront index = gpu.thread_id / 64. + auto ids = herd.getIds(); + Value wavefrontSize = arith::ConstantIndexOp::create( + builder, herd.getLoc(), 64); + for (unsigned i = 0; i < numDims; ++i) { + Value threadId = gpu::ThreadIdOp::create(builder, herd.getLoc(), + dimFromIndex(i)); + Value wavefrontId = arith::DivUIOp::create(builder, herd.getLoc(), + threadId, wavefrontSize); + mapping.map(ids[i], wavefrontId); + } + + // Map tile sizes to size operands. + auto sizeArgs = herd.getSize(); + auto sizeOperands = herd.getSizeOperands(); + for (unsigned i = 0; i < numDims; ++i) + mapping.map(sizeArgs[i], sizeOperands[i]); + + // Map kernel arguments to kernel operands. + auto kernelArgs = herd.getKernelArguments(); + auto kernelOperands = herd.getKernelOperands(); + for (unsigned i = 0; i < kernelArgs.size(); ++i) + mapping.map(kernelArgs[i], kernelOperands[i]); + + // Clone body. + cloneBodyOps(builder, body, mapping); + + // Replace async token if present. + if (auto token = herd.getAsyncToken()) + token.replaceAllUsesWith(Value()); + herd->erase(); + } + + // --- air.segment -> inline --- + SmallVector segments; + moduleOp->walk( + [&](xilinx::air::SegmentOp op) { segments.push_back(op); }); + for (auto segment : segments) { + builder.setInsertionPoint(segment); + Block &body = segment.getBody().front(); + unsigned numDims = segment.getNumDims(); + IRMapping mapping; + + // Map segment IDs (if any) to gpu.block_id ops. + auto ids = segment.getIds(); + auto sizeArgs = segment.getSize(); + auto sizeOperands = segment.getSizeOperands(); + for (unsigned i = 0; i < numDims; ++i) { + Value blockId = gpu::BlockIdOp::create(builder, segment.getLoc(), + dimFromIndex(i)); + mapping.map(ids[i], blockId); + mapping.map(sizeArgs[i], sizeOperands[i]); + } + + // Map kernel arguments to kernel operands. + auto kernelArgs = segment.getKernelArguments(); + auto kernelOperands = segment.getKernelOperands(); + for (unsigned i = 0; i < kernelArgs.size(); ++i) + mapping.map(kernelArgs[i], kernelOperands[i]); + + cloneBodyOps(builder, body, mapping); + + if (auto token = segment.getAsyncToken()) + token.replaceAllUsesWith(Value()); + segment->erase(); + } + + // --- air.launch -> gpu.block_id --- + SmallVector launches; + moduleOp->walk( + [&](xilinx::air::LaunchOp op) { launches.push_back(op); }); + for (auto launch : launches) { + builder.setInsertionPoint(launch); + Block &body = launch.getBody().front(); + unsigned numDims = launch.getNumDims(); + IRMapping mapping; + + // Map launch IDs to gpu.block_id ops. + auto ids = launch.getIds(); + auto sizeArgs = launch.getSize(); + auto sizeOperands = launch.getSizeOperands(); + for (unsigned i = 0; i < numDims; ++i) { + Value blockId = gpu::BlockIdOp::create(builder, launch.getLoc(), + dimFromIndex(i)); + mapping.map(ids[i], blockId); + mapping.map(sizeArgs[i], sizeOperands[i]); + } + + // Map kernel arguments to kernel operands. + auto kernelArgs = launch.getKernelArguments(); + auto kernelOperands = launch.getKernelOperands(); + for (unsigned i = 0; i < kernelArgs.size(); ++i) + mapping.map(kernelArgs[i], kernelOperands[i]); + + cloneBodyOps(builder, body, mapping); + + if (auto token = launch.getAsyncToken()) + token.replaceAllUsesWith(Value()); + launch->erase(); + } + + // --- scf.parallel -> wavefront ID inlining --- + // After air-dma-to-channel, hoisted channel.put/get ops are wrapped in + // scf.parallel loops that iterate over the herd tile space. These must + // be inlined: replace induction variables with gpu.thread_id / 64 + // (wavefront index), then splice body ops in place of the parallel. + SmallVector parallels; + moduleOp->walk( + [&](scf::ParallelOp op) { parallels.push_back(op); }); + for (auto parallel : parallels) { + builder.setInsertionPoint(parallel); + Location loc = parallel.getLoc(); + Block &body = parallel.getRegion().front(); + + // Compute wavefront ID for each induction variable. + // All dimensions are derived from thread_id_x (1D wavefront layout). + Value wavefrontSize = + arith::ConstantIndexOp::create(builder, loc, 64); + Value threadIdX = + gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x); + Value wavefrontId = + arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize); + + auto ivs = parallel.getInductionVars(); + if (ivs.size() == 1) { + // 1D parallel: iv = wavefrontId directly. + ivs[0].replaceAllUsesWith(wavefrontId); + } else if (ivs.size() == 2) { + // 2D parallel: decompose wavefrontId into (x, y). + auto ub = parallel.getUpperBound(); + Value ubX = ub[0]; + Value ivX = arith::RemUIOp::create(builder, loc, wavefrontId, ubX); + Value ivY = arith::DivUIOp::create(builder, loc, wavefrontId, ubX); + ivs[0].replaceAllUsesWith(ivX); + ivs[1].replaceAllUsesWith(ivY); + } + + // Replace init values / results. scf.parallel with reduce produces + // results from scf.reduce. Replace with init values (no actual + // reduction needed — each wavefront runs independently). + for (unsigned i = 0; i < parallel.getNumResults(); ++i) + parallel.getResult(i).replaceAllUsesWith(parallel.getInitVals()[i]); + + // Splice body ops (skip the yield/reduce terminator) before parallel. + auto &parentOps = parallel->getBlock()->getOperations(); + auto &bodyOps = body.getOperations(); + // Find the last non-terminator op. + auto termIt = body.getTerminator()->getIterator(); + parentOps.splice(parallel->getIterator(), bodyOps, bodyOps.begin(), + termIt); + + parallel->erase(); + } + } +}; + +} // namespace + +namespace mlir::aster::mlir_air { +std::unique_ptr createAirToAMDGCN() { + return std::make_unique(); +} +} // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/lib/ConvertAirChannelToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertAirChannelToAMDGCN.cpp new file mode 100644 index 000000000..6b4af7b09 --- /dev/null +++ b/contrib/mlir-air/lib/ConvertAirChannelToAMDGCN.cpp @@ -0,0 +1,314 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- ConvertAirChannelToAMDGCN.cpp - air.channel -> library calls -------===// +// +// Lowers air.channel.put and air.channel.get to AMDGCN library function calls. +// +// For each air.channel.put/get: +// - The memref operand is decomposed the same way as in ConvertLinalgToAMDGCN: +// * Global memref -> (!sgpr<[?+2]>, byte_stride: index) +// * Promoted buffer (memref.view of memref.alloca with memory space) +// -> LDS byte offset (index) +// - A func.call to a named library function is emitted: +// copy__(src_args..., dst_args...) +// where the shape comes from the channel's memref type. +// +// The channel.put sends data INTO the channel (producer side). +// The channel.get receives data FROM the channel (consumer side). +// Together they represent a point-to-point copy: the put's src is the copy +// source, the get's dst is the copy destination. +// +// This pass matches put/get pairs by channel name and emits a single copy call +// at the get site (the consumer), erasing both ops and the channel declaration. +//===----------------------------------------------------------------------===// + +#include "air/Dialect/AIR/AIRDialect.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNDialect.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNOps.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNTypes.h" +#include "aster/Dialect/LSIR/IR/LSIRDialect.h" +#include "aster/Dialect/LSIR/IR/LSIROps.h" +#include "aster/Interfaces/ModuleOpInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/Dialect/Ptr/IR/PtrTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::aster; +using namespace mlir::aster::amdgcn; + +namespace { + +// --------------------------------------------------------------------------- +// Utilities (shared with ConvertLinalgToAMDGCN.cpp) +// --------------------------------------------------------------------------- + +static std::string buildCopyFuncName(MemRefType ty) { + std::string name; + llvm::raw_string_ostream os(name); + os << "copy"; + Type elt = ty.getElementType(); + if (elt.isF16()) + os << "_f16"; + else if (elt.isF32()) + os << "_f32"; + else if (elt.isBF16()) + os << "_bf16"; + else + os << "_unknown"; + auto shape = ty.getShape(); + for (size_t i = 0; i < shape.size(); ++i) + os << (i == 0 ? "_" : "x") << shape[i]; + return name; +} + +static void ensureDecl(OpBuilder &builder, Block &block, Location loc, + StringRef name, FunctionType funcTy) { + for (auto &op : block) + if (auto fn = dyn_cast(&op)) + if (fn.getName() == name) + return; + auto savedIP = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(&block); + auto decl = func::FuncOp::create(builder, loc, name, funcTy); + decl.setPrivate(); + builder.restoreInsertionPoint(savedIP); +} + +static bool isPromotedBuffer(Value v) { + if (auto viewOp = v.getDefiningOp()) { + if (auto allocaOp = viewOp.getSource().getDefiningOp()) + return allocaOp.getMemref().getType().getMemorySpace() != nullptr; + } + if (auto allocOp = v.getDefiningOp()) + return allocOp.getMemref().getType().getMemorySpace() != nullptr; + return false; +} + +static Value emitLDSOffset(OpBuilder &builder, Location loc, Value memrefVal, + DenseMap &ldsCache) { + auto it = ldsCache.find(memrefVal); + if (it != ldsCache.end()) + return it->second; + + int64_t sizeBytes = 0; + Value byteShift; + if (auto viewOp = memrefVal.getDefiningOp()) { + auto allocaOp = viewOp.getSource().getDefiningOp(); + sizeBytes = allocaOp.getMemref().getType().getNumElements(); + byteShift = viewOp.getByteShift(); + } else if (auto allocOp = memrefVal.getDefiningOp()) { + auto mrTy = allocOp.getMemref().getType(); + unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); + sizeBytes = mrTy.getNumElements() * eltBits / 8; + } + auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), + sizeBytes, /*alignment=*/16, + /*offset=*/IntegerAttr{}); + auto ldsOffset = + GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); + Value result = ldsOffset.getResult(); + if (byteShift) + result = builder.create(loc, result, byteShift); + ldsCache[memrefVal] = result; + return result; +} + +static std::pair +decomposeGlobalMemref(OpBuilder &builder, Location loc, Value memref) { + auto mrTy = cast(memref.getType()); + unsigned eltBytes = mrTy.getElementType().getIntOrFloatBitWidth() / 8; + auto metadata = + memref::ExtractStridedMetadataOp::create(builder, loc, memref); + Value baseBuffer = metadata.getBaseBuffer(); + Value offset = metadata.getOffset(); + Value leadingStride = metadata.getStrides()[0]; + Value eltSize = arith::ConstantIndexOp::create(builder, loc, eltBytes); + Value byteStride = + arith::MulIOp::create(builder, loc, leadingStride, eltSize); + Value byteOffset = arith::MulIOp::create(builder, loc, offset, eltSize); + auto addrSpace = cast(mrTy.getMemorySpace()); + auto ptrTy = ptr::PtrType::get(builder.getContext(), addrSpace); + Value ptrVal = ptr::ToPtrOp::create(builder, loc, ptrTy, baseBuffer); + auto sx2Ty = amdgcn::SGPRType::get(builder.getContext(), Register(), + /*size=*/2, /*alignment=*/2); + Value rawPtr = lsir::ToRegOp::create(builder, loc, sx2Ty, ptrVal); + Value ptrFromReg = lsir::FromRegOp::create(builder, loc, ptrTy, rawPtr); + Value adjusted = + ptr::PtrAddOp::create(builder, loc, ptrTy, ptrFromReg, byteOffset); + Value result = lsir::ToRegOp::create(builder, loc, sx2Ty, adjusted); + return {result, byteStride}; +} + +/// Emit decomposed args for a memref operand (either LDS offset or global ptr). +static void emitDecomposedArgs(OpBuilder &builder, Location loc, Value memref, + SmallVectorImpl &callArgs, + SmallVectorImpl &argTypes, + DenseMap &ldsCache) { + auto indexTy = builder.getIndexType(); + auto sx2Ty = amdgcn::SGPRType::get(builder.getContext(), Register(), + /*size=*/2, /*alignment=*/2); + if (isPromotedBuffer(memref)) { + callArgs.push_back(emitLDSOffset(builder, loc, memref, ldsCache)); + argTypes.push_back(indexTy); + } else { + auto [ptrVal, byteStride] = decomposeGlobalMemref(builder, loc, memref); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); + argTypes.push_back(indexTy); + } +} + +// --------------------------------------------------------------------------- +// Pass +// --------------------------------------------------------------------------- + +struct ConvertAirChannelToAMDGCN + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertAirChannelToAMDGCN) + StringRef getArgument() const override { + return "convert-air-channel-to-amdgcn"; + } + StringRef getDescription() const override { + return "Convert air.channel.put/get pairs to AMDGCN library calls"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + Operation *moduleOp = getOperation(); + MLIRContext *ctx = &getContext(); + + Operation *declParent = moduleOp; + if (isa(moduleOp)) + moduleOp->walk([&](amdgcn::ModuleOp m) { declParent = m; }); + auto &declBlock = declParent->getRegion(0).front(); + + OpBuilder builder(ctx); + SmallVector toErase; + DenseMap ldsCache; + + // --------------------------------------------------------------- + // Path 1: Convert air.dma_memcpy_nd directly (no channels). + // --------------------------------------------------------------- + moduleOp->walk([&](xilinx::air::DmaMemcpyNdOp dma) { + Value dst = dma.getDstMemref(); + Value src = dma.getSrcMemref(); + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) + return; + + builder.setInsertionPoint(dma); + Location loc = dma.getLoc(); + + std::string name = buildCopyFuncName(dstTy); + + SmallVector callArgs; + SmallVector argTypes; + + // src: if the DMA has src offsets/sizes/strides, create a subview. + Value srcForDecompose = src; + auto srcOffsets = dma.getSrcOffsets(); + auto srcSizes = dma.getSrcSizes(); + auto srcStrides = dma.getSrcStrides(); + if (!srcOffsets.empty()) { + SmallVector offsets, sizes, strides; + for (auto v : srcOffsets) + offsets.push_back(v); + for (auto v : srcSizes) + sizes.push_back(v); + for (auto v : srcStrides) + strides.push_back(v); + srcForDecompose = memref::SubViewOp::create( + builder, loc, src, offsets, sizes, strides); + } + emitDecomposedArgs(builder, loc, srcForDecompose, callArgs, argTypes, + ldsCache); + // dst args. + emitDecomposedArgs(builder, loc, dst, callArgs, argTypes, ldsCache); + + auto funcTy = builder.getFunctionType(argTypes, {}); + ensureDecl(builder, declBlock, loc, name, funcTy); + func::CallOp::create(builder, loc, name, TypeRange{}, callArgs); + + toErase.push_back(dma); + }); + + // --------------------------------------------------------------- + // Path 2: Convert air.channel.put/get pairs (if channels present). + // --------------------------------------------------------------- + DenseMap> putsByChannel; + moduleOp->walk([&](xilinx::air::ChannelPutOp put) { + putsByChannel[put.getChanName()].push_back(put); + }); + + moduleOp->walk([&](xilinx::air::ChannelGetOp get) { + StringRef chanName = get.getChanName(); + auto it = putsByChannel.find(chanName); + if (it == putsByChannel.end() || it->second.empty()) + return; + + xilinx::air::ChannelPutOp put = it->second.front(); + + Value src = put.getSrc(); + Value dst = get.getDst(); + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) + return; + + builder.setInsertionPoint(get); + Location loc = get.getLoc(); + + std::string name = buildCopyFuncName(dstTy); + + SmallVector callArgs; + SmallVector argTypes; + + // src args. + emitDecomposedArgs(builder, loc, src, callArgs, argTypes, ldsCache); + // dst args. + emitDecomposedArgs(builder, loc, dst, callArgs, argTypes, ldsCache); + + auto funcTy = builder.getFunctionType(argTypes, {}); + ensureDecl(builder, declBlock, loc, name, funcTy); + func::CallOp::create(builder, loc, name, TypeRange{}, callArgs); + + toErase.push_back(get); + toErase.push_back(put); + }); + + for (auto *op : toErase) + op->erase(); + + // Clean up channel declarations that are now unused. + SmallVector deadChannels; + moduleOp->walk([&](xilinx::air::ChannelOp chan) { + deadChannels.push_back(chan); + }); + for (auto *op : deadChannels) + op->erase(); + } +}; + +} // namespace + +namespace mlir::aster::mlir_air { +std::unique_ptr createConvertAirChannelToAMDGCN() { + return std::make_unique(); +} +} // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp index 62b204e4b..4030c2b4a 100644 --- a/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp +++ b/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp @@ -6,9 +6,14 @@ //===- ConvertLinalgToAMDGCN.cpp - linalg ops -> AMDGCN library calls -----===// +#include "air/Dialect/AIR/AIRDialect.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h" #include "aster/Dialect/AMDGCN/IR/AMDGCNDialect.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNEnums.h" #include "aster/Dialect/AMDGCN/IR/AMDGCNOps.h" #include "aster/Dialect/AMDGCN/IR/AMDGCNTypes.h" +#include "aster/Dialect/LSIR/IR/LSIRDialect.h" #include "aster/Dialect/LSIR/IR/LSIROps.h" #include "aster/Interfaces/ModuleOpInterface.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -16,6 +21,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" #include "mlir/Dialect/Ptr/IR/PtrOps.h" #include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/IR/Builders.h" @@ -60,38 +66,69 @@ static void ensureDecl(OpBuilder &builder, Block &block, Location loc, builder.restoreInsertionPoint(savedIP); } -/// Check if a memref value comes from promote to shared memory -/// (memref.view(memref.alloca) with workgroup address space). +/// Check if a memref value comes from promote to shared memory. +/// Matches two patterns: +/// 1. memref.view(memref.alloca) with non-default memory space (from promote) +/// 2. memref.alloc with non-default memory space (from bufferize_to_allocation) static bool isPromotedBuffer(Value v) { - auto viewOp = v.getDefiningOp(); - if (!viewOp) - return false; - auto allocaOp = viewOp.getSource().getDefiningOp(); - if (!allocaOp) + // Pattern 1: memref.view(memref.alloca) — from transform.structured.promote. + if (auto viewOp = v.getDefiningOp()) { + if (auto allocaOp = viewOp.getSource().getDefiningOp()) { + return allocaOp.getMemref().getType().getMemorySpace() != nullptr; + } + } + // Pattern 2: memref.alloc with L1/local memory space — + // from bufferize_to_allocation. + if (auto allocOp = v.getDefiningOp()) { + auto memSpace = allocOp.getMemref().getType().getMemorySpace(); + if (!memSpace) + return false; + // Integer memory space 2 = L1 (AIR convention). + if (auto intAttr = dyn_cast(memSpace)) + return intAttr.getInt() == 2; + // #amdgcn.addr_space = LDS. + if (auto addrSpace = dyn_cast(memSpace)) + return addrSpace.getSpace() == amdgcn::AddressSpaceKind::Local; return false; - auto memSpace = allocaOp.getMemref().getType().getMemorySpace(); - return memSpace != nullptr; + } + return false; } /// Emit amdgcn.alloc_lds + get_lds_offset for a promoted buffer. -/// Uses a cache so the same promoted buffer (same memref.view(memref.alloca)) -/// gets the same LDS region for both write (copy) and read (matmul). +/// Uses a cache so the same promoted buffer gets the same LDS region +/// for both write (copy) and read (matmul). static Value emitLDSOffset(OpBuilder &builder, Location loc, Value memrefVal, DenseMap &ldsCache) { auto it = ldsCache.find(memrefVal); if (it != ldsCache.end()) return it->second; - auto viewOp = memrefVal.getDefiningOp(); - auto allocaOp = viewOp.getSource().getDefiningOp(); - int64_t sizeBytes = allocaOp.getMemref().getType().getNumElements(); + int64_t sizeBytes = 0; + Value byteShift; + + // Pattern 1: memref.view(memref.alloca) — from promote. + if (auto viewOp = memrefVal.getDefiningOp()) { + auto allocaOp = viewOp.getSource().getDefiningOp(); + sizeBytes = allocaOp.getMemref().getType().getNumElements(); + byteShift = viewOp.getByteShift(); + } + // Pattern 2: memref.alloc — from bufferize_to_allocation. + else if (auto allocOp = memrefVal.getDefiningOp()) { + auto mrTy = allocOp.getMemref().getType(); + unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); + sizeBytes = mrTy.getNumElements() * eltBits / 8; + } + auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), sizeBytes, /*alignment=*/16, /*offset=*/IntegerAttr{}); auto ldsOffset = GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); - Value result = builder.create(loc, ldsOffset.getResult(), - viewOp.getByteShift()); + + Value result = ldsOffset.getResult(); + if (byteShift) + result = builder.create(loc, result, byteShift); + ldsCache[memrefVal] = result; return result; } @@ -145,6 +182,14 @@ static void replaceWithCall(OpBuilder &builder, Block &declBlock, Operation *op, StringRef namePrefix, SmallVector &toErase, DenseMap &ldsCache) { + // Only convert ops that involve at least one promoted (LDS) buffer. + bool hasPromotedOperand = false; + for (Value operand : op->getOperands()) + if (isPromotedBuffer(operand)) + hasPromotedOperand = true; + if (!hasPromotedOperand) + return; + auto indexTy = builder.getIndexType(); SmallVector callArgs; SmallVector argTypes; @@ -170,13 +215,42 @@ static void replaceWithCall(OpBuilder &builder, Block &declBlock, Operation *op, callArgs.push_back(emitLDSOffset(builder, loc, operand, ldsCache)); argTypes.push_back(indexTy); } else { - // Decompose global memref into (!sx2, byte_stride) + // Global memref: if this is a subview, decompose the BASE memref + // (clean sgpr) and pass tile offsets separately. This avoids + // baking wavefront-varying offsets into the pointer. + Value baseMemref = operand; + SmallVector tileOffsets; + if (auto svOp = operand.getDefiningOp()) { + baseMemref = svOp.getSource(); + for (auto off : svOp.getMixedOffsets()) { + if (auto val = dyn_cast(off)) + tileOffsets.push_back(val); + else + tileOffsets.push_back(arith::ConstantIndexOp::create( + builder, loc, + cast(off.get()).getInt())); + } + } auto [ptrVal, byteStride] = - decomposeGlobalMemref(builder, loc, operand); + decomposeGlobalMemref(builder, loc, baseMemref); callArgs.push_back(ptrVal); argTypes.push_back(sx2Ty); callArgs.push_back(byteStride); argTypes.push_back(indexTy); + // Pass tile offsets (or zeros if no subview). + if (tileOffsets.empty()) { + auto rank = mrTy.getRank(); + for (int64_t i = 0; i < rank; ++i) { + callArgs.push_back( + arith::ConstantIndexOp::create(builder, loc, 0)); + argTypes.push_back(indexTy); + } + } else { + for (auto off : tileOffsets) { + callArgs.push_back(off); + argTypes.push_back(indexTy); + } + } } } else { callArgs.push_back(operand); @@ -198,6 +272,11 @@ struct ConvertLinalgToAMDGCN StringRef getDescription() const override { return "Convert tiled linalg ops to AMDGCN library calls"; } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } void runOnOperation() override { Operation *moduleOp = getOperation(); @@ -211,17 +290,130 @@ struct ConvertLinalgToAMDGCN SmallVector toErase; DenseMap ldsCache; + // Pre-allocate LDS at function entry for channel get destinations FIRST, + // before processing linalg ops. This ensures the matmul (which shares the + // same memref.alloc as the channel.get) hits the cache and uses the + // function-entry LDS offset instead of creating one inside the loop. + DenseMap> putsByChannel; + moduleOp->walk([&](xilinx::air::ChannelPutOp put) { + putsByChannel[put.getChanName()].push_back(put); + }); + // Determine number of wavefronts from channel array dimensions. + int64_t numWavefronts = 1; + moduleOp->walk([&](xilinx::air::ChannelOp chan) { + auto sizes = chan.getSize(); + int64_t total = 1; + for (auto s : sizes) + total *= cast(s).getInt(); + if (total > numWavefronts) + numWavefronts = total; + }); + + // Pre-allocate LDS for channel get destinations (Global→LDS direction). + // Each wavefront gets its own LDS region: alloc numWavefronts * size, + // offset by wavefront_id * size. + moduleOp->walk([&](xilinx::air::ChannelGetOp get) { + Value dst = get.getDst(); + if (!isPromotedBuffer(dst)) + return; + auto funcOp = get->getParentOfType(); + if (!funcOp || funcOp.empty()) + return; + auto savedIP = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(&funcOp.front()); + Location loc = funcOp.getLoc(); + + // Get per-wavefront tile size. + int64_t tileSizeBytes = 0; + if (auto allocOp = dst.getDefiningOp()) { + auto mrTy = allocOp.getMemref().getType(); + unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); + tileSizeBytes = mrTy.getNumElements() * eltBits / 8; + } + + // Allocate numWavefronts * tileSizeBytes. + auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), + numWavefronts * tileSizeBytes, + /*alignment=*/16, + /*offset=*/IntegerAttr{}); + auto ldsBaseOffset = + GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); + + // Per-wavefront offset: base + wavefront_id * tileSizeBytes. + Value wavefrontSize = + arith::ConstantIndexOp::create(builder, loc, 64); + Value threadIdX = + gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x); + Value wavefrontId = + arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize); + Value tileSizeVal = + arith::ConstantIndexOp::create(builder, loc, tileSizeBytes); + Value wavefrontOffset = + arith::MulIOp::create(builder, loc, wavefrontId, tileSizeVal); + Value adjustedOffset = builder.create( + loc, ldsBaseOffset.getResult(), wavefrontOffset); + + ldsCache[dst] = adjustedOffset; + builder.restoreInsertionPoint(savedIP); + }); + // Pre-allocate LDS for channel put sources (LDS→Global direction). + moduleOp->walk([&](xilinx::air::ChannelPutOp put) { + Value src = put.getSrc(); + if (!isPromotedBuffer(src)) + return; + if (ldsCache.count(src)) + return; // Already allocated (shared with channel get). + auto funcOp = put->getParentOfType(); + if (!funcOp || funcOp.empty()) + return; + auto savedIP = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(&funcOp.front()); + Location loc = funcOp.getLoc(); + + int64_t tileSizeBytes = 0; + if (auto allocOp = src.getDefiningOp()) { + auto mrTy = allocOp.getMemref().getType(); + unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); + tileSizeBytes = mrTy.getNumElements() * eltBits / 8; + } + + auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), + numWavefronts * tileSizeBytes, + /*alignment=*/16, + /*offset=*/IntegerAttr{}); + auto ldsBaseOffset = + GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); + + Value wavefrontSize = + arith::ConstantIndexOp::create(builder, loc, 64); + Value threadIdX = + gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x); + Value wavefrontId = + arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize); + Value tileSizeVal = + arith::ConstantIndexOp::create(builder, loc, tileSizeBytes); + Value wavefrontOffset = + arith::MulIOp::create(builder, loc, wavefrontId, tileSizeVal); + Value adjustedOffset = builder.create( + loc, ldsBaseOffset.getResult(), wavefrontOffset); + + ldsCache[src] = adjustedOffset; + builder.restoreInsertionPoint(savedIP); + }); + + // Now process linalg ops — they'll hit the ldsCache for shared allocs. moduleOp->walk([&](linalg::FillOp op) { replaceWithCall(builder, declBlock, op, "fill", toErase, ldsCache); }); moduleOp->walk([&](linalg::CopyOp op) { replaceWithCall(builder, declBlock, op, "copy", toErase, ldsCache); }); + moduleOp->walk([&](memref::CopyOp op) { + replaceWithCall(builder, declBlock, op, "copy", toErase, ldsCache); + }); moduleOp->walk([&](linalg::MatmulOp op) { replaceWithCall(builder, declBlock, op, "mfma_matmul", toErase, ldsCache); }); - // Also handle linalg.generic with matmul-like semantics (e.g., - // matmul_transpose_b expressed as generic with (m,n,k)->(m,k),(n,k),(m,n)). moduleOp->walk([&](linalg::GenericOp op) { if (op.getNumDpsInputs() == 2 && op.getNumDpsInits() == 1 && op.getNumReductionLoops() == 1) @@ -229,9 +421,150 @@ struct ConvertLinalgToAMDGCN ldsCache); }); + // Emit copy call at each put site (where global src operands dominate). + moduleOp->walk([&](xilinx::air::ChannelGetOp get) { + StringRef chanName = get.getChanName(); + auto it = putsByChannel.find(chanName); + if (it == putsByChannel.end() || it->second.empty()) + return; + xilinx::air::ChannelPutOp put = it->second.front(); + + Value dst = get.getDst(); + auto dstTy = dyn_cast(dst.getType()); + if (!dstTy) + return; + + Value src = put.getSrc(); + bool srcIsLDS = isPromotedBuffer(src); + bool dstIsLDS = isPromotedBuffer(dst); + + // Determine direction and emit at the appropriate site. + // Global→LDS: emit at put site (global src operands dominate). + // LDS→Global: emit at get site (global dst operands dominate). + if (srcIsLDS && !dstIsLDS) { + // LDS→Global (C write-back): emit at get site. + builder.setInsertionPoint(get); + } else { + // Global→LDS (A/B copy): emit at put site. + builder.setInsertionPoint(put); + } + Location loc = builder.getInsertionPoint()->getLoc(); + + // Use the L1 (smaller) memref type for the function name. + auto namingTy = srcIsLDS ? cast(src.getType()) : dstTy; + std::string name = buildFuncName("copy", namingTy); + + auto indexTy = builder.getIndexType(); + auto sx2Ty = amdgcn::SGPRType::get(ctx, Register(), + /*size=*/2, /*alignment=*/2); + SmallVector callArgs; + SmallVector argTypes; + + // Decompose src side. + if (srcIsLDS) { + // LDS src. + assert(ldsCache.count(src) && "LDS offset not pre-allocated for channel put src"); + callArgs.push_back(ldsCache[src]); + argTypes.push_back(indexTy); + } else { + // Global src: decompose BASE memref (not subview) to get a clean + // sgpr pointer. Pass the channel's tile offsets separately so the + // library function handles them (kittens pattern). + auto [ptrVal, byteStride] = + decomposeGlobalMemref(builder, loc, src); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); + argTypes.push_back(indexTy); + // Tile offsets from the channel put (element-level indices). + auto putOffsets = put.getSrcOffsets(); + if (putOffsets.size() >= 2) { + callArgs.push_back(putOffsets[0]); // row offset + argTypes.push_back(indexTy); + callArgs.push_back(putOffsets[1]); // col offset + argTypes.push_back(indexTy); + } else { + Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); + callArgs.push_back(c0); + argTypes.push_back(indexTy); + callArgs.push_back(c0); + argTypes.push_back(indexTy); + } + } + + // Decompose dst side. + if (dstIsLDS) { + // LDS dst. + assert(ldsCache.count(dst) && "LDS offset not pre-allocated for channel get dst"); + callArgs.push_back(ldsCache[dst]); + argTypes.push_back(indexTy); + } else { + // Global dst: decompose BASE memref, pass offsets separately. + auto [ptrVal, byteStride] = + decomposeGlobalMemref(builder, loc, dst); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); + argTypes.push_back(indexTy); + auto getOffsets = get.getDstOffsets(); + if (getOffsets.size() >= 2) { + callArgs.push_back(getOffsets[0]); + argTypes.push_back(indexTy); + callArgs.push_back(getOffsets[1]); + argTypes.push_back(indexTy); + } else { + Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); + callArgs.push_back(c0); + argTypes.push_back(indexTy); + callArgs.push_back(c0); + argTypes.push_back(indexTy); + } + } + + auto funcTy = builder.getFunctionType(argTypes, {}); + ensureDecl(builder, declBlock, loc, name, funcTy); + func::CallOp::create(builder, loc, name, TypeRange{}, callArgs); + + toErase.push_back(put); + toErase.push_back(get); + }); + + // Clean up channel declarations. + moduleOp->walk([&](xilinx::air::ChannelOp chan) { + toErase.push_back(chan); + }); + for (auto *op : toErase) op->erase(); + // Erase linalg.fill on global buffers — the library handles zero-init. + SmallVector globalFills; + moduleOp->walk([&](linalg::FillOp fill) { + for (Value out : fill.getDpsInits()) { + if (auto mrTy = dyn_cast(out.getType())) + if (!isPromotedBuffer(out)) + globalFills.push_back(fill); + } + }); + for (auto fill : globalFills) + fill->erase(); + + // Eliminate global→global memref.copy by forwarding the destination. + // This handles the `memref.copy %alloc, %arg` from materialize_in_destination. + moduleOp->walk([&](memref::CopyOp copy) { + Value src = copy.getSource(); + Value dst = copy.getTarget(); + if (isPromotedBuffer(src) || isPromotedBuffer(dst)) + return; + // Both are global — replace all uses of src with dst and erase. + if (auto allocOp = src.getDefiningOp()) { + src.replaceAllUsesWith(dst); + copy->erase(); + if (allocOp->use_empty()) + allocOp->erase(); + } + }); + // DCE unused alloca/view/dealloc. SmallVector deadOps; moduleOp->walk([&](Operation *op) { diff --git a/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp new file mode 100644 index 000000000..b47497036 --- /dev/null +++ b/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp @@ -0,0 +1,177 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- ConvertMemSpaceToAMDGCN.cpp - integer memspace -> amdgcn addr_space ===// +// +// Converts AIR-style integer memory spaces on memref types to AMDGCN address +// space attributes: +// +// null / 0 → #amdgcn.addr_space +// 2 (L1) → #amdgcn.addr_space +// +// This pass bridges the gap between the upstream AIR pipeline (which uses +// integer memory spaces) and aster's AMDGCN decomposition passes (which +// require #amdgcn.addr_space attributes for pointer generation). +// +// Run after air-to-amdgcn and before convert-air-channel-to-amdgcn. +//===----------------------------------------------------------------------===// + +#include "aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNDialect.h" +#include "aster/Dialect/AMDGCN/IR/AMDGCNEnums.h" +#include "aster/Interfaces/ModuleOpInterface.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::aster::amdgcn; + +namespace { + +/// Map an integer memory space to an AMDGCN AddressSpaceAttr. +/// Returns nullptr if no mapping is needed (e.g., already an AMDGCN attr). +static AddressSpaceAttr mapMemorySpace(Attribute memSpace, + MLIRContext *ctx) { + // null or integer 0 → global + if (!memSpace) { + return AddressSpaceAttr::get(ctx, AddressSpaceKind::Global, + AccessKind::ReadWrite); + } + if (auto intAttr = dyn_cast(memSpace)) { + unsigned space = intAttr.getInt(); + switch (space) { + case 0: + return AddressSpaceAttr::get(ctx, AddressSpaceKind::Global, + AccessKind::ReadWrite); + case 2: + return AddressSpaceAttr::get(ctx, AddressSpaceKind::Local, + AccessKind::ReadWrite); + default: + // Unknown integer space — map to global as fallback. + return AddressSpaceAttr::get(ctx, AddressSpaceKind::Global, + AccessKind::ReadWrite); + } + } + // Already a non-integer attribute (e.g., #amdgcn.addr_space) — no change. + return {}; +} + +/// Convert a MemRefType's memory space if needed. +static MemRefType convertMemRefType(MemRefType ty, MLIRContext *ctx) { + auto newSpace = mapMemorySpace(ty.getMemorySpace(), ctx); + if (!newSpace) + return ty; + // Use the AffineMap overload since MemRefLayoutAttrInterface overload + // doesn't exist for some MLIR versions. + Attribute spaceAttr = newSpace; + return MemRefType::get(ty.getShape(), ty.getElementType(), + ty.getLayout().getAffineMap(), spaceAttr); +} + +struct ConvertMemSpaceToAMDGCN + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertMemSpaceToAMDGCN) + StringRef getArgument() const override { + return "convert-memspace-to-amdgcn"; + } + StringRef getDescription() const override { + return "Convert integer memory spaces to #amdgcn.addr_space attributes"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + Operation *moduleOp = getOperation(); + MLIRContext *ctx = &getContext(); + + // Collect all func.func ops to update. + SmallVector funcs; + moduleOp->walk([&](func::FuncOp f) { funcs.push_back(f); }); + + for (auto funcOp : funcs) { + // Update function signature. + auto funcTy = funcOp.getFunctionType(); + bool changed = false; + + SmallVector newInputs; + for (auto ty : funcTy.getInputs()) { + if (auto mrTy = dyn_cast(ty)) { + auto newTy = convertMemRefType(mrTy, ctx); + newInputs.push_back(newTy); + if (newTy != mrTy) + changed = true; + } else { + newInputs.push_back(ty); + } + } + + SmallVector newResults; + for (auto ty : funcTy.getResults()) { + if (auto mrTy = dyn_cast(ty)) { + auto newTy = convertMemRefType(mrTy, ctx); + newResults.push_back(newTy); + if (newTy != mrTy) + changed = true; + } else { + newResults.push_back(ty); + } + } + + if (changed) { + funcOp.setType(FunctionType::get(ctx, newInputs, newResults)); + // Update block argument types. + if (!funcOp.empty()) { + Block &entry = funcOp.front(); + for (unsigned i = 0; i < entry.getNumArguments(); ++i) { + if (auto mrTy = dyn_cast(entry.getArgument(i).getType())) { + auto newTy = convertMemRefType(mrTy, ctx); + if (newTy != mrTy) + entry.getArgument(i).setType(newTy); + } + } + } + } + + // Walk all ops inside the function and update memref types. + funcOp->walk([&](Operation *op) { + // Update operand types (handled transitively via result types). + // Update result types. + for (unsigned i = 0; i < op->getNumResults(); ++i) { + auto ty = op->getResult(i).getType(); + if (auto mrTy = dyn_cast(ty)) { + auto newTy = convertMemRefType(mrTy, ctx); + if (newTy != mrTy) + op->getResult(i).setType(newTy); + } + } + // Update block argument types in regions. + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto arg : block.getArguments()) { + if (auto mrTy = dyn_cast(arg.getType())) { + auto newTy = convertMemRefType(mrTy, ctx); + if (newTy != mrTy) + arg.setType(newTy); + } + } + } + } + }); + } + } +}; + +} // namespace + +namespace mlir::aster::mlir_air { +std::unique_ptr createConvertMemSpaceToAMDGCN() { + return std::make_unique(); +} +} // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/lib/Init.cpp b/contrib/mlir-air/lib/Init.cpp index b3d11c0fd..7a0226dc5 100644 --- a/contrib/mlir-air/lib/Init.cpp +++ b/contrib/mlir-air/lib/Init.cpp @@ -6,8 +6,39 @@ //===- Init.cpp - mlir-air dialect and pass registration ------------------===// +#include "air/Conversion/ConvertToAIRPass.h" +#include "air/Dialect/AIR/AIRDialect.h" +#include "air/Dialect/AIR/AIRTransformOps.h" +#include "air/Transform/AIRDmaToChannel.h" + +// Tablegen-generated per-pass registration for upstream AIR passes. +namespace air_conv_reg { +#define GEN_PASS_REGISTRATION_COPYTODMA +#define GEN_PASS_REGISTRATION_PARALLELTOHERD +#define GEN_PASS_REGISTRATION_PARALLELTOLAUNCH +#define GEN_PASS_REGISTRATION_AIRWRAPFUNCWITHPARALLELPASS +#include "air/Conversion/Passes.h.inc" +} // namespace air_conv_reg + +namespace air_xform_reg { +#define GEN_PASS_REGISTRATION_DMATOCHANNEL +#include "air/Transform/Passes.h.inc" +} // namespace air_xform_reg + #include "aster/Init.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" @@ -20,15 +51,33 @@ namespace mlir::aster::mlir_air { +std::unique_ptr createAirToAMDGCN(); +std::unique_ptr createConvertAirChannelToAMDGCN(); std::unique_ptr createConvertLinalgToAMDGCN(); +std::unique_ptr createConvertMemSpaceToAMDGCN(); void registerPipelines(); void registerAll(DialectRegistry ®istry) { + // AIR dialect. + registry.insert(); + // Dialects needed for linalg tiling + transform dialect. + registry.insert(); registry.insert(); registry.insert(); + // Bufferization interface models. + arith::registerBufferizableOpInterfaceExternalModels(registry); + bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + linalg::registerBufferizableOpInterfaceExternalModels(registry); + linalg::registerSubsetOpInterfaceExternalModels(registry); + scf::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); + // Transform dialect extensions. + bufferization::registerTransformDialectExtension(registry); linalg::registerTransformDialectExtension(registry); scf::registerTransformDialectExtension(registry); @@ -36,11 +85,33 @@ void registerAll(DialectRegistry ®istry) { linalg::registerTilingInterfaceExternalModels(registry); // Upstream passes. + bufferization::registerBufferizationPasses(); registerLinalgPasses(); memref::registerMemRefPasses(); transform::registerInterpreterPass(); + // AIR transform ops extension (air.transform.*). + xilinx::air::registerTransformDialectExtension(registry); + + // Upstream doesn't declare airDialect as a dependent of the transform + // extension — add it so par_to_herd can create air.herd ops. + registry.addExtension( + +[](MLIRContext *ctx, transform::TransformDialect *dialect) { + ctx->getOrLoadDialect(); + }); + + // Upstream AIR passes (tablegen-generated registration). + air_conv_reg::registerCopyToDma(); // air-copy-to-dma + air_conv_reg::registerParallelToHerd(); // air-par-to-herd + air_conv_reg::registerParallelToLaunch(); // air-par-to-launch + air_conv_reg::registerAIRWrapFuncWithParallelPass(); // air-wrap-func-with-parallel + air_xform_reg::registerDmaToChannel(); // air-dma-to-channel + + // Aster-specific passes. + registerPass([] { return createAirToAMDGCN(); }); registerPass([] { return createConvertLinalgToAMDGCN(); }); + registerPass([] { return createConvertAirChannelToAMDGCN(); }); + registerPass([] { return createConvertMemSpaceToAMDGCN(); }); // mlir-air pipelines. registerPipelines(); @@ -50,4 +121,5 @@ void registerAll(DialectRegistry ®istry) { // are available in aster-opt and the Python bindings when linked. static int _register = (mlir::aster::registerContribDialects(registerAll), 0); + } // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir new file mode 100644 index 000000000..40c2eeab8 --- /dev/null +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir @@ -0,0 +1,220 @@ +// RUN: mlir-air-opt %s \ +// RUN: --transform-interpreter --canonicalize --cse \ +// RUN: --air-par-to-launch="has-air-segment=true" --canonicalize --cse \ +// RUN: --air-copy-to-dma \ +// RUN: --air-dma-to-channel \ +// RUN: --air-to-amdgcn --canonicalize \ +// RUN: --convert-memspace-to-amdgcn \ +// RUN: --convert-linalg-to-amdgcn \ +// RUN: --amdgcn-preload-library="library-paths=%p/../../../mlir_kernels/library/common/register-init.mlir,%p/../../../mlir_kernels/library/common/indexing.mlir,%p/../../../mlir_kernels/library/common/indexing_ptr.mlir,%p/../../../mlir_kernels/library/common/futures.mlir,%p/../../../contrib/kittens/library/compute_16x16_f16.mlir,%p/../../../contrib/kittens/library/global_16x64_b.mlir,%p/../../../contrib/kittens/library/lds_16x64_b.mlir,%p/../../../contrib/kittens/library/lds_mfma_16x64_b.mlir" \ +// RUN: --inline --symbol-dce --canonicalize \ +// RUN: --mlir-air-to-asm \ +// RUN: | aster-translate --mlir-to-asm \ +// RUN: | FileCheck %s + +// CHECK-LABEL: matmul_f16_64x64: +// CHECK: global_load_dwordx4 +// CHECK: ds_write_b64 +// CHECK: ds_read_b64 +// CHECK: v_mfma_f32_16x16x16_f16 +// CHECK: global_store_dword +// CHECK: s_endpgm + +// Real AIR pipeline (adapted from xrt/12, tile-using-pad path): +// +// 1. linalg.generic on tensors (64x64 matmul, no AIR ops) +// 2. transform: tile_using_forall (2x1 herd) → tile_using_for (compute) +// → pad → bufferize_to_allocation (L1) → one_shot_bufferize +// → forall_to_parallel → par_to_herd +// 3. air-copy-to-dma → air-dma-to-channel +// 4. air-to-amdgcn (herd → wavefront index) +// 5. convert-memspace → convert-linalg → preload → asm + +!sx2 = !amdgcn.sgpr<[? + 2]> +!vx2 = !amdgcn.vgpr<[? + 2]> +!ax4 = !amdgcn.agpr<[? + 4]> +!lds_write_token = !amdgcn.write_token +!future_lds_read = !aster_utils.struct> +!future_global_read = !aster_utils.struct> + +module attributes {transform.with_named_sequence} { + amdgcn.library @linalg_lib isa = [#amdgcn.isa] { + func.func private @zero_C() -> !ax4 + func.func private @mfma_f32_16x16x16_f16(!vx2, !vx2, !ax4) -> !ax4 + func.func private @store_global_C_mfma_f32_16x16x16_f16( + !ax4, !aster_utils.any, index, index, index) + func.func private @prepare_ptr(!sx2) -> !aster_utils.any + func.func private @load_global_tile_16x64_b( + !aster_utils.any, index, index, index) -> !future_global_read + func.func private @store_global_tile_to_lds_16x64_b( + index, !future_global_read) -> (!lds_write_token, !lds_write_token) + func.func private @load_lds_A_swizzled( + index, index, index) -> !future_lds_read + func.func private @load_lds_B_swizzled( + index, index, index) -> !future_lds_read + func.func private @get_lds_read_value_vx2(!future_lds_read) -> !vx2 + + func.func private @copy_f16_16x32( + %src_ptr: !sx2, %src_stride: index, + %row_offset: index, %col_offset: index, + %lds_dst: index) { + %ptr = func.call @prepare_ptr(%src_ptr) : (!sx2) -> !aster_utils.any + %gfut = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col_offset, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t0, %t1 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst, %gfut) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + amdgcn.wait deps %t0 : !lds_write_token + amdgcn.wait deps %t1 : !lds_write_token + return + } + + func.func private @mfma_matmul_f16_16x32( + %lds_A: index, %lds_B: index, + %C_ptr: !sx2, %C_stride: index, + %C_row_offset: index, %C_col_offset: index) { + %C_prepared = func.call @prepare_ptr(%C_ptr) : (!sx2) -> !aster_utils.any + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %acc = func.call @zero_C() : () -> !ax4 + %A0f = func.call @load_lds_A_swizzled(%lds_A, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A0 = func.call @get_lds_read_value_vx2(%A0f) : (!future_lds_read) -> !vx2 + %B0f = func.call @load_lds_B_swizzled(%lds_B, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B0 = func.call @get_lds_read_value_vx2(%B0f) : (!future_lds_read) -> !vx2 + %acc0 = func.call @mfma_f32_16x16x16_f16(%A0, %B0, %acc) + : (!vx2, !vx2, !ax4) -> !ax4 + %A1f = func.call @load_lds_A_swizzled(%lds_A, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A1 = func.call @get_lds_read_value_vx2(%A1f) : (!future_lds_read) -> !vx2 + %B1f = func.call @load_lds_B_swizzled(%lds_B, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B1 = func.call @get_lds_read_value_vx2(%B1f) : (!future_lds_read) -> !vx2 + %acc1 = func.call @mfma_f32_16x16x16_f16(%A1, %B1, %acc0) + : (!vx2, !vx2, !ax4) -> !ax4 + func.call @store_global_C_mfma_f32_16x16x16_f16( + %acc1, %C_prepared, %C_row_offset, %C_col_offset, %C_stride) + : (!ax4, !aster_utils.any, index, index, index) -> () + return + } + + func.func private @fill_f16_16x32(%val: f16, %lds_dst: index) { return } + } + + amdgcn.module @matmul_mod target = #amdgcn.target isa = #amdgcn.isa { + // 64x64 tensor-based matmul. No AIR ops — transform generates hierarchy. + func.func @matmul_f16_64x64( + %A: memref<64x64xf16>, %B: memref<64x64xf16>, %C: memref<64x64xf32>) + attributes {gpu.kernel} { + %cst = arith.constant 0.000000e+00 : f32 + %a = bufferization.to_tensor %A restrict writable : memref<64x64xf16> to tensor<64x64xf16> + %b = bufferization.to_tensor %B restrict writable : memref<64x64xf16> to tensor<64x64xf16> + %empty = tensor.empty() : tensor<64x64xf32> + %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<64x64xf32>) -> tensor<64x64xf32> + // matmul_transpose_b: C[m,n] += A[m,k] * B[n,k] + %result = linalg.generic { + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%a, %b : tensor<64x64xf16>, tensor<64x64xf16>) + outs(%fill : tensor<64x64xf32>) { + ^bb0(%av: f16, %bv: f16, %cv: f32): + %a_ext = arith.extf %av : f16 to f32 + %b_ext = arith.extf %bv : f16 to f32 + %prod = arith.mulf %a_ext, %b_ext : f32 + %sum = arith.addf %cv, %prod : f32 + linalg.yield %sum : f32 + } -> tensor<64x64xf32> + bufferization.materialize_in_destination %result in writable %C + : (tensor<64x64xf32>, memref<64x64xf32>) -> () + return + } + } + + // Transform adapted from xrt/12 (tile-using-pad, no packing). + transform.named_sequence @__transform_main( + %arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + + // Outer tiling: 2x1 forall on M (becomes 2-wavefront air.herd). + // 64/32 = 2 iterations — non-trivial, survives canonicalization. + %outer_tiled, %outer_forall = + transform.structured.tile_using_forall %matmul + tile_sizes [32, 0, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Compute tiling inside forall: 16x16 tiles, K=32. + %tiled, %lm, %ln, %lk = transform.structured.tile_using_for %outer_tiled + tile_sizes [16, 16, 32] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op, !transform.any_op) + + // Pad A and B only (not C — the matmul stores C directly to global). + %padded, %pad, %copy_back = transform.structured.pad %tiled { + padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32], + padding_dimensions = [0, 1, 2], + pack_paddings = [1, 1, 0], + nofold_flags = [1, 1, 0], + copy_back_op = "linalg.copy" + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad + : (!transform.any_op) -> !transform.any_op + + // Promote A,B,C pads to L1 (memory_space=2) via bufferize_to_allocation. + %padded_lhs = transform.get_producer_of_operand %padded[0] + : (!transform.any_op) -> (!transform.any_op) + %buf_a, %new_a = transform.structured.bufferize_to_allocation %padded_lhs + {memory_space = 2, bufferize_destination_only, emit_dealloc} + : !transform.any_op + + %padded_rhs = transform.get_producer_of_operand %padded[1] + : (!transform.any_op) -> (!transform.any_op) + %buf_b, %new_b = transform.structured.bufferize_to_allocation %padded_rhs + {memory_space = 2, bufferize_destination_only, emit_dealloc} + : !transform.any_op + + // Canonicalize. + %func_0 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_0 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_0 : !transform.any_op + + // One-shot bufferize. + %func_1 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %func_buf = transform.bufferization.one_shot_bufferize %func_1 { + allow_return_allocs_from_loops = true + } : (!transform.any_op) -> !transform.any_op + + // Cleanup. + %func_2 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_2 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_2 : !transform.any_op + %func_3 = transform.air.remove_uninitialized_copy %func_2 + : (!transform.any_op) -> !transform.any_op + + // Convert outer forall → parallel → air.herd (now on memrefs). + %forall_2 = transform.structured.match ops{["scf.forall"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %parallel = transform.loop.forall_to_parallel %forall_2 + : (!transform.any_op) -> !transform.any_op + %herd = transform.air.par_to_herd %parallel + : (!transform.any_op) -> !transform.any_op + + transform.yield + } +} diff --git a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py new file mode 100644 index 000000000..e0019c69e --- /dev/null +++ b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py @@ -0,0 +1,137 @@ +"""E2E matmul test exercising the real AIR lowering path. + +Pipeline: + mlir-air-opt (preprocess): + --transform-interpreter + --air-par-to-herd (forall → herd, each tile = 1 wavefront) + --one-shot-bufferize + --air-par-to-launch (outer parallel → launch) + --air-copy-to-dma (memref.copy → air.dma_memcpy_nd) + --air-dma-to-channel (DMA → air.channel.put/get) + --air-to-amdgcn (flatten hierarchy, herd → wavefront index) + --convert-memspace-to-amdgcn (integer memspace → #amdgcn.addr_space) + --convert-linalg-to-amdgcn (linalg + channels → library calls) + then aster pipeline: + --preload → inline → mlir-air-to-asm +""" + +import os +import shutil +import subprocess + +import numpy as np +import pytest + +from aster.execution.helpers import compile_and_run + +MCPU = "gfx942" + +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_MLIR_FILE = os.path.join(_THIS_DIR, "..", "air-to-amdgcn-matmul.mlir") +_LIBRARY_DIR = os.path.join( + _THIS_DIR, "..", "..", "..", "..", "mlir_kernels", "library" +) +_KITTENS_DIR = os.path.join( + _THIS_DIR, "..", "..", "..", "..", "contrib", "kittens", "library" +) + +_LIBRARY_PATHS = [ + os.path.join(_LIBRARY_DIR, "common", f) + for f in [ + "register-init.mlir", + "indexing.mlir", + "indexing_ptr.mlir", + "futures.mlir", + ] +] + [ + os.path.join(_KITTENS_DIR, f) + for f in [ + "compute_16x16_f16.mlir", + "global_16x64_b.mlir", + "lds_16x64_b.mlir", + "lds_mfma_16x64_b.mlir", + ] +] + + +def _find_mlir_air_opt(): + """Find the mlir-air-opt binary.""" + build_path = os.path.join( + _THIS_DIR, "..", "..", "..", "..", "build", "bin", "mlir-air-opt" + ) + if os.path.isfile(build_path): + return os.path.abspath(build_path) + path = shutil.which("mlir-air-opt") + if path: + return path + pytest.skip("mlir-air-opt not found") + + +def _air_preprocess(mlir_text): + """Run the full AIR lowering pipeline before handing to aster.""" + opt = _find_mlir_air_opt() + result = subprocess.run( + [ + opt, + "--transform-interpreter", + "--air-par-to-herd", + "--canonicalize", "--cse", + "--one-shot-bufferize", + "--canonicalize", "--cse", + "--air-par-to-launch=has-air-segment=true", + "--canonicalize", "--cse", + "--air-copy-to-dma", + "--air-dma-to-channel", + "--air-to-amdgcn", + "--canonicalize", + "--convert-memspace-to-amdgcn", + "--convert-linalg-to-amdgcn", + ], + input=mlir_text, + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError( + f"mlir-air-opt AIR preprocessing failed:\n{result.stderr}" + ) + return result.stdout + + +def _post_air_pipeline(library_paths): + libs = ",".join(library_paths) + return ( + "builtin.module(" + "canonicalize," + f"amdgcn-preload-library{{library-paths={libs}}}," + "inline, symbol-dce, canonicalize," + "mlir-air-to-asm)" + ) + + +class TestAirMatmulE2E: + + def test_matmul_64x64(self): + M, N, K = 64, 64, 64 + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np.float16) + B_KxN = (np.random.randn(K, N) * 0.1).astype(np.float16) + B_T = np.ascontiguousarray(B_KxN.T) + # C must be zero-initialized (fill is erased by convert-linalg-to-amdgcn; + # the library's zero_C handles accumulator init per tile). + C = np.zeros(M * N, dtype=np.float32) + + compile_and_run( + file_name=_MLIR_FILE, + kernel_name="matmul_f16_64x64", + input_data=[A.flatten(), B_T.flatten()], + output_data=[C], + pass_pipeline=_post_air_pipeline(_LIBRARY_PATHS), + library_paths=[], + grid_dim=(1, 1, 1), + block_dim=(128, 1, 1), # 2 wavefronts (2x1 herd) + preprocess=_air_preprocess, + ) + + expected = (A.astype(np.float32) @ B_KxN.astype(np.float32)).flatten() + np.testing.assert_allclose(C, expected, rtol=1e-2, atol=1e-2) diff --git a/contrib/mlir-air/tools/mlir-air-opt.cpp b/contrib/mlir-air/tools/mlir-air-opt.cpp index 874bdc2b5..6fd335091 100644 --- a/contrib/mlir-air/tools/mlir-air-opt.cpp +++ b/contrib/mlir-air/tools/mlir-air-opt.cpp @@ -8,11 +8,12 @@ #include "aster/Init.h" #include "mlir/IR/Dialect.h" +#include "mlir/Pass/PassRegistry.h" #include "mlir/Tools/mlir-opt/MlirOptMain.h" namespace mlir::aster::mlir_air { void registerAll(DialectRegistry ®istry); -} +} // namespace mlir::aster::mlir_air int main(int argc, char **argv) { mlir::DialectRegistry registry; @@ -23,7 +24,7 @@ int main(int argc, char **argv) { mlir::aster::registerUpstreamMLIRExternalModels(registry); mlir::aster::initDialects(registry); mlir::aster::registerPasses(); - // mlir-air additions. + // mlir-air additions (registers dialects, passes, pipelines). mlir::aster::mlir_air::registerAll(registry); return mlir::asMainReturnCode( mlir::MlirOptMain(argc, argv, "mlir-air optimizer driver\n", registry)); From 9a3807b0e264fc82854ff4ca3f5b7820123a8b6f Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 02/16] =?UTF-8?q?[mlir-air]=20Wire=20DMA=E2=86=92LDS=20low?= =?UTF-8?q?ering;=20fix=20K-accumulation=20and=20multi-wavefront=20LDS?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove air-dma-to-channel from the pipeline: channels are not mappable to GPU fabric when herds map to wavefronts (no hardware FIFO backpressure). DMAs stay inside the herd body so copy and compute remain colocated. - Add DMA→library-call lowering in ConvertLinalgToAMDGCN: detects num-wavefronts from IR (wavefrontId affine.apply coefficient), allocates nWf*tileSizeBytes of LDS and strides per-wavefront to prevent LDS collision. - Fix K-accumulation: tile_using_for tile_sizes [16,16,0] (no K tiling) so the library's zero_C is called exactly once per output tile. - Add copy_f16_16x64 and mfma_matmul_f16_16x64 library functions to handle K=64 in a single call (4 MFMA panels from two 1024-byte LDS blocks). - Delete ConvertAirChannelToAMDGCN pass (unused, no pipeline invokes it). E2E test (64x64 f16 matmul on gfx942) passes: rtol/atol 1e-2. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Erwei Wang --- contrib/mlir-air/CMakeLists.txt | 7 +- contrib/mlir-air/lib/AirToAMDGCN.cpp | 2 +- .../lib/ConvertAirChannelToAMDGCN.cpp | 314 ------------------ .../mlir-air/lib/ConvertLinalgToAMDGCN.cpp | 163 +++++++++ .../mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp | 2 +- contrib/mlir-air/lib/Init.cpp | 2 - .../mlir-air/test/air-to-amdgcn-matmul.mlir | 100 +++++- .../test/integration/test_air_matmul_e2e.py | 4 +- tools/aster-shlib/CMakeLists.txt | 9 + 9 files changed, 276 insertions(+), 327 deletions(-) delete mode 100644 contrib/mlir-air/lib/ConvertAirChannelToAMDGCN.cpp diff --git a/contrib/mlir-air/CMakeLists.txt b/contrib/mlir-air/CMakeLists.txt index 33a30c8f6..17c35c96e 100644 --- a/contrib/mlir-air/CMakeLists.txt +++ b/contrib/mlir-air/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_library(MlirAirLib lib/AirToAMDGCN.cpp - lib/ConvertAirChannelToAMDGCN.cpp lib/ConvertLinalgToAMDGCN.cpp lib/ConvertMemSpaceToAMDGCN.cpp lib/Init.cpp @@ -47,6 +46,12 @@ target_link_libraries(mlir-air-opt PRIVATE ASTERInit AMDGCNTransforms MLIROptLib + # Re-list archives whose symbols are referenced by MlirAirLib (Init.cpp, + # AirToAMDGCN.cpp) but resolved after the linker's first pass. + AIRDialect + AIRTransformOps + MLIRTensorInferTypeOpInterfaceImpl + MLIRBufferizationTransformOps ) mlir_check_all_link_libraries(mlir-air-opt) install(TARGETS mlir-air-opt) diff --git a/contrib/mlir-air/lib/AirToAMDGCN.cpp b/contrib/mlir-air/lib/AirToAMDGCN.cpp index cd135716d..0ef45d72d 100644 --- a/contrib/mlir-air/lib/AirToAMDGCN.cpp +++ b/contrib/mlir-air/lib/AirToAMDGCN.cpp @@ -15,7 +15,7 @@ // air.execute -> inline body (strip async) // air.wait_all -> erase // -// air.channel.put/get are preserved for the convert-air-channel-to-amdgcn pass. +// air.channel.put/get are not expected after air-to-amdgcn (air-dma-to-channel not used). //===----------------------------------------------------------------------===// #include "air/Dialect/AIR/AIRDialect.h" diff --git a/contrib/mlir-air/lib/ConvertAirChannelToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertAirChannelToAMDGCN.cpp deleted file mode 100644 index 6b4af7b09..000000000 --- a/contrib/mlir-air/lib/ConvertAirChannelToAMDGCN.cpp +++ /dev/null @@ -1,314 +0,0 @@ -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -//===- ConvertAirChannelToAMDGCN.cpp - air.channel -> library calls -------===// -// -// Lowers air.channel.put and air.channel.get to AMDGCN library function calls. -// -// For each air.channel.put/get: -// - The memref operand is decomposed the same way as in ConvertLinalgToAMDGCN: -// * Global memref -> (!sgpr<[?+2]>, byte_stride: index) -// * Promoted buffer (memref.view of memref.alloca with memory space) -// -> LDS byte offset (index) -// - A func.call to a named library function is emitted: -// copy__(src_args..., dst_args...) -// where the shape comes from the channel's memref type. -// -// The channel.put sends data INTO the channel (producer side). -// The channel.get receives data FROM the channel (consumer side). -// Together they represent a point-to-point copy: the put's src is the copy -// source, the get's dst is the copy destination. -// -// This pass matches put/get pairs by channel name and emits a single copy call -// at the get site (the consumer), erasing both ops and the channel declaration. -//===----------------------------------------------------------------------===// - -#include "air/Dialect/AIR/AIRDialect.h" -#include "aster/Dialect/AMDGCN/IR/AMDGCNDialect.h" -#include "aster/Dialect/AMDGCN/IR/AMDGCNOps.h" -#include "aster/Dialect/AMDGCN/IR/AMDGCNTypes.h" -#include "aster/Dialect/LSIR/IR/LSIRDialect.h" -#include "aster/Dialect/LSIR/IR/LSIROps.h" -#include "aster/Interfaces/ModuleOpInterface.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Ptr/IR/PtrDialect.h" -#include "mlir/Dialect/Ptr/IR/PtrOps.h" -#include "mlir/Dialect/Ptr/IR/PtrTypes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Pass/Pass.h" - -using namespace mlir; -using namespace mlir::aster; -using namespace mlir::aster::amdgcn; - -namespace { - -// --------------------------------------------------------------------------- -// Utilities (shared with ConvertLinalgToAMDGCN.cpp) -// --------------------------------------------------------------------------- - -static std::string buildCopyFuncName(MemRefType ty) { - std::string name; - llvm::raw_string_ostream os(name); - os << "copy"; - Type elt = ty.getElementType(); - if (elt.isF16()) - os << "_f16"; - else if (elt.isF32()) - os << "_f32"; - else if (elt.isBF16()) - os << "_bf16"; - else - os << "_unknown"; - auto shape = ty.getShape(); - for (size_t i = 0; i < shape.size(); ++i) - os << (i == 0 ? "_" : "x") << shape[i]; - return name; -} - -static void ensureDecl(OpBuilder &builder, Block &block, Location loc, - StringRef name, FunctionType funcTy) { - for (auto &op : block) - if (auto fn = dyn_cast(&op)) - if (fn.getName() == name) - return; - auto savedIP = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(&block); - auto decl = func::FuncOp::create(builder, loc, name, funcTy); - decl.setPrivate(); - builder.restoreInsertionPoint(savedIP); -} - -static bool isPromotedBuffer(Value v) { - if (auto viewOp = v.getDefiningOp()) { - if (auto allocaOp = viewOp.getSource().getDefiningOp()) - return allocaOp.getMemref().getType().getMemorySpace() != nullptr; - } - if (auto allocOp = v.getDefiningOp()) - return allocOp.getMemref().getType().getMemorySpace() != nullptr; - return false; -} - -static Value emitLDSOffset(OpBuilder &builder, Location loc, Value memrefVal, - DenseMap &ldsCache) { - auto it = ldsCache.find(memrefVal); - if (it != ldsCache.end()) - return it->second; - - int64_t sizeBytes = 0; - Value byteShift; - if (auto viewOp = memrefVal.getDefiningOp()) { - auto allocaOp = viewOp.getSource().getDefiningOp(); - sizeBytes = allocaOp.getMemref().getType().getNumElements(); - byteShift = viewOp.getByteShift(); - } else if (auto allocOp = memrefVal.getDefiningOp()) { - auto mrTy = allocOp.getMemref().getType(); - unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); - sizeBytes = mrTy.getNumElements() * eltBits / 8; - } - auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), - sizeBytes, /*alignment=*/16, - /*offset=*/IntegerAttr{}); - auto ldsOffset = - GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); - Value result = ldsOffset.getResult(); - if (byteShift) - result = builder.create(loc, result, byteShift); - ldsCache[memrefVal] = result; - return result; -} - -static std::pair -decomposeGlobalMemref(OpBuilder &builder, Location loc, Value memref) { - auto mrTy = cast(memref.getType()); - unsigned eltBytes = mrTy.getElementType().getIntOrFloatBitWidth() / 8; - auto metadata = - memref::ExtractStridedMetadataOp::create(builder, loc, memref); - Value baseBuffer = metadata.getBaseBuffer(); - Value offset = metadata.getOffset(); - Value leadingStride = metadata.getStrides()[0]; - Value eltSize = arith::ConstantIndexOp::create(builder, loc, eltBytes); - Value byteStride = - arith::MulIOp::create(builder, loc, leadingStride, eltSize); - Value byteOffset = arith::MulIOp::create(builder, loc, offset, eltSize); - auto addrSpace = cast(mrTy.getMemorySpace()); - auto ptrTy = ptr::PtrType::get(builder.getContext(), addrSpace); - Value ptrVal = ptr::ToPtrOp::create(builder, loc, ptrTy, baseBuffer); - auto sx2Ty = amdgcn::SGPRType::get(builder.getContext(), Register(), - /*size=*/2, /*alignment=*/2); - Value rawPtr = lsir::ToRegOp::create(builder, loc, sx2Ty, ptrVal); - Value ptrFromReg = lsir::FromRegOp::create(builder, loc, ptrTy, rawPtr); - Value adjusted = - ptr::PtrAddOp::create(builder, loc, ptrTy, ptrFromReg, byteOffset); - Value result = lsir::ToRegOp::create(builder, loc, sx2Ty, adjusted); - return {result, byteStride}; -} - -/// Emit decomposed args for a memref operand (either LDS offset or global ptr). -static void emitDecomposedArgs(OpBuilder &builder, Location loc, Value memref, - SmallVectorImpl &callArgs, - SmallVectorImpl &argTypes, - DenseMap &ldsCache) { - auto indexTy = builder.getIndexType(); - auto sx2Ty = amdgcn::SGPRType::get(builder.getContext(), Register(), - /*size=*/2, /*alignment=*/2); - if (isPromotedBuffer(memref)) { - callArgs.push_back(emitLDSOffset(builder, loc, memref, ldsCache)); - argTypes.push_back(indexTy); - } else { - auto [ptrVal, byteStride] = decomposeGlobalMemref(builder, loc, memref); - callArgs.push_back(ptrVal); - argTypes.push_back(sx2Ty); - callArgs.push_back(byteStride); - argTypes.push_back(indexTy); - } -} - -// --------------------------------------------------------------------------- -// Pass -// --------------------------------------------------------------------------- - -struct ConvertAirChannelToAMDGCN - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertAirChannelToAMDGCN) - StringRef getArgument() const override { - return "convert-air-channel-to-amdgcn"; - } - StringRef getDescription() const override { - return "Convert air.channel.put/get pairs to AMDGCN library calls"; - } - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - registry.insert(); - } - - void runOnOperation() override { - Operation *moduleOp = getOperation(); - MLIRContext *ctx = &getContext(); - - Operation *declParent = moduleOp; - if (isa(moduleOp)) - moduleOp->walk([&](amdgcn::ModuleOp m) { declParent = m; }); - auto &declBlock = declParent->getRegion(0).front(); - - OpBuilder builder(ctx); - SmallVector toErase; - DenseMap ldsCache; - - // --------------------------------------------------------------- - // Path 1: Convert air.dma_memcpy_nd directly (no channels). - // --------------------------------------------------------------- - moduleOp->walk([&](xilinx::air::DmaMemcpyNdOp dma) { - Value dst = dma.getDstMemref(); - Value src = dma.getSrcMemref(); - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) - return; - - builder.setInsertionPoint(dma); - Location loc = dma.getLoc(); - - std::string name = buildCopyFuncName(dstTy); - - SmallVector callArgs; - SmallVector argTypes; - - // src: if the DMA has src offsets/sizes/strides, create a subview. - Value srcForDecompose = src; - auto srcOffsets = dma.getSrcOffsets(); - auto srcSizes = dma.getSrcSizes(); - auto srcStrides = dma.getSrcStrides(); - if (!srcOffsets.empty()) { - SmallVector offsets, sizes, strides; - for (auto v : srcOffsets) - offsets.push_back(v); - for (auto v : srcSizes) - sizes.push_back(v); - for (auto v : srcStrides) - strides.push_back(v); - srcForDecompose = memref::SubViewOp::create( - builder, loc, src, offsets, sizes, strides); - } - emitDecomposedArgs(builder, loc, srcForDecompose, callArgs, argTypes, - ldsCache); - // dst args. - emitDecomposedArgs(builder, loc, dst, callArgs, argTypes, ldsCache); - - auto funcTy = builder.getFunctionType(argTypes, {}); - ensureDecl(builder, declBlock, loc, name, funcTy); - func::CallOp::create(builder, loc, name, TypeRange{}, callArgs); - - toErase.push_back(dma); - }); - - // --------------------------------------------------------------- - // Path 2: Convert air.channel.put/get pairs (if channels present). - // --------------------------------------------------------------- - DenseMap> putsByChannel; - moduleOp->walk([&](xilinx::air::ChannelPutOp put) { - putsByChannel[put.getChanName()].push_back(put); - }); - - moduleOp->walk([&](xilinx::air::ChannelGetOp get) { - StringRef chanName = get.getChanName(); - auto it = putsByChannel.find(chanName); - if (it == putsByChannel.end() || it->second.empty()) - return; - - xilinx::air::ChannelPutOp put = it->second.front(); - - Value src = put.getSrc(); - Value dst = get.getDst(); - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) - return; - - builder.setInsertionPoint(get); - Location loc = get.getLoc(); - - std::string name = buildCopyFuncName(dstTy); - - SmallVector callArgs; - SmallVector argTypes; - - // src args. - emitDecomposedArgs(builder, loc, src, callArgs, argTypes, ldsCache); - // dst args. - emitDecomposedArgs(builder, loc, dst, callArgs, argTypes, ldsCache); - - auto funcTy = builder.getFunctionType(argTypes, {}); - ensureDecl(builder, declBlock, loc, name, funcTy); - func::CallOp::create(builder, loc, name, TypeRange{}, callArgs); - - toErase.push_back(get); - toErase.push_back(put); - }); - - for (auto *op : toErase) - op->erase(); - - // Clean up channel declarations that are now unused. - SmallVector deadChannels; - moduleOp->walk([&](xilinx::air::ChannelOp chan) { - deadChannels.push_back(chan); - }); - for (auto *op : deadChannels) - op->erase(); - } -}; - -} // namespace - -namespace mlir::aster::mlir_air { -std::unique_ptr createConvertAirChannelToAMDGCN() { - return std::make_unique(); -} -} // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp index 4030c2b4a..8e654360f 100644 --- a/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp +++ b/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp @@ -278,6 +278,7 @@ struct ConvertLinalgToAMDGCN registry.insert(); } + void runOnOperation() override { Operation *moduleOp = getOperation(); MLIRContext *ctx = &getContext(); @@ -401,6 +402,168 @@ struct ConvertLinalgToAMDGCN builder.restoreInsertionPoint(savedIP); }); + // Detect numWavefronts from the IR. + // air-to-amdgcn emits: wavefrontId = gpu.thread_id x / 64 + // The affine.apply for M row uses: #map()[%loop_var, %wavefrontId] + // = loop_var + wavefrontId * herdTileM + // numWavefronts = totalM / herdTileM. + // We find herdTileM from the coefficient of the wavefrontId symbol in the + // affine map by scanning AffineApplyOp users of the divui result. + int64_t detectedNumWavefronts = 1; + moduleOp->walk([&](arith::DivUIOp divOp) { + if (detectedNumWavefronts > 1) + return; + auto threadId = divOp.getLhs().getDefiningOp(); + if (!threadId || threadId.getDimension() != gpu::Dimension::x) + return; + auto cst = divOp.getRhs().getDefiningOp(); + if (!cst || cst.value() != 64) + return; + // wavefrontId = divOp.getResult(). Find its use in affine.apply. + for (Operation *user : divOp.getResult().getUsers()) { + auto applyOp = dyn_cast(user); + if (!applyOp || applyOp.getAffineMap().getNumResults() != 1) + continue; + // Find the position of wavefrontId in the operand list. + unsigned wfPos = 0; + bool found = false; + for (auto [idx, operand] : llvm::enumerate(applyOp.getMapOperands())) { + if (operand == divOp.getResult()) { + wfPos = idx; + found = true; + break; + } + } + if (!found) + continue; + // Extract coefficient of symbol/dim at wfPos in the affine map. + AffineMap map = applyOp.getAffineMap(); + // The map operands are pure symbols in this context. + int64_t stride = 0; + AffineExpr expr = map.getResult(0); + expr.walk([&](AffineExpr e) { + auto mul = dyn_cast(e); + if (!mul || mul.getKind() != AffineExprKind::Mul) + return; + auto sym = dyn_cast(mul.getLHS()); + auto con = dyn_cast(mul.getRHS()); + if (!sym || !con) + return; + if (sym.getPosition() == wfPos) + stride = con.getValue(); + }); + if (stride <= 0) + continue; + // Get totalM from DMA src memref. + moduleOp->walk([&](xilinx::air::DmaMemcpyNdOp dma2) { + if (detectedNumWavefronts > 1) + return; + Value src = dma2.getSrcMemref(); + auto srcTy = dyn_cast(src.getType()); + if (!srcTy || srcTy.getRank() < 1 || srcTy.getDimSize(0) <= 0) + return; + detectedNumWavefronts = srcTy.getDimSize(0) / stride; + }); + } + }); + + // Pre-allocate LDS for air.dma_memcpy_nd destinations (no-channel path). + // Must run before linalg op processing so matmul hits the same ldsCache. + // Allocate numWavefronts * tileSizeBytes and stripe by wavefront_id. + moduleOp->walk([&](xilinx::air::DmaMemcpyNdOp dma) { + Value dst = dma.getDstMemref(); + if (!isPromotedBuffer(dst) || ldsCache.count(dst)) + return; + auto funcOp = dma->getParentOfType(); + if (!funcOp || funcOp.empty()) + return; + auto savedIP = builder.saveInsertionPoint(); + builder.setInsertionPointToStart(&funcOp.front()); + Location loc = funcOp.getLoc(); + + int64_t tileSizeBytes = 0; + if (auto allocOp = dst.getDefiningOp()) { + auto mrTy = allocOp.getMemref().getType(); + unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); + tileSizeBytes = mrTy.getNumElements() * eltBits / 8; + } + int64_t nWf = detectedNumWavefronts; + auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), + nWf * tileSizeBytes, /*alignment=*/16, + /*offset=*/IntegerAttr{}); + auto ldsBaseOffset = + GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); + Value result = ldsBaseOffset.getResult(); + if (nWf > 1) { + Value wavefrontSize = arith::ConstantIndexOp::create(builder, loc, 64); + Value threadIdX = + gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x); + Value wavefrontId = + arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize); + Value tileSizeVal = + arith::ConstantIndexOp::create(builder, loc, tileSizeBytes); + Value wavefrontOffset = + arith::MulIOp::create(builder, loc, wavefrontId, tileSizeVal); + result = arith::AddIOp::create(builder, loc, result, wavefrontOffset); + } + ldsCache[dst] = result; + builder.restoreInsertionPoint(savedIP); + }); + + // Convert air.dma_memcpy_nd directly (Global→LDS only; no channels). + // Emit: copy__(base_sgpr, stride, row_off, col_off, lds_dst) + moduleOp->walk([&](xilinx::air::DmaMemcpyNdOp dma) { + Value dst = dma.getDstMemref(); + Value src = dma.getSrcMemref(); + bool dstIsLDS = isPromotedBuffer(dst); + if (!dstIsLDS) + return; // Only handle Global→LDS DMAs here. + + auto dstTy = cast(dst.getType()); + builder.setInsertionPoint(dma); + Location loc = dma.getLoc(); + + std::string name = buildFuncName("copy", dstTy); + + auto indexTy = builder.getIndexType(); + auto sx2Ty = amdgcn::SGPRType::get(ctx, Register(), /*size=*/2, + /*alignment=*/2); + SmallVector callArgs; + SmallVector argTypes; + + // Decompose BASE src memref → (sgpr_base, byte_stride). + auto [ptrVal, byteStride] = decomposeGlobalMemref(builder, loc, src); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); + argTypes.push_back(indexTy); + + // Src tile offsets (row, col) from DMA operands. + auto srcOffsets = dma.getSrcOffsets(); + if (srcOffsets.size() >= 2) { + callArgs.push_back(srcOffsets[0]); + argTypes.push_back(indexTy); + callArgs.push_back(srcOffsets[1]); + argTypes.push_back(indexTy); + } else { + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + callArgs.push_back(zero); + argTypes.push_back(indexTy); + callArgs.push_back(zero); + argTypes.push_back(indexTy); + } + + // LDS dst offset from cache. + assert(ldsCache.count(dst) && "DMA dst LDS not pre-allocated"); + callArgs.push_back(ldsCache[dst]); + argTypes.push_back(indexTy); + + auto funcTy = builder.getFunctionType(argTypes, {}); + ensureDecl(builder, declBlock, loc, name, funcTy); + func::CallOp::create(builder, loc, name, TypeRange{}, callArgs); + toErase.push_back(dma); + }); + // Now process linalg ops — they'll hit the ldsCache for shared allocs. moduleOp->walk([&](linalg::FillOp op) { replaceWithCall(builder, declBlock, op, "fill", toErase, ldsCache); diff --git a/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp index b47497036..a03666913 100644 --- a/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp +++ b/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp @@ -16,7 +16,7 @@ // integer memory spaces) and aster's AMDGCN decomposition passes (which // require #amdgcn.addr_space attributes for pointer generation). // -// Run after air-to-amdgcn and before convert-air-channel-to-amdgcn. +// Run after air-to-amdgcn and before convert-linalg-to-amdgcn. //===----------------------------------------------------------------------===// #include "aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h" diff --git a/contrib/mlir-air/lib/Init.cpp b/contrib/mlir-air/lib/Init.cpp index 7a0226dc5..0c98a92ab 100644 --- a/contrib/mlir-air/lib/Init.cpp +++ b/contrib/mlir-air/lib/Init.cpp @@ -52,7 +52,6 @@ namespace air_xform_reg { namespace mlir::aster::mlir_air { std::unique_ptr createAirToAMDGCN(); -std::unique_ptr createConvertAirChannelToAMDGCN(); std::unique_ptr createConvertLinalgToAMDGCN(); std::unique_ptr createConvertMemSpaceToAMDGCN(); void registerPipelines(); @@ -110,7 +109,6 @@ void registerAll(DialectRegistry ®istry) { // Aster-specific passes. registerPass([] { return createAirToAMDGCN(); }); registerPass([] { return createConvertLinalgToAMDGCN(); }); - registerPass([] { return createConvertAirChannelToAMDGCN(); }); registerPass([] { return createConvertMemSpaceToAMDGCN(); }); // mlir-air pipelines. diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir index 40c2eeab8..1fbe03652 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir @@ -2,7 +2,6 @@ // RUN: --transform-interpreter --canonicalize --cse \ // RUN: --air-par-to-launch="has-air-segment=true" --canonicalize --cse \ // RUN: --air-copy-to-dma \ -// RUN: --air-dma-to-channel \ // RUN: --air-to-amdgcn --canonicalize \ // RUN: --convert-memspace-to-amdgcn \ // RUN: --convert-linalg-to-amdgcn \ @@ -101,6 +100,94 @@ module attributes {transform.with_named_sequence} { } func.func private @fill_f16_16x32(%val: f16, %lds_dst: index) { return } + + // Copy a 16x64 f16 tile from global memory to LDS. + // Loads two consecutive 16x32-element (= 16x64-byte) panels at col and col+32. + func.func private @copy_f16_16x64( + %src_ptr: !sx2, %src_stride: index, + %row_offset: index, %col_offset: index, + %lds_dst: index) { + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %col1 = arith.addi %col_offset, %c32 : index + %lds_dst1 = arith.addi %lds_dst, %c1024 : index + %ptr = func.call @prepare_ptr(%src_ptr) : (!sx2) -> !aster_utils.any + %gfut0 = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col_offset, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t0, %t1 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst, %gfut0) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + %gfut1 = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col1, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t2, %t3 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst1, %gfut1) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + amdgcn.wait deps %t0 : !lds_write_token + amdgcn.wait deps %t1 : !lds_write_token + amdgcn.wait deps %t2 : !lds_write_token + amdgcn.wait deps %t3 : !lds_write_token + return + } + + // 16x16 output matmul using two 16x32 A tiles and two 16x32 B tiles in LDS. + // lds_A points to the first 16x32 block; lds_A2 = lds_A + 1024 is the second. + // Similarly for lds_B. Reads 4 panels at k_byte_offsets 0,32 within each block. + func.func private @mfma_matmul_f16_16x64( + %lds_A: index, %lds_B: index, + %C_ptr: !sx2, %C_stride: index, + %C_row_offset: index, %C_col_offset: index) { + %C_prepared = func.call @prepare_ptr(%C_ptr) : (!sx2) -> !aster_utils.any + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + // Second half of K: LDS base + 1024 bytes. + %lds_A2 = arith.addi %lds_A, %c1024 : index + %lds_B2 = arith.addi %lds_B, %c1024 : index + %acc = func.call @zero_C() : () -> !ax4 + // K panel 0 (k=0..15): from first block, offset 0. + %A0f = func.call @load_lds_A_swizzled(%lds_A, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A0 = func.call @get_lds_read_value_vx2(%A0f) : (!future_lds_read) -> !vx2 + %B0f = func.call @load_lds_B_swizzled(%lds_B, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B0 = func.call @get_lds_read_value_vx2(%B0f) : (!future_lds_read) -> !vx2 + %acc0 = func.call @mfma_f32_16x16x16_f16(%A0, %B0, %acc) + : (!vx2, !vx2, !ax4) -> !ax4 + // K panel 1 (k=16..31): from first block, offset 32. + %A1f = func.call @load_lds_A_swizzled(%lds_A, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A1 = func.call @get_lds_read_value_vx2(%A1f) : (!future_lds_read) -> !vx2 + %B1f = func.call @load_lds_B_swizzled(%lds_B, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B1 = func.call @get_lds_read_value_vx2(%B1f) : (!future_lds_read) -> !vx2 + %acc1 = func.call @mfma_f32_16x16x16_f16(%A1, %B1, %acc0) + : (!vx2, !vx2, !ax4) -> !ax4 + // K panel 2 (k=32..47): from second block, offset 0. + %A2f = func.call @load_lds_A_swizzled(%lds_A2, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A2 = func.call @get_lds_read_value_vx2(%A2f) : (!future_lds_read) -> !vx2 + %B2f = func.call @load_lds_B_swizzled(%lds_B2, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B2 = func.call @get_lds_read_value_vx2(%B2f) : (!future_lds_read) -> !vx2 + %acc2 = func.call @mfma_f32_16x16x16_f16(%A2, %B2, %acc1) + : (!vx2, !vx2, !ax4) -> !ax4 + // K panel 3 (k=48..63): from second block, offset 32. + %A3f = func.call @load_lds_A_swizzled(%lds_A2, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A3 = func.call @get_lds_read_value_vx2(%A3f) : (!future_lds_read) -> !vx2 + %B3f = func.call @load_lds_B_swizzled(%lds_B2, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B3 = func.call @get_lds_read_value_vx2(%B3f) : (!future_lds_read) -> !vx2 + %acc3 = func.call @mfma_f32_16x16x16_f16(%A3, %B3, %acc2) + : (!vx2, !vx2, !ax4) -> !ax4 + func.call @store_global_C_mfma_f32_16x16x16_f16( + %acc3, %C_prepared, %C_row_offset, %C_col_offset, %C_stride) + : (!ax4, !aster_utils.any, index, index, index) -> () + return + } + + func.func private @fill_f16_16x64(%val: f16, %lds_dst: index) { return } } amdgcn.module @matmul_mod target = #amdgcn.target isa = #amdgcn.isa { @@ -149,11 +236,14 @@ module attributes {transform.with_named_sequence} { tile_sizes [32, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - // Compute tiling inside forall: 16x16 tiles, K=32. - %tiled, %lm, %ln, %lk = transform.structured.tile_using_for %outer_tiled - tile_sizes [16, 16, 32] + // Compute tiling inside forall: 16x16 M×N tiles, no K tiling. + // mfma_matmul_f16_16x32 calls zero_C internally so K must not be tiled + // across loop iterations (each call would reset the accumulator). + // The library internally handles K by reading two 16x16 panels from LDS. + %tiled, %lm, %ln = transform.structured.tile_using_for %outer_tiled + tile_sizes [16, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, - !transform.any_op, !transform.any_op) + !transform.any_op) // Pad A and B only (not C — the matmul stores C directly to global). %padded, %pad, %copy_back = transform.structured.pad %tiled { diff --git a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py index e0019c69e..f4ad28aac 100644 --- a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py +++ b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py @@ -7,10 +7,9 @@ --one-shot-bufferize --air-par-to-launch (outer parallel → launch) --air-copy-to-dma (memref.copy → air.dma_memcpy_nd) - --air-dma-to-channel (DMA → air.channel.put/get) --air-to-amdgcn (flatten hierarchy, herd → wavefront index) --convert-memspace-to-amdgcn (integer memspace → #amdgcn.addr_space) - --convert-linalg-to-amdgcn (linalg + channels → library calls) + --convert-linalg-to-amdgcn (air.dma_memcpy_nd + linalg ops → library calls) then aster pipeline: --preload → inline → mlir-air-to-asm """ @@ -81,7 +80,6 @@ def _air_preprocess(mlir_text): "--air-par-to-launch=has-air-segment=true", "--canonicalize", "--cse", "--air-copy-to-dma", - "--air-dma-to-channel", "--air-to-amdgcn", "--canonicalize", "--convert-memspace-to-amdgcn", diff --git a/tools/aster-shlib/CMakeLists.txt b/tools/aster-shlib/CMakeLists.txt index 1cd08a933..1e08bea88 100644 --- a/tools/aster-shlib/CMakeLists.txt +++ b/tools/aster-shlib/CMakeLists.txt @@ -52,9 +52,18 @@ if(ASTER_ENABLE_MLIR_AIR) -Wl,--whole-archive $ -Wl,--no-whole-archive) endif() target_link_libraries(ASTER PRIVATE + AIRDialect + AIRTransformOps + AIRTransformPasses + AIRConversionPasses + MLIRBufferizationDialect + MLIRBufferizationTransformOps + MLIRBufferizationTransforms MLIRLinalgDialect MLIRLinalgTransformOps MLIRLinalgTransforms + MLIRTensorInferTypeOpInterfaceImpl + MLIRTensorTransforms MLIRTransformDialect MLIRTransformDialectTransforms MLIRSCFTransformOps From 9a8bece6105ed9dc42102a81f26cefebd2fe1a1b Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 03/16] [mlir-air] Clean up link deps: use CMake transitive resolution Remove manually duplicated AIR/MLIR link deps from mlir-air-opt and aster-shlib. MlirAirLib declares them as LINK_LIBS PUBLIC, so CMake propagates them automatically. The aster-shlib --whole-archive hack bypassed this; fixed by also linking MlirAirLib normally for dep resolution. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- contrib/mlir-air/CMakeLists.txt | 6 ------ tools/aster-shlib/CMakeLists.txt | 20 +++----------------- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/contrib/mlir-air/CMakeLists.txt b/contrib/mlir-air/CMakeLists.txt index 17c35c96e..6fee25a34 100644 --- a/contrib/mlir-air/CMakeLists.txt +++ b/contrib/mlir-air/CMakeLists.txt @@ -46,12 +46,6 @@ target_link_libraries(mlir-air-opt PRIVATE ASTERInit AMDGCNTransforms MLIROptLib - # Re-list archives whose symbols are referenced by MlirAirLib (Init.cpp, - # AirToAMDGCN.cpp) but resolved after the linker's first pass. - AIRDialect - AIRTransformOps - MLIRTensorInferTypeOpInterfaceImpl - MLIRBufferizationTransformOps ) mlir_check_all_link_libraries(mlir-air-opt) install(TARGETS mlir-air-opt) diff --git a/tools/aster-shlib/CMakeLists.txt b/tools/aster-shlib/CMakeLists.txt index 1e08bea88..7ca64ef16 100644 --- a/tools/aster-shlib/CMakeLists.txt +++ b/tools/aster-shlib/CMakeLists.txt @@ -51,23 +51,9 @@ if(ASTER_ENABLE_MLIR_AIR) target_link_libraries(ASTER PRIVATE -Wl,--whole-archive $ -Wl,--no-whole-archive) endif() - target_link_libraries(ASTER PRIVATE - AIRDialect - AIRTransformOps - AIRTransformPasses - AIRConversionPasses - MLIRBufferizationDialect - MLIRBufferizationTransformOps - MLIRBufferizationTransforms - MLIRLinalgDialect - MLIRLinalgTransformOps - MLIRLinalgTransforms - MLIRTensorInferTypeOpInterfaceImpl - MLIRTensorTransforms - MLIRTransformDialect - MLIRTransformDialectTransforms - MLIRSCFTransformOps - ) + # Normal link for CMake transitive dep resolution (--whole-archive + # bypasses CMake's LINK_LIBS propagation). + target_link_libraries(ASTER PRIVATE MlirAirLib) endif() # Place ASTER.dylib/.so alongside the Python extensions in the build tree so # that @loader_path RPATH on the extensions resolves at build time (stubgen From e82320b6ed7dd31977a79a9a03d9a9b419a63b16 Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 04/16] [mlir-air] Clean up transform script and ConvertLinalgToAMDGCN - Split transform sequence into separate file (air-to-amdgcn-matmul-transform.mlir) and load via --transform-preload-library, matching upstream MLIR convention. No transform ops in the payload or output; no cleanup code needed in passes. - Register transform::registerPreloadLibraryPass() in Init.cpp. - Remove emit_dealloc from bufferize_to_allocation (deallocs not needed). - Remove manual DCE for alloca/view/dealloc from ConvertLinalgToAMDGCN. - Remove global->global memref.copy forwarding hack (fill writes to %C directly via bufferization.to_tensor, so no temp alloc/copy pattern). - Remove transform dialect cleanup from ConvertLinalgToAMDGCN. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- .../mlir-air/lib/ConvertLinalgToAMDGCN.cpp | 56 ++---------- contrib/mlir-air/lib/Init.cpp | 1 + .../test/air-to-amdgcn-matmul-transform.mlir | 88 ++++++++++++++++++ .../mlir-air/test/air-to-amdgcn-matmul.mlir | 91 +------------------ .../test/integration/test_air_matmul_e2e.py | 8 +- 5 files changed, 104 insertions(+), 140 deletions(-) create mode 100644 contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir diff --git a/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp index 8e654360f..8a808364b 100644 --- a/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp +++ b/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp @@ -700,61 +700,19 @@ struct ConvertLinalgToAMDGCN for (auto *op : toErase) op->erase(); - // Erase linalg.fill on global buffers — the library handles zero-init. + // Erase linalg.fill on global (non-LDS) buffers. + // The library's zero_C handles accumulator init, so the fill is redundant. + // It must be erased because the aster backend cannot lower linalg.fill + // on global memrefs. SmallVector globalFills; moduleOp->walk([&](linalg::FillOp fill) { - for (Value out : fill.getDpsInits()) { - if (auto mrTy = dyn_cast(out.getType())) - if (!isPromotedBuffer(out)) - globalFills.push_back(fill); - } + for (Value out : fill.getDpsInits()) + if (isa(out.getType()) && !isPromotedBuffer(out)) + globalFills.push_back(fill); }); for (auto fill : globalFills) fill->erase(); - // Eliminate global→global memref.copy by forwarding the destination. - // This handles the `memref.copy %alloc, %arg` from materialize_in_destination. - moduleOp->walk([&](memref::CopyOp copy) { - Value src = copy.getSource(); - Value dst = copy.getTarget(); - if (isPromotedBuffer(src) || isPromotedBuffer(dst)) - return; - // Both are global — replace all uses of src with dst and erase. - if (auto allocOp = src.getDefiningOp()) { - src.replaceAllUsesWith(dst); - copy->erase(); - if (allocOp->use_empty()) - allocOp->erase(); - } - }); - - // DCE unused alloca/view/dealloc. - SmallVector deadOps; - moduleOp->walk([&](Operation *op) { - if (isa(op) || - (isa(op) && op->use_empty())) - deadOps.push_back(op); - }); - for (auto *op : deadOps) - op->erase(); - deadOps.clear(); - moduleOp->walk([&](memref::AllocaOp op) { - if (op->use_empty()) - deadOps.push_back(op); - }); - for (auto *op : deadOps) - op->erase(); - - // Erase transform dialect ops. - if (auto builtinMod = dyn_cast(moduleOp)) { - SmallVector transformOps; - for (auto &op : builtinMod.getBody()->getOperations()) - if (op.getDialect() && op.getDialect()->getNamespace() == "transform") - transformOps.push_back(&op); - for (auto *op : transformOps) - op->erase(); - builtinMod->removeAttr("transform.with_named_sequence"); - } } }; diff --git a/contrib/mlir-air/lib/Init.cpp b/contrib/mlir-air/lib/Init.cpp index 0c98a92ab..068c2480a 100644 --- a/contrib/mlir-air/lib/Init.cpp +++ b/contrib/mlir-air/lib/Init.cpp @@ -88,6 +88,7 @@ void registerAll(DialectRegistry ®istry) { registerLinalgPasses(); memref::registerMemRefPasses(); transform::registerInterpreterPass(); + transform::registerPreloadLibraryPass(); // AIR transform ops extension (air.transform.*). xilinx::air::registerTransformDialectExtension(registry); diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir new file mode 100644 index 000000000..448abf26a --- /dev/null +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir @@ -0,0 +1,88 @@ +// Transform sequence for 64x64 matmul: tile, pad, bufferize, map to AIR herd. +// Adapted from xrt/12 (tile-using-pad, no packing). + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + + // Outer tiling: 2x1 forall on M (becomes 2-wavefront air.herd). + // 64/32 = 2 iterations — non-trivial, survives canonicalization. + %outer_tiled, %outer_forall = + transform.structured.tile_using_forall %matmul + tile_sizes [32, 0, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Compute tiling inside forall: 16x16 M×N tiles, no K tiling. + // mfma_matmul_f16_16x32 calls zero_C internally so K must not be tiled + // across loop iterations (each call would reset the accumulator). + // The library internally handles K by reading two 16x16 panels from LDS. + %tiled, %lm, %ln = transform.structured.tile_using_for %outer_tiled + tile_sizes [16, 16, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op) + + // Pad A and B only (not C — the matmul stores C directly to global). + %padded, %pad, %copy_back = transform.structured.pad %tiled { + padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32], + padding_dimensions = [0, 1, 2], + pack_paddings = [1, 1, 0], + nofold_flags = [1, 1, 0], + copy_back_op = "linalg.copy" + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad + : (!transform.any_op) -> !transform.any_op + + // Promote A,B pads to L1 (memory_space=2) via bufferize_to_allocation. + %padded_lhs = transform.get_producer_of_operand %padded[0] + : (!transform.any_op) -> (!transform.any_op) + %buf_a, %new_a = transform.structured.bufferize_to_allocation %padded_lhs + {memory_space = 2, bufferize_destination_only} + : !transform.any_op + + %padded_rhs = transform.get_producer_of_operand %padded[1] + : (!transform.any_op) -> (!transform.any_op) + %buf_b, %new_b = transform.structured.bufferize_to_allocation %padded_rhs + {memory_space = 2, bufferize_destination_only} + : !transform.any_op + + // Canonicalize. + %func_0 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_0 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_0 : !transform.any_op + + // One-shot bufferize. + %func_1 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %func_buf = transform.bufferization.one_shot_bufferize %func_1 { + allow_return_allocs_from_loops = true + } : (!transform.any_op) -> !transform.any_op + + // Cleanup. + %func_2 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_2 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_2 : !transform.any_op + %func_3 = transform.air.remove_uninitialized_copy %func_2 + : (!transform.any_op) -> !transform.any_op + + // Convert outer forall → parallel → air.herd (now on memrefs). + %forall_2 = transform.structured.match ops{["scf.forall"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %parallel = transform.loop.forall_to_parallel %forall_2 + : (!transform.any_op) -> !transform.any_op + %herd = transform.air.par_to_herd %parallel + : (!transform.any_op) -> !transform.any_op + + transform.yield + } +} diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir index 1fbe03652..cf53bd290 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir @@ -1,4 +1,5 @@ // RUN: mlir-air-opt %s \ +// RUN: --transform-preload-library="transform-library-paths=%p/air-to-amdgcn-matmul-transform.mlir" \ // RUN: --transform-interpreter --canonicalize --cse \ // RUN: --air-par-to-launch="has-air-segment=true" --canonicalize --cse \ // RUN: --air-copy-to-dma \ @@ -36,7 +37,7 @@ !future_lds_read = !aster_utils.struct> !future_global_read = !aster_utils.struct> -module attributes {transform.with_named_sequence} { +module { amdgcn.library @linalg_lib isa = [#amdgcn.isa] { func.func private @zero_C() -> !ax4 func.func private @mfma_f32_16x16x16_f16(!vx2, !vx2, !ax4) -> !ax4 @@ -198,8 +199,8 @@ module attributes {transform.with_named_sequence} { %cst = arith.constant 0.000000e+00 : f32 %a = bufferization.to_tensor %A restrict writable : memref<64x64xf16> to tensor<64x64xf16> %b = bufferization.to_tensor %B restrict writable : memref<64x64xf16> to tensor<64x64xf16> - %empty = tensor.empty() : tensor<64x64xf32> - %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<64x64xf32>) -> tensor<64x64xf32> + %c = bufferization.to_tensor %C restrict writable : memref<64x64xf32> to tensor<64x64xf32> + %fill = linalg.fill ins(%cst : f32) outs(%c : tensor<64x64xf32>) -> tensor<64x64xf32> // matmul_transpose_b: C[m,n] += A[m,k] * B[n,k] %result = linalg.generic { indexing_maps = [ @@ -223,88 +224,4 @@ module attributes {transform.with_named_sequence} { } } - // Transform adapted from xrt/12 (tile-using-pad, no packing). - transform.named_sequence @__transform_main( - %arg0: !transform.any_op {transform.readonly}) { - %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 - : (!transform.any_op) -> !transform.any_op - - // Outer tiling: 2x1 forall on M (becomes 2-wavefront air.herd). - // 64/32 = 2 iterations — non-trivial, survives canonicalization. - %outer_tiled, %outer_forall = - transform.structured.tile_using_forall %matmul - tile_sizes [32, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Compute tiling inside forall: 16x16 M×N tiles, no K tiling. - // mfma_matmul_f16_16x32 calls zero_C internally so K must not be tiled - // across loop iterations (each call would reset the accumulator). - // The library internally handles K by reading two 16x16 panels from LDS. - %tiled, %lm, %ln = transform.structured.tile_using_for %outer_tiled - tile_sizes [16, 16, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, - !transform.any_op) - - // Pad A and B only (not C — the matmul stores C directly to global). - %padded, %pad, %copy_back = transform.structured.pad %tiled { - padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32], - padding_dimensions = [0, 1, 2], - pack_paddings = [1, 1, 0], - nofold_flags = [1, 1, 0], - copy_back_op = "linalg.copy" - } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, - !transform.any_op) - %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad - : (!transform.any_op) -> !transform.any_op - - // Promote A,B,C pads to L1 (memory_space=2) via bufferize_to_allocation. - %padded_lhs = transform.get_producer_of_operand %padded[0] - : (!transform.any_op) -> (!transform.any_op) - %buf_a, %new_a = transform.structured.bufferize_to_allocation %padded_lhs - {memory_space = 2, bufferize_destination_only, emit_dealloc} - : !transform.any_op - - %padded_rhs = transform.get_producer_of_operand %padded[1] - : (!transform.any_op) -> (!transform.any_op) - %buf_b, %new_b = transform.structured.bufferize_to_allocation %padded_rhs - {memory_space = 2, bufferize_destination_only, emit_dealloc} - : !transform.any_op - - // Canonicalize. - %func_0 = transform.structured.match ops{["func.func"]} in %arg0 - : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_0 { - transform.apply_patterns.linalg.tiling_canonicalization - transform.apply_patterns.scf.for_loop_canonicalization - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_0 : !transform.any_op - - // One-shot bufferize. - %func_1 = transform.structured.match ops{["func.func"]} in %arg0 - : (!transform.any_op) -> !transform.any_op - %func_buf = transform.bufferization.one_shot_bufferize %func_1 { - allow_return_allocs_from_loops = true - } : (!transform.any_op) -> !transform.any_op - - // Cleanup. - %func_2 = transform.structured.match ops{["func.func"]} in %arg0 - : (!transform.any_op) -> !transform.any_op - transform.apply_patterns to %func_2 { - transform.apply_patterns.canonicalization - } : !transform.any_op - transform.apply_cse to %func_2 : !transform.any_op - %func_3 = transform.air.remove_uninitialized_copy %func_2 - : (!transform.any_op) -> !transform.any_op - - // Convert outer forall → parallel → air.herd (now on memrefs). - %forall_2 = transform.structured.match ops{["scf.forall"]} in %arg0 - : (!transform.any_op) -> !transform.any_op - %parallel = transform.loop.forall_to_parallel %forall_2 - : (!transform.any_op) -> !transform.any_op - %herd = transform.air.par_to_herd %parallel - : (!transform.any_op) -> !transform.any_op - - transform.yield - } } diff --git a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py index f4ad28aac..dde70e1d9 100644 --- a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py +++ b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py @@ -27,6 +27,7 @@ _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _MLIR_FILE = os.path.join(_THIS_DIR, "..", "air-to-amdgcn-matmul.mlir") +_TRANSFORM_FILE = os.path.join(_THIS_DIR, "..", "air-to-amdgcn-matmul-transform.mlir") _LIBRARY_DIR = os.path.join( _THIS_DIR, "..", "..", "..", "..", "mlir_kernels", "library" ) @@ -72,6 +73,7 @@ def _air_preprocess(mlir_text): result = subprocess.run( [ opt, + f"--transform-preload-library=transform-library-paths={_TRANSFORM_FILE}", "--transform-interpreter", "--air-par-to-herd", "--canonicalize", "--cse", @@ -86,13 +88,11 @@ def _air_preprocess(mlir_text): "--convert-linalg-to-amdgcn", ], input=mlir_text, - capture_output=True, - text=True, + capture_output=True, text=True, ) if result.returncode != 0: raise RuntimeError( - f"mlir-air-opt AIR preprocessing failed:\n{result.stderr}" - ) + f"mlir-air-opt AIR preprocessing failed:\n{result.stderr}") return result.stdout From d20501fd43ef1a83f5ed03039ca6c6ea799ea073 Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 05/16] =?UTF-8?q?[mlir-air]=20Rename=20ConvertLinalgToAMDG?= =?UTF-8?q?CN=20=E2=86=92=20ConvertToAMDGCNLibraryCalls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pass converts linalg ops, air.dma_memcpy_nd, and air.channel.put/get to AMDGCN library calls — not just linalg. Rename to reflect its scope: convert-linalg-to-amdgcn → convert-to-amdgcn-library-calls Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- contrib/mlir-air/CMakeLists.txt | 2 +- contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp | 2 +- ...MDGCN.cpp => ConvertToAMDGCNLibraryCalls.cpp} | 16 ++++++++-------- contrib/mlir-air/lib/Init.cpp | 4 ++-- contrib/mlir-air/test/air-to-amdgcn-matmul.mlir | 2 +- .../test/integration/test_air_matmul_e2e.py | 6 +++--- .../test/integration/test_linalg_matmul_e2e.py | 2 +- contrib/mlir-air/test/linalg-to-amdgcn.mlir | 2 +- 8 files changed, 18 insertions(+), 18 deletions(-) rename contrib/mlir-air/lib/{ConvertLinalgToAMDGCN.cpp => ConvertToAMDGCNLibraryCalls.cpp} (98%) diff --git a/contrib/mlir-air/CMakeLists.txt b/contrib/mlir-air/CMakeLists.txt index 6fee25a34..234399296 100644 --- a/contrib/mlir-air/CMakeLists.txt +++ b/contrib/mlir-air/CMakeLists.txt @@ -1,6 +1,6 @@ add_mlir_library(MlirAirLib lib/AirToAMDGCN.cpp - lib/ConvertLinalgToAMDGCN.cpp + lib/ConvertToAMDGCNLibraryCalls.cpp lib/ConvertMemSpaceToAMDGCN.cpp lib/Init.cpp lib/Pipelines.cpp diff --git a/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp index a03666913..c20c75f3f 100644 --- a/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp +++ b/contrib/mlir-air/lib/ConvertMemSpaceToAMDGCN.cpp @@ -16,7 +16,7 @@ // integer memory spaces) and aster's AMDGCN decomposition passes (which // require #amdgcn.addr_space attributes for pointer generation). // -// Run after air-to-amdgcn and before convert-linalg-to-amdgcn. +// Run after air-to-amdgcn and before convert-to-amdgcn-library-calls. //===----------------------------------------------------------------------===// #include "aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h" diff --git a/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp similarity index 98% rename from contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp rename to contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp index 8a808364b..a508c7899 100644 --- a/contrib/mlir-air/lib/ConvertLinalgToAMDGCN.cpp +++ b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp @@ -4,7 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -//===- ConvertLinalgToAMDGCN.cpp - linalg ops -> AMDGCN library calls -----===// +//===- ConvertToAMDGCNLibraryCalls.cpp - ops -> AMDGCN library calls ------===// #include "air/Dialect/AIR/AIRDialect.h" #include "aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h" @@ -264,13 +264,13 @@ static void replaceWithCall(OpBuilder &builder, Block &declBlock, Operation *op, toErase.push_back(op); } -struct ConvertLinalgToAMDGCN - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertLinalgToAMDGCN) - StringRef getArgument() const override { return "convert-linalg-to-amdgcn"; } + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToAMDGCNLibraryCalls) + StringRef getArgument() const override { return "convert-to-amdgcn-library-calls"; } StringRef getDescription() const override { - return "Convert tiled linalg ops to AMDGCN library calls"; + return "Convert linalg/AIR ops to AMDGCN library calls"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); @@ -719,7 +719,7 @@ struct ConvertLinalgToAMDGCN } // namespace namespace mlir::aster::mlir_air { -std::unique_ptr createConvertLinalgToAMDGCN() { - return std::make_unique(); +std::unique_ptr createConvertToAMDGCNLibraryCalls() { + return std::make_unique(); } } // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/lib/Init.cpp b/contrib/mlir-air/lib/Init.cpp index 068c2480a..13e748b5d 100644 --- a/contrib/mlir-air/lib/Init.cpp +++ b/contrib/mlir-air/lib/Init.cpp @@ -52,7 +52,7 @@ namespace air_xform_reg { namespace mlir::aster::mlir_air { std::unique_ptr createAirToAMDGCN(); -std::unique_ptr createConvertLinalgToAMDGCN(); +std::unique_ptr createConvertToAMDGCNLibraryCalls(); std::unique_ptr createConvertMemSpaceToAMDGCN(); void registerPipelines(); @@ -109,7 +109,7 @@ void registerAll(DialectRegistry ®istry) { // Aster-specific passes. registerPass([] { return createAirToAMDGCN(); }); - registerPass([] { return createConvertLinalgToAMDGCN(); }); + registerPass([] { return createConvertToAMDGCNLibraryCalls(); }); registerPass([] { return createConvertMemSpaceToAMDGCN(); }); // mlir-air pipelines. diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir index cf53bd290..7f9a90b5b 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir @@ -5,7 +5,7 @@ // RUN: --air-copy-to-dma \ // RUN: --air-to-amdgcn --canonicalize \ // RUN: --convert-memspace-to-amdgcn \ -// RUN: --convert-linalg-to-amdgcn \ +// RUN: --convert-to-amdgcn-library-calls \ // RUN: --amdgcn-preload-library="library-paths=%p/../../../mlir_kernels/library/common/register-init.mlir,%p/../../../mlir_kernels/library/common/indexing.mlir,%p/../../../mlir_kernels/library/common/indexing_ptr.mlir,%p/../../../mlir_kernels/library/common/futures.mlir,%p/../../../contrib/kittens/library/compute_16x16_f16.mlir,%p/../../../contrib/kittens/library/global_16x64_b.mlir,%p/../../../contrib/kittens/library/lds_16x64_b.mlir,%p/../../../contrib/kittens/library/lds_mfma_16x64_b.mlir" \ // RUN: --inline --symbol-dce --canonicalize \ // RUN: --mlir-air-to-asm \ diff --git a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py index dde70e1d9..fc2d45e6c 100644 --- a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py +++ b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py @@ -9,7 +9,7 @@ --air-copy-to-dma (memref.copy → air.dma_memcpy_nd) --air-to-amdgcn (flatten hierarchy, herd → wavefront index) --convert-memspace-to-amdgcn (integer memspace → #amdgcn.addr_space) - --convert-linalg-to-amdgcn (air.dma_memcpy_nd + linalg ops → library calls) + --convert-to-amdgcn-library-calls (air.dma_memcpy_nd + linalg ops → library calls) then aster pipeline: --preload → inline → mlir-air-to-asm """ @@ -85,7 +85,7 @@ def _air_preprocess(mlir_text): "--air-to-amdgcn", "--canonicalize", "--convert-memspace-to-amdgcn", - "--convert-linalg-to-amdgcn", + "--convert-to-amdgcn-library-calls", ], input=mlir_text, capture_output=True, text=True, @@ -115,7 +115,7 @@ def test_matmul_64x64(self): A = (np.random.randn(M, K) * 0.1).astype(np.float16) B_KxN = (np.random.randn(K, N) * 0.1).astype(np.float16) B_T = np.ascontiguousarray(B_KxN.T) - # C must be zero-initialized (fill is erased by convert-linalg-to-amdgcn; + # C must be zero-initialized (fill is erased by convert-to-amdgcn-library-calls; # the library's zero_C handles accumulator init per tile). C = np.zeros(M * N, dtype=np.float32) diff --git a/contrib/mlir-air/test/integration/test_linalg_matmul_e2e.py b/contrib/mlir-air/test/integration/test_linalg_matmul_e2e.py index 8bd7c1fc6..f31c3bb41 100644 --- a/contrib/mlir-air/test/integration/test_linalg_matmul_e2e.py +++ b/contrib/mlir-air/test/integration/test_linalg_matmul_e2e.py @@ -40,7 +40,7 @@ def _mlir_air_pipeline(library_paths): return ( "builtin.module(" "transform-interpreter, canonicalize," - "convert-linalg-to-amdgcn," + "convert-to-amdgcn-library-calls," f"amdgcn-preload-library{{library-paths={libs}}}," "inline, symbol-dce, canonicalize," "mlir-air-to-asm)" diff --git a/contrib/mlir-air/test/linalg-to-amdgcn.mlir b/contrib/mlir-air/test/linalg-to-amdgcn.mlir index c1fe93e48..266b4238b 100644 --- a/contrib/mlir-air/test/linalg-to-amdgcn.mlir +++ b/contrib/mlir-air/test/linalg-to-amdgcn.mlir @@ -1,6 +1,6 @@ // RUN: mlir-air-opt %s \ // RUN: --transform-interpreter --canonicalize \ -// RUN: --convert-linalg-to-amdgcn \ +// RUN: --convert-to-amdgcn-library-calls \ // RUN: --amdgcn-preload-library="library-paths=%p/../../../mlir_kernels/library/common/register-init.mlir,%p/../../../mlir_kernels/library/common/indexing.mlir,%p/../../../mlir_kernels/library/common/indexing_ptr.mlir,%p/../../../mlir_kernels/library/common/futures.mlir,%p/../../../contrib/kittens/library/compute_16x16_f16.mlir,%p/../../../contrib/kittens/library/global_16x64_b.mlir,%p/../../../contrib/kittens/library/lds_16x64_b.mlir,%p/../../../contrib/kittens/library/lds_mfma_16x64_b.mlir" \ // RUN: --inline --symbol-dce --canonicalize \ // RUN: --mlir-air-to-asm \ From deebce2a096e1a5f923469c424a0d99f159c30d7 Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 06/16] [mlir-air] Remove dead channel handling from ConvertToAMDGCNLibraryCalls air-dma-to-channel is not in the pipeline, so ChannelGetOp/ChannelPutOp/ ChannelOp are never present. Remove ~210 lines of dead channel pre-allocation, conversion, and cleanup code. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- .../lib/ConvertToAMDGCNLibraryCalls.cpp | 224 ------------------ 1 file changed, 224 deletions(-) diff --git a/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp index a508c7899..4c8e65426 100644 --- a/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp +++ b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp @@ -291,117 +291,6 @@ struct ConvertToAMDGCNLibraryCalls SmallVector toErase; DenseMap ldsCache; - // Pre-allocate LDS at function entry for channel get destinations FIRST, - // before processing linalg ops. This ensures the matmul (which shares the - // same memref.alloc as the channel.get) hits the cache and uses the - // function-entry LDS offset instead of creating one inside the loop. - DenseMap> putsByChannel; - moduleOp->walk([&](xilinx::air::ChannelPutOp put) { - putsByChannel[put.getChanName()].push_back(put); - }); - // Determine number of wavefronts from channel array dimensions. - int64_t numWavefronts = 1; - moduleOp->walk([&](xilinx::air::ChannelOp chan) { - auto sizes = chan.getSize(); - int64_t total = 1; - for (auto s : sizes) - total *= cast(s).getInt(); - if (total > numWavefronts) - numWavefronts = total; - }); - - // Pre-allocate LDS for channel get destinations (Global→LDS direction). - // Each wavefront gets its own LDS region: alloc numWavefronts * size, - // offset by wavefront_id * size. - moduleOp->walk([&](xilinx::air::ChannelGetOp get) { - Value dst = get.getDst(); - if (!isPromotedBuffer(dst)) - return; - auto funcOp = get->getParentOfType(); - if (!funcOp || funcOp.empty()) - return; - auto savedIP = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(&funcOp.front()); - Location loc = funcOp.getLoc(); - - // Get per-wavefront tile size. - int64_t tileSizeBytes = 0; - if (auto allocOp = dst.getDefiningOp()) { - auto mrTy = allocOp.getMemref().getType(); - unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); - tileSizeBytes = mrTy.getNumElements() * eltBits / 8; - } - - // Allocate numWavefronts * tileSizeBytes. - auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), - numWavefronts * tileSizeBytes, - /*alignment=*/16, - /*offset=*/IntegerAttr{}); - auto ldsBaseOffset = - GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); - - // Per-wavefront offset: base + wavefront_id * tileSizeBytes. - Value wavefrontSize = - arith::ConstantIndexOp::create(builder, loc, 64); - Value threadIdX = - gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x); - Value wavefrontId = - arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize); - Value tileSizeVal = - arith::ConstantIndexOp::create(builder, loc, tileSizeBytes); - Value wavefrontOffset = - arith::MulIOp::create(builder, loc, wavefrontId, tileSizeVal); - Value adjustedOffset = builder.create( - loc, ldsBaseOffset.getResult(), wavefrontOffset); - - ldsCache[dst] = adjustedOffset; - builder.restoreInsertionPoint(savedIP); - }); - // Pre-allocate LDS for channel put sources (LDS→Global direction). - moduleOp->walk([&](xilinx::air::ChannelPutOp put) { - Value src = put.getSrc(); - if (!isPromotedBuffer(src)) - return; - if (ldsCache.count(src)) - return; // Already allocated (shared with channel get). - auto funcOp = put->getParentOfType(); - if (!funcOp || funcOp.empty()) - return; - auto savedIP = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(&funcOp.front()); - Location loc = funcOp.getLoc(); - - int64_t tileSizeBytes = 0; - if (auto allocOp = src.getDefiningOp()) { - auto mrTy = allocOp.getMemref().getType(); - unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); - tileSizeBytes = mrTy.getNumElements() * eltBits / 8; - } - - auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), - numWavefronts * tileSizeBytes, - /*alignment=*/16, - /*offset=*/IntegerAttr{}); - auto ldsBaseOffset = - GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); - - Value wavefrontSize = - arith::ConstantIndexOp::create(builder, loc, 64); - Value threadIdX = - gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x); - Value wavefrontId = - arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize); - Value tileSizeVal = - arith::ConstantIndexOp::create(builder, loc, tileSizeBytes); - Value wavefrontOffset = - arith::MulIOp::create(builder, loc, wavefrontId, tileSizeVal); - Value adjustedOffset = builder.create( - loc, ldsBaseOffset.getResult(), wavefrontOffset); - - ldsCache[src] = adjustedOffset; - builder.restoreInsertionPoint(savedIP); - }); - // Detect numWavefronts from the IR. // air-to-amdgcn emits: wavefrontId = gpu.thread_id x / 64 // The affine.apply for M row uses: #map()[%loop_var, %wavefrontId] @@ -584,119 +473,6 @@ struct ConvertToAMDGCNLibraryCalls ldsCache); }); - // Emit copy call at each put site (where global src operands dominate). - moduleOp->walk([&](xilinx::air::ChannelGetOp get) { - StringRef chanName = get.getChanName(); - auto it = putsByChannel.find(chanName); - if (it == putsByChannel.end() || it->second.empty()) - return; - xilinx::air::ChannelPutOp put = it->second.front(); - - Value dst = get.getDst(); - auto dstTy = dyn_cast(dst.getType()); - if (!dstTy) - return; - - Value src = put.getSrc(); - bool srcIsLDS = isPromotedBuffer(src); - bool dstIsLDS = isPromotedBuffer(dst); - - // Determine direction and emit at the appropriate site. - // Global→LDS: emit at put site (global src operands dominate). - // LDS→Global: emit at get site (global dst operands dominate). - if (srcIsLDS && !dstIsLDS) { - // LDS→Global (C write-back): emit at get site. - builder.setInsertionPoint(get); - } else { - // Global→LDS (A/B copy): emit at put site. - builder.setInsertionPoint(put); - } - Location loc = builder.getInsertionPoint()->getLoc(); - - // Use the L1 (smaller) memref type for the function name. - auto namingTy = srcIsLDS ? cast(src.getType()) : dstTy; - std::string name = buildFuncName("copy", namingTy); - - auto indexTy = builder.getIndexType(); - auto sx2Ty = amdgcn::SGPRType::get(ctx, Register(), - /*size=*/2, /*alignment=*/2); - SmallVector callArgs; - SmallVector argTypes; - - // Decompose src side. - if (srcIsLDS) { - // LDS src. - assert(ldsCache.count(src) && "LDS offset not pre-allocated for channel put src"); - callArgs.push_back(ldsCache[src]); - argTypes.push_back(indexTy); - } else { - // Global src: decompose BASE memref (not subview) to get a clean - // sgpr pointer. Pass the channel's tile offsets separately so the - // library function handles them (kittens pattern). - auto [ptrVal, byteStride] = - decomposeGlobalMemref(builder, loc, src); - callArgs.push_back(ptrVal); - argTypes.push_back(sx2Ty); - callArgs.push_back(byteStride); - argTypes.push_back(indexTy); - // Tile offsets from the channel put (element-level indices). - auto putOffsets = put.getSrcOffsets(); - if (putOffsets.size() >= 2) { - callArgs.push_back(putOffsets[0]); // row offset - argTypes.push_back(indexTy); - callArgs.push_back(putOffsets[1]); // col offset - argTypes.push_back(indexTy); - } else { - Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); - callArgs.push_back(c0); - argTypes.push_back(indexTy); - callArgs.push_back(c0); - argTypes.push_back(indexTy); - } - } - - // Decompose dst side. - if (dstIsLDS) { - // LDS dst. - assert(ldsCache.count(dst) && "LDS offset not pre-allocated for channel get dst"); - callArgs.push_back(ldsCache[dst]); - argTypes.push_back(indexTy); - } else { - // Global dst: decompose BASE memref, pass offsets separately. - auto [ptrVal, byteStride] = - decomposeGlobalMemref(builder, loc, dst); - callArgs.push_back(ptrVal); - argTypes.push_back(sx2Ty); - callArgs.push_back(byteStride); - argTypes.push_back(indexTy); - auto getOffsets = get.getDstOffsets(); - if (getOffsets.size() >= 2) { - callArgs.push_back(getOffsets[0]); - argTypes.push_back(indexTy); - callArgs.push_back(getOffsets[1]); - argTypes.push_back(indexTy); - } else { - Value c0 = arith::ConstantIndexOp::create(builder, loc, 0); - callArgs.push_back(c0); - argTypes.push_back(indexTy); - callArgs.push_back(c0); - argTypes.push_back(indexTy); - } - } - - auto funcTy = builder.getFunctionType(argTypes, {}); - ensureDecl(builder, declBlock, loc, name, funcTy); - func::CallOp::create(builder, loc, name, TypeRange{}, callArgs); - - toErase.push_back(put); - toErase.push_back(get); - }); - - // Clean up channel declarations. - moduleOp->walk([&](xilinx::air::ChannelOp chan) { - toErase.push_back(chan); - }); - for (auto *op : toErase) op->erase(); From 21b05bd66b75c3bf0b0f1f181e204a5dab51056e Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 07/16] [mlir-air] Refactor ConvertToAMDGCNLibraryCalls to use rewrite patterns MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace imperative IR walks with OpRewritePattern + applyPatternsGreedily: - DmaToLibraryCall: air.dma_memcpy_nd → copy library call - LinalgToLibraryCall: templated for fill/copy/matmul ops - GenericMatmulToLibraryCall: linalg.generic with matmul semantics - EraseGlobalFill: erase fill on global buffers (library handles zero-init) Merge LDS pre-allocation into emitLDSOffset with wavefront striping and function-entry insertion, eliminating the separate preallocateDmaLDS walk. Pattern ordering no longer matters — whichever pattern fires first creates the correct LDS allocation in the shared cache. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- .../lib/ConvertToAMDGCNLibraryCalls.cpp | 633 +++++++++--------- 1 file changed, 309 insertions(+), 324 deletions(-) diff --git a/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp index 4c8e65426..0af73bd78 100644 --- a/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp +++ b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp @@ -9,7 +9,6 @@ #include "air/Dialect/AIR/AIRDialect.h" #include "aster/Dialect/AMDGCN/IR/AMDGCNAttrs.h" #include "aster/Dialect/AMDGCN/IR/AMDGCNDialect.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "aster/Dialect/AMDGCN/IR/AMDGCNEnums.h" #include "aster/Dialect/AMDGCN/IR/AMDGCNOps.h" #include "aster/Dialect/AMDGCN/IR/AMDGCNTypes.h" @@ -19,6 +18,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Ptr/IR/PtrDialect.h" @@ -26,7 +26,9 @@ #include "mlir/Dialect/Ptr/IR/PtrTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::aster; @@ -34,6 +36,10 @@ using namespace mlir::aster::amdgcn; namespace { +// --------------------------------------------------------------------------- +// Utilities +// --------------------------------------------------------------------------- + static std::string buildFuncName(StringRef prefix, MemRefType ty) { std::string name; llvm::raw_string_ostream os(name); @@ -66,27 +72,17 @@ static void ensureDecl(OpBuilder &builder, Block &block, Location loc, builder.restoreInsertionPoint(savedIP); } -/// Check if a memref value comes from promote to shared memory. -/// Matches two patterns: -/// 1. memref.view(memref.alloca) with non-default memory space (from promote) -/// 2. memref.alloc with non-default memory space (from bufferize_to_allocation) static bool isPromotedBuffer(Value v) { - // Pattern 1: memref.view(memref.alloca) — from transform.structured.promote. if (auto viewOp = v.getDefiningOp()) { - if (auto allocaOp = viewOp.getSource().getDefiningOp()) { + if (auto allocaOp = viewOp.getSource().getDefiningOp()) return allocaOp.getMemref().getType().getMemorySpace() != nullptr; - } } - // Pattern 2: memref.alloc with L1/local memory space — - // from bufferize_to_allocation. if (auto allocOp = v.getDefiningOp()) { auto memSpace = allocOp.getMemref().getType().getMemorySpace(); if (!memSpace) return false; - // Integer memory space 2 = L1 (AIR convention). if (auto intAttr = dyn_cast(memSpace)) return intAttr.getInt() == 2; - // #amdgcn.addr_space = LDS. if (auto addrSpace = dyn_cast(memSpace)) return addrSpace.getSpace() == amdgcn::AddressSpaceKind::Local; return false; @@ -94,181 +90,355 @@ static bool isPromotedBuffer(Value v) { return false; } -/// Emit amdgcn.alloc_lds + get_lds_offset for a promoted buffer. -/// Uses a cache so the same promoted buffer gets the same LDS region -/// for both write (copy) and read (matmul). +// --------------------------------------------------------------------------- +// Shared context for patterns (populated by analysis, read by patterns). +// --------------------------------------------------------------------------- + +struct ConversionContext { + Block *declBlock = nullptr; + int64_t numWavefronts = 1; + DenseMap ldsCache; +}; + +/// Emit amdgcn.alloc_lds + get_lds_offset for a promoted buffer at function +/// entry. When numWavefronts > 1, allocates nWf * tileSizeBytes and adds a +/// per-wavefront offset (wavefrontId * tileSizeBytes). +/// Uses ldsCache so the same buffer gets the same LDS region regardless of +/// which pattern fires first. static Value emitLDSOffset(OpBuilder &builder, Location loc, Value memrefVal, - DenseMap &ldsCache) { - auto it = ldsCache.find(memrefVal); - if (it != ldsCache.end()) + ConversionContext &convCtx) { + auto it = convCtx.ldsCache.find(memrefVal); + if (it != convCtx.ldsCache.end()) return it->second; int64_t sizeBytes = 0; Value byteShift; - - // Pattern 1: memref.view(memref.alloca) — from promote. if (auto viewOp = memrefVal.getDefiningOp()) { auto allocaOp = viewOp.getSource().getDefiningOp(); sizeBytes = allocaOp.getMemref().getType().getNumElements(); byteShift = viewOp.getByteShift(); - } - // Pattern 2: memref.alloc — from bufferize_to_allocation. - else if (auto allocOp = memrefVal.getDefiningOp()) { + } else if (auto allocOp = memrefVal.getDefiningOp()) { auto mrTy = allocOp.getMemref().getType(); unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); sizeBytes = mrTy.getNumElements() * eltBits / 8; } + // Insert at function entry so LDS allocation dominates all uses. + auto funcOp = memrefVal.getParentRegion()->getParentOfType(); + auto savedIP = builder.saveInsertionPoint(); + if (funcOp) + builder.setInsertionPointToStart(&funcOp.front()); + + int64_t nWf = convCtx.numWavefronts; auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), - sizeBytes, /*alignment=*/16, + nWf * sizeBytes, /*alignment=*/16, /*offset=*/IntegerAttr{}); auto ldsOffset = GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); - Value result = ldsOffset.getResult(); + + if (nWf > 1) { + Value wavefrontSize = arith::ConstantIndexOp::create(builder, loc, 64); + Value threadIdX = + gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x); + Value wavefrontId = + arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize); + Value tileSizeVal = + arith::ConstantIndexOp::create(builder, loc, sizeBytes); + Value wavefrontOffset = + arith::MulIOp::create(builder, loc, wavefrontId, tileSizeVal); + result = arith::AddIOp::create(builder, loc, result, wavefrontOffset); + } + if (byteShift) result = builder.create(loc, result, byteShift); - ldsCache[memrefVal] = result; + builder.restoreInsertionPoint(savedIP); + convCtx.ldsCache[memrefVal] = result; return result; } -/// Decompose a global memref into (!sx2, byte_stride: index). -/// Emits: extract_strided_metadata -> ptr.to_ptr -> lsir.to_reg -> ptr_add. static std::pair decomposeGlobalMemref(OpBuilder &builder, Location loc, Value memref) { auto mrTy = cast(memref.getType()); unsigned eltBytes = mrTy.getElementType().getIntOrFloatBitWidth() / 8; - - // extract_strided_metadata -> (base_memref, offset, sizes..., strides...) auto metadata = memref::ExtractStridedMetadataOp::create(builder, loc, memref); Value baseBuffer = metadata.getBaseBuffer(); Value offset = metadata.getOffset(); - // Leading stride is strides[0] (row stride in elements). Value leadingStride = metadata.getStrides()[0]; - - // byte_stride = leading_stride * elt_bytes Value eltSize = arith::ConstantIndexOp::create(builder, loc, eltBytes); Value byteStride = arith::MulIOp::create(builder, loc, leadingStride, eltSize); - - // byte_offset = offset * elt_bytes Value byteOffset = arith::MulIOp::create(builder, loc, offset, eltSize); - - // ptr.to_ptr base_memref -> !ptr.ptr auto addrSpace = cast(mrTy.getMemorySpace()); auto ptrTy = ptr::PtrType::get(builder.getContext(), addrSpace); Value ptrVal = ptr::ToPtrOp::create(builder, loc, ptrTy, baseBuffer); - - // lsir.to_reg ptr -> !sx2 auto sx2Ty = amdgcn::SGPRType::get(builder.getContext(), Register(), /*size=*/2, /*alignment=*/2); Value rawPtr = lsir::ToRegOp::create(builder, loc, sx2Ty, ptrVal); - - // Add byte offset: from_reg -> ptr_add -> to_reg Value ptrFromReg = lsir::FromRegOp::create(builder, loc, ptrTy, rawPtr); Value adjusted = ptr::PtrAddOp::create(builder, loc, ptrTy, ptrFromReg, byteOffset); Value result = lsir::ToRegOp::create(builder, loc, sx2Ty, adjusted); - return {result, byteStride}; } -/// Replace a linalg op with a library call. -/// Global memrefs -> decomposed (!sx2, byte_stride) args. -/// Promoted buffers -> index (LDS offset). -static void replaceWithCall(OpBuilder &builder, Block &declBlock, Operation *op, - StringRef namePrefix, - SmallVector &toErase, - DenseMap &ldsCache) { - // Only convert ops that involve at least one promoted (LDS) buffer. - bool hasPromotedOperand = false; - for (Value operand : op->getOperands()) - if (isPromotedBuffer(operand)) - hasPromotedOperand = true; - if (!hasPromotedOperand) - return; - - auto indexTy = builder.getIndexType(); - SmallVector callArgs; - SmallVector argTypes; - - MemRefType namingType; - for (Value operand : op->getOperands()) - if (auto mrTy = dyn_cast(operand.getType())) - if (!namingType) - namingType = mrTy; - if (!namingType) - return; - std::string name = buildFuncName(namePrefix, namingType); - - builder.setInsertionPoint(op); - Location loc = op->getLoc(); +// --------------------------------------------------------------------------- +// Patterns +// --------------------------------------------------------------------------- + +/// Convert air.dma_memcpy_nd (Global→LDS) to a library call. +struct DmaToLibraryCall + : public OpRewritePattern { + ConversionContext &convCtx; + + DmaToLibraryCall(MLIRContext *ctx, ConversionContext &convCtx) + : OpRewritePattern(ctx), convCtx(convCtx) {} + + LogicalResult matchAndRewrite(xilinx::air::DmaMemcpyNdOp dma, + PatternRewriter &rewriter) const override { + Value dst = dma.getDstMemref(); + if (!isPromotedBuffer(dst)) + return failure(); + + Value src = dma.getSrcMemref(); + auto dstTy = cast(dst.getType()); + Location loc = dma.getLoc(); + std::string name = buildFuncName("copy", dstTy); + + auto indexTy = rewriter.getIndexType(); + auto sx2Ty = amdgcn::SGPRType::get(rewriter.getContext(), Register(), + /*size=*/2, /*alignment=*/2); + SmallVector callArgs; + SmallVector argTypes; + + auto [ptrVal, byteStride] = decomposeGlobalMemref(rewriter, loc, src); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); + argTypes.push_back(indexTy); + + auto srcOffsets = dma.getSrcOffsets(); + if (srcOffsets.size() >= 2) { + callArgs.push_back(srcOffsets[0]); + argTypes.push_back(indexTy); + callArgs.push_back(srcOffsets[1]); + argTypes.push_back(indexTy); + } else { + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + callArgs.push_back(zero); + argTypes.push_back(indexTy); + callArgs.push_back(zero); + argTypes.push_back(indexTy); + } - auto sx2Ty = amdgcn::SGPRType::get(builder.getContext(), Register(), - /*size=*/2, /*alignment=*/2); + callArgs.push_back(emitLDSOffset(rewriter, loc, dst, convCtx)); + argTypes.push_back(indexTy); - for (Value operand : op->getOperands()) { - if (auto mrTy = dyn_cast(operand.getType())) { - if (isPromotedBuffer(operand)) { - callArgs.push_back(emitLDSOffset(builder, loc, operand, ldsCache)); - argTypes.push_back(indexTy); - } else { - // Global memref: if this is a subview, decompose the BASE memref - // (clean sgpr) and pass tile offsets separately. This avoids - // baking wavefront-varying offsets into the pointer. - Value baseMemref = operand; - SmallVector tileOffsets; - if (auto svOp = operand.getDefiningOp()) { - baseMemref = svOp.getSource(); - for (auto off : svOp.getMixedOffsets()) { - if (auto val = dyn_cast(off)) - tileOffsets.push_back(val); - else - tileOffsets.push_back(arith::ConstantIndexOp::create( - builder, loc, - cast(off.get()).getInt())); - } - } - auto [ptrVal, byteStride] = - decomposeGlobalMemref(builder, loc, baseMemref); - callArgs.push_back(ptrVal); - argTypes.push_back(sx2Ty); - callArgs.push_back(byteStride); - argTypes.push_back(indexTy); - // Pass tile offsets (or zeros if no subview). - if (tileOffsets.empty()) { - auto rank = mrTy.getRank(); - for (int64_t i = 0; i < rank; ++i) { - callArgs.push_back( - arith::ConstantIndexOp::create(builder, loc, 0)); - argTypes.push_back(indexTy); - } + auto funcTy = rewriter.getFunctionType(argTypes, {}); + ensureDecl(rewriter, *convCtx.declBlock, loc, name, funcTy); + func::CallOp::create(rewriter, loc, name, TypeRange{}, callArgs); + rewriter.eraseOp(dma); + return success(); + } +}; + +/// Convert a linalg op (or memref.copy) with at least one promoted (LDS) +/// operand to a library call. +template +struct LinalgToLibraryCall : public OpRewritePattern { + ConversionContext &convCtx; + StringRef namePrefix; + + LinalgToLibraryCall(MLIRContext *ctx, ConversionContext &convCtx, + StringRef namePrefix) + : OpRewritePattern(ctx), convCtx(convCtx), + namePrefix(namePrefix) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + bool hasPromoted = false; + for (Value operand : op->getOperands()) + if (isPromotedBuffer(operand)) + hasPromoted = true; + if (!hasPromoted) + return failure(); + + MemRefType namingType; + for (Value operand : op->getOperands()) + if (auto mrTy = dyn_cast(operand.getType())) + if (!namingType) + namingType = mrTy; + if (!namingType) + return failure(); + + std::string name = buildFuncName(namePrefix, namingType); + Location loc = op->getLoc(); + auto indexTy = rewriter.getIndexType(); + auto sx2Ty = amdgcn::SGPRType::get(rewriter.getContext(), Register(), + /*size=*/2, /*alignment=*/2); + SmallVector callArgs; + SmallVector argTypes; + + for (Value operand : op->getOperands()) { + if (auto mrTy = dyn_cast(operand.getType())) { + if (isPromotedBuffer(operand)) { + callArgs.push_back( + emitLDSOffset(rewriter, loc, operand, convCtx)); + argTypes.push_back(indexTy); } else { - for (auto off : tileOffsets) { - callArgs.push_back(off); - argTypes.push_back(indexTy); + Value baseMemref = operand; + SmallVector tileOffsets; + if (auto svOp = operand.getDefiningOp()) { + baseMemref = svOp.getSource(); + for (auto off : svOp.getMixedOffsets()) { + if (auto val = dyn_cast(off)) + tileOffsets.push_back(val); + else + tileOffsets.push_back(arith::ConstantIndexOp::create( + rewriter, loc, + cast(off.get()).getInt())); + } + } + auto [ptrVal, byteStride] = + decomposeGlobalMemref(rewriter, loc, baseMemref); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); + argTypes.push_back(indexTy); + if (tileOffsets.empty()) { + for (int64_t i = 0; i < mrTy.getRank(); ++i) { + callArgs.push_back( + arith::ConstantIndexOp::create(rewriter, loc, 0)); + argTypes.push_back(indexTy); + } + } else { + for (auto off : tileOffsets) { + callArgs.push_back(off); + argTypes.push_back(indexTy); + } } } + } else { + callArgs.push_back(operand); + argTypes.push_back(operand.getType()); } - } else { - callArgs.push_back(operand); - argTypes.push_back(operand.getType()); } + + auto funcTy = rewriter.getFunctionType(argTypes, {}); + ensureDecl(rewriter, *convCtx.declBlock, loc, name, funcTy); + func::CallOp::create(rewriter, loc, name, TypeRange{}, callArgs); + rewriter.eraseOp(op); + return success(); } +}; - auto funcTy = builder.getFunctionType(argTypes, {}); - ensureDecl(builder, declBlock, loc, name, funcTy); - func::CallOp::create(builder, loc, name, TypeRange{}, callArgs); - toErase.push_back(op); +/// Match linalg.generic with matmul semantics (2 inputs, 1 output, 1 reduction). +struct GenericMatmulToLibraryCall : public OpRewritePattern { + ConversionContext &convCtx; + + GenericMatmulToLibraryCall(MLIRContext *ctx, ConversionContext &convCtx) + : OpRewritePattern(ctx), convCtx(convCtx) {} + + LogicalResult matchAndRewrite(linalg::GenericOp op, + PatternRewriter &rewriter) const override { + if (op.getNumDpsInputs() != 2 || op.getNumDpsInits() != 1 || + op.getNumReductionLoops() != 1) + return failure(); + // Delegate to the generic linalg pattern. + LinalgToLibraryCall inner( + rewriter.getContext(), convCtx, "mfma_matmul"); + return inner.matchAndRewrite(op, rewriter); + } +}; + +/// Erase linalg.fill on global (non-LDS) buffers. +/// The library's zero_C handles accumulator init. +struct EraseGlobalFill : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::FillOp fill, + PatternRewriter &rewriter) const override { + for (Value out : fill.getDpsInits()) + if (isa(out.getType()) && !isPromotedBuffer(out)) { + rewriter.eraseOp(fill); + return success(); + } + return failure(); + } +}; + +// --------------------------------------------------------------------------- +// Analysis: detect numWavefronts from IR and pre-allocate LDS. +// --------------------------------------------------------------------------- + +static int64_t detectNumWavefronts(Operation *moduleOp) { + int64_t result = 1; + moduleOp->walk([&](arith::DivUIOp divOp) { + if (result > 1) + return; + auto threadId = divOp.getLhs().getDefiningOp(); + if (!threadId || threadId.getDimension() != gpu::Dimension::x) + return; + auto cst = divOp.getRhs().getDefiningOp(); + if (!cst || cst.value() != 64) + return; + for (Operation *user : divOp.getResult().getUsers()) { + auto applyOp = dyn_cast(user); + if (!applyOp || applyOp.getAffineMap().getNumResults() != 1) + continue; + unsigned wfPos = 0; + bool found = false; + for (auto [idx, operand] : llvm::enumerate(applyOp.getMapOperands())) { + if (operand == divOp.getResult()) { + wfPos = idx; + found = true; + break; + } + } + if (!found) + continue; + int64_t stride = 0; + AffineExpr expr = applyOp.getAffineMap().getResult(0); + expr.walk([&](AffineExpr e) { + auto mul = dyn_cast(e); + if (!mul || mul.getKind() != AffineExprKind::Mul) + return; + auto sym = dyn_cast(mul.getLHS()); + auto con = dyn_cast(mul.getRHS()); + if (!sym || !con) + return; + if (sym.getPosition() == wfPos) + stride = con.getValue(); + }); + if (stride <= 0) + continue; + moduleOp->walk([&](xilinx::air::DmaMemcpyNdOp dma2) { + if (result > 1) + return; + Value src = dma2.getSrcMemref(); + auto srcTy = dyn_cast(src.getType()); + if (!srcTy || srcTy.getRank() < 1 || srcTy.getDimSize(0) <= 0) + return; + result = srcTy.getDimSize(0) / stride; + }); + } + }); + return result; } +// --------------------------------------------------------------------------- +// Pass +// --------------------------------------------------------------------------- + struct ConvertToAMDGCNLibraryCalls : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToAMDGCNLibraryCalls) - StringRef getArgument() const override { return "convert-to-amdgcn-library-calls"; } + StringRef getArgument() const override { + return "convert-to-amdgcn-library-calls"; + } StringRef getDescription() const override { return "Convert linalg/AIR ops to AMDGCN library calls"; } @@ -278,217 +448,32 @@ struct ConvertToAMDGCNLibraryCalls registry.insert(); } - void runOnOperation() override { Operation *moduleOp = getOperation(); MLIRContext *ctx = &getContext(); + // Find the declaration block (inside amdgcn.module if present). Operation *declParent = moduleOp; if (isa(moduleOp)) moduleOp->walk([&](amdgcn::ModuleOp m) { declParent = m; }); - auto &declBlock = declParent->getRegion(0).front(); - OpBuilder builder(ctx); - SmallVector toErase; - DenseMap ldsCache; - - // Detect numWavefronts from the IR. - // air-to-amdgcn emits: wavefrontId = gpu.thread_id x / 64 - // The affine.apply for M row uses: #map()[%loop_var, %wavefrontId] - // = loop_var + wavefrontId * herdTileM - // numWavefronts = totalM / herdTileM. - // We find herdTileM from the coefficient of the wavefrontId symbol in the - // affine map by scanning AffineApplyOp users of the divui result. - int64_t detectedNumWavefronts = 1; - moduleOp->walk([&](arith::DivUIOp divOp) { - if (detectedNumWavefronts > 1) - return; - auto threadId = divOp.getLhs().getDefiningOp(); - if (!threadId || threadId.getDimension() != gpu::Dimension::x) - return; - auto cst = divOp.getRhs().getDefiningOp(); - if (!cst || cst.value() != 64) - return; - // wavefrontId = divOp.getResult(). Find its use in affine.apply. - for (Operation *user : divOp.getResult().getUsers()) { - auto applyOp = dyn_cast(user); - if (!applyOp || applyOp.getAffineMap().getNumResults() != 1) - continue; - // Find the position of wavefrontId in the operand list. - unsigned wfPos = 0; - bool found = false; - for (auto [idx, operand] : llvm::enumerate(applyOp.getMapOperands())) { - if (operand == divOp.getResult()) { - wfPos = idx; - found = true; - break; - } - } - if (!found) - continue; - // Extract coefficient of symbol/dim at wfPos in the affine map. - AffineMap map = applyOp.getAffineMap(); - // The map operands are pure symbols in this context. - int64_t stride = 0; - AffineExpr expr = map.getResult(0); - expr.walk([&](AffineExpr e) { - auto mul = dyn_cast(e); - if (!mul || mul.getKind() != AffineExprKind::Mul) - return; - auto sym = dyn_cast(mul.getLHS()); - auto con = dyn_cast(mul.getRHS()); - if (!sym || !con) - return; - if (sym.getPosition() == wfPos) - stride = con.getValue(); - }); - if (stride <= 0) - continue; - // Get totalM from DMA src memref. - moduleOp->walk([&](xilinx::air::DmaMemcpyNdOp dma2) { - if (detectedNumWavefronts > 1) - return; - Value src = dma2.getSrcMemref(); - auto srcTy = dyn_cast(src.getType()); - if (!srcTy || srcTy.getRank() < 1 || srcTy.getDimSize(0) <= 0) - return; - detectedNumWavefronts = srcTy.getDimSize(0) / stride; - }); - } - }); - - // Pre-allocate LDS for air.dma_memcpy_nd destinations (no-channel path). - // Must run before linalg op processing so matmul hits the same ldsCache. - // Allocate numWavefronts * tileSizeBytes and stripe by wavefront_id. - moduleOp->walk([&](xilinx::air::DmaMemcpyNdOp dma) { - Value dst = dma.getDstMemref(); - if (!isPromotedBuffer(dst) || ldsCache.count(dst)) - return; - auto funcOp = dma->getParentOfType(); - if (!funcOp || funcOp.empty()) - return; - auto savedIP = builder.saveInsertionPoint(); - builder.setInsertionPointToStart(&funcOp.front()); - Location loc = funcOp.getLoc(); - - int64_t tileSizeBytes = 0; - if (auto allocOp = dst.getDefiningOp()) { - auto mrTy = allocOp.getMemref().getType(); - unsigned eltBits = mrTy.getElementType().getIntOrFloatBitWidth(); - tileSizeBytes = mrTy.getNumElements() * eltBits / 8; - } - int64_t nWf = detectedNumWavefronts; - auto ldsAlloc = AllocLDSOp::create(builder, loc, /*dynamic_size=*/Value(), - nWf * tileSizeBytes, /*alignment=*/16, - /*offset=*/IntegerAttr{}); - auto ldsBaseOffset = - GetLDSOffsetOp::create(builder, loc, builder.getIndexType(), ldsAlloc); - Value result = ldsBaseOffset.getResult(); - if (nWf > 1) { - Value wavefrontSize = arith::ConstantIndexOp::create(builder, loc, 64); - Value threadIdX = - gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x); - Value wavefrontId = - arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize); - Value tileSizeVal = - arith::ConstantIndexOp::create(builder, loc, tileSizeBytes); - Value wavefrontOffset = - arith::MulIOp::create(builder, loc, wavefrontId, tileSizeVal); - result = arith::AddIOp::create(builder, loc, result, wavefrontOffset); - } - ldsCache[dst] = result; - builder.restoreInsertionPoint(savedIP); - }); - - // Convert air.dma_memcpy_nd directly (Global→LDS only; no channels). - // Emit: copy__(base_sgpr, stride, row_off, col_off, lds_dst) - moduleOp->walk([&](xilinx::air::DmaMemcpyNdOp dma) { - Value dst = dma.getDstMemref(); - Value src = dma.getSrcMemref(); - bool dstIsLDS = isPromotedBuffer(dst); - if (!dstIsLDS) - return; // Only handle Global→LDS DMAs here. - - auto dstTy = cast(dst.getType()); - builder.setInsertionPoint(dma); - Location loc = dma.getLoc(); - - std::string name = buildFuncName("copy", dstTy); - - auto indexTy = builder.getIndexType(); - auto sx2Ty = amdgcn::SGPRType::get(ctx, Register(), /*size=*/2, - /*alignment=*/2); - SmallVector callArgs; - SmallVector argTypes; - - // Decompose BASE src memref → (sgpr_base, byte_stride). - auto [ptrVal, byteStride] = decomposeGlobalMemref(builder, loc, src); - callArgs.push_back(ptrVal); - argTypes.push_back(sx2Ty); - callArgs.push_back(byteStride); - argTypes.push_back(indexTy); - - // Src tile offsets (row, col) from DMA operands. - auto srcOffsets = dma.getSrcOffsets(); - if (srcOffsets.size() >= 2) { - callArgs.push_back(srcOffsets[0]); - argTypes.push_back(indexTy); - callArgs.push_back(srcOffsets[1]); - argTypes.push_back(indexTy); - } else { - Value zero = arith::ConstantIndexOp::create(builder, loc, 0); - callArgs.push_back(zero); - argTypes.push_back(indexTy); - callArgs.push_back(zero); - argTypes.push_back(indexTy); - } - - // LDS dst offset from cache. - assert(ldsCache.count(dst) && "DMA dst LDS not pre-allocated"); - callArgs.push_back(ldsCache[dst]); - argTypes.push_back(indexTy); - - auto funcTy = builder.getFunctionType(argTypes, {}); - ensureDecl(builder, declBlock, loc, name, funcTy); - func::CallOp::create(builder, loc, name, TypeRange{}, callArgs); - toErase.push_back(dma); - }); - - // Now process linalg ops — they'll hit the ldsCache for shared allocs. - moduleOp->walk([&](linalg::FillOp op) { - replaceWithCall(builder, declBlock, op, "fill", toErase, ldsCache); - }); - moduleOp->walk([&](linalg::CopyOp op) { - replaceWithCall(builder, declBlock, op, "copy", toErase, ldsCache); - }); - moduleOp->walk([&](memref::CopyOp op) { - replaceWithCall(builder, declBlock, op, "copy", toErase, ldsCache); - }); - moduleOp->walk([&](linalg::MatmulOp op) { - replaceWithCall(builder, declBlock, op, "mfma_matmul", toErase, ldsCache); - }); - moduleOp->walk([&](linalg::GenericOp op) { - if (op.getNumDpsInputs() == 2 && op.getNumDpsInits() == 1 && - op.getNumReductionLoops() == 1) - replaceWithCall(builder, declBlock, op, "mfma_matmul", toErase, - ldsCache); - }); - - for (auto *op : toErase) - op->erase(); - - // Erase linalg.fill on global (non-LDS) buffers. - // The library's zero_C handles accumulator init, so the fill is redundant. - // It must be erased because the aster backend cannot lower linalg.fill - // on global memrefs. - SmallVector globalFills; - moduleOp->walk([&](linalg::FillOp fill) { - for (Value out : fill.getDpsInits()) - if (isa(out.getType()) && !isPromotedBuffer(out)) - globalFills.push_back(fill); - }); - for (auto fill : globalFills) - fill->erase(); + ConversionContext convCtx; + convCtx.declBlock = &declParent->getRegion(0).front(); + convCtx.numWavefronts = detectNumWavefronts(moduleOp); + + // Apply conversion patterns. + RewritePatternSet patterns(ctx); + patterns.add(ctx, convCtx); + patterns.add>(ctx, convCtx, "fill"); + patterns.add>(ctx, convCtx, "copy"); + patterns.add>(ctx, convCtx, "copy"); + patterns.add>(ctx, convCtx, + "mfma_matmul"); + patterns.add(ctx, convCtx); + patterns.add(ctx); + + if (failed(applyPatternsGreedily(moduleOp, std::move(patterns)))) + signalPassFailure(); } }; From 3244c6845a39d996a4f3286e4d4c2c48776296b8 Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 08/16] [mlir-air] Add padded matmul E2E test (40x40, host-padded to 48x48) Demonstrates non-tile-aligned matmul: actual dimensions 40x40x64, host pads inputs/output to tile-aligned 48x48x64. Kernel operates on full tiles; padding zeros produce zero output in the padded region. Host extracts valid C[0:40, 0:40] for verification. - New payload: air-to-amdgcn-matmul-padded.mlir (48x48x64 kernel) - New transform: tile_using_forall [16,0,0] (3 wavefronts) - New test: test_matmul_padded_40x40 in test_air_matmul_e2e.py - Register tensor ValueBounds and SubsetOp interfaces in Init.cpp (needed for tensor.extract_slice in padding transforms) - Refactor _air_preprocess into _air_preprocess_with_files for reuse Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- contrib/mlir-air/lib/Init.cpp | 4 + ...air-to-amdgcn-matmul-padded-transform.mlir | 87 +++++++++++ .../test/air-to-amdgcn-matmul-padded.mlir | 143 ++++++++++++++++++ .../test/integration/test_air_matmul_e2e.py | 51 ++++++- 4 files changed, 283 insertions(+), 2 deletions(-) create mode 100644 contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir create mode 100644 contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir diff --git a/contrib/mlir-air/lib/Init.cpp b/contrib/mlir-air/lib/Init.cpp index 13e748b5d..5e8d392b8 100644 --- a/contrib/mlir-air/lib/Init.cpp +++ b/contrib/mlir-air/lib/Init.cpp @@ -36,6 +36,8 @@ namespace air_xform_reg { #include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h" @@ -74,6 +76,8 @@ void registerAll(DialectRegistry ®istry) { scf::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerInferTypeOpInterfaceExternalModels(registry); + tensor::registerSubsetOpInterfaceExternalModels(registry); + tensor::registerValueBoundsOpInterfaceExternalModels(registry); // Transform dialect extensions. bufferization::registerTransformDialectExtension(registry); diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir new file mode 100644 index 000000000..c7bcdb708 --- /dev/null +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir @@ -0,0 +1,87 @@ +// Transform sequence for 48x48x64 matmul (host-padded from 40x40x64). +// tile_using_forall [16,0,0] → 3 wavefronts (48/16=3, exact). +// tile_using_for [16,16,0] → 16x16 compute tiles, untiled K. +// All tiles are full (48 is a multiple of 16). No boundary padding needed. + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg0: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + + // Outer tiling: forall on M, one wavefront per 16 M rows. + %outer_tiled, %outer_forall = + transform.structured.tile_using_forall %matmul + tile_sizes [16, 0, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Compute tiling: 16x16 output tiles, K untiled. + %tiled, %lm, %ln = transform.structured.tile_using_for %outer_tiled + tile_sizes [16, 16, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op) + + // Pad A and B to tile size (16). Boundary tiles (8 rows/cols) get + // zero-padded to 16. + %padded, %pad, %copy_back = transform.structured.pad %tiled { + padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32], + padding_dimensions = [0, 1, 2], + pack_paddings = [1, 1, 0], + nofold_flags = [1, 1, 0], + copy_back_op = "linalg.copy" + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, + !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad + : (!transform.any_op) -> !transform.any_op + + // Promote padded A,B to LDS (memory_space=2). + %padded_lhs = transform.get_producer_of_operand %padded[0] + : (!transform.any_op) -> (!transform.any_op) + %buf_a, %new_a = transform.structured.bufferize_to_allocation %padded_lhs + {memory_space = 2, bufferize_destination_only} + : !transform.any_op + + %padded_rhs = transform.get_producer_of_operand %padded[1] + : (!transform.any_op) -> (!transform.any_op) + %buf_b, %new_b = transform.structured.bufferize_to_allocation %padded_rhs + {memory_space = 2, bufferize_destination_only} + : !transform.any_op + + // Canonicalize. + %func_0 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_0 { + transform.apply_patterns.linalg.tiling_canonicalization + transform.apply_patterns.scf.for_loop_canonicalization + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_0 : !transform.any_op + + // One-shot bufferize. + %func_1 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %func_buf = transform.bufferization.one_shot_bufferize %func_1 { + allow_return_allocs_from_loops = true + } : (!transform.any_op) -> !transform.any_op + + // Cleanup. + %func_2 = transform.structured.match ops{["func.func"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %func_2 { + transform.apply_patterns.canonicalization + } : !transform.any_op + transform.apply_cse to %func_2 : !transform.any_op + %func_3 = transform.air.remove_uninitialized_copy %func_2 + : (!transform.any_op) -> !transform.any_op + + // Convert outer forall → parallel → air.herd. + %forall_2 = transform.structured.match ops{["scf.forall"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %parallel = transform.loop.forall_to_parallel %forall_2 + : (!transform.any_op) -> !transform.any_op + %herd = transform.air.par_to_herd %parallel + : (!transform.any_op) -> !transform.any_op + + transform.yield + } +} diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir new file mode 100644 index 000000000..c1d1f448b --- /dev/null +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir @@ -0,0 +1,143 @@ +// Padded matmul E2E test: actual dimensions M=40, N=40, K=64. +// Host pads inputs to tile-aligned sizes: A=48x64, B=48x64, C=48x48. +// Kernel operates on full tile-aligned dimensions; padding zeros produce +// zero output in the padded region. Host extracts valid C[0:40, 0:40]. + +!sx2 = !amdgcn.sgpr<[? + 2]> +!vx2 = !amdgcn.vgpr<[? + 2]> +!ax4 = !amdgcn.agpr<[? + 4]> +!lds_write_token = !amdgcn.write_token +!future_lds_read = !aster_utils.struct> +!future_global_read = !aster_utils.struct> + +module { + amdgcn.library @linalg_lib isa = [#amdgcn.isa] { + func.func private @zero_C() -> !ax4 + func.func private @mfma_f32_16x16x16_f16(!vx2, !vx2, !ax4) -> !ax4 + func.func private @store_global_C_mfma_f32_16x16x16_f16( + !ax4, !aster_utils.any, index, index, index) + func.func private @prepare_ptr(!sx2) -> !aster_utils.any + func.func private @load_global_tile_16x64_b( + !aster_utils.any, index, index, index) -> !future_global_read + func.func private @store_global_tile_to_lds_16x64_b( + index, !future_global_read) -> (!lds_write_token, !lds_write_token) + func.func private @load_lds_A_swizzled( + index, index, index) -> !future_lds_read + func.func private @load_lds_B_swizzled( + index, index, index) -> !future_lds_read + func.func private @get_lds_read_value_vx2(!future_lds_read) -> !vx2 + + func.func private @copy_f16_16x64( + %src_ptr: !sx2, %src_stride: index, + %row_offset: index, %col_offset: index, + %lds_dst: index) { + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %col1 = arith.addi %col_offset, %c32 : index + %lds_dst1 = arith.addi %lds_dst, %c1024 : index + %ptr = func.call @prepare_ptr(%src_ptr) : (!sx2) -> !aster_utils.any + %gfut0 = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col_offset, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t0, %t1 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst, %gfut0) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + %gfut1 = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col1, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t2, %t3 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst1, %gfut1) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + amdgcn.wait deps %t0 : !lds_write_token + amdgcn.wait deps %t1 : !lds_write_token + amdgcn.wait deps %t2 : !lds_write_token + amdgcn.wait deps %t3 : !lds_write_token + return + } + + func.func private @mfma_matmul_f16_16x64( + %lds_A: index, %lds_B: index, + %C_ptr: !sx2, %C_stride: index, + %C_row_offset: index, %C_col_offset: index) { + %C_prepared = func.call @prepare_ptr(%C_ptr) : (!sx2) -> !aster_utils.any + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %lds_A2 = arith.addi %lds_A, %c1024 : index + %lds_B2 = arith.addi %lds_B, %c1024 : index + %acc = func.call @zero_C() : () -> !ax4 + %A0f = func.call @load_lds_A_swizzled(%lds_A, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A0 = func.call @get_lds_read_value_vx2(%A0f) : (!future_lds_read) -> !vx2 + %B0f = func.call @load_lds_B_swizzled(%lds_B, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B0 = func.call @get_lds_read_value_vx2(%B0f) : (!future_lds_read) -> !vx2 + %acc0 = func.call @mfma_f32_16x16x16_f16(%A0, %B0, %acc) + : (!vx2, !vx2, !ax4) -> !ax4 + %A1f = func.call @load_lds_A_swizzled(%lds_A, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A1 = func.call @get_lds_read_value_vx2(%A1f) : (!future_lds_read) -> !vx2 + %B1f = func.call @load_lds_B_swizzled(%lds_B, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B1 = func.call @get_lds_read_value_vx2(%B1f) : (!future_lds_read) -> !vx2 + %acc1 = func.call @mfma_f32_16x16x16_f16(%A1, %B1, %acc0) + : (!vx2, !vx2, !ax4) -> !ax4 + %A2f = func.call @load_lds_A_swizzled(%lds_A2, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A2 = func.call @get_lds_read_value_vx2(%A2f) : (!future_lds_read) -> !vx2 + %B2f = func.call @load_lds_B_swizzled(%lds_B2, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B2 = func.call @get_lds_read_value_vx2(%B2f) : (!future_lds_read) -> !vx2 + %acc2 = func.call @mfma_f32_16x16x16_f16(%A2, %B2, %acc1) + : (!vx2, !vx2, !ax4) -> !ax4 + %A3f = func.call @load_lds_A_swizzled(%lds_A2, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A3 = func.call @get_lds_read_value_vx2(%A3f) : (!future_lds_read) -> !vx2 + %B3f = func.call @load_lds_B_swizzled(%lds_B2, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B3 = func.call @get_lds_read_value_vx2(%B3f) : (!future_lds_read) -> !vx2 + %acc3 = func.call @mfma_f32_16x16x16_f16(%A3, %B3, %acc2) + : (!vx2, !vx2, !ax4) -> !ax4 + func.call @store_global_C_mfma_f32_16x16x16_f16( + %acc3, %C_prepared, %C_row_offset, %C_col_offset, %C_stride) + : (!ax4, !aster_utils.any, index, index, index) -> () + return + } + + func.func private @fill_f16_16x64(%val: f16, %lds_dst: index) { return } + func.func private @fill_f16_16x32(%val: f16, %lds_dst: index) { return } + } + + amdgcn.module @matmul_mod target = #amdgcn.target isa = #amdgcn.isa { + // 48x48x64 matmul on tile-aligned dimensions. + // Host pads actual 40x40 data to these sizes. + func.func @matmul_f16_48x48( + %A: memref<48x64xf16>, %B: memref<48x64xf16>, %C: memref<48x48xf32>) + attributes {gpu.kernel} { + %cst = arith.constant 0.000000e+00 : f32 + %a = bufferization.to_tensor %A restrict writable : memref<48x64xf16> to tensor<48x64xf16> + %b = bufferization.to_tensor %B restrict writable : memref<48x64xf16> to tensor<48x64xf16> + %c = bufferization.to_tensor %C restrict writable : memref<48x48xf32> to tensor<48x48xf32> + %fill = linalg.fill ins(%cst : f32) outs(%c : tensor<48x48xf32>) -> tensor<48x48xf32> + // matmul_transpose_b: C[m,n] += A[m,k] * B[n,k] + %result = linalg.generic { + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%a, %b : tensor<48x64xf16>, tensor<48x64xf16>) + outs(%fill : tensor<48x48xf32>) { + ^bb0(%av: f16, %bv: f16, %cv: f32): + %a_ext = arith.extf %av : f16 to f32 + %b_ext = arith.extf %bv : f16 to f32 + %prod = arith.mulf %a_ext, %b_ext : f32 + %sum = arith.addf %cv, %prod : f32 + linalg.yield %sum : f32 + } -> tensor<48x48xf32> + bufferization.materialize_in_destination %result in writable %C + : (tensor<48x48xf32>, memref<48x48xf32>) -> () + return + } + } +} diff --git a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py index fc2d45e6c..766aa2439 100644 --- a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py +++ b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py @@ -28,6 +28,10 @@ _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _MLIR_FILE = os.path.join(_THIS_DIR, "..", "air-to-amdgcn-matmul.mlir") _TRANSFORM_FILE = os.path.join(_THIS_DIR, "..", "air-to-amdgcn-matmul-transform.mlir") +_PADDED_MLIR_FILE = os.path.join(_THIS_DIR, "..", "air-to-amdgcn-matmul-padded.mlir") +_PADDED_TRANSFORM_FILE = os.path.join( + _THIS_DIR, "..", "air-to-amdgcn-matmul-padded-transform.mlir" +) _LIBRARY_DIR = os.path.join( _THIS_DIR, "..", "..", "..", "..", "mlir_kernels", "library" ) @@ -67,13 +71,13 @@ def _find_mlir_air_opt(): pytest.skip("mlir-air-opt not found") -def _air_preprocess(mlir_text): +def _air_preprocess_with_files(mlir_text, transform_file): """Run the full AIR lowering pipeline before handing to aster.""" opt = _find_mlir_air_opt() result = subprocess.run( [ opt, - f"--transform-preload-library=transform-library-paths={_TRANSFORM_FILE}", + f"--transform-preload-library=transform-library-paths={transform_file}", "--transform-interpreter", "--air-par-to-herd", "--canonicalize", "--cse", @@ -96,6 +100,10 @@ def _air_preprocess(mlir_text): return result.stdout +def _air_preprocess(mlir_text): + return _air_preprocess_with_files(mlir_text, _TRANSFORM_FILE) + + def _post_air_pipeline(library_paths): libs = ",".join(library_paths) return ( @@ -133,3 +141,42 @@ def test_matmul_64x64(self): expected = (A.astype(np.float32) @ B_KxN.astype(np.float32)).flatten() np.testing.assert_allclose(C, expected, rtol=1e-2, atol=1e-2) + + def test_matmul_padded_40x40(self): + """Matmul with non-tile-aligned dimensions: actual 40x40, padded to 48x48.""" + M_actual, N_actual, K = 40, 40, 64 + M_pad, N_pad = 48, 48 # next multiple of 16 + + np.random.seed(42) + A_actual = (np.random.randn(M_actual, K) * 0.1).astype(np.float16) + B_actual = (np.random.randn(N_actual, K) * 0.1).astype(np.float16) + + # Pad A and B to tile-aligned sizes (zero-fill padding region). + A_padded = np.zeros((M_pad, K), dtype=np.float16) + A_padded[:M_actual, :] = A_actual + B_padded = np.zeros((N_pad, K), dtype=np.float16) + B_padded[:N_actual, :] = B_actual + + C = np.zeros(M_pad * N_pad, dtype=np.float32) + + def padded_preprocess(mlir_text): + return _air_preprocess_with_files( + mlir_text, _PADDED_TRANSFORM_FILE) + + compile_and_run( + file_name=_PADDED_MLIR_FILE, + kernel_name="matmul_f16_48x48", + input_data=[A_padded.flatten(), B_padded.flatten()], + output_data=[C], + pass_pipeline=_post_air_pipeline(_LIBRARY_PATHS), + library_paths=[], + grid_dim=(1, 1, 1), + block_dim=(192, 1, 1), # 3 wavefronts (3x1 herd) + preprocess=padded_preprocess, + ) + + # Extract valid 40x40 region and compare. + C_2d = C.reshape(M_pad, N_pad) + C_valid = C_2d[:M_actual, :N_actual].flatten() + expected = (A_actual.astype(np.float32) @ B_actual.T.astype(np.float32)).flatten() + np.testing.assert_allclose(C_valid, expected, rtol=1e-2, atol=1e-2) From d20bc7e6a35721897540112b1a950b45f968d70e Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 09/16] [mlir-air] WIP: Compiler-level padding with LDS C accumulation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Infrastructure for non-tile-aligned matmul (40x40x64) with compiler-level padding. Transform pads all three operands (A, B, C) to tile size 16. C accumulates in LDS, then copies back to global (only valid region). New library functions: - store_lds_C_mfma_f32_16x16x16_f16: AGPR→LDS via ds_write_b32 - mfma_matmul_lds_c_f16_16x64: matmul with all-LDS operands - fill_f32_16x16, copy_f32_16x16: LDS C init and global→LDS copy - store_global_f32_16x16: LDS→global C writeback Pass changes: - DmaToLibraryCall: handle both global→LDS and LDS→global directions - GenericMatmulToLibraryCall: _lds_c suffix when output is in LDS - Register air-override-memref-memory-space pass - Register tensor ValueBounds/SubsetOp interfaces BLOCKING: amdgcn-preload-library hangs on the padded test's IR when library functions with bodies (store_global_f32_16x16 etc.) interact with preloaded external library functions. The 64x64 test still passes. Needs investigation of the preload pass's internal inlining behavior. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- .../kittens/library/compute_16x16_f16.mlir | 76 ++++++++ .../lib/ConvertToAMDGCNLibraryCalls.cpp | 85 ++++++--- contrib/mlir-air/lib/Init.cpp | 3 + ...air-to-amdgcn-matmul-padded-transform.mlir | 41 ++-- .../test/air-to-amdgcn-matmul-padded.mlir | 175 ++++++++++++++++-- .../test/integration/test_air_matmul_e2e.py | 67 ++++--- 6 files changed, 366 insertions(+), 81 deletions(-) diff --git a/contrib/kittens/library/compute_16x16_f16.mlir b/contrib/kittens/library/compute_16x16_f16.mlir index 125e8e938..9a78b2475 100644 --- a/contrib/kittens/library/compute_16x16_f16.mlir +++ b/contrib/kittens/library/compute_16x16_f16.mlir @@ -17,6 +17,7 @@ amdgcn.library @kittens_compute_16x16_f16 isa = [#amdgcn.isa] { // From register-init.mlir func.func private @init_agprx4(i32) -> !ax4 + func.func private @alloc_vgpr() -> !amdgcn.vgpr // From indexing.mlir func.func private @mfma_index_C_16x16_f32() -> !index_pair func.func private @mfma_c_16x16_f32_byte_offset(index, index, index, index, index, index, index) -> index @@ -99,4 +100,79 @@ amdgcn.library @kittens_compute_16x16_f16 isa = [#amdgcn.isa] { return } + + //===--------------------------------------------------------------------===// + // C tile store to LDS (AGPR → LDS, row-major 16x16 f32) + //===--------------------------------------------------------------------===// + + // Store a 16x16 f32 C tile from AGPRs to LDS in row-major layout. + // LDS layout: 16 rows × 16 cols × 4 bytes = 1024 bytes, stride = 64 bytes/row. + // Each thread writes 4 f32 values at its MFMA C fragment positions. + func.func private @store_lds_C_mfma_f32_16x16x16_f16(%tile: !rt_C_f32, %lds_base: index) { + %mfma_idx = func.call @mfma_index_C_16x16_f32() : () -> !index_pair + %col, %row_base = aster_utils.struct_extract %mfma_idx ["i", "j"] : !index_pair -> index, index + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0_i32 = arith.constant 0 : i32 + + %a0, %a1, %a2, %a3 = amdgcn.split_register_range %tile : !ax4 + %col_off = arith.muli %col, %c4 : index + + // Manually unrolled: 4 AGPR stores to LDS at consecutive rows. + %row0 = arith.addi %row_base, %c0 : index + %row1 = arith.addi %row_base, %c1 : index + %row2 = arith.addi %row_base, %c2 : index + %row3 = arith.addi %row_base, %c3 : index + + // Row 0 + %r0off = arith.muli %row0, %c64 : index + %b0 = arith.addi %r0off, %col_off : index + %addr0 = arith.addi %lds_base, %b0 : index + %ai0 = arith.index_cast %addr0 : index to i32 + %av0 = lsir.to_reg %ai0 : i32 -> !amdgcn.vgpr + %dv0 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr + %vv0 = lsir.copy %dv0, %a0 : !amdgcn.vgpr, !a + amdgcn.store ds_write_b32 data %vv0 addr %av0 offset c(%c0_i32) + : ins(!amdgcn.vgpr, !amdgcn.vgpr, i32) -> !amdgcn.write_token + + // Row 1 + %r1off = arith.muli %row1, %c64 : index + %b1 = arith.addi %r1off, %col_off : index + %addr1 = arith.addi %lds_base, %b1 : index + %ai1 = arith.index_cast %addr1 : index to i32 + %av1 = lsir.to_reg %ai1 : i32 -> !amdgcn.vgpr + %dv1 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr + %vv1 = lsir.copy %dv1, %a1 : !amdgcn.vgpr, !a + amdgcn.store ds_write_b32 data %vv1 addr %av1 offset c(%c0_i32) + : ins(!amdgcn.vgpr, !amdgcn.vgpr, i32) -> !amdgcn.write_token + + // Row 2 + %r2off = arith.muli %row2, %c64 : index + %b2 = arith.addi %r2off, %col_off : index + %addr2 = arith.addi %lds_base, %b2 : index + %ai2 = arith.index_cast %addr2 : index to i32 + %av2 = lsir.to_reg %ai2 : i32 -> !amdgcn.vgpr + %dv2 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr + %vv2 = lsir.copy %dv2, %a2 : !amdgcn.vgpr, !a + amdgcn.store ds_write_b32 data %vv2 addr %av2 offset c(%c0_i32) + : ins(!amdgcn.vgpr, !amdgcn.vgpr, i32) -> !amdgcn.write_token + + // Row 3 + %r3off = arith.muli %row3, %c64 : index + %b3 = arith.addi %r3off, %col_off : index + %addr3 = arith.addi %lds_base, %b3 : index + %ai3 = arith.index_cast %addr3 : index to i32 + %av3 = lsir.to_reg %ai3 : i32 -> !amdgcn.vgpr + %dv3 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr + %vv3 = lsir.copy %dv3, %a3 : !amdgcn.vgpr, !a + amdgcn.store ds_write_b32 data %vv3 addr %av3 offset c(%c0_i32) + : ins(!amdgcn.vgpr, !amdgcn.vgpr, i32) -> !amdgcn.write_token + + return + } } diff --git a/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp index 0af73bd78..58e20abca 100644 --- a/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp +++ b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp @@ -188,7 +188,8 @@ decomposeGlobalMemref(OpBuilder &builder, Location loc, Value memref) { // Patterns // --------------------------------------------------------------------------- -/// Convert air.dma_memcpy_nd (Global→LDS) to a library call. +/// Convert air.dma_memcpy_nd to a library call. +/// Handles both directions: Global→LDS and LDS→Global. struct DmaToLibraryCall : public OpRewritePattern { ConversionContext &convCtx; @@ -199,43 +200,70 @@ struct DmaToLibraryCall LogicalResult matchAndRewrite(xilinx::air::DmaMemcpyNdOp dma, PatternRewriter &rewriter) const override { Value dst = dma.getDstMemref(); - if (!isPromotedBuffer(dst)) + Value src = dma.getSrcMemref(); + bool dstIsLDS = isPromotedBuffer(dst); + bool srcIsLDS = isPromotedBuffer(src); + if (!dstIsLDS && !srcIsLDS) return failure(); - Value src = dma.getSrcMemref(); - auto dstTy = cast(dst.getType()); Location loc = dma.getLoc(); - std::string name = buildFuncName("copy", dstTy); - auto indexTy = rewriter.getIndexType(); auto sx2Ty = amdgcn::SGPRType::get(rewriter.getContext(), Register(), /*size=*/2, /*alignment=*/2); SmallVector callArgs; SmallVector argTypes; - auto [ptrVal, byteStride] = decomposeGlobalMemref(rewriter, loc, src); - callArgs.push_back(ptrVal); - argTypes.push_back(sx2Ty); - callArgs.push_back(byteStride); - argTypes.push_back(indexTy); - - auto srcOffsets = dma.getSrcOffsets(); - if (srcOffsets.size() >= 2) { - callArgs.push_back(srcOffsets[0]); + // Use the LDS memref type for function naming. + // Append direction suffix to distinguish global→LDS from LDS→global. + auto ldsTy = cast(dstIsLDS ? dst.getType() : src.getType()); + std::string name = buildFuncName( + srcIsLDS ? "store_global" : "copy", ldsTy); + + if (dstIsLDS && !srcIsLDS) { + // Global→LDS: copy(global_ptr, stride, row, col, lds_dst) + auto [ptrVal, byteStride] = decomposeGlobalMemref(rewriter, loc, src); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); argTypes.push_back(indexTy); - callArgs.push_back(srcOffsets[1]); + auto srcOffsets = dma.getSrcOffsets(); + if (srcOffsets.size() >= 2) { + callArgs.push_back(srcOffsets[0]); + callArgs.push_back(srcOffsets[1]); + } else { + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + callArgs.push_back(zero); + callArgs.push_back(zero); + } argTypes.push_back(indexTy); - } else { - Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); - callArgs.push_back(zero); argTypes.push_back(indexTy); - callArgs.push_back(zero); + callArgs.push_back(emitLDSOffset(rewriter, loc, dst, convCtx)); + argTypes.push_back(indexTy); + } else if (srcIsLDS && !dstIsLDS) { + // LDS→Global: copy(lds_src, global_ptr, stride, row, col) + callArgs.push_back(emitLDSOffset(rewriter, loc, src, convCtx)); argTypes.push_back(indexTy); + auto [ptrVal, byteStride] = decomposeGlobalMemref(rewriter, loc, dst); + callArgs.push_back(ptrVal); + argTypes.push_back(sx2Ty); + callArgs.push_back(byteStride); + argTypes.push_back(indexTy); + auto dstOffsets = dma.getDstOffsets(); + if (dstOffsets.size() >= 2) { + callArgs.push_back(dstOffsets[0]); + callArgs.push_back(dstOffsets[1]); + } else { + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + callArgs.push_back(zero); + callArgs.push_back(zero); + } + argTypes.push_back(indexTy); + argTypes.push_back(indexTy); + } else { + // LDS→LDS: not supported. + return failure(); } - callArgs.push_back(emitLDSOffset(rewriter, loc, dst, convCtx)); - argTypes.push_back(indexTy); - auto funcTy = rewriter.getFunctionType(argTypes, {}); ensureDecl(rewriter, *convCtx.declBlock, loc, name, funcTy); func::CallOp::create(rewriter, loc, name, TypeRange{}, callArgs); @@ -346,9 +374,16 @@ struct GenericMatmulToLibraryCall : public OpRewritePattern { if (op.getNumDpsInputs() != 2 || op.getNumDpsInits() != 1 || op.getNumReductionLoops() != 1) return failure(); - // Delegate to the generic linalg pattern. + // Use _lds_c suffix when the output operand is in LDS. + StringRef prefix = "mfma_matmul"; + bool outputIsLDS = false; + for (Value out : op.getDpsInits()) + if (isPromotedBuffer(out)) + outputIsLDS = true; + std::string prefixStr = + outputIsLDS ? "mfma_matmul_lds_c" : "mfma_matmul"; LinalgToLibraryCall inner( - rewriter.getContext(), convCtx, "mfma_matmul"); + rewriter.getContext(), convCtx, prefixStr); return inner.matchAndRewrite(op, rewriter); } }; diff --git a/contrib/mlir-air/lib/Init.cpp b/contrib/mlir-air/lib/Init.cpp index 5e8d392b8..a415a5ec8 100644 --- a/contrib/mlir-air/lib/Init.cpp +++ b/contrib/mlir-air/lib/Init.cpp @@ -10,6 +10,7 @@ #include "air/Dialect/AIR/AIRDialect.h" #include "air/Dialect/AIR/AIRTransformOps.h" #include "air/Transform/AIRDmaToChannel.h" +#include "air/Transform/AIRMiscPasses.h" // Tablegen-generated per-pass registration for upstream AIR passes. namespace air_conv_reg { @@ -22,6 +23,7 @@ namespace air_conv_reg { namespace air_xform_reg { #define GEN_PASS_REGISTRATION_DMATOCHANNEL +#define GEN_PASS_REGISTRATION_AIROVERRIDEMEMREFMEMORYSPACE #include "air/Transform/Passes.h.inc" } // namespace air_xform_reg @@ -110,6 +112,7 @@ void registerAll(DialectRegistry ®istry) { air_conv_reg::registerParallelToLaunch(); // air-par-to-launch air_conv_reg::registerAIRWrapFuncWithParallelPass(); // air-wrap-func-with-parallel air_xform_reg::registerDmaToChannel(); // air-dma-to-channel + air_xform_reg::registerAIROverrideMemRefMemorySpace(); // Aster-specific passes. registerPass([] { return createAirToAMDGCN(); }); diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir index c7bcdb708..f0a37c842 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir @@ -1,7 +1,9 @@ -// Transform sequence for 48x48x64 matmul (host-padded from 40x40x64). -// tile_using_forall [16,0,0] → 3 wavefronts (48/16=3, exact). +// Transform sequence for padded matmul: M=40, N=40, K=64. +// tile_using_forall [16,0,0] → 3 wavefronts (ceil(40/16)=3). // tile_using_for [16,16,0] → 16x16 compute tiles, untiled K. -// All tiles are full (48 is a multiple of 16). No boundary padding needed. +// Boundary tiles (8 rows or 8 cols) are padded to 16. +// bufferize_to_allocation on pad ops BEFORE DPS rewrite uses the +// dedicated PadOp allocation path that handles the full pad result. module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main( @@ -21,31 +23,36 @@ module attributes {transform.with_named_sequence} { : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - // Pad A and B to tile size (16). Boundary tiles (8 rows/cols) get - // zero-padded to 16. + // Pad all operands. Boundary tiles get zero-padded to 16. %padded, %pad, %copy_back = transform.structured.pad %tiled { padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32], padding_dimensions = [0, 1, 2], - pack_paddings = [1, 1, 0], - nofold_flags = [1, 1, 0], - copy_back_op = "linalg.copy" + pack_paddings = [1, 1, 1], + nofold_flags = [1, 1, 1] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad - : (!transform.any_op) -> !transform.any_op - // Promote padded A,B to LDS (memory_space=2). + // Pre-allocate pad buffers with explicit memory spaces BEFORE DPS + // rewrite. The PadOp-specific path in bufferize_to_allocation handles + // the full pad (fill + copy) in one allocation. + // A, B → LDS (memory_space=2). C → global (memory_space=0 is default, + // but we don't pre-allocate C — it writes directly to the global subview). %padded_lhs = transform.get_producer_of_operand %padded[0] : (!transform.any_op) -> (!transform.any_op) %buf_a, %new_a = transform.structured.bufferize_to_allocation %padded_lhs - {memory_space = 2, bufferize_destination_only} - : !transform.any_op + {memory_space = 2} : !transform.any_op %padded_rhs = transform.get_producer_of_operand %padded[1] : (!transform.any_op) -> (!transform.any_op) %buf_b, %new_b = transform.structured.bufferize_to_allocation %padded_rhs - {memory_space = 2, bufferize_destination_only} - : !transform.any_op + {memory_space = 2} : !transform.any_op + + // C pad: allocate in LDS (memory_space=2). The matmul computes the full + // 16x16 tile in LDS, then copies back only the valid region to global C. + %padded_out = transform.get_producer_of_operand %padded[2] + : (!transform.any_op) -> (!transform.any_op) + %buf_c, %new_c = transform.structured.bufferize_to_allocation %padded_out + {memory_space = 2} : !transform.any_op // Canonicalize. %func_0 = transform.structured.match ops{["func.func"]} in %arg0 @@ -74,13 +81,11 @@ module attributes {transform.with_named_sequence} { %func_3 = transform.air.remove_uninitialized_copy %func_2 : (!transform.any_op) -> !transform.any_op - // Convert outer forall → parallel → air.herd. + // Convert outer forall → parallel (par_to_herd runs as pipeline pass). %forall_2 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op %parallel = transform.loop.forall_to_parallel %forall_2 : (!transform.any_op) -> !transform.any_op - %herd = transform.air.par_to_herd %parallel - : (!transform.any_op) -> !transform.any_op transform.yield } diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir index c1d1f448b..1bd91776d 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir @@ -1,7 +1,7 @@ -// Padded matmul E2E test: actual dimensions M=40, N=40, K=64. -// Host pads inputs to tile-aligned sizes: A=48x64, B=48x64, C=48x48. -// Kernel operates on full tile-aligned dimensions; padding zeros produce -// zero output in the padded region. Host extracts valid C[0:40, 0:40]. +// Padded matmul: actual M=40, N=40, K=64. +// Kernel operates on actual (non-tile-aligned) dimensions. +// The transform pads boundary tiles to full 16-element tiles. +// Host over-allocates C (48*48 elements) to accommodate OOB boundary stores. !sx2 = !amdgcn.sgpr<[? + 2]> !vx2 = !amdgcn.vgpr<[? + 2]> @@ -26,6 +26,12 @@ module { func.func private @load_lds_B_swizzled( index, index, index) -> !future_lds_read func.func private @get_lds_read_value_vx2(!future_lds_read) -> !vx2 + func.func private @fill_lds_16x64_b(index) + func.func private @store_lds_C_mfma_f32_16x16x16_f16(!ax4, index) + func.func private @mfma_index_C_16x16_f32() -> !aster_utils.struct + func.func private @mfma_c_16x16_f32_byte_offset(index, index, index, index, index, index, index) -> index + func.func private @global_addr_from_offset(!sx2, index) -> !vx2 + func.func private @alloc_vgpr() -> !amdgcn.vgpr func.func private @copy_f16_16x64( %src_ptr: !sx2, %src_stride: index, @@ -105,20 +111,155 @@ module { func.func private @fill_f16_16x64(%val: f16, %lds_dst: index) { return } func.func private @fill_f16_16x32(%val: f16, %lds_dst: index) { return } + + // Zero-fill 16x16 f32 LDS tile (1024 bytes). + // Uses MFMA C fragment layout: each thread zeros its 4 positions. + func.func private @fill_f32_16x16(%val: f32, %lds_dst: index) { + func.call @fill_lds_16x64_b(%lds_dst) : (index) -> () + return + } + + // Copy 16x16 f32 tile from global to LDS. + // Reuses the 16x64-byte tile load (16x16 f32 = 16 rows × 64 bytes). + func.func private @copy_f32_16x16( + %src_ptr: !sx2, %src_stride: index, + %row_offset: index, %col_offset: index, + %lds_dst: index) { + %ptr = func.call @prepare_ptr(%src_ptr) : (!sx2) -> !aster_utils.any + %gfut = func.call @load_global_tile_16x64_b( + %ptr, %row_offset, %col_offset, %src_stride) + : (!aster_utils.any, index, index, index) -> !future_global_read + %t0, %t1 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst, %gfut) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + amdgcn.wait deps %t0 : !lds_write_token + amdgcn.wait deps %t1 : !lds_write_token + return + } + + // MFMA matmul with all operands in LDS (including C accumulator). + // Stores result to LDS C via store_lds_C (not global_store). + func.func private @mfma_matmul_lds_c_f16_16x64( + %lds_A: index, %lds_B: index, %lds_C: index) { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %lds_A2 = arith.addi %lds_A, %c1024 : index + %lds_B2 = arith.addi %lds_B, %c1024 : index + %acc = func.call @zero_C() : () -> !ax4 + %A0f = func.call @load_lds_A_swizzled(%lds_A, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A0 = func.call @get_lds_read_value_vx2(%A0f) : (!future_lds_read) -> !vx2 + %B0f = func.call @load_lds_B_swizzled(%lds_B, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B0 = func.call @get_lds_read_value_vx2(%B0f) : (!future_lds_read) -> !vx2 + %acc0 = func.call @mfma_f32_16x16x16_f16(%A0, %B0, %acc) + : (!vx2, !vx2, !ax4) -> !ax4 + %A1f = func.call @load_lds_A_swizzled(%lds_A, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A1 = func.call @get_lds_read_value_vx2(%A1f) : (!future_lds_read) -> !vx2 + %B1f = func.call @load_lds_B_swizzled(%lds_B, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B1 = func.call @get_lds_read_value_vx2(%B1f) : (!future_lds_read) -> !vx2 + %acc1 = func.call @mfma_f32_16x16x16_f16(%A1, %B1, %acc0) + : (!vx2, !vx2, !ax4) -> !ax4 + %A2f = func.call @load_lds_A_swizzled(%lds_A2, %c0, %c2) + : (index, index, index) -> !future_lds_read + %A2 = func.call @get_lds_read_value_vx2(%A2f) : (!future_lds_read) -> !vx2 + %B2f = func.call @load_lds_B_swizzled(%lds_B2, %c0, %c2) + : (index, index, index) -> !future_lds_read + %B2 = func.call @get_lds_read_value_vx2(%B2f) : (!future_lds_read) -> !vx2 + %acc2 = func.call @mfma_f32_16x16x16_f16(%A2, %B2, %acc1) + : (!vx2, !vx2, !ax4) -> !ax4 + %A3f = func.call @load_lds_A_swizzled(%lds_A2, %c32, %c2) + : (index, index, index) -> !future_lds_read + %A3 = func.call @get_lds_read_value_vx2(%A3f) : (!future_lds_read) -> !vx2 + %B3f = func.call @load_lds_B_swizzled(%lds_B2, %c32, %c2) + : (index, index, index) -> !future_lds_read + %B3 = func.call @get_lds_read_value_vx2(%B3f) : (!future_lds_read) -> !vx2 + %acc3 = func.call @mfma_f32_16x16x16_f16(%A3, %B3, %acc2) + : (!vx2, !vx2, !ax4) -> !ax4 + // Store to LDS C (not global). + func.call @store_lds_C_mfma_f32_16x16x16_f16(%acc3, %lds_C) + : (!ax4, index) -> () + return + } + + // Copy 16x16 f32 tile from LDS to global (C writeback). + // Uses MFMA C fragment layout: each thread reads its 4 values from LDS + // and writes to global memory. + func.func private @store_global_f32_16x16( + %lds_src: index, + %dst_ptr: !sx2, %dst_stride: index, + %row_offset: index, %col_offset: index) { + // Read from LDS and write to global using the MFMA C fragment layout. + %dst_prepared = func.call @prepare_ptr(%dst_ptr) : (!sx2) -> !aster_utils.any + %mfma_idx = func.call @mfma_index_C_16x16_f32() : () -> !aster_utils.struct + %col, %row_base = aster_utils.struct_extract %mfma_idx ["i", "j"] + : !aster_utils.struct -> index, index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %c64 = arith.constant 64 : index + %elt_size = arith.constant 4 : index + %col_byte = arith.muli %col, %elt_size : index + // Manually unrolled: 4 iterations for MFMA C fragment (4 consecutive rows). + // Avoids {aster.constexpr} loop which causes --inline canonicalize hang. + %off0 = func.call @mfma_c_16x16_f32_byte_offset(%row_offset, %col_offset, %row_base, %col, %dst_stride, %elt_size, %c0) : (index, index, index, index, index, index, index) -> index + %off1 = func.call @mfma_c_16x16_f32_byte_offset(%row_offset, %col_offset, %row_base, %col, %dst_stride, %elt_size, %c1) : (index, index, index, index, index, index, index) -> index + %off2 = func.call @mfma_c_16x16_f32_byte_offset(%row_offset, %col_offset, %row_base, %col, %dst_stride, %elt_size, %c2) : (index, index, index, index, index, index, index) -> index + %off3 = func.call @mfma_c_16x16_f32_byte_offset(%row_offset, %col_offset, %row_base, %col, %dst_stride, %elt_size, %c3) : (index, index, index, index, index, index, index) -> index + %addr0 = func.call @global_addr_from_offset(%dst_ptr, %off0) : (!sx2, index) -> !vx2 + %addr1 = func.call @global_addr_from_offset(%dst_ptr, %off1) : (!sx2, index) -> !vx2 + %addr2 = func.call @global_addr_from_offset(%dst_ptr, %off2) : (!sx2, index) -> !vx2 + %addr3 = func.call @global_addr_from_offset(%dst_ptr, %off3) : (!sx2, index) -> !vx2 + // Read 4 f32 values from LDS at MFMA C fragment positions. + %row0_off = arith.muli %row_base, %c64 : index + %lds0 = arith.addi %lds_src, %row0_off : index + %lds0b = arith.addi %lds0, %col_byte : index + %lds1b = arith.addi %lds0b, %c64 : index + %lds2b = arith.addi %lds1b, %c64 : index + %lds3b = arith.addi %lds2b, %c64 : index + %la0 = arith.index_cast %lds0b : index to i32 + %la1 = arith.index_cast %lds1b : index to i32 + %la2 = arith.index_cast %lds2b : index to i32 + %la3 = arith.index_cast %lds3b : index to i32 + %lv0 = lsir.to_reg %la0 : i32 -> !amdgcn.vgpr + %lv1 = lsir.to_reg %la1 : i32 -> !amdgcn.vgpr + %lv2 = lsir.to_reg %la2 : i32 -> !amdgcn.vgpr + %lv3 = lsir.to_reg %la3 : i32 -> !amdgcn.vgpr + %c0_i32 = arith.constant 0 : i32 + %d0 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr + %d1 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr + %d2 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr + %d3 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr + %v0, %t0 = amdgcn.load ds_read_b32 dest %d0 addr %lv0 offset c(%c0_i32) : dps(!amdgcn.vgpr) ins(!amdgcn.vgpr, i32) -> !amdgcn.read_token + %v1, %t1 = amdgcn.load ds_read_b32 dest %d1 addr %lv1 offset c(%c0_i32) : dps(!amdgcn.vgpr) ins(!amdgcn.vgpr, i32) -> !amdgcn.read_token + %v2, %t2 = amdgcn.load ds_read_b32 dest %d2 addr %lv2 offset c(%c0_i32) : dps(!amdgcn.vgpr) ins(!amdgcn.vgpr, i32) -> !amdgcn.read_token + %v3, %t3 = amdgcn.load ds_read_b32 dest %d3 addr %lv3 offset c(%c0_i32) : dps(!amdgcn.vgpr) ins(!amdgcn.vgpr, i32) -> !amdgcn.read_token + amdgcn.wait deps %t0 : !amdgcn.read_token + amdgcn.wait deps %t1 : !amdgcn.read_token + amdgcn.wait deps %t2 : !amdgcn.read_token + amdgcn.wait deps %t3 : !amdgcn.read_token + // Fire-and-forget global stores. + amdgcn.store global_store_dword data %v0 addr %addr0 : ins(!amdgcn.vgpr, !vx2) -> !amdgcn.write_token + amdgcn.store global_store_dword data %v1 addr %addr1 : ins(!amdgcn.vgpr, !vx2) -> !amdgcn.write_token + amdgcn.store global_store_dword data %v2 addr %addr2 : ins(!amdgcn.vgpr, !vx2) -> !amdgcn.write_token + amdgcn.store global_store_dword data %v3 addr %addr3 : ins(!amdgcn.vgpr, !vx2) -> !amdgcn.write_token + return + } } amdgcn.module @matmul_mod target = #amdgcn.target isa = #amdgcn.isa { - // 48x48x64 matmul on tile-aligned dimensions. - // Host pads actual 40x40 data to these sizes. - func.func @matmul_f16_48x48( - %A: memref<48x64xf16>, %B: memref<48x64xf16>, %C: memref<48x48xf32>) + func.func @matmul_f16_40x40( + %A: memref<40x64xf16>, %B: memref<40x64xf16>, %C: memref<40x40xf32>) attributes {gpu.kernel} { %cst = arith.constant 0.000000e+00 : f32 - %a = bufferization.to_tensor %A restrict writable : memref<48x64xf16> to tensor<48x64xf16> - %b = bufferization.to_tensor %B restrict writable : memref<48x64xf16> to tensor<48x64xf16> - %c = bufferization.to_tensor %C restrict writable : memref<48x48xf32> to tensor<48x48xf32> - %fill = linalg.fill ins(%cst : f32) outs(%c : tensor<48x48xf32>) -> tensor<48x48xf32> - // matmul_transpose_b: C[m,n] += A[m,k] * B[n,k] + %a = bufferization.to_tensor %A restrict writable : memref<40x64xf16> to tensor<40x64xf16> + %b = bufferization.to_tensor %B restrict writable : memref<40x64xf16> to tensor<40x64xf16> + %c = bufferization.to_tensor %C restrict writable : memref<40x40xf32> to tensor<40x40xf32> + %fill = linalg.fill ins(%cst : f32) outs(%c : tensor<40x40xf32>) -> tensor<40x40xf32> %result = linalg.generic { indexing_maps = [ affine_map<(m, n, k) -> (m, k)>, @@ -126,17 +267,17 @@ module { affine_map<(m, n, k) -> (m, n)> ], iterator_types = ["parallel", "parallel", "reduction"] - } ins(%a, %b : tensor<48x64xf16>, tensor<48x64xf16>) - outs(%fill : tensor<48x48xf32>) { + } ins(%a, %b : tensor<40x64xf16>, tensor<40x64xf16>) + outs(%fill : tensor<40x40xf32>) { ^bb0(%av: f16, %bv: f16, %cv: f32): %a_ext = arith.extf %av : f16 to f32 %b_ext = arith.extf %bv : f16 to f32 %prod = arith.mulf %a_ext, %b_ext : f32 %sum = arith.addf %cv, %prod : f32 linalg.yield %sum : f32 - } -> tensor<48x48xf32> + } -> tensor<40x40xf32> bufferization.materialize_in_destination %result in writable %C - : (tensor<48x48xf32>, memref<48x48xf32>) -> () + : (tensor<40x40xf32>, memref<40x40xf32>) -> () return } } diff --git a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py index 766aa2439..f0adaea02 100644 --- a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py +++ b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py @@ -143,30 +143,55 @@ def test_matmul_64x64(self): np.testing.assert_allclose(C, expected, rtol=1e-2, atol=1e-2) def test_matmul_padded_40x40(self): - """Matmul with non-tile-aligned dimensions: actual 40x40, padded to 48x48.""" - M_actual, N_actual, K = 40, 40, 64 - M_pad, N_pad = 48, 48 # next multiple of 16 + """Matmul with non-tile-aligned dimensions: actual 40x40x64. - np.random.seed(42) - A_actual = (np.random.randn(M_actual, K) * 0.1).astype(np.float16) - B_actual = (np.random.randn(N_actual, K) * 0.1).astype(np.float16) + Kernel operates on actual dimensions (memref<40x64xf16>, memref<40x40xf32>). + Transform pads boundary tiles to 16. Host over-allocates C to 48*48 + to accommodate OOB stores from boundary tiles. + """ + M, N, K = 40, 40, 64 + M_pad = 48 # next multiple of 16, for C over-allocation - # Pad A and B to tile-aligned sizes (zero-fill padding region). - A_padded = np.zeros((M_pad, K), dtype=np.float16) - A_padded[:M_actual, :] = A_actual - B_padded = np.zeros((N_pad, K), dtype=np.float16) - B_padded[:N_actual, :] = B_actual + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np.float16) + B_T = (np.random.randn(N, K) * 0.1).astype(np.float16) - C = np.zeros(M_pad * N_pad, dtype=np.float32) + # Over-allocate C: kernel writes full 16x16 tiles at boundaries, + # going beyond 40x40. Allocate 48*48 elements but pass as 40x40 memref. + C = np.zeros(M_pad * M_pad, dtype=np.float32) def padded_preprocess(mlir_text): - return _air_preprocess_with_files( - mlir_text, _PADDED_TRANSFORM_FILE) + opt = _find_mlir_air_opt() + result = subprocess.run( + [ + opt, + f"--transform-preload-library=transform-library-paths={_PADDED_TRANSFORM_FILE}", + "--transform-interpreter", + "--canonicalize", "--cse", + # Set memory_space=2 on padding allocs (no memory space → L1). + "--air-override-memref-memory-space=scope=func memory-space=2", + "--air-par-to-herd", + "--canonicalize", "--cse", + "--air-par-to-launch=has-air-segment=true", + "--canonicalize", "--cse", + "--air-copy-to-dma", + "--air-to-amdgcn", + "--canonicalize", + "--convert-memspace-to-amdgcn", + "--convert-to-amdgcn-library-calls", + ], + input=mlir_text, + capture_output=True, text=True, + ) + if result.returncode != 0: + raise RuntimeError( + f"mlir-air-opt padded preprocessing failed:\n{result.stderr}") + return result.stdout compile_and_run( file_name=_PADDED_MLIR_FILE, - kernel_name="matmul_f16_48x48", - input_data=[A_padded.flatten(), B_padded.flatten()], + kernel_name="matmul_f16_40x40", + input_data=[A.flatten(), B_T.flatten()], output_data=[C], pass_pipeline=_post_air_pipeline(_LIBRARY_PATHS), library_paths=[], @@ -175,8 +200,8 @@ def padded_preprocess(mlir_text): preprocess=padded_preprocess, ) - # Extract valid 40x40 region and compare. - C_2d = C.reshape(M_pad, N_pad) - C_valid = C_2d[:M_actual, :N_actual].flatten() - expected = (A_actual.astype(np.float32) @ B_actual.T.astype(np.float32)).flatten() - np.testing.assert_allclose(C_valid, expected, rtol=1e-2, atol=1e-2) + # Extract valid 40x40 region (C is over-allocated as flat 48*48). + # The kernel writes with stride=40, so reinterpret accordingly. + C_2d = C[:M * M_pad].reshape(-1, M_pad)[:M, :N].flatten() + expected = (A.astype(np.float32) @ B_T.T.astype(np.float32)).flatten() + np.testing.assert_allclose(C_2d, expected, rtol=1e-2, atol=1e-2) From bf652bb3694ea8f8945677377d20e4ddbd18e99d Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 10/16] [mlir-air] Fix PreloadLibrary infinite loop + WIP padded matmul MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix a bug in PreloadLibrary where declarationsToReplace replaces a declaration with another declaration from the library map, returning true forever. Guard: only replace if the library entry has a body. WIP: compiler-level padded matmul (40x40x64) with LDS C accumulation. Infrastructure works through convert-to-amdgcn-library-calls, but the aster backend rejects non-thread-uniform loops from non-divisible tiling. New: fill_lds_16x64_b in lds_16x64_b.mlir for LDS zero-fill. BLOCKING: "only thread-uniform loops are supported" — the scf.for loop from tile_using_for with affine.min bounds varies per wavefront. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- contrib/kittens/library/lds_16x64_b.mlir | 18 ++++++++++++++++++ .../AMDGCN/Transforms/PreloadLibrary.cpp | 5 ++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/contrib/kittens/library/lds_16x64_b.mlir b/contrib/kittens/library/lds_16x64_b.mlir index 8b4679ea0..a644e1a8c 100644 --- a/contrib/kittens/library/lds_16x64_b.mlir +++ b/contrib/kittens/library/lds_16x64_b.mlir @@ -138,4 +138,22 @@ amdgcn.library @kittens_lds_16x64_b isa = [#amdgcn.isa] { return %future : !future_lds_read } + //===--------------------------------------------------------------------===// + // Zero-fill a 16x64_b LDS tile (1024 bytes). + //===--------------------------------------------------------------------===// + + // Each thread writes 16 bytes of zeros at its assigned positions. + func.func private @fill_lds_16x64_b(%tile_base: index) { + %zero = func.call @alloc_vgprx2() : () -> !vx2 + %addr_lo, %addr_hi = func.call @compute_lds_write_addrs_16x64_b(%tile_base) + : (index) -> (index, index) + %t0 = func.call @write_vx2_to_lds_at(%zero, %addr_lo) + : (!vx2, index) -> !future_lds_write + %t1 = func.call @write_vx2_to_lds_at(%zero, %addr_hi) + : (!vx2, index) -> !future_lds_write + amdgcn.wait deps %t0 : !future_lds_write + amdgcn.wait deps %t1 : !future_lds_write + return + } + } diff --git a/lib/Dialect/AMDGCN/Transforms/PreloadLibrary.cpp b/lib/Dialect/AMDGCN/Transforms/PreloadLibrary.cpp index b46f9111a..13982acc8 100644 --- a/lib/Dialect/AMDGCN/Transforms/PreloadLibrary.cpp +++ b/lib/Dialect/AMDGCN/Transforms/PreloadLibrary.cpp @@ -138,7 +138,10 @@ bool PreloadLibrary::processModule( for (auto funcOp : module.getOps()) { if (funcOp.isDeclaration()) { StringRef name = funcOp.getSymName(); - if (libraryFunctions.contains(name)) { + auto it = libraryFunctions.find(name); + // Only replace if the library has a definition (not another declaration). + if (it != libraryFunctions.end() && + !it->second->getRegions().front().empty()) { declarationsToReplace.push_back(funcOp); } } From 9ca7081761101fa61948c94747cca33a1e67f8a4 Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 11/16] =?UTF-8?q?[mlir-air]=20Use=20pad=5Ftiling=5Finterfa?= =?UTF-8?q?ce=20for=20padded=20matmul=20(40=E2=86=9248)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace per-tile padding approach with pad_tiling_interface which pads the entire iteration domain from 40→48 BEFORE tiling. This eliminates: - affine.min bounds (all tiles are full, 48 % 16 == 0) - Non-uniform loops rejected by amdgcn-convert-scf-control-flow - Dynamic allocs without memory_space - C-in-LDS complexity (C stays in global, same as 64x64 test) The padded allocs (48x64, 48x48) are at function level, outside the herd. Library functions are identical to the 64x64 test. BLOCKING: lsir.reg_cast normal form violation in aster backend from padded alloc pointer decomposition. 64x64 test still passes. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- ...air-to-amdgcn-matmul-padded-transform.mlir | 57 ++++--- .../test/air-to-amdgcn-matmul-padded.mlir | 157 ++---------------- .../test/integration/test_air_matmul_e2e.py | 45 +---- 3 files changed, 47 insertions(+), 212 deletions(-) diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir index f0a37c842..b8415522d 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir @@ -1,9 +1,6 @@ -// Transform sequence for padded matmul: M=40, N=40, K=64. -// tile_using_forall [16,0,0] → 3 wavefronts (ceil(40/16)=3). -// tile_using_for [16,16,0] → 16x16 compute tiles, untiled K. -// Boundary tiles (8 rows or 8 cols) are padded to 16. -// bufferize_to_allocation on pad ops BEFORE DPS rewrite uses the -// dedicated PadOp allocation path that handles the full pad result. +// Transform sequence for padded matmul: actual M=40, N=40, K=64. +// Uses pad_tiling_interface to pad 40→48 (next multiple of 16) BEFORE tiling. +// After padding, all dimensions are tile-aligned → no affine.min, uniform loops. module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main( @@ -11,48 +8,52 @@ module attributes {transform.with_named_sequence} { %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op - // Outer tiling: forall on M, one wavefront per 16 M rows. + // Step 1: Pad the iteration domain to tile-aligned sizes. + // M: 40→48, N: 40→48, K: unchanged (already 64, divisible by any tile). + %padded_matmul, %pad_op = transform.structured.pad_tiling_interface %matmul + to padding_sizes [16, 16, 0] pad_to_multiple_of + { padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32] } + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + + // Step 2: Outer tiling — 3 wavefronts (48/16=3, exact). %outer_tiled, %outer_forall = - transform.structured.tile_using_forall %matmul + transform.structured.tile_using_forall %padded_matmul tile_sizes [16, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - // Compute tiling: 16x16 output tiles, K untiled. + // Step 3: Compute tiling — 16x16 output tiles, K untiled. + // 48 % 16 == 0 → no affine.min, all tiles are full. %tiled, %lm, %ln = transform.structured.tile_using_for %outer_tiled tile_sizes [16, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) - // Pad all operands. Boundary tiles get zero-padded to 16. + // Step 4: Pad A and B operands to promote to LDS. + // All tiles are full (no boundary tiles) so pad is a no-op on shapes + // but nofold forces allocation for the copy to LDS. %padded, %pad, %copy_back = transform.structured.pad %tiled { padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32], padding_dimensions = [0, 1, 2], - pack_paddings = [1, 1, 1], - nofold_flags = [1, 1, 1] + pack_paddings = [1, 1, 0], + nofold_flags = [1, 1, 0], + copy_back_op = "linalg.copy" } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad + : (!transform.any_op) -> !transform.any_op - // Pre-allocate pad buffers with explicit memory spaces BEFORE DPS - // rewrite. The PadOp-specific path in bufferize_to_allocation handles - // the full pad (fill + copy) in one allocation. - // A, B → LDS (memory_space=2). C → global (memory_space=0 is default, - // but we don't pre-allocate C — it writes directly to the global subview). + // Step 5: Promote padded A,B to LDS (memory_space=2). %padded_lhs = transform.get_producer_of_operand %padded[0] : (!transform.any_op) -> (!transform.any_op) %buf_a, %new_a = transform.structured.bufferize_to_allocation %padded_lhs - {memory_space = 2} : !transform.any_op + {memory_space = 2, bufferize_destination_only} + : !transform.any_op %padded_rhs = transform.get_producer_of_operand %padded[1] : (!transform.any_op) -> (!transform.any_op) %buf_b, %new_b = transform.structured.bufferize_to_allocation %padded_rhs - {memory_space = 2} : !transform.any_op - - // C pad: allocate in LDS (memory_space=2). The matmul computes the full - // 16x16 tile in LDS, then copies back only the valid region to global C. - %padded_out = transform.get_producer_of_operand %padded[2] - : (!transform.any_op) -> (!transform.any_op) - %buf_c, %new_c = transform.structured.bufferize_to_allocation %padded_out - {memory_space = 2} : !transform.any_op + {memory_space = 2, bufferize_destination_only} + : !transform.any_op // Canonicalize. %func_0 = transform.structured.match ops{["func.func"]} in %arg0 @@ -81,11 +82,13 @@ module attributes {transform.with_named_sequence} { %func_3 = transform.air.remove_uninitialized_copy %func_2 : (!transform.any_op) -> !transform.any_op - // Convert outer forall → parallel (par_to_herd runs as pipeline pass). + // Convert outer forall → parallel → air.herd. %forall_2 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op %parallel = transform.loop.forall_to_parallel %forall_2 : (!transform.any_op) -> !transform.any_op + %herd = transform.air.par_to_herd %parallel + : (!transform.any_op) -> !transform.any_op transform.yield } diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir index 1bd91776d..044798122 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir @@ -1,7 +1,7 @@ // Padded matmul: actual M=40, N=40, K=64. -// Kernel operates on actual (non-tile-aligned) dimensions. -// The transform pads boundary tiles to full 16-element tiles. -// Host over-allocates C (48*48 elements) to accommodate OOB boundary stores. +// pad_tiling_interface pads 40→48 at the iteration domain level. +// After padding, all tiles are full (48 % 16 == 0). No boundary tiles. +// C output is 48x48 (padded); host extracts valid 40x40 region. !sx2 = !amdgcn.sgpr<[? + 2]> !vx2 = !amdgcn.vgpr<[? + 2]> @@ -26,13 +26,9 @@ module { func.func private @load_lds_B_swizzled( index, index, index) -> !future_lds_read func.func private @get_lds_read_value_vx2(!future_lds_read) -> !vx2 - func.func private @fill_lds_16x64_b(index) - func.func private @store_lds_C_mfma_f32_16x16x16_f16(!ax4, index) - func.func private @mfma_index_C_16x16_f32() -> !aster_utils.struct - func.func private @mfma_c_16x16_f32_byte_offset(index, index, index, index, index, index, index) -> index - func.func private @global_addr_from_offset(!sx2, index) -> !vx2 - func.func private @alloc_vgpr() -> !amdgcn.vgpr + // Same library functions as the 64x64 test — all tiles are full after + // pad_tiling_interface pads 40→48. func.func private @copy_f16_16x64( %src_ptr: !sx2, %src_stride: index, %row_offset: index, %col_offset: index, @@ -111,147 +107,12 @@ module { func.func private @fill_f16_16x64(%val: f16, %lds_dst: index) { return } func.func private @fill_f16_16x32(%val: f16, %lds_dst: index) { return } - - // Zero-fill 16x16 f32 LDS tile (1024 bytes). - // Uses MFMA C fragment layout: each thread zeros its 4 positions. - func.func private @fill_f32_16x16(%val: f32, %lds_dst: index) { - func.call @fill_lds_16x64_b(%lds_dst) : (index) -> () - return - } - - // Copy 16x16 f32 tile from global to LDS. - // Reuses the 16x64-byte tile load (16x16 f32 = 16 rows × 64 bytes). - func.func private @copy_f32_16x16( - %src_ptr: !sx2, %src_stride: index, - %row_offset: index, %col_offset: index, - %lds_dst: index) { - %ptr = func.call @prepare_ptr(%src_ptr) : (!sx2) -> !aster_utils.any - %gfut = func.call @load_global_tile_16x64_b( - %ptr, %row_offset, %col_offset, %src_stride) - : (!aster_utils.any, index, index, index) -> !future_global_read - %t0, %t1 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst, %gfut) - : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) - amdgcn.wait deps %t0 : !lds_write_token - amdgcn.wait deps %t1 : !lds_write_token - return - } - - // MFMA matmul with all operands in LDS (including C accumulator). - // Stores result to LDS C via store_lds_C (not global_store). - func.func private @mfma_matmul_lds_c_f16_16x64( - %lds_A: index, %lds_B: index, %lds_C: index) { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c32 = arith.constant 32 : index - %c1024 = arith.constant 1024 : index - %lds_A2 = arith.addi %lds_A, %c1024 : index - %lds_B2 = arith.addi %lds_B, %c1024 : index - %acc = func.call @zero_C() : () -> !ax4 - %A0f = func.call @load_lds_A_swizzled(%lds_A, %c0, %c2) - : (index, index, index) -> !future_lds_read - %A0 = func.call @get_lds_read_value_vx2(%A0f) : (!future_lds_read) -> !vx2 - %B0f = func.call @load_lds_B_swizzled(%lds_B, %c0, %c2) - : (index, index, index) -> !future_lds_read - %B0 = func.call @get_lds_read_value_vx2(%B0f) : (!future_lds_read) -> !vx2 - %acc0 = func.call @mfma_f32_16x16x16_f16(%A0, %B0, %acc) - : (!vx2, !vx2, !ax4) -> !ax4 - %A1f = func.call @load_lds_A_swizzled(%lds_A, %c32, %c2) - : (index, index, index) -> !future_lds_read - %A1 = func.call @get_lds_read_value_vx2(%A1f) : (!future_lds_read) -> !vx2 - %B1f = func.call @load_lds_B_swizzled(%lds_B, %c32, %c2) - : (index, index, index) -> !future_lds_read - %B1 = func.call @get_lds_read_value_vx2(%B1f) : (!future_lds_read) -> !vx2 - %acc1 = func.call @mfma_f32_16x16x16_f16(%A1, %B1, %acc0) - : (!vx2, !vx2, !ax4) -> !ax4 - %A2f = func.call @load_lds_A_swizzled(%lds_A2, %c0, %c2) - : (index, index, index) -> !future_lds_read - %A2 = func.call @get_lds_read_value_vx2(%A2f) : (!future_lds_read) -> !vx2 - %B2f = func.call @load_lds_B_swizzled(%lds_B2, %c0, %c2) - : (index, index, index) -> !future_lds_read - %B2 = func.call @get_lds_read_value_vx2(%B2f) : (!future_lds_read) -> !vx2 - %acc2 = func.call @mfma_f32_16x16x16_f16(%A2, %B2, %acc1) - : (!vx2, !vx2, !ax4) -> !ax4 - %A3f = func.call @load_lds_A_swizzled(%lds_A2, %c32, %c2) - : (index, index, index) -> !future_lds_read - %A3 = func.call @get_lds_read_value_vx2(%A3f) : (!future_lds_read) -> !vx2 - %B3f = func.call @load_lds_B_swizzled(%lds_B2, %c32, %c2) - : (index, index, index) -> !future_lds_read - %B3 = func.call @get_lds_read_value_vx2(%B3f) : (!future_lds_read) -> !vx2 - %acc3 = func.call @mfma_f32_16x16x16_f16(%A3, %B3, %acc2) - : (!vx2, !vx2, !ax4) -> !ax4 - // Store to LDS C (not global). - func.call @store_lds_C_mfma_f32_16x16x16_f16(%acc3, %lds_C) - : (!ax4, index) -> () - return - } - - // Copy 16x16 f32 tile from LDS to global (C writeback). - // Uses MFMA C fragment layout: each thread reads its 4 values from LDS - // and writes to global memory. - func.func private @store_global_f32_16x16( - %lds_src: index, - %dst_ptr: !sx2, %dst_stride: index, - %row_offset: index, %col_offset: index) { - // Read from LDS and write to global using the MFMA C fragment layout. - %dst_prepared = func.call @prepare_ptr(%dst_ptr) : (!sx2) -> !aster_utils.any - %mfma_idx = func.call @mfma_index_C_16x16_f32() : () -> !aster_utils.struct - %col, %row_base = aster_utils.struct_extract %mfma_idx ["i", "j"] - : !aster_utils.struct -> index, index - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c64 = arith.constant 64 : index - %elt_size = arith.constant 4 : index - %col_byte = arith.muli %col, %elt_size : index - // Manually unrolled: 4 iterations for MFMA C fragment (4 consecutive rows). - // Avoids {aster.constexpr} loop which causes --inline canonicalize hang. - %off0 = func.call @mfma_c_16x16_f32_byte_offset(%row_offset, %col_offset, %row_base, %col, %dst_stride, %elt_size, %c0) : (index, index, index, index, index, index, index) -> index - %off1 = func.call @mfma_c_16x16_f32_byte_offset(%row_offset, %col_offset, %row_base, %col, %dst_stride, %elt_size, %c1) : (index, index, index, index, index, index, index) -> index - %off2 = func.call @mfma_c_16x16_f32_byte_offset(%row_offset, %col_offset, %row_base, %col, %dst_stride, %elt_size, %c2) : (index, index, index, index, index, index, index) -> index - %off3 = func.call @mfma_c_16x16_f32_byte_offset(%row_offset, %col_offset, %row_base, %col, %dst_stride, %elt_size, %c3) : (index, index, index, index, index, index, index) -> index - %addr0 = func.call @global_addr_from_offset(%dst_ptr, %off0) : (!sx2, index) -> !vx2 - %addr1 = func.call @global_addr_from_offset(%dst_ptr, %off1) : (!sx2, index) -> !vx2 - %addr2 = func.call @global_addr_from_offset(%dst_ptr, %off2) : (!sx2, index) -> !vx2 - %addr3 = func.call @global_addr_from_offset(%dst_ptr, %off3) : (!sx2, index) -> !vx2 - // Read 4 f32 values from LDS at MFMA C fragment positions. - %row0_off = arith.muli %row_base, %c64 : index - %lds0 = arith.addi %lds_src, %row0_off : index - %lds0b = arith.addi %lds0, %col_byte : index - %lds1b = arith.addi %lds0b, %c64 : index - %lds2b = arith.addi %lds1b, %c64 : index - %lds3b = arith.addi %lds2b, %c64 : index - %la0 = arith.index_cast %lds0b : index to i32 - %la1 = arith.index_cast %lds1b : index to i32 - %la2 = arith.index_cast %lds2b : index to i32 - %la3 = arith.index_cast %lds3b : index to i32 - %lv0 = lsir.to_reg %la0 : i32 -> !amdgcn.vgpr - %lv1 = lsir.to_reg %la1 : i32 -> !amdgcn.vgpr - %lv2 = lsir.to_reg %la2 : i32 -> !amdgcn.vgpr - %lv3 = lsir.to_reg %la3 : i32 -> !amdgcn.vgpr - %c0_i32 = arith.constant 0 : i32 - %d0 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr - %d1 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr - %d2 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr - %d3 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr - %v0, %t0 = amdgcn.load ds_read_b32 dest %d0 addr %lv0 offset c(%c0_i32) : dps(!amdgcn.vgpr) ins(!amdgcn.vgpr, i32) -> !amdgcn.read_token - %v1, %t1 = amdgcn.load ds_read_b32 dest %d1 addr %lv1 offset c(%c0_i32) : dps(!amdgcn.vgpr) ins(!amdgcn.vgpr, i32) -> !amdgcn.read_token - %v2, %t2 = amdgcn.load ds_read_b32 dest %d2 addr %lv2 offset c(%c0_i32) : dps(!amdgcn.vgpr) ins(!amdgcn.vgpr, i32) -> !amdgcn.read_token - %v3, %t3 = amdgcn.load ds_read_b32 dest %d3 addr %lv3 offset c(%c0_i32) : dps(!amdgcn.vgpr) ins(!amdgcn.vgpr, i32) -> !amdgcn.read_token - amdgcn.wait deps %t0 : !amdgcn.read_token - amdgcn.wait deps %t1 : !amdgcn.read_token - amdgcn.wait deps %t2 : !amdgcn.read_token - amdgcn.wait deps %t3 : !amdgcn.read_token - // Fire-and-forget global stores. - amdgcn.store global_store_dword data %v0 addr %addr0 : ins(!amdgcn.vgpr, !vx2) -> !amdgcn.write_token - amdgcn.store global_store_dword data %v1 addr %addr1 : ins(!amdgcn.vgpr, !vx2) -> !amdgcn.write_token - amdgcn.store global_store_dword data %v2 addr %addr2 : ins(!amdgcn.vgpr, !vx2) -> !amdgcn.write_token - amdgcn.store global_store_dword data %v3 addr %addr3 : ins(!amdgcn.vgpr, !vx2) -> !amdgcn.write_token - return - } } amdgcn.module @matmul_mod target = #amdgcn.target isa = #amdgcn.isa { + // Kernel operates on actual 40x40 dimensions. + // pad_tiling_interface in the transform pads to 48x48 internally. + // Host over-allocates C (48*48 elements) to fit padded stores. func.func @matmul_f16_40x40( %A: memref<40x64xf16>, %B: memref<40x64xf16>, %C: memref<40x40xf32>) attributes {gpu.kernel} { @@ -260,6 +121,8 @@ module { %b = bufferization.to_tensor %B restrict writable : memref<40x64xf16> to tensor<40x64xf16> %c = bufferization.to_tensor %C restrict writable : memref<40x40xf32> to tensor<40x40xf32> %fill = linalg.fill ins(%cst : f32) outs(%c : tensor<40x40xf32>) -> tensor<40x40xf32> + // matmul_transpose_b on actual 40-element M/N dims. + // pad_tiling_interface will pad these to 48 before tiling. %result = linalg.generic { indexing_maps = [ affine_map<(m, n, k) -> (m, k)>, diff --git a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py index f0adaea02..03f2bd280 100644 --- a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py +++ b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py @@ -145,48 +145,20 @@ def test_matmul_64x64(self): def test_matmul_padded_40x40(self): """Matmul with non-tile-aligned dimensions: actual 40x40x64. - Kernel operates on actual dimensions (memref<40x64xf16>, memref<40x40xf32>). - Transform pads boundary tiles to 16. Host over-allocates C to 48*48 - to accommodate OOB stores from boundary tiles. + pad_tiling_interface pads M,N from 40→48 at the iteration domain level. + After padding, all tiles are full (48 % 16 == 0). No affine.min bounds. + The padding is internal to the tensor computation; the C memref stays 40x40. """ M, N, K = 40, 40, 64 - M_pad = 48 # next multiple of 16, for C over-allocation np.random.seed(42) A = (np.random.randn(M, K) * 0.1).astype(np.float16) B_T = (np.random.randn(N, K) * 0.1).astype(np.float16) - - # Over-allocate C: kernel writes full 16x16 tiles at boundaries, - # going beyond 40x40. Allocate 48*48 elements but pass as 40x40 memref. - C = np.zeros(M_pad * M_pad, dtype=np.float32) + C = np.zeros(M * N, dtype=np.float32) def padded_preprocess(mlir_text): - opt = _find_mlir_air_opt() - result = subprocess.run( - [ - opt, - f"--transform-preload-library=transform-library-paths={_PADDED_TRANSFORM_FILE}", - "--transform-interpreter", - "--canonicalize", "--cse", - # Set memory_space=2 on padding allocs (no memory space → L1). - "--air-override-memref-memory-space=scope=func memory-space=2", - "--air-par-to-herd", - "--canonicalize", "--cse", - "--air-par-to-launch=has-air-segment=true", - "--canonicalize", "--cse", - "--air-copy-to-dma", - "--air-to-amdgcn", - "--canonicalize", - "--convert-memspace-to-amdgcn", - "--convert-to-amdgcn-library-calls", - ], - input=mlir_text, - capture_output=True, text=True, - ) - if result.returncode != 0: - raise RuntimeError( - f"mlir-air-opt padded preprocessing failed:\n{result.stderr}") - return result.stdout + return _air_preprocess_with_files( + mlir_text, _PADDED_TRANSFORM_FILE) compile_and_run( file_name=_PADDED_MLIR_FILE, @@ -200,8 +172,5 @@ def padded_preprocess(mlir_text): preprocess=padded_preprocess, ) - # Extract valid 40x40 region (C is over-allocated as flat 48*48). - # The kernel writes with stride=40, so reinterpret accordingly. - C_2d = C[:M * M_pad].reshape(-1, M_pad)[:M, :N].flatten() expected = (A.astype(np.float32) @ B_T.T.astype(np.float32)).flatten() - np.testing.assert_allclose(C_2d, expected, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(C, expected, rtol=1e-2, atol=1e-2) From b3e494e9bce0b7fe021633012859a352bc3bb298 Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 12/16] [mlir-air] Device-side padded matmul via air-split-launch-for-padding Add E2E test for non-tile-aligned matmul (M=40, N=40, K=64) with device-side padding. The kernel writes air.launch directly with a 3x3 grid and dynamic-sized DMA copies (arith.minui for boundary tiles). The upstream air-split-launch-for-padding pass (split-mode=single-launch, pad-location=source) produces a single launch with scf.if on block indices. Boundary DMAs get pad_after attributes, which ConvertToAMDGCNLibraryCalls converts to fill_f16_16x64 (zero LDS) + copy_f16_16x64_padded (clamped global loads via arith.minui). Key changes: - ConvertToAMDGCNLibraryCalls: detect pad_after on DMAs, emit fill + padded copy for boundary tiles - PromoteAllocsToFuncArgs: new pass to promote nested memref.alloc (including inside scf.if) to kernel function arguments - Init.cpp: register upstream air-split-launch-for-padding (extracted from AIE-gated code in mlir-air #1496), scf::SCFDialect - indexing_ptr.mlir: add make_raw_buffer_rsrc_bounded - Update mlir-air submodule to 39ee8fc6 (includes #1496 and #1499) Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- contrib/mlir-air/CMakeLists.txt | 1 + .../lib/ConvertToAMDGCNLibraryCalls.cpp | 88 +++++++++- contrib/mlir-air/lib/Init.cpp | 6 + .../mlir-air/lib/PromoteAllocsToFuncArgs.cpp | 133 +++++++++++++++ ...air-to-amdgcn-matmul-padded-transform.mlir | 55 +++---- .../test/air-to-amdgcn-matmul-padded.mlir | 155 ++++++++++++++---- .../test/integration/test_air_matmul_e2e.py | 63 +++++-- mlir_kernels/library/common/indexing_ptr.mlir | 10 ++ third_party/mlir-air | 1 + 9 files changed, 432 insertions(+), 80 deletions(-) create mode 100644 contrib/mlir-air/lib/PromoteAllocsToFuncArgs.cpp create mode 160000 third_party/mlir-air diff --git a/contrib/mlir-air/CMakeLists.txt b/contrib/mlir-air/CMakeLists.txt index 234399296..7808bb1c9 100644 --- a/contrib/mlir-air/CMakeLists.txt +++ b/contrib/mlir-air/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_library(MlirAirLib lib/ConvertMemSpaceToAMDGCN.cpp lib/Init.cpp lib/Pipelines.cpp + lib/PromoteAllocsToFuncArgs.cpp LINK_LIBS PUBLIC AIRConversionPasses diff --git a/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp index 58e20abca..fdbdb2a5f 100644 --- a/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp +++ b/contrib/mlir-air/lib/ConvertToAMDGCNLibraryCalls.cpp @@ -220,7 +220,91 @@ struct DmaToLibraryCall srcIsLDS ? "store_global" : "copy", ldsTy); if (dstIsLDS && !srcIsLDS) { - // Global→LDS: copy(global_ptr, stride, row, col, lds_dst) + // Detect boundary tile padding from either: + // (a) pad_after attribute (set by air-split-launch-for-padding), or + // (b) DMA src_sizes[0] < dst memref dim 0 (from transform tensor.pad). + auto padAfterAttr = dma->getAttrOfType("pad_after"); + bool hasPadding = false; + int32_t rowPad = 0; // pad rows appended after the valid region + if (padAfterAttr) { + auto padArr = padAfterAttr.asArrayRef(); + if (!padArr.empty() && padArr[0] > 0) { + hasPadding = true; + rowPad = padArr[0]; + } + } + // Also detect from DMA sizes: if src_sizes[0] is a constant < dst dim 0, + // this is a partial copy into a padded LDS buffer (from tensor.pad path). + if (!hasPadding) { + auto srcSizes = dma.getSrcSizes(); + if (!srcSizes.empty()) { + auto srcRowOpt = getConstantIntValue(srcSizes[0]); + int64_t dstDim0 = ldsTy.getDimSize(0); + if (srcRowOpt && dstDim0 > 0 && *srcRowOpt < dstDim0) { + hasPadding = true; + rowPad = dstDim0 - *srcRowOpt; + } + } + } + + Value ldsOffset = emitLDSOffset(rewriter, loc, dst, convCtx); + + if (hasPadding) { + // When padding is detected from pad_after attribute (air-split-launch), + // emit fill explicitly. When detected from DMA sizes (tensor.pad path), + // the linalg.fill from transform already handles the zero-fill and is + // converted separately by LinalgToLibraryCall. + if (padAfterAttr) { + std::string fillName = buildFuncName("fill", ldsTy); + Type f16Ty = rewriter.getF16Type(); + Value zeroF16 = arith::ConstantOp::create( + rewriter, loc, f16Ty, + rewriter.getF16FloatAttr(0.0f)); + auto fillTy = rewriter.getFunctionType({f16Ty, indexTy}, {}); + ensureDecl(rewriter, *convCtx.declBlock, loc, fillName, fillTy); + func::CallOp::create(rewriter, loc, fillName, TypeRange{}, + ValueRange{zeroF16, ldsOffset}); + } + + // Partial copy — copy only the valid rows. + // copy_f16_16x64_padded(global_ptr, stride, row, col, actual_rows, lds_dst) + // The library function reads actual_rows rows from global (not all 16). + auto [ptrVal, byteStride] = decomposeGlobalMemref(rewriter, loc, src); + auto srcSizes = dma.getSrcSizes(); + // actual_rows = src_sizes[0] (set by split-launch pass to actualLast). + Value actualRows; + if (!srcSizes.empty()) + actualRows = srcSizes[0]; + else + actualRows = arith::ConstantIndexOp::create( + rewriter, loc, ldsTy.getDimSize(0) - rowPad); + + std::string copyName = buildFuncName("copy", ldsTy) + "_padded"; + SmallVector copyArgs = {ptrVal, byteStride}; + SmallVector copyArgTypes = {sx2Ty, indexTy}; + auto srcOffsets = dma.getSrcOffsets(); + if (srcOffsets.size() >= 2) { + copyArgs.push_back(srcOffsets[0]); + copyArgs.push_back(srcOffsets[1]); + } else { + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + copyArgs.push_back(zero); + copyArgs.push_back(zero); + } + copyArgTypes.push_back(indexTy); + copyArgTypes.push_back(indexTy); + copyArgs.push_back(actualRows); + copyArgTypes.push_back(indexTy); + copyArgs.push_back(ldsOffset); + copyArgTypes.push_back(indexTy); + auto copyTy = rewriter.getFunctionType(copyArgTypes, {}); + ensureDecl(rewriter, *convCtx.declBlock, loc, copyName, copyTy); + func::CallOp::create(rewriter, loc, copyName, TypeRange{}, copyArgs); + rewriter.eraseOp(dma); + return success(); + } + + // Non-padded: Global→LDS: copy(global_ptr, stride, row, col, lds_dst) auto [ptrVal, byteStride] = decomposeGlobalMemref(rewriter, loc, src); callArgs.push_back(ptrVal); argTypes.push_back(sx2Ty); @@ -237,7 +321,7 @@ struct DmaToLibraryCall } argTypes.push_back(indexTy); argTypes.push_back(indexTy); - callArgs.push_back(emitLDSOffset(rewriter, loc, dst, convCtx)); + callArgs.push_back(ldsOffset); argTypes.push_back(indexTy); } else if (srcIsLDS && !dstIsLDS) { // LDS→Global: copy(lds_src, global_ptr, stride, row, col) diff --git a/contrib/mlir-air/lib/Init.cpp b/contrib/mlir-air/lib/Init.cpp index a415a5ec8..3e9903f11 100644 --- a/contrib/mlir-air/lib/Init.cpp +++ b/contrib/mlir-air/lib/Init.cpp @@ -11,6 +11,7 @@ #include "air/Dialect/AIR/AIRTransformOps.h" #include "air/Transform/AIRDmaToChannel.h" #include "air/Transform/AIRMiscPasses.h" +#include "air/Transform/AIRSplitLaunchForPadding.h" // Tablegen-generated per-pass registration for upstream AIR passes. namespace air_conv_reg { @@ -24,6 +25,7 @@ namespace air_conv_reg { namespace air_xform_reg { #define GEN_PASS_REGISTRATION_DMATOCHANNEL #define GEN_PASS_REGISTRATION_AIROVERRIDEMEMREFMEMORYSPACE +#define GEN_PASS_REGISTRATION_AIRSPLITLAUNCHFORPADDING #include "air/Transform/Passes.h.inc" } // namespace air_xform_reg @@ -58,6 +60,7 @@ namespace mlir::aster::mlir_air { std::unique_ptr createAirToAMDGCN(); std::unique_ptr createConvertToAMDGCNLibraryCalls(); std::unique_ptr createConvertMemSpaceToAMDGCN(); +std::unique_ptr createPromoteAllocsToFuncArgs(); void registerPipelines(); void registerAll(DialectRegistry ®istry) { @@ -67,6 +70,7 @@ void registerAll(DialectRegistry ®istry) { // Dialects needed for linalg tiling + transform dialect. registry.insert(); registry.insert(); + registry.insert(); registry.insert(); // Bufferization interface models. @@ -113,11 +117,13 @@ void registerAll(DialectRegistry ®istry) { air_conv_reg::registerAIRWrapFuncWithParallelPass(); // air-wrap-func-with-parallel air_xform_reg::registerDmaToChannel(); // air-dma-to-channel air_xform_reg::registerAIROverrideMemRefMemorySpace(); + air_xform_reg::registerAIRSplitLaunchForPadding(); // air-split-launch-for-padding (upstream, with single-launch GPU mode) // Aster-specific passes. registerPass([] { return createAirToAMDGCN(); }); registerPass([] { return createConvertToAMDGCNLibraryCalls(); }); registerPass([] { return createConvertMemSpaceToAMDGCN(); }); + registerPass([] { return createPromoteAllocsToFuncArgs(); }); // mlir-air pipelines. registerPipelines(); diff --git a/contrib/mlir-air/lib/PromoteAllocsToFuncArgs.cpp b/contrib/mlir-air/lib/PromoteAllocsToFuncArgs.cpp new file mode 100644 index 000000000..154db0f72 --- /dev/null +++ b/contrib/mlir-air/lib/PromoteAllocsToFuncArgs.cpp @@ -0,0 +1,133 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +//===- PromoteAllocsToFuncArgs.cpp - alloc → function argument promotion --===// + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +static void promoteAllocsInFunc(func::FuncOp funcOp) { + if (!funcOp->hasAttr("gpu.kernel")) + return; + + SmallVector allocsToPromote; + // Walk ALL allocs in the function, including those nested in scf.if + // or other control flow (e.g., from air-split-launch-for-padding). + funcOp.walk([&](memref::AllocOp allocOp) { + if (allocOp.getMemref().getType().getMemorySpace()) + return; + if (!allocOp.getMemref().getType().hasStaticShape()) + return; + allocsToPromote.push_back(allocOp); + }); + + if (allocsToPromote.empty()) + return; + + // Deduplicate allocs with the same type: share a single promoted arg. + // Multiple allocs of the same type (e.g., from split-launch scf.if + // branches) can safely share one buffer since only one branch executes. + DenseMap typeToArgIdx; + SmallVector uniqueAllocs; + SmallVector allocArgMapping; // maps each alloc → arg index + + auto funcTy = funcOp.getFunctionType(); + SmallVector newArgTypes(funcTy.getInputs()); + for (auto allocOp : allocsToPromote) { + auto ty = allocOp.getMemref().getType(); + auto it = typeToArgIdx.find(ty); + if (it == typeToArgIdx.end()) { + unsigned idx = newArgTypes.size(); + typeToArgIdx[ty] = idx; + newArgTypes.push_back(ty); + uniqueAllocs.push_back(allocOp); + allocArgMapping.push_back(idx); + } else { + allocArgMapping.push_back(it->second); + } + } + + funcOp.setFunctionType( + FunctionType::get(funcOp.getContext(), newArgTypes, funcTy.getResults())); + + auto &entryBlock = funcOp.getBody().front(); + // Add one block arg per unique type. + SmallVector newArgs; + for (auto allocOp : uniqueAllocs) { + auto newArg = + entryBlock.addArgument(allocOp.getMemref().getType(), allocOp.getLoc()); + newArgs.push_back(newArg); + } + // Replace each alloc with the corresponding shared arg. + for (unsigned i = 0; i < allocsToPromote.size(); ++i) { + unsigned argIdx = allocArgMapping[i] - funcTy.getNumInputs(); + allocsToPromote[i].getResult().replaceAllUsesWith(newArgs[argIdx]); + } + + // Erase the allocs and their initialization ops (fill, copy into padded + // buffer). The host is responsible for pre-initializing the workspace. + for (auto allocOp : llvm::reverse(allocsToPromote)) { + for (auto user : + llvm::make_early_inc_range(allocOp.getResult().getUsers())) { + if (isa(user)) + user->erase(); + } + allocOp->erase(); + } + + // Erase linalg.map (zero-fill) and memref.copy (padding init copy) + // on the promoted workspace args. The host pre-initializes them. + // Walk ALL ops (including nested in scf.if) to find operations on + // promoted workspace args. + DenseSet promotedArgs(newArgs.begin(), newArgs.end()); + SmallVector toErase; + funcOp.walk([&](Operation *op) { + if (auto mapOp = dyn_cast(op)) { + for (Value out : mapOp.getDpsInits()) + if (promotedArgs.contains(out)) + toErase.push_back(op); + } else if (auto copyOp = dyn_cast(op)) { + // Only erase copies INTO workspace (dst is workspace). + // Keep copies FROM workspace to global (the result copy-back). + Value dst = copyOp.getTarget(); + auto dstSv = dst.getDefiningOp(); + bool dstIsWs = promotedArgs.contains(dst) || + (dstSv && promotedArgs.contains(dstSv.getSource())); + if (dstIsWs) + toErase.push_back(op); + } + }); + for (auto *op : llvm::reverse(toErase)) + op->erase(); +} + +struct PromoteAllocsToFuncArgs + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PromoteAllocsToFuncArgs) + StringRef getArgument() const override { + return "promote-allocs-to-func-args"; + } + StringRef getDescription() const override { + return "Promote function-level memref.alloc to function arguments"; + } + void runOnOperation() override { + getOperation()->walk([](func::FuncOp f) { promoteAllocsInFunc(f); }); + } +}; + +} // namespace + +namespace mlir::aster::mlir_air { +std::unique_ptr createPromoteAllocsToFuncArgs() { + return std::make_unique(); +} +} // namespace mlir::aster::mlir_air diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir index b8415522d..312f5b48f 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir @@ -1,6 +1,15 @@ -// Transform sequence for padded matmul: actual M=40, N=40, K=64. -// Uses pad_tiling_interface to pad 40→48 (next multiple of 16) BEFORE tiling. -// After padding, all dimensions are tile-aligned → no affine.min, uniform loops. +// Transform sequence for padded matmul with device-side padding. +// A/B are 40x64 (actual), NOT padded. C is 48x48 (padded for safe stores). +// +// The matmul is 40x40x64. tile_sizes [16, 16, 0] creates a 3x3 grid: +// - Interior tiles (2x2): full 16x16 output, full 16x64 A/B loads +// - M-boundary (row 2): 8xN output, 8x64 A loads +// - N-boundary (col 2): Mx8 output, 8x64 B loads +// - Corner (2,2): 8x8 output, 8x64 A and B loads +// +// After air-split-launch-for-padding, boundary DMAs get pad_after attribute. +// ConvertToAMDGCNLibraryCalls emits fill_f16_16x64 + copy_f16_16x64_padded +// for boundary tiles. module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main( @@ -8,47 +17,31 @@ module attributes {transform.with_named_sequence} { %matmul = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op - // Step 1: Pad the iteration domain to tile-aligned sizes. - // M: 40→48, N: 40→48, K: unchanged (already 64, divisible by any tile). - %padded_matmul, %pad_op = transform.structured.pad_tiling_interface %matmul - to padding_sizes [16, 16, 0] pad_to_multiple_of - { padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32] } - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Step 2: Outer tiling — 3 wavefronts (48/16=3, exact). + // Tile 40x40 matmul into 3x3 grid. Boundary tiles have 8 rows/cols. %outer_tiled, %outer_forall = - transform.structured.tile_using_forall %padded_matmul - tile_sizes [16, 0, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - - // Step 3: Compute tiling — 16x16 output tiles, K untiled. - // 48 % 16 == 0 → no affine.min, all tiles are full. - %tiled, %lm, %ln = transform.structured.tile_using_for %outer_tiled + transform.structured.tile_using_forall %matmul tile_sizes [16, 16, 0] - : (!transform.any_op) -> (!transform.any_op, !transform.any_op, - !transform.any_op) + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) - // Step 4: Pad A and B operands to promote to LDS. - // All tiles are full (no boundary tiles) so pad is a no-op on shapes - // but nofold forces allocation for the copy to LDS. - %padded, %pad, %copy_back = transform.structured.pad %tiled { + // Pad A,B for LDS promotion (16x64 tiles; K untiled). + // nofold ensures ALL tiles (including full interior tiles) go through + // the pad→alloc→copy path, so all get air.dma_memcpy_nd ops. + %padded, %pad, %copy_back = transform.structured.pad %outer_tiled { padding_values = [0.0 : f16, 0.0 : f16, 0.0 : f32], padding_dimensions = [0, 1, 2], pack_paddings = [1, 1, 0], - nofold_flags = [1, 1, 0], - copy_back_op = "linalg.copy" + nofold_flags = [1, 1, 0] } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad : (!transform.any_op) -> !transform.any_op - // Step 5: Promote padded A,B to LDS (memory_space=2). + // Promote padded A,B to LDS (memory_space = 2 = shared/LDS). %padded_lhs = transform.get_producer_of_operand %padded[0] : (!transform.any_op) -> (!transform.any_op) %buf_a, %new_a = transform.structured.bufferize_to_allocation %padded_lhs {memory_space = 2, bufferize_destination_only} : !transform.any_op - %padded_rhs = transform.get_producer_of_operand %padded[1] : (!transform.any_op) -> (!transform.any_op) %buf_b, %new_b = transform.structured.bufferize_to_allocation %padded_rhs @@ -65,7 +58,7 @@ module attributes {transform.with_named_sequence} { } : !transform.any_op transform.apply_cse to %func_0 : !transform.any_op - // One-shot bufferize. + // Bufferize. %func_1 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op %func_buf = transform.bufferization.one_shot_bufferize %func_1 { @@ -82,12 +75,12 @@ module attributes {transform.with_named_sequence} { %func_3 = transform.air.remove_uninitialized_copy %func_2 : (!transform.any_op) -> !transform.any_op - // Convert outer forall → parallel → air.herd. + // forall → parallel → launch (M and N both in launch). %forall_2 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op %parallel = transform.loop.forall_to_parallel %forall_2 : (!transform.any_op) -> !transform.any_op - %herd = transform.air.par_to_herd %parallel + %launch = transform.air.par_to_launch %parallel : (!transform.any_op) -> !transform.any_op transform.yield diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir index 044798122..cc26c0063 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir @@ -1,7 +1,15 @@ // Padded matmul: actual M=40, N=40, K=64. -// pad_tiling_interface pads 40→48 at the iteration domain level. -// After padding, all tiles are full (48 % 16 == 0). No boundary tiles. -// C output is 48x48 (padded); host extracts valid 40x40 region. +// Device-side padding via air-split-launch-for-padding. +// +// The kernel has air.launch (3x3 grid) written directly — no transform, +// no air-wrap-func-with-parallel. Each block computes one 16x16 output tile. +// +// A, B: actual 40x64, C: 48x48 (padded for safe boundary stores). +// Boundary tiles (last row/col) copy fewer rows from global, with LDS +// zero-filled first. split-launch adds pad_after on boundary DMAs. +// ConvertToAMDGCNLibraryCalls emits fill + copy_padded for those. +// +// Host extracts valid C[0:40, 0:40] after kernel. !sx2 = !amdgcn.sgpr<[? + 2]> !vx2 = !amdgcn.vgpr<[? + 2]> @@ -26,9 +34,11 @@ module { func.func private @load_lds_B_swizzled( index, index, index) -> !future_lds_read func.func private @get_lds_read_value_vx2(!future_lds_read) -> !vx2 + func.func private @thread_tile_pos_16x64_b() -> (index, index) + func.func private @tiled_row_byte_off(index, index, index, index, index, index) -> index + func.func private @load_global_at_byte_off(!aster_utils.any, index) -> !future_global_read + func.func private @fill_lds_16x64_b(index) - // Same library functions as the 64x64 test — all tiles are full after - // pad_tiling_interface pads 40→48. func.func private @copy_f16_16x64( %src_ptr: !sx2, %src_stride: index, %row_offset: index, %col_offset: index, @@ -55,6 +65,43 @@ module { return } + func.func private @copy_f16_16x64_padded( + %src_ptr: !sx2, %src_stride: index, + %row_offset: index, %col_offset: index, + %actual_rows: index, + %lds_dst: index) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c1024 = arith.constant 1024 : index + %ptr = func.call @prepare_ptr(%src_ptr) : (!sx2) -> !aster_utils.any + %my_row, %my_col_byte = func.call @thread_tile_pos_16x64_b() + : () -> (index, index) + %max_row = arith.subi %actual_rows, %c1 : index + %clamped_row = arith.minui %my_row, %max_row : index + %byte_off0 = func.call @tiled_row_byte_off( + %row_offset, %clamped_row, %col_offset, %my_col_byte, %src_stride, %c2) + : (index, index, index, index, index, index) -> index + %gfut0 = func.call @load_global_at_byte_off(%ptr, %byte_off0) + : (!aster_utils.any, index) -> !future_global_read + %t0, %t1 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst, %gfut0) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + %col1 = arith.addi %col_offset, %c32 : index + %lds_dst1 = arith.addi %lds_dst, %c1024 : index + %byte_off1 = func.call @tiled_row_byte_off( + %row_offset, %clamped_row, %col1, %my_col_byte, %src_stride, %c2) + : (index, index, index, index, index, index) -> index + %gfut1 = func.call @load_global_at_byte_off(%ptr, %byte_off1) + : (!aster_utils.any, index) -> !future_global_read + %t2, %t3 = func.call @store_global_tile_to_lds_16x64_b(%lds_dst1, %gfut1) + : (index, !future_global_read) -> (!lds_write_token, !lds_write_token) + amdgcn.wait deps %t0 : !lds_write_token + amdgcn.wait deps %t1 : !lds_write_token + amdgcn.wait deps %t2 : !lds_write_token + amdgcn.wait deps %t3 : !lds_write_token + return + } + func.func private @mfma_matmul_f16_16x64( %lds_A: index, %lds_B: index, %C_ptr: !sx2, %C_stride: index, @@ -105,42 +152,78 @@ module { return } - func.func private @fill_f16_16x64(%val: f16, %lds_dst: index) { return } - func.func private @fill_f16_16x32(%val: f16, %lds_dst: index) { return } + func.func private @fill_f16_16x64(%val: f16, %lds_dst: index) { + func.call @fill_lds_16x64_b(%lds_dst) : (index) -> () + return + } } amdgcn.module @matmul_mod target = #amdgcn.target isa = #amdgcn.isa { - // Kernel operates on actual 40x40 dimensions. - // pad_tiling_interface in the transform pads to 48x48 internally. - // Host over-allocates C (48*48 elements) to fit padded stores. func.func @matmul_f16_40x40( - %A: memref<40x64xf16>, %B: memref<40x64xf16>, %C: memref<40x40xf32>) + %A: memref<40x64xf16>, %B: memref<40x64xf16>, %C: memref<48x48xf32>) attributes {gpu.kernel} { - %cst = arith.constant 0.000000e+00 : f32 - %a = bufferization.to_tensor %A restrict writable : memref<40x64xf16> to tensor<40x64xf16> - %b = bufferization.to_tensor %B restrict writable : memref<40x64xf16> to tensor<40x64xf16> - %c = bufferization.to_tensor %C restrict writable : memref<40x40xf32> to tensor<40x40xf32> - %fill = linalg.fill ins(%cst : f32) outs(%c : tensor<40x40xf32>) -> tensor<40x40xf32> - // matmul_transpose_b on actual 40-element M/N dims. - // pad_tiling_interface will pad these to 48 before tiling. - %result = linalg.generic { - indexing_maps = [ - affine_map<(m, n, k) -> (m, k)>, - affine_map<(m, n, k) -> (n, k)>, - affine_map<(m, n, k) -> (m, n)> - ], - iterator_types = ["parallel", "parallel", "reduction"] - } ins(%a, %b : tensor<40x64xf16>, tensor<40x64xf16>) - outs(%fill : tensor<40x40xf32>) { - ^bb0(%av: f16, %bv: f16, %cv: f32): - %a_ext = arith.extf %av : f16 to f32 - %b_ext = arith.extf %bv : f16 to f32 - %prod = arith.mulf %a_ext, %b_ext : f32 - %sum = arith.addf %cv, %prod : f32 - linalg.yield %sum : f32 - } -> tensor<40x40xf32> - bufferization.materialize_in_destination %result in writable %C - : (tensor<40x40xf32>, memref<40x40xf32>) -> () + %c3 = arith.constant 3 : index + + air.launch (%m_id, %n_id) in (%m_sz=%c3, %n_sz=%c3) + args(%a=%A, %b=%B, %c=%C) + : memref<40x64xf16>, memref<40x64xf16>, memref<48x48xf32> + attributes {air.actual_sizes = array} { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c40 = arith.constant 40 : index + %cst_f16 = arith.constant 0.000000e+00 : f16 + + // Tile offsets from launch block indices. + %m_off = arith.muli %m_id, %c16 : index + %n_off = arith.muli %n_id, %c16 : index + + // A tile: copy min(16, 40-m_off) rows from global to LDS. + %m_rem = arith.subi %c40, %m_off : index + %m_size = arith.minui %c16, %m_rem : index + %lds_a = memref.alloc() : memref<16x64xf16, 2> + linalg.fill ins(%cst_f16 : f16) outs(%lds_a : memref<16x64xf16, 2>) + %a_sub = memref.subview %a[%m_off, 0] [%m_size, 64] [1, 1] + : memref<40x64xf16> to memref> + %lds_a_sub = memref.subview %lds_a[0, 0] [%m_size, 64] [1, 1] + : memref<16x64xf16, 2> to memref, 2> + memref.copy %a_sub, %lds_a_sub + : memref> + to memref, 2> + + // B tile: copy min(16, 40-n_off) rows from global to LDS. + %n_rem = arith.subi %c40, %n_off : index + %n_size = arith.minui %c16, %n_rem : index + %lds_b = memref.alloc() : memref<16x64xf16, 2> + linalg.fill ins(%cst_f16 : f16) outs(%lds_b : memref<16x64xf16, 2>) + %b_sub = memref.subview %b[%n_off, 0] [%n_size, 64] [1, 1] + : memref<40x64xf16> to memref> + %lds_b_sub = memref.subview %lds_b[0, 0] [%n_size, 64] [1, 1] + : memref<16x64xf16, 2> to memref, 2> + memref.copy %b_sub, %lds_b_sub + : memref> + to memref, 2> + + // Matmul on full 16x64 LDS tiles → 16x16 output written to C. + // C is 48x48 so writing 16x16 at any (m_off, n_off) is safe. + %c_sub = memref.subview %c[%m_off, %n_off] [16, 16] [1, 1] + : memref<48x48xf32> to memref<16x16xf32, strided<[48, 1], offset: ?>> + linalg.generic { + indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%lds_a, %lds_b : memref<16x64xf16, 2>, memref<16x64xf16, 2>) + outs(%c_sub : memref<16x16xf32, strided<[48, 1], offset: ?>>) { + ^bb0(%av: f16, %bv: f16, %cv: f32): + %a_ext = arith.extf %av : f16 to f32 + %b_ext = arith.extf %bv : f16 to f32 + %prod = arith.mulf %a_ext, %b_ext : f32 + %sum = arith.addf %cv, %prod : f32 + linalg.yield %sum : f32 + } + } return } } diff --git a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py index 03f2bd280..d0cb059aa 100644 --- a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py +++ b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py @@ -145,32 +145,73 @@ def test_matmul_64x64(self): def test_matmul_padded_40x40(self): """Matmul with non-tile-aligned dimensions: actual 40x40x64. - pad_tiling_interface pads M,N from 40→48 at the iteration domain level. - After padding, all tiles are full (48 % 16 == 0). No affine.min bounds. - The padding is internal to the tensor computation; the C memref stays 40x40. + Device-side padding via air-split-launch-for-padding: + - A, B are actual 40x64 (NO host-side padding) + - C is 48x48 (padded to next multiple of tile=16 for safe boundary stores) + - The split-launch pass produces a single 3x3 launch with scf.if + on block indices to select between interior and boundary bodies + - Boundary tiles use fill_f16_16x64 (zero-fill LDS) + + copy_f16_16x64_padded (clamped global loads via arith.minui) + - Valid result is in C[0:40, 0:40] + + Grid: 3x3 blocks (ceil(40/16)=3 per dim), one wavefront per block. """ M, N, K = 40, 40, 64 + M_pad, N_pad = 48, 48 # next multiple of tile size 16 (for C only) np.random.seed(42) A = (np.random.randn(M, K) * 0.1).astype(np.float16) B_T = (np.random.randn(N, K) * 0.1).astype(np.float16) - C = np.zeros(M * N, dtype=np.float32) + + # C is padded to 48x48 so boundary 16x16 stores don't go OOB. + C_pad = np.zeros(M_pad * N_pad, dtype=np.float32) def padded_preprocess(mlir_text): - return _air_preprocess_with_files( - mlir_text, _PADDED_TRANSFORM_FILE) + opt = _find_mlir_air_opt() + dump_dir = "/mnt/m2m_nobackup/erweiw/aster/ir_dumps" + os.makedirs(dump_dir, exist_ok=True) + + result = subprocess.run( + [ + opt, + "--air-copy-to-dma", + "--air-split-launch-for-padding=split-mode=single-launch pad-location=source", + "--canonicalize", "--cse", + "--air-to-amdgcn", + "--canonicalize", + "--convert-memspace-to-amdgcn", + "--convert-to-amdgcn-library-calls", + ], + input=mlir_text, + capture_output=True, text=True, + ) + if result.returncode != 0: + raise RuntimeError( + f"mlir-air-opt padded preprocessing failed:\n{result.stderr}") + return result.stdout + + from aster.execution.core import InOutArray compile_and_run( file_name=_PADDED_MLIR_FILE, kernel_name="matmul_f16_40x40", - input_data=[A.flatten(), B_T.flatten()], - output_data=[C], + # Kernel args: (A: 40x64, B: 40x64, C: 48x48) + # No workspace buffers — single-tile kernel, no promoted allocs. + input_data=[ + A.flatten(), # arg0: A 40x64 (read-only) + B_T.flatten(), # arg1: B 40x64 (read-only) + InOutArray(C_pad), # arg2: C 48x48 (read-write) + ], + output_data=[], pass_pipeline=_post_air_pipeline(_LIBRARY_PATHS), library_paths=[], - grid_dim=(1, 1, 1), - block_dim=(192, 1, 1), # 3 wavefronts (3x1 herd) + grid_dim=(3, 3, 1), # ceil(40/16)=3 per dim, 9 blocks total + block_dim=(64, 1, 1), # 1 wavefront per block preprocess=padded_preprocess, ) + # Extract valid 40x40 region from padded 48x48 output. + C_2d = C_pad.reshape(M_pad, N_pad) + C_valid = C_2d[:M, :N].flatten() expected = (A.astype(np.float32) @ B_T.T.astype(np.float32)).flatten() - np.testing.assert_allclose(C, expected, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(C_valid, expected, rtol=1e-2, atol=1e-2) diff --git a/mlir_kernels/library/common/indexing_ptr.mlir b/mlir_kernels/library/common/indexing_ptr.mlir index e800b9043..1e620c1d9 100644 --- a/mlir_kernels/library/common/indexing_ptr.mlir +++ b/mlir_kernels/library/common/indexing_ptr.mlir @@ -69,4 +69,14 @@ amdgcn.library @common_indexing_ptr { : (!sx2, !s, i32) -> !sx4 return %rsrc : !sx4 } + + // Bounded variant: set num_records to the actual buffer size in bytes. + // OOB loads return zero, OOB stores are silently dropped. + func.func private @make_raw_buffer_rsrc_bounded(%base: !sx2, %num_bytes: !s) -> !sx4 { + %c0_stride = arith.constant 0 : i32 + %rsrc = amdgcn.make_buffer_rsrc %base, %num_bytes, %c0_stride, + cache_swizzle = false, swizzle_enable = false, flags = 131072 + : (!sx2, !s, i32) -> !sx4 + return %rsrc : !sx4 + } } diff --git a/third_party/mlir-air b/third_party/mlir-air new file mode 160000 index 000000000..39ee8fc6c --- /dev/null +++ b/third_party/mlir-air @@ -0,0 +1 @@ +Subproject commit 39ee8fc6c9fe45057e2ad86dc3dd078e00e8d132 From e79cbb45f63083137678e6419732b6f15e3b2f5a Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 13/16] Revert dead kittens code, keep only fill_lds_16x64_b MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove store_lds_C_mfma_f32_16x16x16_f16 (AGPR→LDS C tile store) from compute_16x16_f16.mlir — unused since the padded matmul writes directly to global C. Keep fill_lds_16x64_b in lds_16x64_b.mlir as it is needed by the device-side padding path (zero-fill LDS before partial copy). Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- .../kittens/library/compute_16x16_f16.mlir | 76 ------------------- contrib/kittens/library/lds_16x64_b.mlir | 3 - 2 files changed, 79 deletions(-) diff --git a/contrib/kittens/library/compute_16x16_f16.mlir b/contrib/kittens/library/compute_16x16_f16.mlir index 9a78b2475..125e8e938 100644 --- a/contrib/kittens/library/compute_16x16_f16.mlir +++ b/contrib/kittens/library/compute_16x16_f16.mlir @@ -17,7 +17,6 @@ amdgcn.library @kittens_compute_16x16_f16 isa = [#amdgcn.isa] { // From register-init.mlir func.func private @init_agprx4(i32) -> !ax4 - func.func private @alloc_vgpr() -> !amdgcn.vgpr // From indexing.mlir func.func private @mfma_index_C_16x16_f32() -> !index_pair func.func private @mfma_c_16x16_f32_byte_offset(index, index, index, index, index, index, index) -> index @@ -100,79 +99,4 @@ amdgcn.library @kittens_compute_16x16_f16 isa = [#amdgcn.isa] { return } - - //===--------------------------------------------------------------------===// - // C tile store to LDS (AGPR → LDS, row-major 16x16 f32) - //===--------------------------------------------------------------------===// - - // Store a 16x16 f32 C tile from AGPRs to LDS in row-major layout. - // LDS layout: 16 rows × 16 cols × 4 bytes = 1024 bytes, stride = 64 bytes/row. - // Each thread writes 4 f32 values at its MFMA C fragment positions. - func.func private @store_lds_C_mfma_f32_16x16x16_f16(%tile: !rt_C_f32, %lds_base: index) { - %mfma_idx = func.call @mfma_index_C_16x16_f32() : () -> !index_pair - %col, %row_base = aster_utils.struct_extract %mfma_idx ["i", "j"] : !index_pair -> index, index - - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c3 = arith.constant 3 : index - %c4 = arith.constant 4 : index - %c64 = arith.constant 64 : index - %c0_i32 = arith.constant 0 : i32 - - %a0, %a1, %a2, %a3 = amdgcn.split_register_range %tile : !ax4 - %col_off = arith.muli %col, %c4 : index - - // Manually unrolled: 4 AGPR stores to LDS at consecutive rows. - %row0 = arith.addi %row_base, %c0 : index - %row1 = arith.addi %row_base, %c1 : index - %row2 = arith.addi %row_base, %c2 : index - %row3 = arith.addi %row_base, %c3 : index - - // Row 0 - %r0off = arith.muli %row0, %c64 : index - %b0 = arith.addi %r0off, %col_off : index - %addr0 = arith.addi %lds_base, %b0 : index - %ai0 = arith.index_cast %addr0 : index to i32 - %av0 = lsir.to_reg %ai0 : i32 -> !amdgcn.vgpr - %dv0 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr - %vv0 = lsir.copy %dv0, %a0 : !amdgcn.vgpr, !a - amdgcn.store ds_write_b32 data %vv0 addr %av0 offset c(%c0_i32) - : ins(!amdgcn.vgpr, !amdgcn.vgpr, i32) -> !amdgcn.write_token - - // Row 1 - %r1off = arith.muli %row1, %c64 : index - %b1 = arith.addi %r1off, %col_off : index - %addr1 = arith.addi %lds_base, %b1 : index - %ai1 = arith.index_cast %addr1 : index to i32 - %av1 = lsir.to_reg %ai1 : i32 -> !amdgcn.vgpr - %dv1 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr - %vv1 = lsir.copy %dv1, %a1 : !amdgcn.vgpr, !a - amdgcn.store ds_write_b32 data %vv1 addr %av1 offset c(%c0_i32) - : ins(!amdgcn.vgpr, !amdgcn.vgpr, i32) -> !amdgcn.write_token - - // Row 2 - %r2off = arith.muli %row2, %c64 : index - %b2 = arith.addi %r2off, %col_off : index - %addr2 = arith.addi %lds_base, %b2 : index - %ai2 = arith.index_cast %addr2 : index to i32 - %av2 = lsir.to_reg %ai2 : i32 -> !amdgcn.vgpr - %dv2 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr - %vv2 = lsir.copy %dv2, %a2 : !amdgcn.vgpr, !a - amdgcn.store ds_write_b32 data %vv2 addr %av2 offset c(%c0_i32) - : ins(!amdgcn.vgpr, !amdgcn.vgpr, i32) -> !amdgcn.write_token - - // Row 3 - %r3off = arith.muli %row3, %c64 : index - %b3 = arith.addi %r3off, %col_off : index - %addr3 = arith.addi %lds_base, %b3 : index - %ai3 = arith.index_cast %addr3 : index to i32 - %av3 = lsir.to_reg %ai3 : i32 -> !amdgcn.vgpr - %dv3 = func.call @alloc_vgpr() : () -> !amdgcn.vgpr - %vv3 = lsir.copy %dv3, %a3 : !amdgcn.vgpr, !a - amdgcn.store ds_write_b32 data %vv3 addr %av3 offset c(%c0_i32) - : ins(!amdgcn.vgpr, !amdgcn.vgpr, i32) -> !amdgcn.write_token - - return - } } diff --git a/contrib/kittens/library/lds_16x64_b.mlir b/contrib/kittens/library/lds_16x64_b.mlir index a644e1a8c..6e0704bcd 100644 --- a/contrib/kittens/library/lds_16x64_b.mlir +++ b/contrib/kittens/library/lds_16x64_b.mlir @@ -138,10 +138,7 @@ amdgcn.library @kittens_lds_16x64_b isa = [#amdgcn.isa] { return %future : !future_lds_read } - //===--------------------------------------------------------------------===// // Zero-fill a 16x64_b LDS tile (1024 bytes). - //===--------------------------------------------------------------------===// - // Each thread writes 16 bytes of zeros at its assigned positions. func.func private @fill_lds_16x64_b(%tile_base: index) { %zero = func.call @alloc_vgprx2() : () -> !vx2 From 465d671692d0199d2c8989fe078d6f4902c1ad27 Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 14/16] Remove unused PromoteAllocsToFuncArgs pass No longer needed: the upstream single-launch mode in air-split-launch-for-padding shares LDS allocs across scf.if branches instead of duplicating them, and the padded matmul writes directly to global C without temp allocs. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- contrib/mlir-air/CMakeLists.txt | 1 - contrib/mlir-air/lib/Init.cpp | 2 - .../mlir-air/lib/PromoteAllocsToFuncArgs.cpp | 133 ------------------ 3 files changed, 136 deletions(-) delete mode 100644 contrib/mlir-air/lib/PromoteAllocsToFuncArgs.cpp diff --git a/contrib/mlir-air/CMakeLists.txt b/contrib/mlir-air/CMakeLists.txt index 7808bb1c9..234399296 100644 --- a/contrib/mlir-air/CMakeLists.txt +++ b/contrib/mlir-air/CMakeLists.txt @@ -4,7 +4,6 @@ add_mlir_library(MlirAirLib lib/ConvertMemSpaceToAMDGCN.cpp lib/Init.cpp lib/Pipelines.cpp - lib/PromoteAllocsToFuncArgs.cpp LINK_LIBS PUBLIC AIRConversionPasses diff --git a/contrib/mlir-air/lib/Init.cpp b/contrib/mlir-air/lib/Init.cpp index 3e9903f11..6ba97744a 100644 --- a/contrib/mlir-air/lib/Init.cpp +++ b/contrib/mlir-air/lib/Init.cpp @@ -60,7 +60,6 @@ namespace mlir::aster::mlir_air { std::unique_ptr createAirToAMDGCN(); std::unique_ptr createConvertToAMDGCNLibraryCalls(); std::unique_ptr createConvertMemSpaceToAMDGCN(); -std::unique_ptr createPromoteAllocsToFuncArgs(); void registerPipelines(); void registerAll(DialectRegistry ®istry) { @@ -123,7 +122,6 @@ void registerAll(DialectRegistry ®istry) { registerPass([] { return createAirToAMDGCN(); }); registerPass([] { return createConvertToAMDGCNLibraryCalls(); }); registerPass([] { return createConvertMemSpaceToAMDGCN(); }); - registerPass([] { return createPromoteAllocsToFuncArgs(); }); // mlir-air pipelines. registerPipelines(); diff --git a/contrib/mlir-air/lib/PromoteAllocsToFuncArgs.cpp b/contrib/mlir-air/lib/PromoteAllocsToFuncArgs.cpp deleted file mode 100644 index 154db0f72..000000000 --- a/contrib/mlir-air/lib/PromoteAllocsToFuncArgs.cpp +++ /dev/null @@ -1,133 +0,0 @@ -// Copyright 2026 The ASTER Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -//===- PromoteAllocsToFuncArgs.cpp - alloc → function argument promotion --===// - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Pass/Pass.h" - -using namespace mlir; - -namespace { - -static void promoteAllocsInFunc(func::FuncOp funcOp) { - if (!funcOp->hasAttr("gpu.kernel")) - return; - - SmallVector allocsToPromote; - // Walk ALL allocs in the function, including those nested in scf.if - // or other control flow (e.g., from air-split-launch-for-padding). - funcOp.walk([&](memref::AllocOp allocOp) { - if (allocOp.getMemref().getType().getMemorySpace()) - return; - if (!allocOp.getMemref().getType().hasStaticShape()) - return; - allocsToPromote.push_back(allocOp); - }); - - if (allocsToPromote.empty()) - return; - - // Deduplicate allocs with the same type: share a single promoted arg. - // Multiple allocs of the same type (e.g., from split-launch scf.if - // branches) can safely share one buffer since only one branch executes. - DenseMap typeToArgIdx; - SmallVector uniqueAllocs; - SmallVector allocArgMapping; // maps each alloc → arg index - - auto funcTy = funcOp.getFunctionType(); - SmallVector newArgTypes(funcTy.getInputs()); - for (auto allocOp : allocsToPromote) { - auto ty = allocOp.getMemref().getType(); - auto it = typeToArgIdx.find(ty); - if (it == typeToArgIdx.end()) { - unsigned idx = newArgTypes.size(); - typeToArgIdx[ty] = idx; - newArgTypes.push_back(ty); - uniqueAllocs.push_back(allocOp); - allocArgMapping.push_back(idx); - } else { - allocArgMapping.push_back(it->second); - } - } - - funcOp.setFunctionType( - FunctionType::get(funcOp.getContext(), newArgTypes, funcTy.getResults())); - - auto &entryBlock = funcOp.getBody().front(); - // Add one block arg per unique type. - SmallVector newArgs; - for (auto allocOp : uniqueAllocs) { - auto newArg = - entryBlock.addArgument(allocOp.getMemref().getType(), allocOp.getLoc()); - newArgs.push_back(newArg); - } - // Replace each alloc with the corresponding shared arg. - for (unsigned i = 0; i < allocsToPromote.size(); ++i) { - unsigned argIdx = allocArgMapping[i] - funcTy.getNumInputs(); - allocsToPromote[i].getResult().replaceAllUsesWith(newArgs[argIdx]); - } - - // Erase the allocs and their initialization ops (fill, copy into padded - // buffer). The host is responsible for pre-initializing the workspace. - for (auto allocOp : llvm::reverse(allocsToPromote)) { - for (auto user : - llvm::make_early_inc_range(allocOp.getResult().getUsers())) { - if (isa(user)) - user->erase(); - } - allocOp->erase(); - } - - // Erase linalg.map (zero-fill) and memref.copy (padding init copy) - // on the promoted workspace args. The host pre-initializes them. - // Walk ALL ops (including nested in scf.if) to find operations on - // promoted workspace args. - DenseSet promotedArgs(newArgs.begin(), newArgs.end()); - SmallVector toErase; - funcOp.walk([&](Operation *op) { - if (auto mapOp = dyn_cast(op)) { - for (Value out : mapOp.getDpsInits()) - if (promotedArgs.contains(out)) - toErase.push_back(op); - } else if (auto copyOp = dyn_cast(op)) { - // Only erase copies INTO workspace (dst is workspace). - // Keep copies FROM workspace to global (the result copy-back). - Value dst = copyOp.getTarget(); - auto dstSv = dst.getDefiningOp(); - bool dstIsWs = promotedArgs.contains(dst) || - (dstSv && promotedArgs.contains(dstSv.getSource())); - if (dstIsWs) - toErase.push_back(op); - } - }); - for (auto *op : llvm::reverse(toErase)) - op->erase(); -} - -struct PromoteAllocsToFuncArgs - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PromoteAllocsToFuncArgs) - StringRef getArgument() const override { - return "promote-allocs-to-func-args"; - } - StringRef getDescription() const override { - return "Promote function-level memref.alloc to function arguments"; - } - void runOnOperation() override { - getOperation()->walk([](func::FuncOp f) { promoteAllocsInFunc(f); }); - } -}; - -} // namespace - -namespace mlir::aster::mlir_air { -std::unique_ptr createPromoteAllocsToFuncArgs() { - return std::make_unique(); -} -} // namespace mlir::aster::mlir_air From 1b6d448175d5d081c5b4b68b9f6a7d09917f17d6 Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 15/16] Add copyright headers to new MLIR test files Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- .../test/air-to-amdgcn-matmul-padded-transform.mlir | 6 ++++++ contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir index 312f5b48f..c7970ef64 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded-transform.mlir @@ -1,3 +1,9 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // Transform sequence for padded matmul with device-side padding. // A/B are 40x64 (actual), NOT padded. C is 48x48 (padded for safe stores). // diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir index cc26c0063..71ec3cac0 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-padded.mlir @@ -1,3 +1,9 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // Padded matmul: actual M=40, N=40, K=64. // Device-side padding via air-split-launch-for-padding. // From e8c694309d52cef6f4a761453fe95a706aff63b0 Mon Sep 17 00:00:00 2001 From: Erwei Wang Date: Wed, 8 Apr 2026 23:21:12 +0000 Subject: [PATCH 16/16] Add missing copyright headers to new files Co-Authored-By: Claude Opus 4.6 Signed-off-by: Erwei Wang --- contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir | 6 ++++++ contrib/mlir-air/test/air-to-amdgcn-matmul.mlir | 6 ++++++ contrib/mlir-air/test/integration/test_air_matmul_e2e.py | 6 ++++++ 3 files changed, 18 insertions(+) diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir index 448abf26a..01a7bd0c0 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul-transform.mlir @@ -1,3 +1,9 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // Transform sequence for 64x64 matmul: tile, pad, bufferize, map to AIR herd. // Adapted from xrt/12 (tile-using-pad, no packing). diff --git a/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir index 7f9a90b5b..40d51c5ce 100644 --- a/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir +++ b/contrib/mlir-air/test/air-to-amdgcn-matmul.mlir @@ -1,3 +1,9 @@ +// Copyright 2026 The ASTER Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // RUN: mlir-air-opt %s \ // RUN: --transform-preload-library="transform-library-paths=%p/air-to-amdgcn-matmul-transform.mlir" \ // RUN: --transform-interpreter --canonicalize --cse \ diff --git a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py index d0cb059aa..5d8fdd016 100644 --- a/contrib/mlir-air/test/integration/test_air_matmul_e2e.py +++ b/contrib/mlir-air/test/integration/test_air_matmul_e2e.py @@ -1,3 +1,9 @@ +# Copyright 2026 The ASTER Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + """E2E matmul test exercising the real AIR lowering path. Pipeline: