diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 607f2030d6..4506db95d1 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -631,7 +631,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- if weighted %} auto p_indice_weights_sorted = sorted_indice_weights.data(); {%- endif %} - auto emb_dim = embedding_dim; + auto emb_dim = max_D; constexpr int32_t segment_prefetch = 2; constexpr int32_t segment_unroll = 8; constexpr int32_t segment_split = 0; @@ -765,7 +765,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for index_type in ['int32_t', 'int64_t'] %} - {%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %} + {%- for kEmbeddingDim in range(64, 2049, 64) %} {%- for kWeighDecayMode in [0, 1, 2] %} {{ hip_template_instantiation( emb_type, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 460f5c8fa0..1d5709ea14 100755 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -1257,9 +1257,9 @@ Tensor {{ embedding_cuda_op }}( if (use_hip_kernel && !mixed_D && !cached && supported_weights_type && supported_grad_type && same_precision && rocm::is_supported_cdna()) { constexpr int segments_per_workgroup = 4; - {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} + {%- for kBucketDim in range(64, 2049, 64) %} {%- for kWeightDecayMode in [0, 1, 2] %} - if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) + if (max_D <= {{ kBucketDim }} && max_D > {{ kBucketDim - 64 }} && weight_decay_mode == {{ kWeightDecayMode }}) { warp_per_row_grid_size = div_round_up(sorted_linear_indices_num_runs[0].item(), segments_per_workgroup); blockSize = dim3(256); @@ -1274,7 +1274,7 @@ Tensor {{ embedding_cuda_op }}( kFixedMaxVecsPerThread, kThreadGroupSize, kUseVecBlocking, - {{ kDimSize }}, + {{ kBucketDim }}, {{ kWeightDecayMode }}>; } {%- endfor %} diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index c63f372a74..11b4bb42f0 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -44,7 +44,7 @@ struct rowwise_adagrad_optimizer_t } template - __device__ void update(cache_t* acc, emb_t* weight, index_t row_index) + __device__ void update(cache_t* acc, emb_t* weight, index_t row_index, int32_t emb_dim) { if constexpr(segment_split == 0) { @@ -74,7 +74,7 @@ struct rowwise_adagrad_optimizer_t cache_t avg_square = wave_reduce, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre) / - embedding_dim; + emb_dim; cache_t momentum_new = momentum + avg_square; @@ -139,6 +139,8 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( optimizer_karg_t opt_karg, const float * p_sorted_indice_weights = nullptr) { + // allow dword_per_row to be calculated at compile time based on [embedding_dim] because the + // ceil operation produces the same result as with the runtime param [emb_dim] constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW; constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE; constexpr uint32_t length_mask = ~(segment_unroll - 1); @@ -220,13 +222,13 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[0] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); if constexpr (!weighted){ #pragma unroll for(int j = 2; j < segment_unroll; j += 2) @@ -238,7 +240,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); @@ -247,7 +249,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j + 1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); } accumulate_row_per_warp::run( @@ -274,7 +276,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); @@ -283,7 +285,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j + 1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); } accumulate_row_per_warp::run( @@ -307,13 +309,13 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[0] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); table_index = infos[1] >> info_B_num_bits; bag_index = infos[1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); if constexpr (!weighted) { #pragma unroll @@ -326,7 +328,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id); @@ -335,7 +337,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j + 1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); } accumulate_row_per_warp::run( @@ -353,7 +355,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]); @@ -362,7 +364,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( bag_index = infos[j + 1] & info_B_mask; load_row_per_warp::run( - &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); } accumulate_row_per_warp::run( @@ -385,7 +387,7 @@ L_tail_grad_acc: bag_index = infos[0] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id); @@ -403,7 +405,7 @@ L_tail_grad_acc: bag_index = infos[0] & info_B_mask; load_row_per_warp::run( - &grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id); + &grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim); accumulate_row_per_warp::run( &grad_acc[0], &grad_data[0], lane_id, indice_weights[0]); @@ -414,10 +416,10 @@ L_tail_grad_acc: // load the old emb weight data load_row_per_warp::run( - &emb_data[0], emb_idx, p_emb_table, lane_id); + &emb_data[0], emb_idx, p_emb_table, lane_id, emb_dim); optimizer_t optimizer(opt_karg); - optimizer.template update(grad_acc, emb_data, emb_idx); + optimizer.template update(grad_acc, emb_data, emb_idx, emb_dim); - store_row_per_warp::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id); + store_row_per_warp::run(&emb_data[0], emb_idx, p_emb_table, lane_id, emb_dim); } } // namespace fbgemm_gpu::rocm diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index d7d6954d00..95a0e43e34 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -145,7 +145,8 @@ struct load_row_per_warp { emb_t* emb_data, index_t row_index, const emb_t* p_emb_table, - int lane_id) { + int lane_id, + int32_t runtime_dim) { // Types are not supported, but we need an instance of run method to avoid // run-time .so symbol failure. Currently, the kernel dispatch for // unsupported type is guarded on host side @@ -160,83 +161,73 @@ struct load_row_per_warp { } }; -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 64); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp16(emb_res, lane_id * sizeof(half)); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 128); - *reinterpret_cast(emb_data) = - llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); - } -}; - -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource( - p_emb_table + row_index * 160, sizeof(half) * 160); - *reinterpret_cast(emb_data) = - llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half)); - } -}; +template + requires(std::is_same_v || std::is_same_v) + struct load_row_per_warp + { + static __device__ void run( + emb_t *emb_data, + index_t row_index, + const emb_t *p_emb_table, + int lane_id, + int32_t runtime_dim) + { + int32x4_t emb_res = + amdgcn_make_buffer_resource(p_emb_table + row_index * runtime_dim, sizeof(emb_t) * runtime_dim); + + int offset = 0; + int reg_idx = 0; + + if constexpr (std::is_same_v) { + // For half: vector load 128 elements per iteration (64 threads * 2 halfs) + constexpr int num_vector_ops = embedding_dim / 128; + + #pragma unroll + for(int i = 0; i < num_vector_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half2); + + half2 val = llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, voffset); + // Unpack into register array + emb_data[reg_idx] = val.x; + emb_data[reg_idx + 1] = val.y; + + offset += 128; + reg_idx += 2; + } -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); - *reinterpret_cast(emb_data) = - llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half)); - } -}; + // load remaining elements (scalar loads) + constexpr int tail_start = num_vector_ops * 128; + constexpr int num_scalar_ops = (embedding_dim - tail_start + 63) / 64; -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 256); - *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); - *reinterpret_cast(&emb_data[2]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2)); - } -}; + #pragma unroll + for(int i = 0; i < num_scalar_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half); -template -struct load_row_per_warp { - static __device__ void - run(half* emb_data, index_t row_index, const half* p_emb_table, int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource( - p_emb_table + row_index * 320, sizeof(half) * 320); - *reinterpret_cast(&emb_data[0]) = - llvm_amdgcn_raw_buffer_load_fp16x2(emb_res, lane_id * sizeof(half2)); - *reinterpret_cast(&emb_data[2]) = - llvm_amdgcn_raw_buffer_load_fp16x2( - emb_res, (lane_id + 64) * sizeof(half2)); - emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 256) * sizeof(half)); - } -}; + // we don't care about loading past the end of the embedding row because the + // emb_res buffer resource ensures that only 0s will be loaded + emb_data[reg_idx] = llvm_amdgcn_raw_buffer_load_fp16(emb_res, voffset); + offset += 64; + reg_idx += 1; + } + } else if constexpr (std::is_same_v) { + // For float: load 64 elements per iteration (64 threads * 1 float) + constexpr int num_ops = (embedding_dim + 63) / 64; + + #pragma unroll + for(int i = 0; i < num_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(float); + + // as above, we don't care about loading past the end of the embedding row + emb_data[reg_idx] = llvm_amdgcn_raw_buffer_load_fp32(emb_res, voffset); + offset += 64; + reg_idx += 1; + } + } + } + }; template struct load_row_per_warp { @@ -244,120 +235,14 @@ struct load_row_per_warp { c10::Half* emb_data, index_t row_index, const c10::Half* p_emb_table, - int lane_id) { + int lane_id, + int32_t runtime_dim) { load_row_per_warp::run( reinterpret_cast(emb_data), row_index, reinterpret_cast(p_emb_table), - lane_id); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 64); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 128); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 64) * sizeof(float)); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource( - p_emb_table + row_index * 160, sizeof(float) * 160); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 64) * sizeof(float)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 128) * sizeof(float)); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 192); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 64) * sizeof(float)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 128) * sizeof(float)); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = - amdgcn_make_buffer_resource(p_emb_table + row_index * 256); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 64) * sizeof(float)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 128) * sizeof(float)); - emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 192) * sizeof(float)); - } -}; - -template -struct load_row_per_warp { - static __device__ void run( - float* emb_data, - index_t row_index, - const float* p_emb_table, - int lane_id) { - int32x4_t emb_res = amdgcn_make_buffer_resource( - p_emb_table + row_index * 320, sizeof(float) * 320); - emb_data[0] = - llvm_amdgcn_raw_buffer_load_fp32(emb_res, lane_id * sizeof(float)); - emb_data[1] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 64) * sizeof(float)); - emb_data[2] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 128) * sizeof(float)); - emb_data[3] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 192) * sizeof(float)); - emb_data[4] = llvm_amdgcn_raw_buffer_load_fp32( - emb_res, (lane_id + 256) * sizeof(float)); + lane_id, + runtime_dim); } }; @@ -391,9 +276,14 @@ struct accumulate_row_per_warp { } }; -template +template struct store_row_per_warp { - static __device__ void run(const emb_t* acc, emb_t* p_output, int lane_id) { + static __device__ void run( + const emb_t* acc, + index_t row_index, + emb_t* p_output_table, + int lane_id, + int runtime_dim) { // Types are not supported, but we need an instance of run method to avoid // run-time .so symbol failure. Currently, the kernel dispatch for // unsupported type is guarded on host function @@ -408,157 +298,82 @@ struct store_row_per_warp { } }; -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp16(acc[0], out_res, lane_id * sizeof(half)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { +template + requires(std::is_same_v || std::is_same_v) +struct store_row_per_warp { + static __device__ void run(const emb_t* acc, index_t row_index, emb_t* p_output_table, int lane_id, int32_t runtime_dim) { int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, 160 * sizeof(half)); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); - llvm_amdgcn_raw_buffer_store_fp16( - acc[2], out_res, (lane_id + 128) * sizeof(half)); - } -}; + amdgcn_make_buffer_resource(p_output_table + row_index * runtime_dim, sizeof(emb_t) * runtime_dim); -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); - llvm_amdgcn_raw_buffer_store_fp16( - acc[2], out_res, (lane_id + 128) * sizeof(half)); - } -}; + int offset = 0; + int reg_idx = 0; -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc + 2), - out_res, - (lane_id + 64) * sizeof(half2)); - } -}; + if constexpr (std::is_same_v) { + // For half: vector store 128 elements per iteration (64 threads * 2 halfs) + constexpr int num_vector_ops = embedding_dim / 128; -template <> -struct store_row_per_warp { - static __device__ void run(const half* acc, half* p_output, int lane_id) { - int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, 320 * sizeof(half)); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc), out_res, lane_id * sizeof(half2)); - llvm_amdgcn_raw_buffer_store_fp16x2( - *reinterpret_cast(acc + 2), - out_res, - (lane_id + 64) * sizeof(half2)); - llvm_amdgcn_raw_buffer_store_fp16( - acc[4], out_res, (lane_id + 256) * sizeof(half)); - } -}; + #pragma unroll + for(int i = 0; i < num_vector_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half2); -template -struct store_row_per_warp { - static __device__ void - run(const c10::Half* emb_data, c10::Half* p_emb_table, int lane_id) { - store_row_per_warp::run( - reinterpret_cast(emb_data), - reinterpret_cast(p_emb_table), - lane_id); - } -}; + // Pack two half values into half2 for vectorized store + half2 val; + val.x = acc[reg_idx]; + val.y = acc[reg_idx + 1]; + llvm_amdgcn_raw_buffer_store_fp16x2(val, out_res, voffset); -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - } -}; + offset += 128; + reg_idx += 2; + } -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[1], out_res, (lane_id + 64) * sizeof(float)); - } -}; + // store remaining elements (scalar stores) + constexpr int tail_start = num_vector_ops * 128; + constexpr int num_scalar_ops = (embedding_dim - tail_start + 63) / 64; -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, sizeof(float) * 160); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[1], out_res, (lane_id + 64) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); - } -}; + #pragma unroll + for(int i = 0; i < num_scalar_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(half); -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[1], out_res, (lane_id + 64) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); - } -}; - -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = amdgcn_make_buffer_resource(p_output); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[1], out_res, (lane_id + 64) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[3], out_res, (lane_id + 192) * sizeof(float)); + // we don't care about storing past the end of the embedding row because the + // out_res buffer resource ensures that those writes will be ignored + llvm_amdgcn_raw_buffer_store_fp16(acc[reg_idx], out_res, voffset); + offset += 64; + reg_idx += 1; + } + } else if constexpr (std::is_same_v) { + // For float: store 64 elements per iteration (64 threads * 1 float) + constexpr int num_ops = (embedding_dim + 63) / 64; + + #pragma unroll + for(int i = 0; i < num_ops; i++) + { + int voffset = (offset + lane_id) * sizeof(float); + + // as above, we don't care about storing past the end of the embedding row + llvm_amdgcn_raw_buffer_store_fp32(acc[reg_idx], out_res, voffset); + offset += 64; + reg_idx += 1; + } + } } }; -template <> -struct store_row_per_warp { - static __device__ void run(const float* acc, float* p_output, int lane_id) { - int32x4_t out_res = - amdgcn_make_buffer_resource(p_output, sizeof(float) * 320); - llvm_amdgcn_raw_buffer_store_fp32(acc[0], out_res, lane_id * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[1], out_res, (lane_id + 64) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[2], out_res, (lane_id + 128) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[3], out_res, (lane_id + 192) * sizeof(float)); - llvm_amdgcn_raw_buffer_store_fp32( - acc[4], out_res, (lane_id + 256) * sizeof(float)); +template +struct store_row_per_warp { + static __device__ void run( + const c10::Half* emb_data, + index_t row_index, + c10::Half* p_emb_table, + int lane_id, + int32_t runtime_dim) { + store_row_per_warp::run( + reinterpret_cast(emb_data), + row_index, + reinterpret_cast(p_emb_table), + lane_id, + runtime_dim); } };