-
Notifications
You must be signed in to change notification settings - Fork 9
Optimizing writes in nobag inference kernel #157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: aryaman/upstream
Are you sure you want to change the base?
Changes from all commits
1d15bad
db2c70e
46acccb
01c7d8a
15d6906
b56d302
89d30b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -281,6 +281,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no | |
| } | ||
| } | ||
|
|
||
| {% if not nobag %} | ||
| for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { | ||
| #pragma unroll OutputRowsPerThread | ||
| for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { | ||
|
|
@@ -304,7 +305,6 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no | |
|
|
||
| using scalar_t = {{ emb_weight_type.cpp_type_name }}; | ||
|
|
||
| {% if not nobag %} | ||
| #pragma unroll AccumulateStoreRequests | ||
| for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { | ||
| scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x]; | ||
|
|
@@ -314,57 +314,135 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no | |
| accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); | ||
| {% endif %} | ||
| } | ||
| {% else %} | ||
| const int32_t output_j = indices_starts[i] + L_start + input_row_idx; | ||
| if constexpr (std::is_same_v<output_t, float> || std::is_same_v<output_t, at::Half> || std::is_same_v<output_t, at::BFloat16>) { | ||
| #pragma unroll AccumulateStoreRequests | ||
| for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { | ||
| // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: | ||
| // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to | ||
| // the scale/shift handling). | ||
| // Reason: to avoid divergence the first thread in the warp computes garbage. | ||
| const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; | ||
| scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x]; | ||
| if (output_d < D) { | ||
| const int num_valid_outputs = min(static_cast<int>(D - output_d), static_cast<int>({{ (32 // emb_weight_type.bit_width) }})); | ||
| VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); | ||
| acc.store(&output[output_j][output_d], num_valid_outputs); | ||
| } | ||
| } | ||
| } | ||
| {% else %} | ||
| using scalar_t = {{ emb_weight_type.cpp_type_name }}; | ||
| {% if emb_weight_type.primitive_type == "INT" and is_rocm %} | ||
| if (D % (kWarpSize * kOutputsPerThread) == 0 && D_padding > 0) { | ||
| for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { | ||
| #pragma unroll OutputRowsPerThread | ||
| for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { | ||
| bool valid = L_start + input_row_idx < Ls[i]; | ||
| if (!valid) { | ||
| continue; | ||
| } | ||
| } else if constexpr (std::is_same_v<output_t, uint8_t>) { | ||
| // INT8: | ||
| // apply per feature row-wise int8 | ||
| auto thread_local_min = std::numeric_limits<float>::max(); | ||
| auto thread_local_max = std::numeric_limits<float>::lowest(); | ||
| float2 qparams; | ||
| #pragma unroll AccumulateStoreRequests | ||
| for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { | ||
| auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; | ||
| scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x]; | ||
| VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); | ||
| if (output_d < D) { | ||
| const uint32_t* row = reinterpret_cast<const uint32_t*>(&buffers[warp_idx][i][input_row_idx][0]); | ||
| half2 shift_scale = reinterpret_cast<const half2*>(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; | ||
| // Size of the per-row scale/bias header in scalar_t units; derived from shift_scale's type. | ||
| constexpr uint32_t kHeaderScalarOffset = sizeof(decltype(shift_scale)) / sizeof(scalar_t); | ||
| const uint32_t opt_iters = D / (kWarpSize * kOutputsPerThread); | ||
| const int32_t output_j = indices_starts[i] + L_start + input_row_idx; | ||
| if constexpr (std::is_same_v<output_t, float> || std::is_same_v<output_t, at::Half> || std::is_same_v<output_t, at::BFloat16>) { | ||
| for (uint32_t j = 0; j < opt_iters; ++j) { | ||
| const auto output_d = j * kWarpSize * kOutputsPerThread + threadIdx.x * kOutputsPerThread; | ||
| // +kHeaderScalarOffset skips the half2 scale/bias header at row[0] | ||
| scalar_t v = reinterpret_cast<const scalar_t*>(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset]; | ||
| VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale); | ||
| acc.store(&output[output_j][output_d], {{ (32 // emb_weight_type.bit_width) }}); | ||
| } | ||
| } else if constexpr (std::is_same_v<output_t, uint8_t>) { | ||
| auto thread_local_min = std::numeric_limits<float>::max(); | ||
| auto thread_local_max = std::numeric_limits<float>::lowest(); | ||
| float2 qparams; | ||
| // Pass 1: min/max scan | ||
| for (uint32_t j = 0; j < opt_iters; ++j) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Loops with
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did investigate that. I created a dynamic if-else ladder for different embedding dimensions (256, 512, 768 etc.) and manually unrolling the loops within each case. This negatively impacted the performance, though |
||
| scalar_t v = reinterpret_cast<const scalar_t*>(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset]; | ||
| VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale); | ||
| thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); | ||
| thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); | ||
| } | ||
| qparams = warp_find_qparams(thread_local_min, thread_local_max); | ||
| // Pass 2: requantize and store | ||
| for (uint32_t j = 0; j < opt_iters; ++j) { | ||
| const auto output_d = j * kWarpSize * kOutputsPerThread + threadIdx.x * kOutputsPerThread; | ||
| scalar_t v = reinterpret_cast<const scalar_t*>(row)[j * kWarpSize + threadIdx.x + kHeaderScalarOffset]; | ||
| VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::INT> acc(v, shift_scale); | ||
| acc.store(&output[output_j][output_d], qparams, {{ (32 // emb_weight_type.bit_width) }}); | ||
| } | ||
| if (threadIdx.x == 0) { | ||
| store_qparams_to_row(&output[output_j][D], qparams); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } else | ||
| {% endif %} | ||
| {#- For non-INT weight types or non-ROCm builds the {%- if %} above renders | ||
| nothing; this block then becomes a bare scope holding the unmodified loop. -#} | ||
| { | ||
| for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { | ||
| #pragma unroll OutputRowsPerThread | ||
| for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { | ||
| bool valid = L_start + input_row_idx < Ls[i]; | ||
| if (!valid) { | ||
| continue; | ||
| } | ||
| qparams = warp_find_qparams(thread_local_min, thread_local_max); | ||
| #pragma unroll AccumulateStoreRequests | ||
| for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { | ||
| const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; | ||
| scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x]; | ||
| if (output_d < D) { | ||
| const int num_valid_outputs = min(static_cast<int>(D - output_d), static_cast<int>({{ (32 // emb_weight_type.bit_width) }})); | ||
| const uint32_t* row = reinterpret_cast<const uint32_t*>(&buffers[warp_idx][i][input_row_idx][0]); | ||
| // scale and bias are at the beginning of each row. | ||
| // rationale: have scale/shift at start since these get loaded first | ||
| // and then broadcasted around so it might speed up the first cache miss. | ||
| {% if emb_weight_type.primitive_type == "INT" %} | ||
| // In PackedMode, row pointer may contain several rows from different bags, so each thread/lane should | ||
| // read the certain shift_scale related to the row in the packed_bag. | ||
| half2 shift_scale = reinterpret_cast<const half2*>(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; | ||
| {% endif %} | ||
|
|
||
| {% if weighted %} | ||
| float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; | ||
| {% endif %} | ||
|
|
||
| const int32_t output_j = indices_starts[i] + L_start + input_row_idx; | ||
| if constexpr (std::is_same_v<output_t, float> || std::is_same_v<output_t, at::Half> || std::is_same_v<output_t, at::BFloat16>) { | ||
| #pragma unroll AccumulateStoreRequests | ||
| for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { | ||
| // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: | ||
| // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to | ||
| // the scale/shift handling). | ||
| // Reason: to avoid divergence the first thread in the warp computes garbage. | ||
| const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; | ||
| scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x]; | ||
| if (output_d < D) { | ||
| const int num_valid_outputs = min(static_cast<int>(D - output_d), static_cast<int>({{ (32 // emb_weight_type.bit_width) }})); | ||
| VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); | ||
| acc.store(&output[output_j][output_d], num_valid_outputs); | ||
| } | ||
| } | ||
| } else if constexpr (std::is_same_v<output_t, uint8_t>) { | ||
| // INT8: | ||
| // apply per feature row-wise int8 | ||
| auto thread_local_min = std::numeric_limits<float>::max(); | ||
| auto thread_local_max = std::numeric_limits<float>::lowest(); | ||
| float2 qparams; | ||
| #pragma unroll AccumulateStoreRequests | ||
| for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { | ||
| auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; | ||
| scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x]; | ||
| VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); | ||
| acc.store(&output[output_j][output_d], qparams, num_valid_outputs); | ||
| if (output_d < D) { | ||
| thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); | ||
| thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); | ||
| } | ||
| } | ||
| qparams = warp_find_qparams(thread_local_min, thread_local_max); | ||
| #pragma unroll AccumulateStoreRequests | ||
| for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { | ||
| const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; | ||
| scalar_t v = reinterpret_cast<const scalar_t*>(row)[kWarpSize * j + threadIdx.x]; | ||
| if (output_d < D) { | ||
| const int num_valid_outputs = min(static_cast<int>(D - output_d), static_cast<int>({{ (32 // emb_weight_type.bit_width) }})); | ||
| VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); | ||
| acc.store(&output[output_j][output_d], qparams, num_valid_outputs); | ||
| } | ||
| } | ||
| if (threadIdx.x == 0) { | ||
| store_qparams_to_row(&output[output_j][D], qparams); | ||
| } | ||
| } | ||
| if (threadIdx.x == 0) { | ||
| store_qparams_to_row(&output[output_j][D], qparams); | ||
| } | ||
| } | ||
| {% endif %} | ||
| } | ||
| } | ||
| {% endif %} | ||
| } | ||
|
|
||
| {% if not nobag %} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ain't always
D_padding > 0in case of quantization to INT?