Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
35f8c39
account for different layouts
kahmed10 Feb 17, 2026
901e040
update to row major when copying
kahmed10 Feb 17, 2026
84b68cf
more cleanup and update requirements
kahmed10 Feb 18, 2026
6e7708e
formatting
kahmed10 Feb 18, 2026
00fc04c
fix requirements.txt
kahmed10 Feb 18, 2026
7a123f9
add debug and fix requirements
kahmed10 Feb 18, 2026
4b1285b
refactor and update requirements
kahmed10 Feb 20, 2026
51290e9
update lowering
kahmed10 Feb 20, 2026
892aa7d
more refactoring
kahmed10 Feb 20, 2026
2d1d0dc
refactor
kahmed10 Feb 24, 2026
8f78587
test eigen integration
kahmed10 Feb 24, 2026
03ea48f
remove debug env var
kahmed10 Feb 24, 2026
10b9802
exclude non MPL headers
kahmed10 Feb 24, 2026
cc1a7e4
manual merge
kahmed10 Feb 24, 2026
459f903
update changelog
kahmed10 Feb 24, 2026
3484fd3
update cmake
kahmed10 Feb 24, 2026
40a7acb
update based on review comments
kahmed10 Mar 4, 2026
40d877a
update review comments
kahmed10 Mar 13, 2026
1878a6e
manual merge
kahmed10 Mar 21, 2026
a42a97f
fix license and format
kahmed10 Mar 21, 2026
069aee9
Update src/include/migraphx/gemm.hpp
pfultz2 Mar 22, 2026
6ed7ac5
Merge branch 'develop' into eigen_gemm_impl
causten Mar 24, 2026
6a483a0
handle int8 to accumulate to int32
kahmed10 Mar 24, 2026
8464a9c
formatting
kahmed10 Mar 24, 2026
205f4c0
manual merge
kahmed10 Mar 24, 2026
41e068f
manual merge
kahmed10 Mar 24, 2026
426ac74
fix rocmlir hash
kahmed10 Mar 24, 2026
f270fd6
update quant_dot 4 args case and added test
kahmed10 Apr 8, 2026
9dbe0db
manual merge
kahmed10 Apr 8, 2026
d3c8916
cleanup and license fix
kahmed10 Apr 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Full documentation for MIGraphX is available at
* Added `auto_pad` attribute support for the ONNX `ConvTranspose` operator, supporting `SAME_UPPER`, `SAME_LOWER`, and `VALID` padding modes for static shapes (#4638).
* Added a dedicated logger for MIGraphX.
* [Linux] Use HSA API to query number of chiplets for architectures where this is applicable (ex. gfx90a).
* Added Eigen third party headers for ref GEMMs (#4631).
* Added a fuse_horizontal pass which batches independent cross embedding gather instructions (#4599).
* Added GPU JIT `Resize` kernel (#4553).
* Added environment variable `MIGRAPHX_SKIP_BENCHMARKING` which when enabled, skips tuning of MIGraphX and rocMLIR kernels (#4628).
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ else()
option(MIGRAPHX_USE_COMPOSABLEKERNEL "Enable MIGraphX to use composable kernel JIT library" ON)
endif()

option(MIGRAPHX_USE_EIGEN "Enable Eigen for optimized ref GEMM" ON)

include(ROCMSetupVersion)

option(BUILD_DEV "Build for development purpose only" OFF)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off -DMSGPACK_BUILD_EXAMPLES=Off -DCMAKE_POLICY_VERSION_MINIMUM=3.5
sqlite3@3.50.4 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/composable_kernel@ad0db05b040bacda751c65c705261b8a0a7ed25d --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On -DBUILD_TESTING=Off
https://gitlab.com/libeigen/eigen/-/archive/5.0.1/eigen-5.0.1.tar.gz -DBUILD_TESTING=Off -DEIGEN_BUILD_DOC=Off
ROCm/rocMLIR@2d0afe0deb7d7ee7a1913e7a3f85ad91d489d6b5 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off
20 changes: 15 additions & 5 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ add_library(migraphx
fuse_pointwise.cpp
fuse_pointwise_reduce.cpp
fuse_reduce.cpp
gemm.cpp
generate.cpp
graphviz.cpp
inline_module.cpp
Expand Down Expand Up @@ -344,9 +345,22 @@ if(NOT WIN32)
target_link_libraries(migraphx PRIVATE -ldl)
endif()

target_include_directories(migraphx SYSTEM PUBLIC $<BUILD_INTERFACE:${HALF_INCLUDE_DIR}>)
target_link_libraries(migraphx PUBLIC Threads::Threads)

if(MIGRAPHX_USE_EIGEN)
find_package(Eigen3)
if(Eigen3_FOUND)
target_compile_definitions(migraphx PUBLIC MIGRAPHX_USE_EIGEN=1)
target_compile_definitions(migraphx PRIVATE EIGEN_MPL2_ONLY)
target_link_libraries(migraphx PRIVATE Eigen3::Eigen)
else()
message(STATUS "Eigen not found, disabling MIGRAPHX_USE_EIGEN")
target_compile_definitions(migraphx PUBLIC MIGRAPHX_USE_EIGEN=0)
endif()
Comment thread
kahmed10 marked this conversation as resolved.
else()
target_compile_definitions(migraphx PUBLIC MIGRAPHX_USE_EIGEN=0)
endif()

if(MIGRAPHX_HAS_EXECUTORS AND ParallelSTL_USES_TBB)
list(APPEND MIGRAPHX_CONFIG_DEPENDS PACKAGE TBB)
endif()
Expand Down Expand Up @@ -423,10 +437,6 @@ else()
target_compile_definitions(migraphx_all_targets INTERFACE MIGRAPHX_USE_MIOPEN=0)
endif()

if(HAVE_HALF_EXPR)
target_compile_definitions(migraphx PUBLIC -DHAS_HALF_V1)
endif()

if(BUILD_DEV)
target_compile_definitions(migraphx PUBLIC -DBUILD_DEV)
endif()
Expand Down
204 changes: 204 additions & 0 deletions src/gemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gemm.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <algorithm>
#include <numeric>
#include <vector>

#if MIGRAPHX_USE_EIGEN
#include <Eigen/Core>
#endif

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

namespace {

template <class T, class U>
[[maybe_unused]] void gemm_naive(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat)
{
std::size_t n_dims = cmat.get_shape().ndim();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];

auto cs = cmat.get_shape();
par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += static_cast<double>(amat(a_idx.begin(), a_idx.end())) *
static_cast<double>(bmat(b_idx.begin(), b_idx.end()));
});
cmat(c_idx.begin(), c_idx.end()) = static_cast<T>(s);
});
}

#if MIGRAPHX_USE_EIGEN

using eigen_row_major = Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using eigen_stride = Eigen::Stride<Eigen::Dynamic, Eigen::Dynamic>;

struct batch_slicer
{
batch_slicer(const shape& mat_shape)
{
auto n_batch_dims = mat_shape.ndim() - 2;
inner_shape = shape{mat_shape.type(),
{mat_shape.lens().end() - 2, mat_shape.lens().end()},
{mat_shape.strides().end() - 2, mat_shape.strides().end()}};
if(n_batch_dims > 0)
{
outer_shape =
shape{mat_shape.type(),
{mat_shape.lens().begin(), mat_shape.lens().begin() + n_batch_dims},
{mat_shape.strides().begin(), mat_shape.strides().begin() + n_batch_dims}};
}
}

template <class T>
tensor_view<T> extract(tensor_view<T> view, std::size_t batch) const
{
std::size_t offset = 0;
if(not outer_shape.lens().empty())
offset = outer_shape.index(batch);
return make_view(inner_shape, view.data() + offset);
}

std::size_t num_batches() const
{
if(outer_shape.lens().empty())
return 1;
return outer_shape.elements();
}

shape inner_shape;
shape outer_shape;
};

template <class T>
auto make_eigen_map(tensor_view<T> view)
{
const auto& s = view.get_shape();
auto dim_0 = s.ndim() - 2;
auto dim_1 = s.ndim() - 1;
Eigen::Index rows = s.lens()[dim_0];
Eigen::Index cols = s.lens()[dim_1];
Eigen::Index rowstride = s.strides()[dim_0];
Eigen::Index colstride = s.strides()[dim_1];
return Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>,
Eigen::Unaligned,
eigen_stride>{view.data(), rows, cols, eigen_stride{rowstride, colstride}};
}

template <class T>
void eigen_multiply(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat)
{
batch_slicer slicer(cmat.get_shape());
batch_slicer a_slicer(amat.get_shape());
batch_slicer b_slicer(bmat.get_shape());

par_for(slicer.num_batches(), [&](auto batch) {
auto a_slice = a_slicer.extract(amat, batch);
auto b_slice = b_slicer.extract(bmat, batch);
auto c_slice = slicer.extract(cmat, batch);

auto a = make_eigen_map(a_slice);
auto b = make_eigen_map(b_slice);
auto c = make_eigen_map(c_slice);
c.noalias() = a * b;
});
}

template <class AccType, class T, class U>
void gemm_eigen_with_copy(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat)
{
std::vector<AccType> a_buf(amat.get_shape().elements());
std::copy(amat.begin(), amat.end(), a_buf.begin());
auto amat_flat = make_view(amat.get_shape().as_standard().with_type(shape::get_type<AccType>{}),
a_buf.data());

std::vector<AccType> b_buf(bmat.get_shape().elements());
std::copy(bmat.begin(), bmat.end(), b_buf.begin());
auto bmat_flat = make_view(bmat.get_shape().as_standard().with_type(shape::get_type<AccType>{}),
b_buf.data());

std::vector<AccType> c_buf(cmat.get_shape().elements(), AccType{0});
auto c_shape_std = cmat.get_shape().as_standard().with_type(shape::get_type<AccType>{});
auto cmat_flat = make_view(c_shape_std, c_buf.data());

eigen_multiply(cmat_flat, amat_flat, bmat_flat);

std::copy(c_buf.begin(), c_buf.end(), cmat.begin());
}

template <class T, class U>
void gemm_eigen(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat)
{
if constexpr(std::is_same<T, U>{} and (std::is_same<T, float>{} or std::is_same<T, double>{}))
{
eigen_multiply(cmat, amat, bmat);
}
else if constexpr(std::is_integral<U>{})
{
gemm_eigen_with_copy<int64_t>(cmat, amat, bmat);
}
else
{
gemm_eigen_with_copy<float>(cmat, amat, bmat);
}
Comment thread
kahmed10 marked this conversation as resolved.
}

#endif

template <class Visitor>
void gemm_ref_visit(const argument& c_arg, const argument& a_arg, const argument& b_arg, Visitor v)
{
c_arg.visit([&](auto cmat) {
visit_all(a_arg, b_arg)([&](auto amat, auto bmat) { v(cmat, amat, bmat); });
});
}

} // namespace

void gemm(const argument& c_arg, const argument& a_arg, const argument& b_arg)
{
#if MIGRAPHX_USE_EIGEN
gemm_ref_visit(
c_arg, a_arg, b_arg, [](auto cmat, auto amat, auto bmat) { gemm_eigen(cmat, amat, bmat); });
#else
gemm_ref_visit(
c_arg, a_arg, b_arg, [](auto cmat, auto amat, auto bmat) { gemm_naive(cmat, amat, bmat); });
#endif
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
33 changes: 3 additions & 30 deletions src/include/migraphx/gemm.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -25,39 +25,12 @@
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP

#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/tensor_view.hpp>
#include <migraphx/argument.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

template <class T, class U, class F>
void gemm(tensor_view<T> cmat, tensor_view<U> amat, tensor_view<U> bmat, F alpha, F beta)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];

assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);
auto cs = cmat.get_shape();

par_for(cs.elements(), [&](auto i) {
auto c_idx = cs.multi(i);
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += static_cast<double>(amat(a_idx.begin(), a_idx.end())) *
static_cast<double>(bmat(b_idx.begin(), b_idx.end()));
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
}
MIGRAPHX_EXPORT void gemm(const argument& c_arg, const argument& a_arg, const argument& b_arg);

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
7 changes: 3 additions & 4 deletions src/include/migraphx/op/dot.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -125,9 +125,8 @@ struct dot

argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result = argument{dyn_out.computed_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
argument result{dyn_out.computed_shape};
gemm(result, args[0], args[1]);
return result;
}
};
Expand Down
Loading
Loading