From ab03f2dcc0e9d889ccce9192cbe60fe911d62a26 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Tue, 25 Nov 2025 11:09:42 +0000 Subject: [PATCH 1/8] split_embeddings_common: adds generalized load and store functions for warp-per-row kernel with half datatype --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index d7d6954d00..9d92324f68 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -160,6 +160,58 @@ struct load_row_per_warp { } }; +template + struct load_row_per_warp + { + static __device__ void run( + half *emb_data, + index_t row_index, + const half *p_emb_table, + int lane_id) + { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim, sizeof(half) * embedding_dim); + + int offset = 0; + int reg_idx = 0; + + int dim_remaining = embedding_dim; + + // vector load as many elements as possible + constexpr int num_vector_ops = embedding_dim / 128; + + #pragma unroll + for(int i = 0; i < num_vector_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half2); + + half2 val = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, voffset); + // Unpack into register array + emb_data[reg_idx] = val.x; + emb_data[reg_idx + 1] = val.y; + + offset += 128; + reg_idx += 2; + dim_remaining -= 128; + } + + // load remaining elements (scalar loads) + constexpr int tail_start = num_vector_ops * 128; + constexpr int num_scalar_ops = (embedding_dim - tail_start + 63) / 64; + + #pragma unroll + for(int i = 0; i < num_scalar_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half); + + emb_data[reg_idx] = llvm_amdgcn_raw_buffer_load_fp16(emb_res, voffset); + offset += 64; + reg_idx += 1; + dim_remaining -= 64; + } + } + }; + template struct load_row_per_warp { static __device__ void @@ -477,6 +529,53 @@ struct store_row_per_warp { } }; +template +struct store_row_per_warp { + static __device__ void run(const half* acc, half* p_output, int lane_id) { + int32x4_t out_res = + amdgcn_make_buffer_resource(p_output, sizeof(half) * embedding_dim); + + int offset = 0; + int reg_idx = 0; + + int dim_remaining = embedding_dim; + + // vector store as many elements as possible + constexpr int num_vector_ops = embedding_dim / 128; + + #pragma unroll + for(int i = 0; i < num_vector_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half2); + + // Pack two half values into half2 for vectorized store + half2 val; + val.x = acc[reg_idx]; + val.y = acc[reg_idx + 1]; + llvm_amdgcn_raw_buffer_store_fp16x2(val, out_res, voffset); + + offset += 128; + reg_idx += 2; + dim_remaining -= 128; + } + + // store remaining elements (scalar stores) + constexpr int tail_start = num_vector_ops * 128; + constexpr int num_scalar_ops = (embedding_dim - tail_start + 63) / 64; + + #pragma unroll + for(int i = 0; i < num_scalar_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half); + + llvm_amdgcn_raw_buffer_store_fp16(acc[reg_idx], out_res, voffset); + offset += 64; + reg_idx += 1; + dim_remaining -= 64; + } + } +}; + template struct store_row_per_warp { static __device__ void From 8c71f5f772de76efd97f54d7abdc214147b19d57 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Tue, 25 Nov 2025 14:02:07 +0000 Subject: [PATCH 2/8] split_embeddings_common: generalizes load and store function for warp-per-row kernel to float and half datatypes --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 154 ++++++++++-------- 1 file changed, 88 insertions(+), 66 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 9d92324f68..0f38566c84 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -160,54 +160,65 @@ struct load_row_per_warp { } }; -template - struct load_row_per_warp +template + requires(std::is_same_v || std::is_same_v) + struct load_row_per_warp { static __device__ void run( - half *emb_data, + emb_t *emb_data, index_t row_index, - const half *p_emb_table, + const emb_t *p_emb_table, int lane_id) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim, sizeof(half) * embedding_dim); + amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim, sizeof(emb_t) * embedding_dim); int offset = 0; int reg_idx = 0; - int dim_remaining = embedding_dim; - - // vector load as many elements as possible - constexpr int num_vector_ops = embedding_dim / 128; - - #pragma unroll - for(int i = 0; i < num_vector_ops; i++) - { - int voffset = (offset + lane_id) * sizeof(half2); + if constexpr (std::is_same_v) { + // For half: vector load 128 elements per iteration (64 threads * 2 halfs) + constexpr int num_vector_ops = embedding_dim / 128; + + #pragma unroll + for(int i = 0; i < num_vector_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half2); + + half2 val = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, voffset); + // Unpack into register array + emb_data[reg_idx] = val.x; + emb_data[reg_idx + 1] = val.y; + + offset += 128; + reg_idx += 2; + } - half2 val = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, voffset); - // Unpack into register array - emb_data[reg_idx] = val.x; - emb_data[reg_idx + 1] = val.y; + // load remaining elements (scalar loads) + constexpr int tail_start = num_vector_ops * 128; + constexpr int num_scalar_ops = (embedding_dim - tail_start + 63) / 64; - offset += 128; - reg_idx += 2; - dim_remaining -= 128; - } + #pragma unroll + for(int i = 0; i < num_scalar_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half); - // load remaining elements (scalar loads) - constexpr int tail_start = num_vector_ops * 128; - constexpr int num_scalar_ops = (embedding_dim - tail_start + 63) / 64; - - #pragma unroll - for(int i = 0; i < num_scalar_ops; i++) - { - int voffset = (offset + lane_id) * sizeof(half); - - emb_data[reg_idx] = llvm_amdgcn_raw_buffer_load_fp16(emb_res, voffset); - offset += 64; - reg_idx += 1; - dim_remaining -= 64; + emb_data[reg_idx] = llvm_amdgcn_raw_buffer_load_fp16(emb_res, voffset); + offset += 64; + reg_idx += 1; + } + } else if constexpr (std::is_same_v) { + // For float: load 64 elements per iteration (64 threads * 1 float) + constexpr int num_ops = (embedding_dim + 63) / 64; + + #pragma unroll + for(int i = 0; i < num_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(float); + emb_data[reg_idx] = llvm_amdgcn_raw_buffer_load_fp32(emb_res, voffset); + offset += 64; + reg_idx += 1; + } } } }; @@ -529,49 +540,60 @@ struct store_row_per_warp { } }; -template -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { +template + requires(std::is_same_v || std::is_same_v) +struct store_row_per_warp { + static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, sizeof(half) * embedding_dim); + amdgcn_make_buffer_resource(p_output, sizeof(emb_t) * embedding_dim); int offset = 0; int reg_idx = 0; - int dim_remaining = embedding_dim; + if constexpr (std::is_same_v) { + // For half: vector store 128 elements per iteration (64 threads * 2 halfs) + constexpr int num_vector_ops = embedding_dim / 128; - // vector store as many elements as possible - constexpr int num_vector_ops = embedding_dim / 128; + #pragma unroll + for(int i = 0; i < num_vector_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half2); - #pragma unroll - for(int i = 0; i < num_vector_ops; i++) - { - int voffset = (offset + lane_id) * sizeof(half2); + // Pack two half values into half2 for vectorized store + half2 val; + val.x = acc[reg_idx]; + val.y = acc[reg_idx + 1]; + llvm_amdgcn_raw_buffer_store_fp16x2(val, out_res, voffset); - // Pack two half values into half2 for vectorized store - half2 val; - val.x = acc[reg_idx]; - val.y = acc[reg_idx + 1]; - llvm_amdgcn_raw_buffer_store_fp16x2(val, out_res, voffset); + offset += 128; + reg_idx += 2; + } - offset += 128; - reg_idx += 2; - dim_remaining -= 128; - } + // store remaining elements (scalar stores) + constexpr int tail_start = num_vector_ops * 128; + constexpr int num_scalar_ops = (embedding_dim - tail_start + 63) / 64; - // store remaining elements (scalar stores) - constexpr int tail_start = num_vector_ops * 128; - constexpr int num_scalar_ops = (embedding_dim - tail_start + 63) / 64; + #pragma unroll + for(int i = 0; i < num_scalar_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half); - #pragma unroll - for(int i = 0; i < num_scalar_ops; i++) - { - int voffset = (offset + lane_id) * sizeof(half); + llvm_amdgcn_raw_buffer_store_fp16(acc[reg_idx], out_res, voffset); + offset += 64; + reg_idx += 1; + } + } else if constexpr (std::is_same_v) { + // For float: store 64 elements per iteration (64 threads * 1 float) + constexpr int num_ops = (embedding_dim + 63) / 64; - llvm_amdgcn_raw_buffer_store_fp16(acc[reg_idx], out_res, voffset); - offset += 64; - reg_idx += 1; - dim_remaining -= 64; + #pragma unroll + for(int i = 0; i < num_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(float); + llvm_amdgcn_raw_buffer_store_fp32(acc[reg_idx], out_res, voffset); + offset += 64; + reg_idx += 1; + } } } }; From 54daa91429e4c65ee21857f2ad6a6159b3fc194f Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Wed, 26 Nov 2025 15:39:44 +0000 Subject: [PATCH 3/8] adds 96 as a dimension size for optimized HIP kernel in Jinja2 templates --- .../backward/embedding_backward_split_kernel_warp_template.cu | 2 +- .../training/backward/embedding_backward_split_template.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 607f2030d6..6107822e7b 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -765,7 +765,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} + {%- for kEmbeddingDim in [64, 96, 128, 160, 192, 256, 320] %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 460f5c8fa0..7e481b3a7d 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1257,7 +1257,7 @@ Tensor {{ embedding_cuda_op }}( if (use_hip_kernel && !mixed_D && !cached && supported_weights_type && supported_grad_type && same_precision && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} + {%- for kDimSize in [64, 96, 128, 160, 192, 256, 320] %} {%- for kWeightDecayMode in [0, 1, 2] %} if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) { From 9bd40b669cf1a952b6f98c55e1be7802ed97d994 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 28 Nov 2025 10:59:23 +0000 Subject: [PATCH 4/8] split_embeddings_common: pass embedding vector dimension at runtime into load and store functions. Removes instantiations for special dimension values --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 355 +----------------- 1 file changed, 18 insertions(+), 337 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 0f38566c84..c91d652c7b 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -145,7 +145,8 @@ struct load_row_per_warp { emb_t* emb_data, index_t row_index, const emb_t* p_emb_table, - int lane_id) { + int lane_id, + int32_t runtime_dim) { // Types are not supported, but we need an instance of run method to avoid // run-time .so symbol failure. Currently, the kernel dispatch for // unsupported type is guarded on host side @@ -168,10 +169,11 @@ template emb_t *emb_data, index_t row_index, const emb_t *p_emb_table, - int lane_id) + int lane_id, + int32_t runtime_dim) { int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim, sizeof(emb_t) * embedding_dim); + amdgcn_make_buffer_resource(p_emb_table + row_index * runtime_dim, sizeof(emb_t) * runtime_dim); int offset = 0; int reg_idx = 0; @@ -179,7 +181,7 @@ template if constexpr (std::is_same_v) { // For half: vector load 128 elements per iteration (64 threads * 2 halfs) constexpr int num_vector_ops = embedding_dim / 128; - + #pragma unroll for(int i = 0; i < num_vector_ops; i++) { @@ -203,6 +205,8 @@ template { int voffset = (offset + lane_id) * sizeof(half); + // we don't care about loading past the end of the embedding row because the + // emb_res buffer resource ensures that only 0s will be loaded emb_data[reg_idx] = llvm_amdgcn_raw_buffer_load_fp16(emb_res, voffset); offset += 64; reg_idx += 1; @@ -210,11 +214,13 @@ template } else if constexpr (std::is_same_v) { // For float: load 64 elements per iteration (64 threads * 1 float) constexpr int num_ops = (embedding_dim + 63) / 64; - + #pragma unroll for(int i = 0; i < num_ops; i++) { int voffset = (offset + lane_id) * sizeof(float); + + // as above, we don't care about loading past the end of the embedding row emb_data[reg_idx] = llvm_amdgcn_raw_buffer_load_fp32(emb_res, voffset); offset += 64; reg_idx += 1; @@ -223,84 +229,6 @@ template } }; -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 64); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half)); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 128); - *reinterpret_cast(emb_data) = - llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource( - p_emb_table + row_index * 160, sizeof(half) * 160); - *reinterpret_cast(emb_data) = - llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half)); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); - *reinterpret_cast(emb_data) = - llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half)); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 256); - *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); - *reinterpret_cast(&emb_data[2]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2)); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource( - p_emb_table + row_index * 320, sizeof(half) * 320); - *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); - *reinterpret_cast(&emb_data[2]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2)); - emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 256) * sizeof(half)); - } -}; - template struct load_row_per_warp { static __device__ void run( @@ -316,114 +244,6 @@ struct load_row_per_warp { } }; -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 64); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 128); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 64) * sizeof(float)); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource( - p_emb_table + row_index * 160, sizeof(float) * 160); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 64) * sizeof(float)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 128) * sizeof(float)); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 64) * sizeof(float)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 128) * sizeof(float)); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 256); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 64) * sizeof(float)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 128) * sizeof(float)); - emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 192) * sizeof(float)); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource( - p_emb_table + row_index * 320, sizeof(float) * 320); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 64) * sizeof(float)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 128) * sizeof(float)); - emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 192) * sizeof(float)); - emb_data[4] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 256) * sizeof(float)); - } -}; - template < typename emb_t, int32_t embedding_dim, @@ -456,7 +276,7 @@ struct accumulate_row_per_warp { template struct store_row_per_warp { - static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { + static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id, int runtime_dim) { // Types are not supported, but we need an instance of run method to avoid // run-time .so symbol failure. Currently, the kernel dispatch for // unsupported type is guarded on host function @@ -471,81 +291,12 @@ struct store_row_per_warp { } }; -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp16(acc[0], out_res, lane_id * sizeof(half)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, 160 * sizeof(half)); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); - llvm_amdgcn_raw_buffer_store_fp16( - acc[2], out_res, (lane_id + 128) * sizeof(half)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); - llvm_amdgcn_raw_buffer_store_fp16( - acc[2], out_res, (lane_id + 128) * sizeof(half)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc + 2), - out_res, - (lane_id + 64) * sizeof(half2)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, 320 * sizeof(half)); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc + 2), - out_res, - (lane_id + 64) * sizeof(half2)); - llvm_amdgcn_raw_buffer_store_fp16( - acc[4], out_res, (lane_id + 256) * sizeof(half)); - } -}; - template requires(std::is_same_v || std::is_same_v) struct store_row_per_warp { - static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { + static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id, int runtime_dim) { int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, sizeof(emb_t) * embedding_dim); + amdgcn_make_buffer_resource(p_output, sizeof(emb_t) * runtime_dim); int offset = 0; int reg_idx = 0; @@ -578,6 +329,8 @@ struct store_row_per_warp { { int voffset = (offset + lane_id) * sizeof(half); + // we don't care about storing past the end of the embedding row because the + // out_res buffer resource ensures that those writes will be ignored llvm_amdgcn_raw_buffer_store_fp16(acc[reg_idx], out_res, voffset); offset += 64; reg_idx += 1; @@ -590,6 +343,8 @@ struct store_row_per_warp { for(int i = 0; i < num_ops; i++) { int voffset = (offset + lane_id) * sizeof(float); + + // as above, we don't care about storing past the end of the embedding row llvm_amdgcn_raw_buffer_store_fp32(acc[reg_idx], out_res, voffset); offset += 64; reg_idx += 1; @@ -609,80 +364,6 @@ struct store_row_per_warp { } }; -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[1], out_res, (lane_id + 64) * sizeof(float)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[1], out_res, (lane_id + 64) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[1], out_res, (lane_id + 64) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[1], out_res, (lane_id + 64) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[3], out_res, (lane_id + 192) * sizeof(float)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[1], out_res, (lane_id + 64) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[3], out_res, (lane_id + 192) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[4], out_res, (lane_id + 256) * sizeof(float)); - } -}; - // Helper function to pack fp16 and fp32 into int to further pass // into mov_dpp and readfirstlane() template From 79b52e88c3440afe09c9be925b8dd3a9bf7f0287 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 28 Nov 2025 13:01:07 +0000 Subject: [PATCH 5/8] refactors split_tbe_backward_hip_kernel to use runtime embedding dimension parameter --- ..._backward_split_device_kernel_template.hip | 40 ++++++++++--------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index c63f372a74..298abe6159 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -44,7 +44,7 @@ struct rowwise_adagrad_optimizer_t } template - __device__ void update(cache_t* acc, emb_t* weight, index_t row_index) + __device__ void update(cache_t* acc, emb_t* weight, index_t row_index, int32_t emb_dim) { if constexpr(segment_split == 0) { @@ -74,7 +74,7 @@ struct rowwise_adagrad_optimizer_t cache_t avg_square = wave_reduce, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre) / - embedding_dim; + emb_dim; cache_t momentum_new = momentum + avg_square; @@ -139,6 +139,8 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) { + // allow dword_per_row to be calculated at compile time based on [embedding_dim] because the + // ceil operation produces the same result as with the runtime param [emb_dim] constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; constexpr uint32_t length_mask = ~(segment_unroll - 1); @@ -220,13 +222,13 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[0] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); if constexpr (!weighted){ #pragma unroll for(int j = 2; j < segment_unroll; j += 2) @@ -238,7 +240,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); @@ -247,7 +249,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j + 1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); } accumulate_row_per_warp::run( @@ -274,7 +276,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); @@ -283,7 +285,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j + 1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); } accumulate_row_per_warp::run( @@ -307,13 +309,13 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[0] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); if constexpr (!weighted) { #pragma unroll @@ -326,7 +328,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); @@ -335,7 +337,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j + 1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); } accumulate_row_per_warp::run( @@ -353,7 +355,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); @@ -362,7 +364,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j + 1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); } accumulate_row_per_warp::run( @@ -385,7 +387,7 @@ L_tail_grad_acc: bag_index = infos[0] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -403,7 +405,7 @@ L_tail_grad_acc: bag_index = infos[0] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); @@ -414,10 +416,10 @@ L_tail_grad_acc: // load the old emb weight data load_row_per_warp::run( - &emb_data[0], emb_idx, p_emb_table, lane_id); + &emb_data[0], emb_idx, p_emb_table, lane_id, emb_dim); optimizer_t optimizer(opt_karg); - optimizer.template update(grad_acc, emb_data, emb_idx); + optimizer.template update(grad_acc, emb_data, emb_idx, emb_dim); - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); + store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * emb_dim, lane_id, emb_dim); } } // namespace fbgemm_gpu::rocm From 8a40915182bdb71435fb7335edad7ac920495de1 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 28 Nov 2025 15:15:51 +0000 Subject: [PATCH 6/8] generalizes HIP global backward warp per row kernel and its invocation to arbitrary embeddins dimensions --- .../embedding_backward_split_kernel_warp_template.cu | 4 ++-- .../training/backward/embedding_backward_split_template.cu | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 6107822e7b..4506db95d1 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -631,7 +631,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- if weighted %} auto p_indice_weights_sorted = sorted_indice_weights.data(); {%- endif %} - auto emb_dim = embedding_dim; + auto emb_dim = max_D; constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; @@ -765,7 +765,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 96, 128, 160, 192, 256, 320] %} + {%- for kEmbeddingDim in range(64, 2049, 64) %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 7e481b3a7d..1d5709ea14 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1257,9 +1257,9 @@ Tensor {{ embedding_cuda_op }}( if (use_hip_kernel && !mixed_D && !cached && supported_weights_type && supported_grad_type && same_precision && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 96, 128, 160, 192, 256, 320] %} + {%- for kBucketDim in range(64, 2049, 64) %} {%- for kWeightDecayMode in [0, 1, 2] %} - if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) + if (max_D <= {{ kBucketDim }} && max_D > {{ kBucketDim - 64 }} && weight_decay_mode == {{ kWeightDecayMode }}) { warp_per_row_grid_size = div_round_up(sorted_linear_indices_num_runs[0].item(), segments_per_workgroup); blockSize = dim3(256); @@ -1274,7 +1274,7 @@ Tensor {{ embedding_cuda_op }}( kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking, - {{ kDimSize }}, + {{ kBucketDim }}, {{ kWeightDecayMode }}>; } {%- endfor %} From f1be5d8bfae5b5f0144bd00461d4d3e6a4c90028 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 28 Nov 2025 15:18:36 +0000 Subject: [PATCH 7/8] split_embeddings_common: adds runtime embedding dimension param in wrapper load and store functions --- .../fbgemm_gpu/rocm/split_embeddings_common.h | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index c91d652c7b..f098486ae6 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -235,12 +235,14 @@ struct load_row_per_warp { c10::Half* emb_data, index_t row_index, const c10::Half* p_emb_table, - int lane_id) { + int lane_id, + int32_t runtime_dim) { load_row_per_warp::run( reinterpret_cast(emb_data), row_index, reinterpret_cast(p_emb_table), - lane_id); + lane_id, + runtime_dim); } }; @@ -355,12 +357,16 @@ struct store_row_per_warp { template struct store_row_per_warp { - static __device__ void - run(const c10::Half* emb_data, c10::Half* p_emb_table, int lane_id) { + static __device__ void run( + const c10::Half* emb_data, + c10::Half* p_emb_table, + int lane_id, + int32_t runtime_dim) { store_row_per_warp::run( reinterpret_cast(emb_data), reinterpret_cast(p_emb_table), - lane_id); + lane_id, + runtime_dim); } }; From d56ce53e11e5bed9abd83dde91ccd579c59c99a5 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 28 Nov 2025 18:14:33 +0000 Subject: [PATCH 8/8] unifies pointer argument style in load and store warp_per_row HIP kernels --- ..._backward_split_device_kernel_template.hip | 2 +- .../fbgemm_gpu/rocm/split_embeddings_common.h | 25 ++++++++++++------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index 298abe6159..11b4bb42f0 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -420,6 +420,6 @@ L_tail_grad_acc: optimizer_t optimizer(opt_karg); optimizer.template update(grad_acc, emb_data, emb_idx, emb_dim); - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * emb_dim, lane_id, emb_dim); + store_row_per_warp::run(&emb_data[0], emb_idx, p_emb_table, lane_id, emb_dim); } } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index f098486ae6..95a0e43e34 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -276,9 +276,14 @@ struct accumulate_row_per_warp { } }; -template +template struct store_row_per_warp { - static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id, int runtime_dim) { + static __device__ void run( + const emb_t* acc, + index_t row_index, + emb_t* p_output_table, + int lane_id, + int runtime_dim) { // Types are not supported, but we need an instance of run method to avoid // run-time .so symbol failure. Currently, the kernel dispatch for // unsupported type is guarded on host function @@ -293,12 +298,12 @@ struct store_row_per_warp { } }; -template +template requires(std::is_same_v || std::is_same_v) -struct store_row_per_warp { - static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id, int runtime_dim) { +struct store_row_per_warp { + static __device__ void run(const emb_t* acc, index_t row_index, emb_t* p_output_table, int lane_id, int32_t runtime_dim) { int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, sizeof(emb_t) * runtime_dim); + amdgcn_make_buffer_resource(p_output_table + row_index * runtime_dim, sizeof(emb_t) * runtime_dim); int offset = 0; int reg_idx = 0; @@ -355,15 +360,17 @@ struct store_row_per_warp { } }; -template -struct store_row_per_warp { +template +struct store_row_per_warp { static __device__ void run( const c10::Half* emb_data, + index_t row_index, c10::Half* p_emb_table, int lane_id, int32_t runtime_dim) { - store_row_per_warp::run( + store_row_per_warp::run( reinterpret_cast(emb_data), + row_index, reinterpret_cast(p_emb_table), lane_id, runtime_dim);