Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9d272dc
[verl] feat: add trust_remote_code arg and compilation_config dict su…
sijyang Feb 11, 2026
93ebe11
[verl] feat: add logprobs and request_id support across sampling pipe…
sijyang Feb 11, 2026
def041f
[verl] feat: weight sync, memory lifecycle and DP isolation for verl …
sijyang Feb 11, 2026
dbc7704
[verl] feat: utility command dispatch and broadcast communication
sijyang Feb 11, 2026
c1870e2
[verl] feat: basic integration with verl - load_weights, sleep/wake_u…
sijyang Feb 11, 2026
2479e27
[atom] fix: rope parameters handling, remove CLI trust_remote_code, a…
sijyang Feb 12, 2026
4c58b0a
[atom] feat: implement packed weight handling in ModelRunner for FP8 …
sijyang Mar 5, 2026
452847e
[verl] refactor: decouple RLHF rollout logic from inference engine in…
sijyang Mar 18, 2026
2e299a8
[verl] feat: extend tokenIDProcessor for logprobs support and enhance…
sijyang Apr 3, 2026
a3e2f65
fix: patch NCCL device binding for DP-isolated ModelRunner
sijyang Apr 13, 2026
e9f107e
refactor: minimize diff against main by reverting non-functional changes
sijyang Apr 13, 2026
e062e18
refactor: improve code readability by formatting and organizing funct…
sijyang Apr 13, 2026
60060f9
refactor: extract sleep logic from engine_core busy_loop into helper …
sijyang Apr 22, 2026
7424177
[verl] refactor: merge logprobs and DP isolation into base ModelRunne…
sijyang Apr 15, 2026
be3ba87
refactor: rename sleep state variables and update related logic for R…
sijyang Apr 23, 2026
aeeeab6
fix: restore mark_trace profiler around cudagraph capture
sijyang Apr 30, 2026
cef0263
docs: add veRL + Megatron + ATOM environment setup guide for ROCm
sijyang Apr 30, 2026
f7c935e
[verl] feat: add logprobs and request_id support across sampling pipe…
sijyang Feb 11, 2026
7a7b672
[verl] refactor: unify load_weights API with auto mode selection
sijyang May 24, 2026
9b344f4
fix: batch token ID processing in tokenIDProcessor
sijyang May 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,7 @@ class Config:
enable_tbo: bool = False
enable_tbo_decode: bool = False
enable_low_latency: bool = False
runner_qualname: str = "atom.model_engine.model_runner.ModelRunner"

# only use for plugin mode
plugin_config: Optional[PluginConfig] = None
Expand All @@ -987,6 +988,8 @@ def _set_cudagraph_sizes(self):
self.graph_bs = cuda_graph_sizes

def __post_init__(self):
if isinstance(self.compilation_config, dict):
self.compilation_config = CompilationConfig(**self.compilation_config)
# assert os.path.isdir(self.model)

assert 1 <= self.tensor_parallel_size <= 8
Expand Down
13 changes: 12 additions & 1 deletion atom/model_engine/async_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
runner_qualname: str,
rank: int,
kv_output_addr: str | None = None,
all_ranks_barrier=None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -107,6 +108,8 @@ def __init__(
t.start()
self.io_threads.append(t)

self.all_ranks_barrier = all_ranks_barrier

runner_class = resolve_obj_by_qualname(runner_qualname)
self.runners: list[object] = []
self.runners = [runner_class(rank, *args, **kwargs)]
Expand Down Expand Up @@ -162,15 +165,22 @@ def send_output_to_socket(self, addr: str, output_queue: queue.Queue):
serialized_obj = pickle.dumps(result)
socket.send(serialized_obj)

# Functions that require all TP ranks to synchronize via barrier before
# rank 0 returns, so the caller can safely reuse/overwrite shared buffers.
_BARRIER_FUNCS = {"update_weights_from_ipc", "update_weights_from_shm"}

def busy_loop(self):
"""Main event loop: dequeue RPCs and dispatch to runners."""
while True:
func_name, args = self.get_func()
need_barrier = func_name in self._BARRIER_FUNCS
for runner in self.runners:
func = getattr(runner, func_name, None)
if func is None:
continue
out = func(*args)
if need_barrier and self.all_ranks_barrier is not None:
self.all_ranks_barrier.wait()
if out is not None:
if (
self.io_addrs[1] is not None
Expand All @@ -179,7 +189,6 @@ def busy_loop(self):
self.io_queues[1].put_nowait(out)
if self.kv_queue is not None and func_name in self._KV_FUNC_NAMES:
self.kv_queue.put_nowait(out)

if func_name == "exit":
break
logger.debug(f"{self.label}: exit busy_loop...")
Expand Down Expand Up @@ -227,6 +236,7 @@ def __init__(self, finalizer, proc_num: int, runner: str, *args):
import atexit

atexit.register(self._cleanup_shared_memory)
self.all_ranks_barrier = ctx.Barrier(proc_num)
init_exit_handler(self)

# KV output aggregation infrastructure
Expand All @@ -252,6 +262,7 @@ def __init__(self, finalizer, proc_num: int, runner: str, *args):
runner,
i,
self.kv_output_addrs[i],
self.all_ranks_barrier,
*args,
),
)
Expand Down
98 changes: 87 additions & 11 deletions atom/model_engine/engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import zmq
from atom.config import Config, ParallelConfig
from atom.model_engine.async_proc import AsyncIOProcManager
from atom.rollout.engine_utility import EngineUtilityHandler
from atom.model_engine.scheduler import Scheduler
from atom.model_engine.sequence import Sequence, SequenceStatus, get_exit_sequence
from atom.utils import init_exit_handler, make_zmq_socket
Expand Down Expand Up @@ -44,6 +45,8 @@ class EngineCoreRequestType(enum.Enum):
STREAM = b"\x06"
# Signal that EngineCore is fully initialized and ready
READY = b"\x07"
# Response to a synchronous utility command
UTILITY_RESPONSE = b"\x08"


