Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
a7f454a
Add AIR-to-AMDGCN lowering passes for mlir-air contrib
erwei-xilinx Apr 8, 2026
9a3807b
[mlir-air] Wire DMA→LDS lowering; fix K-accumulation and multi-wavefr…
erwei-xilinx Apr 8, 2026
9a8bece
[mlir-air] Clean up link deps: use CMake transitive resolution
erwei-xilinx Apr 8, 2026
e82320b
[mlir-air] Clean up transform script and ConvertLinalgToAMDGCN
erwei-xilinx Apr 8, 2026
d20501f
[mlir-air] Rename ConvertLinalgToAMDGCN → ConvertToAMDGCNLibraryCalls
erwei-xilinx Apr 8, 2026
deebce2
[mlir-air] Remove dead channel handling from ConvertToAMDGCNLibraryCalls
erwei-xilinx Apr 8, 2026
21b05bd
[mlir-air] Refactor ConvertToAMDGCNLibraryCalls to use rewrite patterns
erwei-xilinx Apr 8, 2026
3244c68
[mlir-air] Add padded matmul E2E test (40x40, host-padded to 48x48)
erwei-xilinx Apr 8, 2026
d20bc7e
[mlir-air] WIP: Compiler-level padding with LDS C accumulation
erwei-xilinx Apr 8, 2026
bf652bb
[mlir-air] Fix PreloadLibrary infinite loop + WIP padded matmul
erwei-xilinx Apr 8, 2026
9ca7081
[mlir-air] Use pad_tiling_interface for padded matmul (40→48)
erwei-xilinx Apr 8, 2026
b3e494e
[mlir-air] Device-side padded matmul via air-split-launch-for-padding
erwei-xilinx Apr 8, 2026
e79cbb4
Revert dead kittens code, keep only fill_lds_16x64_b
erwei-xilinx Apr 8, 2026
465d671
Remove unused PromoteAllocsToFuncArgs pass
erwei-xilinx Apr 8, 2026
1b6d448
Add copyright headers to new MLIR test files
erwei-xilinx Apr 8, 2026
e8c6943
Add missing copyright headers to new files
erwei-xilinx Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,18 @@ add_subdirectory(test)
option(ASTER_ENABLE_MLIR_AIR "Build mlir-air contrib (linalg->AMDGCN)" OFF)
if(ASTER_ENABLE_MLIR_AIR)
message(STATUS "mlir-air enabled")
# Upstream Xilinx/mlir-air: AIR dialect + conversion passes (no AIE, no GPU backend).
set(AIR_ENABLE_AIE OFF CACHE BOOL "" FORCE)
set(AIR_ENABLE_GPU OFF CACHE BOOL "" FORCE)
# Stub targets expected by upstream mlir-air CMakeLists.
add_custom_target(air-headers)
# Make upstream AIR .td files and generated headers visible to mlir-tblgen
# and C++ compilation. Must come before add_subdirectory so the directory
# property is set when mlir_tablegen() calls get_directory_property().
include_directories(${CMAKE_SOURCE_DIR}/third_party/mlir-air/mlir/include)
include_directories(${CMAKE_BINARY_DIR}/third_party/mlir-air/mlir/include)
# Only build AIR lib+include, not its test suite (check-all not available here).
add_subdirectory(third_party/mlir-air/mlir/include)
add_subdirectory(third_party/mlir-air/mlir/lib)
add_subdirectory(contrib/mlir-air)
endif()
15 changes: 15 additions & 0 deletions contrib/kittens/library/lds_16x64_b.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,19 @@ amdgcn.library @kittens_lds_16x64_b isa = [#amdgcn.isa<cdna3>] {
return %future : !future_lds_read
}

