diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh index b9db6e47f8..e28eaf00c4 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_device_kernel_template.cuh @@ -11,6 +11,7 @@ #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" #include "fbgemm_gpu/utils/tensor_accessor_builder.h" #include "fbgemm_gpu/split_embeddings_utils.cuh" +#include "fbgemm_gpu/rocm/split_embeddings_common.h" using namespace fbgemm_gpu; @@ -102,6 +103,10 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( const unsigned int shfl_sync_mask, const int32_t num_vecs ) { + // const uint32_t lane_id = threadIdx.x % AMDGCN_WAVE_SIZE; + // auto p_output_grad = grad_output.data(); + // const grad_t* p_output_grad = grad_output.data(); + // Copy value to vecs to make num_vecs known at compile time when // kUseVecBlocking == false const int32_t vecs = kUseVecBlocking ? num_vecs : kFixedMaxVecsPerThread; @@ -160,18 +165,50 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}( #pragma unroll kFixedMaxVecsPerThread for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) { const int32_t d = {{ d }}; - Vec4TAcc grad_out_vec( - {%- if nobag and is_index_select %} - // grad_output is 1d - &grad_output[grad_offset + l_j * grad_stride + d] - {%- elif nobag %} - &grad_output[l_j][d] - {%- elif vbe %} - &grad_output[0][grad_offset_j + d] - {%- else %} - &grad_output[b_j][0] + D_start_j + d - {%- endif %} // if nobag - ); + // Vec4TAcc grad_out_vec( + // {%- if nobag and is_index_select %} + // // grad_output is 1d + // &grad_output[grad_offset + l_j * grad_stride + d] + // {%- elif nobag %} + // &grad_output[l_j][d] + // {%- elif vbe %} + // &grad_output[0][grad_offset_j + d] + // {%- else %} + // &grad_output[b_j][0] + D_start_j + d + // {%- endif %} // if nobag + // ); + {%- if nobag and is_index_select %} + Vec4TAcc grad_out_vec(&grad_output[grad_offset + l_j * grad_stride + d]); + {%- elif nobag %} + Vec4TAcc grad_out_vec(&grad_output[l_j][d]); + {%- elif vbe %} + Vec4TAcc grad_out_vec(&grad_output[0][grad_offset_j + d]); + {%- else %} + int32x4_t emb_res = fbgemm_gpu::rocm::amdgcn_make_buffer_resource(&grad_output[b_j][0] + D_start_j + d); + + Vec4TAcc grad_out_vec; + // Vec4TAcc grad_out_vec( + // &grad_output[b_j][0] + D_start_j + d + // // if nobag + // ); + grad_out_vec.acc.x = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp16(emb_res, 0 * sizeof(half), 0, 0); + grad_out_vec.acc.y = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp16(emb_res, 1 * sizeof(half), 0, 0); + grad_out_vec.acc.z = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp16(emb_res, 2 * sizeof(half), 0, 0); + grad_out_vec.acc.w = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp16(emb_res, 3 * sizeof(half), 0, 0); + // grad_out_vec.acc.x = val0; + // grad_out_vec.acc.y = val1; + // grad_out_vec.acc.z = val2; + // grad_out_vec.acc.w = val3; + + // int32x4_t emb_res = fbgemm_gpu::rocm::amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim); + // emb_data[i] = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp32( + // emb_res, (lane_id + i * 64) * sizeof(float), 0, 0); + + // grad_out_vec.acc = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp32x4(emb_res, lane_id * sizeof(float) * 4, 0, 0); + // grad_out_vec.acc = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp32x4(emb_res, 0, 0, 0); + + // grad_out_vec = *((const float4*)p_grad_out_vec); + {%- endif %} // if nobag {%- if weighted %} grad_sum[vec].fma_(grad_out_vec, idx_weight_j); diff --git a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h index 974eae2594..016e6b10da 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h +++ b/fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h @@ -67,6 +67,12 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16( __asm("llvm.amdgcn.raw.buffer.load.f16"); #endif +// __device__ float4 llvm_amdgcn_raw_buffer_load_fp32x4( +// int32x4_t srsrc, +// int32_t voffset, +// int32_t soffset, +// int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); + __device__ float llvm_amdgcn_raw_buffer_load_fp32( int32x4_t srsrc, int32_t voffset,