From 174394eac6e62dd3c9d10412cac4369879d50b52 Mon Sep 17 00:00:00 2001 From: jichen Date: Thu, 21 Aug 2025 10:26:35 +0000 Subject: [PATCH] apply Vec4T on vbe forward --- ...embedding_forward_split_kernel_template.cu | 21 ++----------------- .../embedding_forward_split_template.cu | 5 ----- 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu index 0122cfcee9..ac1e771fc6 100644 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -84,11 +84,7 @@ using namespace fbgemm_gpu; {#-/* Set the weights row accessor */#} - {%- if is_rocm %} - const auto weights_row = rocm::WeightRowAccessorVec2 - {%- else %} const auto weights_row = WeightRowAccessor - {%- endif %} < {{ 'cache_t' if from_cache else 'emb_t' }}, cache_t @@ -182,11 +178,7 @@ using namespace fbgemm_gpu; {%- endif %} {#-/* Set the weights row accessor */#} - {%- if is_rocm %} - const auto weights_row = rocm::WeightRowAccessorVec2 - {%- else %} const auto weights_row = WeightRowAccessor - {%- endif %} < {{ 'cache_t' if from_cache else 'emb_t' }}, cache_t @@ -319,7 +311,7 @@ using namespace fbgemm_gpu; {%- if is_rocm %} {%- if not nobag %} - rocm::Vec2T vals[kManualUnrollLength * kMaxVecsPerThread]; + Vec4T vals[kManualUnrollLength * kMaxVecsPerThread]; {%- endif %} // Iterate over kThreadGroupSize indices for (auto outer_j = 0; outer_j < kThreadGroupSize && l_start + outer_j < L - L % kManualUnrollLength; outer_j += kManualUnrollLength) @@ -633,12 +625,7 @@ batch_index_select_dim0_codegen_forward_kernel( #endif // Elements are processed 4 at a time through fbgemm_gpu::Vec4 (CUDA float4, 16 bytes) - // for CUDA devices and 2 at a time for ROCm - {%- if is_rocm %} - constexpr int VEC_WIDTH = 2; - {%- else %} constexpr int VEC_WIDTH = 4; - {%- endif %} {%- if is_rocm %} // Unroll factor for ROCm devices constexpr int kManualUnrollLength = 4; @@ -743,12 +730,8 @@ batch_index_select_dim0_codegen_forward_kernel( const float inv_L = (mean_pooling && L != 0) ? static_cast(1.0) / L: static_cast(1.0); // Set up the accumulator buffer - {%- if is_rocm %} - rocm::Vec2T accumulators[kMaxVecsPerThread]; - {%- else %} Vec4T accumulators[kMaxVecsPerThread]; {%- endif %} - {%- endif %} {%- if dense %} {{ embedding_pool_or_store("NULL") }} @@ -930,7 +913,7 @@ batch_index_select_dim0_codegen_forward_kernel {%- endmacro %} {%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) %} - {%- set max_vecs_per_thread = 2 * kMaxVecsPerThread if is_rocm else kMaxVecsPerThread %} + {%- set max_vecs_per_thread = kMaxVecsPerThread %} {%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %} {%- for cache_type in ['float', 'at::Half'] %} {%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %} diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu index bbd62a8bbc..b2b47b9935 100755 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu @@ -716,12 +716,7 @@ batch_index_select_dim0_codegen_forward_cuda( // kFixedMaxVecsPerThread instead of kMaxVecsPerThread. But // kMaxVecsPerThread and kFixedMaxVecsPerThread are the same // forward - {%- if is_rocm %} - // Account for Vec2 load for ROCm - constexpr auto kMaxVecsPerThread = 2 * kFixedMaxVecsPerThread; - {%- else %} constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread; - {%- endif %} const auto grid = min( div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize),