From 49465f550405cfe21dcc29b01395ed3008d286c0 Mon Sep 17 00:00:00 2001 From: zhimding Date: Wed, 27 Aug 2025 09:46:16 +0000 Subject: [PATCH 1/3] apply unroll and prefetch optimization --- ..._backward_split_device_kernel_template.cuh | 105 ++++++++++++ ...ing_backward_split_kernel_warp_template.cu | 120 ++++++++++++- ...optimizer_split_device_kernel_template.cuh | 158 ++++++++++++++++++ 3 files changed, 379 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index b9db6e47f8..3c4e79cf8d 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -63,6 +63,111 @@ DEVICE_INLINE void store_grad_sum( */ #} +{%- if not nobag and not weighted and vbe %} +template < + typename grad_t, + typename cache_t, + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize = kWarpSize, + int32_t VEC_WIDTH, + bool kUseVecBlocking +> +DEVICE_INLINE void compute_grad_sum_unweighted_vbe_rowwise_adagrad( + Vec4TAcc* grad_sum, + Vec4TAcc* smem_grad_sum, + const pta::PackedTensorAccessor64& grad_output, + const pta::PackedTensorAccessor32& D_offsets, + const int32_t D, + const int32_t T, + const pta::PackedTensorAccessor32& sorted_infos, + const pta::PackedTensorAccessor32& B_offsets, + const pta::PackedTensorAccessor32& row_output_offsets, + const int32_t info_B_num_bits, + const uint32_t info_B_mask, + const int32_t segment_start, + const int32_t sl_start, + const int32_t sl_end, + const unsigned int shfl_sync_mask, + const int32_t num_vecs, + const int32_t b_t_pre, + const int32_t boff_pre +) { + // Copy value to vecs to make num_vecs known at compile time when + // kUseVecBlocking == false + const int32_t vecs = kUseVecBlocking ? num_vecs : kFixedMaxVecsPerThread; + for (int32_t vec_start = 0; + vec_start < vecs; + vec_start += kFixedMaxVecsPerThread) { + + // Reset grad_sum vectors + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; vec < kFixedMaxVecsPerThread; vec++) { + grad_sum[vec].acc.x = 0; + grad_sum[vec].acc.y = 0; + grad_sum[vec].acc.z = 0; + grad_sum[vec].acc.w = 0; + } + + for (int32_t sl = sl_start; sl < sl_end; sl += kThreadGroupSize) { + auto sl_j = sl + threadIdx.x; + const auto b_t = (sl==sl_start && vec_start==0) ? b_t_pre : (sl_j < sl_end + ? reinterpret_cast( + &sorted_infos[0])[segment_start + sl_j] + : 0); + const auto b = b_t & info_B_mask; + const auto t = b_t >> info_B_num_bits; + const auto boff = (sl == sl_start && vec_start == 0) ? boff_pre: B_offsets[t]; + const auto grad_offset = row_output_offsets[B_offsets[t] + b]; + const int32_t d = threadIdx.x * VEC_WIDTH; + + for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; j += 8) { + const auto grad_offset_j0 = SHFL_SYNC(grad_offset, j); + const auto grad_offset_j1 = SHFL_SYNC(grad_offset, j + 1); + const auto grad_offset_j2 = SHFL_SYNC(grad_offset, j + 2); + const auto grad_offset_j3 = SHFL_SYNC(grad_offset, j + 3); + const auto grad_offset_j4 = SHFL_SYNC(grad_offset, j + 4); + const auto grad_offset_j5 = SHFL_SYNC(grad_offset, j + 5); + const auto grad_offset_j6 = SHFL_SYNC(grad_offset, j + 6); + const auto grad_offset_j7 = SHFL_SYNC(grad_offset, j + 7); + if (threadIdx.x * VEC_WIDTH < D) { + Vec4TAcc grad_out_vec0 = Vec4TAcc(&grad_output[0][grad_offset_j0 + d]); + Vec4TAcc grad_out_vec1 = sl + j + 1 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j1 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec2 = sl + j + 2 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j2 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec3 = sl + j + 3 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j3 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec4 = sl + j + 4 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j4 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec5 = sl + j + 5 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j5 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec6 = sl + j + 6 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j6 + d]) : Vec4TAcc(); + Vec4TAcc grad_out_vec7 = sl + j + 7 < sl_end ? Vec4TAcc(&grad_output[0][grad_offset_j7 + d]) : Vec4TAcc(); + grad_sum[0].add_(grad_out_vec0); + grad_sum[0].add_(grad_out_vec1); + grad_sum[0].add_(grad_out_vec2); + grad_sum[0].add_(grad_out_vec3); + grad_sum[0].add_(grad_out_vec4); + grad_sum[0].add_(grad_out_vec5); + grad_sum[0].add_(grad_out_vec6); + grad_sum[0].add_(grad_out_vec7); + + } + } + } + + {%- set d_vec = "((vec + vec_start) * kThreadGroupSize + threadIdx.x)" %} + + if (smem_grad_sum) { + // Store grad_sum in smem_grad_sum + #pragma unroll kFixedMaxVecsPerThread + for (int32_t vec = 0; + (vec < kFixedMaxVecsPerThread) && {{ d_vec }} * VEC_WIDTH < D; + ++vec) { + const int32_t d_vec = {{ d_vec }}; + smem_grad_sum[d_vec] = grad_sum[vec]; + } + } + } +} + +{%- endif %} + template < typename grad_t, typename cache_t, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index 322d997e83..a0658d044b 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -153,6 +153,15 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( const float weight_decay_base = 1 - learning_rate * weight_decay; {%- endif %} + {%- if not nobag and vbe and not weighted and not ssd %} + const auto run_sum = sorted_linear_indices_run.size(0) < sorted_linear_indices_num_runs[0] + ? sorted_linear_indices_run.size(0) + : sorted_linear_indices_num_runs[0]; + int64_t linear_index_pre = sorted_linear_indices_run[start_run_id]; + int32_t segment_start_pre = sorted_linear_indices_cumulative_run_lengths[start_run_id]; + int32_t segment_end_pre = sorted_linear_indices_cumulative_run_lengths[start_run_id + 1]; + {%- endif %} + #ifdef FBGEMM_USE_SUBWARP_SHUFFLE const unsigned int shfl_sync_mask = ((1L << kThreadGroupSize) - 1) << @@ -169,6 +178,24 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( ? smem.getPointer() + threadIdx.y * grad_sum_stride : nullptr; + {%- if vbe and not weighted and not ssd and not nobag and optimizer == "rowwise_adagrad" and not is_gwd_kernel %} + int32_t segment_start = segment_start_pre; + int32_t segment_end = segment_end_pre; + int64_t linear_index = linear_index_pre; + int32_t SL = segment_end - segment_start; + auto info_0 = reinterpret_cast(&sorted_infos[0])[segment_start_pre]; + auto t_0 = info_0 >> info_B_num_bits; + auto weights_placement = static_cast(weights_placements[t_0]); + + auto b_t_pre = threadIdx.x < SL + ? reinterpret_cast(&sorted_infos[0])[segment_start + threadIdx.x] + : 0; + + auto t = b_t_pre >> info_B_num_bits; + auto boff_pre = B_offsets[t]; + + for (uint32_t run_id = start_run_id; run_id < run_sum; run_id += gridDim.x * blockDim.y) { + {%- else %} for (uint32_t run_id = start_run_id; run_id < sorted_linear_indices_run.size(0) && run_id < sorted_linear_indices_num_runs[0]; run_id += gridDim.x * blockDim.y) { @@ -179,12 +206,19 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( const int32_t segment_end = sorted_linear_indices_cumulative_run_lengths[run_id + 1]; const int32_t SL = segment_end - segment_start; - + {%- endif %} if (SL >= max_segment_length_per_warp) { continue; } + {%- if vbe and not weighted and not ssd and not nobag and optimizer == "rowwise_adagrad" and not is_gwd_kernel %} + if (run_id + gridDim.x * blockDim.y < run_sum) { + linear_index_pre = sorted_linear_indices_run[run_id + gridDim.x * blockDim.y]; + segment_start_pre = sorted_linear_indices_cumulative_run_lengths[run_id + gridDim.x * blockDim.y]; + segment_end_pre = sorted_linear_indices_cumulative_run_lengths[run_id + gridDim.x * blockDim.y + 1]; + } + {%- else %} // now, each segment corresponds to exactly one table `t` and row in // that table (`idx`). Thus, we can hoist out some of the book-keeping. {%- if not nobag %} @@ -194,6 +228,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( const auto info_0 = sorted_infos[segment_start]; int32_t t_0 = info_0 % T; {%- endif %} + {%- endif %} int64_t hash_size = hash_size_cumsum[t_0]; {%- if not nobag or is_index_select %} @@ -219,6 +254,45 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( constexpr int32_t kGroupVecWidth = kThreadGroupSize * VEC_WIDTH; const int32_t num_vecs = (D + kGroupVecWidth - 1) / kGroupVecWidth; + {%- if not nobag and not weighted and vbe and optimizer == "rowwise_adagrad" and not is_gwd_kernel and not ssd %} + compute_grad_sum_unweighted_vbe_rowwise_adagrad< + grad_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + grad_sum, + smem_grad_sum, + grad_output, + D_offsets, + D, + T, + sorted_infos, + B_offsets, + row_output_offsets, + info_B_num_bits, + info_B_mask, + segment_start, + sl_start, + sl_end, + shfl_sync_mask, + num_vecs, + b_t_pre, + boff_pre + ); + if (run_id + gridDim.x * blockDim.y < run_sum) { + info_0 = reinterpret_cast(&sorted_infos[0])[segment_start_pre]; + } + + segment_start = segment_start_pre; + segment_end = segment_end_pre; + linear_index = linear_index_pre; + SL = segment_end - segment_start; + b_t_pre = threadIdx.x < SL + ? reinterpret_cast(&sorted_infos[0])[segment_start + threadIdx.x] + : 0; + {%- else %} compute_grad_sum_{{ kdesc }}< grad_t, cache_t, @@ -255,7 +329,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( sl_end, shfl_sync_mask, num_vecs - ); + ); + {%- endif %} // Copy value to max_vecs to make max_vecs_per_thread known at compile time // when kUseVecBlocking == false @@ -263,6 +338,42 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; {%- if not dense and optimizer != "none" %} + {%- if not nobag and not weighted and vbe and optimizer == "rowwise_adagrad" and not is_gwd_kernel and not ssd %} + vbe_unweighted_split_rowwise_adagrad_table_update_kernel< + emb_t, + cache_t, + kFixedMaxVecsPerThread, + kThreadGroupSize, + VEC_WIDTH, + kUseVecBlocking>( + dev_weights, + uvm_weights, + lxu_cache_weights, + weights_placements, + weights_offsets, + sorted_{{ locs_or_addrs_tensor }}, + grad_sum, + smem_grad_sum, + smem_grad_sum, // shared_weight_update_row (reuse smem_grad_sum) + stochastic_rounding, + stochastic_rounding_philox_args, + run_id, + segment_start, + D, + t_0, + idx, + 1, // global_weight_decay + shfl_sync_mask, + max_vecs, + weights_placement, + {{ args.split_kernel_arg_names | join(", ") }} + ); + + t_0 = info_0 >> info_B_num_bits; + auto weights_placement = static_cast(weights_placements[t_0]); + t = b_t_pre >> info_B_num_bits; + boff_pre = B_offsets[t]; + {%- else %} {{ mdesc }}_{{ optimizer }}_table_update_kernel< emb_t, cache_t, @@ -303,7 +414,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( enable_optimizer_offloading, {%- endif %} {{ args.split_kernel_arg_names | join(", ") }} - ); + ); + {%- endif %} {%- else %} // Write deduplicated gradient to grad_dev_weights gradient is sparse // for split_embedding and dense for dense_embedding @@ -853,4 +965,4 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd #endif //////////////////////////////////////////////////////////////////////////////// {%- endif %} - // clang-format on + // clang-format on \ No newline at end of file diff --git a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh index e4fb6c548c..66fb2ae430 100644 --- a/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh @@ -176,4 +176,162 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel( {{ split_post_update }} } +{%- if optimizer == "rowwise_adagrad" and not ssd %} +template < + typename emb_t, + typename cache_t, + {%- for ph_name in args.placeholder_tensor_names %} + {%- set ph_type = "{}_ph_t".format(ph_name) %} + typename {{ ph_type }}, + {%- endfor %} + int32_t kFixedMaxVecsPerThread, + int32_t kThreadGroupSize = kWarpSize, + int32_t VEC_WIDTH, + bool kUseVecBlocking +> +DEVICE_INLINE void vbe_unweighted_{{ mdesc }}_{{ optimizer }}_table_update_kernel( + pta::PackedTensorAccessor64& dev_weights, + pta::PackedTensorAccessor64& uvm_weights, + pta::PackedTensorAccessor64& lxu_cache_weights, + const pta::PackedTensorAccessor32& weights_placements, + const pta::PackedTensorAccessor32& weights_offsets, + const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits>& sorted_{{ locs_or_addrs_tensor }}, + Vec4TAcc* grad_sum, + Vec4TAcc* smem_grad_sum, + Vec4TAcc* shared_weight_update_row, + const bool stochastic_rounding, + const at::PhiloxCudaState& stochastic_rounding_philox_args, + const uint32_t run_id, + const uint32_t cache_loc_run_id, + const int32_t D, + const int32_t t, + const int64_t idx, + {%- if has_global_weight_decay_support %} + const float global_weight_decay, + {%- endif %} + const uint32_t shfl_sync_mask, + const int32_t max_vecs_per_thread, + PlacementType weights_placement, + {%- if ssd %} + const bool enable_optimizer_offloading, + {%- endif %} + {{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }} +) { + constexpr auto kIsInt8 = std::is_same_v; + // Copy value to max_vecs to make max_vecs_per_thread known at compile time + // when kUseVecBlocking == false + const int32_t max_vecs = + kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; + const int64_t weights_offset = weights_offsets[t]; + emb_t* __restrict__ weights {nullptr}; + cache_t* __restrict__ cache_weights {nullptr}; + int32_t D_emb = D; + if constexpr (kIsInt8) { + D_emb += kINT8QparamsBytes; + } + // const auto weights_placement = static_cast(weights_placements[t]); + if (weights_placement == PlacementType::DEVICE) { + weights = &dev_weights[weights_offset + idx * D_emb]; + } else { + weights = {{ "nullptr" if ssd else "&uvm_weights[weights_offset + idx * D_emb]" }}; + } + if (weights_placement == PlacementType::MANAGED_CACHING) { + const auto {{ locs_or_addrs_idx }} = sorted_{{ locs_or_addrs_tensor }}[cache_loc_run_id]; + {%- if ssd %} + cache_weights = reinterpret_cast( + *reinterpret_cast(&{{ locs_or_addrs_idx }})); + {%- else %} + if ({{ locs_or_addrs_idx }} != kCacheLocationMissing) { + cache_weights = &lxu_cache_weights[{{ locs_or_addrs_idx }}][0]; + } + {%- endif %} + } + {%- for tensor in args.split_tensors %} + {{ args.split_tensor_types[tensor] }}* __restrict__ {{ tensor }}; + const auto {{ tensor }}_placement = static_cast({{ tensor }}_placements[t]); + const int64_t {{ tensor }}_offset = {{ tensor }}_offsets[t]; + if ({{ tensor }}_placement == PlacementType::DEVICE) { + {{ tensor }} = &{{ tensor }}_dev[{{ tensor }}_offset]; + } else { + {{ tensor }} = &{{ tensor }}_uvm[{{ tensor }}_offset]; + } + {%- endfor %} + + auto weight_row_template = + WeightRow>( + weights, + cache_weights, + D, + stochastic_rounding, + &stochastic_rounding_philox_args, + threadIdx.x + run_id * blockDim.x); + + float2 qparams_template; + if constexpr (kIsInt8) { + if (!cache_weights) { + qparams_template = weight_row_template.load_qparams(); + } + } + + {%- if not ssd %} + [[maybe_unused]] constexpr auto enable_optimizer_offloading = false; + {%- endif %} + + {{ split_precomputation }} + + {# /* Note: technically, global weight decay (gwd) compensation should be done before + `split_precomputation`). But since decouple mode in `rowwise_adagrad` only computes correction, + the order of applying gwd does not matter. We perform gwd update before `split_weight_update` + below to minimize number of times to load weights. + So, note that the behavior may be different if you want to enable gwd for other optimizers + such as `lamb` or `partial_rowwise_lamb`. + */#} + float2 qparams_new; + {{ + generate_optimized_grad_sum_loop_access( + """ + Vec4TAcc weight_new = weight_row_template.load(d, qparams_template); + Vec4TAcc& grad = {grad_vec}; + {global_weight_decay_update} + {split_weight_update} + if (kIsInt8 && !cache_weights) { + shared_weight_update_row[d_vec] = weight_new; + } else { + // qparams_new not used if type is not int8 + weight_row_template.store(weight_new, d, qparams_new); + } + """, + other_formats={ + "split_weight_update": split_weight_update, + "global_weight_decay_update": "weight_new.mul_(global_weight_decay);" if has_global_weight_decay_support else "" + }, + ) + }} + + if constexpr (kIsInt8) { + if (!cache_weights) { + // Calculate new qparams after row update + qparams_new = thrust_find_qparams>( + shared_weight_update_row, D); + weight_row_template.store_qparams(qparams_new); + + // Fetch cached updated row from shared mem and quantize on-the-fly + // when saving to lowp embedding + for (int32_t vec = 0; + (vec * kThreadGroupSize + threadIdx.x) * VEC_WIDTH < D; + ++vec) { + const auto d_vec = vec * kThreadGroupSize + threadIdx.x; + const int32_t d = d_vec * VEC_WIDTH; + weight_row_template.store( + shared_weight_update_row[d_vec], + d, + qparams_new); + } + } + } + + {{ split_post_update }} +} +{%- endif %} + // clang-format on From f7d39c7faf53f61387e48053a002513ef1a7e0de Mon Sep 17 00:00:00 2001 From: wulley Date: Thu, 28 Aug 2025 02:25:02 +0000 Subject: [PATCH 2/3] fix --- .../embedding_backward_split_device_kernel_template.cuh | 2 +- .../embedding_backward_split_kernel_warp_template.cu | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index 3c4e79cf8d..00aad1911b 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -117,7 +117,7 @@ DEVICE_INLINE void compute_grad_sum_unweighted_vbe_rowwise_adagrad( const auto b = b_t & info_B_mask; const auto t = b_t >> info_B_num_bits; const auto boff = (sl == sl_start && vec_start == 0) ? boff_pre: B_offsets[t]; - const auto grad_offset = row_output_offsets[B_offsets[t] + b]; + const auto grad_offset = row_output_offsets[boff + b]; // if vbe // if not nobag const int32_t d = threadIdx.x * VEC_WIDTH; for (int32_t j = 0; j < kThreadGroupSize && sl + j < sl_end; j += 8) { diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index a0658d044b..d632b7f509 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -230,6 +230,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} {%- endif %} + // now, each segment corresponds to exactly one table `t` and row in + // that table (`idx`). Thus, we can hoist out some of the book-keeping. int64_t hash_size = hash_size_cumsum[t_0]; {%- if not nobag or is_index_select %} const auto D_start_t0 = D_offsets[t_0]; @@ -243,7 +245,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( const auto grad_stride = permute_output_dim_0_1 ? D_offsets[T] : D; {%- endif %} {%- endif %} - int64_t idx = linear_index - hash_size; + int64_t idx = linear_index - hash_size; // the id value or emb index {{ compute_global_weight_decay(is_gwd_kernel) }} @@ -367,7 +369,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( max_vecs, weights_placement, {{ args.split_kernel_arg_names | join(", ") }} - ); + ); // if not dense and optimizer != "none" t_0 = info_0 >> info_B_num_bits; auto weights_placement = static_cast(weights_placements[t_0]); @@ -441,7 +443,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( idx, max_vecs ); - {%- endif %} // if not dense and optimizer != "none" + {%- endif %} } } From 833da5bfa65cf5576ece2f71acd40ae9a889a68a Mon Sep 17 00:00:00 2001 From: wulley Date: Thu, 28 Aug 2025 11:50:52 +0000 Subject: [PATCH 3/3] rm redundant codegen --- ..._backward_split_device_kernel_template.cuh | 2 +- ...ing_backward_split_kernel_warp_template.cu | 32 +++++++++++-------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index 00aad1911b..f3a807cfd2 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -63,7 +63,7 @@ DEVICE_INLINE void store_grad_sum( */ #} -{%- if not nobag and not weighted and vbe %} +{%- if not nobag and not weighted and vbe and not ssd %} template < typename grad_t, typename cache_t, diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu index d632b7f509..a0d49d6baf 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu @@ -153,7 +153,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( const float weight_decay_base = 1 - learning_rate * weight_decay; {%- endif %} - {%- if not nobag and vbe and not weighted and not ssd %} + {%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %} const auto run_sum = sorted_linear_indices_run.size(0) < sorted_linear_indices_num_runs[0] ? sorted_linear_indices_run.size(0) : sorted_linear_indices_num_runs[0]; @@ -178,7 +178,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( ? smem.getPointer() + threadIdx.y * grad_sum_stride : nullptr; - {%- if vbe and not weighted and not ssd and not nobag and optimizer == "rowwise_adagrad" and not is_gwd_kernel %} + {%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %} int32_t segment_start = segment_start_pre; int32_t segment_end = segment_end_pre; int64_t linear_index = linear_index_pre; @@ -196,6 +196,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( for (uint32_t run_id = start_run_id; run_id < run_sum; run_id += gridDim.x * blockDim.y) { {%- else %} + for (uint32_t run_id = start_run_id; run_id < sorted_linear_indices_run.size(0) && run_id < sorted_linear_indices_num_runs[0]; run_id += gridDim.x * blockDim.y) { @@ -208,17 +209,19 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( const int32_t SL = segment_end - segment_start; {%- endif %} + if (SL >= max_segment_length_per_warp) { continue; } - {%- if vbe and not weighted and not ssd and not nobag and optimizer == "rowwise_adagrad" and not is_gwd_kernel %} + {%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %} if (run_id + gridDim.x * blockDim.y < run_sum) { linear_index_pre = sorted_linear_indices_run[run_id + gridDim.x * blockDim.y]; segment_start_pre = sorted_linear_indices_cumulative_run_lengths[run_id + gridDim.x * blockDim.y]; segment_end_pre = sorted_linear_indices_cumulative_run_lengths[run_id + gridDim.x * blockDim.y + 1]; } {%- else %} + // now, each segment corresponds to exactly one table `t` and row in // that table (`idx`). Thus, we can hoist out some of the book-keeping. {%- if not nobag %} @@ -230,8 +233,6 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( {%- endif %} {%- endif %} - // now, each segment corresponds to exactly one table `t` and row in - // that table (`idx`). Thus, we can hoist out some of the book-keeping. int64_t hash_size = hash_size_cumsum[t_0]; {%- if not nobag or is_index_select %} const auto D_start_t0 = D_offsets[t_0]; @@ -245,7 +246,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( const auto grad_stride = permute_output_dim_0_1 ? D_offsets[T] : D; {%- endif %} {%- endif %} - int64_t idx = linear_index - hash_size; // the id value or emb index + int64_t idx = linear_index - hash_size; {{ compute_global_weight_decay(is_gwd_kernel) }} @@ -256,7 +257,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( constexpr int32_t kGroupVecWidth = kThreadGroupSize * VEC_WIDTH; const int32_t num_vecs = (D + kGroupVecWidth - 1) / kGroupVecWidth; - {%- if not nobag and not weighted and vbe and optimizer == "rowwise_adagrad" and not is_gwd_kernel and not ssd %} + {%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %} compute_grad_sum_unweighted_vbe_rowwise_adagrad< grad_t, cache_t, @@ -295,6 +296,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( ? reinterpret_cast(&sorted_infos[0])[segment_start + threadIdx.x] : 0; {%- else %} + compute_grad_sum_{{ kdesc }}< grad_t, cache_t, @@ -331,7 +333,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( sl_end, shfl_sync_mask, num_vecs - ); + ); {%- endif %} // Copy value to max_vecs to make max_vecs_per_thread known at compile time @@ -340,7 +342,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( kUseVecBlocking ? max_vecs_per_thread : kFixedMaxVecsPerThread; {%- if not dense and optimizer != "none" %} - {%- if not nobag and not weighted and vbe and optimizer == "rowwise_adagrad" and not is_gwd_kernel and not ssd %} + {%- if not is_gwd_kernel and not nobag and vbe and not weighted and not ssd and optimizer == "rowwise_adagrad" and mdesc == "split" %} vbe_unweighted_split_rowwise_adagrad_table_update_kernel< emb_t, cache_t, @@ -360,7 +362,9 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( stochastic_rounding, stochastic_rounding_philox_args, run_id, - segment_start, + use_uniq_cache_locations + ? (run_id - table_unique_indices_offsets[t_0]) + : segment_start, D, t_0, idx, @@ -369,7 +373,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( max_vecs, weights_placement, {{ args.split_kernel_arg_names | join(", ") }} - ); // if not dense and optimizer != "none" + ); // if not dense and optimizer != "none" t_0 = info_0 >> info_B_num_bits; auto weights_placement = static_cast(weights_placements[t_0]); @@ -416,7 +420,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( enable_optimizer_offloading, {%- endif %} {{ args.split_kernel_arg_names | join(", ") }} - ); + ); // if not dense and optimizer != "none" {%- endif %} {%- else %} // Write deduplicated gradient to grad_dev_weights gradient is sparse @@ -442,8 +446,8 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row( weights_offset, idx, max_vecs - ); - {%- endif %} + ); // if not dense and optimizer != "none" + {%- endif %} } }