From 4f2256245bc860290dd2ce0451fb972297720c48 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Thu, 19 Mar 2026 21:27:56 +0000 Subject: [PATCH 1/5] use acc_type for momentum instead of cache_t --- ...ding_backward_split_device_kernel_template.hip | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip index c63f372a74..a5814c2931 100644 --- a/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip +++ b/fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip @@ -32,6 +32,7 @@ #include #include +#include #include "fbgemm_gpu/rocm/split_embeddings_common.h" namespace fbgemm_gpu::rocm { @@ -48,8 +49,9 @@ struct rowwise_adagrad_optimizer_t { if constexpr(segment_split == 0) { - cache_t * p_momentum = reinterpret_cast(karg.p_momentum); - cache_t momentum = p_momentum[row_index]; // should be s_load + using momentum_t = at::acc_type; + momentum_t* p_momentum = reinterpret_cast(karg.p_momentum); + momentum_t momentum = p_momentum[row_index]; // should be s_load // compute per row square sum cache_t local_sum_squre = .0f; if constexpr(weight_decay_mode == 1) @@ -72,11 +74,11 @@ struct rowwise_adagrad_optimizer_t } } - cache_t avg_square = - wave_reduce, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre) / + momentum_t avg_square = + static_cast(wave_reduce, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre)) / embedding_dim; - cache_t momentum_new = momentum + avg_square; + momentum_t momentum_new = momentum + avg_square; cache_t multiplier = karg.learning_rate / (sqrtf(momentum_new) + karg.eps); cache_t correction; @@ -164,7 +166,8 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}( const int64_t emb_idx = linear_index - hash_size; p_emb_table += hash_size * emb_dim; - opt_karg.p_momentum = reinterpret_cast(reinterpret_cast(opt_karg.p_momentum) + hash_size); + using momentum_t = at::acc_type; + opt_karg.p_momentum = reinterpret_cast(reinterpret_cast(opt_karg.p_momentum) + hash_size); const int32_t segment_length = segment_end - segment_start; From 0dfaed7f0e423ea9e67925965a62fb4a504d6843 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Mon, 23 Mar 2026 13:58:40 +0000 Subject: [PATCH 2/5] backward_adagrad_test.py: updates unit test to vary params --- .../tbe/training/backward_adagrad_common.py | 2 + .../tbe/training/backward_adagrad_test.py | 77 ++++++++++++++----- 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py index 73ee7b98e6..fd03e8d9c7 100755 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_common.py @@ -50,6 +50,7 @@ gpu_unavailable, gradcheck, optests, + skipIfNotRocm, skipIfRocm, TEST_WITH_ROCM, use_cpu_strategy, @@ -62,6 +63,7 @@ gpu_unavailable, gradcheck, optests, + skipIfNotRocm, skipIfRocm, TEST_WITH_ROCM, use_cpu_strategy, diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py index 931be47ffd..713bb5b1e1 100644 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py @@ -23,6 +23,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, SplitTableBatchedEmbeddingBagsCodegen, + WeightDecayMode, ) from hypothesis import given, settings @@ -38,6 +39,7 @@ gpu_unavailable, optests, PoolingMode, + skipIfNotRocm, SparseType, st, ) @@ -234,28 +236,61 @@ def test_backward_adagrad_fp16_pmSUM_with_max_norm( # noqa C901 **kwargs, ) + @given( + T=st.integers(min_value=1, max_value=5), + D=st.sampled_from([16, 32, 40, 48, 64, 80]), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=2, max_value=20), + weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), + weighted=st.booleans(), + weight_decay_mode=st.sampled_from( + [ + WeightDecayMode.NONE, + WeightDecayMode.L2, + WeightDecayMode.DECOUPLE, + ] + ), + ) + @settings(**common_settings) @unittest.skipIf(*gpu_unavailable) - def test_backward_adagrad_fp16_pmSUM_D320(self) -> None: - execute_backward_adagrad( - T=2, - # using D=80 since the test harness multiplies D by 4, so 80*4=320 - D=80, - B=16, - log_E=4, - L=4, - D_gradcheck=1, - weights_precision=SparseType.FP16, - stochastic_rounding=False, - weighted=False, - row_wise=True, - mixed=False, - mixed_B=False, - use_cache=False, - cache_algorithm=CacheAlgorithm.LRU, - pooling_mode=PoolingMode.SUM, - use_cpu=False, - output_dtype=SparseType.FP16, - ) + @skipIfNotRocm("Test evaluates ROCm backward kernels") + def test_backward_adagrad_rocm_optimized( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + weighted: bool, + weight_decay_mode: WeightDecayMode, + ) -> None: + for use_optimized_kernel in ["0", "1"]: + os.environ["FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL"] = use_optimized_kernel + logging.info( + f"Testing ROCm backward kernel with FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL={use_optimized_kernel}" + ) + execute_backward_adagrad( + T=T, + D=D, + B=B, + log_E=log_E, + L=L, + D_gradcheck=1, + weights_precision=weights_precision, + stochastic_rounding=False, + weighted=weighted, + row_wise=True, + mixed=False, + mixed_B=False, + use_cache=False, + cache_algorithm=CacheAlgorithm.LRU, + pooling_mode=PoolingMode.SUM, + use_cpu=False, + output_dtype=weights_precision, + weight_decay_mode=weight_decay_mode, + ) @unittest.skipIf(*gpu_unavailable) @unittest.skipIf(*gpu_memory_lt_gb(40)) From 3a179adc251f29cd2157ebc93751083a2de0f427 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Mon, 23 Mar 2026 11:23:17 -0500 Subject: [PATCH 3/5] split env variable calls --- .../tbe/training/backward_adagrad_test.py | 121 ++++++++++++++---- 1 file changed, 94 insertions(+), 27 deletions(-) diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py index 713bb5b1e1..011a7ad115 100644 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py @@ -236,6 +236,84 @@ def test_backward_adagrad_fp16_pmSUM_with_max_norm( # noqa C901 **kwargs, ) + def _test_backward_adagrad_rocm_kernel( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + weighted: bool, + weight_decay_mode: WeightDecayMode, + ) -> None: + """Helper method for ROCm backward kernel tests.""" + execute_backward_adagrad( + T=T, + D=D, + B=B, + log_E=log_E, + L=L, + D_gradcheck=1, + weights_precision=weights_precision, + stochastic_rounding=False, + weighted=weighted, + row_wise=True, + mixed=False, + mixed_B=False, + use_cache=False, + cache_algorithm=CacheAlgorithm.LRU, + pooling_mode=PoolingMode.SUM, + use_cpu=False, + output_dtype=weights_precision, + weight_decay_mode=weight_decay_mode, + ) + + @given( + T=st.integers(min_value=1, max_value=5), + D=st.sampled_from([16, 32, 40, 48, 64, 80]), + B=st.integers(min_value=1, max_value=128), + log_E=st.integers(min_value=3, max_value=5), + L=st.integers(min_value=2, max_value=20), + weights_precision=st.sampled_from([SparseType.FP16, SparseType.FP32]), + weighted=st.booleans(), + weight_decay_mode=st.sampled_from( + [ + WeightDecayMode.NONE, + WeightDecayMode.L2, + WeightDecayMode.DECOUPLE, + ] + ), + ) + @settings(**common_settings) + @unittest.skipIf(*gpu_unavailable) + @skipIfNotRocm("Test evaluates ROCm stock backward kernel") + def test_backward_adagrad_rocm_stock_kernel( + self, + T: int, + D: int, + B: int, + log_E: int, + L: int, + weights_precision: SparseType, + weighted: bool, + weight_decay_mode: WeightDecayMode, + ) -> None: + os.environ["FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL"] = "0" + logging.info( + "Testing ROCm backward kernel with FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL=0 (stock)" + ) + self._test_backward_adagrad_rocm_kernel( + T=T, + D=D, + B=B, + log_E=log_E, + L=L, + weights_precision=weights_precision, + weighted=weighted, + weight_decay_mode=weight_decay_mode, + ) + @given( T=st.integers(min_value=1, max_value=5), D=st.sampled_from([16, 32, 40, 48, 64, 80]), @@ -254,8 +332,8 @@ def test_backward_adagrad_fp16_pmSUM_with_max_norm( # noqa C901 ) @settings(**common_settings) @unittest.skipIf(*gpu_unavailable) - @skipIfNotRocm("Test evaluates ROCm backward kernels") - def test_backward_adagrad_rocm_optimized( + @skipIfNotRocm("Test evaluates ROCm optimized backward kernel") + def test_backward_adagrad_rocm_optimized_kernel( self, T: int, D: int, @@ -266,31 +344,20 @@ def test_backward_adagrad_rocm_optimized( weighted: bool, weight_decay_mode: WeightDecayMode, ) -> None: - for use_optimized_kernel in ["0", "1"]: - os.environ["FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL"] = use_optimized_kernel - logging.info( - f"Testing ROCm backward kernel with FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL={use_optimized_kernel}" - ) - execute_backward_adagrad( - T=T, - D=D, - B=B, - log_E=log_E, - L=L, - D_gradcheck=1, - weights_precision=weights_precision, - stochastic_rounding=False, - weighted=weighted, - row_wise=True, - mixed=False, - mixed_B=False, - use_cache=False, - cache_algorithm=CacheAlgorithm.LRU, - pooling_mode=PoolingMode.SUM, - use_cpu=False, - output_dtype=weights_precision, - weight_decay_mode=weight_decay_mode, - ) + os.environ["FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL"] = "1" + logging.info( + "Testing ROCm backward kernel with FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL=1 (optimized)" + ) + self._test_backward_adagrad_rocm_kernel( + T=T, + D=D, + B=B, + log_E=log_E, + L=L, + weights_precision=weights_precision, + weighted=weighted, + weight_decay_mode=weight_decay_mode, + ) @unittest.skipIf(*gpu_unavailable) @unittest.skipIf(*gpu_memory_lt_gb(40)) From 77f752b89085c831723b9ad03377ad10506cac06 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Mon, 23 Mar 2026 16:39:43 +0000 Subject: [PATCH 4/5] backward_adagrad_test.py: modifies test name and desc --- fbgemm_gpu/test/tbe/training/backward_adagrad_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py index 011a7ad115..7bcc45946d 100644 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py @@ -287,8 +287,8 @@ def _test_backward_adagrad_rocm_kernel( ) @settings(**common_settings) @unittest.skipIf(*gpu_unavailable) - @skipIfNotRocm("Test evaluates ROCm stock backward kernel") - def test_backward_adagrad_rocm_stock_kernel( + @skipIfNotRocm("Test evaluates fallback kernel on ROCm") + def test_backward_adagrad_rocm_fallback_kernel( self, T: int, D: int, From ba53028fcccbfb86bee77548ad5e3d25949568e5 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Mon, 23 Mar 2026 12:14:41 -0500 Subject: [PATCH 5/5] backward_adagrad_test: reset env variable after test --- .../tbe/training/backward_adagrad_test.py | 66 ++++++++++++------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py index 7bcc45946d..ee2934caf2 100644 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py @@ -299,20 +299,29 @@ def test_backward_adagrad_rocm_fallback_kernel( weighted: bool, weight_decay_mode: WeightDecayMode, ) -> None: - os.environ["FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL"] = "0" + env_var = "FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL" + original_value = os.environ.get(env_var) + os.environ[env_var] = "0" logging.info( - "Testing ROCm backward kernel with FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL=0 (stock)" - ) - self._test_backward_adagrad_rocm_kernel( - T=T, - D=D, - B=B, - log_E=log_E, - L=L, - weights_precision=weights_precision, - weighted=weighted, - weight_decay_mode=weight_decay_mode, + f"Testing ROCm backward kernel with {env_var}=0 (stock)" ) + try: + self._test_backward_adagrad_rocm_kernel( + T=T, + D=D, + B=B, + log_E=log_E, + L=L, + weights_precision=weights_precision, + weighted=weighted, + weight_decay_mode=weight_decay_mode, + ) + finally: + # Restore original value + if original_value is None: + os.environ.pop(env_var, None) + else: + os.environ[env_var] = original_value @given( T=st.integers(min_value=1, max_value=5), @@ -344,20 +353,29 @@ def test_backward_adagrad_rocm_optimized_kernel( weighted: bool, weight_decay_mode: WeightDecayMode, ) -> None: - os.environ["FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL"] = "1" + env_var = "FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL" + original_value = os.environ.get(env_var) + os.environ[env_var] = "1" logging.info( - "Testing ROCm backward kernel with FBGEMM_TBE_ROCM_HIP_BACKWARD_KERNEL=1 (optimized)" - ) - self._test_backward_adagrad_rocm_kernel( - T=T, - D=D, - B=B, - log_E=log_E, - L=L, - weights_precision=weights_precision, - weighted=weighted, - weight_decay_mode=weight_decay_mode, + f"Testing ROCm backward kernel with {env_var}=1 (optimized)" ) + try: + self._test_backward_adagrad_rocm_kernel( + T=T, + D=D, + B=B, + log_E=log_E, + L=L, + weights_precision=weights_precision, + weighted=weighted, + weight_decay_mode=weight_decay_mode, + ) + finally: + # Restore original value + if original_value is None: + os.environ.pop(env_var, None) + else: + os.environ[env_var] = original_value @unittest.skipIf(*gpu_unavailable) @unittest.skipIf(*gpu_memory_lt_gb(40))