diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index a53fd24ea55a9..385342479913a 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -992,10 +992,10 @@ if (onnxruntime_USE_WEBGPU) if (onnxruntime_USE_EXTERNAL_DAWN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_EXTERNAL_DAWN=1) endif() - if (onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) + if ((onnxruntime_ENABLE_DAWN_BACKEND_VULKAN AND WIN32) OR LINUX OR ANDROID) list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_VULKAN=1) endif() - if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12) + if (onnxruntime_ENABLE_DAWN_BACKEND_D3D12 AND WIN32) list(APPEND ORT_PROVIDER_FLAGS -DDAWN_ENABLE_D3D12=1) endif() if (onnxruntime_ENABLE_PIX_FOR_WEBGPU_EP) @@ -1189,7 +1189,9 @@ function(onnxruntime_set_compile_flags target_name) endif() if (onnxruntime_USE_CUDA) foreach(FLAG ${ORT_WARNING_FLAGS}) - target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") + if (NOT "${FLAG}" STREQUAL "-Wshorten-64-to-32") + target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") + endif() endforeach() if (NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f1b3b091bbc6e..5f1e166057b64 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -284,6 +284,7 @@ function(setup_kleidiai) target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/kai_ukernel_interface.cpp ${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/sbgemm_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp ${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp ) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 4d1600b5fab7f..ff667d8f117e0 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -181,7 +181,9 @@ endif() foreach(ORT_FLAG ${ORT_WARNING_FLAGS}) + if (NOT "${ORT_FLAG}" STREQUAL "-Wshorten-64-to-32") target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler \"${ORT_FLAG}\">") + endif() endforeach() # Note: The minimum required CUDA version is greater than 11.3. @@ -222,6 +224,38 @@ "$<$>:-Wno-reorder>") target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "IBMClang") + foreach(CLANG_WARNING + braced-scalar-init + defaulted-function-deleted + inconsistent-missing-override + instantiation-after-specialization + logical-op-parentheses + mismatched-tags + shorten-64-to-32 + unneeded-internal-declaration + unknown-warning-option + unused-private-field + unused-variable) + target_compile_options(${target} PRIVATE "$<$>:-Wno-error=${CLANG_WARNING}>") + endforeach() + if (CMAKE_CUDA_HOST_COMPILER_ID STREQUAL "Clang" OR CMAKE_CUDA_HOST_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CUDA_HOST_COMPILER_ID STREQUAL "IBMClang") + foreach(CLANG_WARNING + braced-scalar-init + defaulted-function-deleted + inconsistent-missing-override + instantiation-after-specialization + logical-op-parentheses + mismatched-tags + shorten-64-to-32 + unneeded-internal-declaration + unknown-warning-option + unused-private-field + unused-variable) + target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=${CLANG_WARNING}>") + endforeach() + endif() + endif() else() #mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4834>") diff --git a/js/package-lock.json b/js/package-lock.json index 0fca515b61238..22fb22757e94b 100644 --- a/js/package-lock.json +++ b/js/package-lock.json @@ -4282,9 +4282,9 @@ } }, "node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "dependencies": { "brace-expansion": "^1.1.7" @@ -8760,9 +8760,9 @@ } }, "minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "requires": { "brace-expansion": "^1.1.7" diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index 8e3f605deef7d..d8a273ef6825f 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -9895,7 +9895,9 @@ } }, "node_modules/minimatch": { - "version": "3.1.2", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "license": "ISC", "dependencies": { "brace-expansion": "^1.1.7" diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index 3350e741a5632..6073725939e87 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -6302,7 +6302,9 @@ } }, "node_modules/minimatch": { - "version": "3.1.2", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "license": "ISC", "dependencies": { diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 48877a85cf4c8..0e6d47f952c43 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1603,9 +1603,9 @@ } }, "node_modules/glob/node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "dependencies": { "brace-expansion": "^1.1.7" @@ -2258,9 +2258,9 @@ } }, "node_modules/karma/node_modules/minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "dependencies": { "brace-expansion": "^1.1.7" @@ -4953,9 +4953,9 @@ } }, "minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "requires": { "brace-expansion": "^1.1.7" @@ -5367,9 +5367,9 @@ } }, "minimatch": { - "version": "3.1.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz", - "integrity": "sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==", + "version": "3.1.4", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.4.tgz", + "integrity": "sha512-twmL+S8+7yIsE9wsqgzU3E8/LumN3M3QELrBZ20OdmQ9jB2JvW5oZtBEmft84k/Gs5CG9mqtWc6Y9vW+JEzGxw==", "dev": true, "requires": { "brace-expansion": "^1.1.7" diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu deleted file mode 100644 index d4e872f8ac165..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_128.cu +++ /dev/null @@ -1,67 +0,0 @@ -/* - * The implementation of this file is based on code provided by https://github.com/NVIDIA/FasterTransformer - * - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Modifications Copyright (c) Microsoft. -// Licensed under the MIT License. - -#include "decoder_masked_multihead_attention_impl.h" -#include "decoder_masked_multihead_attention_impl_utils.h" - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -using namespace decoder_masked_self_attention_details; - -#define MMHA_LAUNCH_KERNEL( \ - T, QK, head_size, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK) \ - size_t dynamic_block_memory = CalcDynamicBlockMemory(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) - -template -void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream) { - constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; - int total_sequence_length = params.total_sequence_length; - - if (total_sequence_length < 32) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 4, THREADS_PER_VALUE, 64); - } else if (total_sequence_length < 2048) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 2, THREADS_PER_VALUE, 128); - } else { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 1, THREADS_PER_VALUE, 256); - } -} - -// Instantiate templates -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu deleted file mode 100644 index 16f22b020ee1f..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_32.cu +++ /dev/null @@ -1,67 +0,0 @@ -/* - * The implementation of this file is based on code provided by https://github.com/NVIDIA/FasterTransformer - * - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Modifications Copyright (c) Microsoft. -// Licensed under the MIT License. - -#include "decoder_masked_multihead_attention_impl.h" -#include "decoder_masked_multihead_attention_impl_utils.h" - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -using namespace decoder_masked_self_attention_details; - -#define MMHA_LAUNCH_KERNEL( \ - T, QK, head_size, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK) \ - size_t dynamic_block_memory = CalcDynamicBlockMemory(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) - -template -void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream) { - constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; - int total_sequence_length = params.total_sequence_length; - - if (total_sequence_length < 32) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 4, THREADS_PER_VALUE, 64); - } else if (total_sequence_length < 2048) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 2, THREADS_PER_VALUE, 128); - } else { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 1, THREADS_PER_VALUE, 256); - } -} - -// Instantiate templates -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu deleted file mode 100644 index c933b0c6d2241..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_64.cu +++ /dev/null @@ -1,67 +0,0 @@ -/* - * The implementation of this file is based on code provided by https://github.com/NVIDIA/FasterTransformer - * - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Modifications Copyright (c) Microsoft. -// Licensed under the MIT License. - -#include "decoder_masked_multihead_attention_impl.h" -#include "decoder_masked_multihead_attention_impl_utils.h" - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -using namespace decoder_masked_self_attention_details; - -#define MMHA_LAUNCH_KERNEL( \ - T, QK, head_size, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK) \ - size_t dynamic_block_memory = CalcDynamicBlockMemory(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - masked_multihead_attention_kernel \ - <<>>(params) - -template -void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream) { - constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; - int total_sequence_length = params.total_sequence_length; - - if (total_sequence_length < 32) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 4, THREADS_PER_VALUE, 64); - } else if (total_sequence_length < 2048) { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 2, THREADS_PER_VALUE, 128); - } else { - MMHA_LAUNCH_KERNEL(T, QK, head_size, 1, THREADS_PER_VALUE, 256); - } -} - -// Instantiate templates -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); -template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters&, cudaStream_t stream); - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index efb48fee60772..b2fb743d28d92 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -819,6 +819,54 @@ template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); template void __global__ masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentionParameters params); +#define MMHA_LAUNCH_KERNEL_FOR_HEAD(T, QK, HEAD_SIZE, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK) \ + size_t dynamic_block_memory = CalcDynamicBlockMemory(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + masked_multihead_attention_kernel \ + <<>>(params) + +template +void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream) { + constexpr int THREADS_PER_VALUE = ThreadsPerValue::value; + const int total_sequence_length = params.total_sequence_length; + + if (total_sequence_length < 32) { + MMHA_LAUNCH_KERNEL_FOR_HEAD(T, QK, head_size, 4, THREADS_PER_VALUE, 64); + } else if (total_sequence_length < 2048) { + MMHA_LAUNCH_KERNEL_FOR_HEAD(T, QK, head_size, 2, THREADS_PER_VALUE, 128); + } else { + MMHA_LAUNCH_KERNEL_FOR_HEAD(T, QK, head_size, 1, THREADS_PER_VALUE, 256); + } +} + +#undef MMHA_LAUNCH_KERNEL_FOR_HEAD + +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); + +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); + +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); +template void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParameters& params, cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index fcc470b19a7b4..0d994d4060e6c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -106,16 +106,16 @@ void set_params_fprop(Flash_fwd_params& params, #pragma warning(disable : 4267) // Ignore conversion from 'size_t' to 'int', possible loss of data #pragma warning(disable : 4244) // Ignore conversion from 'double' to 'float', possible loss of data #endif - params.b = batch_size; - params.h = num_heads; - params.h_k = num_heads_k; - params.h_h_k_ratio = num_heads / num_heads_k; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = head_size; - params.d_rounded = head_size_rounded; + params.b = static_cast(batch_size); + params.h = static_cast(num_heads); + params.h_k = static_cast(num_heads_k); + params.h_h_k_ratio = static_cast(num_heads / num_heads_k); + params.seqlen_q = static_cast(seqlen_q); + params.seqlen_k = static_cast(seqlen_k); + params.seqlen_q_rounded = static_cast(seqlen_q_rounded); + params.seqlen_k_rounded = static_cast(seqlen_k_rounded); + params.d = static_cast(head_size); + params.d_rounded = static_cast(head_size_rounded); // Set the different scale values. if (softcap > 0.0) { @@ -136,10 +136,10 @@ void set_params_fprop(Flash_fwd_params& params, params.is_causal = false; } if (window_size_left < 0 && window_size_right >= 0) { - window_size_left = seqlen_k; + window_size_left = static_cast(seqlen_k); } if (window_size_left >= 0 && window_size_right < 0) { - window_size_right = seqlen_k; + window_size_right = static_cast(seqlen_k); } #if defined(_MSC_VER) #pragma warning(pop) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index a965e00f6a391..12c594375d2ce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -238,7 +238,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) cons // GQA CUDA uses either flash attention or memory efficient attention. Neither kernel supports returning the QK output. ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckNoQKOutput( context->OutputCount(), - static_cast(Info().GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); + static_cast(Info().template GetAttrOrDefault("qk_output", static_cast(QKOutputType::NO_OUTPUT))))); if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 26bed32e3ceb1..3b667bf68634c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -377,7 +377,7 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { if (TryMatMulNBits( - nbits_, + static_cast(nbits_), reinterpret_cast(Y->MutableData()), reinterpret_cast(a_data), blob_data, diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 79b25ba91ebbb..87ff5ecbfe18c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -162,7 +162,7 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, block_size_)); -#if defined(__GNUC__) +#if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // Mute "maybe used uninitialized" warning for MoEParameters. #endif @@ -183,7 +183,7 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { GetDeviceProp()); } -#if defined(__GNUC__) +#if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic pop #endif } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index a3781c8e6cfa3..e2e60066ec36d 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -684,10 +684,10 @@ CudaBeamSearchScorer::CudaBeamSearchScorer(const transformers::IGenerationParame CUDA_CALL_THROW(cudaEventCreate(&event_process_complete_.Get())); state_cpu_ = AllocateCPUPinned(); - state_cpu_->batch_size_ = static_cast(parameters.batch_size); - state_cpu_->num_beams_ = static_cast(parameters.num_beams); - state_cpu_->max_length_ = static_cast(parameters.max_length); - state_cpu_->num_return_sequences_ = static_cast(parameters.num_return_sequences); + state_cpu_->batch_size_ = static_cast(parameters.batch_size); + state_cpu_->num_beams_ = static_cast(parameters.num_beams); + state_cpu_->max_length_ = static_cast(parameters.max_length); + state_cpu_->num_return_sequences_ = static_cast(parameters.num_return_sequences); state_cpu_->pad_token_id_ = parameters.pad_token_id; state_cpu_->eos_token_id_ = parameters.eos_token_id; state_cpu_->early_stopping_ = parameters.early_stopping; diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h index 3937ce3948de9..fe3b5a5f81bfb 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.h @@ -20,11 +20,11 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { void Print(const char* name, const OrtValue& value) const override; void Print(const std::string& value) const override; -#define CUDA_DUMPER_PRINT_TYPE(dtype) \ - void Print(const char* name, const dtype* tensor, int dim0, int dim1) const; \ - void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const; \ - void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const; \ - void Print(const char* name, const dtype* tensor, gsl::span& dims) const; +#define CUDA_DUMPER_PRINT_TYPE(dtype) \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1) const override; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const override; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const override; \ + void Print(const char* name, const dtype* tensor, gsl::span& dims) const override; CUDA_DUMPER_PRINT_TYPE(int8_t) CUDA_DUMPER_PRINT_TYPE(uint8_t) @@ -35,6 +35,14 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::IConsoleDumper { CUDA_DUMPER_PRINT_TYPE(BFloat16) CUDA_DUMPER_PRINT_TYPE(UInt4x2) CUDA_DUMPER_PRINT_TYPE(Int4x2) +#undef CUDA_DUMPER_PRINT_TYPE + +#define CUDA_DUMPER_PRINT_TYPE(dtype) \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1) const; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2) const; \ + void Print(const char* name, const dtype* tensor, int dim0, int dim1, int dim2, int dim3) const; \ + void Print(const char* name, const dtype* tensor, gsl::span& dims) const; + CUDA_DUMPER_PRINT_TYPE(half) CUDA_DUMPER_PRINT_TYPE(__nv_bfloat16) #undef CUDA_DUMPER_PRINT_TYPE diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 00711e416e4e3..5990013c925c5 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -162,13 +162,7 @@ void CPUIDInfo::X86Init() { // Check for TPAUSE CheckIntelResult check_intel = CheckIntel(); if (check_intel.is_intel) { -#ifdef __linux__ -#if !defined(__ANDROID__) - has_tpause_ = __builtin_cpu_supports("waitpkg") != 0; -#endif -#else has_tpause_ = (data[2] & (1 << 5)) != 0; -#endif } if (max_SubLeaves >= 1) { GetCPUID(7, 1, data); diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 80817ff87d736..56849995656f3 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -2005,7 +2005,9 @@ struct MLAS_SBGEMM_DATA_PARAMS { const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ - bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */ + bool ZeroMode = true; /**< when true: C = A*B + Bias (if Bias != nullptr); + when false: C += A*B and Bias is ignored */ + bool BIsPacked = false; /**< Whether B is pre-packed */ }; /** @@ -2015,40 +2017,84 @@ struct MLAS_SBGEMM_DATA_PARAMS { * Note: We only support uniform batching, so shapes and types of the * input must be same across all parameter blocks. * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] TransA Supplies the transpose operation for matrix A. + * @param[in] TransB Supplies the transpose operation for matrix B. + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] ThreadPool + * @param[in] BackendKernelSelectorConfig Supplies the backend kernel selector + configuration options, else nullptr if the + default configuration should be used. * @return */ void MLASCALL -MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool); +MlasSBGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SBGEMM_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +); /** * @brief For bfloat16 precision GEMM, returns size of the * packing buffer needed for right hand side - * @param[in] N Number of columns - * @param[in] K Number of rows - * @return size of the packing buffer, - * 0 if operation not supported + * @param[in] TransA Supplies the transpose operation for matrix A. + * @param[in] TransB Supplies the transpose operation for matrix B. + * @param[in] BIsfp32 Is matrix B datatype FP32 + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] BackendKernelSelectorConfig Supplies the backend kernel selector + configuration options, else nullptr if the + default configuration should be used. + * @return size of the packing buffer, + * 0 if operation not supported */ size_t MLASCALL -MlasSBGemmPackBSize(size_t N, size_t K); +MlasSBGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + bool BIsfp32, + size_t N, + size_t K, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +); /** * @brief For bfloat16 precision GEMM, convert the float matrix B * to blfoat16 precision and pack it into a packing buffer * - * @param[in] N Number of columns - * @param[in] K Number of rows - * @param[in] B Address of matrix B - * @param[in] ldb leading dimension of input matrix B - * @param[out] PackedB Address of the packed matrix + * @param[in] TransA Supplies the transpose operation for matrix A. + * @param[in] TransB Supplies the transpose operation for matrix B. + * @param[in] BIsfp32 Is matrix B datatype FP32 + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + * @param[in] BackendKernelSelectorConfig Supplies the backend kernel selector + configuration options, else nullptr if the + default configuration should be used. */ void MLASCALL -MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +MlasSBGemmConvertPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + bool BIsfp32, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +); #endif /** diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp index 6a25447c43e09..6ee80594c6b49 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp @@ -30,9 +30,11 @@ #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" // SME2 kernels -// GEMM/QGEMM +// GEMM/QGEMM/SBGEMM #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" + // GEMV #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" @@ -227,6 +229,9 @@ const KaiF32IMatmulKernel imatmul_conv_sme = const KaiF32IMatmulKernel imatmul_conv_sme2 = KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa); +const KaiBF16SBgemmKernel sbgemm_gemm_sme2 = + KAI_WRAP_UKERNEL_RUN_MATMUL_11(matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa); + #if defined(ENABLE_QMX_KERNELS) const KaiF32IMatmulKernel imatmul_conv_qmx = KAI_WRAP_UKERNEL_RUN_IMATMUL_PACKED_7(imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_qmx_mopa); @@ -362,3 +367,8 @@ const KaiDynamicQGemmKernel& GetKleidiAIQGemmUKernel() { #endif // ENABLE_QMX_KERNELS } } + +const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel() { + // Currently only SME2 variant exists for bfloat16/SBGEMM kernel + return sbgemm_gemm_sme2; +} diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h index aa3be8b47cb88..155ecf1762b3b 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h @@ -12,6 +12,8 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h" + #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h" #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p_f32p_interface.h" @@ -39,6 +41,8 @@ using KaiDynamicQGemmKernel = KaiMatmulKernel; +using KaiBF16SBgemmKernel = KaiMatmulKernel; + // Returns the selected Qnbit GEMM ukernel based on runtime CPU capabilities. const KaiQnbitGemmKernel& GetKleidiAIGemmUKernel(); @@ -56,3 +60,6 @@ const KaiF32SgemvKernel& GetKleidiAISGemvUKernel(); // Returns the selected FP32 IMATMUL ukernel used by the KleidiAI convolution implementation. const KaiF32IMatmulKernel& GetKleidiAIF32IMatmulUKernel(); + +// Returns the selected BF16 SBGEMM ukernel used by the KleidiAI based on runtime CPU capabilities. +const KaiBF16SBgemmKernel& GetKleidiAISBGemmUKernel(); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index f00efce7d5f88..a1aa241b89299 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -105,6 +105,42 @@ MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ); +#if defined(__aarch64__) && defined(__linux__) +size_t +MLASCALL +MlasSBGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K + ); + +bool +MLASCALL +MlasSBGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ); + +bool +MLASCALL +MlasSBGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SBGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); +#endif + size_t MLASCALL MlasDynamicQGemmPackBSize( diff --git a/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp new file mode 100644 index 0000000000000..f88af056aa156 --- /dev/null +++ b/onnxruntime/core/mlas/lib/kleidiai/sbgemm_kleidiai.cpp @@ -0,0 +1,449 @@ +// +// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#if defined(__aarch64__) && defined(__linux__) + +#include +#include +#include +#include + +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.h" + +#include "mlas.h" + +#include "mlasi_kleidiai.h" +#include "kai_ukernel_interface.h" + +// Thread-local reusable buffers to reduce allocation overhead across tiles. +struct KaiTlsBuffersSbgemm { + std::vector output_tile; + std::vector bias_zero; + std::vector rhs_packed; + std::vector lhs_packed; +}; +static thread_local KaiTlsBuffersSbgemm g_kai_tls_sbgemm; + +const KaiBF16SBgemmKernel& sbgemm_gemm = GetKleidiAISBGemmUKernel(); + +/*++ +Routine Description: + Accumulate src into dst: dst[i,j] += src[i,j], respecting ldc. + +Arguments: + src - Pointer to the temporary A*B results (row-major, rows x cols). + rows - Number of rows in the tile. + cols - Number of columns in the tile. + dst - Pointer to the destination tile in C (row-major with leading dimension ldc). + ldc - Leading dimension of C (in elements). + +Notes: + Implements the accumulation path for SBGEMM when ZeroMode == false. +--*/ +static inline void AccumulateTile(const float* src, + size_t rows, + size_t cols, + float* dst, + size_t ldc) { + if (ldc == cols) { + // contiguous block in memory: add elementwise across whole block + size_t elems = rows * cols; + for (size_t i = 0; i < elems; ++i) { + dst[i] += src[i]; + } + } else { + // general case with row stride and a column offset + for (size_t i = 0; i < rows; ++i) { + const float* src_row = src + i * cols; + float* dst_row = dst + i * ldc; + for (size_t j = 0; j < cols; ++j) { + dst_row[j] += src_row[j]; + } + } + } +} + +/*++ +Routine Description: + Apply bias to a 2-D tile (rows x cols). + +Arguments: + src - Pointer to the temporary A*B results (row-major, rows x cols). + rows - Number of rows in the tile. + cols - Number of columns in the tile. + bias - Pointer to the bias vector or nullptr if no bias. + dst - Pointer to the destination tile in C (row-major with leading dimension ldc). + ldc - Leading dimension of C (in elements). + start_col - Starting column index of the tile (NIdx * n_step). + +Notes: + Uses a row by row memcpy path when no bias. +--*/ +static inline void ApplyBias2D(const float* src, + size_t rows, + size_t cols, + const float* bias, + float* dst, + size_t ldc, + size_t start_col) { + for (size_t i = 0; i < rows; ++i) { + const float* src_row = src + i * cols; + float* dst_row = dst + i * ldc; + + if (bias != nullptr) { + for (size_t j = 0; j < cols; ++j) { + dst_row[j] = src_row[j] + bias[start_col + j]; + } + } else { + // No bias but can't memcpy whole so needs to be done row by row. + memcpy(dst_row, src_row, cols * sizeof(float)); + } + } +} + +size_t +MLASCALL +ArmKleidiAI::MlasSBGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K +) +/*++ + +Routine Description: + + This routine computes the length in bytes for the packed matrix B buffer. + +Arguments: + + TransA - Supplies the transpose operation on A matrix. + + TransB - Supplies the transpose operation on B matrix. + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + +Return Value: + + Returns the size in bytes for the packed matrix B buffer. + +--*/ +{ + if (TransA != CblasNoTrans || TransB != CblasNoTrans || N == 0 || K == 0) { + KLEIDIAI_DEBUG_LOG("MlasSBGemmPackBSize returning 0 size. N=" << N << " K=" << K); + return 0; + } + // + // Compute the number of bytes required to hold the packed buffer. + // + size_t bytes = 0; + bytes = kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(N, K); + + return bytes; +} + +bool +MLASCALL +ArmKleidiAI::MlasSBGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB +) +/*++ + +Routine Description: + + This routine packs the contents of matrix B to the destination buffer. The + destination buffer should be sized based on MlasSBGemmPackBSize(). For best + performance, the destination buffer should be aligned to the value returned + from MlasGetPreferredBufferAlignment(). + +Arguments: + + TransA - Supplies the transpose operation on A matrix. + + TransB - Supplies the transpose operation on B matrix. + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + + B - Supplies the address of matrix B. + + ldb - Supplies the first dimension of matrix B. + + PackedB - Supplies the address of packed matrix B. + +Return Value: + + Returns true if the packing operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ +{ + if (TransA != CblasNoTrans || TransB != CblasNoTrans || N == 0 || K == 0) { + KLEIDIAI_DEBUG_LOG("MlasSBGemmPackB one of N or K is 0, falling back to MLAS."); + return false; + } + + const size_t nr = sbgemm_gemm.ukernel.get_nr(); + const size_t kr = sbgemm_gemm.ukernel.get_kr(); + const size_t sr = sbgemm_gemm.ukernel.get_sr(); + + // Ensure size and zero the used span. + g_kai_tls_sbgemm.bias_zero.resize(N, 0.0f); + + kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls_sbgemm.bias_zero.data(), nullptr, PackedB, 0, nullptr); + + return true; +} + +bool +MLASCALL +ArmKleidiAI::MlasSBGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SBGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) +/*++ + +Routine Description: + + This routine performs a bfloat16 batched matrix multiplication (SBGEMM) operation using KleidiAI kernels. + If packing is needed, it prepares the required buffers and invokes the + appropriate left-hand side (LHS) and right-hand side (RHS) pack functions. + +Arguments: + + TransA - Supplies the transpose operation on A matrix. + + TransB - Supplies the transpose operation on B matrix. + + M - Supplies the number of rows of matrix A and matrix C. + + N - Supplies the number of columns of matrix B and matrix C. + + K - Supplies the number of columns of matrix A and rows of matrix B. + + Data - Supplies a pointer to the MLAS_SBGEMM_DATA_PARAMS array containing per-batch input/output pointers and parameters. + + BatchSize - Supplies the number of independent GEMM computations to perform in the batch. + + ThreadPool - Supplies the thread pool to parallelize computation across batches and tiles. + +Return Value: + + Returns true if the GEMM operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ +{ + if (TransA != CblasNoTrans || TransB != CblasNoTrans || K == 0) { + return false; + } + + if (M == 0 || N == 0 || BatchSize == 0) { + return true; + } + + size_t m_step = sbgemm_gemm.ukernel.get_m_step(); + size_t n_step = sbgemm_gemm.ukernel.get_n_step(); + + if ((M < m_step || N < n_step) && !Data->BIsPacked) { + // Fallback + return false; + } + + const size_t mr = sbgemm_gemm.ukernel.get_mr(); + const size_t kr = sbgemm_gemm.ukernel.get_kr(); + const size_t sr = sbgemm_gemm.ukernel.get_sr(); + + size_t LhsPackedStride = 0; + std::byte* LhsPackedData = nullptr; + + LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr); + + size_t lhs_resize = 0; + if (mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) + { + // size_t wraparound detected for LhsPackedStride, fallback to MLAS + return false; + } + + g_kai_tls_sbgemm.lhs_packed.resize(lhs_resize); + LhsPackedData = g_kai_tls_sbgemm.lhs_packed.data(); + + // RHS packed buffer: use TLS reusable vector to minimize allocations + size_t RhsPackedStride = 0; + std::byte* RhsPackedData = nullptr; + + // It is assumed all B batches require packing or not + if (Data[0].BIsPacked) { + // We have already decided the matmul variant we are using, before having values for M,N,K + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + KLEIDIAI_KERNEL_LOG("kai_run_lhs_pack_bf16p2vlx2_f32_sme" << " M=" << M << " K=" << K << " mr=" << mr << " kr=" << kr << " sr=" << sr); + kai_run_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + }); + } else { + // Multithread pack lhs and rhs + RhsPackedStride = ArmKleidiAI::MlasSBGemmPackBSize(TransA, TransB, N, K); + size_t rhs_resize = 0; + if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize)) + { + // size_t wraparound detected for RhsPackedStride, fallback to MLAS + return false; + } + + g_kai_tls_sbgemm.rhs_packed.resize(rhs_resize); + RhsPackedData = g_kai_tls_sbgemm.rhs_packed.data(); + + MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) { + if (batch_idx & 0x1) { + batch_idx >>= 1; + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + KLEIDIAI_KERNEL_LOG("kai_run_lhs_pack_bf16p2vlx2_f32_sme" + << " M=" << M << " K=" << K << " mr=" << mr << " kr=" << kr << " sr=" << sr); + kai_run_lhs_pack_bf16p2vlx2_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + } else { + batch_idx >>= 1; + std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]); + ArmKleidiAI::MlasSBGemmPackB(TransA, TransB, N, K, + reinterpret_cast(Data[batch_idx].B), + Data[batch_idx].ldb, RhsPackedPtr); + } + }); + } + + // tile iteration dimensions + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(M, m_step); // M + dim[2] = MlasDivRoundup(N, n_step); // N + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step); + + // update tile iterations + dim[1] = MlasDivRoundup(M, m_step); + dim[2] = MlasDivRoundup(N, n_step); + + // Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop. + // Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively. + size_t max_tile_elems = 0; + if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) { + // size_t wraparound detected for tile size, fallback to MLAS + return false; + } + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + // compute B,M,N index from iteration index + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = sbgemm_gemm.ukernel.get_rhs_packed_offset(NIdx * n_step, K); + + const std::byte* B_base = Data[0].BIsPacked + ? reinterpret_cast(Data[BIdx].B) + : (RhsPackedData + RhsPackedStride * BIdx); + auto BTile = reinterpret_cast(B_base + rhs_packed_offset); + + // Get lhs tile, A + const size_t lhs_packed_offset = sbgemm_gemm.ukernel.get_lhs_packed_offset(MIdx * m_step, K); + + const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx; + auto ATile = reinterpret_cast(A_base + lhs_packed_offset); + + auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = reinterpret_cast( + reinterpret_cast(Data[BIdx].C) + + MIdx * m_step * Data[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) + ); + + // Final output tile and bias pointers + float* dst_tile = reinterpret_cast(CTile); + const float* bias = Data[BIdx].Bias; + const size_t ldc = Data[BIdx].ldc; + + // Select output destination and strides once, then run_matmul exactly once. + const bool direct_to_c = ( + bias == nullptr && + Data[BIdx].ZeroMode && + TileSizeM != 0 && + TileSizeN != 0); + + float* out_tile = nullptr; + size_t out_row_stride_bytes = 0; + + if (direct_to_c) { + out_tile = dst_tile; + out_row_stride_bytes = ldc * sizeof(float); + } else { + // Compute into a temporary buffer for raw A*B result (TLS reusable buffer) + const size_t tile_elems = TileSizeM * TileSizeN; + g_kai_tls_sbgemm.output_tile.resize(tile_elems); + out_tile = g_kai_tls_sbgemm.output_tile.data(); + out_row_stride_bytes = TileSizeN * sizeof(float); + } + + KLEIDIAI_KERNEL_LOG(sbgemm_gemm.name + << " M=" << TileSizeM << " N=" << TileSizeN << " K=" << K); + sbgemm_gemm.ukernel.run_matmul( + TileSizeM, + TileSizeN, + K, + ATile, BTile, out_tile, + out_row_stride_bytes, sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + + if (!direct_to_c) { + if (Data[BIdx].ZeroMode) { + ApplyBias2D(out_tile, TileSizeM, TileSizeN, bias, dst_tile, ldc, NIdx * n_step); + } else { + AccumulateTile(out_tile, TileSizeM, TileSizeN, dst_tile, ldc); + } + } + + if (Data[BIdx].OutputProcessor != nullptr) { + Data[BIdx].OutputProcessor->Process( + Data[BIdx].C, + MIdx * m_step, + NIdx * n_step, + TileSizeM, + TileSizeN, + Data[BIdx].ldc); + } + }); + return true; +} +#endif diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 3c0ee29896cd9..954849fe90049 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -906,6 +906,39 @@ void const float* Bias, void* PackedB); +#if defined(__aarch64__) && defined(__linux__) +typedef +bool +(MLASCALL MLAS_SBGEMM_BATCH_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SBGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef +size_t +(MLASCALL MLAS_SBGEMM_PACK_B_SIZE_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef +bool +(MLASCALL MLAS_SBGEMM_PACK_B_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); +#endif + extern "C" { #if defined(MLAS_TARGET_AMD64_IX86) @@ -1364,6 +1397,13 @@ struct MLAS_PLATFORM { // MLAS Conv overrides MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; +#if defined(__aarch64__) && defined(__linux__) + // SBGemm overrides + MLAS_SBGEMM_BATCH_OVERRIDE* MlasSBGemmBatchOverride = nullptr; + MLAS_SBGEMM_PACK_B_SIZE_OVERRIDE* MlasSBGemmPackBSizeOverride = nullptr; + MLAS_SBGEMM_PACK_B_OVERRIDE* MlasSBGemmPackBOverride = nullptr; +#endif + #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 12dcd61b8840e..ac3761d63bd20 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -445,7 +445,6 @@ Return Value: if ((Cpuid7_1[0] & 0x10) != 0) { - this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; @@ -500,7 +499,6 @@ Return Value: if ((Cpuid7[2] & 0x800) != 0) { - this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Vnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; @@ -540,7 +538,6 @@ Return Value: (Cpuid7[3] & 0b1 << 25) != 0 && (xcr0 & XFEATURE_MASK_XTILE) == XFEATURE_MASK_XTILE) { if (MlasInitAMX()) { - this->GemmU8U8Dispatch = &MlasGemmU8S8DispatchAmx; this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx; } } @@ -626,6 +623,14 @@ Return Value: this->MlasDynamicQGemmPackBOverride = ArmKleidiAI::MlasDynamicQGemmPackB; this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare; this->MlasConvOverride = ArmKleidiAI::MlasConv; +#if defined(__aarch64__) && defined(__linux__) + // Currently only an SME2 variant of SBGEMM exists + if (ArmKleidiAI::UseSME2){ + this->MlasSBGemmBatchOverride = ArmKleidiAI::MlasSBGemmBatch; + this->MlasSBGemmPackBSizeOverride = ArmKleidiAI::MlasSBGemmPackBSize; + this->MlasSBGemmPackBOverride = ArmKleidiAI::MlasSBGemmPackB; + } +#endif } #endif diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp index 479a82e712c5e..bef8e1f800fd3 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_amx.cpp @@ -219,101 +219,6 @@ MlasGemmQuantThreadInit() } -static inline -void -InitHalfTileWithRowColSums( - int32_t* Tile, - const int32_t* rowsum_ptr, - const __m512i colsum, - const int32_t* c_ptr, - const size_t ldc, - bool ZeroMode - ) -{ - __m512i row0,row1,row2,row3,row4,row5,row6,row7; - row0 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[0])); - row1 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[1])); - row2 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[2])); - row3 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[3])); - row4 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[4])); - row5 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[5])); - row6 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[6])); - row7 = _mm512_add_epi32(colsum, _mm512_set1_epi32(rowsum_ptr[7])); - if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); - } - _mm512_storeu_si512(Tile, row0); - _mm512_storeu_si512(Tile+16, row1); - _mm512_storeu_si512(Tile+32, row2); - _mm512_storeu_si512(Tile+48, row3); - _mm512_storeu_si512(Tile+64, row4); - _mm512_storeu_si512(Tile+80, row5); - _mm512_storeu_si512(Tile+96, row6); - _mm512_storeu_si512(Tile+112, row7); - //Tile += 128; - //rowsum_ptr+=8; - //c_ptr += ldc * 8; -} - -static inline -void -InitHalfTileWithRowColSumsZeroPoints( - int32_t* Tile, - const int32_t* rowsum_ptr, - const __m512i colsum, - const __m512i zeropoint, - const int32_t* c_ptr, - const size_t ldc, - bool ZeroMode - ) -{ - __m512i row0,row1,row2,row3,row4,row5,row6,row7; - row0 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[0])); - row1 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[1])); - row2 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[2])); - row3 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[3])); - row4 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[4])); - row5 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[5])); - row6 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[6])); - row7 = _mm512_mullo_epi32(zeropoint, _mm512_set1_epi32(rowsum_ptr[7])); - row0 = _mm512_add_epi32(colsum, row0); - row1 = _mm512_add_epi32(colsum, row1); - row2 = _mm512_add_epi32(colsum, row2); - row3 = _mm512_add_epi32(colsum, row3); - row4 = _mm512_add_epi32(colsum, row4); - row5 = _mm512_add_epi32(colsum, row5); - row6 = _mm512_add_epi32(colsum, row6); - row7 = _mm512_add_epi32(colsum, row7); - if (!ZeroMode){ - row0 = _mm512_add_epi32(row0, _mm512_loadu_si512(c_ptr)); - row1 = _mm512_add_epi32(row1, _mm512_loadu_si512(c_ptr+ldc)); - row2 = _mm512_add_epi32(row2, _mm512_loadu_si512(c_ptr+ldc*2)); - row3 = _mm512_add_epi32(row3, _mm512_loadu_si512(c_ptr+ldc*3)); - row4 = _mm512_add_epi32(row4, _mm512_loadu_si512(c_ptr+ldc*4)); - row5 = _mm512_add_epi32(row5, _mm512_loadu_si512(c_ptr+ldc*5)); - row6 = _mm512_add_epi32(row6, _mm512_loadu_si512(c_ptr+ldc*6)); - row7 = _mm512_add_epi32(row7, _mm512_loadu_si512(c_ptr+ldc*7)); - } - _mm512_storeu_si512(Tile, row0); - _mm512_storeu_si512(Tile+16, row1); - _mm512_storeu_si512(Tile+32, row2); - _mm512_storeu_si512(Tile+48, row3); - _mm512_storeu_si512(Tile+64, row4); - _mm512_storeu_si512(Tile+80, row5); - _mm512_storeu_si512(Tile+96, row6); - _mm512_storeu_si512(Tile+112, row7); - //Tile += 128; - //rowsum_ptr+=8; - //c_ptr += ldc * 8; -} - static inline void @@ -636,178 +541,147 @@ MlasGemmQuantKernel( } - int32_t* c_blk = C; // C - beginning of the row + constexpr uint16_t FullMask = 0xFFFF; + int32_t* c_blk = C; int32_t* c16_blk = C + ldc * TILE_M; - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedBType* b_blk = B; // restart B + const MLAS_GEMM_U8S8_KERNEL_AMX::PackedBType* b_blk = B; const int32_t* col_sum_ptr = ColumnSumBuffer; const int32_t* zp_ptr = ZeroPointB; size_t n = CountN; for (; n >= 2 * TILE_N; n -= 2 * TILE_N) { - // Restart A from row start - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; - const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; - - if (ZeroPointB != nullptr){ - __m512i colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; + __m512i colsum = _mm512_loadu_si512(col_sum_ptr); + col_sum_ptr += TILE_N; + if (ZeroPointB != nullptr) { __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; - tile_loadd(TMM0, b_blk, TILE_K); - InitHalfTileWithRowColSumsZeroPoints(Tile4, RowSumBuffer, colsum, zeropoint, c_blk, ldc, ZeroMode); - tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSumsZeroPoints(Tile4+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - InitHalfTileWithRowColSumsZeroPoints(Tile5, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk, ldc, ZeroMode); - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSumsZeroPoints(Tile5+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; - zeropoint = _mm512_loadu_si512(zp_ptr); + InitTileWithRowColSumsZeroPoints( + Tile4, TILE_M, FullMask, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); + InitTileWithRowColSumsZeroPoints( + Tile5, TILE_M, FullMask, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); + } else { + InitTileWithRowColSums(Tile4, TILE_M, FullMask, RowSumBuffer, colsum, ZeroMode, c_blk, ldc); + InitTileWithRowColSums(Tile5, TILE_M, FullMask, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); + } + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + + colsum = _mm512_loadu_si512(col_sum_ptr); + col_sum_ptr += TILE_N; + if (ZeroPointB != nullptr) { + __m512i zeropoint = _mm512_loadu_si512(zp_ptr); zp_ptr += TILE_N; - InitHalfTileWithRowColSumsZeroPoints(Tile6, RowSumBuffer, colsum, zeropoint, c_blk+TILE_N, ldc, ZeroMode); - tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - InitHalfTileWithRowColSumsZeroPoints(Tile6+128, RowSumBuffer+8, colsum, zeropoint, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); - tile_dpbusd(TMM4, TMM2, TMM0); - InitHalfTileWithRowColSumsZeroPoints(Tile7, RowSumBuffer+TILE_M, colsum, zeropoint, c16_blk+TILE_N, ldc, ZeroMode); - InitHalfTileWithRowColSumsZeroPoints(Tile7+128, RowSumBuffer+TILE_M+8, colsum, zeropoint, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); + InitTileWithRowColSumsZeroPoints( + Tile6, TILE_M, FullMask, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); + InitTileWithRowColSumsZeroPoints( + Tile7, TILE_M, FullMask, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); } else { - __m512i colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; + InitTileWithRowColSums(Tile6, TILE_M, FullMask, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); + InitTileWithRowColSums(Tile7, TILE_M, FullMask, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); + } + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + + const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; + const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; + for (size_t k = PackedCountK; k > 0; k -= TILE_K) { tile_loadd(TMM0, b_blk, TILE_K); - InitHalfTileWithRowColSums(Tile4, RowSumBuffer, colsum, c_blk, ldc, ZeroMode); tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSums(Tile4+128, RowSumBuffer+8, colsum, c_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); - InitHalfTileWithRowColSums(Tile5, RowSumBuffer+TILE_M, colsum, c16_blk, ldc, ZeroMode); - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - InitHalfTileWithRowColSums(Tile5+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8, ldc, ZeroMode); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); - colsum = _mm512_loadu_si512(col_sum_ptr); - col_sum_ptr += TILE_N; - InitHalfTileWithRowColSums(Tile6, RowSumBuffer, colsum, c_blk+TILE_N, ldc, ZeroMode); tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - InitHalfTileWithRowColSums(Tile6+128, RowSumBuffer+8, colsum, c_blk+ldc*8+TILE_N, ldc, ZeroMode); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); + tile_dpbusd(TMM4, TMM2, TMM0); - InitHalfTileWithRowColSums(Tile7, RowSumBuffer+TILE_M, colsum, c16_blk+TILE_N, ldc, ZeroMode); - InitHalfTileWithRowColSums(Tile7+128, RowSumBuffer+TILE_M+8, colsum, c16_blk+ldc*8+TILE_N, ldc, ZeroMode); - } - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + tile_dpbusd(TMM5, TMM3, TMM0); + tile_dpbusd(TMM6, TMM2, TMM1); + tile_dpbusd(TMM7, TMM3, TMM1); - for (size_t k = PackedCountK - TILE_K; k > 0; k -= TILE_K) { b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; - tile_dpbusd(TMM5, TMM3, TMM0); - tile_loadd(TMM0, b_blk, TILE_K); - tile_dpbusd(TMM6, TMM2, TMM1); - tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); - tile_dpbusd(TMM7, TMM3, TMM1); - tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - tile_dpbusd(TMM4, TMM2, TMM0); } - tile_dpbusd(TMM5, TMM3, TMM0); - tile_dpbusd(TMM6, TMM2, TMM1); - tile_dpbusd(TMM7, TMM3, TMM1); - b_blk += PackedCountK * TILE_N + TILE_N * TILE_K; - tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); tile_stored(TMM6, (void*)(c_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); c_blk += 2 * TILE_N; - tile_stored(TMM7, (void*)(c16_blk + TILE_N), static_cast(ldc * sizeof(int32_t))); c16_blk += 2 * TILE_N; + b_blk += PackedCountK * TILE_N; } if (n != 0) { const uint16_t nmask_high = static_cast(nmasks >> 16); __m512i colsum = _mm512_maskz_loadu_epi32(static_cast(nmasks), col_sum_ptr); col_sum_ptr += TILE_N; - if (ZeroPointB != nullptr){ + if (ZeroPointB != nullptr) { __m512i zeropoint = _mm512_maskz_loadu_epi32(static_cast(nmasks), zp_ptr); zp_ptr += TILE_N; InitTileWithRowColSumsZeroPoints( - Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, - zeropoint, ZeroMode, c_blk, ldc); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk, ldc); InitTileWithRowColSumsZeroPoints( - Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, - zeropoint, ZeroMode, c16_blk, ldc); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk, ldc); } else { InitTileWithRowColSums( - Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, - ZeroMode, c_blk, ldc); - tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + Tile4, TILE_M, static_cast(nmasks), RowSumBuffer, colsum, ZeroMode, c_blk, ldc); InitTileWithRowColSums( - Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, - ZeroMode, c16_blk, ldc); - tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + Tile5, TILE_M, static_cast(nmasks), RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk, ldc); } - if (nmask_high != 0){ + tile_loadd(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_loadd(TMM5, Tile5, TILE_N * sizeof(int32_t)); + + if (nmask_high != 0) { colsum = _mm512_maskz_loadu_epi32(nmask_high, col_sum_ptr); - if (ZeroPointB != nullptr){ + if (ZeroPointB != nullptr) { __m512i zeropoint = _mm512_maskz_loadu_epi32(nmask_high, zp_ptr); InitTileWithRowColSumsZeroPoints( - Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, - zeropoint, ZeroMode, c_blk + TILE_N, ldc); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, zeropoint, ZeroMode, c_blk + TILE_N, ldc); InitTileWithRowColSumsZeroPoints( - Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, - zeropoint, ZeroMode, c16_blk + TILE_N, ldc); - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, zeropoint, ZeroMode, c16_blk + TILE_N, ldc); } else { InitTileWithRowColSums( - Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, - ZeroMode, c_blk + TILE_N, ldc); - tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + Tile6, TILE_M, nmask_high, RowSumBuffer, colsum, ZeroMode, c_blk + TILE_N, ldc); InitTileWithRowColSums( - Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, - ZeroMode, c16_blk + TILE_N, ldc); - tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); + Tile7, TILE_M, nmask_high, RowSumBuffer + TILE_M, colsum, ZeroMode, c16_blk + TILE_N, ldc); } + tile_loadd(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_loadd(TMM7, Tile7, TILE_N * sizeof(int32_t)); } const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_blk = A; const MLAS_GEMM_U8S8_KERNEL_AMX::PackedAType* a_next_blk = A + PackedCountK * TILE_M; - for (size_t k = PackedCountK; k > 0; k -=TILE_K) { - tile_loadd(TMM0, b_blk, TILE_K); + for (size_t k = PackedCountK; k > 0; k -= TILE_K) { + tile_loadd(TMM0, b_blk, TILE_K); tile_loadd(TMM2, a_blk, static_cast(PackedCountK)); tile_loadd(TMM3, a_next_blk, static_cast(PackedCountK)); - tile_dpbusd(TMM4, TMM2, TMM0); + tile_dpbusd(TMM4, TMM2, TMM0); tile_dpbusd(TMM5, TMM3, TMM0); - if (nmask_high != 0){ + if (nmask_high != 0) { tile_loadd(TMM1, (void*)(b_blk + PackedCountK * TILE_N), TILE_K); - tile_dpbusd(TMM6, TMM2, TMM1); + tile_dpbusd(TMM6, TMM2, TMM1); tile_dpbusd(TMM7, TMM3, TMM1); - } + b_blk += TILE_N * TILE_K; a_blk += TILE_K; a_next_blk += TILE_K; } - if ((static_cast(nmasks) & 0x8000) != 0){ - tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); - tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); + if ((static_cast(nmasks) & 0x8000) != 0) { + tile_stored(TMM4, c_blk, static_cast(ldc * sizeof(int32_t))); + tile_stored(TMM5, c16_blk, static_cast(ldc * sizeof(int32_t))); } else { - tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); - MoveTile(Tile4, TILE_M, static_cast(nmasks), c_blk, ldc); MoveTile(Tile5, TILE_M, static_cast(nmasks), c16_blk, ldc); } - if (nmask_high != 0){ - tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); - tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + if (nmask_high != 0) { + tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); MoveTile(Tile6, TILE_M, nmask_high, c_blk + TILE_N, ldc); MoveTile(Tile7, TILE_M, nmask_high, c16_blk + TILE_N, ldc); } diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp index a6dbe8defd0e4..c3998dabd1d90 100644 --- a/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_avx2.cpp @@ -377,7 +377,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8U8DispatchAvx2Vnni = { MlasGemmQuantCopyPackB, MLAS_GEMM_U8U8_KERNEL_AVX2VNNI::PackedK, MLAS_GEMM_U8U8_KERNEL_AVX2VNNI::PackedStrides.K, - 6 // assembly kernel M stride + 4 // avoid the legacy >4-row fallback, which assumes non-VNNI packing }; // S8S8 AVX-VNNI-INT8 support @@ -450,7 +450,7 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchAvx2Vnni = { MlasGemmQuantCopyPackB, MLAS_GEMM_S8S8_KERNEL_AVX2::PackedK, MLAS_GEMM_S8S8_KERNEL_AVX2::PackedStrides.K, - 6 // assembly kernel M stride + 4 // avoid the legacy >4-row fallback, which assumes non-VNNI packing }; // S8U8 AVX-VNNI-INT8 support @@ -523,5 +523,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8U8DispatchAvx2Vnni = { MlasGemmQuantCopyPackB, MLAS_GEMM_S8U8_KERNEL_AVX2::PackedK, MLAS_GEMM_S8U8_KERNEL_AVX2::PackedStrides.K, - 6 // assembly kernel M stride + 4 // avoid the legacy >4-row fallback, which assumes non-VNNI packing }; diff --git a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp index f41b380b2a071..f3a17f30be166 100644 --- a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp @@ -44,6 +44,9 @@ MlasConvPointwiseBf16KernelNeon( const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config; + mlas_backend_kernel_selector_config.use_kleidiai = ((KernelFlags & MLAS_CONV_KERNEL_MLAS_ARM_USE_KLEIDIAI) != 0); + const size_t StrideWidthElements = StrideWidth / sizeof(float); const size_t InputStrideElements = InputStride / sizeof(float); const size_t FilterStrideElements = FilterStride / sizeof(float); @@ -91,7 +94,7 @@ MlasConvPointwiseBf16KernelNeon( } } - MlasSBGemmBatch(OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr); + MlasSBGemmBatch(CblasNoTrans, CblasNoTrans, OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr, &mlas_backend_kernel_selector_config); if (ReluActivation) { const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h index 5415cb3dc4406..559c8b48e78b6 100644 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -299,11 +299,36 @@ MlasSBGemmGetDispatch() } size_t MLASCALL -MlasSBGemmPackBSize(size_t N, size_t K) +MlasSBGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + bool BIsfp32, + size_t N, + size_t K, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +) { // // Compute the number of bytes required to hold the packed buffer. // +#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasSBGemmPackBSizeOverride != nullptr && + TransA == CBLAS_TRANSPOSE::CblasNoTrans && + TransB == CBLAS_TRANSPOSE::CblasNoTrans && + BIsfp32) { + size_t bytes_required; + bytes_required = GetMlasPlatform().MlasSBGemmPackBSizeOverride(TransA, TransB, N, K); + if (bytes_required != 0){ // If ArmKleidiAI::MlasSBGemmPackBSize ran to completion + return bytes_required; + } + } +#endif + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + MLAS_UNREFERENCED_PARAMETER(BIsfp32); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); + const auto* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return 0; @@ -322,8 +347,33 @@ MlasSBGemmPackBSize(size_t N, size_t K) } void MLASCALL -MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) +MlasSBGemmConvertPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + bool BIsfp32, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +) { +#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasSBGemmPackBOverride != nullptr && + TransA == CBLAS_TRANSPOSE::CblasNoTrans && + TransB == CBLAS_TRANSPOSE::CblasNoTrans && + BIsfp32 && + GetMlasPlatform().MlasSBGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)){ + return; + } +#endif + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + MLAS_UNREFERENCED_PARAMETER(BIsfp32); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); + const auto* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return; @@ -331,8 +381,33 @@ MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* Pac } void MLASCALL -MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) +MlasSBGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const MLAS_SBGEMM_DATA_PARAMS* Data, + MLAS_THREADPOOL* ThreadPool, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* BackendKernelSelectorConfig +) { +#if defined(USE_KLEIDIAI) && !defined(MLAS_USE_ARM_NEON_NCHWC) + if ((!BackendKernelSelectorConfig || BackendKernelSelectorConfig->use_kleidiai) && + GetMlasPlatform().MlasSBGemmBatchOverride != nullptr && + TransA == CBLAS_TRANSPOSE::CblasNoTrans && + TransB == CBLAS_TRANSPOSE::CblasNoTrans && + Data->AIsfp32 && + (Data->BIsPacked || Data->BIsfp32) && + GetMlasPlatform().MlasSBGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchN, ThreadPool)){ + return; + } +#endif + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + MLAS_UNREFERENCED_PARAMETER(BackendKernelSelectorConfig); + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); if (dispatch == nullptr) return; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index e2846ebda8b9d..8a7795a81027d 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -137,10 +137,12 @@ Status MatMul::Compute(OpKernelContext* ctx) const { #if defined(__aarch64__) && defined(__linux__) bool GemmPackBBfloat16(AllocatorPtr& alloc, const Tensor& tensor_b, + bool trans_a, bool trans_b, IAllocatorUniquePtr& packed_b, size_t& packed_b_size, - TensorShape& b_shape) { + TensorShape& b_shape, + const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config) { // Only handle the common case of a 2D weight matrix. Additional matrices // could be handled by stacking the packed buffers. if (tensor_b.Shape().NumDimensions() != 2) { @@ -152,7 +154,12 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc, const size_t K = trans_b ? static_cast(b_shape[1]) : static_cast(b_shape[0]); const size_t N = trans_b ? static_cast(b_shape[0]) : static_cast(b_shape[1]); - packed_b_size = MlasSBGemmPackBSize(N, K); + packed_b_size = MlasSBGemmPackBSize(trans_a ? CBLAS_TRANSPOSE::CblasTrans : CBLAS_TRANSPOSE::CblasNoTrans, + trans_b ? CBLAS_TRANSPOSE::CblasTrans : CBLAS_TRANSPOSE::CblasNoTrans, + true, + N, + K, + mlas_backend_kernel_selector_config); if (packed_b_size == 0) { return false; } @@ -164,11 +171,15 @@ bool GemmPackBBfloat16(AllocatorPtr& alloc, // buffer memory and we don not want it uninitialized and generate different hashes // if and when we try to cache this pre-packed buffer for sharing between sessions. memset(packed_b_data, 0, packed_b_size); - MlasSBGemmConvertPackB(N, + MlasSBGemmConvertPackB(trans_a ? CBLAS_TRANSPOSE::CblasTrans : CBLAS_TRANSPOSE::CblasNoTrans, + trans_b ? CBLAS_TRANSPOSE::CblasTrans : CBLAS_TRANSPOSE::CblasNoTrans, + true, + N, K, tensor_b.Data(), trans_b ? K : N, - packed_b_data); + packed_b_data, + mlas_backend_kernel_selector_config); return true; } #endif @@ -191,8 +202,8 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc dim2 = static_cast(b_shape[1]); } - if (use_fastmath_mode_ && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { - is_packed = GemmPackBBfloat16(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + if (use_fastmath_mode_ && (trans_a_attr_ == 0) && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { + is_packed = GemmPackBBfloat16(alloc, tensor, trans_a_attr_ != 0, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_, &mlas_backend_kernel_selector_config_); } else #endif { @@ -260,7 +271,7 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); #if defined(__aarch64__) && defined(__linux__) - if (use_fastmath_mode_ && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { + if (use_fastmath_mode_ && !trans_a && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { std::vector data(max_len); for (size_t i = 0; i < max_len; i++) { data[i].BIsfp32 = !(bool(packed_b_)); @@ -273,8 +284,10 @@ Status MatMul::Compute(OpKernelContext* ctx) const { data[i].ldc = N; data[i].Bias = nullptr; data[i].OutputProcessor = nullptr; + data[i].BIsPacked = static_cast(packed_b_); } - MlasSBGemmBatch(M, N, K, max_len, data.data(), thread_pool); + MlasSBGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, + M, N, K, max_len, data.data(), thread_pool, &mlas_backend_kernel_selector_config_); } else #endif { diff --git a/onnxruntime/core/providers/cuda/cuda_mempool_arena.h b/onnxruntime/core/providers/cuda/cuda_mempool_arena.h index ec0e69bd33f3f..98f6abb0dd071 100644 --- a/onnxruntime/core/providers/cuda/cuda_mempool_arena.h +++ b/onnxruntime/core/providers/cuda/cuda_mempool_arena.h @@ -155,7 +155,6 @@ class CudaMempoolArena final : public IArena { // ---- Pool/context configuration (immutable) ---- uint64_t pool_release_threshold_; size_t bytes_to_keep_on_shrink_; - size_t initial_pool_size_bytes_; const logging::Logger* logger_; cudaMemPool_t pool_{nullptr}; diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index c3d8f2631feb4..31eb7b7fad43b 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1261,6 +1261,21 @@ TEST_F(PlannerTest, LocationPlanningForImplicitInputsWithoutExplicitConsumersInM // EXPECT_EQ(para_graph_plan->allocation_plan[input_data_index].location.device.Type(), OrtDevice::GPU); } +void ExpectExecutionStepTypeContains(const SessionState& state, + size_t stream_idx, + size_t step_idx, + const char* expected_type_name, + const char* message) { + const auto* execution_plan = state.GetExecutionPlan(); + ASSERT_NE(execution_plan, nullptr) << message; + ASSERT_LT(stream_idx, execution_plan->execution_plan.size()) << message; + const auto& steps = execution_plan->execution_plan[stream_idx]->steps_; + ASSERT_LT(step_idx, steps.size()) << message; + const auto* step = steps[step_idx].get(); + ASSERT_NE(step, nullptr) << message; + EXPECT_NE(strstr(typeid(*step).name(), expected_type_name), nullptr) << message; +} + // Test MultiStream scenario for the graph: // node1(CPU ep)->node2(CPU ep)->node3(CUDA ep)->node4(CPU ep) TEST_F(PlannerTest, MultiStream) { @@ -1288,18 +1303,18 @@ TEST_F(PlannerTest, MultiStream) { EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 2) << "2 logic streams for CPU and CUDA separately"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 6) << "CPU stream has 6 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[1]).name(), "LaunchKernelStep"), nullptr) << "1st step: LaunchKernelStep for node 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 3, no Activate/Wait step between node 2 and node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[3]).name(), "BarrierStep"), nullptr) << "3rd step: BarrierStep for node 4"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[4]).name(), "WaitOnEPStep"), nullptr) << "4th step: WaitOnEPStep for node 4"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[5]).name(), "LaunchKernelStep"), nullptr) << "5th step: LaunchKernelStep for node 4"; + ExpectExecutionStepTypeContains(GetState(), 0, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 1, "LaunchKernelStep", "1st step: LaunchKernelStep for node 2"); + ExpectExecutionStepTypeContains(GetState(), 0, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 3, no Activate/Wait step between node 2 and node 3"); + ExpectExecutionStepTypeContains(GetState(), 0, 3, "BarrierStep", "3rd step: BarrierStep for node 4"); + ExpectExecutionStepTypeContains(GetState(), 0, 4, "WaitOnEPStep", "4th step: WaitOnEPStep for node 4"); + ExpectExecutionStepTypeContains(GetState(), 0, 5, "LaunchKernelStep", "5th step: LaunchKernelStep for node 4"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[1]->steps_.size(), 4) << "CUDA stream has 4 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[0]).name(), "BarrierStep"), nullptr) << "0th step: BarrierStep for node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[1]).name(), "LaunchKernelStep"), nullptr) << "1st step: LaunchKernelStep for node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[2]).name(), "ActivateNotificationStep"), nullptr) << "2nd step: ActivateNofiticationStep by node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[3]).name(), "TriggerDownstreamStep"), nullptr) << "3rd step: TriggerDownstreamStep for node 4"; + ExpectExecutionStepTypeContains(GetState(), 1, 0, "BarrierStep", "0th step: BarrierStep for node 3"); + ExpectExecutionStepTypeContains(GetState(), 1, 1, "LaunchKernelStep", "1st step: LaunchKernelStep for node 3"); + ExpectExecutionStepTypeContains(GetState(), 1, 2, "ActivateNotificationStep", "2nd step: ActivateNofiticationStep by node 3"); + ExpectExecutionStepTypeContains(GetState(), 1, 3, "TriggerDownstreamStep", "3rd step: TriggerDownstreamStep for node 4"); } // Test execution plan for the graph: @@ -1328,21 +1343,21 @@ TEST_F(PlannerTest, MultiStream1StreamWaitFor2Streams) { EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 3) << "3 logic streams"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 3) << "stream 0 has 3 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[1]).name(), "ActivateNotificationStep"), nullptr) << "1st step: ActivateNofiticationStep by node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 0, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 1, "ActivateNotificationStep", "1st step: ActivateNofiticationStep by node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 3"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[1]->steps_.size(), 3) << "stream 1 has 3 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[1]).name(), "ActivateNotificationStep"), nullptr) << "1st step: ActivateNofiticationStep by node 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 1, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 2"); + ExpectExecutionStepTypeContains(GetState(), 1, 1, "ActivateNotificationStep", "1st step: ActivateNofiticationStep by node 2"); + ExpectExecutionStepTypeContains(GetState(), 1, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 3"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[2]->steps_.size(), 5) << "stream 2 has 5 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[2]->steps_[0]).name(), "BarrierStep"), nullptr) << "0th step: BarrierStep for node 3, for TriggerDownstreamStep in stream 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[2]->steps_[1]).name(), "BarrierStep"), nullptr) << "1st step: BarrierStep for node 3, for TriggerDownstreamStep in stream 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[2]->steps_[2]).name(), "WaitOnEPStep"), nullptr) << "2nd step: WaitOnEPStep for node 3, for ActivateNotificationStep in stream 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[2]->steps_[3]).name(), "WaitOnEPStep"), nullptr) << "3rd step: WaitOnEPStep for node 3, for ActivateNotificationStep in stream 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[2]->steps_[4]).name(), "LaunchKernelStep"), nullptr) << "4th step: LaunchKernelStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 2, 0, "BarrierStep", "0th step: BarrierStep for node 3, for TriggerDownstreamStep in stream 1"); + ExpectExecutionStepTypeContains(GetState(), 2, 1, "BarrierStep", "1st step: BarrierStep for node 3, for TriggerDownstreamStep in stream 2"); + ExpectExecutionStepTypeContains(GetState(), 2, 2, "WaitOnEPStep", "2nd step: WaitOnEPStep for node 3, for ActivateNotificationStep in stream 1"); + ExpectExecutionStepTypeContains(GetState(), 2, 3, "WaitOnEPStep", "3rd step: WaitOnEPStep for node 3, for ActivateNotificationStep in stream 2"); + ExpectExecutionStepTypeContains(GetState(), 2, 4, "LaunchKernelStep", "4th step: LaunchKernelStep for node 3"); } // Test execution plan for the graph: @@ -1353,16 +1368,16 @@ TEST_F(PlannerTest, MultiStreamCudaEPNodeCPUOutput) { MemcpyToHostInCuda_TransposeInCudaAndCpu("./testdata/multi_stream_models/memcpyToHost_same_stream_with_transpose.json"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 2) << "2 logic streams"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 5) << "stream 0 has 5 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[1]).name(), "ActivateNotificationStep"), nullptr) << "1st step: ActivateNofiticationStep by node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[3]).name(), "WaitOnEPStep"), nullptr) << "3rd step: WaitOnEPStep for node 3 in the same stream, as node 1's output is to CPU"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[4]).name(), "LaunchKernelStep"), nullptr) << "4th step: LaunchKernelStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 0, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 1, "ActivateNotificationStep", "1st step: ActivateNofiticationStep by node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 3"); + ExpectExecutionStepTypeContains(GetState(), 0, 3, "WaitOnEPStep", "3rd step: WaitOnEPStep for node 3 in the same stream, as node 1's output is to CPU"); + ExpectExecutionStepTypeContains(GetState(), 0, 4, "LaunchKernelStep", "4th step: LaunchKernelStep for node 3"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[1]->steps_.size(), 3) << "stream 1 has 3 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[0]).name(), "BarrierStep"), nullptr) << "0th step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[1]).name(), "WaitOnEPStep"), nullptr) << "1st step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[2]).name(), "LaunchKernelStep"), nullptr) << "2nd step: LaunchKernelStep for node 2"; + ExpectExecutionStepTypeContains(GetState(), 1, 0, "BarrierStep", "0th step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 1, "WaitOnEPStep", "1st step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 2, "LaunchKernelStep", "2nd step: LaunchKernelStep for node 2"); } // Test execution plan for the graph: @@ -1389,14 +1404,14 @@ TEST_F(PlannerTest, MultiStreamMultiOutput) { EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 2) << "2 logic streams"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 3) << "stream 0 has 3 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[1]).name(), "ActivateNotificationStep"), nullptr) << "1st step: ActivateNofiticationStep by node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 2"; + ExpectExecutionStepTypeContains(GetState(), 0, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 1, "ActivateNotificationStep", "1st step: ActivateNofiticationStep by node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 2"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[1]->steps_.size(), 3) << "stream 1 has 3 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[0]).name(), "BarrierStep"), nullptr) << "0th step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[1]).name(), "WaitOnEPStep"), nullptr) << "1st step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[2]).name(), "LaunchKernelStep"), nullptr) << "2nd step: LaunchKernelStep for node 2"; + ExpectExecutionStepTypeContains(GetState(), 1, 0, "BarrierStep", "0th step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 1, "WaitOnEPStep", "1st step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 2, "LaunchKernelStep", "2nd step: LaunchKernelStep for node 2"); } // Test execution plan for the graph: @@ -1427,19 +1442,19 @@ TEST_F(PlannerTest, MultiStream2NodesSameStreamConsumedBy1NodeInDifferentStream) EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan.size(), 2) << "2 logic streams"; EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[0]->steps_.size(), 6) << "stream 0 has 6 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[0]).name(), "LaunchKernelStep"), nullptr) << "0th step: LaunchKernelStep for node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[1]).name(), "ActivateNotificationStep"), nullptr) << "1st step: ActivateNofiticationStep by node 1"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[2]).name(), "TriggerDownstreamStep"), nullptr) << "2nd step: TriggerDownstreamStep for node 3"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[3]).name(), "LaunchKernelStep"), nullptr) << "3rd step: LaunchKernelStep for node 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[4]).name(), "ActivateNotificationStep"), nullptr) << "4th step: ActivateNofiticationStep by node 2"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[0]->steps_[5]).name(), "TriggerDownstreamStep"), nullptr) << "5th step: TriggerDownstreamStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 0, 0, "LaunchKernelStep", "0th step: LaunchKernelStep for node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 1, "ActivateNotificationStep", "1st step: ActivateNofiticationStep by node 1"); + ExpectExecutionStepTypeContains(GetState(), 0, 2, "TriggerDownstreamStep", "2nd step: TriggerDownstreamStep for node 3"); + ExpectExecutionStepTypeContains(GetState(), 0, 3, "LaunchKernelStep", "3rd step: LaunchKernelStep for node 2"); + ExpectExecutionStepTypeContains(GetState(), 0, 4, "ActivateNotificationStep", "4th step: ActivateNofiticationStep by node 2"); + ExpectExecutionStepTypeContains(GetState(), 0, 5, "TriggerDownstreamStep", "5th step: TriggerDownstreamStep for node 3"); EXPECT_EQ(GetState().GetExecutionPlan()->execution_plan[1]->steps_.size(), 5) << "stream 1 has 5 steps"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[0]).name(), "BarrierStep"), nullptr) << "0th step: BarrierStep for node 1, for TriggerDownstreamStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[1]).name(), "BarrierStep"), nullptr) << "1st step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[2]).name(), "WaitOnEPStep"), nullptr) << "2nd step: WaitOnEPStep for node 1, for ActivateNotificationStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[3]).name(), "WaitOnEPStep"), nullptr) << "3rd step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"; - EXPECT_NE(strstr(typeid(*GetState().GetExecutionPlan()->execution_plan[1]->steps_[4]).name(), "LaunchKernelStep"), nullptr) << "4th step: LaunchKernelStep for node 3"; + ExpectExecutionStepTypeContains(GetState(), 1, 0, "BarrierStep", "0th step: BarrierStep for node 1, for TriggerDownstreamStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 1, "BarrierStep", "1st step: BarrierStep for node 2, for TriggerDownstreamStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 2, "WaitOnEPStep", "2nd step: WaitOnEPStep for node 1, for ActivateNotificationStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 3, "WaitOnEPStep", "3rd step: WaitOnEPStep for node 2, for ActivateNotificationStep in stream 0"); + ExpectExecutionStepTypeContains(GetState(), 1, 4, "LaunchKernelStep", "4th step: LaunchKernelStep for node 3"); } #endif @@ -2078,7 +2093,8 @@ TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { int gather_count = 0; ASSERT_GT(plan->execution_plan.size(), 1) << "Number of execution plans should be greater than 1"; for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) { - if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) { + const auto* step = plan->execution_plan[1]->steps_[i].get(); + if (strstr(typeid(*step).name(), "LaunchKernelStep")) { const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex()); if (node->OpType() == "Gather") gather_count++; diff --git a/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h index 40f688a16ecca..4170bcea6a1ea 100644 --- a/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h +++ b/onnxruntime/test/mlas/unittest/test_qgemm_fixture.h @@ -94,6 +94,11 @@ class QgemmShortExecuteTest : public Ml test_registered += RegisterSingleTest(1, 32, b, 5, 0, 0); } } + // Conv-like qgemm shapes that exposed AMD64 dispatch bugs in the VNNI/AMX paths. + test_registered += RegisterSingleTest(6, 30, 207, 1, 183, 223); + test_registered += RegisterSingleTest(6, 30, 207, 1, 17); + test_registered += RegisterSingleTest(169, 30, 207, 1, 183, 223); + test_registered += RegisterSingleTest(169, 30, 207, 1, 17); test_registered += RegisterSingleTest(43, 500, 401, 1, 183, 223); test_registered += RegisterSingleTest(1023, 1023, 1023, 1, 5, 8); test_registered += RegisterSingleTest(1023, 1023, 1023, 1, 7); @@ -169,6 +174,7 @@ class QgemmShortExecuteTest : public Mlas for (size_t b = 1; b < 96; b++) { test_registered += RegisterSingleTest(1, b, 32, 0, 0); } + test_registered += RegisterSingleTest(169, 30, 207, 183, 223); test_registered += RegisterSingleTest(43, 503, 401, 183, 223); test_registered += RegisterSingleTest(1024, 1024, 256, 13, 15); diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp index f85fe97776dc1..2a8e01b9dda3a 100644 --- a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -92,17 +92,62 @@ class SBGemmShortExecuteTest : public MlasTestFixture +class SBGemmAccumulateExecuteTest : public MlasTestFixture> { + public: + explicit SBGemmAccumulateExecuteTest(size_t M, size_t N, size_t K, size_t Batch) + : M_(M), N_(N), K_(K), Batch_(Batch) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->TestAccumulate(M_, N_, K_, Batch_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch) { + std::stringstream ss; + ss << "Accumulate/Batch" << Batch << "/M" << M << "xN" << N << "xK" << K; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSBGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SBGemmAccumulateExecuteTest(M, N, K, Batch); + }); + + return 1; + } + + static size_t RegisterAccumulateTests() { + size_t test_registered = 0; + test_registered += RegisterSingleTest(1, 1, 1, 1); + test_registered += RegisterSingleTest(7, 9, 13, 1); + test_registered += RegisterSingleTest(32, 32, 32, 1); + if (!Packed) { + test_registered += RegisterSingleTest(5, 7, 3, 4); + } + return test_registered; + } + + private: + size_t M_, N_, K_, Batch_; +}; + static size_t SBGemmRegistLongExecute() { size_t count = 0; count += MlasLongExecuteTests>::RegisterLongExecute(); - if (MlasSBGemmPackBSize(128, 128) > 0) { + if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) { count += MlasLongExecuteTests>::RegisterLongExecute(); } if (GetMlasThreadPool() != nullptr) { count += MlasLongExecuteTests>::RegisterLongExecute(); - if (MlasSBGemmPackBSize(128, 128) > 0) { + if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) { count += MlasLongExecuteTests>::RegisterLongExecute(); } } @@ -114,14 +159,18 @@ static size_t SBGemmRegistShortExecute() { size_t count = 0; count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); - if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmAccumulateExecuteTest::RegisterAccumulateTests(); + if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) { count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + count += SBGemmAccumulateExecuteTest::RegisterAccumulateTests(); } if (GetMlasThreadPool() != nullptr) { count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); - if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmAccumulateExecuteTest::RegisterAccumulateTests(); + if (MlasSBGemmPackBSize(CblasNoTrans, CblasNoTrans, true, 128, 128, nullptr) > 0) { count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + count += SBGemmAccumulateExecuteTest::RegisterAccumulateTests(); } } diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.h b/onnxruntime/test/mlas/unittest/test_sbgemm.h index 13701e2e3de46..af5b1a34be8f0 100644 --- a/onnxruntime/test/mlas/unittest/test_sbgemm.h +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h @@ -60,14 +60,15 @@ class MlasSBGemmTest : public MlasTestBase { MatrixGuardBuffer BufferFloatC; MLAS_THREADPOOL* threadpool_; - void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { - size_t PackedBSize = MlasSBGemmPackBSize(N, K); + void* PackB(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, const BType* B, size_t ldb) { + const bool BIsfp32 = std::is_same::value; + size_t PackedBSize = MlasSBGemmPackBSize(TransA, TransB, BIsfp32, N, K, nullptr); if (PackedBSize == 0) { return nullptr; } void* PackedB = BufferBPacked.GetBuffer(PackedBSize); if (std::is_same::value) { - MlasSBGemmConvertPackB(N, K, (const float*)B, ldb, PackedB); + MlasSBGemmConvertPackB(TransA, TransB, BIsfp32, N, K, (const float*)B, ldb, PackedB, nullptr); } else { } return PackedB; @@ -83,7 +84,10 @@ class MlasSBGemmTest : public MlasTestBase { size_t ldb, const float* Bias, float* C, - size_t ldc) { + size_t ldc, + bool ZeroMode = true) { + constexpr CBLAS_TRANSPOSE TransA = CblasNoTrans; + constexpr CBLAS_TRANSPOSE TransB = CblasNoTrans; std::vector GemmParameters(BatchSize); for (size_t i = 0; i < GemmParameters.size(); i++) { @@ -99,10 +103,13 @@ class MlasSBGemmTest : public MlasTestBase { params.ldc = ldc; params.AIsfp32 = true; params.BIsfp32 = true; + params.BIsPacked = false; + params.ZeroMode = ZeroMode; if (Packed) { ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; - params.B = PackB(N, K, B, ldb); + params.B = PackB(TransA, TransB, N, K, B, ldb); + params.BIsPacked = true; params.ldb = 0; params.BIsfp32 = false; } else { @@ -111,7 +118,7 @@ class MlasSBGemmTest : public MlasTestBase { } } - MlasSBGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); + MlasSBGemmBatch(TransA, TransB, M, N, K, BatchSize, GemmParameters.data(), threadpool_, nullptr); } void ReferenceSgemm(size_t M, @@ -186,12 +193,14 @@ class MlasSBGemmTest : public MlasTestBase { ReferenceSgemm(M, N, K, BatchSize, A, B, Bias, CReference); const float cosine_similarity_threshold = 0.98; - for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t batch = 0; batch < BatchSize; batch++) { for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++, f++) { + for (size_t n = 0; n < N; n++) { + // Compute flat index to avoid desync if we break + const size_t f = batch * M * N + m * N + n; if (!(CloseEnough(float(C[f]), CReference[f]))) { float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); - if (abs(cos_sim) < cosine_similarity_threshold) { + if (std::abs(cos_sim) < cosine_similarity_threshold) { ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; } else { break; @@ -208,6 +217,81 @@ class MlasSBGemmTest : public MlasTestBase { } } + void TestAccumulate(size_t M, size_t N, size_t K, size_t BatchSize) { + AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + AType Atail[16]; + std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); + + BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); + BType Btail[16]; + std::memcpy(Btail, B + N * K * BatchSize, 16 * sizeof(BType)); + + const float* Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); + float BiasTail[16]; + std::memcpy(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)); + + float* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* CReference = BufferCReference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + + this->CallSBGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N, true); + ReferenceSgemm(M, N, K, BatchSize, A, B, Bias, CReference); + + const float cosine_similarity_threshold = 0.98; + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + // Compute flat index to avoid desync if we break + const size_t f = batch * M * N + m * N + n; + if (!(CloseEnough(float(C[f]), CReference[f]))) { + float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); + if (std::abs(cos_sim) < cosine_similarity_threshold) { + ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; + } else { + break; + } + } + } + } + } + + float* CNoBias = BufferFloatC.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + ReferenceSgemm(M, N, K, BatchSize, A, B, nullptr, CNoBias); + for (size_t i = 0, size = N * M * BatchSize; i < size; i++) { + CReference[i] += CNoBias[i]; + } + + this->CallSBGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N, false); + + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + // Compute flat index to avoid desync if we break + const size_t f = batch * M * N + m * N + n; + if (!(CloseEnough(float(C[f]), CReference[f]))) { + float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); + if (std::abs(cos_sim) < cosine_similarity_threshold) { + ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; + } else { + break; + } + } + } + } + } + + ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; + ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; + ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)), 0) << "Bias buffer overwritten!"; + } + private: public: static const char* GetTestSuiteName() {