class EngineCore:
Expand All @@ -54,6 +57,14 @@ def __init__(self, config: Config, input_address: str, output_address: str):
self.stream_output_queue = (
queue.Queue()
) # Queue for streaming intermediate outputs
# Queue for utility commands (processed in busy_loop to avoid thread contention)
self.utility_queue = queue.Queue()
self._has_pending_utility = (
False # Flag to avoid checking empty queue every loop
)
self._is_rl_weights_offloaded = (
False # True when weights are offloaded for RL training
)
self.input_address = input_address
self.output_address = output_address
self.output_thread = threading.Thread(
Expand Down Expand Up @@ -84,9 +95,10 @@ def __init__(self, config: Config, input_address: str, output_address: str):
self.runner_mgr = AsyncIOProcManager(
self._finalizer,
config.tensor_parallel_size,
"atom.model_engine.model_runner.ModelRunner",
config.runner_qualname,
Comment thread
sijyang marked this conversation as resolved.
config,
)

block_info = self.runner_mgr.call_func("get_num_blocks", wait_out=True)
num_blocks = block_info["num_kvcache_blocks"]
config.per_req_cache_equiv_blocks = block_info.get(
Expand All @@ -102,7 +114,6 @@ def __init__(self, config: Config, input_address: str, output_address: str):

config.num_kvcache_blocks = num_blocks
if not config.enforce_eager:
# Start profiler before cudagraph capture only if mark-trace is enabled.
Comment thread
sijyang marked this conversation as resolved.
if self.profile_enbaled and self.mark_trace:
self.runner_mgr.call_func(
"start_profiler", "capture_graph", wait_out=True
Expand All @@ -114,7 +125,6 @@ def __init__(self, config: Config, input_address: str, output_address: str):
f"{self.label}: cudagraph capture{bs} cost: {cap_cost:.2f} seconds"
)
if self.profile_enbaled and self.mark_trace:
# Persist a dedicated capture-graph trace immediately.
self.runner_mgr.call_func("stop_profiler", wait_out=True)
good = True
finally:
Expand All @@ -132,6 +142,10 @@ def __init__(self, config: Config, input_address: str, output_address: str):
world_size=config.tensor_parallel_size
)

self.utility_handler = EngineUtilityHandler(
self.runner_mgr, self.output_queue, label=self.label
)

self._send_ready_signal()
logger.info(f"{self.label}: EngineCore fully initialized and ready")

Expand Down Expand Up @@ -181,12 +195,26 @@ def run_engine(config: Config, input_address: str, output_address: str):
if engine is not None:
engine.exit()

def _is_idle_rl_weights_offloaded(self) -> bool:
"""Check if weights are offloaded for RL training.

When offloaded, busy-wait with a short delay to avoid CPU spin.
Returns True if the caller should skip model execution this tick.
"""
if self._is_rl_weights_offloaded:
time.sleep(0.01)
return True
return False

def busy_loop(self):
shutdown = False
while True:
self.utility_handler.process_queue(self.utility_queue, self)
shutdown = shutdown or self.pull_and_process_input_queue()
if shutdown:
break
if self._is_idle_rl_weights_offloaded():
continue
if not self.scheduler.is_finished():
self._process_engine_step()

Expand Down Expand Up @@ -306,7 +334,8 @@ def process_input_sockets(self, input_address: str):
)
self.input_queue.put_nowait(reqs)
elif request_type == EngineCoreRequestType.UTILITY:
# Handle utility commands like start_profile/stop_profile
# Put utility commands into queue for processing in busy_loop
# This ensures all runner_mgr.call_func() calls happen in the same thread
cmd = reqs.get("cmd") if isinstance(reqs, dict) else None
logger.debug(f"{self.label}: input get UTILITY command: {cmd}")
if cmd == "start_profile":
Expand All @@ -315,6 +344,10 @@ def process_input_sockets(self, input_address: str):
self.stop_profiler()
elif cmd == "get_mtp_stats":
self.print_mtp_statistics()
else:
# Queue command for processing in busy_loop (main thread)
self.utility_queue.put_nowait((cmd, reqs))
self._has_pending_utility = True
elif request_type == EngineCoreRequestType.SHUTDOWN:
logger.debug(f"{self.label}: input get {request_type}")
self.input_queue.put_nowait([get_exit_sequence()])
Expand Down Expand Up @@ -346,6 +379,15 @@ def process_output_sockets(self, output_address: str):
logger.debug(f"{self.label}: sent READY signal")
continue

