From 2d16467ba374933a57ca6992d7b2ec9f0f0b1ccc Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Mon, 13 Apr 2026 16:36:37 +0800 Subject: [PATCH 1/2] fix: restore tile shape verifier helper --- lib/PTO/IR/PTO.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 8bdf4f7a..19f3c100 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -101,6 +101,9 @@ static SmallVector getShapeVec(Type ty); static SmallVector getValidShapeVec(Type ty); static SmallVector getValidShapeVec(Value value); static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name); +static LogicalResult verifyTileBufSameShapeAndElem(Operation *op, Type lhs, Type rhs, + StringRef lhsName, + StringRef rhsName); static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, StringRef lhsName, StringRef rhsName); @@ -2290,6 +2293,17 @@ static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs return success(); } +static LogicalResult verifyTileBufSameShapeAndElem(Operation *op, Type lhs, Type rhs, + StringRef lhsName, + StringRef rhsName) { + if (failed(verifyTileBufSameElemType(op, lhs, rhs, lhsName, rhsName))) + return failure(); + if (getShapeVec(lhs) != getShapeVec(rhs)) + return op->emitOpError() << "expects " << lhsName << " and " << rhsName + << " to have the same shape"; + return success(); +} + static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, StringRef lhsName, StringRef rhsName) { if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) From 11645c07e953f15685a771d2f0e306c9230daac7 Mon Sep 17 00:00:00 2001 From: HecreReed <821896444@qq.com> Date: Mon, 13 Apr 2026 17:10:16 +0800 Subject: [PATCH 2/2] fix(verifier): use elem-type check for arg reduction tmp --- lib/PTO/IR/PTO.cpp | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index 19f3c100..c5063d42 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -101,9 +101,6 @@ static SmallVector getShapeVec(Type ty); static SmallVector getValidShapeVec(Type ty); static SmallVector getValidShapeVec(Value value); static LogicalResult verifyTileBufCommon(Operation *op, Type ty, StringRef name); -static LogicalResult verifyTileBufSameShapeAndElem(Operation *op, Type lhs, Type rhs, - StringRef lhsName, - StringRef rhsName); static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs, StringRef lhsName, StringRef rhsName); @@ -2293,17 +2290,6 @@ static LogicalResult verifyTileBufSameElemType(Operation *op, Type lhs, Type rhs return success(); } -static LogicalResult verifyTileBufSameShapeAndElem(Operation *op, Type lhs, Type rhs, - StringRef lhsName, - StringRef rhsName) { - if (failed(verifyTileBufSameElemType(op, lhs, rhs, lhsName, rhsName))) - return failure(); - if (getShapeVec(lhs) != getShapeVec(rhs)) - return op->emitOpError() << "expects " << lhsName << " and " << rhsName - << " to have the same shape"; - return success(); -} - static LogicalResult verifyTileBufSameValidShape(Operation *op, Type lhs, Type rhs, StringRef lhsName, StringRef rhsName) { if (!isTileLikeType(lhs) || !isTileLikeType(rhs)) @@ -3426,7 +3412,7 @@ LogicalResult pto::TColArgMaxOp::verify() { failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || failed(verifyColArgReductionDstLayout(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, tmpTy, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) return failure(); if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, @@ -3493,7 +3479,7 @@ LogicalResult pto::TColArgMinOp::verify() { failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || failed(verifyColArgReductionDstLayout(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, tmpTy, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) return failure(); if (failed(verifyColReductionValidRegion(*this, srcTy, dstTy, @@ -7812,7 +7798,7 @@ mlir::LogicalResult mlir::pto::TRowArgMaxOp::verify() { failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, tmpTy, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) return failure(); if (failed(verifyRowReductionValidRegion(*this, srcTy, dstTy))) @@ -7886,7 +7872,7 @@ mlir::LogicalResult mlir::pto::TRowArgMinOp::verify() { failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) return failure(); - if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + if (failed(verifyTileBufSameElemType(*this, srcTy, tmpTy, "src", "tmp")) || failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) return failure(); if (failed(verifyRowReductionValidRegion(*this, srcTy, dstTy)))