Skip to content
Open
4 changes: 2 additions & 2 deletions .github/workflows/docker/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
trinity-node-1:
image: trinity-rft-unittest:20260506
image: trinity-rft-unittest:20260512
cap_add:
- SYS_PTRACE
pull_policy: never
Expand Down Expand Up @@ -34,7 +34,7 @@ services:
capabilities: [gpu]

trinity-node-2:
image: trinity-rft-unittest:20260506
image: trinity-rft-unittest:20260512
cap_add:
- SYS_PTRACE
pull_policy: never
Expand Down
2 changes: 2 additions & 0 deletions perf/scripts/explorer/perf_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,6 @@ async def run_async(self) -> List[Experience]:
else 0.0
),
}
for exp in exps:
exp.reward = 1.0
return exps
9 changes: 9 additions & 0 deletions scripts/docker/Dockerfile.uv
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ RUN . /opt/venv/bin/activate && MAX_JOBS=${BUILD_JOBS} \
--config-settings="--build-option=--cuda_ext" \
git+https://github.com/NVIDIA/apex.git

# Install SGLang
RUN RUSTUP_HOME=/usr/local/rustup CARGO_HOME=/usr/local/cargo \
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
RUN . /opt/venv/bin/activate && . /usr/local/cargo/env && \
git clone -b v0.5.11 https://github.com/sgl-project/sglang.git /tmp/sglang && \
uv pip install /tmp/sglang/python && \
rm -rf /tmp/sglang && \
uv pip install transformers==5.8.0

# Set Env variables
# ENV LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/lib64

Expand Down
21 changes: 21 additions & 0 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,27 @@ def test_inference_model_without_base_port_uses_ephemeral_port(self):

self.assertGreater(port, 0)

def test_inference_model_random_port_ignores_base_port(self):
requested_port = 9005
model = DummyInferenceModel(InferenceModelConfig(base_port=9000, engine_id=5))

_, port = model.get_available_address(random_port=True)

self.assertNotEqual(port, requested_port)
self.assertGreater(port, 0)

def test_inference_model_random_port_can_use_port_reserved_by_api_server(self):
requested_port = 9006
model = DummyInferenceModel(InferenceModelConfig(base_port=9000, engine_id=6))

with socket.socket() as occupied_socket:
occupied_socket.bind(("", requested_port))

_, port = model.get_available_address(random_port=True)

self.assertNotEqual(port, requested_port)
self.assertGreater(port, 0)

def test_load_default_config(self):
config = get_template_config()
config.buffer.batch_size = 8
Expand Down
282 changes: 282 additions & 0 deletions tests/common/sglang_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
import asyncio

from parameterized import parameterized_class
from transformers import AutoTokenizer

from tests.tools import (
CHAT_TEMPLATE,
RayUnittestBaseAsync,
get_model_path,
get_template_config,
)
from trinity.common.models import create_explorer_models
from trinity.common.models.model import ModelWrapper


async def prepare_engines(engines, auxiliary_engines):
prepare_model_refs = []
for engine in engines:
prepare_model_refs.append(engine.prepare.remote())
for engines in auxiliary_engines:
for engine in engines:
prepare_model_refs.append(engine.prepare.remote())
await asyncio.gather(*prepare_model_refs)


def assert_experience_tokens_match_text(test_case, tokenizer, exp, prompt_contents, response_text):
full_text = tokenizer.decode(exp.tokens.tolist(), skip_special_tokens=False)
prompt_text = tokenizer.decode(
exp.tokens[: exp.prompt_length].tolist(), skip_special_tokens=False
)
decoded_response_text = tokenizer.decode(
exp.tokens[exp.prompt_length :].tolist(), skip_special_tokens=False
)

for prompt_content in prompt_contents:
test_case.assertIn(prompt_content, full_text)
test_case.assertIn(prompt_content, prompt_text)
test_case.assertIn(response_text, full_text)
test_case.assertIn(response_text, decoded_response_text)