if isinstance(item, tuple) and item[0] == "UTILITY_RESPONSE":
# Send utility command response back to CoreManager
response_data = item[1]
serialized_obj = pickle.dumps(
(EngineCoreRequestType.UTILITY_RESPONSE, response_data)
)
socket.send(serialized_obj)
continue

# Regular finished sequences
seqs = item
valid_seqs = [
Expand Down Expand Up @@ -410,28 +452,53 @@ def exit(self):
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)

def _gather_local_dp_state(self) -> tuple[bool, int, int, bool, bool]:
"""Gather local scheduling state, substituting dummy values when
weights are offloaded for RL training.

Offloaded cores must still participate in _sync_dp_state (NCCL
all_reduce) to prevent other DP ranks from blocking forever.
The offload flag is included in the sync tensor so that ALL cores
agree to skip model execution together — MoE expert routing and
dummy_execution also contain DP-wide collectives that would hang
if only some cores participated.
"""
offloaded = self._is_rl_weights_offloaded
if not offloaded:
is_prefill, num_tokens, num_reqs = self.scheduler.get_next_batch_info()
unfinished = not self.scheduler.is_finished()
else:
is_prefill, num_tokens, num_reqs, unfinished = False, 0, 0, False
return is_prefill, num_tokens, num_reqs, unfinished, offloaded

def busy_loop(self):
shutdown = False
while True:
self.utility_handler.process_queue(self.utility_queue, self)
shutdown = shutdown or self.pull_and_process_input_queue()

local_is_prefill, local_num_tokens, local_num_reqs = (
self.scheduler.get_next_batch_info()
)
local_unfinished = not self.scheduler.is_finished()
(
local_is_prefill,
local_num_tokens,
local_num_reqs,
local_unfinished,
local_offloaded,
) = self._gather_local_dp_state()

(
global_has_prefill,
global_max_tokens,
global_max_reqs,
global_has_unfinished,
global_shutdown,
global_offloaded,
) = self._sync_dp_state(
local_is_prefill,
local_num_tokens,
local_num_reqs,
local_unfinished,
shutdown,
local_offloaded,
)

if global_shutdown and not global_has_unfinished:
Expand All @@ -440,6 +507,10 @@ def busy_loop(self):
)
break

if global_offloaded:
time.sleep(0.01)
continue

if not global_has_unfinished and not self.engines_running:
self.engines_running = False
continue
Expand Down Expand Up @@ -489,25 +560,28 @@ def _sync_dp_state(
local_num_reqs: int,
local_has_unfinished: bool,
local_shutdown: bool = False,
) -> tuple[bool, int, int, bool, bool]:
local_offloaded: bool = False,
) -> tuple[bool, int, int, bool, bool, bool]:
if self._shutting_down:
return (
local_is_prefill,
local_num_tokens,
local_num_reqs,
local_has_unfinished,
True,
local_offloaded,
)

try:
# Pack all state: [is_prefill, num_tokens, num_reqs, has_unfinished, shutdown]
# Pack all state: [is_prefill, num_tokens, num_reqs, has_unfinished, shutdown, offloaded]
state_tensor = torch.tensor(
[
1 if local_is_prefill else 0,
local_num_tokens,
local_num_reqs,
1 if local_has_unfinished else 0,
1 if local_shutdown else 0,
1 if local_offloaded else 0,
],
dtype=torch.int64,
device="cpu",
Expand All @@ -520,23 +594,25 @@ def _sync_dp_state(
global_max_reqs = state_tensor[2].item()
global_has_unfinished = state_tensor[3].item() == 1
global_shutdown = state_tensor[4].item() == 1
global_offloaded = state_tensor[5].item() == 1
return (
global_has_prefill,
global_max_tokens,
global_max_reqs,
global_has_unfinished,
global_shutdown,
global_offloaded,
)
except RuntimeError as e:
logger.warning(f"{self.label}: _sync_dp_state failed: {e}")
# If sync fails, assume shutdown to prevent hang
self._shutting_down = True
return (
local_is_prefill,
local_num_tokens,
local_num_reqs,
local_has_unfinished,
True,
local_offloaded,
)

def _sync_shutdown_state(self, local_should_shutdown: bool) -> bool:
Expand Down
Loading
Loading