From 915af9d1a2debc9a8018585f72ca851bc4dad52b Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Fri, 5 Dec 2025 13:44:27 +0000 Subject: [PATCH 1/2] split_embeddings_common: fixes pointer arithmetic error in HIP load fn for D=320, half --- fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 59f96a19b7..e931a0d733 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -226,7 +226,7 @@ struct load_row_per_warp { llvm_amdgcn_raw_buffer_load_fp16x2( emb_res, (lane_id + 64) * sizeof(half2)); emb_data[4] = llvm_amdgcn_raw_buffer_load_fp16( - emb_res, (lane_id + 128) * sizeof(half)); + emb_res, (lane_id + 256) * sizeof(half)); } }; From 3a983a0c87715c0f35139281809e5caddee95a52 Mon Sep 17 00:00:00 2001 From: Aryaman Gupta Date: Tue, 9 Dec 2025 10:05:54 +0000 Subject: [PATCH 2/2] backward_adagrad_test: adds test for specific config: fp16, D=320 --- .../tbe/training/backward_adagrad_test.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py index 2c22a3fd5e..d78282a409 100644 --- a/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py +++ b/fbgemm_gpu/test/tbe/training/backward_adagrad_test.py @@ -19,6 +19,7 @@ from .backward_adagrad_common import ( additional_decorators, adjust_mixed_B_st, + CacheAlgorithm, common_settings, common_strategy, execute_backward_adagrad, @@ -221,6 +222,29 @@ def test_backward_adagrad_fp16_pmSUM_with_max_norm( # noqa C901 **kwargs, ) + @unittest.skipIf(*gpu_unavailable) + def test_backward_adagrad_fp16_pmSUM_D320(self) -> None: + execute_backward_adagrad( + T=2, + # using D=80 since the test harness multiplies D by 4, so 80*4=320 + D=80, + B=16, + log_E=4, + L=4, + D_gradcheck=1, + weights_precision=SparseType.FP16, + stochastic_rounding=False, + weighted=False, + row_wise=True, + mixed=False, + mixed_B=False, + use_cache=False, + cache_algorithm=CacheAlgorithm.LRU, + pooling_mode=PoolingMode.SUM, + use_cpu=False, + output_dtype=SparseType.FP16, + ) + if __name__ == "__main__": unittest.main()