diff --git a/perf/scripts/explorer/perf_workflow.py b/perf/scripts/explorer/perf_workflow.py new file mode 100644 index 00000000000..4b705617875 --- /dev/null +++ b/perf/scripts/explorer/perf_workflow.py @@ -0,0 +1,80 @@ +import time +from typing import Any, List, Optional, cast + +from trinity.common.experience import Experience +from trinity.common.models.model import ModelWrapper +from trinity.common.workflows.workflow import Task, Workflow + + +class PerfWorkflow(Workflow): + """A workflow for performance testing of Explorer with OpenAI API calls.""" + + is_async: bool = True + can_reset: bool = True + + def __init__( + self, + *, + task: Task, + model: ModelWrapper, + auxiliary_models: Optional[List[ModelWrapper]] = None, + ): + super().__init__( + task=task, + model=model, + auxiliary_models=auxiliary_models, + ) + self.client = self.model.get_openai_async_client() + self.model_path = getattr(self.client, "model_path") + self.reset(task) + + def reset(self, task: Task) -> None: + raw_task = task.raw_task or {} + self.messages = raw_task.get("messages") or [] + if not self.messages: + raise ValueError("PerfWorkflow requires task.raw_task['messages'].") + self.tools = raw_task.get("tools") + + async def run_async(self) -> List[Experience]: + request_latencies = [] + usage_prompt_tokens = 0.0 + usage_completion_tokens = 0.0 + for i in range(len(self.messages)): + if self.messages[i].get("role") == "assistant": + # send a fake request to trigger the workflow and measure performance, but ignore the response content + request_kwargs = { + "model": self.model_path, + "messages": self.messages[:i], + } + if self.tools is not None: + request_kwargs["tools"] = self.tools + + request_start = time.perf_counter() + responses = await self.client.chat.completions.create(**request_kwargs) + request_latency = time.perf_counter() - request_start + request_latencies.append(request_latency) + + usage = cast(Any, getattr(responses, "usage", None)) + prompt_tokens = getattr(usage, "prompt_tokens", None) + completion_tokens = getattr(usage, "completion_tokens", None) + if isinstance(prompt_tokens, (int, float)): + usage_prompt_tokens += float(prompt_tokens) + if isinstance(completion_tokens, (int, float)): + usage_completion_tokens += float(completion_tokens) + + self.logger.info("Received response: %s", responses.choices[0].message) + exps = self.model.extract_experience_from_history() + total_request_latency = sum(request_latencies) + exps[0].metrics = { + "prompt_length": usage_prompt_tokens, + "response_length": usage_completion_tokens, + "api_call_prompt_tokens_per_second": ( + usage_prompt_tokens / total_request_latency if total_request_latency > 0 else 0.0 + ), + "api_call_response_tokens_per_second": ( + usage_completion_tokens / total_request_latency + if total_request_latency > 0 + else 0.0 + ), + } + return exps diff --git a/pyproject.toml b/pyproject.toml index bbb92fb6b83..2417e448f77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "matplotlib", "psutil", "nvidia-ml-py", - "transformers>=5.6.2", + "transformers>=5.8.0", "datasets>=4.0.0", "typer>=0.20.1", ] @@ -56,6 +56,9 @@ trinity = "trinity.cli.launcher:main" vllm = [ "vllm>=0.19.1,<=0.20.1", ] +sglang = [ + "sglang==0.5.11", +] data = [ "py-data-juicer>=1.4.3" ] diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index 58e62d7d100..740a9d796e1 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -19,7 +19,7 @@ RUN chmod 1777 /tmp && apt update && apt install -y \ curl git wget vim tmux net-tools cmake \ python3 python3-pip python3-dev python3-packaging python3-venv \ libomp-dev libnuma1 infiniband-diags libibverbs-dev librdmacm-dev rdma-core perftest \ - libnuma-dev \ + libnuma-dev protobuf-compiler \ && rm -rf /var/lib/apt/lists/* \ && ln -sf /usr/bin/python3 /usr/bin/python \ && ln -sf /usr/bin/pip3 /usr/bin/pip diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 58fe4708571..7c9bfbe1671 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -4,17 +4,75 @@ import math import os import shutil +import socket import unittest +from unittest.mock import patch import torch from tests.tools import get_template_config, get_unittest_dataset_config from trinity.common.config import InferenceModelConfig, load_config +from trinity.common.constants import SyncMethod +from trinity.common.models.model import InferenceModel CHECKPOINT_ROOT_DIR = os.path.join(os.path.dirname(__file__), "temp_checkpoint_dir") +class DummyInferenceModel(InferenceModel): + async def generate(self, prompt: str, **kwargs): + raise NotImplementedError + + async def chat(self, messages, **kwargs): + raise NotImplementedError + + async def logprobs(self, token_ids, **kwargs): + raise NotImplementedError + + async def convert_messages_to_experience(self, messages, tools=None, temperature=None): + raise NotImplementedError + + async def sync_model( + self, model_version: int, sync_method: SyncMethod, timeout: float = 1200 + ) -> int: + return model_version + + def get_model_version(self) -> int: + return 0 + + class TestConfig(unittest.TestCase): + def test_inference_model_base_port_uses_engine_id(self): + model = DummyInferenceModel(InferenceModelConfig(base_port=9000, engine_id=3)) + + _, port = model.get_available_address() + + self.assertEqual(port, 9003) + + def test_inference_model_base_port_falls_back_when_unavailable(self): + requested_port = 9004 + model = DummyInferenceModel(InferenceModelConfig(base_port=9000, engine_id=4)) + + with socket.socket() as occupied_socket: + occupied_socket.bind(("", requested_port)) + + with patch.object(model.logger, "warning") as mock_warning: + _, port = model.get_available_address() + + self.assertNotEqual(port, requested_port) + self.assertGreater(port, 0) + mock_warning.assert_called_once_with( + "Configured port %s is unavailable for engine %s; falling back to an ephemeral port.", + requested_port, + 4, + ) + + def test_inference_model_without_base_port_uses_ephemeral_port(self): + model = DummyInferenceModel(InferenceModelConfig()) + + _, port = model.get_available_address() + + self.assertGreater(port, 0) + def test_load_default_config(self): config = get_template_config() config.buffer.batch_size = 8 diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 89d8cc2e606..b2deb91ef5d 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -282,7 +282,7 @@ def __init__(self): super().__init__(InferenceModelConfig(model_path="dummy_model")) - def sync_model(self, model_version, update_weight_args_list): + def sync_model(self, model_version, sync_method, timeout): return True async def prepare(self): @@ -329,7 +329,7 @@ async def generate(self, prompt: str, lora_request=None, **kwargs) -> Sequence[E @ray.remote class DummyAuxiliaryModel(InferenceModel): - def sync_model(self, model_version, update_weight_args_list): + def sync_model(self, model_version, sync_method, timeout): return True def get_model_version(self): diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index dd9d4c8c4c8..bb45ac6093c 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -66,12 +66,14 @@ async def new_finish_explore_step(self: Explorer, step: int, model_version: int) await asyncio.sleep(explore_step_time_list[step - 1]) dummy_exps = [ Experience( - tokens=torch.tensor([0, 0, 0]), + tokens=torch.tensor([0, 1, 2]), info={"model_version": model_version}, ) for _ in range(self.config.buffer.train_batch_size) ] - await self.experience_pipeline.process.remote(Experience.serialize_many(dummy_exps)) + await self.rollout_coordinator.process_experiences.remote( + [Experience.serialize_many(dummy_exps)] + ) self.monitor.log(metric, step=step) Explorer.explore_step = new_explore_step @@ -347,6 +349,7 @@ class TestPullLatestWeights(unittest.IsolatedAsyncioTestCase): def setUp(self): self.explorer = object.__new__(Explorer) + self.explorer.config = Config() self.explorer.logger = MagicMock() self.explorer.models = [MagicMock(), MagicMock()] self.explorer.synchronizer = MagicMock() @@ -378,7 +381,11 @@ async def test_pull_latest_weights(self, model_version, new_version, expect_sync for m in self.explorer.models: if expect_sync: - m.sync_model.remote.assert_called_once_with(new_version) + m.sync_model.remote.assert_called_once_with( + new_version, + self.explorer.config.synchronizer.sync_method, + timeout=self.explorer.config.synchronizer.sync_timeout, + ) else: m.sync_model.remote.assert_not_called() diff --git a/tests/template/config.yaml b/tests/template/config.yaml index 1659d2e4a94..757bacc9f32 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -44,7 +44,7 @@ explorer: gpu_memory_utilization: 0.8 trainer: trainer_type: verl - trainer_strategy: fsdp + trainer_strategy: fsdp2 save_interval: 100 save_hf_checkpoint: never grad_clip: 1.0 diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 739b64c8169..69673abca68 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -78,7 +78,7 @@ def setUp(self): @parameterized_class( ("strategy",), [ - ("fsdp",), + ("fsdp2",), ("megatron",), ], ) @@ -576,7 +576,7 @@ def run_serve(config: Config, stop_event=None) -> None: @parameterized_class( ("use_priority_queue", "strategy"), - [(False, "fsdp"), (True, "fsdp"), (True, "megatron")], + [(False, "fsdp"), (True, "fsdp2"), (True, "megatron")], ) class TestFullyAsyncMode(unittest.TestCase): def setUp(self): @@ -603,16 +603,14 @@ def test_fully_async_mode(self): config.synchronizer.sync_method = SyncMethod.CHECKPOINT config.synchronizer.sync_style = SyncStyle.EXPLORER_DRIVEN config.synchronizer.sync_interval = 8 + config.trainer.trainer_strategy = self.strategy config.monitor.monitor_type = "tensorboard" trainer_config = deepcopy(config) trainer_config.mode = "train" trainer_config.buffer.train_batch_size = 4 - if self.strategy == "megatron": - trainer_config.trainer.trainer_strategy = "megatron" trainer_config.check_and_update() - if self.strategy == "megatron": - _trainer_config = trainer_config.trainer.trainer_config - _trainer_config.critic.strategy = "megatron" + _trainer_config = trainer_config.trainer.trainer_config + _trainer_config.critic.strategy = self.strategy explorer1_config = deepcopy(config) explorer1_config.trainer = deepcopy(trainer_config.trainer) diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index d474a2819cc..0e3a01de0d5 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -1,5 +1,6 @@ """Reader of the Queue buffer.""" +import traceback from typing import Dict, List, Optional import ray @@ -33,11 +34,24 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: ) except StopAsyncIteration: raise StopIteration() + except Exception as e: + if "StopAsyncIteration" in traceback.format_exc(): + raise StopIteration() from e + else: + raise return exps async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: batch_size = self.read_batch_size if batch_size is None else batch_size - exp_bytes = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs) + try: + exp_bytes = await self.queue.get_batch.remote( + batch_size, timeout=self.timeout, **kwargs + ) + except Exception as e: + if "StopAsyncIteration" in traceback.format_exc(): + raise StopAsyncIteration() from e + else: + raise exps = Experience.deserialize_many(exp_bytes) if len(exps) != batch_size: raise TimeoutError( diff --git a/trinity/common/config.py b/trinity/common/config.py index 57ff7f003a0..b6729bcd662 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -546,6 +546,8 @@ class InferenceModelConfig: repetition_penalty: Optional[float] = None # used for testing very long response generation, do not set it unless you know what you are doing ignore_eos: bool = False + # for multi-modal models + enable_multimodal: bool = False # override chat template in model chat_template: Optional[str] = None @@ -559,6 +561,7 @@ class InferenceModelConfig: # For OpenAI API enable_openai_api: bool = False enable_log_requests: bool = False # whether to enable request logging in vLLM API server + base_port: Optional[int] = None # For tool calls in OpenAI API enable_auto_tool_choice: bool = False @@ -572,6 +575,7 @@ class InferenceModelConfig: # ! DO NOT SET bundle_indices: str = "" + engine_id: int = 0 ray_namespace: Optional[str] = None cuda_visible_devices: Optional[str] = None @@ -751,7 +755,7 @@ class ExplorerConfig: class TrainerConfig: name: str = TRAINER_NAME trainer_type: str = "verl" - trainer_strategy: str = "fsdp" # "fsdp", "fsdp2" or "megatron" + trainer_strategy: str = "fsdp2" # "fsdp", "fsdp2" or "megatron" save_interval: int = 0 enable_preview: bool = True # enable rollout preview in wandb total_steps: Optional[ diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index d5fb9d80513..ea9fddf398f 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -1,6 +1,7 @@ import asyncio import os from collections import defaultdict +from copy import deepcopy from typing import Dict, List, Tuple import ray @@ -65,6 +66,10 @@ def create_explorer_models( from trinity.common.models.vllm_model import vLLMRolloutModel engine_cls = vLLMRolloutModel + elif config.explorer.rollout_model.engine_type == "sglang": + from trinity.common.models.sglang_model import SGLangRolloutModel + + engine_cls = SGLangRolloutModel elif config.explorer.rollout_model.engine_type == "external": rollout_engines = create_external_models( config=config.explorer.rollout_model, @@ -114,11 +119,15 @@ def create_explorer_models( if config.mode == "colocate": rollout_engine = ( - ray.remote(vLLMRolloutModel) + ray.remote(engine_cls) .options( name=f"{config.explorer.name}_rollout_model_0", num_cpus=0, - num_gpus=0, + num_gpus=( + config.explorer.rollout_model.tensor_parallel_size + if config.explorer.rollout_model.engine_type == "sglang" + else 0 + ), namespace=config.ray_namespace, ) .remote( @@ -127,6 +136,28 @@ def create_explorer_models( ) return [rollout_engine], [] + if config.explorer.rollout_model.engine_type == "sglang": + rollout_engines = create_sglang_explorer_models( + config=config.explorer.rollout_model, + actor_name=f"{config.explorer.name}_rollout_model", + ) + + if config.explorer.rollout_model.enable_history: + logger.info( + "Model History recording is enabled. Please periodically extract " + "history via `extract_experience_from_history` to avoid out-of-memory issues." + ) + + auxiliary_engines = [] + for i, model_config in enumerate(config.explorer.auxiliary_models): + engines = create_sglang_explorer_models( + config=model_config, + actor_name=f"{config.explorer.name}_auxiliary_model_{model_config.name or i}", + ) + auxiliary_engines.append(engines) + + return rollout_engines, auxiliary_engines + num_gpus = ( config.explorer.rollout_model.engine_num * config.explorer.rollout_model.tensor_parallel_size @@ -141,7 +172,12 @@ def create_explorer_models( allocator = _BundleAllocator(num_gpus=num_gpus) # create rollout models - rollout_engines = create_vllm_inference_models( + model_factory = ( + create_sglang_inference_models + if config.explorer.rollout_model.engine_type == "sglang" + else create_vllm_inference_models + ) + rollout_engines = model_factory( config=config.explorer.rollout_model, allocator=allocator, actor_name=f"{config.explorer.name}_rollout_model", @@ -156,7 +192,7 @@ def create_explorer_models( # create auxiliary models auxiliary_engines = [] for i, model_config in enumerate(config.explorer.auxiliary_models): - engines = create_vllm_inference_models( + engines = model_factory( config=model_config, allocator=allocator, actor_name=f"{config.explorer.name}_auxiliary_model_{model_config.name or i}", @@ -176,14 +212,49 @@ def create_vllm_inference_models( models = [] for i in range(config.engine_num): bundles_for_engine = allocator.allocate(config.tensor_parallel_size) - config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) + model_config = deepcopy(config) + model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) + model_config.engine_id = i models.append( ray.remote(vLLMRolloutModel) .options( name=f"{actor_name}_{i}", num_cpus=0, - num_gpus=0 if config.tensor_parallel_size > 1 else 1, - namespace=config.ray_namespace, + num_gpus=0 if model_config.tensor_parallel_size > 1 else 1, + namespace=model_config.ray_namespace, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=allocator.pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundles_for_engine[0], + ), + ) + .remote( + config=model_config, + ) + ) + return models + + +def create_sglang_inference_models( + config: InferenceModelConfig, + allocator: _BundleAllocator, + actor_name: str, +) -> List: + from trinity.common.models.sglang_model import SGLangRolloutModel + + models = [] + for i in range(config.engine_num): + bundles_for_engine = allocator.allocate(config.tensor_parallel_size) + model_config = deepcopy(config) + model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) + model_config.engine_id = i + models.append( + ray.remote(SGLangRolloutModel) + .options( + name=f"{actor_name}_{i}", + num_cpus=0, + num_gpus=0 if model_config.tensor_parallel_size > 1 else 1, + namespace=model_config.ray_namespace, scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=allocator.pg, placement_group_capture_child_tasks=True, @@ -191,7 +262,42 @@ def create_vllm_inference_models( ), ) .remote( - config=config, + config=model_config, + ) + ) + return models + + +def create_sglang_explorer_models( + config: InferenceModelConfig, + actor_name: str, +) -> List: + from trinity.common.models.sglang_model import SGLangRolloutModel + + models = [] + for i in range(config.engine_num): + model_config = deepcopy(config) + model_config.engine_id = i + engine_pg = placement_group( + [{"GPU": model_config.tensor_parallel_size}], + strategy="PACK", + ) + ray.get(engine_pg.ready()) + models.append( + ray.remote(SGLangRolloutModel) + .options( + name=f"{actor_name}_{i}", + num_cpus=0, + num_gpus=model_config.tensor_parallel_size, + namespace=model_config.ray_namespace, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=engine_pg, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=0, + ), + ) + .remote( + config=model_config, ) ) return models @@ -214,17 +320,19 @@ def create_external_models( models = [] for i in range(config.engine_num): + model_config = deepcopy(config) + model_config.engine_id = i models.append( ray.remote(ExternalModel) .options( name=f"{actor_name}_{i}", num_cpus=0, num_gpus=0, - namespace=config.ray_namespace, + namespace=model_config.ray_namespace, runtime_env={"env_vars": env_vars}, ) .remote( - config=config, + config=model_config, ) ) return models diff --git a/trinity/common/models/external_model.py b/trinity/common/models/external_model.py index 2d7e4ce1499..697a40878f7 100644 --- a/trinity/common/models/external_model.py +++ b/trinity/common/models/external_model.py @@ -4,6 +4,7 @@ import torch from trinity.common.config import InferenceModelConfig +from trinity.common.constants import SyncMethod from trinity.common.experience import Experience from trinity.common.models.model import InferenceModel @@ -153,7 +154,9 @@ async def convert_messages_to_experience( exp.tools = tools return exp - async def sync_model(self, model_version: int) -> int: + async def sync_model( + self, model_version: int, sync_method: SyncMethod, timeout: float = 1200 + ) -> int: # for self.model_version = model_version return self.model_version diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 8150893921d..37b9ebf4e48 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Base Model Class""" + import asyncio import copy import socket @@ -12,7 +13,7 @@ from torch import Tensor from trinity.common.config import InferenceModelConfig -from trinity.common.constants import RunningStatus +from trinity.common.constants import RunningStatus, SyncMethod from trinity.common.experience import Experience from trinity.common.models.utils import get_action_mask_method from trinity.utils.log import get_logger @@ -54,7 +55,9 @@ async def prepare(self) -> None: pass @abstractmethod - async def sync_model(self, model_version: int) -> int: + async def sync_model( + self, model_version: int, method: SyncMethod, timeout: float = 1200 + ) -> int: """Sync the model with the latest model_version.""" @abstractmethod @@ -64,6 +67,18 @@ def get_model_version(self) -> int: def get_available_address(self) -> Tuple[str, int]: """Get the address of the actor.""" address = ray.util.get_node_ip_address() + if self.config.base_port is not None: + configured_port = self.config.base_port + self.config.engine_id + with socket.socket() as s: + try: + s.bind(("", configured_port)) + return address, configured_port + except OSError: + self.logger.warning( + "Configured port %s is unavailable for engine %s; falling back to an ephemeral port.", + configured_port, + self.config.engine_id, + ) with socket.socket() as s: s.bind(("", 0)) port = s.getsockname()[1] @@ -589,9 +604,9 @@ async def get_current_load(self) -> int: data = response.json() return data["server_load"] - async def sync_model_weights(self, model_version: int) -> None: + async def sync_model_weights(self, model_version: int, method: SyncMethod) -> None: """Sync the model weights""" - await self.model.sync_model.remote(model_version) + await self.model.sync_model.remote(model_version, method) def extract_experience_from_history(self, clear_history: bool = True) -> List[Experience]: """Extract experiences from the history.""" diff --git a/trinity/common/models/sglang_model.py b/trinity/common/models/sglang_model.py new file mode 100644 index 00000000000..8c22086eb81 --- /dev/null +++ b/trinity/common/models/sglang_model.py @@ -0,0 +1,515 @@ +from __future__ import annotations + +import asyncio +import os +import traceback +from logging import Logger +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Sequence, Tuple + +import httpx +import torch +from transformers import AutoTokenizer + +from trinity.common.config import InferenceModelConfig +from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME, SyncMethod +from trinity.common.experience import Experience +from trinity.common.models.model import BaseInferenceModel +from trinity.manager.synchronizer import Synchronizer + +if TYPE_CHECKING: + from sglang.srt.server_args import ServerArgs + + +class SGLangClient: + """A simple http client to interact with the SGLang API server.""" + + def __init__(self, server_url: str, api_key: Optional[str], logger: Logger): + self.server_url = server_url + self.api_key = api_key + self.logger = logger + + async def _server_call( + self, + method: Literal["GET", "POST"], + endpoint: str, + payload: Optional[dict] = None, + timeout: float = 60, + ) -> dict: + async with httpx.AsyncClient( + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {self.api_key}" if self.api_key else "", + } + ) as client: + url = f"{self.server_url}{endpoint}" + self.logger.debug( + f"Making {method} request to SGLang API server at {url} with payload: {payload}" + ) + try: + if method == "GET": + response = await client.get(url, timeout=timeout) + elif method == "POST": + response = await client.post(url, json=payload or {}, timeout=timeout) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + response.raise_for_status() + return response.json() + except Exception: + self.logger.error( + f"Error during {method} request to SGLang API server at {url}:\n{traceback.format_exc()}" + ) + return {"error": traceback.format_exc()} + + async def health_check(self) -> bool: + try: + async with httpx.AsyncClient( + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {self.api_key}" if self.api_key else "", + } + ) as client: + response = await client.get(f"{self.server_url}/health", timeout=5) + return response.status_code == 200 + except Exception as e: + self.logger.debug(f"SGLang API server health check failed: {e}") + return False + + async def init_weights_update_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + timeout: int = 1200, + ) -> bool: + payload = { + "master_address": master_address, + "master_port": master_port, + "rank_offset": rank_offset, + "world_size": world_size, + "group_name": group_name, + "backend": backend, + } + response = await self._server_call( + "POST", "/init_weights_update_group", payload, timeout=timeout + ) + success = response.get("success", False) + if not success: + self.logger.error( + f"Failed to initialize weights update group in SGLang API server: {response.get('message')}" + ) + return success + + async def destroy_weights_update_group(self, group_name: str) -> bool: + payload = {"group_name": group_name} + response = await self._server_call("POST", "/destroy_weights_update_group", payload) + success = response.get("success", False) + if not success: + self.logger.error( + f"Failed to destroy weights update group in SGLang API server: {response.get('message')}" + ) + return success + + async def update_weights_from_distributed( + self, + state_dict_meta_list: List[Tuple[str, str, Tuple]], + group_name: str, + flash_cache: bool = True, + abort_all_requests: bool = True, + weight_version: Optional[str] = None, + timeout: float = 300, + ) -> bool: + names = [meta[0] for meta in state_dict_meta_list] + dtypes = [meta[1] for meta in state_dict_meta_list] + shapes = [meta[2] for meta in state_dict_meta_list] + payload = { + "names": names, + "dtypes": dtypes, + "shapes": shapes, + "group_name": group_name, + "flash_cache": flash_cache, + "abort_all_requests": abort_all_requests, + "weight_version": weight_version, + } + response = await self._server_call( + "POST", "/update_weights_from_distributed", payload, timeout=timeout + ) + success = response.get("success", False) + self.logger.info( + "Response from update_weights_from_distributed: %s", response.get("message", "") + ) + if not success: + self.logger.error( + f"Failed to update weights from distributed in SGLang API server: {response.get('message')}" + ) + return success + + async def update_weights_from_disk( + self, + model_path: str, + abort_all_requests: bool = True, + weight_version: Optional[str] = None, + is_async: bool = False, + timeout: float = 300, + ) -> bool: + payload = { + "model_path": model_path, + "abort_all_requests": abort_all_requests, + "weight_version": weight_version, + "is_async": is_async, + "torch_empty_cache": True, + } + response = await self._server_call( + "POST", "/update_weights_from_disk", payload, timeout=timeout + ) + success = response.get("success", False) + if not success: + self.logger.error( + f"Failed to update weights from disk in SGLang API server: {response.get('message')}" + ) + return success + + async def generate(self, input_ids: List[int], **kwargs) -> Sequence[dict[str, Any]]: + sampling_params = { + "n": kwargs.get("n", 1), + "temperature": kwargs.get("temperature"), + "top_p": kwargs.get("top_p"), + "top_k": kwargs.get("top_k"), + "max_new_tokens": kwargs.get("max_tokens"), + "min_new_tokens": kwargs.get("min_tokens"), + "repetition_penalty": kwargs.get("repetition_penalty"), + "stop": kwargs.get("stop"), + "ignore_eos": kwargs.get("ignore_eos"), + } + sampling_params = {k: v for k, v in sampling_params.items() if v is not None} + + payload: dict[str, Any] = { + "sampling_params": sampling_params, + "return_logprob": kwargs.get("return_logprob", False), + "top_logprobs_num": kwargs.get("top_logprobs_num", 0), + "return_text_in_logprobs": False, + "input_ids": input_ids, + } + + response = await self._server_call( + "POST", + "/generate", + payload, + timeout=kwargs.get("timeout", 300), + ) + if isinstance(response, dict) and response.get("error"): + raise RuntimeError(f"Failed to generate with SGLang: {response['error']}") + if isinstance(response, dict): + return [response] + if isinstance(response, list): + return response + raise TypeError(f"Unexpected SGLang generate response type: {type(response)!r}") + + +class SGLangRolloutModel(BaseInferenceModel): + """Wrapper around the SGLang engine to handle async requests. + + Args: + config (Config): The config. + """ + + def __init__( + self, + config: InferenceModelConfig, + ) -> None: + super().__init__(config) + if config.cuda_visible_devices: + os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_visible_devices + if not self.config.enable_openai_api: + self.logger.warning("SGLangRolloutModel requires OpenAI API to be enabled.") + self.config.enable_openai_api = True + os.environ["SGLANG_GRPC_PORT"] = "12345" # a dummy port not actually used + os.environ["SGLANG_ENABLE_GRPC"] = "0" + self.api_server_host: Optional[str] = None + self.api_server_port: Optional[int] = None + self.api_server: Optional[asyncio.Task[None]] = None + self.api_client: Optional[SGLangClient] = None + self.synchronizer = None + self.state_dict_meta: List[Tuple[str, str, Tuple]] = [] + self.model_version = 0 + self.server_args: Optional[ServerArgs] = None + self._prepared = False + self._has_weight_update_group = False + self.async_lock = asyncio.Lock() + self.group_name = ROLLOUT_WEIGHT_SYNC_GROUP_NAME + + async def init_process_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + explorer_name: str, + backend: str = "nccl", + timeout: int = 1200, + state_dict_meta: Optional[List[Tuple[str, str, Tuple]]] = None, + ): + assert ( + self.api_client is not None + ), "API client must be initialized before calling init_process_group" + if not self.synchronizer: + self.synchronizer = Synchronizer.get_actor(namespace=self.config.ray_namespace) + self.logger.info( + "SGLang starting init_process_group:\n" + f" > address={master_address}:{master_port}\n" + f" > rank_offset={rank_offset}\n" + f" > world_size={world_size}" + ) + self.state_dict_meta = state_dict_meta or [] + self.group_name = group_name + resp = await self.api_client.init_weights_update_group( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + group_name=group_name, + backend=backend, + timeout=timeout, + ) + self.logger.info("SGLang init_process_group finished.") + self._has_weight_update_group = resp + return resp + + async def _initialize_tokenizer(self) -> None: + if self.tokenizer is not None: + return + self.tokenizer = AutoTokenizer.from_pretrained( + self.config.model_path, + trust_remote_code=self.config.trust_remote_code, + ) + self.tokenizer.truncation_side = "left" + + async def prepare(self) -> None: + async with self.async_lock: + if self._prepared: + return + await self.run_api_server() + self._prepared = True + + @staticmethod + def _extract_output_logprobs(meta_info: dict[str, Any]) -> List[float]: + output_token_logprobs = meta_info.get("output_token_logprobs") or [] + return [float(logprob) for logprob, *_ in output_token_logprobs] + + def _normalize_chat_messages(self, messages: List[dict]) -> List[dict]: + normalized_messages = [] + for message in messages: + normalized_message = dict(message) + content = normalized_message.get("content") + if isinstance(content, list): + text_parts = [item["text"] for item in content if item.get("type") == "text"] + normalized_message["content"] = "".join(text_parts) + normalized_messages.append(normalized_message) + return normalized_messages + + async def generate(self, prompt: str, lora_request=None, **kwargs) -> Sequence[Experience]: + assert self.api_client is not None, "API client must be initialized before calling generate" + if self.tokenizer is None: + await self._initialize_tokenizer() + + returned_seq, is_valid = self._handle_prompt_truncation(prompt, **kwargs) + if not is_valid: + return returned_seq + prompt_token_ids = list(returned_seq) + + logprobs = kwargs.get("logprobs", self.config.logprobs) + return_logprob = logprobs is not None and logprobs is not False + responses = await self.api_client.generate( + input_ids=prompt_token_ids, + n=kwargs.get("n", 1), + temperature=kwargs.get("temperature", self.config.temperature), + top_p=kwargs.get("top_p", self.config.top_p), + top_k=kwargs.get("top_k", self.config.top_k), + max_tokens=kwargs.get("max_tokens", self.config.max_response_tokens), + min_tokens=kwargs.get("min_tokens", self.config.min_response_tokens), + repetition_penalty=kwargs.get("repetition_penalty", self.config.repetition_penalty), + stop=kwargs.get("stop"), + ignore_eos=kwargs.get("ignore_eos", self.config.ignore_eos), + return_logprob=return_logprob, + timeout=kwargs.get("timeout", 300), + ) + + prompt_text = self.tokenizer.decode(prompt_token_ids) + experiences = [] + for response in responses: + response_token_ids = response.get("output_ids") or [] + response_text = response.get("text") or "" + if not response_token_ids and response_text: + response_token_ids = self.tokenizer.encode(response_text, add_special_tokens=False) + + meta_info = response.get("meta_info") or {} + prompt_length = int(meta_info.get("prompt_tokens") or len(prompt_token_ids)) + if return_logprob: + response_logprobs = torch.tensor( + self._extract_output_logprobs(meta_info), + dtype=torch.float32, + ) + else: + response_logprobs = torch.tensor([], dtype=torch.float32) + + experiences.append( + Experience( + tokens=torch.tensor(prompt_token_ids + response_token_ids, dtype=torch.int32), + logprobs=response_logprobs, + prompt_length=prompt_length, + prompt_text=prompt_text, + response_text=response_text, + ) + ) + return experiences + + async def chat(self, messages: List[dict], lora_request=None, **kwargs) -> Sequence[Experience]: + if self.tokenizer is None: + await self._initialize_tokenizer() + + normalized_messages = self._normalize_chat_messages(messages) + prompt = self.apply_chat_template(self.tokenizer, normalized_messages) + return await self.generate(prompt=prompt, lora_request=lora_request, **kwargs) + + async def logprobs(self, token_ids: List[int], **kwargs) -> torch.Tensor: + raise NotImplementedError("SGLangRolloutModel does not support logprobs.") + + async def convert_messages_to_experience( + self, + messages: List[dict], + tools=None, + temperature: Optional[float] = None, + ) -> Experience: + del messages, tools, temperature + raise NotImplementedError( + "SGLangRolloutModel does not support convert_messages_to_experience." + ) + + def _build_server_args(self, host: str, port: int): + from sglang.srt.server_args import ServerArgs + + server_args_kwargs = { + "model_path": self.config.model_path, + "host": host, + "port": port, + "tp_size": self.config.tensor_parallel_size, + "dtype": self.config.dtype, + "mem_fraction_static": self.config.gpu_memory_utilization, + "served_model_name": self.config.name or self.config.model_path, + "trust_remote_code": self.config.trust_remote_code, + "context_length": self.config.max_model_len, + "enable_multimodal": self.config.enable_multimodal, + "skip_server_warmup": True, + "disable_piecewise_cuda_graph": True, + "api_key": "EMPTY", + "device": "cuda", + } + # if self.config.chat_template: + # server_args_kwargs["chat_template"] = self.config.chat_template + # TODO: fill in nnodes and node_rank for distributed setups + return ServerArgs(**server_args_kwargs) + + def _get_api_server_exit_reason(self) -> Optional[str]: + if self.api_server is None or not self.api_server.done(): + return None + if self.api_server.cancelled(): + return "cancelled" + exc = self.api_server.exception() + return "unknown error" if exc is None else repr(exc) + + async def _wait_until_server_ready(self, server_url: str) -> None: + max_retries = 100 + assert self.server_args is not None and self.api_client is not None + for _ in range(max_retries): + reason = self._get_api_server_exit_reason() + if reason is not None: + raise RuntimeError(f"SGLang API server exited before becoming ready: {reason}.") + if await self.api_client.health_check(): + self.logger.info(f"SGLang API server at {server_url} is ready.") + return + self.logger.debug(f"SGLang API server at {server_url} not ready yet, retrying...") + await asyncio.sleep(2) + self.logger.error( + f"SGLang API server at {server_url} not ready after {max_retries} attempts." + ) + await self.shutdown() + raise RuntimeError( + f"SGLang API server at {server_url} not ready after {max_retries} attempts." + ) + + async def run_api_server(self) -> bool: + from trinity.common.models.sglang_patch import get_api_server + + if self.api_server_host is None or self.api_server_port is None: + self.api_server_host, self.api_server_port = self.get_available_address() + self.server_args = self._build_server_args( + host=self.api_server_host, + port=self.api_server_port, + ) + self.api_server = get_api_server(self.server_args, logger=self.logger) + self.api_client = SGLangClient( + server_url=f"http://{self.api_server_host}:{self.api_server_port}", + api_key=self.server_args.api_key, + logger=self.logger, + ) + await self._wait_until_server_ready(self.server_args.url()) + return True + + async def shutdown(self) -> None: + if self.api_server is not None: + if self._has_weight_update_group: + await self.api_client.destroy_weights_update_group(group_name=self.group_name) + self.api_server.cancel() + try: + await self.api_server + except asyncio.CancelledError: + pass + reason = self._get_api_server_exit_reason() + if reason not in {None, "cancelled"}: + self.logger.warning("Embedded SGLang HTTP server exited with error: %s", reason) + self.api_server = None + self.api_client = None + + async def sync_model( + self, model_version: int, method: SyncMethod, timeout: float = 1200 + ) -> int: + assert ( + self.api_client is not None + ), "API client must be initialized before calling sync_model" + assert ( + self.synchronizer is not None + ), "Synchronizer must be initialized before calling sync_model" + self.logger.info(f"Synchronizing model to version {model_version} using method {method}...") + if method == SyncMethod.NCCL: + assert self.state_dict_meta, "state_dict_meta must be initialized for NCCL sync" + await self.api_client.update_weights_from_distributed( + state_dict_meta_list=self.state_dict_meta, + group_name=self.group_name, + weight_version=str(model_version), + timeout=timeout, + ) + self.model_version = model_version + elif method == SyncMethod.CHECKPOINT: + model_path = await self.synchronizer.get_latest_model_path.remote() + if model_path is not None: + await self.api_client.update_weights_from_disk( + model_path=model_path, + weight_version=str(model_version), + timeout=timeout, + ) + self.model_version = model_version + else: + raise ValueError(f"Unsupported sync method for SGLang: {method}") + self.logger.info("Synchronization finished.") + return model_version + + def get_model_version(self) -> int: + return self.model_version + + def get_api_server_url(self) -> Optional[str]: + if not self._prepared: + raise RuntimeError("Model is not prepared. Please call `prepare()` first.") + return f"http://{self.api_server_host}:{self.api_server_port}" diff --git a/trinity/common/models/sglang_patch/__init__.py b/trinity/common/models/sglang_patch/__init__.py new file mode 100644 index 00000000000..7f2d6fdb07f --- /dev/null +++ b/trinity/common/models/sglang_patch/__init__.py @@ -0,0 +1,5 @@ +from trinity.common.models.sglang_patch.api_patch import get_api_server + +__all__ = [ + "get_api_server", +] diff --git a/trinity/common/models/sglang_patch/api_patch.py b/trinity/common/models/sglang_patch/api_patch.py new file mode 100644 index 00000000000..87050893e2c --- /dev/null +++ b/trinity/common/models/sglang_patch/api_patch.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import asyncio +from logging import Logger +from typing import Callable, Dict, List, Optional + +import uvicorn +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.entrypoints.http_server import ( + _execute_server_warmup, + _GlobalState, + add_prometheus_track_response_middleware, + app, + app_has_admin_force_endpoints, + envs, + set_global_state, + set_uvicorn_logging_configs, +) +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_process_tree +from sglang.srt.utils.watchdog import SubprocessWatchdog + + +def _setup_and_run_http_server( # noqa: C901 + server_args: ServerArgs, + tokenizer_manager, + template_manager, + port_args, + scheduler_infos: List[Dict], + subprocess_watchdog: Optional[SubprocessWatchdog], + logger: Logger, + server: uvicorn.Server, + execute_warmup_func: Callable = _execute_server_warmup, + launch_callback: Optional[Callable[[], None]] = None, +) -> "asyncio.Task[None]": + set_global_state( + _GlobalState( + tokenizer_manager=tokenizer_manager, + template_manager=template_manager, + scheduler_info=scheduler_infos[0], + ) + ) + del port_args + + if tokenizer_manager is not None: + tokenizer_manager._subprocess_watchdog = subprocess_watchdog + + if server_args.enable_metrics: + add_prometheus_track_response_middleware(app) + + setattr(app, "is_single_tokenizer_mode", True) + setattr(app, "server_args", server_args) + setattr( + app, + "warmup_thread_kwargs", + { + "server_args": server_args, + "launch_callback": launch_callback, + "execute_warmup_func": execute_warmup_func, + }, + ) + + if server_args.api_key or server_args.admin_api_key or app_has_admin_force_endpoints(app): + from sglang.srt.utils.auth import add_api_key_middleware + + add_api_key_middleware( + app, + api_key=server_args.api_key, + admin_api_key=server_args.admin_api_key, + ) + + set_uvicorn_logging_configs(server_args) + if server_args.ssl_certfile: + logger.info( + "SSL enabled: certfile=%s, keyfile=%s", + server_args.ssl_certfile, + server_args.ssl_keyfile, + ) + + cleaned_up = False + + async def cleanup() -> None: + nonlocal cleaned_up + if cleaned_up: + return + cleaned_up = True + + if subprocess_watchdog is not None: + subprocess_watchdog.stop() + for process in getattr(subprocess_watchdog, "_processes", []): + if process is None or process.pid is None: + continue + try: + kill_process_tree(process.pid, wait_timeout=60) + except Exception as exc: + logger.warning( + "Failed to terminate SGLang child process %s: %s", + process.pid, + exc, + ) + if tokenizer_manager is not None and hasattr(tokenizer_manager, "_subprocess_watchdog"): + tokenizer_manager._subprocess_watchdog = None + + import sglang.srt.entrypoints.http_server as http_server_module + + http_server_module._global_state = None + + task = asyncio.create_task(server.serve()) + + def _cleanup_task(task: "asyncio.Task[None]") -> None: + if task.cancelled(): + asyncio.create_task(cleanup()) + return + if task.exception() is not None: + logger.warning("Embedded SGLang HTTP server exited with error: %s", task.exception()) + asyncio.create_task(cleanup()) + + task.add_done_callback(_cleanup_task) + return task + + +def get_api_server( + server_args: ServerArgs, + logger: Logger, +) -> "asyncio.Task[None]": + if server_args.enable_http2: + raise NotImplementedError("Embedded SGLang server does not support HTTP/2 yet.") + if server_args.tokenizer_worker_num != 1: + raise NotImplementedError( + "Embedded SGLang server currently supports tokenizer_worker_num == 1 only." + ) + if server_args.enable_ssl_refresh: + raise NotImplementedError("Embedded SGLang server does not support SSL refresh yet.") + + ( + tokenizer_manager, + template_manager, + port_args, + scheduler_init_result, + subprocess_watchdog, + ) = Engine._launch_subprocesses( + server_args=server_args, + init_tokenizer_manager_func=Engine.init_tokenizer_manager_func, + run_scheduler_process_func=Engine.run_scheduler_process_func, + run_detokenizer_process_func=Engine.run_detokenizer_process_func, + ) + + config = uvicorn.Config( + app, + host=server_args.host, + port=server_args.port, + root_path=server_args.fastapi_root_path, + log_level=server_args.log_level_http or server_args.log_level, + timeout_keep_alive=envs.SGLANG_TIMEOUT_KEEP_ALIVE.get(), + loop="uvloop", + ssl_keyfile=server_args.ssl_keyfile, + ssl_certfile=server_args.ssl_certfile, + ssl_ca_certs=server_args.ssl_ca_certs, + ssl_keyfile_password=server_args.ssl_keyfile_password, + ) + server = uvicorn.Server(config) + return _setup_and_run_http_server( + server_args, + tokenizer_manager, + template_manager, + port_args, + scheduler_init_result.scheduler_infos, + subprocess_watchdog, + logger, + server, + ) diff --git a/trinity/common/models/tinker_model.py b/trinity/common/models/tinker_model.py index 1629e1e9f95..d5ded3a18a4 100644 --- a/trinity/common/models/tinker_model.py +++ b/trinity/common/models/tinker_model.py @@ -9,6 +9,7 @@ from torch import Tensor from trinity.common.config import InferenceModelConfig +from trinity.common.constants import SyncMethod from trinity.common.experience import Experience from trinity.common.models.model import BaseInferenceModel from trinity.manager.synchronizer import Synchronizer @@ -149,7 +150,9 @@ async def prepare(self) -> None: ) await self._initialize_tokenizer() - async def sync_model(self, model_version: int) -> int: + async def sync_model( + self, model_version: int, sync_method: SyncMethod, timeout: float = 1200 + ) -> int: self.model_version = model_version remote_sampler_path, _ = await self.synchronizer.get_model_state_dict.remote() self.model = await self.service_client.create_sampling_client_async( diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index f3537638d63..01de7052613 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -10,6 +10,7 @@ from transformers import AutoProcessor from trinity.common.config import InferenceModelConfig +from trinity.common.constants import SyncMethod from trinity.common.experience import Experience from trinity.common.models.mm_utils import ( build_mm_input_for_training, @@ -56,13 +57,6 @@ def __init__( os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0" if self.config.enable_runtime_lora_updating: os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "1" - if not config.enforce_eager: - # To avoid torch compile conflicts when multiple model are started simultaneously. - # remove this when the following PR is released: - # https://github.com/vllm-project/vllm/pull/27616 - os.environ["VLLM_CACHE_ROOT"] = os.path.expanduser( - f"~/.cache/vllm/{config.bundle_indices}" - ) self.tokenization_kwargs = { "truncate_prompt_tokens": config.max_prompt_tokens if config.enable_prompt_truncation @@ -487,7 +481,9 @@ async def _collective_rpc( method, timeout, args, kwargs ) - async def sync_model(self, model_version: int) -> int: + async def sync_model( + self, model_version: int, sync_method: SyncMethod, timeout: float = 1200 + ) -> int: """Sync model weights to vLLM.""" if self.enable_lora: # Revise the lora path; no need to sync weights manually. @@ -504,8 +500,10 @@ async def sync_model(self, model_version: int) -> int: self.model_version = model_version return model_version await self.async_llm.reset_prefix_cache() - await self._collective_rpc("update_weight") - self.logger.info("Sync model weights to vLLM successfully.") + await self._collective_rpc("update_weight", timeout=timeout) + self.logger.info( + f"Synchronized model to version {model_version} using method {sync_method}." + ) self.model_version = model_version return model_version @@ -519,7 +517,7 @@ async def init_process_group( explorer_name: str, backend: str = "nccl", timeout: int = 1200, - state_dict_meta: dict = None, + state_dict_meta: List = None, ): return await self._collective_rpc( "init_process_group", diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index a6239c4f044..527690a64ed 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -54,9 +54,7 @@ def init_process_group( timeout=timeout, world_size=world_size, rank=self._weight_update_rank, - device_id=self.device, ) - torch.distributed.barrier(group=self._model_update_group) self.logger.info("vLLM init_process_group finished.") self._explorer_name = explorer_name self._namespace = namespace @@ -92,6 +90,5 @@ def update_weight(self): weight = weight.type(self.model_config.dtype) self.model_runner.model.load_weights(weights=[(name, weight)]) del weight - torch.distributed.barrier(group=self._model_update_group) torch.cuda.synchronize() torch.cuda.empty_cache() diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index d24cf41a8ad..25853322fd9 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -328,7 +328,7 @@ async def run_async(self) -> List[Experience]: # TODO: Optimize the generate function messages = self.format_messages() - self.logger.debug("start chat") + self.logger.info("start chat") responses = await self.model.chat_async(messages, **self.rollout_args) for i, response in enumerate(responses): reward_dict = self.reward_fn( # type: ignore [misc] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 30466b67387..40b4efe80a8 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """The explorer module""" + from __future__ import annotations import asyncio @@ -144,7 +145,16 @@ async def setup_model_level_weight_sync_group(self): async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int: self.logger.info(f"Start to update model weights from checkpoint at step {step_num}.") step_num = await self.synchronizer.set_model_state_dict_with_step_num.remote(step_num) - await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models]) + await asyncio.gather( + *[ + model.sync_model.remote( + step_num, + self.config.synchronizer.sync_method, + timeout=self.config.synchronizer.sync_timeout, + ) + for model in self.models + ] + ) self.logger.info(f"Model weights updated to checkpoint at step {step_num}.") return step_num # type: ignore @@ -157,7 +167,14 @@ async def _pull_latest_weights(self): if self.model_version != -1 or new_version > 0: self.logger.info(f"New model weights version: {new_version}") await asyncio.gather( - *[model.sync_model.remote(new_version) for model in self.models] + *[ + model.sync_model.remote( + new_version, + self.config.synchronizer.sync_method, + timeout=self.config.synchronizer.sync_timeout, + ) + for model in self.models + ] ) self.model_version = new_version else: @@ -174,7 +191,14 @@ async def _nccl_weights_update(self): return self.model_version = new_version await asyncio.gather( - *[model.sync_model.remote(self.model_version) for model in self.models] + *[ + model.sync_model.remote( + self.model_version, + self.config.synchronizer.sync_method, + timeout=self.config.synchronizer.sync_timeout, + ) + for model in self.models + ] ) async def prepare(self) -> None: diff --git a/trinity/explorer/proxy/service.py b/trinity/explorer/proxy/service.py index d8f83ca3b49..eeb5246cb40 100644 --- a/trinity/explorer/proxy/service.py +++ b/trinity/explorer/proxy/service.py @@ -5,7 +5,7 @@ import torch -from trinity.common.constants import RunningStatus +from trinity.common.constants import RunningStatus, SyncMethod from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.explorer.explorer import Explorer @@ -106,7 +106,7 @@ async def _sync_model_weights(self, index: int) -> None: ) latest_version = self.latest_model_version # capture the latest version # perform synchronization - await self.models[index].sync_model_weights(latest_version) + await self.models[index].sync_model_weights(latest_version, method=SyncMethod.CHECKPOINT) self.logger.info(f"Model {index} synchronized to version {latest_version}.") self.model_version_map[index] = await self.models[index].model_version_async self.models[index].status = RunningStatus.RUNNING diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 8682e74c767..3c72364674c 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -36,6 +36,7 @@ def __init__(self, config: Config, module_ref: ray.actor.ActorHandle): self._ready_condition = asyncio.Condition() self.model_state_dict = None self.model_version = 0 + self.model_path = None self.checkpoint_shard_counter = defaultdict(lambda: 0) self.ref_count = 0 self._modules = {module_ref} @@ -263,6 +264,10 @@ async def set_model_state_dict( async with self._ready_condition: self.model_state_dict = model_state_dict self.model_version = trainer_step + # TODO: check model_path for different trainer types + self.model_path = os.path.join( + self.config.checkpoint_job_dir, f"global_step_{trainer_step}", "actor" + ) self.logger.info(f"Set model state dict version to {trainer_step}.") self._ready_condition.notify_all() @@ -351,6 +356,16 @@ async def get_latest_model_version(self) -> int: async with self._ready_condition: return self.model_version + async def get_latest_model_path(self) -> Optional[str]: + """ + Get the latest model path available in the synchronizer. + + Returns: + The current model path. + """ + async with self._ready_condition: + return self.model_path + async def ready_to_nccl_sync(self, module: str, trainer_step: int) -> Union[int, None]: """ Prepare for NCCL-based synchronization between modules. diff --git a/trinity/perf/report_metrics.py b/trinity/perf/report_metrics.py new file mode 100644 index 00000000000..1485560513b --- /dev/null +++ b/trinity/perf/report_metrics.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Any, Optional + +EXPERIENCE_COUNT_METRIC_KEY = "experience_pipeline/experience_count" +PROMPT_LENGTH_MEAN_METRIC_KEY = "rollout/prompt_length/mean" +RESPONSE_LENGTH_MEAN_METRIC_KEY = "rollout/response_length/mean" +API_CALL_PROMPT_TOKENS_PER_SECOND_MEAN_METRIC_KEY = "rollout/api_call_prompt_tokens_per_second/mean" +API_CALL_RESPONSE_TOKENS_PER_SECOND_MEAN_METRIC_KEY = ( + "rollout/api_call_response_tokens_per_second/mean" +) + + +def compute_global_token_throughput_metrics( + execution_time_sec: Optional[float], step_metrics: list[dict[str, Any]] +) -> dict[str, float | None]: + api_call_prompt_tokens_per_second_values = [ + float(step_metric[API_CALL_PROMPT_TOKENS_PER_SECOND_MEAN_METRIC_KEY]) + for step_metric in step_metrics + if step_metric.get(API_CALL_PROMPT_TOKENS_PER_SECOND_MEAN_METRIC_KEY) is not None + ] + api_call_response_tokens_per_second_values = [ + float(step_metric[API_CALL_RESPONSE_TOKENS_PER_SECOND_MEAN_METRIC_KEY]) + for step_metric in step_metrics + if step_metric.get(API_CALL_RESPONSE_TOKENS_PER_SECOND_MEAN_METRIC_KEY) is not None + ] + + if execution_time_sec is None or execution_time_sec <= 0: + return { + "prompt_tokens_per_second": None, + "response_tokens_per_second": None, + "api_call_prompt_tokens_per_second": ( + sum(api_call_prompt_tokens_per_second_values) + / len(api_call_prompt_tokens_per_second_values) + if api_call_prompt_tokens_per_second_values + else None + ), + "api_call_response_tokens_per_second": ( + sum(api_call_response_tokens_per_second_values) + / len(api_call_response_tokens_per_second_values) + if api_call_response_tokens_per_second_values + else None + ), + } + + prompt_token_total = 0.0 + response_token_total = 0.0 + for step_metric in step_metrics: + experience_count = step_metric.get(EXPERIENCE_COUNT_METRIC_KEY) + prompt_length_mean = step_metric.get(PROMPT_LENGTH_MEAN_METRIC_KEY) + response_length_mean = step_metric.get(RESPONSE_LENGTH_MEAN_METRIC_KEY) + if experience_count is None: + continue + if prompt_length_mean is not None: + prompt_token_total += float(experience_count) * float(prompt_length_mean) + if response_length_mean is not None: + response_token_total += float(experience_count) * float(response_length_mean) + + return { + "prompt_tokens_per_second": prompt_token_total / float(execution_time_sec), + "response_tokens_per_second": response_token_total / float(execution_time_sec), + "api_call_prompt_tokens_per_second": ( + sum(api_call_prompt_tokens_per_second_values) + / len(api_call_prompt_tokens_per_second_values) + if api_call_prompt_tokens_per_second_values + else None + ), + "api_call_response_tokens_per_second": ( + sum(api_call_response_tokens_per_second_values) + / len(api_call_response_tokens_per_second_values) + if api_call_response_tokens_per_second_values + else None + ), + } diff --git a/trinity/perf/report_viewer.py b/trinity/perf/report_viewer.py index 415bfe74ca5..8c22ec7bc40 100644 --- a/trinity/perf/report_viewer.py +++ b/trinity/perf/report_viewer.py @@ -21,6 +21,8 @@ "rollout/time/task_execution/mean", "rollout/prompt_length/mean", "rollout/response_length/mean", + "rollout/api_call_prompt_tokens_per_second/mean", + "rollout/api_call_response_tokens_per_second/mean", "experience_pipeline/experience_count", ], "trainer": [], @@ -255,6 +257,16 @@ def render_header(report: dict[str, Any], report_path: str) -> None: st.code(str(status["error"])) +def compute_global_token_throughput_metrics(report: dict[str, Any]) -> dict[str, float | None]: + timing = report.get("timing", {}) + return { + "prompt_tokens_per_second": timing.get("prompt_tokens_per_second"), + "response_tokens_per_second": timing.get("response_tokens_per_second"), + "api_call_prompt_tokens_per_second": timing.get("api_call_prompt_tokens_per_second"), + "api_call_response_tokens_per_second": timing.get("api_call_response_tokens_per_second"), + } + + def render_global_metrics(report: dict[str, Any]) -> None: st.header("Global Metrics") timing = report.get("timing", {}) @@ -265,18 +277,21 @@ def render_global_metrics(report: dict[str, Any]) -> None: metric_key, timing.get(metric_key), ) - for metric_key in ("startup_time_sec", "execution_time_sec", "total_time_sec") + for metric_key in ("startup_time_sec", "execution_time_sec") ) + metric_items.extend(compute_global_token_throughput_metrics(report).items()) shown_items = [(key, value) for key, value in metric_items if value is not None] if not shown_items: st.info("No global metrics found in this report.") return - columns = st.columns(min(4, len(shown_items))) - for index, (metric_key, value) in enumerate(shown_items): - with columns[index % len(columns)]: - render_metric_card(metric_key, value) + for row_start in range(0, len(shown_items), 2): + row_items = shown_items[row_start : row_start + 2] + columns = st.columns(len(row_items)) + for column, (metric_key, value) in zip(columns, row_items): + with column: + render_metric_card(metric_key, value) def render_step_metrics(report: dict[str, Any]) -> None: diff --git a/trinity/perf/stage_perf.py b/trinity/perf/stage_perf.py index 7185c26ebaa..dca98cbeccc 100644 --- a/trinity/perf/stage_perf.py +++ b/trinity/perf/stage_perf.py @@ -13,6 +13,7 @@ from trinity.buffer.pipelines.task_pipeline import check_and_run_task_pipeline from trinity.common.config import Config, load_config +from trinity.perf.report_metrics import compute_global_token_throughput_metrics from trinity.perf.resource_sampler import ResourceSampler from trinity.perf.tensorboard_metrics import ( TensorBoardScalarReader, @@ -62,6 +63,18 @@ def build_explorer_perf_payload( "output_json": options.output_path, } + timing = { + "startup_time_sec": startup_time_sec, + "execution_time_sec": execution_time_sec, + "total_time_sec": total_time_sec, + } + timing.update( + compute_global_token_throughput_metrics( + execution_time_sec=execution_time_sec, + step_metrics=step_metrics, + ) + ) + return { "run_meta": { "module": "explorer", @@ -72,11 +85,7 @@ def build_explorer_perf_payload( "pid": os.getpid(), "generated_at": time.time(), }, - "timing": { - "startup_time_sec": startup_time_sec, - "execution_time_sec": execution_time_sec, - "total_time_sec": total_time_sec, - }, + "timing": timing, **resource_payload, "step_metrics": step_metrics, "artifacts": artifacts, diff --git a/trinity/trainer/verl/fsdp_checkpoint_manager.py b/trinity/trainer/verl/fsdp_checkpoint_manager.py index 1051a4a169c..9c93ade5350 100644 --- a/trinity/trainer/verl/fsdp_checkpoint_manager.py +++ b/trinity/trainer/verl/fsdp_checkpoint_manager.py @@ -115,7 +115,12 @@ def upload_state_dict(self, global_step: int): state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with get_fsdp_state_ctx(self.model, StateDictType.FULL_STATE_DICT, state_dict_config, None): state_dict = self.model.state_dict() - state_dict = {key: value.to("cpu") for key, value in state_dict.items()} + state_dict = { + key: (value.full_tensor() if hasattr(value, "full_tensor") else value) + .detach() + .to("cpu") + for key, value in state_dict.items() + } self._upload_state_dict(state_dict, global_step) def _save_with_thread( diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index 8de3a3bd434..1620bcf5987 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -100,6 +100,12 @@ from trinity.utils.log import get_logger +def _align_trainable_param_dtype(module: torch.nn.Module, target_dtype: torch.dtype) -> None: + for param in module.parameters(): + if param.requires_grad and param.dtype != target_dtype: + param.data = param.data.to(dtype=target_dtype) + + class ActorRolloutRefWorker(Worker, DistProfilerExtension): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy @@ -320,6 +326,20 @@ def _build_model_optimizer( # noqa: C901 log_gpu_memory_usage(f"Before init {role} from HF AutoModel", logger=self.logger) local_path = model_path + fsdp_strategy = self.config.actor.strategy + mixed_precision_config = fsdp_config.get("mixed_precision", None) + if mixed_precision_config is not None: + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("reduce_dtype", "fp32") + ) + buffer_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("buffer_dtype", "fp32") + ) + else: + param_dtype = PrecisionType.to_dtype(fsdp_config.dtype) + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 # note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect # TODO(zhangchi.usc1992): 1. support create from random initialized model. 2. Support init with FSDP directly @@ -334,7 +354,10 @@ def _build_model_optimizer( # noqa: C901 torch_dtype = fsdp_config.model_dtype if torch_dtype is None: - torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 + if self._is_lora and fsdp_strategy == "fsdp2": + torch_dtype = param_dtype + else: + torch_dtype = torch.float32 if self._is_actor else torch.bfloat16 else: torch_dtype = PrecisionType.to_dtype(torch_dtype) @@ -434,6 +457,7 @@ def _build_model_optimizer( # noqa: C901 if self._is_lora: self.logger.info("Applying LoRA to actor module") actor_module.enable_input_require_grads() + autocast_adapter_dtype = fsdp_strategy != "fsdp2" lora_adapter_path = self.config.model.get("lora_adapter_path") if lora_adapter_path is not None: @@ -449,7 +473,10 @@ def _build_model_optimizer( # noqa: C901 ) actor_module = PeftModel.from_pretrained( - actor_module, local_adapter_path, is_trainable=True + actor_module, + local_adapter_path, + is_trainable=True, + autocast_adapter_dtype=autocast_adapter_dtype, ) peft_config = actor_module.peft_config["default"] # Ensure task_type is TaskType enum, not string @@ -466,7 +493,14 @@ def _build_model_optimizer( # noqa: C901 "exclude_modules": convert_to_regular_types(self.config.model.exclude_modules), "bias": "none", } - actor_module = get_peft_model(actor_module, LoraConfig(**lora_config)) + actor_module = get_peft_model( + actor_module, + LoraConfig(**lora_config), + autocast_adapter_dtype=autocast_adapter_dtype, + ) + + if fsdp_strategy == "fsdp2": + _align_trainable_param_dtype(actor_module, param_dtype) self.use_orig_params = fsdp_config.get("use_orig_params", False) if self.config.actor.get("freeze_vision_tower", False): @@ -488,20 +522,6 @@ def _build_model_optimizer( # noqa: C901 log_gpu_memory_usage(f"After init {role} from HF AutoModel", logger=self.logger) # We wrap FSDP for rollout as well - mixed_precision_config = fsdp_config.get("mixed_precision", None) - if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) - reduce_dtype = PrecisionType.to_dtype( - mixed_precision_config.get("reduce_dtype", "fp32") - ) - buffer_dtype = PrecisionType.to_dtype( - mixed_precision_config.get("buffer_dtype", "fp32") - ) - else: - param_dtype = PrecisionType.to_dtype(fsdp_config.dtype) - reduce_dtype = torch.float32 - buffer_dtype = torch.float32 - mixed_precision = MixedPrecision( param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype ) @@ -523,7 +543,6 @@ def _build_model_optimizer( # noqa: C901 # We force reference policy to use CPUOffload to save memory. # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation cpu_offload = None if role == "actor" else CPUOffload(offload_params=True) - fsdp_strategy = self.config.actor.strategy if fsdp_strategy == "fsdp": actor_module_fsdp = FSDP( actor_module, @@ -804,6 +823,7 @@ def setup_weight_sync_group(self): ) timeout = self.config.synchronizer.sync_timeout + self.logger.info("Trainer start init_process_group.") self._model_update_group = init_process_group( host=master_address, port=master_port, @@ -811,11 +831,10 @@ def setup_weight_sync_group(self): backend="nccl", timeout=timeout, world_size=world_size, - device_id=torch.device(f"cuda:{get_device_id()}"), rank=0, ) - torch.distributed.barrier(group=self._model_update_group) ray.get(setup_ref) + self.logger.info("Trainer explorer setup confirmation received.") @register(dispatch_mode=Dispatch.ONE_TO_ALL) def sync_weight(self): @@ -838,7 +857,6 @@ def sync_weight(self): torch.distributed.broadcast(full_param, 0, group=self._model_update_group) del full_param if torch.distributed.get_rank() == 0: - torch.distributed.barrier(group=self._model_update_group) torch.cuda.synchronize() @register(dispatch_mode=Dispatch.ONE_TO_ALL) @@ -1300,7 +1318,28 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 if self.rank == 0: self.logger.info(f"Critic overriding config {override_config_kwargs}") + fsdp_config = self.config.model.fsdp_config + mixed_precision_config = fsdp_config.get("mixed_precision", None) + if mixed_precision_config is not None: + param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) + reduce_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("reduce_dtype", "fp32") + ) + buffer_dtype = PrecisionType.to_dtype( + mixed_precision_config.get("buffer_dtype", "fp32") + ) + else: + param_dtype = torch.bfloat16 + reduce_dtype = torch.float32 + buffer_dtype = torch.float32 + torch_dtype = self.config.model.fsdp_config.model_dtype or "fp32" + if ( + self._is_lora + and config.strategy == "fsdp2" + and self.config.model.fsdp_config.model_dtype is None + ): + torch_dtype = param_dtype torch_dtype = PrecisionType.to_dtype(torch_dtype) from transformers import AutoConfig @@ -1367,6 +1406,7 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 if self._is_lora: self.logger.info("Applying LoRA to critic module") critic_module.enable_input_require_grads() + autocast_adapter_dtype = config.strategy != "fsdp2" # Check if we should load a pre-trained LoRA adapter lora_adapter_path = self.config.model.get("lora_adapter_path") @@ -1383,7 +1423,10 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 ) critic_module = PeftModel.from_pretrained( - critic_module, local_adapter_path, is_trainable=True + critic_module, + local_adapter_path, + is_trainable=True, + autocast_adapter_dtype=autocast_adapter_dtype, ) peft_config = critic_module.peft_config["default"] # Ensure task_type is TaskType enum, not string @@ -1401,28 +1444,20 @@ def _build_critic_model_optimizer(self, config): # noqa: C901 "target_modules": convert_to_regular_types(self.config.model.target_modules), "bias": "none", } - critic_module = get_peft_model(critic_module, LoraConfig(**lora_config)) + critic_module = get_peft_model( + critic_module, + LoraConfig(**lora_config), + autocast_adapter_dtype=autocast_adapter_dtype, + ) + + if config.strategy == "fsdp2": + _align_trainable_param_dtype(critic_module, param_dtype) if self.rank == 0: print_model_size(critic_module) self.critic_model_config = critic_model_config - fsdp_config = self.config.model.fsdp_config - mixed_precision_config = fsdp_config.get("mixed_precision", None) - if mixed_precision_config is not None: - param_dtype = PrecisionType.to_dtype(mixed_precision_config.get("param_dtype", "bf16")) - reduce_dtype = PrecisionType.to_dtype( - mixed_precision_config.get("reduce_dtype", "fp32") - ) - buffer_dtype = PrecisionType.to_dtype( - mixed_precision_config.get("buffer_dtype", "fp32") - ) - else: - param_dtype = torch.bfloat16 - reduce_dtype = torch.float32 - buffer_dtype = torch.float32 - mixed_precision = MixedPrecision( param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype ) diff --git a/trinity/trainer/verl/megatron_workers.py b/trinity/trainer/verl/megatron_workers.py index 36576ad4550..26d6b470f13 100644 --- a/trinity/trainer/verl/megatron_workers.py +++ b/trinity/trainer/verl/megatron_workers.py @@ -782,7 +782,6 @@ def setup_weight_sync_group(self): world_size=world_size, rank=0, ) - torch.distributed.barrier(group=self._model_update_group) ray.get(setup_ref) @register(dispatch_mode=Dispatch.ONE_TO_ALL) @@ -797,7 +796,6 @@ def sync_weight(self): torch.distributed.broadcast(weight, 0, group=self._model_update_group) del weight if torch.distributed.get_rank() == 0: - torch.distributed.barrier(group=self._model_update_group) torch.cuda.synchronize() if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) diff --git a/trinity/trainer/verl/verl_trainer.py b/trinity/trainer/verl/verl_trainer.py index 1796c88e041..a06acb81505 100644 --- a/trinity/trainer/verl/verl_trainer.py +++ b/trinity/trainer/verl/verl_trainer.py @@ -191,6 +191,9 @@ def __init__( global_config: Config, ): self.logger = get_logger(__name__, in_ray_actor=True) + self.logger.info( + f"Initializing verl Trainer with {global_config.trainer.trainer_strategy} backend" + ) train_config = global_config.trainer config = OmegaConf.structured(train_config.trainer_config) # download the checkpoint from hdfs