@parameterized_class(
(
"tensor_parallel_size",
"engine_num",
"enable_history",
),
[
(1, 1, True),
(1, 2, False),
(2, 1, True),
],
)
class TestSGLangOpenAIAPI(RayUnittestBaseAsync):
def setUp(self):
self.config = get_template_config()
self.config.mode = "explore"
self.config.model.model_path = get_model_path()
self.config.explorer.rollout_model.engine_type = "sglang"
self.config.explorer.rollout_model.engine_num = self.engine_num
self.config.explorer.rollout_model.tensor_parallel_size = self.tensor_parallel_size
self.config.explorer.rollout_model.chat_template = CHAT_TEMPLATE
self.config.explorer.rollout_model.enable_openai_api = True
self.config.explorer.rollout_model.base_port = 13000
self.config.check_and_update()

self.engines, self.auxiliary_engines = create_explorer_models(self.config)
self.model_wrapper = ModelWrapper(self.engines[0], enable_history=self.enable_history)
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.model.model_path,
trust_remote_code=self.config.explorer.rollout_model.trust_remote_code,
)

def _assert_experience_matches_text(self, exp, prompt_contents, response_text):
self.assertGreater(exp.prompt_length, 0)
self.assertGreater(len(exp.tokens), exp.prompt_length)
assert_experience_tokens_match_text(
self, self.tokenizer, exp, prompt_contents, response_text
)

def _assert_history_matches_responses(self, expected_count, prompt_contents, response_texts):
if not self.enable_history:
self.assertEqual(len(self.model_wrapper.history), 0)
return []

exps = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(exps), expected_count)
for exp, response_text in zip(exps, response_texts):
self.assertEqual(exp.response_text, response_text)
self._assert_experience_matches_text(exp, prompt_contents, response_text)
return exps

def _get_tool_call_case(self):
tool_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "Use the weather tool result to answer what the weather is in Boston.",
},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_weather_boston",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": '{"location":"Boston","unit":"fahrenheit"}',
},
}
],
},
{
"role": "tool",
"tool_call_id": "call_weather_boston",
"content": "The weather in Boston is 72 F.",
},
]
tool_prompt_contents = [
tool_messages[0]["content"],
tool_messages[1]["content"],
tool_messages[3]["content"],
]
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]
return tool_messages, tool_prompt_contents, tools

async def _collect_response_texts(self, response):
response_texts = []
for choice in response.choices:
self.assertIsNotNone(choice.message.content)
self.assertGreater(len(choice.message.content), 0)
response_texts.append(choice.message.content)
return response_texts

async def _collect_stream_contents(self, stream_response, n):
contents = ["" for _ in range(n)]
async for chunk in stream_response:
for choice in chunk.choices:
if choice.delta.content is not None:
contents[choice.index] += choice.delta.content
return contents

async def test_chat_completions(self):
await prepare_engines(self.engines, self.auxiliary_engines)
await self.model_wrapper.prepare()

self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path)
self.assertIsNotNone(self.model_wrapper.api_address)

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write one short sentence about Boston."},
]
prompt_contents = [message["content"] for message in messages]

openai_client = self.model_wrapper.get_openai_async_client()
response = await openai_client.chat.completions.create(
model=openai_client.model_path,
messages=messages,
n=2,
temperature=0.7,
max_tokens=32,
)

self.assertEqual(len(response.choices), 2)
response_texts = await self._collect_response_texts(response)
self._assert_history_matches_responses(2, prompt_contents, response_texts)

tool_messages, tool_prompt_contents, tools = self._get_tool_call_case()
tool_response = await openai_client.chat.completions.create(
model=openai_client.model_path,
messages=tool_messages,
tools=tools,
tool_choice="none",
temperature=0.7,
max_tokens=32,
)

self.assertEqual(len(tool_response.choices), 1)
tool_response_texts = await self._collect_response_texts(tool_response)
self._assert_history_matches_responses(1, tool_prompt_contents, tool_response_texts)

if not self.enable_history:
stream_response = await openai_client.chat.completions.create(
model=openai_client.model_path,
messages=messages,
n=2,
stream=True,
temperature=0.7,
max_tokens=32,
)
stream_contents = await self._collect_stream_contents(stream_response, 2)

