diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 6b301d830e..ca7bd397ce 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -772,22 +772,9 @@ def __init__( # noqa C901 self.weights_precision = weights_precision - if torch.cuda.is_available() and torch.version.hip: - # NOTE: It was discovered that FP16 cache precision caused a 500x - # slowdown in performance of split_embedding_nobag_backward_codegen_rowwise_adagrad_unweighted_kernel_warp_per_row_1 - # kernel on ROCm, so to work around this, we fix cache precision to - # be FP32 always for the ROCm environment case. - # - # See: - # https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/ - cache_precision = SparseType.FP32 - self.log("Override cache_precision=SparseType.FP32 on ROCm") - else: - # NOTE: The changes from D65865527 are retained here until we can - # test that the the hack also works for non-ROCm environments. - cache_precision = ( - weights_precision if cache_precision is None else cache_precision - ) + cache_precision = ( + weights_precision if cache_precision is None else cache_precision + ) self.output_dtype: int = output_dtype.as_int() assert ( diff --git a/fbgemm_gpu/include/fbgemm_gpu/utils/float.cuh b/fbgemm_gpu/include/fbgemm_gpu/utils/float.cuh index 8c8a0a5117..52ee25d6a6 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/utils/float.cuh +++ b/fbgemm_gpu/include/fbgemm_gpu/utils/float.cuh @@ -51,9 +51,12 @@ struct Half4 { __device__ inline void store(at::Half* p) { #ifdef USE_ROCM - *reinterpret_cast(p) = *reinterpret_cast(&a); - *reinterpret_cast(p + 2) = - *reinterpret_cast(&b); + const unsigned int lo = *reinterpret_cast(&a); + const unsigned int hi = *reinterpret_cast(&b); + const unsigned long long packed = + static_cast(lo) | + (static_cast(hi) << 32); + *reinterpret_cast(p) = packed; #else #ifndef __HALF2_TO_UI