Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 67 additions & 56 deletions include/fusilli/node/matmul_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,52 @@ namespace fusilli {
// Infer the output shape of a matrix multiplication operation from the input
// shapes. For matrices A [..., M, K] and B [..., K, N], the output is [..., M,
// N].
inline std::vector<int64_t>
getMatmulInferredOutputShape(const std::vector<int64_t> &aDim,
const std::vector<int64_t> &bDim) {
constexpr int64_t kNonBatchRank = 2;
size_t rank = aDim.size();
assert(rank == bDim.size() && "Input tensors must have the same rank");
assert(rank >= kNonBatchRank && "Input tensors must have rank >= 2");

std::vector<int64_t> cDim(rank);

// Handle batch dimensions (broadcast if necessary)
size_t batchDims = rank - kNonBatchRank;
for (size_t i = 0; i < batchDims; ++i) {
int64_t aDimVal = aDim[i];
int64_t bDimVal = bDim[i];
// Use the maximum of the two dimensions (broadcasting rule)
assert((aDimVal % bDimVal == 0 || bDimVal % aDimVal == 0) &&
"Incompatible dimensions for broadcasting");
cDim[i] = std::max<int64_t>(aDimVal, bDimVal);
inline ErrorOr<std::vector<int64_t>>
tryGetMatmulInferredOutputShape(const std::vector<int64_t> &aDim,
const std::vector<int64_t> &bDim) {
constexpr size_t kNonBatchRank = 2;
size_t aRank = aDim.size();
size_t bRank = bDim.size();
FUSILLI_RETURN_ERROR_IF(aRank < kNonBatchRank || bRank < kNonBatchRank,
ErrorCode::InvalidAttribute,
"Matmul input tensors must have rank >= 2");

std::vector<int64_t> aBatchDim(aDim.begin(), aDim.end() - kNonBatchRank);
std::vector<int64_t> bBatchDim(bDim.begin(), bDim.end() - kNonBatchRank);
size_t cBatchRank = std::max(aBatchDim.size(), bBatchDim.size());
std::vector<int64_t> cDim(cBatchRank + kNonBatchRank, 1);

// Broadcast batch dimensions using PyTorch/NumPy right-aligned semantics.
for (size_t offset = 0; offset < cBatchRank; ++offset) {
int64_t aDimVal = offset < aBatchDim.size()
? aBatchDim[aBatchDim.size() - 1 - offset]
: 1;
int64_t bDimVal = offset < bBatchDim.size()
? bBatchDim[bBatchDim.size() - 1 - offset]
: 1;
FUSILLI_RETURN_ERROR_IF(
aDimVal != bDimVal && aDimVal != 1 && bDimVal != 1,
ErrorCode::InvalidAttribute,
"Matmul input tensors A and B have incompatible batch dimensions for "
"broadcasting at right-aligned batch index " +
std::to_string(cBatchRank - 1 - offset) + ": A has dim=" +
std::to_string(aDimVal) + ", B has dim=" + std::to_string(bDimVal));
cDim[cBatchRank - 1 - offset] = std::max<int64_t>(aDimVal, bDimVal);
}

// Matrix dimensions: M from A, N from B
cDim[rank - 2] = aDim[rank - 2]; // M
cDim[rank - 1] = bDim[rank - 1]; // N
cDim[cBatchRank] = aDim[aRank - 2]; // M
cDim[cBatchRank + 1] = bDim[bRank - 1]; // N

return cDim;
return ok(std::move(cDim));
}

inline std::vector<int64_t>
getMatmulInferredOutputShape(const std::vector<int64_t> &aDim,
const std::vector<int64_t> &bDim) {
auto cDim = tryGetMatmulInferredOutputShape(aDim, bDim);
assert(isOk(cDim) && "Invalid matmul input dimensions");
return *cDim;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -108,20 +128,14 @@ class MatmulNode : public NodeCRTP<MatmulNode> {
size_t bRank = bT->getDim().size();

// Rank checks on input tensors (must be at least rank 2).
constexpr int64_t kNonBatchRank = 2;
constexpr size_t kNonBatchRank = 2;
FUSILLI_RETURN_ERROR_IF(
aRank < kNonBatchRank, ErrorCode::InvalidAttribute,
"Matmul input tensor A must have a rank of at least 2");
FUSILLI_RETURN_ERROR_IF(
bRank < kNonBatchRank, ErrorCode::InvalidAttribute,
"Matmul input tensor B must have a rank of at least 2");

// Check that input tensors have the same rank.
FUSILLI_RETURN_ERROR_IF(
aRank != bRank, ErrorCode::InvalidAttribute,
"Matmul input tensors A and B must have the same rank: A has rank=" +
std::to_string(aRank) + ", B has rank=" + std::to_string(bRank));

// Check that inner dimensions match (K dimension).
const std::vector<int64_t> &aDim = aT->getDim();
const std::vector<int64_t> &bDim = bT->getDim();
Expand All @@ -135,19 +149,9 @@ class MatmulNode : public NodeCRTP<MatmulNode> {
std::to_string(aK) + ", B has K=" + std::to_string(bK));

// Check that batch dimensions are broadcastable.
// Since both inputs have the same rank, we can directly compare batch dims.
size_t batchDims = aRank - kNonBatchRank;
for (size_t i = 0; i < batchDims; ++i) {
int64_t aDimVal = aDim[i];
int64_t bDimVal = bDim[i];
FUSILLI_RETURN_ERROR_IF(
!(aDimVal % bDimVal == 0 || bDimVal % aDimVal == 0),
ErrorCode::InvalidAttribute,
"Matmul input tensors A and B have incompatible batch dimensions for "
"broadcasting at index " +
std::to_string(i) + ": A has dim=" + std::to_string(aDimVal) +
", B has dim=" + std::to_string(bDimVal));
}
FUSILLI_ASSIGN_OR_RETURN(auto inferredCDim,
tryGetMatmulInferredOutputShape(aDim, bDim));
(void)inferredCDim;

FUSILLI_CHECK_ERROR(checkBatchDims(aT, "A"));
FUSILLI_CHECK_ERROR(checkBatchDims(bT, "B"));
Expand All @@ -164,10 +168,12 @@ class MatmulNode : public NodeCRTP<MatmulNode> {
if (aT->getDataType() != bT->getDataType()) {
constexpr int64_t kMixedPrecisionRequiredRank = 3;
FUSILLI_RETURN_ERROR_IF(
aRank != kMixedPrecisionRequiredRank, ErrorCode::InvalidAttribute,
aRank != kMixedPrecisionRequiredRank ||
bRank != kMixedPrecisionRequiredRank,
ErrorCode::InvalidAttribute,
"Mixed precision matmul is only supported when input tensors A and B "
"are of rank 3 (single batch dim): A and B have rank=" +
std::to_string(aRank));
"are of rank 3 (single batch dim): A has rank=" +
std::to_string(aRank) + ", B has rank=" + std::to_string(bRank));
FUSILLI_RETURN_ERROR_IF(
aDim[0] != bDim[0], ErrorCode::InvalidAttribute,
"Mixed precision matmul input tensors A and B must have exactly "
Expand All @@ -192,15 +198,18 @@ class MatmulNode : public NodeCRTP<MatmulNode> {
const std::vector<int64_t> &aDim = aT->getDim();
const std::vector<int64_t> &bDim = bT->getDim();

const std::vector<int64_t> &cDim = cT->getDim();
const std::vector<int64_t> &cStride = cT->getStride();

// Infer shape of output tensor.
if (cDim.empty())
cT->setDim(getMatmulInferredOutputShape(aDim, bDim));
if (cT->getDim().empty()) {
FUSILLI_ASSIGN_OR_RETURN(auto inferredCDim,
tryGetMatmulInferredOutputShape(aDim, bDim));
cT->setDim(inferredCDim);
}

// Output stride is contiguous (row-major) when unspecified.
if (cStride.empty()) {
const std::vector<int64_t> &cDim = cT->getDim();
cT->setStride(
generateStrideFromDim(cDim, getContiguousStrideOrder(cDim.size())));
}
Expand All @@ -219,17 +228,19 @@ class MatmulNode : public NodeCRTP<MatmulNode> {
size_t cRank = cT->getDim().size();

// Rank checks
constexpr int64_t kNonBatchRank = 2;
constexpr size_t kNonBatchRank = 2;
FUSILLI_RETURN_ERROR_IF(
cRank < kNonBatchRank, ErrorCode::InvalidAttribute,
"Matmul output tensor C must have a rank of at least 2");

FUSILLI_RETURN_ERROR_IF(
cT->getDim() !=
getMatmulInferredOutputShape(aT->getDim(), bT->getDim()),
ErrorCode::InvalidAttribute,
"Matmul output tensor C dimensions do not match the expected shapes "
"inferred based on the input dimensions");
FUSILLI_ASSIGN_OR_RETURN(
auto inferredCDim,
tryGetMatmulInferredOutputShape(aT->getDim(), bT->getDim()));
FUSILLI_RETURN_ERROR_IF(cT->getDim() != inferredCDim,
ErrorCode::InvalidAttribute,
"Matmul output tensor C dimensions do not match "
"the expected shapes inferred based on the input "
"dimensions");
FUSILLI_CHECK_ERROR(checkBatchDims(cT, "C"));
return ok();
}
Expand All @@ -239,7 +250,7 @@ class MatmulNode : public NodeCRTP<MatmulNode> {
// This is equivalent to checking that perm[i] == i for all batch dims.
ErrorObject checkBatchDims(const std::shared_ptr<TensorAttr> &tensor,
const std::string &name) const {
constexpr int64_t kNonBatchRank = 2;
constexpr size_t kNonBatchRank = 2;
size_t batchDims = tensor->getDim().size() - kNonBatchRank;
std::vector<int64_t> perm = tensor->getLogicalToPhysicalPermuteOrder();
for (size_t i = 0; i < batchDims; ++i) {
Expand Down
88 changes: 72 additions & 16 deletions tests/test_matmul_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ TEST_CASE("getMatmulInferredOutputShape", "[matmul_node]") {
std::vector<int64_t>{8, 16, 64});
REQUIRE(getMatmulInferredOutputShape({8, 16, 32}, {1, 32, 64}) ==
std::vector<int64_t>{8, 16, 64});
REQUIRE(getMatmulInferredOutputShape({4, 16, 32}, {8, 32, 64}) ==
std::vector<int64_t>{8, 16, 64});
REQUIRE(getMatmulInferredOutputShape({1, 8, 16, 32}, {4, 1, 32, 64}) ==
std::vector<int64_t>{4, 8, 16, 64});
REQUIRE(getMatmulInferredOutputShape({16, 32}, {8, 32, 64}) ==
std::vector<int64_t>{8, 16, 64});
REQUIRE(getMatmulInferredOutputShape({5, 1, 8, 16, 32}, {4, 1, 32, 64}) ==
std::vector<int64_t>{5, 4, 8, 16, 64});
}

TEST_CASE("MatmulNode getName correctly propagates the attribute name",
Expand Down Expand Up @@ -147,6 +149,27 @@ TEST_CASE("MatmulNode inferPropertiesNode when C is under-specified",
REQUIRE(cT->getStride() == std::vector<int64_t>{n, 1});
}

TEST_CASE("MatmulNode inferPropertiesNode when only C stride is unspecified",
"[matmul_node]") {
Context ctx;
MatmulAttr attr;

int64_t batch = 8, m = 16, k = 32, n = 64;

attr.setA(std::make_shared<TensorAttr>(
TensorAttr().setDim({batch, m, k}).setStride({m * k, k, 1})));
attr.setB(std::make_shared<TensorAttr>(
TensorAttr().setDim({1, k, n}).setStride({k * n, n, 1})));
attr.setC(std::make_shared<TensorAttr>(TensorAttr().setDim({batch, m, n})));

MatmulNode node(std::move(attr), ctx);
FUSILLI_REQUIRE_OK(node.inferPropertiesNode());

auto cT = node.matmulAttr.getC();
REQUIRE(cT->getDim() == std::vector<int64_t>{batch, m, n});
REQUIRE(cT->getStride() == std::vector<int64_t>{m * n, n, 1});
}

TEST_CASE("MatmulNode inferPropertiesNode with batched matrices",
"[matmul_node]") {
Context ctx;
Expand Down Expand Up @@ -364,7 +387,9 @@ TEST_CASE("MatmulNode rank checks", "[matmul_node]") {
"Matmul input tensor B must have a rank of at least 2");
}

SECTION("Input tensors must have the same rank") {
SECTION("Input tensors may have different batch ranks") {
int64_t m = 16, n = 64;

auto aT = std::make_shared<TensorAttr>(
TensorAttr().setDim({16, 32}).setStride({32, 1}).setName("A_rank2"));

Expand All @@ -373,19 +398,19 @@ TEST_CASE("MatmulNode rank checks", "[matmul_node]") {
.setStride({32 * 64ll, 64, 1})
.setName("B_rank3"));

auto cT = std::make_shared<TensorAttr>(
TensorAttr().setDim({16, 64}).setStride({64, 1}).setName("C"));
auto cT = std::make_shared<TensorAttr>();

attr.setA(aT).setB(bT).setC(cT);

MatmulNode node(std::move(attr), ctx);

auto status = node.preValidateNode();
REQUIRE(isError(status));
REQUIRE(status.getCode() == ErrorCode::InvalidAttribute);
REQUIRE(status.getMessage() ==
"Matmul input tensors A and B must have the same rank: A has "
"rank=2, B has rank=3");
FUSILLI_REQUIRE_OK(node.preValidateNode());
FUSILLI_REQUIRE_OK(node.inferPropertiesNode());
FUSILLI_REQUIRE_OK(node.postValidateNode());

auto inferredCT = node.matmulAttr.getC();
REQUIRE(inferredCT->getDim() == std::vector<int64_t>{8, m, n});
REQUIRE(inferredCT->getStride() == std::vector<int64_t>{m * n, n, 1});
}

SECTION("Output C must be at least rank 2") {
Expand Down Expand Up @@ -485,10 +510,11 @@ TEST_CASE("MatmulNode broadcasting dimension compatibility checks",
REQUIRE(status.getCode() == ErrorCode::InvalidAttribute);
REQUIRE(status.getMessage() ==
"Matmul input tensors A and B have incompatible batch dimensions "
"for broadcasting at index 0: A has dim=3, B has dim=5");
"for broadcasting at right-aligned batch index 0: A has dim=3, "
"B has dim=5");
}

SECTION("Compatible batch dimensions - one divides the other") {
SECTION("Incompatible batch dimensions - one divides the other") {
int64_t batchA = 8, batchB = 4, m = 16, k = 32, n = 64;

auto aT = std::make_shared<TensorAttr>(
Expand All @@ -499,9 +525,38 @@ TEST_CASE("MatmulNode broadcasting dimension compatibility checks",
auto cT = std::make_shared<TensorAttr>();
attr.setA(aT).setB(bT).setC(cT);
MatmulNode node(std::move(attr), ctx);

auto status = node.preValidateNode();
REQUIRE(isError(status));
REQUIRE(status.getCode() == ErrorCode::InvalidAttribute);
REQUIRE(status.getMessage() ==
"Matmul input tensors A and B have incompatible batch dimensions "
"for broadcasting at right-aligned batch index 0: A has dim=8, "
"B has dim=4");
}

SECTION("Compatible unequal-rank batch dimensions") {
int64_t b1 = 5, b2 = 8, m = 16, k = 32, n = 64;

auto aT = std::make_shared<TensorAttr>(
TensorAttr()
.setDim({b1, 1, b2, m, k})
.setStride({b2 * m * k, b2 * m * k, m * k, k, 1}));
auto bT = std::make_shared<TensorAttr>(
TensorAttr().setDim({4, 1, k, n}).setStride({k * n, k * n, n, 1}));

auto cT = std::make_shared<TensorAttr>();
attr.setA(aT).setB(bT).setC(cT);
MatmulNode node(std::move(attr), ctx);

FUSILLI_REQUIRE_OK(node.preValidateNode());
FUSILLI_REQUIRE_OK(node.inferPropertiesNode());
FUSILLI_REQUIRE_OK(node.postValidateNode());

auto inferredCT = node.matmulAttr.getC();
REQUIRE(inferredCT->getDim() == std::vector<int64_t>{b1, 4, b2, m, n});
REQUIRE(inferredCT->getStride() ==
std::vector<int64_t>{4 * b2 * m * n, b2 * m * n, m * n, n, 1});
}

SECTION("Incompatible multi-dimensional batch") {
Expand All @@ -525,7 +580,8 @@ TEST_CASE("MatmulNode broadcasting dimension compatibility checks",
REQUIRE(status.getCode() == ErrorCode::InvalidAttribute);
REQUIRE(status.getMessage() ==
"Matmul input tensors A and B have incompatible batch dimensions "
"for broadcasting at index 1: A has dim=3, B has dim=5");
"for broadcasting at right-aligned batch index 1: A has dim=3, "
"B has dim=5");
}
}

Expand Down Expand Up @@ -708,7 +764,7 @@ TEST_CASE("MatmulNode mixed precision constraints", "[matmul_node]") {
REQUIRE(status.getCode() == ErrorCode::InvalidAttribute);
REQUIRE(status.getMessage() ==
"Mixed precision matmul is only supported when input tensors A and "
"B are of rank 3 (single batch dim): A and B have rank=2");
"B are of rank 3 (single batch dim): A has rank=2, B has rank=2");
}

SECTION("Mixed precision 4D matmul (2 batch dims) - fail") {
Expand Down Expand Up @@ -736,7 +792,7 @@ TEST_CASE("MatmulNode mixed precision constraints", "[matmul_node]") {
REQUIRE(status.getCode() == ErrorCode::InvalidAttribute);
REQUIRE(status.getMessage() ==
"Mixed precision matmul is only supported when input tensors A and "
"B are of rank 3 (single batch dim): A and B have rank=4");
"B are of rank 3 (single batch dim): A has rank=4, B has rank=4");
}

SECTION("Mixed precision with broadcast batch dim - fail") {
Expand Down
Loading