Skip to content

Implement fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline#4276

Closed
assistant-librarian[bot] wants to merge 15 commits into
developfrom
import/develop/ROCm_composable_kernel/pr-3669
Closed

Implement fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline#4276
assistant-librarian[bot] wants to merge 15 commits into
developfrom
import/develop/ROCm_composable_kernel/pr-3669

Conversation

@assistant-librarian
Copy link
Copy Markdown
Contributor

Proposed changes

Implement fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline. To make this work, we unify some of the type conversions needed into a templated class DetermineWarpPrecType, and we template-parameterize get_awarp_dstr_encoding of the block GEMM with NumAccessA. The latter is so that we can generate tile distribution encodings during compile time that are compatible with a new function load_tile_transpose_convert that we introduce, which is used for transposed loading and converting of tiles.

This PR depends on this other PR which has not been merged already: ROCm/composable_kernel#3505
The changes introduced by this other PR are also in this one.
For this reason, this PR is opened as a draft.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered


🔁 Imported from ROCm/composable_kernel#3669
🧑‍💻 Originally authored by @SamiAario-AMD

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Implements additional mixed-precision GEMM combinations (fp16×fp8, bf16×fp8, fp8×fp16, fp8×bf16) for the CompV3 pipeline by unifying warp-precision type selection and introducing a general load+convert path that supports transpose+convert.

Changes:

  • Added DetermineWarpPrecType to centralize warp GEMM input type selection for mixed precision.
  • Replaced load_int4_tile / load_interleaved_pk_type with the new load_and_convert_tile abstraction across multiple ops/pipelines.
  • Extended warp GEMM distribution encoding APIs to be parameterizable by number of accesses (compile-time encodings for transpose+convert).

Reviewed changes

Copilot reviewed 59 out of 59 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
projects/composablekernel/include/ck_tile/ops/common/determine_warp_prec_type.hpp Adds unified rules for selecting warp GEMM operand precision types (enables fp16/bf16↔fp8).
projects/composablekernel/include/ck_tile/ops/common/load_and_convert_tile.hpp Introduces generalized tile load + convert (and packed-type handling) used across ops.
projects/composablekernel/include/ck_tile/core/tensor/load_tile_transpose.hpp Adds out-param transpose load and a new transpose+convert helper needed for mixed precision.
projects/composablekernel/include/ck_tile/ops/gemm/** Updates GEMM pipelines/blocks/warp attributes to use new type rules and encoding APIs.
projects/composablekernel/include/ck_tile/ops/**/*.hpp Switches common include from interleaved pk loader to load+convert + warp-prec determination.
projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_util.hpp Adjusts random input initialization range in GEMM tests.
projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp Adds new CompV3 kernel type tuples covering fp16/bf16↔fp8 combinations.
projects/composablekernel/CHANGELOG.md Documents added mixed-precision support.
Comments suppressed due to low confidence (1)

projects/composablekernel/test/ck_tile/gemm/test_gemm_pipeline_util.hpp:1

  • FillUniformDistributionIntegerValue<...> is being given fractional bounds (-0.5, 0.5). If the implementation uses an integer distribution, these bounds may truncate/round to 0, producing all-zero inputs and weakening test coverage (or changing expected numeric behavior). Prefer a float/uniform-real filler for floating types, or use integer bounds like {-1, 1, seed} if the intent is small-magnitude integer samples.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

using a_prec_type = ck_tile::half_t;
using b_prec_type = ck_tile::half_t;
};
}; // namespace ck_tile
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

namespace ck_tile { is closed with };, which is invalid C++ syntax for closing a namespace block and will fail to compile. Replace it with a plain } (no semicolon).

