Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions include/PTO/IR/PTOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1445,33 +1445,39 @@ def TPushToAicOp : PTO_TOp<"tpush_to_aic"> {
}];
}

def TPopFromAicOp : PTO_TOp<"tpop_from_aic"> {
def TPopFromAicOp : PTO_TOp<"tpop_from_aic", [AttrSizedOperandSegments]> {
let summary = "Frontend C2V consumer pop in Vector kernel";

let arguments = (ins
Optional<Index>:$valid_row,
Optional<Index>:$valid_col,
I8Attr:$split
);

let results = (outs PTODpsType:$tile);
let hasVerifier = 1;

let assemblyFormat = [{
`{` `split` `=` $split `}` attr-dict `->` qualified(type($tile))
(`(` $valid_row^ `,` $valid_col `)`)? `{` `split` `=` $split `}` attr-dict
`->` qualified(type($tile))
}];
}

def TPopFromAivOp : PTO_TOp<"tpop_from_aiv"> {
def TPopFromAivOp : PTO_TOp<"tpop_from_aiv", [AttrSizedOperandSegments]> {
let summary = "Frontend V2C consumer pop in Cube kernel";

let arguments = (ins
Optional<Index>:$valid_row,
Optional<Index>:$valid_col,
I8Attr:$split
);

let results = (outs PTODpsType:$tile);
let hasVerifier = 1;

let assemblyFormat = [{
`{` `split` `=` $split `}` attr-dict `->` qualified(type($tile))
(`(` $valid_row^ `,` $valid_col `)`)? `{` `split` `=` $split `}` attr-dict
`->` qualified(type($tile))
}];
}

Expand Down
35 changes: 30 additions & 5 deletions lib/PTO/IR/PTO.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6994,7 +6994,8 @@ static bool isLocallyBoundTileSource(Value value) {
if (!value || isa<BlockArgument>(value))
return false;

if (isa<AllocTileOp, BindTileOp, PointerCastOp>(value.getDefiningOp()))
if (isa<AllocTileOp, DeclareTileOp, BindTileOp, PointerCastOp>(
value.getDefiningOp()))
return true;

if (auto bitcast = value.getDefiningOp<BitcastOp>())
Expand Down Expand Up @@ -10334,6 +10335,32 @@ static LogicalResult verifyFrontendSplitOp(Operation *op,
return verifySplitAttr(op, split);
}

template <typename FrontendPopOpT>
static LogicalResult verifyFrontendPopOp(FrontendPopOpT op,
FunctionKernelKind expected,
StringRef kernelName) {
if (failed(verifyFrontendSplitOp(op.getOperation(), expected, kernelName,
op.getSplit())))
return failure();

bool hasValidRow = static_cast<bool>(op.getValidRow());
bool hasValidCol = static_cast<bool>(op.getValidCol());
if (hasValidRow != hasValidCol)
return op.emitOpError(
"expects valid_row and valid_col operands to be provided together");
if (!hasValidRow)
return success();

auto tileTy = dyn_cast<TileBufType>(op.getTile().getType());
if (!tileTy)
return op.emitOpError(
"expects tile result to be !pto.tile_buf when valid_row/valid_col operands are provided");
if (!tileTy.hasDynamicValid())
return op.emitOpError(
"expects tile result to have dynamic validShape (?, ?) when valid_row/valid_col operands are provided");
return success();
}

static LogicalResult verifyPipeShape(Operation *op, int8_t dirMask, int32_t slotSize,
int32_t slotNum,
std::optional<int32_t> flagBase) {
Expand Down Expand Up @@ -10468,13 +10495,11 @@ LogicalResult TPushToAicOp::verify() {
}

LogicalResult TPopFromAicOp::verify() {
return verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Vector,
"vector", getSplit());
return verifyFrontendPopOp(*this, FunctionKernelKind::Vector, "vector");
}

LogicalResult TPopFromAivOp::verify() {
return verifyFrontendSplitOp(getOperation(), FunctionKernelKind::Cube,
"cube", getSplit());
return verifyFrontendPopOp(*this, FunctionKernelKind::Cube, "cube");
}

LogicalResult TFreeFromAicOp::verify() {
Expand Down
8 changes: 8 additions & 0 deletions lib/PTO/Transforms/PTOLowerFrontendPipeOpsPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ static LogicalResult lowerFrontendDataOps(func::FuncOp funcOp,
}
auto decl = rewriter.create<DeclareTileOp>(pop.getLoc(),
pop.getTile().getType());
if (pop.getValidRow() && pop.getValidCol()) {
rewriter.create<SetValidShapeOp>(pop.getLoc(), decl.getTile(),
pop.getValidRow(), pop.getValidCol());
}
rewriter.create<TPopOp>(pop.getLoc(), decl.getTile(), handles.c2vPipe,
pop.getSplitAttr());
rewriter.replaceOp(pop, decl.getTile());
Expand All @@ -270,6 +274,10 @@ static LogicalResult lowerFrontendDataOps(func::FuncOp funcOp,
}
auto decl = rewriter.create<DeclareTileOp>(pop.getLoc(),
pop.getTile().getType());
if (pop.getValidRow() && pop.getValidCol()) {
rewriter.create<SetValidShapeOp>(pop.getLoc(), decl.getTile(),
pop.getValidRow(), pop.getValidCol());
}
rewriter.create<TPopOp>(pop.getLoc(), decl.getTile(), handles.v2cPipe,
pop.getSplitAttr());
rewriter.replaceOp(pop, decl.getTile());
Expand Down
20 changes: 19 additions & 1 deletion lib/PTO/Transforms/PTOToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9123,10 +9123,28 @@ struct PTOBindTileToEmitC : public OpConversionPattern<pto::BindTileOp> {
};

if (op.getSource().getDefiningOp<pto::DeclareTileMemRefOp>()) {
auto hasFollowingSetValidShape = [&]() {
for (Operation *user : op->getUsers()) {
auto setValidShape = dyn_cast<pto::SetValidShapeOp>(user);
if (!setValidShape)
continue;
if (setValidShape.getSource() != op.getResult())
continue;
return true;
}
return false;
};

FailureOr<TileBuildSpec> tileSpec = buildTileSpec();
if (failed(tileSpec))
return failure();
rewriter.replaceOp(op, buildTileValue(*tileSpec));
TileBuildSpec declSpec = *tileSpec;
if (op->hasAttr(kForceDynamicValidShapeAttrName) &&
hasFollowingSetValidShape()) {
declSpec.useConstructor = false;
declSpec.constructorArgs.clear();
}
rewriter.replaceOp(op, buildTileValue(declSpec));
return success();
}

Expand Down
59 changes: 59 additions & 0 deletions test/basic/tpush_tpop_dynamic_validshape_a5.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5

module {
func.func @cube_kernel(%vr: index, %vc: index)
attributes {pto.kernel_kind = #pto.kernel_kind<cube>} {
%v2c_local = pto.reserve_buffer {
name = "v2c_fifo",
size = 4096,
location = #pto.address_space<mat>,
auto = true
} -> i32
%c2v_import = pto.import_reserved_buffer {
name = "c2v_fifo",
peer_func = @vector_kernel
} -> i32
pto.aic_initialize_pipe {dir_mask = 3, slot_size = 1024}
(c2v_consumer_buf = %c2v_import : i32,
v2c_consumer_buf = %v2c_local : i32)

%mat_tile = pto.tpop_from_aiv(%vr, %vc) {split = 2}
-> !pto.tile_buf<loc=mat, dtype=f32, rows=16, cols=64, v_row=?, v_col=?, blayout=col_major, slayout=row_major, fractal=512, pad=0>
pto.tfree_from_aiv {split = 2}
return
}

func.func @vector_kernel(%vr: index, %vc: index)
attributes {pto.kernel_kind = #pto.kernel_kind<vector>} {
%c2v_local = pto.reserve_buffer {
name = "c2v_fifo",
size = 4096,
location = #pto.address_space<vec>,
auto = true
} -> i32
%v2c_import = pto.import_reserved_buffer {
name = "v2c_fifo",
peer_func = @cube_kernel
} -> i32
pto.aiv_initialize_pipe {dir_mask = 3, slot_size = 1024}
(c2v_consumer_buf = %c2v_local : i32,
v2c_consumer_buf = %v2c_import : i32)

%recv_tile = pto.tpop_from_aic(%vr, %vc) {split = 2}
-> !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=64, v_row=?, v_col=?, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tfree_from_aic {split = 2}
return
}
}

// A5-LABEL: AICORE void cube_kernel(
// A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4>(
// A5: Tile<TileType::Mat, float, 16, 64, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> [[CUBE_TILE:v[0-9]+]];
// A5: [[CUBE_TILE]].SetValidShape({{v[0-9]+}}, {{v[0-9]+}});
// A5: TPOP<TPipe<0, Direction::DIR_BOTH, 1024, 4>, Tile<TileType::Mat, float, 16, 64, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>, TileSplitAxis::TILE_LEFT_RIGHT>({{v[0-9]+}}, [[CUBE_TILE]]);

// A5-LABEL: AICORE void vector_kernel(
// A5: auto {{v[0-9]+}} = TPipe<0, Direction::DIR_BOTH, 1024, 4>(
// A5: Tile<TileType::Vec, float, 16, 64, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null> [[VEC_TILE:v[0-9]+]];
// A5: [[VEC_TILE]].SetValidShape({{v[0-9]+}}, {{v[0-9]+}});
// A5: TPOP<TPipe<0, Direction::DIR_BOTH, 1024, 4>, Tile<TileType::Vec, float, 16, 64, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>, TileSplitAxis::TILE_LEFT_RIGHT>({{v[0-9]+}}, [[VEC_TILE]]);
55 changes: 55 additions & 0 deletions test/basic/tpush_tpop_dynamic_validshape_default_a5.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5

module {
func.func @cube_kernel() attributes {pto.kernel_kind = #pto.kernel_kind<cube>} {
%v2c_local = pto.reserve_buffer {
name = "v2c_fifo",
size = 4096,
location = #pto.address_space<mat>,
auto = true
} -> i32
%c2v_import = pto.import_reserved_buffer {
name = "c2v_fifo",
peer_func = @vector_kernel
} -> i32
pto.aic_initialize_pipe {dir_mask = 3, slot_size = 1024}
(c2v_consumer_buf = %c2v_import : i32,
v2c_consumer_buf = %v2c_local : i32)

%mat_tile = pto.tpop_from_aiv {split = 2}
-> !pto.tile_buf<loc=mat, dtype=f32, rows=16, cols=64, v_row=?, v_col=?, blayout=col_major, slayout=row_major, fractal=512, pad=0>
pto.tfree_from_aiv {split = 2}
return
}

func.func @vector_kernel() attributes {pto.kernel_kind = #pto.kernel_kind<vector>} {
%c2v_local = pto.reserve_buffer {
name = "c2v_fifo",
size = 4096,
location = #pto.address_space<vec>,
auto = true
} -> i32
%v2c_import = pto.import_reserved_buffer {
name = "v2c_fifo",
peer_func = @cube_kernel
} -> i32
pto.aiv_initialize_pipe {dir_mask = 3, slot_size = 1024}
(c2v_consumer_buf = %c2v_local : i32,
v2c_consumer_buf = %v2c_import : i32)

%recv_tile = pto.tpop_from_aic {split = 2}
-> !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=64, v_row=?, v_col=?, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tfree_from_aic {split = 2}
return
}
}

// A5-LABEL: AICORE void cube_kernel(
// A5: Tile<TileType::Mat, float, 16, 64, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null> {{v[0-9]+}} = Tile<TileType::Mat, float, 16, 64, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>({{v[0-9]+}}, {{v[0-9]+}});
// A5-NOT: SetValidShape
// A5: TPOP<TPipe<0, Direction::DIR_BOTH, 1024, 4>, Tile<TileType::Mat, float, 16, 64, BLayout::ColMajor, -1, -1, SLayout::RowMajor, 512, PadValue::Null, CompactMode::Null>, TileSplitAxis::TILE_LEFT_RIGHT>(

// A5-LABEL: AICORE void vector_kernel(
// A5: Tile<TileType::Vec, float, 16, 64, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null> {{v[0-9]+}} = Tile<TileType::Vec, float, 16, 64, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>({{v[0-9]+}}, {{v[0-9]+}});
// A5-NOT: SetValidShape
// A5: TPOP<TPipe<0, Direction::DIR_BOTH, 1024, 4>, Tile<TileType::Vec, float, 16, 64, BLayout::RowMajor, -1, -1, SLayout::NoneBox, 512, PadValue::Null, CompactMode::Null>, TileSplitAxis::TILE_LEFT_RIGHT>(
44 changes: 44 additions & 0 deletions test/basic/tpush_tpop_dynamic_validshape_invalid.pto
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: not ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s

module {
func.func @cube_kernel() attributes {pto.kernel_kind = #pto.kernel_kind<cube>} {
%v2c_local = pto.reserve_buffer {
name = "v2c_fifo",
size = 4096,
location = #pto.address_space<mat>,
auto = true
} -> i32
%c2v_import = pto.import_reserved_buffer {
name = "c2v_fifo",
peer_func = @vector_kernel
} -> i32
pto.aic_initialize_pipe {dir_mask = 3, slot_size = 1024}
(c2v_consumer_buf = %c2v_import : i32,
v2c_consumer_buf = %v2c_local : i32)
return
}

func.func @vector_kernel(%vr: index, %vc: index)
attributes {pto.kernel_kind = #pto.kernel_kind<vector>} {
%c2v_local = pto.reserve_buffer {
name = "c2v_fifo",
size = 4096,
location = #pto.address_space<vec>,
auto = true
} -> i32
%v2c_import = pto.import_reserved_buffer {
name = "v2c_fifo",
peer_func = @cube_kernel
} -> i32
pto.aiv_initialize_pipe {dir_mask = 3, slot_size = 1024}
(c2v_consumer_buf = %c2v_local : i32,
v2c_consumer_buf = %v2c_import : i32)

%recv_tile = pto.tpop_from_aic(%vr, %vc) {split = 0}
-> !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=64, v_row=16, v_col=64, blayout=row_major, slayout=none_box, fractal=512, pad=0>
pto.tfree_from_aic {split = 0}
return
}
}

// CHECK: error: 'pto.tpop_from_aic' op expects tile result to have dynamic validShape (?, ?) when valid_row/valid_col operands are provided
Loading