Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions csrc/include/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ DINLINE O downcast(V val)
}

// This function is meant to be used as the first synchronization in the all
// reduce kernel. Thus, it doesn't need to make any visibility guarantees for
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
// reduce kernel. When using registered CUDA graph inputs, prior memory accesses
// from producer kernels must be system-visible before peers read the input
// buffers. Issue __threadfence_system() to ensure global visibility.
template <int ngpus>
DINLINE void start_sync(const RankSignals& sg,
#ifndef USE_ROCM
Expand All @@ -156,6 +156,11 @@ DINLINE void start_sync(const RankSignals& sg,
int rank)
{
#ifdef USE_ROCM
// Ensure prior memory writes (e.g. from producer kernels in CUDA graph)
// are visible to peer GPUs before signaling readiness.
if(threadIdx.x == 0)
__threadfence_system();
__syncthreads();
uint32_t flag = self_sg->_flag[blockIdx.x] + 1;
if(threadIdx.x < ngpus)
{
Expand Down Expand Up @@ -907,7 +912,7 @@ struct AbsMaxFunctor
template<typename T>
DINLINE T shfl_xor(T var, int mask, int width = opus::get_warp_size())
{
static_assert(sizeof(T) == 4);
static_assert(sizeof(T) == 4);
int self = opus::lane_id();
int index = (self & ~(width - 1)) + ((self ^ mask) & (width - 1));
return __builtin_bit_cast(T, __builtin_amdgcn_ds_bpermute(index << 2, __builtin_bit_cast(int, var)));
Expand Down Expand Up @@ -3539,7 +3544,7 @@ void dispatchFusedQKNormAllReduce(hipStream_t stream,
std::to_string(d));
}
RankData* ptrs = get_buffer_RD(stream, qkv_in);

#define DISPATCH_QKNORM_AR_FUSION_KERNEL(NGPUS) \
{ \
qknorm_allreduce_fusion_kernel_2stage_launcher<T, NGPUS>(ptrs, \
Expand Down
Loading