Implement fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline#4276
Implement fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline#4276assistant-librarian[bot] wants to merge 15 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
Implements additional mixed-precision GEMM combinations (fp16×fp8, bf16×fp8, fp8×fp16, fp8×bf16) for the CompV3 pipeline by unifying warp-precision type selection and introducing a general load+convert path that supports transpose+convert.
Changes:
- Added
DetermineWarpPrecTypeto centralize warp GEMM input type selection for mixed precision. - Replaced
load_int4_tile/load_interleaved_pk_typewith the newload_and_convert_tileabstraction across multiple ops/pipelines. - Extended warp GEMM distribution encoding APIs to be parameterizable by number of accesses (compile-time encodings for transpose+convert).
Reviewed changes
Copilot reviewed 59 out of 59 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| projects/composablekernel/include/ck_tile/ops/common/determine_warp_prec_type.hpp | Adds unified rules for selecting warp GEMM operand precision types (enables fp16/bf16↔fp8). |
| projects/composablekernel/include/ck_tile/ops/common/load_and_convert_tile.hpp | Introduces generalized tile load + convert (and packed-type handling) used across ops. |
| projects/composablekernel/include/ck_tile/core/tensor/load_tile_transpose.hpp | Adds out-param transpose load and a new transpose+convert helper needed for mixed precision. |
| projects/composablekernel/include/ck_tile/ops/gemm/** | Updates GEMM pipelines/blocks/warp attributes to use new type rules and encoding APIs. |
| projects/composablekernel/include/ck_tile/ops/**/*.hpp | Switches common include from interleaved pk loader to load+convert + warp-prec determination. |
| projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_util.hpp | Adjusts random input initialization range in GEMM tests. |
| projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp | Adds new CompV3 kernel type tuples covering fp16/bf16↔fp8 combinations. |
| projects/composablekernel/CHANGELOG.md | Documents added mixed-precision support. |
Comments suppressed due to low confidence (1)
projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_util.hpp:1
FillUniformDistributionIntegerValue<...>is being given fractional bounds (-0.5, 0.5). If the implementation uses an integer distribution, these bounds may truncate/round to 0, producing all-zero inputs and weakening test coverage (or changing expected numeric behavior). Prefer a float/uniform-real filler for floating types, or use integer bounds like{-1, 1, seed}if the intent is small-magnitude integer samples.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| using a_prec_type = ck_tile::half_t; | ||
| using b_prec_type = ck_tile::half_t; | ||
| }; | ||
| }; // namespace ck_tile |
There was a problem hiding this comment.
namespace ck_tile { is closed with };, which is invalid C++ syntax for closing a namespace block and will fail to compile. Replace it with a plain } (no semicolon).
| { | ||
| if constexpr(is_packed_type_v<typename WarpWindow::Base::DataType>) | ||
| { | ||
| ConverterLoader<typename WarpTile::DataType, UnaryOpSize>::load_interleaved_pk_type( |
There was a problem hiding this comment.
The packed-type path ignores the LoadTranspose template argument because ConverterLoader<..., UnaryOpSize> is instantiated with the default LoadTranspose = false. This means load_and_convert_tile<UnaryOpSize, true>(...) will silently take the packed-type path without triggering the intended transpose-not-supported check. Pass LoadTranspose into the ConverterLoader instantiation (or add an explicit static_assert(!LoadTranspose) in load_and_convert_tile when the source is packed).
| ConverterLoader<typename WarpTile::DataType, UnaryOpSize>::load_interleaved_pk_type( | |
| ConverterLoader<typename WarpTile::DataType, UnaryOpSize, LoadTranspose>::load_interleaved_pk_type( |
| using InputDataVec = array<InputDataType, vecLoadSize>; | ||
| static_for<0, num_of_access, 1>{}([&](auto iAccess) { | ||
| auto input_vec = | ||
| trans_tensor.get_thread_buffer().template get_as<InputDataVec>(number<iAccess>{}); | ||
|
|
||
| // Element-wise type conversion | ||
| // This will be unrolled by the compiler for each element in the vector | ||
| static_for<0, vecLoadSize, 1>{}([&](auto iElem) { | ||
| auto output_elem = type_convert<OutputDataType>(input_vec[iElem]); | ||
| out_tensor.get_thread_buffer()[number<iAccess * vecLoadSize + iElem>{}] = output_elem; | ||
| }); |
There was a problem hiding this comment.
The transpose+convert implementation writes linearly into out_tensor.get_thread_buffer() without using the output distribution's y→d mapping (unlike load_tile_transpose_with_offset, which maps each vector to the correct destination via descriptors and set_vectorized_elements). If the output distribution's thread-buffer layout differs from the input distribution (which is especially likely when allowing different vector lengths), this will populate the wrong elements/ordering. Consider mirroring the mapping logic from load_tile_transpose_with_offset: compute the per-access y-index, map through y_out_desc, and write each converted vector to the correct destination using the output view’s vectorized setter.
| using InputDataVec = array<InputDataType, vecLoadSize>; | |
| static_for<0, num_of_access, 1>{}([&](auto iAccess) { | |
| auto input_vec = | |
| trans_tensor.get_thread_buffer().template get_as<InputDataVec>(number<iAccess>{}); | |
| // Element-wise type conversion | |
| // This will be unrolled by the compiler for each element in the vector | |
| static_for<0, vecLoadSize, 1>{}([&](auto iElem) { | |
| auto output_elem = type_convert<OutputDataType>(input_vec[iElem]); | |
| out_tensor.get_thread_buffer()[number<iAccess * vecLoadSize + iElem>{}] = output_elem; | |
| }); | |
| using InputDataVec = array<InputDataType, vecLoadSize>; | |
| using OutputDataVec = array<OutputDataType, vecLoadSize>; | |
| static_for<0, num_of_access, 1>{}([&](auto iAccess) { | |
| auto input_vec = | |
| trans_tensor.get_thread_buffer().template get_as<InputDataVec>(number<iAccess>{}); | |
| // Element-wise type conversion into an output vector, then delegate | |
| // placement to the output tensor's vectorized setter so that the | |
| // output distribution's y->d mapping (y_out_desc) is honored. | |
| OutputDataVec output_vec; | |
| static_for<0, vecLoadSize, 1>{}([&](auto iElem) { | |
| output_vec(iElem) = type_convert<OutputDataType>(input_vec[iElem]); | |
| }); | |
| // Use the same vectorized write path as the non-convert transpose | |
| // implementation, which maps each logical access to the correct | |
| // destination in the output thread buffer according to y_out_desc. | |
| out_tensor.template set_vectorized_elements<OutputDataType, vecLoadSize>( | |
| number<iAccess>{}, output_vec); |
| using b_prec_type = APrecType; | ||
| }; | ||
|
|
||
| // For B x pk_fp4_raw_t, use the B type. |
There was a problem hiding this comment.
The comments for the pk_fp4_raw_t specializations don’t match the template argument order (line 64 is pk_fp4_raw_t x B, not B x pk_fp4_raw_t). Updating these comments would reduce confusion when extending the rule set.
| // For B x pk_fp4_raw_t, use the B type. | |
| // For pk_fp4_raw_t x B, use the B type. |
| // For A x pk_fp4_raw_t, use the A type. | ||
| template <typename APrecType> | ||
| struct DetermineWarpPrecType<APrecType, ck_tile::pk_fp4_raw_t> |
There was a problem hiding this comment.
The comments for the pk_fp4_raw_t specializations don’t match the template argument order (line 64 is pk_fp4_raw_t x B, not B x pk_fp4_raw_t). Updating these comments would reduce confusion when extending the rule set.
|
Please fix conflicts. |
9f46a7b to
3588979
Compare
54fdd39 to
f396416
Compare
… from kernel parameters
…:get_warp_dstr_encoding
…defines the A and B types - This is for improved clarity and finer control of the datatypes to use
f396416 to
5a304ff
Compare
|
Hi @SamiAario-AMD , is this still work in progress or can it be closed? |
Proposed changes
Implement fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline. To make this work, we unify some of the type conversions needed into a templated class
DetermineWarpPrecType, and we template-parameterizeget_awarp_dstr_encodingof the block GEMM withNumAccessA. The latter is so that we can generate tile distribution encodings during compile time that are compatible with a new functionload_tile_transpose_convertthat we introduce, which is used for transposed loading and converting of tiles.This PR depends on this other PR which has not been merged already: ROCm/composable_kernel#3505
The changes introduced by this other PR are also in this one.
For this reason, this PR is opened as a draft.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
🔁 Imported from ROCm/composable_kernel#3669
🧑💻 Originally authored by @SamiAario-AMD