Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
#include "fbgemm_gpu/utils/tensor_accessor_builder.h"
#include "fbgemm_gpu/split_embeddings_utils.cuh"
#include "fbgemm_gpu/rocm/split_embeddings_common.h"

using namespace fbgemm_gpu;

Expand Down Expand Up @@ -102,6 +103,10 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
const unsigned int shfl_sync_mask,
const int32_t num_vecs
) {
// const uint32_t lane_id = threadIdx.x % AMDGCN_WAVE_SIZE;
// auto p_output_grad = grad_output.data();
// const grad_t* p_output_grad = grad_output.data();

// Copy value to vecs to make num_vecs known at compile time when
// kUseVecBlocking == false
const int32_t vecs = kUseVecBlocking ? num_vecs : kFixedMaxVecsPerThread;
Expand Down Expand Up @@ -160,18 +165,50 @@ DEVICE_INLINE void compute_grad_sum_{{ kdesc }}(
#pragma unroll kFixedMaxVecsPerThread
for (int32_t vec = 0; vec < kFixedMaxVecsPerThread && {{ d }} < D; ++vec) {
const int32_t d = {{ d }};
Vec4TAcc<grad_t> grad_out_vec(
{%- if nobag and is_index_select %}
// grad_output is 1d
&grad_output[grad_offset + l_j * grad_stride + d]
{%- elif nobag %}
&grad_output[l_j][d]
{%- elif vbe %}
&grad_output[0][grad_offset_j + d]
{%- else %}
&grad_output[b_j][0] + D_start_j + d
{%- endif %} // if nobag
);
// Vec4TAcc<grad_t> grad_out_vec(
// {%- if nobag and is_index_select %}
// // grad_output is 1d
// &grad_output[grad_offset + l_j * grad_stride + d]
// {%- elif nobag %}
// &grad_output[l_j][d]
// {%- elif vbe %}
// &grad_output[0][grad_offset_j + d]
// {%- else %}
// &grad_output[b_j][0] + D_start_j + d
// {%- endif %} // if nobag
// );
{%- if nobag and is_index_select %}
Vec4TAcc<grad_t> grad_out_vec(&grad_output[grad_offset + l_j * grad_stride + d]);
{%- elif nobag %}
Vec4TAcc<grad_t> grad_out_vec(&grad_output[l_j][d]);
{%- elif vbe %}
Vec4TAcc<grad_t> grad_out_vec(&grad_output[0][grad_offset_j + d]);
{%- else %}
int32x4_t emb_res = fbgemm_gpu::rocm::amdgcn_make_buffer_resource(&grad_output[b_j][0] + D_start_j + d);

Vec4TAcc<grad_t> grad_out_vec;
// Vec4TAcc<grad_t> grad_out_vec(
// &grad_output[b_j][0] + D_start_j + d
// // if nobag
// );
grad_out_vec.acc.x = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp16(emb_res, 0 * sizeof(half), 0, 0);
grad_out_vec.acc.y = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp16(emb_res, 1 * sizeof(half), 0, 0);
grad_out_vec.acc.z = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp16(emb_res, 2 * sizeof(half), 0, 0);
grad_out_vec.acc.w = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp16(emb_res, 3 * sizeof(half), 0, 0);
// grad_out_vec.acc.x = val0;
// grad_out_vec.acc.y = val1;
// grad_out_vec.acc.z = val2;
// grad_out_vec.acc.w = val3;

// int32x4_t emb_res = fbgemm_gpu::rocm::amdgcn_make_buffer_resource(p_emb_table + row_index * embedding_dim);
// emb_data[i] = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp32(
// emb_res, (lane_id + i * 64) * sizeof(float), 0, 0);

// grad_out_vec.acc = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp32x4(emb_res, lane_id * sizeof(float) * 4, 0, 0);
// grad_out_vec.acc = fbgemm_gpu::rocm::llvm_amdgcn_raw_buffer_load_fp32x4(emb_res, 0, 0, 0);

// grad_out_vec = *((const float4*)p_grad_out_vec);
{%- endif %} // if nobag

{%- if weighted %}
grad_sum[vec].fma_(grad_out_vec, idx_weight_j);
Expand Down
6 changes: 6 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/rocm/split_embeddings_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ __device__ half llvm_amdgcn_raw_buffer_load_fp16(
__asm("llvm.amdgcn.raw.buffer.load.f16");
#endif

// __device__ float4 llvm_amdgcn_raw_buffer_load_fp32x4(
// int32x4_t srsrc,
// int32_t voffset,
// int32_t soffset,
// int32_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");

__device__ float llvm_amdgcn_raw_buffer_load_fp32(
int32x4_t srsrc,
int32_t voffset,
Expand Down