MPI4PY based replay buffer communication#489
Conversation
| wandb.log(to_log, step=iteration) | ||
|
|
There was a problem hiding this comment.
The code references args.spawn_backend but the --spawn_backend argument was removed from the argument parser (previously defined around line 1384-1389). This will cause AttributeError: 'Namespace' object has no attribute 'spawn_backend' when this condition is evaluated.
# Fix: Replace with the new argument:
if args.dist_lib == "torch":
assert averaging_policy_torch is not NoneSpotted by Graphite
Is this helpful? React 👍 or 👎 to let us know.
| assert averaging_policy_torch is not None | ||
| assert averaging_policy_mpi4py is not None |
There was a problem hiding this comment.
These assertions will always fail because only one of the two policies can be initialized based on the backend. Looking at the initialization logic (lines 924-988), either averaging_policy_torch OR averaging_policy_mpi4py is set depending on args.dist_lib, but never both. One will always be None, causing the assertion to fail during cleanup.
Fix: Remove these assertions or change to:
assert (averaging_policy_torch is not None) or (averaging_policy_mpi4py is not None)| assert averaging_policy_torch is not None | |
| assert averaging_policy_mpi4py is not None | |
| assert (averaging_policy_torch is not None) or (averaging_policy_mpi4py is not None) |
Spotted by Graphite
Is this helpful? React 👍 or 👎 to let us know.
|
@younik what is the status of this PR given you just merged something similar? |
23f235e to
ec3753c
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #489 +/- ##
==========================================
- Coverage 72.48% 72.20% -0.28%
==========================================
Files 55 55
Lines 8519 8574 +55
Branches 1090 1102 +12
==========================================
+ Hits 6175 6191 +16
- Misses 1957 1995 +38
- Partials 387 388 +1
🚀 New features to boost your workflow:
|
Description
Communication Backend Abstraction
This PR updates the replay buffer communication to support the MPI4PY backend.
To enable this, I introduced wrapper functions in distributed.py (e.g.,
send,recv, etc.). These wrappers abstract the underlying communication implementation, allowing us to switch between backends by selecting either PyTorch distributed or MPI internally.Serialization Improvements
Additionally, I optimized the message serialization and deserialization logic used during communication.
With the MPI4PY backend, we observe 8-12 GB/s bandwidth, compared to ~100 MB/s with Gloo. However, the end-to-end improvement in the communication phase is ~2.4x. This is likely because the current
send/recvoperations are still blocking calls and not async.Changes in
train_hypergrid.pyTo simplify backend selection for replay buffer communication and selective averaging, the configuration has been reduced to two flags:
--dist_lib: Specifies the distributed library to use (torch or mpi)--torch_backend: If torch is selected as dist_lib, this flag selects the PyTorch distributed backend (e.g., gloo, mpi, etc.)The
--spawn_backendflag has been removed. Previously, it toggled between PyTorch Distributed–based model averaging and MPI4PY-based averaging. This behavior is now controlled directly through the--dist_libflag.