[AsyncRL] Use keep mode for pause and resume #1179
[AsyncRL] Use keep mode for pause and resume #1179SumanthRH merged 28 commits intoNovaSky-AI:mainfrom
Conversation
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
…into async-train-keep
There was a problem hiding this comment.
Code Review
This pull request introduces a new fully asynchronous DAPO trainer with soft overlong punishment for rewards and adds new shell scripts for fully async DAPO and GRPO training. It also updates GPU configurations for existing scripts and explicitly sets CUDA device visibility for a retrieval server. A significant change involves refactoring the inference engine client to leverage vLLM's native 'keep-mode' pause/resume functionality, removing previous client-side retry logic and related dataclass definitions. The PauseMode enum is updated to reflect these changes, and the _call_server method is modified to support query parameters for the new pause mechanism. Additionally, the PyTorch version check is improved, and flash_attn imports are made more robust with a fallback implementation. Review comments highlight a potential runtime error in model_wrapper.py due to a unpad_input signature mismatch when flash_attn is installed, and a bug in remote_inference_client.py where the params argument is not passed to session.request in _call_server, preventing correct query parameter usage for the new pause functionality.
| try: | ||
| from flash_attn.bert_padding import pad_input, unpad_input | ||
| except ImportError: | ||
| import torch.nn.functional as _F | ||
|
|
||
| def unpad_input(hidden_states, attention_mask, unused_mask=None): | ||
| all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask | ||
| seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) | ||
| used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | ||
| indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() | ||
| max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
| cu_seqlens = _F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | ||
| b, s = hidden_states.shape[:2] | ||
| flat = hidden_states.reshape(b * s, *hidden_states.shape[2:]) | ||
| return flat[indices], indices, cu_seqlens, max_seqlen_in_batch, used_seqlens_in_batch | ||
|
|
||
| def pad_input(hidden_states, indices, batch, seqlen): | ||
| output = torch.zeros( | ||
| (batch * seqlen, *hidden_states.shape[1:]), | ||
| device=hidden_states.device, | ||
| dtype=hidden_states.dtype, | ||
| ) | ||
| output[indices] = hidden_states | ||
| return output.reshape(batch, seqlen, *hidden_states.shape[1:]) |
There was a problem hiding this comment.
There's a potential signature mismatch for the unpad_input function that could lead to a runtime error. The flash_attn.bert_padding.unpad_input function returns 4 values, but the call sites in this file (e.g., on line 315) expect to unpack 5 values. This will raise a ValueError if flash_attn is installed.
To resolve this and ensure compatibility regardless of whether flash_attn is present, I recommend wrapping the imported unpad_input function to align its return signature with the call sites. Since the fifth return value is unused, returning None for it would be a safe fix.
try:
from flash_attn.bert_padding import pad_input
from flash_attn.bert_padding import unpad_input as _flash_unpad_input
def unpad_input(hidden_states, attention_mask, unused_mask=None):
if unused_mask is not None:
raise NotImplementedError("unused_mask is not supported with flash_attn's unpad_input.")
# Wrap to return 5 values to match call sites; the 5th is unused.
return (*_flash_unpad_input(hidden_states, attention_mask), None)
except ImportError:
import torch.nn.functional as _F
def unpad_input(hidden_states, attention_mask, unused_mask=None):
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = _F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
b, s = hidden_states.shape[:2]
flat = hidden_states.reshape(b * s, *hidden_states.shape[2:])
return flat[indices], indices, cu_seqlens, max_seqlen_in_batch, used_seqlens_in_batch
def pad_input(hidden_states, indices, batch, seqlen):
output = torch.zeros(
(batch * seqlen, *hidden_states.shape[1:]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
output[indices] = hidden_states
return output.reshape(batch, seqlen, *hidden_states.shape[1:])
kouroshHakha
left a comment
There was a problem hiding this comment.
Leaving quick reviews. @SumanthRH @CharlieFRuan should review in more details
| if mode == PauseMode.KEEP: | ||
| params["clear_cache"] = "false" |
There was a problem hiding this comment.
this should be controlled by some cfg arg
There was a problem hiding this comment.
+1
let's add clear_cache as an kwarg to pause, default false
| from skyrl.backends.skyrl_train.utils.torch_utils import chunked_entropy_from_logits, logprobs_from_logits | ||
| from flash_attn.bert_padding import pad_input, unpad_input | ||
|
|
||
| try: |
There was a problem hiding this comment.
how are these related to this PR?
There was a problem hiding this comment.
left over from before version upgrade, can remove
SumanthRH
left a comment
There was a problem hiding this comment.
Nice! Overall looks good, thanks for migrating the old and the new inference codepath.
Please clean up the branch to isolate your changes. Left a few other minor comments.
There was a problem hiding this comment.
@hao-aaron Let's not add the scripts for fully async RL with DAPO and Search R1 in this PR.
@CharlieFRuan can you add these scripts in a separate PR?
There was a problem hiding this comment.
Wait why are there changes in the legacy YAML?
Your PR is converting the legacy YAML into the format of the new dataclasses introduced in #1001 .
| if mode == PauseMode.KEEP: | ||
| params["clear_cache"] = "false" |
There was a problem hiding this comment.
+1
let's add clear_cache as an kwarg to pause, default false
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
…client.py Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
| if session_id: | ||
| headers["X-Session-ID"] = str(session_id) |
There was a problem hiding this comment.
🟡 Falsy session_id (e.g., 0) silently drops X-Session-ID header
In _generate_single, the check if session_id: at line 217 is falsy for valid session IDs like session_id=0 (an integer) or session_id="". When session_id=0 is passed, the X-Session-ID header won't be set, meaning the upstream router won't do session-aware routing for that request. The same pattern exists in chat_completion (remote_inference_client.py:263) and completion (remote_inference_client.py:295). The correct check should be if session_id is not None:.
Was this helpful? React with 👍 or 👎 to provide feedback.
…into async-train-keep
Signed-off-by: hao-aaron <ahao@anyscale.com>
…terface Signed-off-by: hao-aaron <ahao@anyscale.com>
…into async-train-keep
| async with session.post( | ||
| f"{self.url}/pause", | ||
| params={"mode": "keep"}, | ||
| ) as resp: |
There was a problem hiding this comment.
🔴 RemoteInferenceEngine.pause_generation() omits clear_cache=false, risking KV cache eviction in keep mode
When RemoteInferenceEngine.pause_generation() calls the vLLM /pause endpoint, it only sends params={"mode": "keep"} but does not include clear_cache=false. Both other implementations explicitly pass clear_cache=False: AsyncVLLMInferenceEngine.pause_generation() at vllm_engine.py:635 passes clear_cache=clear_cache (default False), and RemoteInferenceClient.pause_generation() at remote_inference_client.py:451 passes "clear_cache": str(clear_cache).lower() (default "false"). The test mock server at test_remote_inference_client.py:72 defaults clear_cache to "true", confirming the server-side default clears cache. Without explicitly sending clear_cache=false, keep-mode pause through RemoteInferenceEngine will clear the KV cache, causing frozen in-flight requests to lose their cached state and defeating the purpose of keep mode.
| async with session.post( | |
| f"{self.url}/pause", | |
| params={"mode": "keep"}, | |
| ) as resp: | |
| async with session.post( | |
| f"{self.url}/pause", | |
| params={"mode": "keep", "clear_cache": "false"}, | |
| ) as resp: |
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
@hao-aaron test_pause_and_continue_generation should now run with the new inference codepath right?
_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra fsdp --extra dev pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_pause_and_continue_generation.py
Can you add it to the GPU CI script if so? here:
SkyRL/ci/gpu_ci_run_skyrl_train.sh
Line 37 in af42cf8
New keep mode introduced as a part of vllm-project/vllm#32103 can be used to simplify async training. Current methods use pause mode abort, to deal with inflight requests, and retries requests with new tokens after resume. Keep mode freezes inflight requests when paused, so we only have to send one request without retry, and we don't have to handle intermediate tokens returned by aborted requests.
Removed retry loop from both old inference path and also _SKYRL_USE_NEW_INFERENCE=1.
gsm8k:


baseline:
==============================================================================
searchr1:


baseline:
==============================================================================
dapo-aime


baseline:
Regression checks:
https://wandb.ai/sky-posttraining-uc-berkeley/async-keep?nw=nwuserahao
TODOS: