Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#include <cassert>
#include <cfloat>
#include <cstdio>
#include <map>
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
Expand Down Expand Up @@ -1092,10 +1094,6 @@ struct ggml_cuda_device_info {
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};

std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};

#ifdef GGML_USE_NCCL
ncclComm_t comms[GGML_CUDA_MAX_DEVICES];
#endif // GGML_USE_NCCL
};

const ggml_cuda_device_info & ggml_cuda_info();
Expand Down Expand Up @@ -1154,7 +1152,6 @@ struct ggml_cuda_pool_alloc {
ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
};


// backend interface

struct ggml_tensor_extra_gpu {
Expand Down
63 changes: 41 additions & 22 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -338,14 +338,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
}
}

#ifdef GGML_USE_NCCL
int dev_ids[GGML_CUDA_MAX_DEVICES];
for (int id = 0; id < info.device_count; ++id) {
dev_ids[id] = id;
}
NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids));
#endif // GGML_USE_NCCL

return info;
}

Expand Down Expand Up @@ -574,6 +566,20 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
}

#ifdef GGML_USE_NCCL
static std::map<std::vector<int>, std::vector<ncclComm_t>> ggml_cuda_nccl_comms;
static std::mutex ggml_cuda_nccl_mutex;

static std::vector<ncclComm_t> ggml_cuda_get_nccl_comms(const std::vector<int> & devs) {
std::lock_guard lock(ggml_cuda_nccl_mutex);
if (ggml_cuda_nccl_comms.find(devs) == ggml_cuda_nccl_comms.end()) {
ggml_cuda_nccl_comms[devs].resize(devs.size());
NCCL_CHECK(ncclCommInitAll(ggml_cuda_nccl_comms[devs].data(), devs.size(), devs.data()));
}
return ggml_cuda_nccl_comms[devs];
}
#endif // GGML_USE_NCCL

// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured

Expand Down Expand Up @@ -1139,19 +1145,29 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t
GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i]));
}

const ggml_cuda_device_info info = ggml_cuda_info();
const ggml_cuda_device_info & info = ggml_cuda_info();
std::vector<int> dev_ids;
dev_ids.reserve(n_backends);
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
dev_ids.push_back(cuda_ctx->device);
}
const std::vector<ncclComm_t> comms = ggml_cuda_get_nccl_comms(dev_ids);
Comment on lines +1148 to +1155
Copy link
Copy Markdown
Contributor

@gaugarg-nv gaugarg-nv Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if we move these ops during setup and not in the critical path. Maybe during meta-backend initialization when all simple backends are known?

Currently, this seems expensive for a critical path as we are:

  • Allocating and reserving vector (heap allocation)
  • Doing map lookup with a vector<int> as key
  • Returning vector<ncclComm_t> by value (heap allocation and copies)
  • Doing Mutex lock/unlock

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean but I'm not sure there is a good way to do this (and also shouldn't it be fine as long as the CPU is sufficiently far ahead of the GPUs?). Longer-term I've been thinking it may make sense to extend the API to allow for evaluating multiple ggml graphs on multiple CUDA backends in tandem. This would then allow us to capture the NCCL operations in a CUDA graph anyways so this is less of a problem.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if this has any perf overhead, especially for small or MOE models?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance
GPU Model Microbatch size Test t/s b8750 t/s f7dc561 Speedup
2x RTX 4090 granitemoe 3B Q4_0 1 pp512 288.59 289.66 1.00
2x RTX 4090 granitemoe 3B Q4_0 2 pp512 511.75 510.89 1.00
2x RTX 4090 granitemoe 3B Q4_0 4 pp512 848.57 851.49 1.00
2x RTX 4090 granitemoe 3B Q4_0 8 pp512 1339.81 1336.25 1.00
2x RTX 4090 granitemoe 3B Q4_0 16 pp512 1550.93 1548.15 1.00
2x RTX 4090 granitemoe 3B Q4_0 32 pp512 2403.81 2385.87 0.99
2x RTX 4090 granitemoe 3B Q4_0 64 pp512 3547.17 3547.75 1.00
2x RTX 4090 granitemoe 3B Q4_0 128 pp512 4926.61 4903.73 1.00
2x RTX 4090 granitemoe 3B Q4_0 256 pp512 7253.22 7234.47 1.00
2x RTX 4090 granitemoe 3B Q4_0 512 pp512 16310.43 16336.26 1.00
2x RTX 4090 llama 1B Q4_0 1 pp512 710.85 711.42 1.00
2x RTX 4090 llama 1B Q4_0 2 pp512 1264.17 1251.33 0.99
2x RTX 4090 llama 1B Q4_0 4 pp512 2168.67 2188.49 1.01
2x RTX 4090 llama 1B Q4_0 8 pp512 3178.45 3111.29 0.98
2x RTX 4090 llama 1B Q4_0 16 pp512 4325.19 4299.78 0.99
2x RTX 4090 llama 1B Q4_0 32 pp512 6723.29 6734.26 1.00
2x RTX 4090 llama 1B Q4_0 64 pp512 9630.32 9687.58 1.01
2x RTX 4090 llama 1B Q4_0 128 pp512 12848.67 12909.88 1.00
2x RTX 4090 llama 1B Q4_0 256 pp512 18234.67 18545.47 1.02
2x RTX 4090 llama 1B Q4_0 512 pp512 30060.60 30200.48 1.00

