From 759e6e20e800bd11333be83845b64fe1fc766f25 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Mon, 5 Jan 2026 12:27:25 +0000 Subject: [PATCH 01/12] Add padding support with transpose Also move check before writing storing is_src_valid during reading --- ...ead_group_tensor_slice_transfer_global.hpp | 66 +++++++++++++++++-- .../grid/gridwise_ab_transfer_wave_tiles.hpp | 4 -- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 8 +-- 3 files changed, 61 insertions(+), 17 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp index 701c786c86a..036d5c274fa 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp @@ -160,6 +160,8 @@ struct ThreadGroupTransferGlobal // check if src element is valid const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + oob_thread_scratch_. + template SetAsType(vgpr_data_idx_seq, is_src_valid); // Vector length of elementwise operation constexpr auto get_elem_op_vec_len = []() { @@ -195,14 +197,13 @@ struct ThreadGroupTransferGlobal using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; - using vector_t = typename vector_type_maker::type::type; - dst_vector_type op_r_v; // Load data from memory in src_vector first + auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0; src_vector_container src_vector = src_vector_container{grid_buf.template Get( - src_coord_.GetOffset(), true)}; + index, true)}; // apply the src elementwise op and convert to DstData under the hood if needed static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) { @@ -213,9 +214,9 @@ struct ThreadGroupTransferGlobal // store result in dvgpr_ (static array holding loaded data). // At this point data is already converted to DstData type and // the elementwise operation has been applied - dvgpr_.template SetAsType( + src_dvgpr_.template SetAsType( vgpr_data_idx_seq, - is_src_valid ? op_r_v.template AsType()[I0] : vector_t(0)); + op_r_v.template AsType()[I0]); // For each dimension move fwd, bwd or don't move static_for<0, nDim, 1>{}([&](auto i) { @@ -248,6 +249,39 @@ struct ThreadGroupTransferGlobal container_reorder_given_new2old(src_access_lengths, src_dim_access_order); constexpr auto ordered_fwd_step = StepsPerIteration{}; + // OOB check + static_ford{}([&](auto ordered_src_access_idx) { + // calculate src data index and make sequence + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}( + [&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order); + }(); + + // make sequence to access vgpr data. Add zero as last element of src_data_idx_seq + constexpr auto vgpr_data_idx_seq = generate_sequence_v2( + [&](auto i) { + if constexpr(i.value < src_data_idx.Size()) + { + return Number{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + auto op_r = src_dvgpr_.template GetAsType(vgpr_data_idx_seq); + const bool is_src_valid = + oob_thread_scratch_.template GetAsType(vgpr_data_idx_seq); + auto op_r_v = is_src_valid ? op_r : dst_vector_t(0); + dst_dvgpr_.template SetAsType(vgpr_data_idx_seq, op_r_v); + }); + // make forward steps // forward step for each iteration just add 1 const auto dst_forward_steps = generate_tuple( @@ -352,7 +386,7 @@ struct ThreadGroupTransferGlobal dst_buf.template Set( dst_coord_.GetOffset(), true, - dvgpr_.template GetAsType(vgpr_data_idx_seq)); + dst_dvgpr_.template GetAsType(vgpr_data_idx_seq)); // For each dimension move fwd, bwd or don't move static_for<0, nDim, 1>{}([&](auto i) { @@ -389,6 +423,14 @@ struct ThreadGroupTransferGlobal return make_naive_tensor_descriptor_packed(access_lengths_as_tuple); } + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto access_lengths_as_tuple = container_push_back( + sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{}); + + return make_naive_tensor_descriptor_packed(access_lengths_as_tuple); + } + static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){}; using ThreadScratchData = StaticTensorTupleOfVectorBuffer; - ThreadScratchData dvgpr_; + static constexpr auto src_oob_thread_scratch_desc_ = + decltype(GetSrcThreadScratchDescriptor()){}; + using OOBThreadScratch = StaticTensorTupleOfVectorBuffer; + + ThreadScratchData src_dvgpr_; + ThreadScratchData dst_dvgpr_; + OOBThreadScratch oob_thread_scratch_; SrcCoord src_coord_; DstCoord dst_coord_; const ElementwiseOperation element_op_; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index e47bb37a899..caf468d6cbb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -132,10 +132,6 @@ struct ABTransferWaveTiles index_t, index_t) { - // Notes: padding is currently not supported with transpose - static_assert(!((PadMN || PadK) && ABDoTranspose), - "padding is currently not supported with transpose"); - const index_t MN_grid = !PadMN ? sizeMN : MNPad; const index_t K_grid = !PadK ? sizeK : KPad; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 5431c054fa7..262a3435b91 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -368,16 +368,12 @@ struct GridwiseGemm_wmma_cshuffle_v3_base #ifdef __gfx12__ static constexpr bool IsAWaveTransferApplicable = !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && - ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && - !is_same_v) || - is_same_v) && + ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; static constexpr bool IsBWaveTransferApplicable = !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && - ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && - !is_same_v) || - is_same_v) && + BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; static constexpr bool IsWaveTileInterleavedFitting = From faf2a206b2fd608de39db913f692f2f512af86ad Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 6 Jan 2026 08:31:27 +0000 Subject: [PATCH 02/12] Add/modify instances to use wave transfer for gemm universal Condition is changed so now the vectorsize of vmem reading and lds writing must be equal to 8 in order to use the wave transfer --- .../gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp | 1 - .../device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp | 3 ++- .../device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp | 5 +++-- .../device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 5 +++-- .../device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp | 2 +- .../device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp | 4 ++-- .../device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 4 ++-- 7 files changed, 13 insertions(+), 11 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 262a3435b91..2a81b3fb1ae 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -364,7 +364,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // Limitations of the current implementation: // - no multiAB - // - GemmSpecialization Default with transpose #ifdef __gfx12__ static constexpr bool IsAWaveTransferApplicable = !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp index d79fe9bfa3a..d7b654a3456 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -47,7 +47,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp index e284cbbb833..7d7966c47fc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -40,7 +40,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 6195d40f872..2f631994807 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -41,7 +41,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -52,7 +52,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index e51bec3dfb0..b50e37cf0a8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -44,7 +44,7 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index 66ba1e38301..4651068d860 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -40,9 +40,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 8eccccf354a..4dcbaccaa48 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -41,7 +41,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -49,7 +49,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, From 27881942ac3bd29df6e9c4d0da78e222c78dd3fb Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 6 Jan 2026 08:33:30 +0000 Subject: [PATCH 03/12] Fix clang format --- ...ead_group_tensor_slice_transfer_global.hpp | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp index 036d5c274fa..1c322fe4a73 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp @@ -160,8 +160,7 @@ struct ThreadGroupTransferGlobal // check if src element is valid const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - oob_thread_scratch_. - template SetAsType(vgpr_data_idx_seq, is_src_valid); + oob_thread_scratch_.template SetAsType(vgpr_data_idx_seq, is_src_valid); // Vector length of elementwise operation constexpr auto get_elem_op_vec_len = []() { @@ -201,9 +200,8 @@ struct ThreadGroupTransferGlobal // Load data from memory in src_vector first auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0; - src_vector_container src_vector = - src_vector_container{grid_buf.template Get( - index, true)}; + src_vector_container src_vector = src_vector_container{ + grid_buf.template Get(index, true)}; // apply the src elementwise op and convert to DstData under the hood if needed static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) { @@ -214,9 +212,8 @@ struct ThreadGroupTransferGlobal // store result in dvgpr_ (static array holding loaded data). // At this point data is already converted to DstData type and // the elementwise operation has been applied - src_dvgpr_.template SetAsType( - vgpr_data_idx_seq, - op_r_v.template AsType()[I0]); + src_dvgpr_.template SetAsType(vgpr_data_idx_seq, + op_r_v.template AsType()[I0]); // For each dimension move fwd, bwd or don't move static_for<0, nDim, 1>{}([&](auto i) { @@ -425,8 +422,8 @@ struct ThreadGroupTransferGlobal __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - constexpr auto access_lengths_as_tuple = container_push_back( - sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{}); + constexpr auto access_lengths_as_tuple = + container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{}); return make_naive_tensor_descriptor_packed(access_lengths_as_tuple); } @@ -441,10 +438,10 @@ struct ThreadGroupTransferGlobal static constexpr auto src_oob_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; using OOBThreadScratch = StaticTensorTupleOfVectorBuffer; + bool, + 1, + decltype(src_oob_thread_scratch_desc_), + true>; ThreadScratchData src_dvgpr_; ThreadScratchData dst_dvgpr_; From 08e6df2dd04b0022e35ba3547e5db0b1a016aca3 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 6 Jan 2026 11:03:26 +0000 Subject: [PATCH 04/12] Modify example --- example/01_gemm/gemm_wmma_fp16_v3.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp16_v3.cpp b/example/01_gemm/gemm_wmma_fp16_v3.cpp index 5b10edd681a..3b3b0fec16f 100644 --- a/example/01_gemm/gemm_wmma_fp16_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp16_v3.cpp @@ -19,22 +19,22 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, - PassThrough, PassThrough, PassThrough, GemmDefault, + PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 1, 8, 1, - S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 1, 8, 1, + 1, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; From 8e909b35131a21bc80ba8cc4793a017f63d861e4 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 6 Jan 2026 15:56:18 +0000 Subject: [PATCH 05/12] Fix bwd data --- ...ce_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index bbf62d5fbec..17bd37edee9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -451,7 +451,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 AComputeType, BComputeType, false, - false>; + false, + false, + true>; #define GridwiseGemmCTransposeTemplateParameters \ ALayout, BLayout, DsLayout, ELayout, Tuple, Tuple, AccDataType, \ @@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \ CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \ - AComputeType, false, false + AComputeType, false, false, false, true using GridwiseGemmCTranspose = std::conditional_t Date: Wed, 7 Jan 2026 14:07:58 +0000 Subject: [PATCH 06/12] Add restriction for wave transfer with padding and transpose Add test case which shows this limitation --- ...tched_gemm_multiple_d_wmma_cshuffle_v3.hpp | 36 +++++++++++++++++ ...e_batched_gemm_wmma_cshuffle_v3_common.hpp | 21 ++++++++++ ..._multiple_d_layernorm_wmma_cshuffle_v3.hpp | 36 +++++++++++++++++ .../device_gemm_reduce_wmma_cshuffle_v3.hpp | 36 +++++++++++++++++ .../device_gemm_wmma_cshuffle_v3_common.hpp | 36 +++++++++++++++++ .../impl/device_gemm_wmma_cshuffle_v3r1.hpp | 36 +++++++++++++++++ ...e_grouped_gemm_wmma_splitk_cshuffle_v3.hpp | 39 ++++++++++++++++++- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 34 +++++++++++----- .../test_gemm_universal_ut_cases_fp16.inc | 8 ++-- 9 files changed, 268 insertions(+), 14 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp index 126d107725d..fcf2c98d362 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -606,6 +606,42 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3 return false; } + if constexpr(GridwiseGemm::AWaveTransferApplicable() && + !(is_same::value)) + { + if(ck::is_gfx12_supported() && + !(arg.K % GridwiseGemm::ABlockTransferDstScalarPerVector_AK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "ABlockTransferDstScalarPerVector_AK1! K: " + << arg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(GridwiseGemm::BWaveTransferApplicable() && + !(is_same::value)) + { + if(ck::is_gfx12_supported() && + !(arg.K % GridwiseGemm::BBlockTransferDstScalarPerVector_BK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "BBlockTransferDstScalarPerVector_BK1! K: " + << arg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + return GridwiseGemm::CheckValidity(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp index 59a820861c3..f1646e22951 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp @@ -455,6 +455,27 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common return false; } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + return GridwiseGemm::CheckValidity(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp index f0216c3f711..77f76645b57 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp @@ -701,6 +701,42 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 return false; } + if constexpr(GridwiseGemmWelford::AWaveTransferApplicable() && + !(is_same::value)) + { + if(ck::is_gfx12_supported() && + !(arg.KRaw_ % GridwiseGemmWelford::ABlockTransferDstScalarPerVector_AK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "ABlockTransferDstScalarPerVector_AK1! K: " + << arg.KRaw_ << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(GridwiseGemmWelford::BWaveTransferApplicable() && + !(is_same::value)) + { + if(ck::is_gfx12_supported() && + !(arg.KRaw_ % GridwiseGemmWelford::BBlockTransferDstScalarPerVector_BK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "BBlockTransferDstScalarPerVector_BK1! K: " + << arg.KRaw_ << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + typename GridwiseGemmWelford::Argument gemm_arg{ std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index 317c4073df9..664ccdabc45 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -456,6 +456,42 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera return false; } + if constexpr(GridwiseGemm::AWaveTransferApplicable() && + !(is_same::value)) + { + if(ck::is_gfx12_supported() && + !(arg.KRaw_ % GridwiseGemm::ABlockTransferDstScalarPerVector_AK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "ABlockTransferDstScalarPerVector_AK1! K: " + << arg.KRaw_ << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(GridwiseGemm::BWaveTransferApplicable() && + !(is_same::value)) + { + if(ck::is_gfx12_supported() && + !(arg.KRaw_ % GridwiseGemm::BBlockTransferDstScalarPerVector_BK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "BBlockTransferDstScalarPerVector_BK1! K: " + << arg.KRaw_ << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, std::array{}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index e96ec58cba3..738e3fb23f8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -421,6 +421,42 @@ struct DeviceGemm_Wmma_CShuffleV3_Common } } + if constexpr(GridwiseGemm::AWaveTransferApplicable() && + !(is_same::value)) + { + if(ck::is_gfx12_supported() && + !(arg.K % GridwiseGemm::ABlockTransferDstScalarPerVector_AK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "ABlockTransferDstScalarPerVector_AK1! K: " + << arg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(GridwiseGemm::BWaveTransferApplicable() && + !(is_same::value)) + { + if(ck::is_gfx12_supported() && + !(arg.K % GridwiseGemm::BBlockTransferDstScalarPerVector_BK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "BBlockTransferDstScalarPerVector_BK1! K: " + << arg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + return GridwiseGemm::CheckValidity(arg); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp index e09c69d052a..3dd71db33b3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp @@ -393,6 +393,42 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1::value)) + { + if(ck::is_gfx12_supported() && + !(arg.K % GridwiseGemm::ABlockTransferDstScalarPerVector_AK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "ABlockTransferDstScalarPerVector_AK1! K: " + << arg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(GridwiseGemm::BWaveTransferApplicable() && + !(is_same::value)) + { + if(ck::is_gfx12_supported() && + !(arg.K % GridwiseGemm::BBlockTransferDstScalarPerVector_BK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "BBlockTransferDstScalarPerVector_BK1! K: " + << arg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + return GridwiseGemm::CheckValidity( *dynamic_cast(&arg)); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 39024d39e43..d70acc8cbef 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -704,7 +704,44 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK::value)) + { + if(ck::is_gfx12_supported() && + !(a.K % GridwiseGemm::ABlockTransferDstScalarPerVector_AK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "ABlockTransferDstScalarPerVector_AK1! K: " + << a.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(GridwiseGemm::BWaveTransferApplicable() && + !(is_same::value)) + { + if(ck::is_gfx12_supported() && + !(a.K % GridwiseGemm::BBlockTransferDstScalarPerVector_BK1_ == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of " + "BBlockTransferDstScalarPerVector_BK1! K: " + << a.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + bool group_arg_valid = GridwiseGemm::CheckValidity(a); if(not group_arg_valid) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 2a81b3fb1ae..fe0c06aeb04 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -296,6 +296,13 @@ template {}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -362,18 +369,27 @@ struct GridwiseGemm_wmma_cshuffle_v3_base WmmaSelector::selected_wmma .wave_size; + __host__ __device__ static constexpr bool AWaveTransferApplicable() + { + return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && + ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 && + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && + !IsBPreShuffled; + } + + __host__ __device__ static constexpr bool BWaveTransferApplicable() + { + return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && + BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 && + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + } + // Limitations of the current implementation: // - no multiAB #ifdef __gfx12__ - static constexpr bool IsAWaveTransferApplicable = - !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && - ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 && - BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; - - static constexpr bool IsBWaveTransferApplicable = - !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && - BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 && - BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable(); + + static constexpr bool IsBWaveTransferApplicable = BWaveTransferApplicable(); static constexpr bool IsWaveTileInterleavedFitting = (NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize); diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc index 25d95cda3df..01d7d5a5fdd 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc @@ -125,7 +125,7 @@ TYPED_TEST(TestGemmUniversal_FP16_KM_NK, MidLargeM) TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -139,7 +139,7 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -153,7 +153,7 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -169,7 +169,7 @@ TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK) TYPED_TEST(TestGemmUniversal_FP16_KM_NK, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; From 6ca366f97ff1c941caddb1bd41ea7dc6f95b52a6 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 8 Jan 2026 09:19:21 +0000 Subject: [PATCH 07/12] Fix validity checks 8 bit types --- ...tched_gemm_multiple_d_wmma_cshuffle_v3.hpp | 36 ++++-------- ..._multiple_d_layernorm_wmma_cshuffle_v3.hpp | 38 ++++--------- .../device_gemm_reduce_wmma_cshuffle_v3.hpp | 38 ++++--------- .../device_gemm_wmma_cshuffle_v3_common.hpp | 36 ++++-------- .../impl/device_gemm_wmma_cshuffle_v3r1.hpp | 36 ++++-------- ...e_grouped_gemm_wmma_splitk_cshuffle_v3.hpp | 36 ++++-------- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 57 ++++++++++++++++--- 7 files changed, 114 insertions(+), 163 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp index fcf2c98d362..ae247f4e31b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -606,40 +606,24 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3 return false; } - if constexpr(GridwiseGemm::AWaveTransferApplicable() && - !(is_same::value)) + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) { - if(ck::is_gfx12_supported() && - !(arg.K % GridwiseGemm::ABlockTransferDstScalarPerVector_AK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "ABlockTransferDstScalarPerVector_AK1! K: " - << arg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } - if constexpr(GridwiseGemm::BWaveTransferApplicable() && - !(is_same::value)) + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) { - if(ck::is_gfx12_supported() && - !(arg.K % GridwiseGemm::BBlockTransferDstScalarPerVector_BK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "BBlockTransferDstScalarPerVector_BK1! K: " - << arg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } return GridwiseGemm::CheckValidity(arg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp index 77f76645b57..81f505b594d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp @@ -701,40 +701,26 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 return false; } - if constexpr(GridwiseGemmWelford::AWaveTransferApplicable() && - !(is_same::value)) + if(ck::is_gfx12_supported() && + !GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) { - if(ck::is_gfx12_supported() && - !(arg.KRaw_ % GridwiseGemmWelford::ABlockTransferDstScalarPerVector_AK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "ABlockTransferDstScalarPerVector_AK1! K: " - << arg.KRaw_ << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } - if constexpr(GridwiseGemmWelford::BWaveTransferApplicable() && - !(is_same::value)) + if(ck::is_gfx12_supported() && + !GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) { - if(ck::is_gfx12_supported() && - !(arg.KRaw_ % GridwiseGemmWelford::BBlockTransferDstScalarPerVector_BK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "BBlockTransferDstScalarPerVector_BK1! K: " - << arg.KRaw_ << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } typename GridwiseGemmWelford::Argument gemm_arg{ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index 664ccdabc45..28c9f2bddcc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -456,40 +456,26 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera return false; } - if constexpr(GridwiseGemm::AWaveTransferApplicable() && - !(is_same::value)) + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) { - if(ck::is_gfx12_supported() && - !(arg.KRaw_ % GridwiseGemm::ABlockTransferDstScalarPerVector_AK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "ABlockTransferDstScalarPerVector_AK1! K: " - << arg.KRaw_ << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } - if constexpr(GridwiseGemm::BWaveTransferApplicable() && - !(is_same::value)) + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) { - if(ck::is_gfx12_supported() && - !(arg.KRaw_ % GridwiseGemm::BBlockTransferDstScalarPerVector_BK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "BBlockTransferDstScalarPerVector_BK1! K: " - << arg.KRaw_ << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index 738e3fb23f8..c09befa717d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -421,40 +421,24 @@ struct DeviceGemm_Wmma_CShuffleV3_Common } } - if constexpr(GridwiseGemm::AWaveTransferApplicable() && - !(is_same::value)) + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) { - if(ck::is_gfx12_supported() && - !(arg.K % GridwiseGemm::ABlockTransferDstScalarPerVector_AK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "ABlockTransferDstScalarPerVector_AK1! K: " - << arg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } - if constexpr(GridwiseGemm::BWaveTransferApplicable() && - !(is_same::value)) + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) { - if(ck::is_gfx12_supported() && - !(arg.K % GridwiseGemm::BBlockTransferDstScalarPerVector_BK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "BBlockTransferDstScalarPerVector_BK1! K: " - << arg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } return GridwiseGemm::CheckValidity(arg); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp index 3dd71db33b3..377f7929795 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp @@ -393,40 +393,24 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1::value)) + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) { - if(ck::is_gfx12_supported() && - !(arg.K % GridwiseGemm::ABlockTransferDstScalarPerVector_AK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "ABlockTransferDstScalarPerVector_AK1! K: " - << arg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } - if constexpr(GridwiseGemm::BWaveTransferApplicable() && - !(is_same::value)) + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) { - if(ck::is_gfx12_supported() && - !(arg.K % GridwiseGemm::BBlockTransferDstScalarPerVector_BK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "BBlockTransferDstScalarPerVector_BK1! K: " - << arg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } return GridwiseGemm::CheckValidity( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index d70acc8cbef..99a18e07fc1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -706,40 +706,24 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK::value)) + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(a.M, a.K)) { - if(ck::is_gfx12_supported() && - !(a.K % GridwiseGemm::ABlockTransferDstScalarPerVector_AK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "ABlockTransferDstScalarPerVector_AK1! K: " - << a.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } - if constexpr(GridwiseGemm::BWaveTransferApplicable() && - !(is_same::value)) + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(a.N, a.K)) { - if(ck::is_gfx12_supported() && - !(a.K % GridwiseGemm::BBlockTransferDstScalarPerVector_BK1_ == 0)) + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of " - "BBlockTransferDstScalarPerVector_BK1! K: " - << a.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; } + return false; } bool group_arg_valid = GridwiseGemm::CheckValidity(a); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index fe0c06aeb04..7980d7e5ba4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -3,6 +3,7 @@ #pragma once +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #include @@ -296,13 +297,6 @@ template {}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -993,6 +987,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return de_grid_desc_mblock_mperblock_nblock_nperblock; } + // Conditions for Wave Transfer with transpose: + // - 16 bit type: K % 8 == 0 (4 subtiles of 8x8) + // - 8 bit type: K % 8 == 0 and M % 16 == 0 (2 subtiles of 8x16) + __host__ static constexpr bool CheckValidityAWaveTransfer(const index_t& M, const index_t& K) + { + if constexpr(AWaveTransferApplicable() && + !(is_same::value)) + { + if(!(K % ABlockTransferDstScalarPerVector_AK1 == 0)) + { + return false; + } + bool pass = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + pass &= !(sizeof(ADataType_) == 1 && + !(M % (2 * ABlockTransferSrcScalarPerVector) == 0)); + }); + return pass; + } + else + { + return true; + } + } + + __host__ static constexpr bool CheckValidityBWaveTransfer(const index_t& N, const index_t& K) + { + if constexpr(BWaveTransferApplicable() && + !(is_same::value)) + { + if(!(K % BBlockTransferDstScalarPerVector_BK1 == 0)) + { + return false; + } + bool pass = true; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + pass &= !(sizeof(BDataType_) == 1 && + !(N % (2 * BBlockTransferSrcScalarPerVector) == 0)); + }); + return pass; + } + else + { + return true; + } + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ static constexpr bool CheckValidity(const Argument& karg, From 1a390585822f350c3500691f526075a17209242c Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 9 Jan 2026 15:59:37 +0000 Subject: [PATCH 08/12] Add validity check gemm_bias_add_reduce --- ..._gemm_bias_add_reduce_wmma_cshuffle_v3.hpp | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp index e8e3b69cb5f..85ca16b2932 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{ std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, From bbc96980f00f8e02c1bcc40e007c6aaa33dacea1 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 13 Jan 2026 09:07:53 +0000 Subject: [PATCH 09/12] Add validity check grouped gemm tile loop --- ..._multiple_d_wmma_cshuffle_tile_loop_v3.hpp | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp index 5ae9eaf8aca..6b5776c4eb1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_wmma_cshuffle_tile_loop_v3.hpp @@ -503,6 +503,29 @@ struct DeviceGroupedGemmMultipleD_Wmma_CShuffle_TileLoop_V3 bool supported = true; for(index_t i = 0; i < arg.group_count_; ++i) { + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer( + arg.gemm_descs_[i].M_, arg.gemm_descs_[i].K_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer( + arg.gemm_descs_[i].N_, arg.gemm_descs_[i].K_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + std::array placeholder_p_ds_grid{}; std::array stride_Ds; std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); From ebb6670d8ad381fb3b35e351fbe8dd43633e6883 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 22 Jan 2026 09:34:40 +0000 Subject: [PATCH 10/12] Fix validity checks new flavours --- ...ontraction_multiple_d_wmma_cshuffle_v3.hpp | 20 +++++++++++++++++ ...e_batched_gemm_reduce_wmma_cshuffle_v3.hpp | 22 +++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp index 47ef2e339d9..b59357ffe93 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp @@ -833,6 +833,26 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3 return false; } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + // check vector access static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp index 227a8aedd94..593a9084981 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp @@ -588,6 +588,28 @@ struct DeviceBatchedGemmReduce_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, std::array{}, From 4efadddfeb0f418ac62ba128363db6d69f0b30ee Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 22 Jan 2026 09:59:35 +0000 Subject: [PATCH 11/12] Minor fixes --- ..._grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 8 ++++---- .../gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index 17bd37edee9..dfdfd53725f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -450,10 +450,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 BlkGemmPipelineVer, AComputeType, BComputeType, - false, - false, - false, - true>; + false, // PermuteA + false, // PermuteB + false, // IsBPreShuffled + true>; // ForceThreadTileTransfer #define GridwiseGemmCTransposeTemplateParameters \ ALayout, BLayout, DsLayout, ELayout, Tuple, Tuple, AccDataType, \ diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 7980d7e5ba4..bcf131003c2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -3,7 +3,6 @@ #pragma once -#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) #include #include From baa0893f141f397cffe512099ba2b5d2c9897d5c Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Mon, 26 Jan 2026 08:45:18 +0000 Subject: [PATCH 12/12] Fix clang format --- .../device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp index f1646e22951..fb1ca3127e2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp @@ -455,7 +455,6 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common return false; } - if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))