Copilot uses AI. Check for mistakes.
{
if constexpr(is_packed_type_v<typename WarpWindow::Base::DataType>)
{
ConverterLoader<typename WarpTile::DataType, UnaryOpSize>::load_interleaved_pk_type(
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The packed-type path ignores the LoadTranspose template argument because ConverterLoader<..., UnaryOpSize> is instantiated with the default LoadTranspose = false. This means load_and_convert_tile<UnaryOpSize, true>(...) will silently take the packed-type path without triggering the intended transpose-not-supported check. Pass LoadTranspose into the ConverterLoader instantiation (or add an explicit static_assert(!LoadTranspose) in load_and_convert_tile when the source is packed).

Suggested change
ConverterLoader<typename WarpTile::DataType, UnaryOpSize>::load_interleaved_pk_type(
ConverterLoader<typename WarpTile::DataType, UnaryOpSize, LoadTranspose>::load_interleaved_pk_type(

Copilot uses AI. Check for mistakes.
Comment on lines +606 to +616
using InputDataVec = array<InputDataType, vecLoadSize>;
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
auto input_vec =
trans_tensor.get_thread_buffer().template get_as<InputDataVec>(number<iAccess>{});

// Element-wise type conversion
// This will be unrolled by the compiler for each element in the vector
static_for<0, vecLoadSize, 1>{}([&](auto iElem) {
auto output_elem = type_convert<OutputDataType>(input_vec[iElem]);
out_tensor.get_thread_buffer()[number<iAccess * vecLoadSize + iElem>{}] = output_elem;
});
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transpose+convert implementation writes linearly into out_tensor.get_thread_buffer() without using the output distribution's y→d mapping (unlike load_tile_transpose_with_offset, which maps each vector to the correct destination via descriptors and set_vectorized_elements). If the output distribution's thread-buffer layout differs from the input distribution (which is especially likely when allowing different vector lengths), this will populate the wrong elements/ordering. Consider mirroring the mapping logic from load_tile_transpose_with_offset: compute the per-access y-index, map through y_out_desc, and write each converted vector to the correct destination using the output view’s vectorized setter.

Suggested change
using InputDataVec = array<InputDataType, vecLoadSize>;
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
auto input_vec =
trans_tensor.get_thread_buffer().template get_as<InputDataVec>(number<iAccess>{});
// Element-wise type conversion
// This will be unrolled by the compiler for each element in the vector
static_for<0, vecLoadSize, 1>{}([&](auto iElem) {
auto output_elem = type_convert<OutputDataType>(input_vec[iElem]);
out_tensor.get_thread_buffer()[number<iAccess * vecLoadSize + iElem>{}] = output_elem;
});
using InputDataVec = array<InputDataType, vecLoadSize>;
using OutputDataVec = array<OutputDataType, vecLoadSize>;
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
auto input_vec =
trans_tensor.get_thread_buffer().template get_as<InputDataVec>(number<iAccess>{});
// Element-wise type conversion into an output vector, then delegate
// placement to the output tensor's vectorized setter so that the
// output distribution's y->d mapping (y_out_desc) is honored.
OutputDataVec output_vec;
static_for<0, vecLoadSize, 1>{}([&](auto iElem) {
output_vec(iElem) = type_convert<OutputDataType>(input_vec[iElem]);
});
// Use the same vectorized write path as the non-convert transpose
// implementation, which maps each logical access to the correct
// destination in the output thread buffer according to y_out_desc.
out_tensor.template set_vectorized_elements<OutputDataType, vecLoadSize>(
number<iAccess>{}, output_vec);

Copilot uses AI. Check for mistakes.
using b_prec_type = APrecType;
};

// For B x pk_fp4_raw_t, use the B type.
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments for the pk_fp4_raw_t specializations don’t match the template argument order (line 64 is pk_fp4_raw_t x B, not B x pk_fp4_raw_t). Updating these comments would reduce confusion when extending the rule set.

Suggested change
// For B x pk_fp4_raw_t, use the B type.
// For pk_fp4_raw_t x B, use the B type.

Copilot uses AI. Check for mistakes.
Comment on lines +70 to +72
// For A x pk_fp4_raw_t, use the A type.
template <typename APrecType>
struct DetermineWarpPrecType<APrecType, ck_tile::pk_fp4_raw_t>
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments for the pk_fp4_raw_t specializations don’t match the template argument order (line 64 is pk_fp4_raw_t x B, not B x pk_fp4_raw_t). Updating these comments would reduce confusion when extending the rule set.

Copilot uses AI. Check for mistakes.
@aosewski
Copy link
Copy Markdown
Contributor

Please fix conflicts.

@SamiAario-AMD SamiAario-AMD force-pushed the import/develop/ROCm_composable_kernel/pr-3669 branch from 9f46a7b to 3588979 Compare March 4, 2026 13:05
@SamiAario-AMD SamiAario-AMD force-pushed the import/develop/ROCm_composable_kernel/pr-3669 branch from 54fdd39 to f396416 Compare March 12, 2026 14:18
@SamiAario-AMD SamiAario-AMD force-pushed the import/develop/ROCm_composable_kernel/pr-3669 branch from f396416 to 5a304ff Compare March 17, 2026 10:00
@illsilin
Copy link
Copy Markdown
Contributor

Hi @SamiAario-AMD , is this still work in progress or can it be closed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants