Skip to content

[AsyncRL] Use keep mode for pause and resume #1179

Merged
SumanthRH merged 28 commits intoNovaSky-AI:mainfrom
hao-aaron:async-train-keep
Mar 13, 2026
Merged

[AsyncRL] Use keep mode for pause and resume #1179
SumanthRH merged 28 commits intoNovaSky-AI:mainfrom
hao-aaron:async-train-keep

Conversation

@hao-aaron
Copy link
Copy Markdown
Collaborator

@hao-aaron hao-aaron commented Feb 19, 2026

New keep mode introduced as a part of vllm-project/vllm#32103 can be used to simplify async training. Current methods use pause mode abort, to deal with inflight requests, and retries requests with new tokens after resume. Keep mode freezes inflight requests when paused, so we only have to send one request without retry, and we don't have to handle intermediate tokens returned by aborted requests.

Removed retry loop from both old inference path and also _SKYRL_USE_NEW_INFERENCE=1.

gsm8k:
Screenshot 2026-03-10 at 3 20 38 PM
baseline:
Screenshot 2026-03-10 at 5 01 52 PM

==============================================================================

searchr1:
Screenshot 2026-03-10 at 3 20 59 PM
baseline:
Screenshot 2026-03-10 at 4 56 51 PM

==============================================================================

dapo-aime
Screenshot 2026-03-10 at 3 21 19 PM
baseline:
Screenshot 2026-03-10 at 5 05 11 PM

Regression checks:
https://wandb.ai/sky-posttraining-uc-berkeley/async-keep?nw=nwuserahao

TODOS:

  • Implementation
  • unit tests
  • regression check

Open with Devin

x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
@CharlieFRuan CharlieFRuan self-assigned this Mar 6, 2026
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
@hao-aaron hao-aaron marked this pull request as ready for review March 10, 2026 21:50
@hao-aaron hao-aaron changed the title [WIP][AsyncRL] Use keep mode for pause and resume [AsyncRL] Use keep mode for pause and resume Mar 10, 2026
@SumanthRH SumanthRH self-assigned this Mar 10, 2026
devin-ai-integration[bot]

This comment was marked as resolved.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new fully asynchronous DAPO trainer with soft overlong punishment for rewards and adds new shell scripts for fully async DAPO and GRPO training. It also updates GPU configurations for existing scripts and explicitly sets CUDA device visibility for a retrieval server. A significant change involves refactoring the inference engine client to leverage vLLM's native 'keep-mode' pause/resume functionality, removing previous client-side retry logic and related dataclass definitions. The PauseMode enum is updated to reflect these changes, and the _call_server method is modified to support query parameters for the new pause mechanism. Additionally, the PyTorch version check is improved, and flash_attn imports are made more robust with a fallback implementation. Review comments highlight a potential runtime error in model_wrapper.py due to a unpad_input signature mismatch when flash_attn is installed, and a bug in remote_inference_client.py where the params argument is not passed to session.request in _call_server, preventing correct query parameter usage for the new pause functionality.

Comment on lines +18 to +41
try:
from flash_attn.bert_padding import pad_input, unpad_input
except ImportError:
import torch.nn.functional as _F

def unpad_input(hidden_states, attention_mask, unused_mask=None):
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = _F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
b, s = hidden_states.shape[:2]
flat = hidden_states.reshape(b * s, *hidden_states.shape[2:])
return flat[indices], indices, cu_seqlens, max_seqlen_in_batch, used_seqlens_in_batch

