Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion lib/Dialect/AMDGCN/CodeGen/CodeGenPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Ptr/IR/PtrOps.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Transforms/DialectConversion.h"

Expand Down Expand Up @@ -80,6 +81,26 @@ struct AmdgpuMFMAOpPattern : public OpCodeGenPattern<amdgpu::MFMAOp> {
ConversionPatternRewriter &rewriter) const override;
};

//===----------------------------------------------------------------------===//
// ToElementsOpPattern
//===----------------------------------------------------------------------===//
struct ToElementsOpPattern : public OpCodeGenPattern<vector::ToElementsOp> {
using OpCodeGenPattern::OpCodeGenPattern;
LogicalResult
matchAndRewrite(vector::ToElementsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

//===----------------------------------------------------------------------===//
// FromElementsOpPattern
//===----------------------------------------------------------------------===//
struct FromElementsOpPattern : public OpCodeGenPattern<vector::FromElementsOp> {
using OpCodeGenPattern::OpCodeGenPattern;
LogicalResult
matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -348,6 +369,32 @@ LogicalResult AmdgpuMFMAOpPattern::matchAndRewrite(
return success();
}

//===----------------------------------------------------------------------===//
// ToElementsOpPattern
//===----------------------------------------------------------------------===//

LogicalResult ToElementsOpPattern::matchAndRewrite(
vector::ToElementsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto splitOp = amdgcn::SplitRegisterRangeOp::create(
rewriter, op.getLoc(), adaptor.getSource());
rewriter.replaceOp(op, splitOp.getResults());
return success();
}

//===----------------------------------------------------------------------===//
// FromElementsOpPattern
//===----------------------------------------------------------------------===//

LogicalResult FromElementsOpPattern::matchAndRewrite(
vector::FromElementsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto makeOp = amdgcn::MakeRegisterRangeOp::create(
rewriter, op.getLoc(), adaptor.getElements());
rewriter.replaceOp(op, makeOp.getResult());
return success();
}

//===----------------------------------------------------------------------===//
// Internal functions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -469,12 +516,14 @@ void mlir::aster::amdgcn::populateCodeGenPatterns(CodeGenConverter &converter,
lsir::RegConstraintOp, ptr::LoadOp, ptr::StoreOp>();

target.addIllegalOp<amdgpu::MFMAOp>();
target.addIllegalOp<vector::ToElementsOp, vector::FromElementsOp>();

// Add the patterns.
patterns.add<IDDimOpPattern<aster_utils::ThreadIdOp, amdgcn::ThreadIdOp>,
IDDimOpPattern<aster_utils::BlockIdOp, amdgcn::BlockIdOp>,
IDDimOpPattern<aster_utils::BlockDimOp, amdgcn::BlockDimOp>,
IDDimOpPattern<aster_utils::GridDimOp, amdgcn::GridDimOp>,
PtrLoadOpPattern, PtrStoreOpPattern, PtrAddOpPattern,
AmdgpuMFMAOpPattern>(converter);
AmdgpuMFMAOpPattern, ToElementsOpPattern,
FromElementsOpPattern>(converter);
}
20 changes: 20 additions & 0 deletions test/CodeGen/amdgcn.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,23 @@ func.func private @test_mfma_f32_16x16x16_f16_attrs(%a: vector<4xf16>, %b: vecto
%result = amdgpu.mfma 16x16x16 %a * %b + %c {cbsz = 1 : i32, abid = 1 : i32, blocks = 1 : i32} blgp = bcast_second_32 : vector<4xf16>, vector<4xf16>, vector<4xf32>
return {abi = (!amdgcn.vgpr_range<[? + 2]>, !amdgcn.vgpr_range<[? + 2]>, !amdgcn.vgpr_range<[? + 4]>) -> !amdgcn.vgpr_range<[? + 4]>} %result : vector<4xf32>
}

// CHECK-LABEL: func.func private @test_to_elements(
// CHECK: %[[SPLIT:.*]]:4 = amdgcn.split_register_range %{{.*}}
// CHECK: return %[[SPLIT]]#0, %[[SPLIT]]#1, %[[SPLIT]]#2, %[[SPLIT]]#3
// CHECK: }
func.func private @test_to_elements(%v: vector<4xf32>) -> (f32, f32, f32, f32)
attributes {abi = (!amdgcn.vgpr_range<[? + 4]>) -> (!amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr)} {
%0:4 = vector.to_elements %v : vector<4xf32>
return {abi = (!amdgcn.vgpr_range<[? + 4]>) -> (!amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr)} %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
}

// CHECK-LABEL: func.func private @test_from_elements(
// CHECK: %[[MAKE:.*]] = amdgcn.make_register_range %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
// CHECK: return %[[MAKE]]
// CHECK: }
func.func private @test_from_elements(%a: f32, %b: f32, %c: f32, %d: f32) -> vector<4xf32>
attributes {abi = (!amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr) -> !amdgcn.vgpr_range<[? + 4]>} {
%v = vector.from_elements %a, %b, %c, %d : vector<4xf32>
return {abi = (!amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr) -> !amdgcn.vgpr_range<[? + 4]>} %v : vector<4xf32>
}