From 0be5cc1d16b1717148037b82340682061d6a9fcc Mon Sep 17 00:00:00 2001 From: patryk-kaiser-ARM Date: Tue, 17 Mar 2026 17:13:35 +0000 Subject: [PATCH 1/6] [MLAS] Integrate KleidiAI BF16 SME2 Kernel Through Mlas SBGEMM Path (#26773) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Description** This PR integrates Arm® KleidiAI™ SME2 BF16 kernel through MLAS SBGEMM path. Rework of https://github.com/microsoft/onnxruntime/pull/24346 **Motivation and Context** This kernel provides performance improvements on SME-enabled devices. --------- Signed-off-by: Patryk Kaiser --- cmake/onnxruntime_mlas.cmake | 1 + onnxruntime/core/mlas/inc/mlas.h | 82 +++- .../core/mlas/lib/kai_ukernel_interface.cpp | 12 +- .../core/mlas/lib/kai_ukernel_interface.h | 7 + .../core/mlas/lib/kleidiai/mlasi_kleidiai.h | 36 ++ .../mlas/lib/kleidiai/sbgemm_kleidiai.cpp | 449 ++++++++++++++++++ onnxruntime/core/mlas/lib/mlasi.h | 40 ++ onnxruntime/core/mlas/lib/platform.cpp | 8 + .../core/mlas/lib/sbconv_kernel_neon.cpp | 5 +- onnxruntime/core/mlas/lib/sbgemm.h | 81 +++- onnxruntime/core/providers/cpu/math/matmul.cc | 29 +- .../test/mlas/unittest/test_sbgemm.cpp | 57 ++- onnxruntime/test/mlas/unittest/test_sbgemm.h | 102 +++- 13 files changed, 865 insertions(+), 44 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f1b3b091bbc6e..5f1e166057b64 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -284,6 +284,7 @@ function(setup_kleidiai) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/kai_ukernel_interface.cpp ${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp ) diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 80817ff87d736..56849995656f3 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2005,7 +2005,9 @@ struct MLAS_SBGEMM_DATA_PARAMS { const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ - bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */ + bool ZeroMode = true; /**< when true: C = A*B + Bias (if Bias != nullptr); + when false: C += A*B and Bias is ignored */ + bool BIsPacked = false; /**< Whether B is pre-packed */ }; /** @@ -2015,40 +2017,84 @@ struct MLAS_SBGEMM_DATA_PARAMS { * Note: We only support uniform batching, so shapes and types of the * input must be same across all parameter blocks. * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] TransA Supplies the transpose operation for matrix A. + * @param[in] TransB Supplies the transpose operation for matrix B. + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] ThreadPool + * @param[in] BackendKernelSelectorConfig Supplies the backend kernel selector + configuration options, else nullptr if the + default configuration should be used. * @return */ void MLASCALL -MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool); +MlasSBGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SBGEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +); /** * @brief For bfloat16 precision GEMM, returns size of the * packing buffer needed for right hand side - * @param[in] N Number of columns - * @param[in] K Number of rows - * @return size of the packing buffer, - * 0 if operation not supported + * @param[in] TransA Supplies the transpose operation for matrix A. + * @param[in] TransB Supplies the transpose operation for matrix B. + * @param[in] BIsfp32 Is matrix B datatype FP32 + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] BackendKernelSelectorConfig Supplies the backend kernel selector + configuration options, else nullptr if the + default configuration should be used. + * @return size of the packing buffer, + * 0 if operation not supported */ size_t MLASCALL -MlasSBGemmPackBSize(size_t N, size_t K); +MlasSBGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + bool BIsfp32, + size_t N, + size_t K, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +); /** * @brief For bfloat16 precision GEMM, convert the float matrix B * to blfoat16 precision and pack it into a packing buffer * - * @param[in] N Number of columns - * @param[in] K Number of rows - * @param[in] B Address of matrix B - * @param[in] ldb leading dimension of input matrix B - * @param[out] PackedB Address of the packed matrix + * @param[in] TransA Supplies the transpose operation for matrix A. + * @param[in] TransB Supplies the transpose operation for matrix B. + * @param[in] BIsfp32 Is matrix B datatype FP32 + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + * @param[in] BackendKernelSelectorConfig Supplies the backend kernel selector + configuration options, else nullptr if the + default configuration should be used. */ void MLASCALL -MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +MlasSBGemmConvertPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + bool BIsfp32, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +); #endif /** diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp index 6a25447c43e09..6ee80594c6b49 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp @@ -30,9 +30,11 @@ #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" // SME2 kernels -// GEMM/QGEMM +// GEMM/QGEMM/SBGEMM #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" + // GEMV #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" @@ -227,6 +229,9 @@ const KaiF32IMatmulKernel imatmul_conv_sme = const KaiF32IMatmulKernel imatmul_conv_sme2 = KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa); +const KaiBF16SBgemmKernel sbgemm_gemm_sme2 = + KAI_WRAP_UKERNEL_RUN_MATMUL_11(matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa); + #if defined(ENABLE_QMX_KERNELS) const KaiF32IMatmulKernel imatmul_conv_qmx = KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_qmx_mopa); @@ -362,3 +367,8 @@ const KaiDynamicQGemmKernel& GetKleidiAIQGemmUKernel() { #endif // ENABLE_QMX_KERNELS } } + +const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel() { + // Currently only SME2 variant exists for bfloat16/SBGEMM kernel + return sbgemm_gemm_sme2; +} diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h index aa3be8b47cb88..155ecf1762b3b 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h @@ -12,6 +12,8 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h" + #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h" #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h" @@ -39,6 +41,8 @@ using KaiDynamicQGemmKernel = KaiMatmulKernel; +using KaiBF16SBgemmKernel = KaiMatmulKernel; + // Returns the selected Qnbit GEMM ukernel based on runtime CPU capabilities. const KaiQnbitGemmKernel& GetKleidiAIGemmUKernel(); @@ -56,3 +60,6 @@ const KaiF32SgemvKernel& GetKleidiAISGemvUKernel(); // Returns the selected FP32 IMATMUL ukernel used by the KleidiAI convolution implementation. const KaiF32IMatmulKernel& GetKleidiAIF32IMatmulUKernel(); + +// Returns the selected BF16 SBGEMM ukernel used by the KleidiAI based on runtime CPU capabilities. +const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel(); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index f00efce7d5f88..a1aa241b89299 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -105,6 +105,42 @@ MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ); +#if defined(__aarch64__) && defined(__linux__) +size_t +MLASCALL +MlasSBGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K + ); + +bool +MLASCALL +MlasSBGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ); + +bool +MLASCALL +MlasSBGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SBGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); +#endif + size_t MLASCALL MlasDynamicQGemmPackBSize( diff --git a/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp new file mode 100644 index 0000000000000..f88af056aa156 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp @@ -0,0 +1,449 @@ +// +// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#if defined(__aarch64__) && defined(__linux__) + +#include +#include +#include +#include + +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h" + +#include "mlas.h" + +#include "mlasi_kleidiai.h" +#include "kai_ukernel_interface.h" + +// Thread-local reusable buffers to reduce allocation overhead across tiles. +struct KaiTlsBuffersSbgemm { + std::vector output_tile; + std::vector bias_zero; + std::vector rhs_packed; + std::vector lhs_packed; +}; +static thread_local KaiTlsBuffersSbgemm g_kai_tls_sbgemm; + +const KaiBF16SBgemmKernel& sbgemm_gemm = GetKleidiAISBGemmUKernel(); + +/*++ +Routine Description: + Accumulate src into dst: dst[i,j] += src[i,j], respecting ldc. + +Arguments: + src - Pointer to the temporary A*B results (row-major, rows x cols). + rows - Number of rows in the tile. + cols - Number of columns in the tile. + dst - Pointer to the destination tile in C (row-major with leading dimension ldc). + ldc - Leading dimension of C (in elements). + +Notes: + Implements the accumulation path for SBGEMM when ZeroMode == false. +--*/ +static inline void AccumulateTile(const float* src, + size_t rows, + size_t cols, + float* dst, + size_t ldc) { + if (ldc == cols) { + // contiguous block in memory: add elementwise across whole block + size_t elems = rows * cols; + for (size_t i = 0; i < elems; ++i) { + dst[i] += src[i]; + } + } else { + // general case with row stride and a column offset + for (size_t i = 0; i < rows; ++i) { + const float* src_row = src + i * cols; + float* dst_row = dst + i * ldc; + for (size_t j = 0; j < cols; ++j) { + dst_row[j] += src_row[j]; + } + } + } +} + +/*++ +Routine Description: + Apply bias to a 2-D tile (rows x cols). + +Arguments: + src - Pointer to the temporary A*B results (row-major, rows x cols). + rows - Number of rows in the tile. + cols - Number of columns in the tile. + bias - Pointer to the bias vector or nullptr if no bias. + dst - Pointer to the destination tile in C (row-major with leading dimension ldc). + ldc - Leading dimension of C (in elements). + start_col - Starting column index of the tile (NIdx * n_step). + +Notes: + Uses a row by row memcpy path when no bias. +--*/ +static inline void ApplyBias2D(const float* src, + size_t rows, + size_t cols, + const float* bias, + float* dst, + size_t ldc, + size_t start_col) { + for (size_t i = 0; i < rows; ++i) { + const float* src_row = src + i * cols; + float* dst_row = dst + i * ldc; + + if (bias != nullptr) { + for (size_t j = 0; j < cols; ++j) { + dst_row[j] = src_row[j] + bias[start_col + j]; + } + } else { + // No bias but can't memcpy whole so needs to be done row by row. + memcpy(dst_row, src_row, cols * sizeof(float)); + } + } +} + +size_t +MLASCALL +ArmKleidiAI::MlasSBGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K +) +/*++ + +Routine Description: + + This routine computes the length in bytes for the packed matrix B buffer. + +Arguments: + + TransA - Supplies the transpose operation on A matrix. + + TransB - Supplies the transpose operation on B matrix. + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + +Return Value: + + Returns the size in bytes for the packed matrix B buffer. + +--*/ +{ + if (TransA != CblasNoTrans || TransB != CblasNoTrans || N == 0 || K == 0) { + KLEIDIAI_DEBUG_LOG("MlasSBGemmPackBSize returning 0 size. N=" << N << " K=" << K); + return 0; + } + // + // Compute the number of bytes required to hold the packed buffer. + // + size_t bytes = 0; + bytes = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(N, K); + + return bytes; +} + +bool +MLASCALL +ArmKleidiAI::MlasSBGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB +) +/*++ + +Routine Description: + + This routine packs the contents of matrix B to the destination buffer. The + destination buffer should be sized based on MlasSBGemmPackBSize(). For best + performance, the destination buffer should be aligned to the value returned + from MlasGetPreferredBufferAlignment(). + +Arguments: + + TransA - Supplies the transpose operation on A matrix. + + TransB - Supplies the transpose operation on B matrix. + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + + B - Supplies the address of matrix B. + + ldb - Supplies the first dimension of matrix B. + + PackedB - Supplies the address of packed matrix B. + +Return Value: + + Returns true if the packing operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ +{ + if (TransA != CblasNoTrans || TransB != CblasNoTrans || N == 0 || K == 0) { + KLEIDIAI_DEBUG_LOG("MlasSBGemmPackB one of N or K is 0, falling back to MLAS."); + return false; + } + + const size_t nr = sbgemm_gemm.ukernel.get_nr(); + const size_t kr = sbgemm_gemm.ukernel.get_kr(); + const size_t sr = sbgemm_gemm.ukernel.get_sr(); + + // Ensure size and zero the used span. + g_kai_tls_sbgemm.bias_zero.resize(N, 0.0f); + + kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls_sbgemm.bias_zero.data(), nullptr, PackedB, 0, nullptr); + + return true; +} + +bool +MLASCALL +ArmKleidiAI::MlasSBGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SBGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) +/*++ + +Routine Description: + + This routine performs a bfloat16 batched matrix multiplication (SBGEMM) operation using KleidiAI kernels. + If packing is needed, it prepares the required buffers and invokes the + appropriate left-hand side (LHS) and right-hand side (RHS) pack functions. + +Arguments: + + TransA - Supplies the transpose operation on A matrix. + + TransB - Supplies the transpose operation on B matrix. + + M - Supplies the number of rows of matrix A and matrix C. + + N - Supplies the number of columns of matrix B and matrix C. + + K - Supplies the number of columns of matrix A and rows of matrix B. + + Data - Supplies a pointer to the MLAS_SBGEMM_DATA_PARAMS array containing per-batch input/output pointers and parameters. + + BatchSize - Supplies the number of independent GEMM computations to perform in the batch. + + ThreadPool - Supplies the thread pool to parallelize computation across batches and tiles. + +Return Value: + + Returns true if the GEMM operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ +{ + if (TransA != CblasNoTrans || TransB != CblasNoTrans || K == 0) { + return false; + } + + if (M == 0 || N == 0 || BatchSize == 0) { + return true; + } + + size_t m_step = sbgemm_gemm.ukernel.get_m_step(); + size_t n_step = sbgemm_gemm.ukernel.get_n_step(); + + if ((M < m_step || N < n_step) && !Data->BIsPacked) { + // Fallback + return false; + } + + const size_t mr = sbgemm_gemm.ukernel.get_mr(); + const size_t kr = sbgemm_gemm.ukernel.get_kr(); + const size_t sr = sbgemm_gemm.ukernel.get_sr(); + + size_t LhsPackedStride = 0; + std::byte* LhsPackedData = nullptr; + + LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr); + + size_t lhs_resize = 0; + if (mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) + { + // size_t wraparound detected for LhsPackedStride, fallback to MLAS + return false; + } + + g_kai_tls_sbgemm.lhs_packed.resize(lhs_resize); + LhsPackedData = g_kai_tls_sbgemm.lhs_packed.data(); + + // RHS packed buffer: use TLS reusable vector to minimize allocations + size_t RhsPackedStride = 0; + std::byte* RhsPackedData = nullptr; + + // It is assumed all B batches require packing or not + if (Data[0].BIsPacked) { + // We have already decided the matmul variant we are using, before having values for M,N,K + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + KLEIDIAI_KERNEL_LOG("kai_run_lhs_pack_bf16p2vlx2_f32_sme" << " M=" << M << " K=" << K << " mr=" << mr << " kr=" << kr << " sr=" << sr); + kai_run_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + }); + } else { + // Multithread pack lhs and rhs + RhsPackedStride = ArmKleidiAI::MlasSBGemmPackBSize(TransA, TransB, N, K); + size_t rhs_resize = 0; + if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize)) + { + // size_t wraparound detected for RhsPackedStride, fallback to MLAS + return false; + } + + g_kai_tls_sbgemm.rhs_packed.resize(rhs_resize); + RhsPackedData = g_kai_tls_sbgemm.rhs_packed.data(); + + MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) { + if (batch_idx & 0x1) { + batch_idx >>= 1; + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + KLEIDIAI_KERNEL_LOG("kai_run_lhs_pack_bf16p2vlx2_f32_sme" + << " M=" << M << " K=" << K << " mr=" << mr << " kr=" << kr << " sr=" << sr); + kai_run_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + } else { + batch_idx >>= 1; + std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]); + ArmKleidiAI::MlasSBGemmPackB(TransA, TransB, N, K, + reinterpret_cast(Data[batch_idx].B), + Data[batch_idx].ldb, RhsPackedPtr); + } + }); + } + + // tile iteration dimensions + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(M, m_step); // M + dim[2] = MlasDivRoundup(N, n_step); // N + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step); + + // update tile iterations + dim[1] = MlasDivRoundup(M, m_step); + dim[2] = MlasDivRoundup(N, n_step); + + // Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop. + // Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively. + size_t max_tile_elems = 0; + if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) { + // size_t wraparound detected for tile size, fallback to MLAS + return false; + } + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + // compute B,M,N index from iteration index + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = sbgemm_gemm.ukernel.get_rhs_packed_offset(NIdx * n_step, K); + + const std::byte* B_base = Data[0].BIsPacked + ? reinterpret_cast(Data[BIdx].B) + : (RhsPackedData + RhsPackedStride * BIdx); + auto BTile = reinterpret_cast(B_base + rhs_packed_offset); + + // Get lhs tile, A + const size_t lhs_packed_offset = sbgemm_gemm.ukernel.get_lhs_packed_offset(MIdx * m_step, K); + + const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx; + auto ATile = reinterpret_cast(A_base + lhs_packed_offset); + + auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = reinterpret_cast( + reinterpret_cast(Data[BIdx].C) + + MIdx * m_step * Data[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) + ); + + // Final output tile and bias pointers + float* dst_tile = reinterpret_cast(CTile); + const float* bias = Data[BIdx].Bias; + const size_t ldc = Data[BIdx].ldc; + + // Select output destination and strides once, then run_matmul exactly once. + const bool direct_to_c = ( + bias == nullptr && + Data[BIdx].ZeroMode && + TileSizeM != 0 && + TileSizeN != 0); + + float* out_tile = nullptr; + size_t out_row_stride_bytes = 0; + + if (direct_to_c) { + out_tile = dst_tile; + out_row_stride_bytes = ldc * sizeof(float); + } else { + // Compute into a temporary buffer for raw A*B result (TLS reusable buffer) + const size_t tile_elems = TileSizeM * TileSizeN; + g_kai_tls_sbgemm.output_tile.resize(tile_elems); + out_tile = g_kai_tls_sbgemm.output_tile.data(); + out_row_stride_bytes = TileSizeN * sizeof(float); + } + + KLEIDIAI_KERNEL_LOG(sbgemm_gemm.name + << " M=" << TileSizeM << " N=" << TileSizeN << " K=" << K); + sbgemm_gemm.ukernel.run_matmul( + TileSizeM, + TileSizeN, + K, + ATile, BTile, out_tile, + out_row_stride_bytes, sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + + if (!direct_to_c) { + if (Data[BIdx].ZeroMode) { + ApplyBias2D(out_tile, TileSizeM, TileSizeN, bias, dst_tile, ldc, NIdx * n_step); + } else { + AccumulateTile(out_tile, TileSizeM, TileSizeN, dst_tile, ldc); + } + } + + if (Data[BIdx].OutputProcessor != nullptr) { + Data[BIdx].OutputProcessor->Process( + Data[BIdx].C, + MIdx * m_step, + NIdx * n_step, + TileSizeM, + TileSizeN, + Data[BIdx].ldc); + } + }); + return true; +} +#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 3c0ee29896cd9..954849fe90049 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -906,6 +906,39 @@ void const float* Bias, void* PackedB); +#if defined(__aarch64__) && defined(__linux__) +typedef +bool +(MLASCALL MLAS_SBGEMM_BATCH_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SBGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef +size_t +(MLASCALL MLAS_SBGEMM_PACK_B_SIZE_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef +bool +(MLASCALL MLAS_SBGEMM_PACK_B_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); +#endif + extern "C" { #if defined(MLAS_TARGET_AMD64_IX86) @@ -1364,6 +1397,13 @@ struct MLAS_PLATFORM { // MLAS Conv overrides MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; +#if defined(__aarch64__) && defined(__linux__) + // SBGemm overrides + MLAS_SBGEMM_BATCH_OVERRIDE* MlasSBGemmBatchOverride = nullptr; + MLAS_SBGEMM_PACK_B_SIZE_OVERRIDE* MlasSBGemmPackBSizeOverride = nullptr; + MLAS_SBGEMM_PACK_B_OVERRIDE* MlasSBGemmPackBOverride = nullptr; +#endif + #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 12dcd61b8840e..4e2c6cf1892e3 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -626,6 +626,14 @@ Return Value: this->MlasDynamicQGemmPackBOverride = ArmKleidiAI::MlasDynamicQGemmPackB; this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare; this->MlasConvOverride = ArmKleidiAI::MlasConv; +#if defined(__aarch64__) && defined(__linux__) + // Currently only an SME2 variant of SBGEMM exists + if (ArmKleidiAI::UseSME2){ + this->MlasSBGemmBatchOverride = ArmKleidiAI::MlasSBGemmBatch; + this->MlasSBGemmPackBSizeOverride = ArmKleidiAI::MlasSBGemmPackBSize; + this->MlasSBGemmPackBOverride = ArmKleidiAI::MlasSBGemmPackB; + } +#endif } #endif diff --git a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp index f41b380b2a071..f3a17f30be166 100644 --- a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp @@ -44,6 +44,9 @@ MlasConvPointwiseBf16KernelNeon( const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config; + mlas_backend_kernel_selector_config.use_kleidiai = ((KernelFlags & MLAS_CONV_KERNEL_MLAS_ARM_USE_KLEIDIAI) != 0); + const size_t StrideWidthElements = StrideWidth / sizeof(float); const size_t InputStrideElements = InputStride / sizeof(float); const size_t FilterStrideElements = FilterStride / sizeof(float); @@ -91,7 +94,7 @@ MlasConvPointwiseBf16KernelNeon( } } - MlasSBGemmBatch(OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr); + MlasSBGemmBatch(CblasNoTrans, CblasNoTrans, OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr, &mlas_backend_kernel_selector_config); if (ReluActivation) { const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h index 5415cb3dc4406..559c8b48e78b6 100644 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -299,11 +299,36 @@ MlasSBGemmGetDispatch() } size_t MLASCALL -MlasSBGemmPackBSize(size_t N, size_t K) +MlasSBGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + bool BIsfp32, + size_t N, + size_t K, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +) { // // Compute the number of bytes required to hold the packed buffer. // +#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasSBGemmPackBSizeOverride != nullptr && + TransA == CBLAS_TRANSPOSE::CblasNoTrans && + TransB == CBLAS_TRANSPOSE::CblasNoTrans && + BIsfp32) { + size_t bytes_required; + bytes_required = GetMlasPlatform().MlasSBGemmPackBSizeOverride(TransA, TransB, N, K); + if (bytes_required != 0){ // If ArmKleidiAI::MlasSBGemmPackBSize ran to completion + return bytes_required; + } + } +#endif + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + MLAS_UNREFERENCED_PARAMETER(BIsfp32); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); + const auto* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return 0; @@ -322,8 +347,33 @@ MlasSBGemmPackBSize(size_t N, size_t K) } void MLASCALL -MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) +MlasSBGemmConvertPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + bool BIsfp32, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +) { +#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasSBGemmPackBOverride != nullptr && + TransA == CBLAS_TRANSPOSE::CblasNoTrans && + TransB == CBLAS_TRANSPOSE::CblasNoTrans && + BIsfp32 && + GetMlasPlatform().MlasSBGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)){ + return; + } +#endif + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + MLAS_UNREFERENCED_PARAMETER(BIsfp32); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); + const auto* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return; @@ -331,8 +381,33 @@ MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* Pac } void MLASCALL -MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) +MlasSBGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SBGEMM_DATA_PARAMS* Data, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +) { +#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasSBGemmBatchOverride != nullptr && + TransA == CBLAS_TRANSPOSE::CblasNoTrans && + TransB == CBLAS_TRANSPOSE::CblasNoTrans && + Data->AIsfp32 && + (Data->BIsPacked || Data->BIsfp32) && + GetMlasPlatform().MlasSBGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchN, ThreadPool)){ + return; + } +#endif + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index e2846ebda8b9d..8a7795a81027d 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -137,10 +137,12 @@ Status MatMul::Compute(OpKernelContext* ctx) const { #if defined(__aarch64__) && defined(__linux__) bool GemmPackBBfloat16(AllocatorPtr& alloc, const Tensor& tensor_b, + bool trans_a, bool trans_b, IAllocatorUniquePtr& packed_b, size_t& packed_b_size, - TensorShape& b_shape) { + TensorShape& b_shape, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { // Only handle the common case of a 2D weight matrix. Additional matrices // could be handled by stacking the packed buffers. if (tensor_b.Shape().NumDimensions() != 2) { @@ -152,7 +154,12 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc, const size_t K = trans_b ? static_cast(b_shape[1]) : static_cast(b_shape[0]); const size_t N = trans_b ? static_cast(b_shape[0]) : static_cast(b_shape[1]); - packed_b_size = MlasSBGemmPackBSize(N, K); + packed_b_size = MlasSBGemmPackBSize(trans_a ? CBLAS_TRANSPOSE::CblasTrans : CBLAS_TRANSPOSE::CblasNoTrans, + trans_b ? CBLAS_TRANSPOSE::CblasTrans : CBLAS_TRANSPOSE::CblasNoTrans, + true, + N, + K, + mlas_backend_kernel_selector_config); if (packed_b_size == 0) { return false; } @@ -164,11 +171,15 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc, // buffer memory and we don not want it uninitialized and generate different hashes // if and when we try to cache this pre-packed buffer for sharing between sessions. memset(packed_b_data, 0, packed_b_size); - MlasSBGemmConvertPackB(N, + MlasSBGemmConvertPackB(trans_a ? CBLAS_TRANSPOSE::CblasTrans : CBLAS_TRANSPOSE::CblasNoTrans, + trans_b ? CBLAS_TRANSPOSE::CblasTrans : CBLAS_TRANSPOSE::CblasNoTrans, + true, + N, K, tensor_b.Data(), trans_b ? K : N, - packed_b_data); + packed_b_data, + mlas_backend_kernel_selector_config); return true; } #endif @@ -191,8 +202,8 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc dim2 = static_cast(b_shape[1]); } - if (use_fastmath_mode_ && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { - is_packed = GemmPackBBfloat16(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + if (use_fastmath_mode_ && (trans_a_attr_ == 0) && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { + is_packed = GemmPackBBfloat16(alloc, tensor, trans_a_attr_ != 0, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_, &mlas_backend_kernel_selector_config_); } else #endif { @@ -260,7 +271,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); #if defined(__aarch64__) && defined(__linux__) - if (use_fastmath_mode_ && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { + if (use_fastmath_mode_ && !trans_a && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { std::vector data(max_len); for (size_t i = 0; i < max_len; i++) { data[i].BIsfp32 = !(bool(packed_b_)); @@ -273,8 +284,10 @@ Status MatMul::Compute(OpKernelContext* ctx) const { data[i].ldc = N; data[i].Bias = nullptr; data[i].OutputProcessor = nullptr; + data[i].BIsPacked = static_cast(packed_b_); } - MlasSBGemmBatch(M, N, K, max_len, data.data(), thread_pool); + MlasSBGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, + M, N, K, max_len, data.data(), thread_pool, &mlas_backend_kernel_selector_config_); } else #endif { diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp index f85fe97776dc1..2a8e01b9dda3a 100644 --- a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -92,17 +92,62 @@ class SBGemmShortExecuteTest : public MlasTestFixture +class SBGemmAccumulateExecuteTest : public MlasTestFixture> { + public: + explicit SBGemmAccumulateExecuteTest(size_t M, size_t N, size_t K, size_t Batch) + : M_(M), N_(N), K_(K), Batch_(Batch) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->TestAccumulate(M_, N_, K_, Batch_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch) { + std::stringstream ss; + ss << "Accumulate/Batch" << Batch << "/M" << M << "xN" << N << "xK" << K; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSBGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SBGemmAccumulateExecuteTest(M, N, K, Batch); + }); + + return 1; + } + + static size_t RegisterAccumulateTests() { + size_t test_registered = 0; + test_registered += RegisterSingleTest(1, 1, 1, 1); + test_registered += RegisterSingleTest(7, 9, 13, 1); + test_registered += RegisterSingleTest(32, 32, 32, 1); + if (!Packed) { + test_registered += RegisterSingleTest(5, 7, 3, 4); + } + return test_registered; + } + + private: + size_t M_, N_, K_, Batch_; +}; + static size_t SBGemmRegistLongExecute() { size_t count = 0; count += MlasLongExecuteTests>::RegisterLongExecute(); - if (MlasSBGemmPackBSize(128, 128) > 0) { + if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) { count += MlasLongExecuteTests>::RegisterLongExecute(); } if (GetMlasThreadPool() != nullptr) { count += MlasLongExecuteTests>::RegisterLongExecute(); - if (MlasSBGemmPackBSize(128, 128) > 0) { + if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) { count += MlasLongExecuteTests>::RegisterLongExecute(); } } @@ -114,14 +159,18 @@ static size_t SBGemmRegistShortExecute() { size_t count = 0; count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); - if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmAccumulateExecuteTest::RegisterAccumulateTests(); + if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) { count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + count += SBGemmAccumulateExecuteTest::RegisterAccumulateTests(); } if (GetMlasThreadPool() != nullptr) { count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); - if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmAccumulateExecuteTest::RegisterAccumulateTests(); + if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) { count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + count += SBGemmAccumulateExecuteTest::RegisterAccumulateTests(); } } diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.h b/onnxruntime/test/mlas/unittest/test_sbgemm.h index 13701e2e3de46..af5b1a34be8f0 100644 --- a/onnxruntime/test/mlas/unittest/test_sbgemm.h +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h @@ -60,14 +60,15 @@ class MlasSBGemmTest : public MlasTestBase { MatrixGuardBuffer BufferFloatC; MLAS_THREADPOOL* threadpool_; - void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { - size_t PackedBSize = MlasSBGemmPackBSize(N, K); + void* PackB(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, const BType* B, size_t ldb) { + const bool BIsfp32 = std::is_same::value; + size_t PackedBSize = MlasSBGemmPackBSize(TransA, TransB, BIsfp32, N, K, nullptr); if (PackedBSize == 0) { return nullptr; } void* PackedB = BufferBPacked.GetBuffer(PackedBSize); if (std::is_same::value) { - MlasSBGemmConvertPackB(N, K, (const float*)B, ldb, PackedB); + MlasSBGemmConvertPackB(TransA, TransB, BIsfp32, N, K, (const float*)B, ldb, PackedB, nullptr); } else { } return PackedB; @@ -83,7 +84,10 @@ class MlasSBGemmTest : public MlasTestBase { size_t ldb, const float* Bias, float* C, - size_t ldc) { + size_t ldc, + bool ZeroMode = true) { + constexpr CBLAS_TRANSPOSE TransA = CblasNoTrans; + constexpr CBLAS_TRANSPOSE TransB = CblasNoTrans; std::vector GemmParameters(BatchSize); for (size_t i = 0; i < GemmParameters.size(); i++) { @@ -99,10 +103,13 @@ class MlasSBGemmTest : public MlasTestBase { params.ldc = ldc; params.AIsfp32 = true; params.BIsfp32 = true; + params.BIsPacked = false; + params.ZeroMode = ZeroMode; if (Packed) { ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; - params.B = PackB(N, K, B, ldb); + params.B = PackB(TransA, TransB, N, K, B, ldb); + params.BIsPacked = true; params.ldb = 0; params.BIsfp32 = false; } else { @@ -111,7 +118,7 @@ class MlasSBGemmTest : public MlasTestBase { } } - MlasSBGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); + MlasSBGemmBatch(TransA, TransB, M, N, K, BatchSize, GemmParameters.data(), threadpool_, nullptr); } void ReferenceSgemm(size_t M, @@ -186,12 +193,14 @@ class MlasSBGemmTest : public MlasTestBase { ReferenceSgemm(M, N, K, BatchSize, A, B, Bias, CReference); const float cosine_similarity_threshold = 0.98; - for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t batch = 0; batch < BatchSize; batch++) { for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { + for (size_t n = 0; n < N; n++) { + // Compute flat index to avoid desync if we break + const size_t f = batch * M * N + m * N + n; if (!(CloseEnough(float(C[f]), CReference[f]))) { float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); - if (abs(cos_sim) < cosine_similarity_threshold) { + if (std::abs(cos_sim) < cosine_similarity_threshold) { ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; } else { break; @@ -208,6 +217,81 @@ class MlasSBGemmTest : public MlasTestBase { } } + void TestAccumulate(size_t M, size_t N, size_t K, size_t BatchSize) { + AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + AType Atail[16]; + std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); + + BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); + BType Btail[16]; + std::memcpy(Btail, B + N * K * BatchSize, 16 * sizeof(BType)); + + const float* Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); + float BiasTail[16]; + std::memcpy(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)); + + float* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* CReference = BufferCReference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + + this->CallSBGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N, true); + ReferenceSgemm(M, N, K, BatchSize, A, B, Bias, CReference); + + const float cosine_similarity_threshold = 0.98; + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + // Compute flat index to avoid desync if we break + const size_t f = batch * M * N + m * N + n; + if (!(CloseEnough(float(C[f]), CReference[f]))) { + float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); + if (std::abs(cos_sim) < cosine_similarity_threshold) { + ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; + } else { + break; + } + } + } + } + } + + float* CNoBias = BufferFloatC.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + ReferenceSgemm(M, N, K, BatchSize, A, B, nullptr, CNoBias); + for (size_t i = 0, size = N * M * BatchSize; i < size; i++) { + CReference[i] += CNoBias[i]; + } + + this->CallSBGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N, false); + + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + // Compute flat index to avoid desync if we break + const size_t f = batch * M * N + m * N + n; + if (!(CloseEnough(float(C[f]), CReference[f]))) { + float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); + if (std::abs(cos_sim) < cosine_similarity_threshold) { + ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; + } else { + break; + } + } + } + } + } + + ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; + ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; + ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)), 0) << "Bias buffer overwritten!"; + } + private: public: static const char* GetTestSuiteName() { From dcf23fb64a6eeb889e46e2f822aac6338a7e5689 Mon Sep 17 00:00:00 2001 From: Kevin Taha Date: Tue, 17 Mar 2026 10:57:10 -0700 Subject: [PATCH 2/6] Upgrade minimatch 3.1.2 to 3.1.4 (CVE-2026-27904) (#27667) Upgrading dependency to resolve CVE-2026-27904, which is lighting up some component governance issues with internal-MSFT builds of ORT. Co-authored-by: Kevin Taha Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- js/package-lock.json | 12 ++++++------ js/react_native/e2e/package-lock.json | 4 +++- js/react_native/package-lock.json | 4 +++- js/web/package-lock.json | 24 ++++++++++++------------ 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/js/package-lock.json b/js/package-lock.json index 0fca515b61238..22fb22757e94b 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4282,9 +4282,9 @@ } }, "node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "dependencies": { "brace-expansion": "^1.1.7" @@ -8760,9 +8760,9 @@ } }, "minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "requires": { "brace-expansion": "^1.1.7" diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index 8e3f605deef7d..d8a273ef6825f 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -9895,7 +9895,9 @@ } }, "node_modules/minimatch": { - "version": "3.1.2", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "license": "ISC", "dependencies": { "brace-expansion": "^1.1.7" diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index 3350e741a5632..6073725939e87 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -6302,7 +6302,9 @@ } }, "node_modules/minimatch": { - "version": "3.1.2", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "license": "ISC", "dependencies": { diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 48877a85cf4c8..0e6d47f952c43 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1603,9 +1603,9 @@ } }, "node_modules/glob/node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "dependencies": { "brace-expansion": "^1.1.7" @@ -2258,9 +2258,9 @@ } }, "node_modules/karma/node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "dependencies": { "brace-expansion": "^1.1.7" @@ -4953,9 +4953,9 @@ } }, "minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "requires": { "brace-expansion": "^1.1.7" @@ -5367,9 +5367,9 @@ } }, "minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "requires": { "brace-expansion": "^1.1.7" From 672e3bbf52efb7268a36423b836761cc2cd8bcfe Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 17 Mar 2026 18:22:09 -0700 Subject: [PATCH 3/6] [webgpu] fix condition of DAWN_ENABLE_VULKAN and DAWN_ENABLE_D3D12 (#27694) ### Description fix condition of the following definitions: - DAWN_ENABLE_VULKAN - DAWN_ENABLE_D3D12 --- cmake/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index a53fd24ea55a9..643abba135a70 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -992,10 +992,10 @@ if (onnxruntime_USE_WEBGPU) if (onnxruntime_USE_EXTERNAL_DAWN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_EXTERNAL_DAWN=1) endif() - if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + if ((onnxruntime_ENABLE_DAWN_BACKEND_VULKAN AND WIN32) OR LINUX OR ANDROID) list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_VULKAN=1) endif() - if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12 AND WIN32) list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_D3D12=1) endif() if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP) From f773c90e2f28ed2d8f220a1246c888a715dbb8e4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 17 Mar 2026 20:04:25 -0700 Subject: [PATCH 4/6] [CUDA] DecoderMaskedMultiHeadAttention files consolidation (#27688) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This deletes 3 per-head-size .cu files and merges their content into a single file to avoid dependency during cuda compiling. Currently, masked_multihead_attention_kernel template is implemented in decoder_masked_multihead_attention_impl.cu‎. The other three .cu files use the masked_multihead_attention_kernel template but not include the implementation. That causes problem when they are built in cuda plugin ep. --- .../decoder_masked_multihead_attention_128.cu | 67 ------------------- .../decoder_masked_multihead_attention_32.cu | 67 ------------------- .../decoder_masked_multihead_attention_64.cu | 67 ------------------- ...decoder_masked_multihead_attention_impl.cu | 48 +++++++++++++ 4 files changed, 48 insertions(+), 201 deletions(-) delete mode 100644 onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu delete mode 100644 onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu delete mode 100644 onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu deleted file mode 100644 index d4e872f8ac165..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu +++ /dev/null @@ -1,67 +0,0 @@ -/* - * The implementation of this file is based on code provided by https://github.com/NVIDIA/FasterTransformer - * - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Modifications Copyright (c) Microsoft. -// Licensed under the MIT License. - -#include "decoder_masked_multihead_attention_impl.h" -#include "decoder_masked_multihead_attention_impl_utils.h" - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -using namespace decoder_masked_self_attention_details; - -#define MMHA_LAUNCH_KERNEL( \ - T, QK, head_size, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK) \ - size_t dynamic_block_memory = CalcDynamicBlockMemory(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) - -template -void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream) { - constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; - int total_sequence_length = params.total_sequence_length; - - if (total_sequence_length < 32) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 4, THREADS_PER_VALUE, 64); - } else if (total_sequence_length < 2048) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 2, THREADS_PER_VALUE, 128); - } else { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 1, THREADS_PER_VALUE, 256); - } -} - -// Instantiate templates -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu deleted file mode 100644 index 16f22b020ee1f..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu +++ /dev/null @@ -1,67 +0,0 @@ -/* - * The implementation of this file is based on code provided by https://github.com/NVIDIA/FasterTransformer - * - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Modifications Copyright (c) Microsoft. -// Licensed under the MIT License. - -#include "decoder_masked_multihead_attention_impl.h" -#include "decoder_masked_multihead_attention_impl_utils.h" - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -using namespace decoder_masked_self_attention_details; - -#define MMHA_LAUNCH_KERNEL( \ - T, QK, head_size, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK) \ - size_t dynamic_block_memory = CalcDynamicBlockMemory(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) - -template -void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream) { - constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; - int total_sequence_length = params.total_sequence_length; - - if (total_sequence_length < 32) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 4, THREADS_PER_VALUE, 64); - } else if (total_sequence_length < 2048) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 2, THREADS_PER_VALUE, 128); - } else { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 1, THREADS_PER_VALUE, 256); - } -} - -// Instantiate templates -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu deleted file mode 100644 index c933b0c6d2241..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu +++ /dev/null @@ -1,67 +0,0 @@ -/* - * The implementation of this file is based on code provided by https://github.com/NVIDIA/FasterTransformer - * - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Modifications Copyright (c) Microsoft. -// Licensed under the MIT License. - -#include "decoder_masked_multihead_attention_impl.h" -#include "decoder_masked_multihead_attention_impl_utils.h" - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -using namespace decoder_masked_self_attention_details; - -#define MMHA_LAUNCH_KERNEL( \ - T, QK, head_size, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK) \ - size_t dynamic_block_memory = CalcDynamicBlockMemory(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) - -template -void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream) { - constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; - int total_sequence_length = params.total_sequence_length; - - if (total_sequence_length < 32) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 4, THREADS_PER_VALUE, 64); - } else if (total_sequence_length < 2048) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 2, THREADS_PER_VALUE, 128); - } else { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 1, THREADS_PER_VALUE, 256); - } -} - -// Instantiate templates -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index efb48fee60772..b2fb743d28d92 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -819,6 +819,54 @@ template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +#define MMHA_LAUNCH_KERNEL_FOR_HEAD(T, QK, HEAD_SIZE, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK) \ + size_t dynamic_block_memory = CalcDynamicBlockMemory(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + masked_multihead_attention_kernel \ + <<>>(params) + +template +void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream) { + constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; + const int total_sequence_length = params.total_sequence_length; + + if (total_sequence_length < 32) { + MMHA_LAUNCH_KERNEL_FOR_HEAD(T, QK, head_size, 4, THREADS_PER_VALUE, 64); + } else if (total_sequence_length < 2048) { + MMHA_LAUNCH_KERNEL_FOR_HEAD(T, QK, head_size, 2, THREADS_PER_VALUE, 128); + } else { + MMHA_LAUNCH_KERNEL_FOR_HEAD(T, QK, head_size, 1, THREADS_PER_VALUE, 256); + } +} + +#undef MMHA_LAUNCH_KERNEL_FOR_HEAD + +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); + +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); + +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime From f3cc7fff925b1bbcf01239cd99223d43423b9753 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 17 Mar 2026 20:14:33 -0700 Subject: [PATCH 5/6] Fix MLAS qgemm dispatch and kernel regressions in quantized conv tests (#27671) ## Description This PR fixes longstanding MLAS issues that were causing `NhwcTransformerTests.*` and `QDQTransformerTests.*` failures in quantized convolution paths (see https://github.com/microsoft/onnxruntime/issues/27670). The failures were not in the graph transformers themselves; they came from incorrect qgemm dispatch selection and broken backend kernel behavior in specific AVX2-VNNI and AMX paths. The fix removes incorrect `U8U8` dispatch upgrades, avoids a broken AVX2-VNNI row-panel fallback, and corrects the AMX `U8S8` 32-row kernel path. It also adds MLAS regression coverage for the conv-shaped qgemm dimensions that exposed the problems. ## Summary of Changes ### Dispatch Selection Fixes | File | Change | |------|--------| | `onnxruntime/core/mlas/lib/platform.cpp` | Remove three incorrect assignments that upgraded `GemmU8U8Dispatch` to `U8S8` dispatch objects in the AVXVNNI, AVX512VNNI, and AMX feature paths. | ### AVX2-VNNI Kernel Fix | File | Change | |------|--------| | `onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp` | Reduce `StrideM` from `6` to `4` for the `U8U8`, `S8S8`, and `S8U8` AVX2-VNNI qgemm dispatch objects so they never enter the legacy `>4` row fallback path. | ### AMX Kernel Fix | File | Change | |------|--------| | `onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp` | Replace the broken pipelined `CountM >= 32` `U8S8` AMX fast path with the same per-K tile update pattern already used by the working smaller-row path. | ### Regression Coverage | File | Change | |------|--------| | `onnxruntime/test/mlas/unittest/test_qgemm_fixture.h` | Add MLAS qgemm regression cases for conv-like shapes `6x30x207` and `169x30x207` in packed/non-packed and int32 or fp32 variants. | ## Root Cause There were three separate MLAS correctness issues: 1. `platform.cpp` was incorrectly overwriting `GemmU8U8Dispatch` with `U8S8` dispatch objects when newer CPU features were detected. That caused `U8U8` conv workloads to run through the wrong dispatch path. 2. The AVX2-VNNI qgemm dispatch objects advertised an `M` stride of `6`, but the assembly kernel only handled VNNI packing safely up to 4 rows. For 5- or 6-row panels it fell back to an older AVX2 path with incompatible packing and sign assumptions. 3. The AMX `U8S8` qgemm kernel had a bug in its `CountM >= 32` fast path. The smaller-row AMX path was correct, but the 32-row pipelined update logic produced wrong accumulators for conv-shaped workloads and caused the remaining QDQ/NHWC failures on AMX-capable hosts. ## Why This Fix - The `platform.cpp` cleanup restores the intended `U8U8` dispatch selection on feature-rich x86 hosts. - The AVX2-VNNI stride change is a targeted mitigation that avoids the known-bad legacy fallback until that assembly path is corrected. - The AMX kernel change keeps the AMX `U8S8` dispatch enabled, but replaces the broken 32-row implementation with a proven update pattern that matches the working smaller-row path. - The new MLAS regression tests cover the exact conv-derived qgemm shapes that exposed the bug, so future dispatch or kernel changes will fail at the MLAS layer before surfacing as transformer test regressions. ## Testing - `cd build/cuda/Release && ./onnxruntime_mlas_test --gtest_filter='QGemmU8S8_*169xN30xK207*:*QGemmU8S8_*6xN30xK207*'` - `cd build/cuda/Release && ./onnxruntime_test_all --gtest_filter='NhwcTransformerTests.*:QDQTransformerTests.*'` - Verified that the filtered transformer suite passes with AMX `U8S8` dispatch enabled. ## Motivation and Context These test failures had been present for a long time and were initially attributed to transformer rewrites because they surfaced in NHWC and QDQ test suites. Investigation showed that the optimized graphs were structurally correct and that the failures came from lower-level MLAS qgemm execution instead. Fixing the behavior in MLAS is the right layer because it restores correctness for both direct qgemm coverage and higher-level quantized conv paths. ## Checklist - [x] Tests added/updated - [x] No breaking changes - [x] CI passes --- onnxruntime/core/mlas/lib/platform.cpp | 3 - .../core/mlas/lib/qgemm_kernel_amx.cpp | 268 +++++------------- .../core/mlas/lib/qgemm_kernel_avx2.cpp | 6 +- .../test/mlas/unittest/test_qgemm_fixture.h | 6 + 4 files changed, 80 insertions(+), 203 deletions(-) diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 4e2c6cf1892e3..ac3761d63bd20 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -445,7 +445,6 @@ Return Value: if ((Cpuid7_1[0] & 0x10) != 0) { - this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; @@ -500,7 +499,6 @@ Return Value: if ((Cpuid7[2] & 0x800) != 0) { - this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; @@ -540,7 +538,6 @@ Return Value: (Cpuid7[3] & 0b1 << 25) != 0 && (xcr0 & XFEATURE_MASK_XTILE) == XFEATURE_MASK_XTILE) { if (MlasInitAMX()) { - this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAmx; this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx; } } diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp index 479a82e712c5e..bef8e1f800fd3 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp @@ -219,101 +219,6 @@ MlasGemmQuantThreadInit() } -static inline -void -InitHalfTileWithRowColSums( - int32_t* Tile, - const int32_t* rowsum_ptr, - const __m512i colsum, - const int32_t* c_ptr, - const size_t ldc, - bool ZeroMode - ) -{ - __m512i row0,row1,row2,row3,row4,row5,row6,row7; - row0 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[0])); - row1 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[1])); - row2 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[2])); - row3 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[3])); - row4 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[4])); - row5 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[5])); - row6 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[6])); - row7 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[7])); - if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); - } - _mm512_storeu_si512(Tile, row0); - _mm512_storeu_si512(Tile+16, row1); - _mm512_storeu_si512(Tile+32, row2); - _mm512_storeu_si512(Tile+48, row3); - _mm512_storeu_si512(Tile+64, row4); - _mm512_storeu_si512(Tile+80, row5); - _mm512_storeu_si512(Tile+96, row6); - _mm512_storeu_si512(Tile+112, row7); - //Tile += 128; - //rowsum_ptr+=8; - //c_ptr += ldc * 8; -} - -static inline -void -InitHalfTileWithRowColSumsZeroPoints( - int32_t* Tile, - const int32_t* rowsum_ptr, - const __m512i colsum, - const __m512i zeropoint, - const int32_t* c_ptr, - const size_t ldc, - bool ZeroMode - ) -{ - __m512i row0,row1,row2,row3,row4,row5,row6,row7; - row0 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[0])); - row1 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[1])); - row2 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[2])); - row3 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[3])); - row4 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[4])); - row5 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[5])); - row6 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[6])); - row7 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[7])); - row0 = _mm512_add_epi32(colsum, row0); - row1 = _mm512_add_epi32(colsum, row1); - row2 = _mm512_add_epi32(colsum, row2); - row3 = _mm512_add_epi32(colsum, row3); - row4 = _mm512_add_epi32(colsum, row4); - row5 = _mm512_add_epi32(colsum, row5); - row6 = _mm512_add_epi32(colsum, row6); - row7 = _mm512_add_epi32(colsum, row7); - if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); - } - _mm512_storeu_si512(Tile, row0); - _mm512_storeu_si512(Tile+16, row1); - _mm512_storeu_si512(Tile+32, row2); - _mm512_storeu_si512(Tile+48, row3); - _mm512_storeu_si512(Tile+64, row4); - _mm512_storeu_si512(Tile+80, row5); - _mm512_storeu_si512(Tile+96, row6); - _mm512_storeu_si512(Tile+112, row7); - //Tile += 128; - //rowsum_ptr+=8; - //c_ptr += ldc * 8; -} - static inline void @@ -636,178 +541,147 @@ MlasGemmQuantKernel( } - int32_t* c_blk = C; // C - beginning of the row + constexpr uint16_t FullMask = 0xFFFF; + int32_t* c_blk = C; int32_t* c16_blk = C + ldc * TILE_M; - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedBType* b_blk = B; // restart B + const MLAS_GEMM_U8S8_KERNEL_AMX::PackedBType* b_blk = B; const int32_t* col_sum_ptr = ColumnSumBuffer; const int32_t* zp_ptr = ZeroPointB; size_t n = CountN; for (; n >= 2 * TILE_N; n -= 2 * TILE_N) { - // Restart A from row start - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; - - if (ZeroPointB != nullptr){ - __m512i colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); + col_sum_ptr += TILE_N; + if (ZeroPointB != nullptr) { __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; - tile_loadd(TMM0, b_blk, TILE_K); - InitHalfTileWithRowColSumsZeroPoints(Tile4, RowSumBuffer, colsum, zeropoint, c_blk, ldc, ZeroMode); - tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSumsZeroPoints(Tile4+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - InitHalfTileWithRowColSumsZeroPoints(Tile5, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk, ldc, ZeroMode); - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSumsZeroPoints(Tile5+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; - zeropoint = _mm512_loadu_si512(zp_ptr); + InitTileWithRowColSumsZeroPoints( + Tile4, TILE_M, FullMask, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); + InitTileWithRowColSumsZeroPoints( + Tile5, TILE_M, FullMask, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); + } else { + InitTileWithRowColSums(Tile4, TILE_M, FullMask, RowSumBuffer, colsum, ZeroMode, c_blk, ldc); + InitTileWithRowColSums(Tile5, TILE_M, FullMask, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); + } + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + + colsum = _mm512_loadu_si512(col_sum_ptr); + col_sum_ptr += TILE_N; + if (ZeroPointB != nullptr) { + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; - InitHalfTileWithRowColSumsZeroPoints(Tile6, RowSumBuffer, colsum, zeropoint, c_blk+TILE_N, ldc, ZeroMode); - tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - InitHalfTileWithRowColSumsZeroPoints(Tile6+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - tile_dpbusd(TMM4, TMM2, TMM0); - InitHalfTileWithRowColSumsZeroPoints(Tile7, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk+TILE_N, ldc, ZeroMode); - InitHalfTileWithRowColSumsZeroPoints(Tile7+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); + InitTileWithRowColSumsZeroPoints( + Tile6, TILE_M, FullMask, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); + InitTileWithRowColSumsZeroPoints( + Tile7, TILE_M, FullMask, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); } else { - __m512i colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; + InitTileWithRowColSums(Tile6, TILE_M, FullMask, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); + InitTileWithRowColSums(Tile7, TILE_M, FullMask, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); + } + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + + const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; + const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; + for (size_t k = PackedCountK; k > 0; k -= TILE_K) { tile_loadd(TMM0, b_blk, TILE_K); - InitHalfTileWithRowColSums(Tile4, RowSumBuffer, colsum, c_blk, ldc, ZeroMode); tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSums(Tile4+128, RowSumBuffer+8, colsum, c_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - InitHalfTileWithRowColSums(Tile5, RowSumBuffer+TILE_M, colsum, c16_blk, ldc, ZeroMode); - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSums(Tile5+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; - InitHalfTileWithRowColSums(Tile6, RowSumBuffer, colsum, c_blk+TILE_N, ldc, ZeroMode); tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - InitHalfTileWithRowColSums(Tile6+128, RowSumBuffer+8, colsum, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM4, TMM2, TMM0); - InitHalfTileWithRowColSums(Tile7, RowSumBuffer+TILE_M, colsum, c16_blk+TILE_N, ldc, ZeroMode); - InitHalfTileWithRowColSums(Tile7+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); - } - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); - for (size_t k = PackedCountK - TILE_K; k > 0; k -= TILE_K) { b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; - tile_dpbusd(TMM5, TMM3, TMM0); - tile_loadd(TMM0, b_blk, TILE_K); - tile_dpbusd(TMM6, TMM2, TMM1); - tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - tile_dpbusd(TMM7, TMM3, TMM1); - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - tile_dpbusd(TMM4, TMM2, TMM0); } - tile_dpbusd(TMM5, TMM3, TMM0); - tile_dpbusd(TMM6, TMM2, TMM1); - tile_dpbusd(TMM7, TMM3, TMM1); - b_blk += PackedCountK * TILE_N + TILE_N * TILE_K; - tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); c_blk += 2 * TILE_N; - tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); c16_blk += 2 * TILE_N; + b_blk += PackedCountK * TILE_N; } if (n != 0) { const uint16_t nmask_high = static_cast(nmasks >> 16); __m512i colsum = _mm512_maskz_loadu_epi32(static_cast(nmasks), col_sum_ptr); col_sum_ptr += TILE_N; - if (ZeroPointB != nullptr){ + if (ZeroPointB != nullptr) { __m512i zeropoint = _mm512_maskz_loadu_epi32(static_cast(nmasks), zp_ptr); zp_ptr += TILE_N; InitTileWithRowColSumsZeroPoints( - Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, - zeropoint, ZeroMode, c_blk, ldc); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); InitTileWithRowColSumsZeroPoints( - Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, - zeropoint, ZeroMode, c16_blk, ldc); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); } else { InitTileWithRowColSums( - Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, - ZeroMode, c_blk, ldc); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, ZeroMode, c_blk, ldc); InitTileWithRowColSums( - Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, - ZeroMode, c16_blk, ldc); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); } - if (nmask_high != 0){ + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + + if (nmask_high != 0) { colsum = _mm512_maskz_loadu_epi32(nmask_high, col_sum_ptr); - if (ZeroPointB != nullptr){ + if (ZeroPointB != nullptr) { __m512i zeropoint = _mm512_maskz_loadu_epi32(nmask_high, zp_ptr); InitTileWithRowColSumsZeroPoints( - Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, - zeropoint, ZeroMode, c_blk + TILE_N, ldc); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); InitTileWithRowColSumsZeroPoints( - Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, - zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); } else { InitTileWithRowColSums( - Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, - ZeroMode, c_blk + TILE_N, ldc); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); InitTileWithRowColSums( - Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, - ZeroMode, c16_blk + TILE_N, ldc); - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); } + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; - for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - tile_loadd(TMM0, b_blk, TILE_K); + for (size_t k = PackedCountK; k > 0; k -= TILE_K) { + tile_loadd(TMM0, b_blk, TILE_K); tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM4, TMM2, TMM0); tile_dpbusd(TMM5, TMM3, TMM0); - if (nmask_high != 0){ + if (nmask_high != 0) { tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - tile_dpbusd(TMM6, TMM2, TMM1); + tile_dpbusd(TMM6, TMM2, TMM1); tile_dpbusd(TMM7, TMM3, TMM1); - } + b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; } - if ((static_cast(nmasks) & 0x8000) != 0){ - tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + if ((static_cast(nmasks) & 0x8000) != 0) { + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); } else { - tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); - MoveTile(Tile4, TILE_M, static_cast(nmasks), c_blk, ldc); MoveTile(Tile5, TILE_M, static_cast(nmasks), c16_blk, ldc); } - if (nmask_high != 0){ - tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); - tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + if (nmask_high != 0) { + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); MoveTile(Tile6, TILE_M, nmask_high, c_blk + TILE_N, ldc); MoveTile(Tile7, TILE_M, nmask_high, c16_blk + TILE_N, ldc); } diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp index a6dbe8defd0e4..c3998dabd1d90 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp @@ -377,7 +377,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2Vnni = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8U8_KERNEL_AVX2VNNI::PackedK, MLAS_GEMM_U8U8_KERNEL_AVX2VNNI::PackedStrides.K, - 6 // assembly kernel M stride + 4 // avoid the legacy >4-row fallback, which assumes non-VNNI packing }; // S8S8 AVX-VNNI-INT8 support @@ -450,7 +450,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchAvx2Vnni = { MlasGemmQuantCopyPackB, MLAS_GEMM_S8S8_KERNEL_AVX2::PackedK, MLAS_GEMM_S8S8_KERNEL_AVX2::PackedStrides.K, - 6 // assembly kernel M stride + 4 // avoid the legacy >4-row fallback, which assumes non-VNNI packing }; // S8U8 AVX-VNNI-INT8 support @@ -523,5 +523,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8U8DispatchAvx2Vnni = { MlasGemmQuantCopyPackB, MLAS_GEMM_S8U8_KERNEL_AVX2::PackedK, MLAS_GEMM_S8U8_KERNEL_AVX2::PackedStrides.K, - 6 // assembly kernel M stride + 4 // avoid the legacy >4-row fallback, which assumes non-VNNI packing }; diff --git a/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h index 40f688a16ecca..4170bcea6a1ea 100644 --- a/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h +++ b/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h @@ -94,6 +94,11 @@ class QgemmShortExecuteTest : public Ml test_registered += RegisterSingleTest(1, 32, b, 5, 0, 0); } } + // Conv-like qgemm shapes that exposed AMD64 dispatch bugs in the VNNI/AMX paths. + test_registered += RegisterSingleTest(6, 30, 207, 1, 183, 223); + test_registered += RegisterSingleTest(6, 30, 207, 1, 17); + test_registered += RegisterSingleTest(169, 30, 207, 1, 183, 223); + test_registered += RegisterSingleTest(169, 30, 207, 1, 17); test_registered += RegisterSingleTest(43, 500, 401, 1, 183, 223); test_registered += RegisterSingleTest(1023, 1023, 1023, 1, 5, 8); test_registered += RegisterSingleTest(1023, 1023, 1023, 1, 7); @@ -169,6 +174,7 @@ class QgemmShortExecuteTest : public Mlas for (size_t b = 1; b < 96; b++) { test_registered += RegisterSingleTest(1, b, 32, 0, 0); } + test_registered += RegisterSingleTest(169, 30, 207, 183, 223); test_registered += RegisterSingleTest(43, 503, 401, 183, 223); test_registered += RegisterSingleTest(1024, 1024, 256, 13, 15); From 8a08225c3b1c6fc4a29ad6ee6d7b50b8dd2a58d2 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 17 Mar 2026 20:29:02 -0700 Subject: [PATCH 6/6] [Build] Fix clang build issues for CPU and CUDA builds (#27669) ## Description This PR fixes clang-specific build failures that show up in both the standalone clang build and the CUDA clang build. It keeps the build-system changes targeted, prefers source fixes where the warnings indicate real type or declaration issues, and avoids broader warning suppression than necessary for the CUDA provider target. ## Summary of Changes ### Build System | File | Change | |------|--------| | `cmake/CMakeLists.txt` | Stop forwarding `-Wshorten-64-to-32` through CUDA host compilation where the GNU host compiler does not recognize it. | | `cmake/onnxruntime_providers_cuda.cmake` | Add targeted clang `-Wno-error` handling for warning classes that are currently triggered by CUDA provider code and third-party CUDA headers under clang. | ### CPU / Common clang fixes | File | Change | |------|--------| | `onnxruntime/core/common/cpuid_info.cc` | Replace the clang-incompatible `__builtin_cpu_supports("waitpkg")` path with the CPUID-bit check for TPAUSE detection. | | `onnxruntime/test/framework/allocation_planner_test.cc` | Refactor `typeid` assertions to avoid clang's potentially-evaluated-expression warning while keeping test coverage unchanged. | ### CUDA provider and contrib fixes | File | Change | |------|--------| | `onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h` | Mark the `IConsoleDumper` overrides explicitly while leaving CUDA-only overloads unchanged. | | `onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc` | Use `template` on the dependent `GetAttrOrDefault` call so clang parses it correctly. | | `onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc` | Make narrowing conversions to flash-attention parameter fields explicit. | | `onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc` | Make the `nbits_` conversion explicit when calling the CUDA helper. | | `onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc` | Restrict the GCC-only warning pragma so clang does not treat it as an unknown warning option. | | `onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc` | Fix explicit state-field assignments to use the actual `int` field type. | | `onnxruntime/core/providers/cuda/cuda_mempool_arena.h` | Remove an unused private field that clang flagged in the CUDA provider build. | ## Testing Tested CPU and CUDA 12.8 builds in Azure Linux with - clang 18.1.8 - gcc 13.2 - cmake 4.2.3 Example for CPU build: ``` export CC=clang export CXX=clang++ bash build.sh --config RelWithDebInfo --parallel --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=ON ``` ## Motivation and Context Clang is stricter than GCC/MSVC in a few areas that affect this tree: CUDA host flag forwarding, explicit narrowing, dependent template parsing, warnings emitted from third-party CUDA headers, and RTTI/typeid expressions in tests. The goal here is to keep the staged fix minimal and maintainable by correcting real source issues where practical and confining warning downgrades to the CUDA provider target where third-party header noise is currently unavoidable. --- cmake/CMakeLists.txt | 4 +- cmake/onnxruntime_providers_cuda.cmake | 34 ++++++ .../cuda/bert/flash_attention/flash_api.cc | 24 ++-- .../cuda/bert/group_query_attention.cc | 2 +- .../cuda/quantization/matmul_nbits.cc | 2 +- .../cuda/quantization/moe_quantization.cc | 4 +- .../transformers/generation_device_helper.cc | 8 +- .../contrib_ops/cuda/utils/dump_cuda_tensor.h | 18 ++- onnxruntime/core/common/cpuid_info.cc | 6 - .../core/providers/cuda/cuda_mempool_arena.h | 1 - .../test/framework/allocation_planner_test.cc | 110 ++++++++++-------- 11 files changed, 133 insertions(+), 80 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 643abba135a70..385342479913a 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1189,7 +1189,9 @@ function(onnxruntime_set_compile_flags target_name) endif() if (onnxruntime_USE_CUDA) foreach(FLAG ${ORT_WARNING_FLAGS}) - target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") + if (NOT "${FLAG}" STREQUAL "-Wshorten-64-to-32") + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") + endif() endforeach() if (NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 4d1600b5fab7f..ff667d8f117e0 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -181,7 +181,9 @@ endif() foreach(ORT_FLAG ${ORT_WARNING_FLAGS}) + if (NOT "${ORT_FLAG}" STREQUAL "-Wshorten-64-to-32") target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler \"${ORT_FLAG}\">") + endif() endforeach() # Note: The minimum required CUDA version is greater than 11.3. @@ -222,6 +224,38 @@ "$<$>:-Wno-reorder>") target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "IBMClang") + foreach(CLANG_WARNING + braced-scalar-init + defaulted-function-deleted + inconsistent-missing-override + instantiation-after-specialization + logical-op-parentheses + mismatched-tags + shorten-64-to-32 + unneeded-internal-declaration + unknown-warning-option + unused-private-field + unused-variable) + target_compile_options(${target} PRIVATE "$<$>:-Wno-error=${CLANG_WARNING}>") + endforeach() + if (CMAKE_CUDA_HOST_COMPILER_ID STREQUAL "Clang" OR CMAKE_CUDA_HOST_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CUDA_HOST_COMPILER_ID STREQUAL "IBMClang") + foreach(CLANG_WARNING + braced-scalar-init + defaulted-function-deleted + inconsistent-missing-override + instantiation-after-specialization + logical-op-parentheses + mismatched-tags + shorten-64-to-32 + unneeded-internal-declaration + unknown-warning-option + unused-private-field + unused-variable) + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=${CLANG_WARNING}>") + endforeach() + endif() + endif() else() #mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4834>") diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index fcc470b19a7b4..0d994d4060e6c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -106,16 +106,16 @@ void set_params_fprop(Flash_fwd_params& params, #pragma warning(disable : 4267) // Ignore conversion from 'size_t' to 'int', possible loss of data #pragma warning(disable : 4244) // Ignore conversion from 'double' to 'float', possible loss of data #endif - params.b = batch_size; - params.h = num_heads; - params.h_k = num_heads_k; - params.h_h_k_ratio = num_heads / num_heads_k; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = head_size; - params.d_rounded = head_size_rounded; + params.b = static_cast(batch_size); + params.h = static_cast(num_heads); + params.h_k = static_cast(num_heads_k); + params.h_h_k_ratio = static_cast(num_heads / num_heads_k); + params.seqlen_q = static_cast(seqlen_q); + params.seqlen_k = static_cast(seqlen_k); + params.seqlen_q_rounded = static_cast(seqlen_q_rounded); + params.seqlen_k_rounded = static_cast(seqlen_k_rounded); + params.d = static_cast(head_size); + params.d_rounded = static_cast(head_size_rounded); // Set the different scale values. if (softcap > 0.0) { @@ -136,10 +136,10 @@ void set_params_fprop(Flash_fwd_params& params, params.is_causal = false; } if (window_size_left < 0 && window_size_right >= 0) { - window_size_left = seqlen_k; + window_size_left = static_cast(seqlen_k); } if (window_size_left >= 0 && window_size_right < 0) { - window_size_right = seqlen_k; + window_size_right = static_cast(seqlen_k); } #if defined(_MSC_VER) #pragma warning(pop) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index a965e00f6a391..12c594375d2ce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -238,7 +238,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // GQA CUDA uses either flash attention or memory efficient attention. Neither kernel supports returning the QK output. ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( context->OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + static_cast(Info().template GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 26bed32e3ceb1..3b667bf68634c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -377,7 +377,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { if (TryMatMulNBits( - nbits_, + static_cast(nbits_), reinterpret_cast(Y->MutableData()), reinterpret_cast(a_data), blob_data, diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 79b25ba91ebbb..87ff5ecbfe18c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -162,7 +162,7 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, block_size_)); -#if defined(__GNUC__) +#if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // Mute "maybe used uninitialized" warning for MoEParameters. #endif @@ -183,7 +183,7 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { GetDeviceProp()); } -#if defined(__GNUC__) +#if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic pop #endif } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index a3781c8e6cfa3..e2e60066ec36d 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -684,10 +684,10 @@ CudaBeamSearchScorer::CudaBeamSearchScorer(const transformers::IGenerationParame CUDA_CALL_THROW(cudaEventCreate(&event_process_complete_.Get())); state_cpu_ = AllocateCPUPinned(); - state_cpu_->batch_size_ = static_cast(parameters.batch_size); - state_cpu_->num_beams_ = static_cast(parameters.num_beams); - state_cpu_->max_length_ = static_cast(parameters.max_length); - state_cpu_->num_return_sequences_ = static_cast(parameters.num_return_sequences); + state_cpu_->batch_size_ = static_cast(parameters.batch_size); + state_cpu_->num_beams_ = static_cast(parameters.num_beams); + state_cpu_->max_length_ = static_cast(parameters.max_length); + state_cpu_->num_return_sequences_ = static_cast(parameters.num_return_sequences); state_cpu_->pad_token_id_ = parameters.pad_token_id; state_cpu_->eos_token_id_ = parameters.eos_token_id; state_cpu_->early_stopping_ = parameters.early_stopping; diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 3937ce3948de9..fe3b5a5f81bfb 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -20,11 +20,11 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { void Print(const char* name, const OrtValue& value) const override; void Print(const std::string& value) const override; -#define CUDA_DUMPER_PRINT_TYPE(dtype) \ - void Print(const char* name, const dtype* tensor, int dim0, int dim1) const; \ - void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const; \ - void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const; \ - void Print(const char* name, const dtype* tensor, gsl::span& dims) const; +#define CUDA_DUMPER_PRINT_TYPE(dtype) \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1) const override; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const override; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const override; \ + void Print(const char* name, const dtype* tensor, gsl::span& dims) const override; CUDA_DUMPER_PRINT_TYPE(int8_t) CUDA_DUMPER_PRINT_TYPE(uint8_t) @@ -35,6 +35,14 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { CUDA_DUMPER_PRINT_TYPE(BFloat16) CUDA_DUMPER_PRINT_TYPE(UInt4x2) CUDA_DUMPER_PRINT_TYPE(Int4x2) +#undef CUDA_DUMPER_PRINT_TYPE + +#define CUDA_DUMPER_PRINT_TYPE(dtype) \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1) const; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const; \ + void Print(const char* name, const dtype* tensor, gsl::span& dims) const; + CUDA_DUMPER_PRINT_TYPE(half) CUDA_DUMPER_PRINT_TYPE(__nv_bfloat16) #undef CUDA_DUMPER_PRINT_TYPE diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 00711e416e4e3..5990013c925c5 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -162,13 +162,7 @@ void CPUIDInfo::X86Init() { // Check for TPAUSE CheckIntelResult check_intel = CheckIntel(); if (check_intel.is_intel) { -#ifdef __linux__ -#if !defined(__ANDROID__) - has_tpause_ = __builtin_cpu_supports("waitpkg") != 0; -#endif -#else has_tpause_ = (data[2] & (1 << 5)) != 0; -#endif } if (max_SubLeaves >= 1) { GetCPUID(7, 1, data); diff --git a/onnxruntime/core/providers/cuda/cuda_mempool_arena.h b/onnxruntime/core/providers/cuda/cuda_mempool_arena.h index ec0e69bd33f3f..98f6abb0dd071 100644 --- a/onnxruntime/core/providers/cuda/cuda_mempool_arena.h +++ b/onnxruntime/core/providers/cuda/cuda_mempool_arena.h @@ -155,7 +155,6 @@ class CudaMempoolArena final : public IArena { // ---- Pool/context configuration (immutable) ---- uint64_t pool_release_threshold_; size_t bytes_to_keep_on_shrink_; - size_t initial_pool_size_bytes_; const logging::Logger* logger_; cudaMemPool_t pool_{nullptr}; diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index c3d8f2631feb4..31eb7b7fad43b 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1261,6 +1261,21 @@ TEST_F(PlannerTest, LocationPlanningForImplicitInputsWithoutExplicitConsumersInM // EXPECT_EQ(para_graph_plan->allocation_plan[input_data_index].location.device.Type(), OrtDevice::GPU); } +void ExpectExecutionStepTypeContains(const SessionState& state, + size_t stream_idx, + size_t step_idx, + const char* expected_type_name, + const char* message) { + const auto* execution_plan = state.GetExecutionPlan(); + ASSERT_NE(execution_plan, nullptr) << message; + ASSERT_LT(stream_idx, execution_plan->execution_plan.size()) << message; + const auto& steps = execution_plan->execution_plan[stream_idx]->steps_; + ASSERT_LT(step_idx, steps.size()) << message; + const auto* step = steps[step_idx].get(); + ASSERT_NE(step, nullptr) << message; + EXPECT_NE(strstr(typeid(*step).name(), expected_type_name), nullptr) << message; +} + // Test MultiStream scenario for the graph: // node1(CPU ep)->node2(CPU ep)->node3(CUDA ep)->node4(CPU ep) TEST_F(PlannerTest, MultiStream) { @@ -1288,18 +1303,18 @@ TEST_F(PlannerTest, MultiStream) { EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 2) << "2 logic streams for CPU and CUDA separately"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 6) << "CPU stream has 6 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[1]).name(), "LaunchKernelStep"), nullptr) << "1st step: LaunchKernelStep for node 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 3, no Activate/Wait step between node 2 and node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[3]).name(), "BarrierStep"), nullptr) << "3rd step: BarrierStep for node 4"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[4]).name(), "WaitOnEPStep"), nullptr) << "4th step: WaitOnEPStep for node 4"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[5]).name(), "LaunchKernelStep"), nullptr) << "5th step: LaunchKernelStep for node 4"; + ExpectExecutionStepTypeContains(GetState(), 0, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 1, "LaunchKernelStep", "1st step: LaunchKernelStep for node 2"); + ExpectExecutionStepTypeContains(GetState(), 0, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 3, no Activate/Wait step between node 2 and node 3"); + ExpectExecutionStepTypeContains(GetState(), 0, 3, "BarrierStep", "3rd step: BarrierStep for node 4"); + ExpectExecutionStepTypeContains(GetState(), 0, 4, "WaitOnEPStep", "4th step: WaitOnEPStep for node 4"); + ExpectExecutionStepTypeContains(GetState(), 0, 5, "LaunchKernelStep", "5th step: LaunchKernelStep for node 4"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[1]->steps_.size(), 4) << "CUDA stream has 4 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[0]).name(), "BarrierStep"), nullptr) << "0th step: BarrierStep for node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[1]).name(), "LaunchKernelStep"), nullptr) << "1st step: LaunchKernelStep for node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[2]).name(), "ActivateNotificationStep"), nullptr) << "2nd step: ActivateNofiticationStep by node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[3]).name(), "TriggerDownstreamStep"), nullptr) << "3rd step: TriggerDownstreamStep for node 4"; + ExpectExecutionStepTypeContains(GetState(), 1, 0, "BarrierStep", "0th step: BarrierStep for node 3"); + ExpectExecutionStepTypeContains(GetState(), 1, 1, "LaunchKernelStep", "1st step: LaunchKernelStep for node 3"); + ExpectExecutionStepTypeContains(GetState(), 1, 2, "ActivateNotificationStep", "2nd step: ActivateNofiticationStep by node 3"); + ExpectExecutionStepTypeContains(GetState(), 1, 3, "TriggerDownstreamStep", "3rd step: TriggerDownstreamStep for node 4"); } // Test execution plan for the graph: @@ -1328,21 +1343,21 @@ TEST_F(PlannerTest, MultiStream1StreamWaitFor2Streams) { EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 3) << "3 logic streams"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 3) << "stream 0 has 3 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[1]).name(), "ActivateNotificationStep"), nullptr) << "1st step: ActivateNofiticationStep by node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 0, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 1, "ActivateNotificationStep", "1st step: ActivateNofiticationStep by node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 3"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[1]->steps_.size(), 3) << "stream 1 has 3 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[1]).name(), "ActivateNotificationStep"), nullptr) << "1st step: ActivateNofiticationStep by node 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 1, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 2"); + ExpectExecutionStepTypeContains(GetState(), 1, 1, "ActivateNotificationStep", "1st step: ActivateNofiticationStep by node 2"); + ExpectExecutionStepTypeContains(GetState(), 1, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 3"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[2]->steps_.size(), 5) << "stream 2 has 5 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[2]->steps_[0]).name(), "BarrierStep"), nullptr) << "0th step: BarrierStep for node 3, for TriggerDownstreamStep in stream 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[2]->steps_[1]).name(), "BarrierStep"), nullptr) << "1st step: BarrierStep for node 3, for TriggerDownstreamStep in stream 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[2]->steps_[2]).name(), "WaitOnEPStep"), nullptr) << "2nd step: WaitOnEPStep for node 3, for ActivateNotificationStep in stream 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[2]->steps_[3]).name(), "WaitOnEPStep"), nullptr) << "3rd step: WaitOnEPStep for node 3, for ActivateNotificationStep in stream 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[2]->steps_[4]).name(), "LaunchKernelStep"), nullptr) << "4th step: LaunchKernelStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 2, 0, "BarrierStep", "0th step: BarrierStep for node 3, for TriggerDownstreamStep in stream 1"); + ExpectExecutionStepTypeContains(GetState(), 2, 1, "BarrierStep", "1st step: BarrierStep for node 3, for TriggerDownstreamStep in stream 2"); + ExpectExecutionStepTypeContains(GetState(), 2, 2, "WaitOnEPStep", "2nd step: WaitOnEPStep for node 3, for ActivateNotificationStep in stream 1"); + ExpectExecutionStepTypeContains(GetState(), 2, 3, "WaitOnEPStep", "3rd step: WaitOnEPStep for node 3, for ActivateNotificationStep in stream 2"); + ExpectExecutionStepTypeContains(GetState(), 2, 4, "LaunchKernelStep", "4th step: LaunchKernelStep for node 3"); } // Test execution plan for the graph: @@ -1353,16 +1368,16 @@ TEST_F(PlannerTest, MultiStreamCudaEPNodeCPUOutput) { MemcpyToHostInCuda_TransposeInCudaAndCpu("./testdata/multi_stream_models/memcpyToHost_same_stream_with_transpose.json"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 2) << "2 logic streams"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 5) << "stream 0 has 5 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[1]).name(), "ActivateNotificationStep"), nullptr) << "1st step: ActivateNofiticationStep by node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[3]).name(), "WaitOnEPStep"), nullptr) << "3rd step: WaitOnEPStep for node 3 in the same stream, as node 1's output is to CPU"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[4]).name(), "LaunchKernelStep"), nullptr) << "4th step: LaunchKernelStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 0, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 1, "ActivateNotificationStep", "1st step: ActivateNofiticationStep by node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 3"); + ExpectExecutionStepTypeContains(GetState(), 0, 3, "WaitOnEPStep", "3rd step: WaitOnEPStep for node 3 in the same stream, as node 1's output is to CPU"); + ExpectExecutionStepTypeContains(GetState(), 0, 4, "LaunchKernelStep", "4th step: LaunchKernelStep for node 3"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[1]->steps_.size(), 3) << "stream 1 has 3 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[0]).name(), "BarrierStep"), nullptr) << "0th step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[1]).name(), "WaitOnEPStep"), nullptr) << "1st step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[2]).name(), "LaunchKernelStep"), nullptr) << "2nd step: LaunchKernelStep for node 2"; + ExpectExecutionStepTypeContains(GetState(), 1, 0, "BarrierStep", "0th step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 1, "WaitOnEPStep", "1st step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 2, "LaunchKernelStep", "2nd step: LaunchKernelStep for node 2"); } // Test execution plan for the graph: @@ -1389,14 +1404,14 @@ TEST_F(PlannerTest, MultiStreamMultiOutput) { EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 2) << "2 logic streams"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 3) << "stream 0 has 3 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[1]).name(), "ActivateNotificationStep"), nullptr) << "1st step: ActivateNofiticationStep by node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 2"; + ExpectExecutionStepTypeContains(GetState(), 0, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 1, "ActivateNotificationStep", "1st step: ActivateNofiticationStep by node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 2"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[1]->steps_.size(), 3) << "stream 1 has 3 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[0]).name(), "BarrierStep"), nullptr) << "0th step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[1]).name(), "WaitOnEPStep"), nullptr) << "1st step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[2]).name(), "LaunchKernelStep"), nullptr) << "2nd step: LaunchKernelStep for node 2"; + ExpectExecutionStepTypeContains(GetState(), 1, 0, "BarrierStep", "0th step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 1, "WaitOnEPStep", "1st step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 2, "LaunchKernelStep", "2nd step: LaunchKernelStep for node 2"); } // Test execution plan for the graph: @@ -1427,19 +1442,19 @@ TEST_F(PlannerTest, MultiStream2NodesSameStreamConsumedBy1NodeInDifferentStream) EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 2) << "2 logic streams"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 6) << "stream 0 has 6 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[1]).name(), "ActivateNotificationStep"), nullptr) << "1st step: ActivateNofiticationStep by node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[3]).name(), "LaunchKernelStep"), nullptr) << "3rd step: LaunchKernelStep for node 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[4]).name(), "ActivateNotificationStep"), nullptr) << "4th step: ActivateNofiticationStep by node 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[5]).name(), "TriggerDownstreamStep"), nullptr) << "5th step: TriggerDownstreamStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 0, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 1, "ActivateNotificationStep", "1st step: ActivateNofiticationStep by node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 3"); + ExpectExecutionStepTypeContains(GetState(), 0, 3, "LaunchKernelStep", "3rd step: LaunchKernelStep for node 2"); + ExpectExecutionStepTypeContains(GetState(), 0, 4, "ActivateNotificationStep", "4th step: ActivateNofiticationStep by node 2"); + ExpectExecutionStepTypeContains(GetState(), 0, 5, "TriggerDownstreamStep", "5th step: TriggerDownstreamStep for node 3"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[1]->steps_.size(), 5) << "stream 1 has 5 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[0]).name(), "BarrierStep"), nullptr) << "0th step: BarrierStep for node 1, for TriggerDownstreamStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[1]).name(), "BarrierStep"), nullptr) << "1st step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[2]).name(), "WaitOnEPStep"), nullptr) << "2nd step: WaitOnEPStep for node 1, for ActivateNotificationStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[3]).name(), "WaitOnEPStep"), nullptr) << "3rd step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[4]).name(), "LaunchKernelStep"), nullptr) << "4th step: LaunchKernelStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 1, 0, "BarrierStep", "0th step: BarrierStep for node 1, for TriggerDownstreamStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 1, "BarrierStep", "1st step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 2, "WaitOnEPStep", "2nd step: WaitOnEPStep for node 1, for ActivateNotificationStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 3, "WaitOnEPStep", "3rd step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 4, "LaunchKernelStep", "4th step: LaunchKernelStep for node 3"); } #endif @@ -2078,7 +2093,8 @@ TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { int gather_count = 0; ASSERT_GT(plan->execution_plan.size(), 1) << "Number of execution plans should be greater than 1"; for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) { - if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) { + const auto* step = plan->execution_plan[1]->steps_[i].get(); + if (strstr(typeid(*step).name(), "LaunchKernelStep")) { const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex()); if (node->OpType() == "Gather") gather_count++;