Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{%- if weighted %}
auto p_indice_weights_sorted = sorted_indice_weights.data();
{%- endif %}
auto emb_dim = embedding_dim;
auto emb_dim = max_D;
constexpr int32_t segment_prefetch = 2;
constexpr int32_t segment_unroll = 8;
constexpr int32_t segment_split = 0;
Expand Down Expand Up @@ -765,7 +765,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
{%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %}
{%- for cache_type in ['float', 'at::Half'] %}
{%- for index_type in ['int32_t', 'int64_t'] %}
{%- for kEmbeddingDim in [64, 128, 160, 192, 256, 320] %}
{%- for kEmbeddingDim in range(64, 2049, 64) %}
{%- for kWeighDecayMode in [0, 1, 2] %}
{{ hip_template_instantiation(
emb_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1257,9 +1257,9 @@ Tensor {{ embedding_cuda_op }}(
if (use_hip_kernel && !mixed_D && !cached && supported_weights_type && supported_grad_type && same_precision && rocm::is_supported_cdna())
{
constexpr int segments_per_workgroup = 4;
{%- for kDimSize in [64, 128, 160, 192, 256, 320] %}
{%- for kBucketDim in range(64, 2049, 64) %}
{%- for kWeightDecayMode in [0, 1, 2] %}
if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }})
if (max_D <= {{ kBucketDim }} && max_D > {{ kBucketDim - 64 }} && weight_decay_mode == {{ kWeightDecayMode }})
{
warp_per_row_grid_size = div_round_up(sorted_linear_indices_num_runs[0].item<int32_t>(), segments_per_workgroup);
blockSize = dim3(256);
Expand All @@ -1274,7 +1274,7 @@ Tensor {{ embedding_cuda_op }}(
kFixedMaxVecsPerThread,
kThreadGroupSize,
kUseVecBlocking,
{{ kDimSize }},
{{ kBucketDim }},
{{ kWeightDecayMode }}>;
}
{%- endfor %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct rowwise_adagrad_optimizer_t
}

template <int32_t thread_length, int32_t segment_split>
__device__ void update(cache_t* acc, emb_t* weight, index_t row_index)
__device__ void update(cache_t* acc, emb_t* weight, index_t row_index, int32_t emb_dim)
{
if constexpr(segment_split == 0)
{
Expand Down Expand Up @@ -74,7 +74,7 @@ struct rowwise_adagrad_optimizer_t

cache_t avg_square =
wave_reduce<reduce_op_sum_t<cache_t>, cache_t, AMDGCN_WAVE_SIZE>(local_sum_squre) /
embedding_dim;
emb_dim;

cache_t momentum_new = momentum + avg_square;

Expand Down Expand Up @@ -139,6 +139,8 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
optimizer_karg_t opt_karg,
const float * p_sorted_indice_weights = nullptr)
{
// allow dword_per_row to be calculated at compile time based on [embedding_dim] because the
// ceil operation produces the same result as with the runtime param [emb_dim]
constexpr uint32_t dword_per_row = (embedding_dim + THREADS_PER_ROW - 1) / THREADS_PER_ROW;
constexpr uint32_t waves_per_block = block_size / AMDGCN_WAVE_SIZE;
constexpr uint32_t length_mask = ~(segment_unroll - 1);
Expand Down Expand Up @@ -220,13 +222,13 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
bag_index = infos[0] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);

table_index = infos[1] >> info_B_num_bits;
bag_index = infos[1] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);
if constexpr (!weighted){
#pragma unroll
for(int j = 2; j < segment_unroll; j += 2)
Expand All @@ -238,7 +240,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
bag_index = infos[j] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
&grad_acc[0], &grad_data[dword_per_row], lane_id);
Expand All @@ -247,7 +249,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
bag_index = infos[j + 1] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);
}

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
Expand All @@ -274,7 +276,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
bag_index = infos[j] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
&grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]);
Expand All @@ -283,7 +285,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
bag_index = infos[j + 1] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);
}

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
Expand All @@ -307,13 +309,13 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
bag_index = infos[0] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);

table_index = infos[1] >> info_B_num_bits;
bag_index = infos[1] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);

if constexpr (!weighted) {
#pragma unroll
Expand All @@ -326,7 +328,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
bag_index = infos[j] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
&grad_acc[0], &grad_data[dword_per_row], lane_id);
Expand All @@ -335,7 +337,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
bag_index = infos[j + 1] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);
}

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
Expand All @@ -353,7 +355,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
bag_index = infos[j] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
&grad_acc[0], &grad_data[dword_per_row], lane_id, indice_weights[j-1]);
Expand All @@ -362,7 +364,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
bag_index = infos[j + 1] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[dword_per_row], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);
}

accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
Expand All @@ -385,7 +387,7 @@ L_tail_grad_acc:
bag_index = infos[0] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);
accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
&grad_acc[0], &grad_data[0], lane_id);

Expand All @@ -403,7 +405,7 @@ L_tail_grad_acc:
bag_index = infos[0] & info_B_mask;

load_row_per_warp<grad_t, embedding_dim, index_t>::run(
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * embedding_dim, lane_id);
&grad_data[0], bag_index * num_tables, p_output_grad + table_index * emb_dim, lane_id, emb_dim);
accumulate_row_per_warp<grad_t, embedding_dim, cache_t, weighted>::run(
&grad_acc[0], &grad_data[0], lane_id, indice_weights[0]);

Expand All @@ -414,10 +416,10 @@ L_tail_grad_acc:

// load the old emb weight data
load_row_per_warp<emb_t, embedding_dim, index_t>::run(
&emb_data[0], emb_idx, p_emb_table, lane_id);
&emb_data[0], emb_idx, p_emb_table, lane_id, emb_dim);
optimizer_t optimizer(opt_karg);
optimizer.template update<dword_per_row, segment_split>(grad_acc, emb_data, emb_idx);
optimizer.template update<dword_per_row, segment_split>(grad_acc, emb_data, emb_idx, emb_dim);

store_row_per_warp<emb_t, embedding_dim>::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id);
store_row_per_warp<emb_t, embedding_dim, index_t>::run(&emb_data[0], emb_idx, p_emb_table, lane_id, emb_dim);
}
} // namespace fbgemm_gpu::rocm
Loading