def pad_input(hidden_states, indices, batch, seqlen):
output = torch.zeros(
(batch * seqlen, *hidden_states.shape[1:]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
output[indices] = hidden_states
return output.reshape(batch, seqlen, *hidden_states.shape[1:])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a potential signature mismatch for the unpad_input function that could lead to a runtime error. The flash_attn.bert_padding.unpad_input function returns 4 values, but the call sites in this file (e.g., on line 315) expect to unpack 5 values. This will raise a ValueError if flash_attn is installed.

To resolve this and ensure compatibility regardless of whether flash_attn is present, I recommend wrapping the imported unpad_input function to align its return signature with the call sites. Since the fifth return value is unused, returning None for it would be a safe fix.

try:
    from flash_attn.bert_padding import pad_input
    from flash_attn.bert_padding import unpad_input as _flash_unpad_input

    def unpad_input(hidden_states, attention_mask, unused_mask=None):
        if unused_mask is not None:
            raise NotImplementedError("unused_mask is not supported with flash_attn's unpad_input.")
        # Wrap to return 5 values to match call sites; the 5th is unused.
        return (*_flash_unpad_input(hidden_states, attention_mask), None)

except ImportError:
    import torch.nn.functional as _F

    def unpad_input(hidden_states, attention_mask, unused_mask=None):
        all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask
        seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32)
        used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
        indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten()
        max_seqlen_in_batch = seqlens_in_batch.max().item()
        cu_seqlens = _F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
        b, s = hidden_states.shape[:2]
        flat = hidden_states.reshape(b * s, *hidden_states.shape[2:])
        return flat[indices], indices, cu_seqlens, max_seqlen_in_batch, used_seqlens_in_batch

    def pad_input(hidden_states, indices, batch, seqlen):
        output = torch.zeros(
            (batch * seqlen, *hidden_states.shape[1:]),
            device=hidden_states.device,
            dtype=hidden_states.dtype,
        )
        output[indices] = hidden_states
        return output.reshape(batch, seqlen, *hidden_states.shape[1:])

x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
Copy link
Copy Markdown
Collaborator

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving quick reviews. @SumanthRH @CharlieFRuan should review in more details

Comment on lines +445 to +446
if mode == PauseMode.KEEP:
params["clear_cache"] = "false"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be controlled by some cfg arg

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

let's add clear_cache as an kwarg to pause, default false

from skyrl.backends.skyrl_train.utils.torch_utils import chunked_entropy_from_logits, logprobs_from_logits
from flash_attn.bert_padding import pad_input, unpad_input

try:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how are these related to this PR?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left over from before version upgrade, can remove

Copy link
Copy Markdown
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Overall looks good, thanks for migrating the old and the new inference codepath.

Please clean up the branch to isolate your changes. Left a few other minor comments.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hao-aaron Let's not add the scripts for fully async RL with DAPO and Search R1 in this PR.

@CharlieFRuan can you add these scripts in a separate PR?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert these changes

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait why are there changes in the legacy YAML?

Your PR is converting the legacy YAML into the format of the new dataclasses introduced in #1001 .

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted

Comment thread skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py Outdated
Comment on lines +445 to +446
if mode == PauseMode.KEEP:
params["clear_cache"] = "false"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

let's add clear_cache as an kwarg to pause, default false

Comment thread skyrl/backends/skyrl_train/inference_servers/remote_inference_client.py Outdated
Comment thread skyrl/backends/skyrl_train/workers/model_wrapper.py
Comment thread skyrl/train/config/config.py
Comment thread skyrl-gym/skyrl_gym/tools/search.py
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
devin-ai-integration[bot]

This comment was marked as resolved.

…client.py

Co-authored-by: devin-ai-integration[bot] <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 12 additional findings in Devin Review.

Open in Devin Review

Comment on lines +217 to +218
if session_id:
headers["X-Session-ID"] = str(session_id)
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Falsy session_id (e.g., 0) silently drops X-Session-ID header

In _generate_single, the check if session_id: at line 217 is falsy for valid session IDs like session_id=0 (an integer) or session_id="". When session_id=0 is passed, the X-Session-ID header won't be set, meaning the upstream router won't do session-aware routing for that request. The same pattern exists in chat_completion (remote_inference_client.py:263) and completion (remote_inference_client.py:295). The correct check should be if session_id is not None:.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
x
Signed-off-by: ahao-anyscale <ahao@anyscale.com>
devin-ai-integration[bot]

This comment was marked as resolved.

Signed-off-by: hao-aaron <ahao@anyscale.com>
devin-ai-integration[bot]

This comment was marked as resolved.

…terface

Signed-off-by: hao-aaron <ahao@anyscale.com>
x
Signed-off-by: hao-aaron <ahao@anyscale.com>
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 14 additional findings in Devin Review.

Open in Devin Review

Comment on lines +286 to +289
async with session.post(
f"{self.url}/pause",
params={"mode": "keep"},
) as resp:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 RemoteInferenceEngine.pause_generation() omits clear_cache=false, risking KV cache eviction in keep mode

When RemoteInferenceEngine.pause_generation() calls the vLLM /pause endpoint, it only sends params={"mode": "keep"} but does not include clear_cache=false. Both other implementations explicitly pass clear_cache=False: AsyncVLLMInferenceEngine.pause_generation() at vllm_engine.py:635 passes clear_cache=clear_cache (default False), and RemoteInferenceClient.pause_generation() at remote_inference_client.py:451 passes "clear_cache": str(clear_cache).lower() (default "false"). The test mock server at test_remote_inference_client.py:72 defaults clear_cache to "true", confirming the server-side default clears cache. Without explicitly sending clear_cache=false, keep-mode pause through RemoteInferenceEngine will clear the KV cache, causing frozen in-flight requests to lose their cached state and defeating the purpose of keep mode.

Suggested change
async with session.post(
f"{self.url}/pause",
params={"mode": "keep"},
) as resp:
async with session.post(
f"{self.url}/pause",
params={"mode": "keep", "clear_cache": "false"},
) as resp:
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hao-aaron test_pause_and_continue_generation should now run with the new inference codepath right?

_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra fsdp --extra dev pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_pause_and_continue_generation.py

Can you add it to the GPU CI script if so? here:

_SKYRL_USE_NEW_INFERENCE=1 uv run --isolated --extra dev --extra fsdp pytest -s tests/backends/skyrl_train/gpu/gpu_ci/test_skyrl_gym_generator.py

@SumanthRH SumanthRH merged commit abe814f into NovaSky-AI:main Mar 13, 2026
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants