Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/gfn/containers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ def __init__(self, message_type: MessageType, message_data: Any = None):
self.message_type = message_type
self.message_data = message_data

def serialize(self) -> torch.ByteTensor:
def serialize(self) -> torch.Tensor:
"""Convert message into a tensor of bytes."""
obj_bytes = pickle.dumps(self)
return torch.frombuffer(bytearray(obj_bytes), dtype=torch.uint8) # type: ignore[return-value]

@staticmethod
def deserialize(byte_tensor: torch.ByteTensor) -> Message:
def deserialize(byte_tensor: torch.Tensor) -> Message:
"""Reconstruct Message from a tensor of bytes."""
obj_bytes = bytes(byte_tensor.numpy())
return pickle.loads(obj_bytes)
98 changes: 70 additions & 28 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from typing import Protocol, Union, cast, runtime_checkable

import torch
import torch.distributed as dist

from gfn.containers.message import Message, MessageType
from gfn.containers.states_container import StatesContainer
from gfn.containers.trajectories import Trajectories
from gfn.containers.transitions import Transitions
from gfn.env import Env
from gfn.utils.common import Timer
from gfn.utils.distributed import recv, send


@runtime_checkable
Expand All @@ -31,18 +32,36 @@ def terminating_states(self): ... # noqa: E704


class ReplayBuffer:
"""A replay buffer for storing containers.
"""A replay buffer for storing training containers.

Supports local-only operation and distributed remote buffer communication.

Features:
- **Local buffering**: Stores Trajectories, Transitions, or
StatesContainers up to a fixed capacity.
- **Prioritized capacity**: Optionally keeps only the highest-reward
items when the buffer is full.
- **Prioritized sampling**: Optionally samples with probability
proportional to reward (softmax over log-rewards).
- **Remote buffer communication**: When ``remote_manager_rank`` is set,
periodically sends batched containers to a remote
``ReplayBufferManager`` and receives score dictionaries back.
- **Communication backends**: The ``communication_backend`` parameter
selects between ``"torch"`` (PyTorch distributed / Gloo) and
``"mpi"`` (MPI4PY, ~8-12 GB/s vs ~100 MB/s with Gloo).
- **Async scoring**: When ``async_score`` is enabled, trajectory sends
are fire-and-forget; scores are collected lazily on the next
``add()`` call (1-iteration stale), decoupling training throughput
from buffer scoring latency.
- **Timing instrumentation**: When ``timing`` is enabled, serialization,
send, and receive durations are recorded for profiling.

