diff --git a/lib/Dialect/AMDGCN/CodeGen/CodeGenPatterns.cpp b/lib/Dialect/AMDGCN/CodeGen/CodeGenPatterns.cpp index d5f6a6b5..183804ed 100644 --- a/lib/Dialect/AMDGCN/CodeGen/CodeGenPatterns.cpp +++ b/lib/Dialect/AMDGCN/CodeGen/CodeGenPatterns.cpp @@ -21,6 +21,7 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" @@ -80,6 +81,26 @@ struct AmdgpuMFMAOpPattern : public OpCodeGenPattern { ConversionPatternRewriter &rewriter) const override; }; +//===----------------------------------------------------------------------===// +// ToElementsOpPattern +//===----------------------------------------------------------------------===// +struct ToElementsOpPattern : public OpCodeGenPattern { + using OpCodeGenPattern::OpCodeGenPattern; + LogicalResult + matchAndRewrite(vector::ToElementsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +//===----------------------------------------------------------------------===// +// FromElementsOpPattern +//===----------------------------------------------------------------------===// +struct FromElementsOpPattern : public OpCodeGenPattern { + using OpCodeGenPattern::OpCodeGenPattern; + LogicalResult + matchAndRewrite(vector::FromElementsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -348,6 +369,32 @@ LogicalResult AmdgpuMFMAOpPattern::matchAndRewrite( return success(); } +//===----------------------------------------------------------------------===// +// ToElementsOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult ToElementsOpPattern::matchAndRewrite( + vector::ToElementsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto splitOp = amdgcn::SplitRegisterRangeOp::create( + rewriter, op.getLoc(), adaptor.getSource()); + rewriter.replaceOp(op, splitOp.getResults()); + return success(); +} + +//===----------------------------------------------------------------------===// +// FromElementsOpPattern +//===----------------------------------------------------------------------===// + +LogicalResult FromElementsOpPattern::matchAndRewrite( + vector::FromElementsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto makeOp = amdgcn::MakeRegisterRangeOp::create( + rewriter, op.getLoc(), adaptor.getElements()); + rewriter.replaceOp(op, makeOp.getResult()); + return success(); +} + //===----------------------------------------------------------------------===// // Internal functions //===----------------------------------------------------------------------===// @@ -469,6 +516,7 @@ void mlir::aster::amdgcn::populateCodeGenPatterns(CodeGenConverter &converter, lsir::RegConstraintOp, ptr::LoadOp, ptr::StoreOp>(); target.addIllegalOp(); + target.addIllegalOp(); // Add the patterns. patterns.add, @@ -476,5 +524,6 @@ void mlir::aster::amdgcn::populateCodeGenPatterns(CodeGenConverter &converter, IDDimOpPattern, IDDimOpPattern, PtrLoadOpPattern, PtrStoreOpPattern, PtrAddOpPattern, - AmdgpuMFMAOpPattern>(converter); + AmdgpuMFMAOpPattern, ToElementsOpPattern, + FromElementsOpPattern>(converter); } diff --git a/test/CodeGen/amdgcn.mlir b/test/CodeGen/amdgcn.mlir index 14b8d594..70cc6980 100644 --- a/test/CodeGen/amdgcn.mlir +++ b/test/CodeGen/amdgcn.mlir @@ -76,3 +76,23 @@ func.func private @test_mfma_f32_16x16x16_f16_attrs(%a: vector<4xf16>, %b: vecto %result = amdgpu.mfma 16x16x16 %a * %b + %c {cbsz = 1 : i32, abid = 1 : i32, blocks = 1 : i32} blgp = bcast_second_32 : vector<4xf16>, vector<4xf16>, vector<4xf32> return {abi = (!amdgcn.vgpr_range<[? + 2]>, !amdgcn.vgpr_range<[? + 2]>, !amdgcn.vgpr_range<[? + 4]>) -> !amdgcn.vgpr_range<[? + 4]>} %result : vector<4xf32> } + +// CHECK-LABEL: func.func private @test_to_elements( +// CHECK: %[[SPLIT:.*]]:4 = amdgcn.split_register_range %{{.*}} +// CHECK: return %[[SPLIT]]#0, %[[SPLIT]]#1, %[[SPLIT]]#2, %[[SPLIT]]#3 +// CHECK: } +func.func private @test_to_elements(%v: vector<4xf32>) -> (f32, f32, f32, f32) + attributes {abi = (!amdgcn.vgpr_range<[? + 4]>) -> (!amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr)} { + %0:4 = vector.to_elements %v : vector<4xf32> + return {abi = (!amdgcn.vgpr_range<[? + 4]>) -> (!amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr)} %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 +} + +// CHECK-LABEL: func.func private @test_from_elements( +// CHECK: %[[MAKE:.*]] = amdgcn.make_register_range %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} +// CHECK: return %[[MAKE]] +// CHECK: } +func.func private @test_from_elements(%a: f32, %b: f32, %c: f32, %d: f32) -> vector<4xf32> + attributes {abi = (!amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr) -> !amdgcn.vgpr_range<[? + 4]>} { + %v = vector.from_elements %a, %b, %c, %d : vector<4xf32> + return {abi = (!amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr, !amdgcn.vgpr) -> !amdgcn.vgpr_range<[? + 4]>} %v : vector<4xf32> +}