// Zero-fill a 16x64_b LDS tile (1024 bytes).
// Each thread writes 16 bytes of zeros at its assigned positions.
func.func private @fill_lds_16x64_b(%tile_base: index) {
%zero = func.call @alloc_vgprx2() : () -> !vx2
%addr_lo, %addr_hi = func.call @compute_lds_write_addrs_16x64_b(%tile_base)
: (index) -> (index, index)
%t0 = func.call @write_vx2_to_lds_at(%zero, %addr_lo)
: (!vx2, index) -> !future_lds_write
%t1 = func.call @write_vx2_to_lds_at(%zero, %addr_hi)
: (!vx2, index) -> !future_lds_write
amdgcn.wait deps %t0 : !future_lds_write
amdgcn.wait deps %t1 : !future_lds_write
return
}

}
14 changes: 13 additions & 1 deletion contrib/mlir-air/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
add_mlir_library(MlirAirLib
lib/ConvertLinalgToAMDGCN.cpp
lib/AirToAMDGCN.cpp
lib/ConvertToAMDGCNLibraryCalls.cpp
lib/ConvertMemSpaceToAMDGCN.cpp
lib/Init.cpp
lib/Pipelines.cpp

LINK_LIBS PUBLIC
AIRConversionPasses
AIRDialect
MLIRBufferizationDialect
MLIRBufferizationTransformOps
MLIRBufferizationTransforms
MLIRTensorInferTypeOpInterfaceImpl
MLIRTensorTransforms
AIRTransformOps
AIRTransformPasses
ASTERInit
AsterTransforms
AsterCodeGen
Expand All @@ -13,6 +24,7 @@ add_mlir_library(MlirAirLib
MLIRAffineDialect
MLIRAffineTransforms
MLIRFuncDialect
MLIRGPUDialect
MLIRLinalgDialect
MLIRLinalgTransformOps
MLIRLinalgTransforms
Expand Down
324 changes: 324 additions & 0 deletions contrib/mlir-air/lib/AirToAMDGCN.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
// Copyright 2026 The ASTER Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

//===- AirToAMDGCN.cpp - Lower AIR hierarchy ops to AMDGCN IR ------------===//
//
// Lowers air.launch, air.segment, air.herd, air.execute, and air.wait_all
// to flat AMDGCN-compatible IR:
//
// air.launch IDs -> gpu.block_id (workgroup IDs)
// air.herd tile IDs -> gpu.thread_id / 64 (wavefront index within workgroup)
// air.segment -> inline body (transparent wrapper)
// air.execute -> inline body (strip async)
// air.wait_all -> erase
//
// air.channel.put/get are not expected after air-to-amdgcn (air-dma-to-channel not used).
//===----------------------------------------------------------------------===//

#include "air/Dialect/AIR/AIRDialect.h"
#include "aster/Interfaces/ModuleOpInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;

namespace {

// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------

/// Map gpu::Dimension enum from an integer index.
static gpu::Dimension dimFromIndex(unsigned i) {
switch (i) {
case 0:
return gpu::Dimension::x;
case 1:
return gpu::Dimension::y;
case 2:
return gpu::Dimension::z;
default:
llvm_unreachable("invalid dimension index");
}
}

/// Clone all ops from `src` into `builder`'s insertion point, applying
/// `mapping`. Ops whose type matches any of `SkipOps...` are skipped.
template <typename... SkipOps>
static void cloneBodyOps(OpBuilder &builder, Block &src, IRMapping &mapping) {
for (auto &op : src.getOperations()) {
if ((isa<SkipOps>(op) || ...))
continue;
builder.clone(op, mapping);
}
}

/// Strip async_dependencies from a channel op and drop its async_token result.
static void stripAsyncFromChannelOp(Operation *op) {
if (auto asyncOp = dyn_cast<xilinx::air::AsyncOpInterface>(op)) {
// Remove all async dependency operands.
while (asyncOp.getAsyncDependencies().size() > 0)
asyncOp.eraseAsyncDependency(0);
// If the op has an async_token result, replace uses and drop it.
if (auto token = asyncOp.getAsyncToken()) {
token.replaceAllUsesWith(Value());
// We can't actually remove a result from an existing op, but since all
// token users will be erased (wait_all, other async deps), replacing
// with null is sufficient — dead token uses are cleaned up below.
}
}
}

// ---------------------------------------------------------------------------
// Pass
// ---------------------------------------------------------------------------

struct AirToAMDGCN
: public PassWrapper<AirToAMDGCN,
InterfacePass<aster::ModuleOpInterface>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AirToAMDGCN)
StringRef getArgument() const override { return "air-to-amdgcn"; }
StringRef getDescription() const override {
return "Lower AIR hierarchy ops to AMDGCN-compatible IR";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<gpu::GPUDialect>();
}

void runOnOperation() override {
Operation *moduleOp = getOperation();
OpBuilder builder(moduleOp->getContext());

// -----------------------------------------------------------------------
// Phase 1: Strip async tokens.
// -----------------------------------------------------------------------

// Strip async from channel ops (they survive this pass).
moduleOp->walk([&](xilinx::air::ChannelPutOp op) {
stripAsyncFromChannelOp(op);
});
moduleOp->walk([&](xilinx::air::ChannelGetOp op) {
stripAsyncFromChannelOp(op);
});

// Inline air.execute: splice body, replace results, erase.
SmallVector<xilinx::air::ExecuteOp> executes;
moduleOp->walk([&](xilinx::air::ExecuteOp op) { executes.push_back(op); });
for (auto execOp : executes) {
Block &body = execOp.getBody();
auto terminator =
cast<xilinx::air::ExecuteTerminatorOp>(body.getTerminator());

// Replace non-token results (results 1..N) with terminator operands.
for (unsigned i = 0; i < terminator.getNumOperands(); ++i)
execOp.getResult(i + 1).replaceAllUsesWith(terminator.getOperand(i));

// Replace async token result (result 0) — users will be erased.
execOp.getResult(0).replaceAllUsesWith(Value());

// Splice body ops (except terminator) before the execute op.
auto &parentOps = execOp->getBlock()->getOperations();
auto &bodyOps = body.getOperations();
auto beforeTerminator = terminator->getIterator();
parentOps.splice(execOp->getIterator(), bodyOps, bodyOps.begin(),
beforeTerminator);

terminator->erase();
execOp->erase();
}

// Erase air.wait_all ops.
SmallVector<xilinx::air::WaitAllOp> waitAlls;
moduleOp->walk(
[&](xilinx::air::WaitAllOp op) { waitAlls.push_back(op); });
for (auto op : waitAlls) {
if (auto token = op.getAsyncToken())
token.replaceAllUsesWith(Value());
op->erase();
}

// -----------------------------------------------------------------------
// Phase 2: Inline hierarchy ops (inside-out).
// -----------------------------------------------------------------------

// --- air.herd -> wavefront index (gpu.thread_id / wavefront_size) ---
// Each herd tile is a wavefront (64 threads cooperating collectively).
// Herd tile ID = wavefront index within the workgroup.
SmallVector<xilinx::air::HerdOp> herds;
moduleOp->walk([&](xilinx::air::HerdOp op) { herds.push_back(op); });
for (auto herd : herds) {
builder.setInsertionPoint(herd);
Block &body = herd.getBody().front();
unsigned numDims = herd.getNumDims();
IRMapping mapping;

// Map tile IDs to wavefront index = gpu.thread_id / 64.
auto ids = herd.getIds();
Value wavefrontSize = arith::ConstantIndexOp::create(
builder, herd.getLoc(), 64);
for (unsigned i = 0; i < numDims; ++i) {
Value threadId = gpu::ThreadIdOp::create(builder, herd.getLoc(),
dimFromIndex(i));
Value wavefrontId = arith::DivUIOp::create(builder, herd.getLoc(),
threadId, wavefrontSize);
mapping.map(ids[i], wavefrontId);
}

// Map tile sizes to size operands.
auto sizeArgs = herd.getSize();
auto sizeOperands = herd.getSizeOperands();
for (unsigned i = 0; i < numDims; ++i)
mapping.map(sizeArgs[i], sizeOperands[i]);

// Map kernel arguments to kernel operands.
auto kernelArgs = herd.getKernelArguments();
auto kernelOperands = herd.getKernelOperands();
for (unsigned i = 0; i < kernelArgs.size(); ++i)
mapping.map(kernelArgs[i], kernelOperands[i]);

// Clone body.
cloneBodyOps<xilinx::air::HerdTerminatorOp>(builder, body, mapping);

// Replace async token if present.
if (auto token = herd.getAsyncToken())
token.replaceAllUsesWith(Value());
herd->erase();
}

// --- air.segment -> inline ---
SmallVector<xilinx::air::SegmentOp> segments;
moduleOp->walk(
[&](xilinx::air::SegmentOp op) { segments.push_back(op); });
for (auto segment : segments) {
builder.setInsertionPoint(segment);
Block &body = segment.getBody().front();
unsigned numDims = segment.getNumDims();
IRMapping mapping;

// Map segment IDs (if any) to gpu.block_id ops.
auto ids = segment.getIds();
auto sizeArgs = segment.getSize();
auto sizeOperands = segment.getSizeOperands();
for (unsigned i = 0; i < numDims; ++i) {
Value blockId = gpu::BlockIdOp::create(builder, segment.getLoc(),
dimFromIndex(i));
mapping.map(ids[i], blockId);
mapping.map(sizeArgs[i], sizeOperands[i]);
}

// Map kernel arguments to kernel operands.
auto kernelArgs = segment.getKernelArguments();
auto kernelOperands = segment.getKernelOperands();
for (unsigned i = 0; i < kernelArgs.size(); ++i)
mapping.map(kernelArgs[i], kernelOperands[i]);

cloneBodyOps<xilinx::air::SegmentTerminatorOp>(builder, body, mapping);

if (auto token = segment.getAsyncToken())
token.replaceAllUsesWith(Value());
segment->erase();
}

// --- air.launch -> gpu.block_id ---
SmallVector<xilinx::air::LaunchOp> launches;
moduleOp->walk(
[&](xilinx::air::LaunchOp op) { launches.push_back(op); });
for (auto launch : launches) {
builder.setInsertionPoint(launch);
Block &body = launch.getBody().front();
unsigned numDims = launch.getNumDims();
IRMapping mapping;

// Map launch IDs to gpu.block_id ops.
auto ids = launch.getIds();
auto sizeArgs = launch.getSize();
auto sizeOperands = launch.getSizeOperands();
for (unsigned i = 0; i < numDims; ++i) {
Value blockId = gpu::BlockIdOp::create(builder, launch.getLoc(),
dimFromIndex(i));
mapping.map(ids[i], blockId);
mapping.map(sizeArgs[i], sizeOperands[i]);
}

// Map kernel arguments to kernel operands.
auto kernelArgs = launch.getKernelArguments();
auto kernelOperands = launch.getKernelOperands();
for (unsigned i = 0; i < kernelArgs.size(); ++i)
mapping.map(kernelArgs[i], kernelOperands[i]);

cloneBodyOps<xilinx::air::LaunchTerminatorOp>(builder, body, mapping);

if (auto token = launch.getAsyncToken())
token.replaceAllUsesWith(Value());
launch->erase();
}

// --- scf.parallel -> wavefront ID inlining ---
// After air-dma-to-channel, hoisted channel.put/get ops are wrapped in
// scf.parallel loops that iterate over the herd tile space. These must
// be inlined: replace induction variables with gpu.thread_id / 64
// (wavefront index), then splice body ops in place of the parallel.
SmallVector<scf::ParallelOp> parallels;
moduleOp->walk(
[&](scf::ParallelOp op) { parallels.push_back(op); });
for (auto parallel : parallels) {
builder.setInsertionPoint(parallel);
Location loc = parallel.getLoc();
Block &body = parallel.getRegion().front();

// Compute wavefront ID for each induction variable.
// All dimensions are derived from thread_id_x (1D wavefront layout).
Value wavefrontSize =
arith::ConstantIndexOp::create(builder, loc, 64);
Value threadIdX =
gpu::ThreadIdOp::create(builder, loc, gpu::Dimension::x);
Value wavefrontId =
arith::DivUIOp::create(builder, loc, threadIdX, wavefrontSize);

auto ivs = parallel.getInductionVars();
if (ivs.size() == 1) {
// 1D parallel: iv = wavefrontId directly.
ivs[0].replaceAllUsesWith(wavefrontId);
} else if (ivs.size() == 2) {
// 2D parallel: decompose wavefrontId into (x, y).
auto ub = parallel.getUpperBound();
Value ubX = ub[0];
Value ivX = arith::RemUIOp::create(builder, loc, wavefrontId, ubX);
Value ivY = arith::DivUIOp::create(builder, loc, wavefrontId, ubX);
ivs[0].replaceAllUsesWith(ivX);
ivs[1].replaceAllUsesWith(ivY);
}

// Replace init values / results. scf.parallel with reduce produces
// results from scf.reduce. Replace with init values (no actual
// reduction needed — each wavefront runs independently).
for (unsigned i = 0; i < parallel.getNumResults(); ++i)
parallel.getResult(i).replaceAllUsesWith(parallel.getInitVals()[i]);

// Splice body ops (skip the yield/reduce terminator) before parallel.
auto &parentOps = parallel->getBlock()->getOperations();
auto &bodyOps = body.getOperations();
// Find the last non-terminator op.
auto termIt = body.getTerminator()->getIterator();
parentOps.splice(parallel->getIterator(), bodyOps, bodyOps.begin(),
termIt);

parallel->erase();
}
}
};

} // namespace

namespace mlir::aster::mlir_air {
std::unique_ptr<Pass> createAirToAMDGCN() {
return std::make_unique<AirToAMDGCN>();
}
} // namespace mlir::aster::mlir_air
Loading