Attributes:
env: The environment associated with the containers.
capacity: The maximum number of items the buffer can hold.
training_container: The buffer contents (Trajectories, Transitions,
or StatesContainer). This is dynamically set based on the type of the
or StatesContainer). Dynamically set based on the type of the
first added object.
prioritized_capacity: Whether to use prioritized capacity
(keep highest-reward items).
prioritized_sampling: Whether to sample items with probability proportional
to their reward.
"""

def __init__(
Expand All @@ -53,6 +72,8 @@ def __init__(
prioritized_sampling: bool = False,
remote_manager_rank: int | None = None,
remote_buffer_freq: int = 1,
communication_backend: str = "mpi",
timing: bool = False,
async_score: bool = False,
):
"""Initializes a ReplayBuffer instance.
Expand All @@ -67,6 +88,11 @@ def __init__(
None if no remote manager is assigned.
remote_buffer_freq: Frequency (in number of add() calls) at which to contact
the remote buffer manager.
communication_backend: Communication backend for remote buffer operations.
``"mpi"`` uses MPI4PY (higher bandwidth), ``"torch"`` uses PyTorch
distributed (Gloo/MPI).
timing: If True, record durations for serialize/send/recv operations
in ``timing_data`` for profiling.
async_score: If True, trajectory sends are fire-and-forget; scores
are collected lazily on the next add() call (1-iteration stale).
Decouples training throughput from buffer scoring latency.
Expand All @@ -78,20 +104,16 @@ def __init__(
self.prioritized_capacity = prioritized_capacity
self.prioritized_sampling = prioritized_sampling
self.pending_container: ContainerUnion | None = None
self.communication_backend = communication_backend
self.timing = timing
self.timing_data: dict[str, list[float]] = {}

# Remote buffer fields
self.remote_manager_rank = remote_manager_rank
self.remote_buffer_freq = remote_buffer_freq
self._add_counter = 0
self.async_score = async_score
self._pending_score: bool = False # True when a score recv is outstanding.
if self.remote_manager_rank is not None:
backend = dist.get_backend()
if backend != "gloo":
raise RuntimeError(
f"Replay Buffer Manager is only supported with the 'gloo' backend, "
f"but the current backend is '{backend}'."
)

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -141,11 +163,13 @@ def add(self, training_container: ContainerUnion) -> dict[str, float] | None:
# Collect stale score from previous send (if any), then
# fire-and-forget the new batch.
stale_score = self._collect_pending_score()
self._send_objs_async(self.pending_container)
with Timer(self.timing_data, "send_objs", enabled=self.timing):
self._send_objs_async(self.pending_container)
self.pending_container = None
return stale_score
else:
score = self._send_objs(self.pending_container)
with Timer(self.timing_data, "send_objs", enabled=self.timing):
score = self._send_objs(self.pending_container)
self.pending_container = None
return score

Expand Down Expand Up @@ -215,19 +239,24 @@ def _recv_worker():
def _send_data(self, training_container: ContainerUnion) -> None:
"""Send a training container to the remote manager."""
msg = Message(MessageType.DATA, training_container)
msg_tensor = msg.serialize()
length_tensor = torch.IntTensor([len(msg_tensor)])
dist.send(length_tensor, dst=self.remote_manager_rank)
dist.send(msg_tensor, dst=self.remote_manager_rank)
with Timer(self.timing_data, "serialize_objs", enabled=self.timing):
msg_tensor = msg.serialize()
with Timer(self.timing_data, "send_data", enabled=self.timing):
send(
msg_tensor,
dst_rank=self.remote_manager_rank,
backend=self.communication_backend,
)

def _recv_score(self) -> dict[str, float]:
"""Receive a score dictionary from the remote manager."""
length_tensor = torch.zeros(1, dtype=torch.int32)
dist.recv(length_tensor, src=self.remote_manager_rank)
length = length_tensor.item()
score_tensor = torch.ByteTensor(length)
dist.recv(score_tensor, src=self.remote_manager_rank)
return Message.deserialize(score_tensor).message_data
with Timer(self.timing_data, "recv_score", enabled=self.timing):
_src_rank, score_tensor = recv(
src_rank=self.remote_manager_rank,
backend=self.communication_backend,
)
with Timer(self.timing_data, "deserialize_score", enabled=self.timing):
return Message.deserialize(score_tensor).message_data

def __repr__(self) -> str:
"""Returns a string representation of the ReplayBuffer.
Expand Down Expand Up @@ -362,6 +391,13 @@ def load(self, path: str):
if self.training_container is not None:
self.training_container = type(self.training_container).load(self.env, path)

def timing_log(self) -> str:
"""Returns a formatted string of the timing information for the replay buffer."""
log_str = "Replay Buffer Timing Information:\n"
for key, times in self.timing_data.items():
log_str += f"{key}: {sum(times):.4f} s\n"
return log_str


class NormBasedDiversePrioritizedReplayBuffer(ReplayBuffer):
"""A replay buffer with diversity-based prioritization.
Expand Down Expand Up @@ -389,6 +425,8 @@ def __init__(
p_norm_distance: float = 1.0,
remote_manager_rank: int | None = None,
remote_buffer_freq: int = 1,
communication_backend: str = "mpi",
timing: bool = False,
async_score: bool = False,
):
"""Initializes a NormBasedDiversePrioritizedReplayBuffer instance.
Expand All @@ -403,6 +441,8 @@ def __init__(
None if no remote manager is assigned.
remote_buffer_freq: Frequency (in number of add() calls) at which to contact
the remote buffer manager.
communication_backend: Communication backend (``"mpi"`` or ``"torch"``).
timing: If True, record operation durations for profiling.
async_score: If True, trajectory sends are fire-and-forget; scores
are collected lazily on the next add() call.
"""
Expand All @@ -412,6 +452,8 @@ def __init__(
prioritized_capacity=True,
remote_manager_rank=remote_manager_rank,
remote_buffer_freq=remote_buffer_freq,
communication_backend=communication_backend,
timing=timing,
async_score=async_score,
)
self.cutoff_distance = cutoff_distance
Expand Down Expand Up @@ -523,8 +565,8 @@ class TerminatingStateBuffer(ReplayBuffer):
training_container: The buffer contents (StatesContainer).
"""

def __init__(self, env: Env, capacity: int = 1000, **kwargs):
super().__init__(env, capacity, **kwargs)
def __init__(self, env: Env, capacity: int = 1000, timing: bool = False, **kwargs):
super().__init__(env, capacity, timing=timing, **kwargs)
self.training_container = StatesContainer(env)

def _local_add(self, training_container: ContainerUnion):
Expand Down
69 changes: 25 additions & 44 deletions src/gfn/containers/replay_buffer_manager.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import math
from typing import Callable, Optional

import torch
import torch.distributed as dist

from gfn.containers.message import Message, MessageType
from gfn.containers.replay_buffer import (
NormBasedDiversePrioritizedReplayBuffer,
ReplayBuffer,
)
from gfn.env import Env
from gfn.utils.distributed import recv, send


class ReplayBufferManager:
Expand All @@ -23,25 +21,23 @@ def __init__(
diverse_replay_buffer: bool = False,
capacity: int = 10000,
remote_manager_rank: int | None = None,
communication_backend: str = "mpi",
):
self.rank = rank
self.is_running = True
self.exit_counter = 0
self.num_training_ranks = num_training_ranks
self.scoring_function = scoring_function or self.default_scoring_function
backend = dist.get_backend()
if backend != "gloo":
raise RuntimeError(
f"Replay Buffer Manager is only supported with the 'gloo' backend, "
f"but the current backend is '{backend}'."
)
self.communication_backend = communication_backend

self.diverse_replay_buffer = diverse_replay_buffer
self.capacity = capacity
self.remote_manager_rank = remote_manager_rank
if self.diverse_replay_buffer:
self.replay_buffer = NormBasedDiversePrioritizedReplayBuffer(
env, capacity=self.capacity
env,
capacity=self.capacity,
communication_backend=self.communication_backend,
)
else:
self.replay_buffer = ReplayBuffer(
Expand All @@ -50,6 +46,7 @@ def __init__(
prioritized_capacity=True, # Always prioritize high reward items.
remote_manager_rank=self.remote_manager_rank,
remote_buffer_freq=1,
communication_backend=self.communication_backend,
)

def default_scoring_function(self, obj, sender_rank: int = -1) -> dict[str, float]:
Expand All @@ -76,17 +73,21 @@ def run(self):
)
message = Message(message_type=MessageType.DATA, message_data=score_dict)
message_tensor = message.serialize()
length_message_tensor = torch.IntTensor([len(message_tensor)])
dist.send(length_message_tensor, dst=sender_rank)
dist.send(message_tensor, dst=sender_rank)
send(
message_tensor,
dst_rank=sender_rank,
backend=self.communication_backend,
)

elif msg.message_type == MessageType.GET_METADATA:
metadata = self._compute_metadata()
msg = Message(message_type=MessageType.DATA, message_data=metadata)
metadata_tensor = msg.serialize()
length_metadata_tensor = torch.IntTensor([len(metadata_tensor)])
dist.send(length_metadata_tensor, dst=sender_rank)
dist.send(metadata_tensor, dst=sender_rank)
send(
metadata_tensor,
dst_rank=sender_rank,
backend=self.communication_backend,
)

elif msg.message_type == MessageType.EXIT:
self.exit_counter = self.exit_counter + 1
Expand All @@ -101,46 +102,26 @@ def run(self):
)

def _recv_object(self):
# Receive the length.
length_tensor = torch.IntTensor([0])
sender_rank = dist.recv(length_tensor)
length = length_tensor.item()

# Receive the actual serialized data.
byte_tensor = torch.ByteTensor(length)
dist.recv(byte_tensor, src=sender_rank)
sender_rank, byte_tensor = recv(backend=self.communication_backend)

# Deserialize back into object.
msg = Message.deserialize(byte_tensor)
return sender_rank, msg, length
return sender_rank, msg, len(byte_tensor)

@staticmethod
def send_termination_signal(manager_rank: int):
def send_termination_signal(manager_rank: int, backend):
"""Sends a termination signal to the replay buffer manager."""
rank = dist.get_rank()
msg = Message(message_type=MessageType.EXIT, message_data=None)
msg_bytes = msg.serialize()
length_tensor = torch.IntTensor([len(msg_bytes)])
dist.send(length_tensor, dst=manager_rank)
dist.send(msg_bytes, dst=manager_rank)
print(
f"Rank {rank} sent termination signal to replay buffer manager {manager_rank}."
)
send(msg_bytes, dst_rank=manager_rank, backend=backend)

@staticmethod
def get_metadata(manager_rank: int) -> dict:
def get_metadata(manager_rank: int, backend) -> dict:
"""Sends a get metadata signal to the replay buffer manager."""
msg = Message(message_type=MessageType.GET_METADATA, message_data=None)
msg_bytes = msg.serialize()

length_tensor = torch.IntTensor([len(msg_bytes)])
dist.send(length_tensor, dst=manager_rank)

dist.send(msg_bytes, dst=manager_rank)
length_metadata_tensor = torch.IntTensor([0])

dist.recv(length_metadata_tensor, src=manager_rank)
metadata_tensor = torch.ByteTensor(length_metadata_tensor.item())

dist.recv(metadata_tensor, src=manager_rank)
send(msg_bytes, dst_rank=manager_rank, backend=backend)
_src_rank, metadata_tensor = recv(manager_rank, backend=backend)
metadata = Message.deserialize(metadata_tensor)
return metadata.message_data
Loading
Loading