diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 56a67f1edc8..5b022ba038a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -25,6 +25,8 @@ #include #include #include +#include +#include #include #include #include @@ -1092,10 +1094,6 @@ struct ggml_cuda_device_info { cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; std::array default_tensor_split = {}; - -#ifdef GGML_USE_NCCL - ncclComm_t comms[GGML_CUDA_MAX_DEVICES]; -#endif // GGML_USE_NCCL }; const ggml_cuda_device_info & ggml_cuda_info(); @@ -1154,7 +1152,6 @@ struct ggml_cuda_pool_alloc { ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete; }; - // backend interface struct ggml_tensor_extra_gpu { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8613d20b9f9..a02d7885274 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -338,14 +338,6 @@ static ggml_cuda_device_info ggml_cuda_init() { } } -#ifdef GGML_USE_NCCL - int dev_ids[GGML_CUDA_MAX_DEVICES]; - for (int id = 0; id < info.device_count; ++id) { - dev_ids[id] = id; - } - NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids)); -#endif // GGML_USE_NCCL - return info; } @@ -574,6 +566,20 @@ std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(i return std::unique_ptr(new ggml_cuda_pool_leg(device)); } +#ifdef GGML_USE_NCCL +static std::map, std::vector> ggml_cuda_nccl_comms; +static std::mutex ggml_cuda_nccl_mutex; + +static std::vector ggml_cuda_get_nccl_comms(const std::vector & devs) { + std::lock_guard lock(ggml_cuda_nccl_mutex); + if (ggml_cuda_nccl_comms.find(devs) == ggml_cuda_nccl_comms.end()) { + ggml_cuda_nccl_comms[devs].resize(devs.size()); + NCCL_CHECK(ncclCommInitAll(ggml_cuda_nccl_comms[devs].data(), devs.size(), devs.data())); + } + return ggml_cuda_nccl_comms[devs]; +} +#endif // GGML_USE_NCCL + // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error // this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured @@ -1139,19 +1145,29 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i])); } - const ggml_cuda_device_info info = ggml_cuda_info(); + const ggml_cuda_device_info & info = ggml_cuda_info(); + std::vector dev_ids; + dev_ids.reserve(n_backends); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + dev_ids.push_back(cuda_ctx->device); + } + const std::vector comms = ggml_cuda_get_nccl_comms(dev_ids); // For small tensors, simply reduce them as FP32. // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0. - if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { - NCCL_CHECK(ncclGroupStart()); - for (size_t i = 0; i < n_backends; ++i) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; - NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); - } - NCCL_CHECK(ncclGroupEnd()); + { + std::lock_guard lock(ggml_cuda_nccl_mutex); + if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comms[i], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); - return true; + return true; + } } // For large tensors it's faster to compress them to BF16 for the reduction: @@ -1169,12 +1185,15 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t CUDA_CHECK(cudaGetLastError()); } - NCCL_CHECK(ncclGroupStart()); - for (size_t i = 0; i < n_backends; ++i) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; - NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); + { + std::lock_guard lock(ggml_cuda_nccl_mutex); + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comms[i], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); } - NCCL_CHECK(ncclGroupEnd()); for (size_t i = 0; i < n_backends; ++i) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;