self.assertEqual(len(stream_contents), 2)
for content in stream_contents:
self.assertGreater(len(content), 0)
self._assert_history_matches_responses(2, prompt_contents, stream_contents)

stream_tool_response = await openai_client.chat.completions.create(
model=openai_client.model_path,
messages=tool_messages,
tools=tools,
tool_choice="none",
n=1,
stream=True,
temperature=0.7,
max_tokens=32,
)
stream_tool_contents = await self._collect_stream_contents(stream_tool_response, 1)

self.assertEqual(len(stream_tool_contents), 1)
self.assertGreater(len(stream_tool_contents[0]), 0)
self._assert_history_matches_responses(1, tool_prompt_contents, stream_tool_contents)

chat_exps = await self.model_wrapper.chat_async(
messages,
n=2,
temperature=0.7,
max_tokens=32,
)

self.assertEqual(len(chat_exps), 2)
for exp in chat_exps:
self.assertGreater(len(exp.response_text), 0)
self.assertGreater(exp.prompt_length, 0)
self.assertGreater(len(exp.tokens), exp.prompt_length)

if self.enable_history:
chat_history = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(chat_history), 2)
for exp, recorded_exp in zip(chat_exps, chat_history):
self.assertEqual(recorded_exp.response_text, exp.response_text)
self.assertEqual(recorded_exp.prompt_length, exp.prompt_length)
self._assert_experience_matches_text(
recorded_exp, prompt_contents, exp.response_text
)
else:
self.assertEqual(len(self.model_wrapper.history), 0)

generate_prompt = "Write one short sentence about Boston."
generate_exps = await self.model_wrapper.generate_async(
[generate_prompt],
n=2,
temperature=0.7,
max_tokens=32,
)

self.assertEqual(len(generate_exps), 2)
for exp in generate_exps:
self.assertEqual(exp.prompt_text, generate_prompt)
self.assertGreater(len(exp.response_text), 0)
self.assertGreater(exp.prompt_length, 0)
self.assertGreater(len(exp.tokens), exp.prompt_length)

if self.enable_history:
generate_history = self.model_wrapper.extract_experience_from_history()
self.assertEqual(len(generate_history), 2)
for exp, recorded_exp in zip(generate_exps, generate_history):
self.assertEqual(recorded_exp.response_text, exp.response_text)
self.assertEqual(recorded_exp.prompt_text, exp.prompt_text)
self._assert_experience_matches_text(
recorded_exp, [generate_prompt], exp.response_text
)
else:
self.assertEqual(len(self.model_wrapper.history), 0)
17 changes: 9 additions & 8 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def setUp(self):


@parameterized_class(
("strategy",),
("strategy", "engine_type"),
[
("fsdp2",),
("megatron",),
("fsdp2", "sglang"),
("megatron", "vllm"),
],
)
class TestTrainerCountdown(BaseTrainerCase):
Expand All @@ -92,6 +92,7 @@ def test_trainer(self):
"original_max_position_embeddings": 16384,
}
self.config.model.rope_theta = 10000
self.config.explorer.rollout_model.engine_type = self.engine_type
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("countdown")
self.config.buffer.explorer_input.taskset.data_selector = DataSelectorConfig(
selector_type="shuffle", seed=42
Expand Down Expand Up @@ -239,12 +240,11 @@ def tearDown(self):


@parameterized_class(
("fsdp_strategy", "offloading"),
("fsdp_strategy", "offloading", "engine_type"),
[
("fsdp", False),
("fsdp2", False),
("fsdp", True),
("fsdp2", True),
("fsdp", False, "vllm"),
("fsdp2", False, "vllm"),
("fsdp2", True, "sglang"),
],
)
class TestTrainerGSM8K(BaseTrainerCase):
Expand All @@ -259,6 +259,7 @@ def test_trainer(self):
}
# self.config.algorithm.repeat_times = 8 # TODO: used for real testing
# self.config.buffer.batch_size = 96 # TODO: used for real testing
self.config.explorer.rollout_model.engine_type = self.engine_type
self.config.buffer.total_epochs = 1
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.trainer.trainer_strategy = self.fsdp_strategy
Expand Down
Loading
Loading