Skip to content

MPI4PY based replay buffer communication#489

Draft
chirayuharyan wants to merge 6 commits intomasterfrom
mpi4py-replay-buffer
Draft

MPI4PY based replay buffer communication#489
chirayuharyan wants to merge 6 commits intomasterfrom
mpi4py-replay-buffer

Conversation

@chirayuharyan
Copy link
Copy Markdown
Collaborator

@chirayuharyan chirayuharyan commented Mar 13, 2026

  • I've read the .github/CONTRIBUTING.md file
  • My code follows the typing guidelines
  • I've added appropriate tests
  • I've run pre-commit hooks locally

Description

Communication Backend Abstraction

This PR updates the replay buffer communication to support the MPI4PY backend.

To enable this, I introduced wrapper functions in distributed.py (e.g., send, recv, etc.). These wrappers abstract the underlying communication implementation, allowing us to switch between backends by selecting either PyTorch distributed or MPI internally.

Serialization Improvements

Additionally, I optimized the message serialization and deserialization logic used during communication.

With the MPI4PY backend, we observe 8-12 GB/s bandwidth, compared to ~100 MB/s with Gloo. However, the end-to-end improvement in the communication phase is ~2.4x. This is likely because the current send/recv operations are still blocking calls and not async.

Changes in train_hypergrid.py

To simplify backend selection for replay buffer communication and selective averaging, the configuration has been reduced to two flags:

--dist_lib : Specifies the distributed library to use (torch or mpi)

--torch_backend : If torch is selected as dist_lib, this flag selects the PyTorch distributed backend (e.g., gloo, mpi, etc.)

The --spawn_backend flag has been removed. Previously, it toggled between PyTorch Distributed–based model averaging and MPI4PY-based averaging. This behavior is now controlled directly through the --dist_lib flag.

@chirayuharyan chirayuharyan changed the title mpi4py final comms MPI4PY based replay buffer communication Mar 13, 2026
Comment on lines 1229 to 1230
wandb.log(to_log, step=iteration)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code references args.spawn_backend but the --spawn_backend argument was removed from the argument parser (previously defined around line 1384-1389). This will cause AttributeError: 'Namespace' object has no attribute 'spawn_backend' when this condition is evaluated.

# Fix: Replace with the new argument:
if args.dist_lib == "torch":
    assert averaging_policy_torch is not None

Spotted by Graphite

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Comment on lines +1250 to +1251
assert averaging_policy_torch is not None
assert averaging_policy_mpi4py is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These assertions will always fail because only one of the two policies can be initialized based on the backend. Looking at the initialization logic (lines 924-988), either averaging_policy_torch OR averaging_policy_mpi4py is set depending on args.dist_lib, but never both. One will always be None, causing the assertion to fail during cleanup.

Fix: Remove these assertions or change to:

assert (averaging_policy_torch is not None) or (averaging_policy_mpi4py is not None)
Suggested change
assert averaging_policy_torch is not None
assert averaging_policy_mpi4py is not None
assert (averaging_policy_torch is not None) or (averaging_policy_mpi4py is not None)

Spotted by Graphite

Fix in Graphite


Is this helpful? React 👍 or 👎 to let us know.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@chirayuharyan chirayuharyan marked this pull request as draft March 16, 2026 03:49
@josephdviviano
Copy link
Copy Markdown
Collaborator

@younik what is the status of this PR given you just merged something similar?

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 30, 2026

Codecov Report

❌ Patch coverage is 23.77049% with 93 lines in your changes missing coverage. Please review.
✅ Project coverage is 72.20%. Comparing base (f0605a8) to head (3cbe811).

Files with missing lines Patch % Lines
src/gfn/utils/distributed.py 15.85% 69 Missing ⚠️
src/gfn/containers/replay_buffer.py 43.47% 12 Missing and 1 partial ⚠️
src/gfn/containers/replay_buffer_manager.py 27.27% 8 Missing ⚠️
src/gfn/containers/message.py 50.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #489      +/-   ##
==========================================
- Coverage   72.48%   72.20%   -0.28%     
==========================================
  Files          55       55              
  Lines        8519     8574      +55     
  Branches     1090     1102      +12     
==========================================
+ Hits         6175     6191      +16     
- Misses       1957     1995      +38     
- Partials      387      388       +1     
Files with missing lines Coverage Δ
src/gfn/containers/message.py 68.18% <50.00%> (-1.82%) ⬇️
src/gfn/containers/replay_buffer_manager.py 26.15% <27.27%> (+5.16%) ⬆️
src/gfn/containers/replay_buffer.py 72.13% <43.47%> (+0.86%) ⬆️
src/gfn/utils/distributed.py 15.63% <15.85%> (-0.86%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants