From 86c85dc0d9957123c7e8946e483d4882a579d747 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 21 Apr 2026 08:30:05 +0000 Subject: [PATCH 1/2] Adjust loads/store to wave size 32 --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 339 +++++++++++++++++- 1 file changed, 336 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 80f0504231..c547bfff1a 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -38,12 +38,23 @@ #include #include +#include "fbgemm_gpu/utils/warp_size.h" + /******************************************************************************/ +#if !defined(ROCM_WAVE32) && !defined(ROCM_WAVE64) +#error "split_embeddings_common.h requires either ROCM_WAVE32 or ROCM_WAVE64 to be defined. Include fbgemm_gpu/utils/warp_size.h before this header." +#endif + typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); typedef float floatx2_t __attribute__((ext_vector_type(2))); #define AMDGCN_BUFFER_RES_3 0x00027000 +#if defined(ROCM_WAVE32) +#define AMDGCN_WAVE_SIZE 32 +#define THREADS_PER_ROW 32 +#else #define AMDGCN_WAVE_SIZE 64 #define THREADS_PER_ROW 64 +#endif #define BLOCK_SIZE_ROCM 256 namespace fbgemm_gpu::rocm { @@ -184,8 +195,14 @@ struct load_row_per_warp { run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 64); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=2, load 2 halves per lane via half2 + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); +#else emb_data[0] = llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half)); +#endif } }; @@ -197,6 +214,11 @@ struct load_row_per_warp { amdgcn_make_buffer_resource(p_emb_table + row_index * 128); *reinterpret_cast(emb_data) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); +#if defined(ROCM_WAVE32) + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 32) * sizeof(half2)); +#endif } }; @@ -206,10 +228,22 @@ struct load_row_per_warp { run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = amdgcn_make_buffer_resource( p_emb_table + row_index * 160, sizeof(half) * 160); - *reinterpret_cast(emb_data) = +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=5 — 2 half2 + 1 half per lane + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 32) * sizeof(half2)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( + emb_res, (lane_id + 128) * sizeof(half)); +#else + // wave64: dword_per_row=3 — 1 half2 + 1 half per lane + *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( emb_res, (lane_id + 128) * sizeof(half)); +#endif } }; @@ -219,10 +253,23 @@ struct load_row_per_warp { run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); - *reinterpret_cast(emb_data) = +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=6 — 3 half2 per lane + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 32) * sizeof(half2)); + *reinterpret_cast(&emb_data[4]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2)); +#else + // wave64: dword_per_row=3 — 1 half2 + 1 half per lane + *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( emb_res, (lane_id + 128) * sizeof(half)); +#endif } }; @@ -232,11 +279,27 @@ struct load_row_per_warp { run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 256); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=8 — 4 half2 per lane *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 32) * sizeof(half2)); + *reinterpret_cast(&emb_data[4]) = llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, (lane_id + 64) * sizeof(half2)); + *reinterpret_cast(&emb_data[6]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 96) * sizeof(half2)); +#else + // wave64: dword_per_row=4 — 2 half2 per lane + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2)); +#endif } }; @@ -246,6 +309,24 @@ struct load_row_per_warp { run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { int32x4_t emb_res = amdgcn_make_buffer_resource( p_emb_table + row_index * 320, sizeof(half) * 320); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=10 — 5 half2 per lane + *reinterpret_cast(&emb_data[0]) = + llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); + *reinterpret_cast(&emb_data[2]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 32) * sizeof(half2)); + *reinterpret_cast(&emb_data[4]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 64) * sizeof(half2)); + *reinterpret_cast(&emb_data[6]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 96) * sizeof(half2)); + *reinterpret_cast(&emb_data[8]) = + llvm_amdgcn_raw_buffer_load_fp16x2( + emb_res, (lane_id + 128) * sizeof(half2)); +#else + // wave64: dword_per_row=5 — 2 half2 + 1 half per lane *reinterpret_cast(&emb_data[0]) = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); *reinterpret_cast(&emb_data[2]) = @@ -253,6 +334,7 @@ struct load_row_per_warp { emb_res, (lane_id + 64) * sizeof(half2)); emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( emb_res, (lane_id + 256) * sizeof(half)); +#endif } }; @@ -282,6 +364,11 @@ struct load_row_per_warp { amdgcn_make_buffer_resource(p_emb_table + row_index * 64); emb_data[0] = llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=2 + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 32) * sizeof(float)); +#endif } }; @@ -296,8 +383,19 @@ struct load_row_per_warp { amdgcn_make_buffer_resource(p_emb_table + row_index * 128); emb_data[0] = llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=4 + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 32) * sizeof(float)); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 64) * sizeof(float)); + emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 96) * sizeof(float)); +#else + // wave64: dword_per_row=2 emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( emb_res, (lane_id + 64) * sizeof(float)); +#endif } }; @@ -310,12 +408,27 @@ struct load_row_per_warp { int lane_id) { int32x4_t emb_res = amdgcn_make_buffer_resource( p_emb_table + row_index * 160, sizeof(float) * 160); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=5 + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 32) * sizeof(float)); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 64) * sizeof(float)); + emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 96) * sizeof(float)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 128) * sizeof(float)); +#else + // wave64: dword_per_row=3 emb_data[0] = llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( emb_res, (lane_id + 64) * sizeof(float)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( emb_res, (lane_id + 128) * sizeof(float)); +#endif } }; @@ -328,12 +441,29 @@ struct load_row_per_warp { int lane_id) { int32x4_t emb_res = amdgcn_make_buffer_resource(p_emb_table + row_index * 192); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=6 + emb_data[0] = + llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 32) * sizeof(float)); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 64) * sizeof(float)); + emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 96) * sizeof(float)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 128) * sizeof(float)); + emb_data[5] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 160) * sizeof(float)); +#else + // wave64: dword_per_row=3 emb_data[0] = llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( emb_res, (lane_id + 64) * sizeof(float)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( emb_res, (lane_id + 128) * sizeof(float)); +#endif } }; @@ -348,12 +478,31 @@ struct load_row_per_warp { amdgcn_make_buffer_resource(p_emb_table + row_index * 256); emb_data[0] = llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=8 + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 32) * sizeof(float)); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 64) * sizeof(float)); + emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 96) * sizeof(float)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 128) * sizeof(float)); + emb_data[5] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 160) * sizeof(float)); + emb_data[6] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 192) * sizeof(float)); + emb_data[7] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 224) * sizeof(float)); +#else + // wave64: dword_per_row=4 emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( emb_res, (lane_id + 64) * sizeof(float)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( emb_res, (lane_id + 128) * sizeof(float)); emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( emb_res, (lane_id + 192) * sizeof(float)); +#endif } }; @@ -368,6 +517,28 @@ struct load_row_per_warp { p_emb_table + row_index * 320, sizeof(float) * 320); emb_data[0] = llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=10 + emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 32) * sizeof(float)); + emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 64) * sizeof(float)); + emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 96) * sizeof(float)); + emb_data[4] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 128) * sizeof(float)); + emb_data[5] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 160) * sizeof(float)); + emb_data[6] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 192) * sizeof(float)); + emb_data[7] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 224) * sizeof(float)); + emb_data[8] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 256) * sizeof(float)); + emb_data[9] = llvm_amdgcn_raw_buffer_load_fp32( + emb_res, (lane_id + 288) * sizeof(float)); +#else + // wave64: dword_per_row=5 emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( emb_res, (lane_id + 64) * sizeof(float)); emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( @@ -376,6 +547,7 @@ struct load_row_per_warp { emb_res, (lane_id + 192) * sizeof(float)); emb_data[4] = llvm_amdgcn_raw_buffer_load_fp32( emb_res, (lane_id + 256) * sizeof(float)); +#endif } }; @@ -430,7 +602,13 @@ template <> struct store_row_per_warp { static __device__ void run(const half* acc, half* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=2 + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); +#else llvm_amdgcn_raw_buffer_store_fp16(acc[0], out_res, lane_id * sizeof(half)); +#endif } }; @@ -440,6 +618,13 @@ struct store_row_per_warp { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp16x2( *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=4 + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 2), + out_res, + (lane_id + 32) * sizeof(half2)); +#endif } }; @@ -450,8 +635,19 @@ struct store_row_per_warp { amdgcn_make_buffer_resource(p_output, 160 * sizeof(half)); llvm_amdgcn_raw_buffer_store_fp16x2( *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=5 — 2 half2 + 1 half + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 2), + out_res, + (lane_id + 32) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16( + acc[4], out_res, (lane_id + 128) * sizeof(half)); +#else + // wave64: dword_per_row=3 — 1 half2 + 1 half llvm_amdgcn_raw_buffer_store_fp16( acc[2], out_res, (lane_id + 128) * sizeof(half)); +#endif } }; @@ -461,8 +657,21 @@ struct store_row_per_warp { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp16x2( *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=6 — 3 half2 + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 2), + out_res, + (lane_id + 32) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 4), + out_res, + (lane_id + 64) * sizeof(half2)); +#else + // wave64: dword_per_row=3 — 1 half2 + 1 half llvm_amdgcn_raw_buffer_store_fp16( acc[2], out_res, (lane_id + 128) * sizeof(half)); +#endif } }; @@ -472,10 +681,27 @@ struct store_row_per_warp { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp16x2( *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=8 — 4 half2 llvm_amdgcn_raw_buffer_store_fp16x2( *reinterpret_cast(acc + 2), out_res, + (lane_id + 32) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 4), + out_res, (lane_id + 64) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 6), + out_res, + (lane_id + 96) * sizeof(half2)); +#else + // wave64: dword_per_row=4 — 2 half2 + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 2), + out_res, + (lane_id + 64) * sizeof(half2)); +#endif } }; @@ -486,12 +712,33 @@ struct store_row_per_warp { amdgcn_make_buffer_resource(p_output, 320 * sizeof(half)); llvm_amdgcn_raw_buffer_store_fp16x2( *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=10 — 5 half2 + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 2), + out_res, + (lane_id + 32) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 4), + out_res, + (lane_id + 64) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 6), + out_res, + (lane_id + 96) * sizeof(half2)); + llvm_amdgcn_raw_buffer_store_fp16x2( + *reinterpret_cast(acc + 8), + out_res, + (lane_id + 128) * sizeof(half2)); +#else + // wave64: dword_per_row=5 — 2 half2 + 1 half llvm_amdgcn_raw_buffer_store_fp16x2( *reinterpret_cast(acc + 2), out_res, (lane_id + 64) * sizeof(half2)); llvm_amdgcn_raw_buffer_store_fp16( acc[4], out_res, (lane_id + 256) * sizeof(half)); +#endif } }; @@ -511,6 +758,11 @@ struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=2 + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 32) * sizeof(float)); +#endif } }; @@ -519,8 +771,19 @@ struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=4 + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 32) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[3], out_res, (lane_id + 96) * sizeof(float)); +#else + // wave64: dword_per_row=2 llvm_amdgcn_raw_buffer_store_fp32( acc[1], out_res, (lane_id + 64) * sizeof(float)); +#endif } }; @@ -530,10 +793,23 @@ struct store_row_per_warp { int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=5 + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 32) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[3], out_res, (lane_id + 96) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[4], out_res, (lane_id + 128) * sizeof(float)); +#else + // wave64: dword_per_row=3 llvm_amdgcn_raw_buffer_store_fp32( acc[1], out_res, (lane_id + 64) * sizeof(float)); llvm_amdgcn_raw_buffer_store_fp32( acc[2], out_res, (lane_id + 128) * sizeof(float)); +#endif } }; @@ -542,10 +818,25 @@ struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=6 + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 32) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[3], out_res, (lane_id + 96) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[4], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[5], out_res, (lane_id + 160) * sizeof(float)); +#else + // wave64: dword_per_row=3 llvm_amdgcn_raw_buffer_store_fp32( acc[1], out_res, (lane_id + 64) * sizeof(float)); llvm_amdgcn_raw_buffer_store_fp32( acc[2], out_res, (lane_id + 128) * sizeof(float)); +#endif } }; @@ -554,12 +845,31 @@ struct store_row_per_warp { static __device__ void run(const float* acc, float* p_output, int lane_id) { int32x4_t out_res = amdgcn_make_buffer_resource(p_output); llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=8 + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 32) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[3], out_res, (lane_id + 96) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[4], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[5], out_res, (lane_id + 160) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[6], out_res, (lane_id + 192) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[7], out_res, (lane_id + 224) * sizeof(float)); +#else + // wave64: dword_per_row=4 llvm_amdgcn_raw_buffer_store_fp32( acc[1], out_res, (lane_id + 64) * sizeof(float)); llvm_amdgcn_raw_buffer_store_fp32( acc[2], out_res, (lane_id + 128) * sizeof(float)); llvm_amdgcn_raw_buffer_store_fp32( acc[3], out_res, (lane_id + 192) * sizeof(float)); +#endif } }; @@ -569,6 +879,28 @@ struct store_row_per_warp { int32x4_t out_res = amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); +#if defined(ROCM_WAVE32) + // wave32: dword_per_row=10 + llvm_amdgcn_raw_buffer_store_fp32( + acc[1], out_res, (lane_id + 32) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[2], out_res, (lane_id + 64) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[3], out_res, (lane_id + 96) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[4], out_res, (lane_id + 128) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[5], out_res, (lane_id + 160) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[6], out_res, (lane_id + 192) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[7], out_res, (lane_id + 224) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[8], out_res, (lane_id + 256) * sizeof(float)); + llvm_amdgcn_raw_buffer_store_fp32( + acc[9], out_res, (lane_id + 288) * sizeof(float)); +#else + // wave64: dword_per_row=5 llvm_amdgcn_raw_buffer_store_fp32( acc[1], out_res, (lane_id + 64) * sizeof(float)); llvm_amdgcn_raw_buffer_store_fp32( @@ -577,6 +909,7 @@ struct store_row_per_warp { acc[3], out_res, (lane_id + 192) * sizeof(float)); llvm_amdgcn_raw_buffer_store_fp32( acc[4], out_res, (lane_id + 256) * sizeof(float)); +#endif } }; @@ -726,7 +1059,7 @@ __device__ __forceinline__ void generic_dpp_reduction(data_t& result) { // Use corresponding assebly instruction for dpp reduction in case // of trivial operation with an option to use custom operation -template +template __device__ __forceinline__ void dpp_reduction(data_t& result) { #if defined(__gfx942__) || defined(__gfx90a__) || defined(__gfx950__) if constexpr (std::is_same_v) { From d0eddf20c7d2106016e00701522e99948df6e2a5 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 21 Apr 2026 08:30:27 +0000 Subject: [PATCH 2/2] Add new arch in is_supported_cdna() --- fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h index cc6d9f4d9e..31d0b8b8c2 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/cdna_guard.h @@ -47,7 +47,7 @@ namespace fbgemm_gpu::rocm { [[nodiscard]] inline bool is_supported_cdna() { static const std::unordered_set supported_archs{ - "gfx942", "gfx90a", "gfx950"}; + "gfx942", "gfx90a", "gfx950", "gfx1250"}; int device_id = 0; HIP_CHECK(hipGetDevice(&device_id)); hipDeviceProp_t dev_props;