The above numbers are based on the averages of 40 benchmark runs each. I'm having difficulty measuring any consistent effect at all and the spread between individual runs is much larger than the difference in the averages of the two commits. So I think it's safe to say that the performance impact is negligible.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the data. This looks good.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if we move these ops during setup and not in the critical path. Maybe during meta-backend initialization when all simple backends are known?

I see what you mean but I'm not sure there is a good way to do this

Currently, the meta backend queries the backend registry on every graph compute for an implementation of "ggml_backend_allreduce_tensor":

if (n_backends > 1 && i < n_subgraphs - 1) {
bool backend_allreduce_success = false;
ggml_backend_allreduce_tensor_t allreduce_tensor = (ggml_backend_allreduce_tensor_t) ggml_backend_reg_get_proc_address(
ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_ctx->backend_configs[0].backend)), "ggml_backend_allreduce_tensor");
if (allreduce_tensor) {
std::vector<ggml_backend_t> backends;
backends.reserve(n_backends);
std::vector<ggml_tensor *> nodes;
nodes.reserve(n_backends);
for (size_t j = 0; j < n_backends; j++) {
auto & bcj = backend_ctx->backend_configs[j];
backends.push_back(bcj.backend);
ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main;
nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]);
}
backend_allreduce_success = allreduce_tensor(backends.data(), nodes.data(), n_backends);
}
if (!backend_allreduce_success) {
const ggml_status status = allreduce_fallback(i);
if (status != GGML_STATUS_SUCCESS) {
return status;
}
}
}

This makes it difficult to maintain the "allreduce state" (i.e. the NCCL communicators).

We can fix this by replacing the ``"ggml_backend_allreduce_tensor"` proc with:

  • "ggml_backend_allreduce_init"
  • "ggml_backend_allreduce_apply"
  • "ggml_backend_allreduce_free"

The ggml_backend_cuda_allreduce_init() and ggml_backend_cuda_allreduce_free() will create and destroy the NCCL communicators. This way the lifetime of the communicators will be managed by the meta backend.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also be fine with me.


// For small tensors, simply reduce them as FP32.
// The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0.
if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) {
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());
{
std::lock_guard lock(ggml_cuda_nccl_mutex);
if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) {
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comms[i], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());

return true;
return true;
}
}

// For large tensors it's faster to compress them to BF16 for the reduction:
Expand All @@ -1169,12 +1185,15 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t
CUDA_CHECK(cudaGetLastError());
}

NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream()));
{
std::lock_guard lock(ggml_cuda_nccl_mutex);
NCCL_CHECK(ncclGroupStart());
for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comms[i], cuda_ctx->stream()));
}
NCCL_CHECK(ncclGroupEnd());
}
NCCL_CHECK(ncclGroupEnd());

for (size_t i = 0; i < n_backends; ++i) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context;
Expand Down
Loading