Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion runner/primus-cli-direct.sh
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,10 @@ print_section ""
# STEP 11: Execute command
###############################################################################
# Temporarily allow pipeline to fail so we can capture PIPESTATUS and log it
set +e
eval "$CMD"
exit_code=$?

set -e
# Print result based on exit code
if [[ $exit_code -ge 128 ]]; then
LOG_ERROR "[direct] torchrun crashed due to signal $((exit_code - 128))"
Expand Down
46 changes: 2 additions & 44 deletions tests/trainer/test_maxtext_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,11 @@
###############################################################################

import os
import subprocess
import sys
import time

import pytest
from absl.testing import absltest

from primus.core.utils import logger
from tests.utils import PrimusUT
from tests.utils import PrimusUT, run_training_script

SKIP_TEST = os.getenv("JAX_SKIP_UT", "0") == "1"

Expand All @@ -36,49 +32,11 @@ def run_script(
train_log_path = os.path.join(ut_log_path, f"log.test_maxtext_trainer-{tag}.txt")
env["TRAIN_LOG"] = train_log_path

do_print_at_runtime = True
run_stdout = subprocess.PIPE if not do_print_at_runtime else sys.stdout
run_stderr = subprocess.PIPE if not do_print_at_runtime else sys.stderr

cmd = ["bash", shell_entry]
if extra_args:
cmd.extend(extra_args)

try:
logger.info(f"[{tag}] Begin MaxText run...")
start = time.time()
subprocess.run(
cmd,
check=True,
stdout=run_stdout,
stderr=run_stderr,
text=True,
env=env,
)
logger.info(f"[{tag}] End run, time={time.time() - start:.3f} s")

with open(train_log_path, "r") as f:
stdout_output = f.read()

return stdout_output, ""

except subprocess.CalledProcessError as e:
stderr_output = e.stderr or ""
stdout_output = e.stdout or ""
if os.path.exists(train_log_path):
try:
with open(train_log_path, "r") as f:
stdout_output = f.read()
except Exception as log_err:
logger.warning(f"[{tag}] Failed to read train log: {log_err}")

if "after training is done" in stdout_output:
logger.warning(f"[{tag}] Training likely succeeded despite return code != 0.")
logger.warning(f"stderr excerpt:\n{stderr_output[:1000]}")
else:
raise AssertionError(f"Shell script failed: {stderr_output.strip()}")

return stdout_output, stderr_output
return run_training_script(tag=tag, cmd=cmd, train_log_path=train_log_path, env=env)


class TestMaxTextTrainer(PrimusUT):
Expand Down
54 changes: 2 additions & 52 deletions tests/trainer/test_megatron_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@

import os
import re
import subprocess
import sys
import time
import unittest

from primus.core.utils import logger
from tests.utils import PrimusUT
from tests.utils import PrimusUT, run_training_script

_GFX_TO_PLATFORM = {
"gfx942": "MI300X",
Expand Down Expand Up @@ -57,13 +53,6 @@ def run_script(
train_log_path = os.path.join(ut_log_path, f"log.test_megatron_trainer-{tag}.txt")
env["TRAIN_LOG"] = train_log_path

# Follow the same UT pattern as TorchTitan trainer tests:
# - print logs at runtime to the console
# - read final output from TRAIN_LOG if present
do_print_at_runtime = True
run_stdout = subprocess.PIPE if not do_print_at_runtime else sys.stdout
run_stderr = subprocess.PIPE if not do_print_at_runtime else sys.stderr

cmd = [
"bash",
shell_entry,
Expand All @@ -79,46 +68,7 @@ def run_script(
if extra_args:
cmd.extend(extra_args)

try:
logger.info(f"Begin run {tag}...")
start = time.time()
subprocess.run(
cmd,
check=True,
stdout=run_stdout,
stderr=run_stderr,
text=True,
env=env,
)
logger.info(f"End run {tag}, time={time.time()-start:.3f} s")

logger.info(f"Training log path: {ut_log_path}/logs/UT-{ut_name}")

stdout_output = ""
if os.path.exists(train_log_path):
with open(train_log_path, "r") as f:
stdout_output = f.read()

return stdout_output, ""

except subprocess.CalledProcessError as e:
stderr_output = e.stderr or ""
stdout_output = e.stdout or ""

if os.path.exists(train_log_path):
try:
with open(train_log_path, "r") as f:
stdout_output = f.read()
except Exception as log_err:
logger.warning(f"[{tag}] Failed to read train log: {log_err}")

if "Training completed." in stdout_output:
logger.warning(f"[{tag}] Training likely succeeded despite return code != 0.")
logger.warning(f"stderr excerpt:\n{stderr_output[:1000]}")
else:
raise AssertionError(f"Shell script failed: {stderr_output.strip()}")

return stdout_output, stderr_output
return run_training_script(tag=tag, cmd=cmd, train_log_path=train_log_path, env=env)


class TestMegatronTrainer(PrimusUT):
Expand Down
48 changes: 2 additions & 46 deletions tests/trainer/test_torchtitan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
###############################################################################

import os
import subprocess
import sys
import time

from primus.core.utils import logger
from tests.utils import PrimusUT
from tests.utils import PrimusUT, run_training_script


def run_script(
Expand All @@ -30,10 +26,6 @@ def run_script(
train_log_path = os.path.join(ut_log_path, f"log.test_torchtitan_trainer-{tag}.txt")
env["TRAIN_LOG"] = train_log_path

do_print_at_runtime = True
run_stdout = subprocess.PIPE if not do_print_at_runtime else sys.stdout
run_stderr = subprocess.PIPE if not do_print_at_runtime else sys.stderr

cmd = [
"bash",
shell_entry,
Expand All @@ -49,43 +41,7 @@ def run_script(
if extra_args:
cmd.extend(extra_args)

try:
logger.info(f"[{tag}] Begin Titan run...")
start = time.time()
subprocess.run(
cmd,
check=True,
stdout=run_stdout,
stderr=run_stderr,
text=True,
env=env,
)
logger.info(f"[{tag}] End run, time={time.time() - start:.3f} s")

stdout_output = ""
if os.path.exists(train_log_path):
with open(train_log_path, "r") as f:
stdout_output = f.read()

return stdout_output, ""

except subprocess.CalledProcessError as e:
stderr_output = e.stderr or ""
stdout_output = e.stdout or ""
if os.path.exists(train_log_path):
try:
with open(train_log_path, "r") as f:
stdout_output = f.read()
except Exception as log_err:
logger.warning(f"[{tag}] Failed to read train log: {log_err}")

if "Training completed." in stdout_output:
logger.warning(f"[{tag}] Training likely succeeded despite return code != 0.")
logger.warning(f"stderr excerpt:\n{stderr_output[:1000]}")
else:
raise AssertionError(f"Shell script failed: {stderr_output.strip()}")

return stdout_output, stderr_output
return run_training_script(tag=tag, cmd=cmd, train_log_path=train_log_path, env=env)


class TestTorchTitanTrainer(PrimusUT):
Expand Down
80 changes: 80 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@


import os
import subprocess
import sys
import time
import unittest
from typing import Optional

from primus.core.utils import logger

TRAINING_COMPLETED_MARKER = "Training completed."


class PrimusUT(unittest.TestCase):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -37,3 +43,77 @@ def setUp(self):

def tearDown(self):
pass


def run_training_script(
tag: str,
cmd: list[str],
train_log_path: str,
env: Optional[dict] = None,
) -> tuple[str, str]:
"""Execute a training command and validate that training completed successfully.

Runs the command via subprocess, streams output to console, then reads the
training log file and asserts that the PrimusRuntime "Training completed."
marker is present. This catches silent failures where the process exits 0
but training did not actually finish.

Args:
tag: Human-readable label for log messages (e.g. "llama3_8B").
cmd: Command to execute (passed to subprocess.run).
train_log_path: Path to the training log file written by the launcher.
env: Environment variables for the subprocess.

Returns:
(stdout_output, stderr_output) tuple where stdout_output is the
content of train_log_path.

Raises:
AssertionError: If training did not complete successfully.
"""
try:
logger.info(f"[{tag}] Begin run...")
start = time.time()
subprocess.run(
cmd,
check=True,
stdout=sys.stdout,
stderr=sys.stderr,
text=True,
env=env,
)
logger.info(f"[{tag}] End run, time={time.time() - start:.3f} s")
logger.info(f"[{tag}] Training log: {train_log_path}")

stdout_output = ""
if os.path.exists(train_log_path):
with open(train_log_path, "r") as f:
stdout_output = f.read()

if TRAINING_COMPLETED_MARKER not in stdout_output:
raise AssertionError(
f"[{tag}] Process exited with code 0 but '{TRAINING_COMPLETED_MARKER}' "
f"not found in log output. Training may have failed silently.\n"
f"Log file: {train_log_path}"
)

return stdout_output, ""

except subprocess.CalledProcessError as e:
stderr_output = e.stderr or ""
stdout_output = e.stdout or ""

if os.path.exists(train_log_path):
try:
with open(train_log_path, "r") as f:
stdout_output = f.read()
except Exception as log_err:
logger.warning(f"[{tag}] Failed to read train log: {log_err}")

if TRAINING_COMPLETED_MARKER in stdout_output:
logger.warning(f"[{tag}] Training likely succeeded despite return code != 0.")
logger.warning(f"stderr excerpt:\n{stderr_output[:1000]}")
else:
raise AssertionError(f"[{tag}] Shell script failed: {stderr_output.strip()}")

return stdout_output, stderr_output
Loading