From d783d88ee8f3ccf751dd756e6acd07475bce395b Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 27 Apr 2026 05:08:10 +0000 Subject: [PATCH] fix: Align matmul batch broadcasting semantics Use PyTorch-style right-aligned broadcasting for matmul batch dimensions while preserving the existing rank >= 2 tensor contract. Reject divisible but non-broadcastable batch dimensions, allow unequal batch-rank inputs, and infer default output strides from the finalized output dimensions. Co-Authored-By: GPT 5.5 Signed-off-by: Sambhav Jain --- include/fusilli/node/matmul_node.h | 123 ++++++++++++++++------------- tests/test_matmul_node.cpp | 88 +++++++++++++++++---- 2 files changed, 139 insertions(+), 72 deletions(-) diff --git a/include/fusilli/node/matmul_node.h b/include/fusilli/node/matmul_node.h index 376634e5..f6aae2a8 100644 --- a/include/fusilli/node/matmul_node.h +++ b/include/fusilli/node/matmul_node.h @@ -37,32 +37,52 @@ namespace fusilli { // Infer the output shape of a matrix multiplication operation from the input // shapes. For matrices A [..., M, K] and B [..., K, N], the output is [..., M, // N]. -inline std::vector -getMatmulInferredOutputShape(const std::vector &aDim, - const std::vector &bDim) { - constexpr int64_t kNonBatchRank = 2; - size_t rank = aDim.size(); - assert(rank == bDim.size() && "Input tensors must have the same rank"); - assert(rank >= kNonBatchRank && "Input tensors must have rank >= 2"); - - std::vector cDim(rank); - - // Handle batch dimensions (broadcast if necessary) - size_t batchDims = rank - kNonBatchRank; - for (size_t i = 0; i < batchDims; ++i) { - int64_t aDimVal = aDim[i]; - int64_t bDimVal = bDim[i]; - // Use the maximum of the two dimensions (broadcasting rule) - assert((aDimVal % bDimVal == 0 || bDimVal % aDimVal == 0) && - "Incompatible dimensions for broadcasting"); - cDim[i] = std::max(aDimVal, bDimVal); +inline ErrorOr> +tryGetMatmulInferredOutputShape(const std::vector &aDim, + const std::vector &bDim) { + constexpr size_t kNonBatchRank = 2; + size_t aRank = aDim.size(); + size_t bRank = bDim.size(); + FUSILLI_RETURN_ERROR_IF(aRank < kNonBatchRank || bRank < kNonBatchRank, + ErrorCode::InvalidAttribute, + "Matmul input tensors must have rank >= 2"); + + std::vector aBatchDim(aDim.begin(), aDim.end() - kNonBatchRank); + std::vector bBatchDim(bDim.begin(), bDim.end() - kNonBatchRank); + size_t cBatchRank = std::max(aBatchDim.size(), bBatchDim.size()); + std::vector cDim(cBatchRank + kNonBatchRank, 1); + + // Broadcast batch dimensions using PyTorch/NumPy right-aligned semantics. + for (size_t offset = 0; offset < cBatchRank; ++offset) { + int64_t aDimVal = offset < aBatchDim.size() + ? aBatchDim[aBatchDim.size() - 1 - offset] + : 1; + int64_t bDimVal = offset < bBatchDim.size() + ? bBatchDim[bBatchDim.size() - 1 - offset] + : 1; + FUSILLI_RETURN_ERROR_IF( + aDimVal != bDimVal && aDimVal != 1 && bDimVal != 1, + ErrorCode::InvalidAttribute, + "Matmul input tensors A and B have incompatible batch dimensions for " + "broadcasting at right-aligned batch index " + + std::to_string(cBatchRank - 1 - offset) + ": A has dim=" + + std::to_string(aDimVal) + ", B has dim=" + std::to_string(bDimVal)); + cDim[cBatchRank - 1 - offset] = std::max(aDimVal, bDimVal); } // Matrix dimensions: M from A, N from B - cDim[rank - 2] = aDim[rank - 2]; // M - cDim[rank - 1] = bDim[rank - 1]; // N + cDim[cBatchRank] = aDim[aRank - 2]; // M + cDim[cBatchRank + 1] = bDim[bRank - 1]; // N - return cDim; + return ok(std::move(cDim)); +} + +inline std::vector +getMatmulInferredOutputShape(const std::vector &aDim, + const std::vector &bDim) { + auto cDim = tryGetMatmulInferredOutputShape(aDim, bDim); + assert(isOk(cDim) && "Invalid matmul input dimensions"); + return *cDim; } //===----------------------------------------------------------------------===// @@ -108,7 +128,7 @@ class MatmulNode : public NodeCRTP { size_t bRank = bT->getDim().size(); // Rank checks on input tensors (must be at least rank 2). - constexpr int64_t kNonBatchRank = 2; + constexpr size_t kNonBatchRank = 2; FUSILLI_RETURN_ERROR_IF( aRank < kNonBatchRank, ErrorCode::InvalidAttribute, "Matmul input tensor A must have a rank of at least 2"); @@ -116,12 +136,6 @@ class MatmulNode : public NodeCRTP { bRank < kNonBatchRank, ErrorCode::InvalidAttribute, "Matmul input tensor B must have a rank of at least 2"); - // Check that input tensors have the same rank. - FUSILLI_RETURN_ERROR_IF( - aRank != bRank, ErrorCode::InvalidAttribute, - "Matmul input tensors A and B must have the same rank: A has rank=" + - std::to_string(aRank) + ", B has rank=" + std::to_string(bRank)); - // Check that inner dimensions match (K dimension). const std::vector &aDim = aT->getDim(); const std::vector &bDim = bT->getDim(); @@ -135,19 +149,9 @@ class MatmulNode : public NodeCRTP { std::to_string(aK) + ", B has K=" + std::to_string(bK)); // Check that batch dimensions are broadcastable. - // Since both inputs have the same rank, we can directly compare batch dims. - size_t batchDims = aRank - kNonBatchRank; - for (size_t i = 0; i < batchDims; ++i) { - int64_t aDimVal = aDim[i]; - int64_t bDimVal = bDim[i]; - FUSILLI_RETURN_ERROR_IF( - !(aDimVal % bDimVal == 0 || bDimVal % aDimVal == 0), - ErrorCode::InvalidAttribute, - "Matmul input tensors A and B have incompatible batch dimensions for " - "broadcasting at index " + - std::to_string(i) + ": A has dim=" + std::to_string(aDimVal) + - ", B has dim=" + std::to_string(bDimVal)); - } + FUSILLI_ASSIGN_OR_RETURN(auto inferredCDim, + tryGetMatmulInferredOutputShape(aDim, bDim)); + (void)inferredCDim; FUSILLI_CHECK_ERROR(checkBatchDims(aT, "A")); FUSILLI_CHECK_ERROR(checkBatchDims(bT, "B")); @@ -164,10 +168,12 @@ class MatmulNode : public NodeCRTP { if (aT->getDataType() != bT->getDataType()) { constexpr int64_t kMixedPrecisionRequiredRank = 3; FUSILLI_RETURN_ERROR_IF( - aRank != kMixedPrecisionRequiredRank, ErrorCode::InvalidAttribute, + aRank != kMixedPrecisionRequiredRank || + bRank != kMixedPrecisionRequiredRank, + ErrorCode::InvalidAttribute, "Mixed precision matmul is only supported when input tensors A and B " - "are of rank 3 (single batch dim): A and B have rank=" + - std::to_string(aRank)); + "are of rank 3 (single batch dim): A has rank=" + + std::to_string(aRank) + ", B has rank=" + std::to_string(bRank)); FUSILLI_RETURN_ERROR_IF( aDim[0] != bDim[0], ErrorCode::InvalidAttribute, "Mixed precision matmul input tensors A and B must have exactly " @@ -192,15 +198,18 @@ class MatmulNode : public NodeCRTP { const std::vector &aDim = aT->getDim(); const std::vector &bDim = bT->getDim(); - const std::vector &cDim = cT->getDim(); const std::vector &cStride = cT->getStride(); // Infer shape of output tensor. - if (cDim.empty()) - cT->setDim(getMatmulInferredOutputShape(aDim, bDim)); + if (cT->getDim().empty()) { + FUSILLI_ASSIGN_OR_RETURN(auto inferredCDim, + tryGetMatmulInferredOutputShape(aDim, bDim)); + cT->setDim(inferredCDim); + } // Output stride is contiguous (row-major) when unspecified. if (cStride.empty()) { + const std::vector &cDim = cT->getDim(); cT->setStride( generateStrideFromDim(cDim, getContiguousStrideOrder(cDim.size()))); } @@ -219,17 +228,19 @@ class MatmulNode : public NodeCRTP { size_t cRank = cT->getDim().size(); // Rank checks - constexpr int64_t kNonBatchRank = 2; + constexpr size_t kNonBatchRank = 2; FUSILLI_RETURN_ERROR_IF( cRank < kNonBatchRank, ErrorCode::InvalidAttribute, "Matmul output tensor C must have a rank of at least 2"); - FUSILLI_RETURN_ERROR_IF( - cT->getDim() != - getMatmulInferredOutputShape(aT->getDim(), bT->getDim()), - ErrorCode::InvalidAttribute, - "Matmul output tensor C dimensions do not match the expected shapes " - "inferred based on the input dimensions"); + FUSILLI_ASSIGN_OR_RETURN( + auto inferredCDim, + tryGetMatmulInferredOutputShape(aT->getDim(), bT->getDim())); + FUSILLI_RETURN_ERROR_IF(cT->getDim() != inferredCDim, + ErrorCode::InvalidAttribute, + "Matmul output tensor C dimensions do not match " + "the expected shapes inferred based on the input " + "dimensions"); FUSILLI_CHECK_ERROR(checkBatchDims(cT, "C")); return ok(); } @@ -239,7 +250,7 @@ class MatmulNode : public NodeCRTP { // This is equivalent to checking that perm[i] == i for all batch dims. ErrorObject checkBatchDims(const std::shared_ptr &tensor, const std::string &name) const { - constexpr int64_t kNonBatchRank = 2; + constexpr size_t kNonBatchRank = 2; size_t batchDims = tensor->getDim().size() - kNonBatchRank; std::vector perm = tensor->getLogicalToPhysicalPermuteOrder(); for (size_t i = 0; i < batchDims; ++i) { diff --git a/tests/test_matmul_node.cpp b/tests/test_matmul_node.cpp index 6f86dfab..15d8c247 100644 --- a/tests/test_matmul_node.cpp +++ b/tests/test_matmul_node.cpp @@ -27,10 +27,12 @@ TEST_CASE("getMatmulInferredOutputShape", "[matmul_node]") { std::vector{8, 16, 64}); REQUIRE(getMatmulInferredOutputShape({8, 16, 32}, {1, 32, 64}) == std::vector{8, 16, 64}); - REQUIRE(getMatmulInferredOutputShape({4, 16, 32}, {8, 32, 64}) == - std::vector{8, 16, 64}); REQUIRE(getMatmulInferredOutputShape({1, 8, 16, 32}, {4, 1, 32, 64}) == std::vector{4, 8, 16, 64}); + REQUIRE(getMatmulInferredOutputShape({16, 32}, {8, 32, 64}) == + std::vector{8, 16, 64}); + REQUIRE(getMatmulInferredOutputShape({5, 1, 8, 16, 32}, {4, 1, 32, 64}) == + std::vector{5, 4, 8, 16, 64}); } TEST_CASE("MatmulNode getName correctly propagates the attribute name", @@ -147,6 +149,27 @@ TEST_CASE("MatmulNode inferPropertiesNode when C is under-specified", REQUIRE(cT->getStride() == std::vector{n, 1}); } +TEST_CASE("MatmulNode inferPropertiesNode when only C stride is unspecified", + "[matmul_node]") { + Context ctx; + MatmulAttr attr; + + int64_t batch = 8, m = 16, k = 32, n = 64; + + attr.setA(std::make_shared( + TensorAttr().setDim({batch, m, k}).setStride({m * k, k, 1}))); + attr.setB(std::make_shared( + TensorAttr().setDim({1, k, n}).setStride({k * n, n, 1}))); + attr.setC(std::make_shared(TensorAttr().setDim({batch, m, n}))); + + MatmulNode node(std::move(attr), ctx); + FUSILLI_REQUIRE_OK(node.inferPropertiesNode()); + + auto cT = node.matmulAttr.getC(); + REQUIRE(cT->getDim() == std::vector{batch, m, n}); + REQUIRE(cT->getStride() == std::vector{m * n, n, 1}); +} + TEST_CASE("MatmulNode inferPropertiesNode with batched matrices", "[matmul_node]") { Context ctx; @@ -364,7 +387,9 @@ TEST_CASE("MatmulNode rank checks", "[matmul_node]") { "Matmul input tensor B must have a rank of at least 2"); } - SECTION("Input tensors must have the same rank") { + SECTION("Input tensors may have different batch ranks") { + int64_t m = 16, n = 64; + auto aT = std::make_shared( TensorAttr().setDim({16, 32}).setStride({32, 1}).setName("A_rank2")); @@ -373,19 +398,19 @@ TEST_CASE("MatmulNode rank checks", "[matmul_node]") { .setStride({32 * 64ll, 64, 1}) .setName("B_rank3")); - auto cT = std::make_shared( - TensorAttr().setDim({16, 64}).setStride({64, 1}).setName("C")); + auto cT = std::make_shared(); attr.setA(aT).setB(bT).setC(cT); MatmulNode node(std::move(attr), ctx); - auto status = node.preValidateNode(); - REQUIRE(isError(status)); - REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); - REQUIRE(status.getMessage() == - "Matmul input tensors A and B must have the same rank: A has " - "rank=2, B has rank=3"); + FUSILLI_REQUIRE_OK(node.preValidateNode()); + FUSILLI_REQUIRE_OK(node.inferPropertiesNode()); + FUSILLI_REQUIRE_OK(node.postValidateNode()); + + auto inferredCT = node.matmulAttr.getC(); + REQUIRE(inferredCT->getDim() == std::vector{8, m, n}); + REQUIRE(inferredCT->getStride() == std::vector{m * n, n, 1}); } SECTION("Output C must be at least rank 2") { @@ -485,10 +510,11 @@ TEST_CASE("MatmulNode broadcasting dimension compatibility checks", REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); REQUIRE(status.getMessage() == "Matmul input tensors A and B have incompatible batch dimensions " - "for broadcasting at index 0: A has dim=3, B has dim=5"); + "for broadcasting at right-aligned batch index 0: A has dim=3, " + "B has dim=5"); } - SECTION("Compatible batch dimensions - one divides the other") { + SECTION("Incompatible batch dimensions - one divides the other") { int64_t batchA = 8, batchB = 4, m = 16, k = 32, n = 64; auto aT = std::make_shared( @@ -499,9 +525,38 @@ TEST_CASE("MatmulNode broadcasting dimension compatibility checks", auto cT = std::make_shared(); attr.setA(aT).setB(bT).setC(cT); MatmulNode node(std::move(attr), ctx); + + auto status = node.preValidateNode(); + REQUIRE(isError(status)); + REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); + REQUIRE(status.getMessage() == + "Matmul input tensors A and B have incompatible batch dimensions " + "for broadcasting at right-aligned batch index 0: A has dim=8, " + "B has dim=4"); + } + + SECTION("Compatible unequal-rank batch dimensions") { + int64_t b1 = 5, b2 = 8, m = 16, k = 32, n = 64; + + auto aT = std::make_shared( + TensorAttr() + .setDim({b1, 1, b2, m, k}) + .setStride({b2 * m * k, b2 * m * k, m * k, k, 1})); + auto bT = std::make_shared( + TensorAttr().setDim({4, 1, k, n}).setStride({k * n, k * n, n, 1})); + + auto cT = std::make_shared(); + attr.setA(aT).setB(bT).setC(cT); + MatmulNode node(std::move(attr), ctx); + FUSILLI_REQUIRE_OK(node.preValidateNode()); FUSILLI_REQUIRE_OK(node.inferPropertiesNode()); FUSILLI_REQUIRE_OK(node.postValidateNode()); + + auto inferredCT = node.matmulAttr.getC(); + REQUIRE(inferredCT->getDim() == std::vector{b1, 4, b2, m, n}); + REQUIRE(inferredCT->getStride() == + std::vector{4 * b2 * m * n, b2 * m * n, m * n, n, 1}); } SECTION("Incompatible multi-dimensional batch") { @@ -525,7 +580,8 @@ TEST_CASE("MatmulNode broadcasting dimension compatibility checks", REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); REQUIRE(status.getMessage() == "Matmul input tensors A and B have incompatible batch dimensions " - "for broadcasting at index 1: A has dim=3, B has dim=5"); + "for broadcasting at right-aligned batch index 1: A has dim=3, " + "B has dim=5"); } } @@ -708,7 +764,7 @@ TEST_CASE("MatmulNode mixed precision constraints", "[matmul_node]") { REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); REQUIRE(status.getMessage() == "Mixed precision matmul is only supported when input tensors A and " - "B are of rank 3 (single batch dim): A and B have rank=2"); + "B are of rank 3 (single batch dim): A has rank=2, B has rank=2"); } SECTION("Mixed precision 4D matmul (2 batch dims) - fail") { @@ -736,7 +792,7 @@ TEST_CASE("MatmulNode mixed precision constraints", "[matmul_node]") { REQUIRE(status.getCode() == ErrorCode::InvalidAttribute); REQUIRE(status.getMessage() == "Mixed precision matmul is only supported when input tensors A and " - "B are of rank 3 (single batch dim): A and B have rank=4"); + "B are of rank 3 (single batch dim): A has rank=4, B has rank=4"); } SECTION("Mixed precision with broadcast batch dim - fail") {