Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
9b3f93a
Add PostTrainBench Docker evaluation runner
lewtun Apr 28, 2026
dc63bec
Default PostTrainBench agent model
lewtun Apr 28, 2026
a985dc5
Merge branch 'main' into post-train-bench-lewis
lewtun Apr 28, 2026
92c3d38
Fix PostTrainBench container agent launch
lewtun Apr 28, 2026
3c33e73
Include Slurm job id in PostTrainBench run ids
lewtun Apr 28, 2026
e8aa2a9
Document PostTrainBench run artifact tree
lewtun Apr 28, 2026
23094cc
Make smoke PostTrainBench runs five minutes
lewtun Apr 28, 2026
9f49a5d
Use shorter PostTrainBench config names
lewtun Apr 28, 2026
e908c47
Add MCP
lewtun Apr 28, 2026
17ee55c
Export PostTrainBench source snapshot path
lewtun Apr 28, 2026
12b3cef
Set descriptive PostTrainBench Slurm task names
lewtun Apr 28, 2026
3fe6a74
Set PostTrainBench Slurm time by mode
lewtun Apr 28, 2026
6a8d353
Limit PostTrainBench smoke evaluation
lewtun Apr 28, 2026
f025068
Shorten PostTrainBench Slurm job names
lewtun Apr 28, 2026
d339d32
Reference final model files in PostTrainBench artifacts
lewtun Apr 28, 2026
0204244
Add PTB prompt
lewtun Apr 28, 2026
f5fd9f2
Merge branch 'main' into post-train-bench-lewis
lewtun Apr 28, 2026
103a834
Amend system prompt
lewtun Apr 28, 2026
48d96da
Harden PostTrainBench runner isolation
lewtun Apr 29, 2026
1478bb2
Fix PostTrainBench eval image build
lewtun Apr 30, 2026
84e0373
Fix Codex judge CLI option ordering
lewtun Apr 30, 2026
066cc15
Fix Codex judge API key auth
lewtun Apr 30, 2026
0e49b36
Make PostTrainBench model validation strict
lewtun Apr 30, 2026
f2c4e43
Fix config
lewtun Apr 30, 2026
bbc6055
Harden PostTrainBench integrity checks
lewtun Apr 30, 2026
c2cc788
Detect PostTrainBench harness tampering
lewtun May 1, 2026
7904837
Harden PostTrainBench runner finalization
lewtun May 1, 2026
abd1a76
Install benchmark agent from writable build copy
lewtun May 1, 2026
3c8dc8b
Use task budget for measured solve timeout
lewtun May 1, 2026
2f90b2b
Make smoke budget strict and realistic
lewtun May 1, 2026
746c3df
Make PostTrainBench smoke deterministic
lewtun May 1, 2026
0a7b23d
Add PostTrainBench validation reprompt variant
lewtun May 4, 2026
aa6e5d4
Add per-model PostTrainBench validation mode
lewtun May 4, 2026
d4fe09d
Make PostTrainBench integrity runner Python-compatible
lewtun May 4, 2026
d0b39be
Tighten PostTrainBench final model prompt contract
lewtun May 5, 2026
2f7db44
Add ten-job PostTrainBench smoke mode
lewtun May 5, 2026
8ceb9d9
Add PostTrainBench array throttle option
lewtun May 5, 2026
4619b63
Strengthen PostTrainBench reprompt recovery
lewtun May 5, 2026
c1778a8
Avoid false secret scan hits on redacted env logs
lewtun May 5, 2026
e14612e
Guard PostTrainBench against broad process kills
lewtun May 6, 2026
39fb2f3
Retry transient streaming LLM failures
lewtun May 6, 2026
e88bac7
Ignore PTB bytecode caches in integrity checks
lewtun May 7, 2026
f2246b6
Merge origin/main into post-train-bench-lewis
lewtun May 11, 2026
6f947e4
Fix PTB bytecode cleanup find command
lewtun May 11, 2026
0067b15
Remove PTB secret scan gate
lewtun May 14, 2026
84c96d7
Match PTB baseline fallback scoring
lewtun May 14, 2026
e754864
Increase PTB full run walltime
lewtun May 14, 2026
d8b4915
Merge origin/main into post-train-bench-lewis
lewtun May 14, 2026
4c05cd5
Update CLI rendering tests for config forwarding
lewtun May 14, 2026
7a7c1b7
Address review feedback on streaming and PTB aggregation
lewtun May 14, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,6 @@ datasets/
models/
checkpoint-*/
runs/
post_train_bench/runs/
wandb/
frontend/tsconfig.tsbuildinfo
5 changes: 5 additions & 0 deletions agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Config(BaseModel):
mcpServers: dict[str, MCPServerConfig] = {}
save_sessions: bool = True
session_dataset_repo: str = "smolagents/ml-intern-sessions"
upload_sessions: bool = True
# Per-user private dataset that mirrors each session in Claude Code JSONL
# format so the HF Agent Trace Viewer auto-renders it
# (https://huggingface.co/changelog/agent-trace-viewer). Created private
Expand All @@ -42,6 +43,10 @@ class Config(BaseModel):
heartbeat_interval_s: int = 60
yolo_mode: bool = False # Auto-approve all tool calls without confirmation
max_iterations: int = 300 # Max LLM calls per agent turn (-1 = unlimited)
# Bare filenames resolve under agent/prompts/. Absolute paths and relative
# paths with directory components are used exactly as configured.
system_prompt_file: str = "system_prompt_v3.yaml"
disabled_tools: list[str] = []

# Permission control parameters
confirm_cpu_jobs: bool = True
Expand Down
12 changes: 10 additions & 2 deletions agent/context_manager/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,16 @@ def _load_system_prompt(
hf_token: str | None = None,
local_mode: bool = False,
):
"""Load and render the system prompt from YAML file with Jinja2"""
prompt_file = Path(__file__).parent.parent / "prompts" / f"{prompt_file_suffix}"
"""Load and render the system prompt YAML file with Jinja2.

Bare prompt filenames are looked up under ``agent/prompts/``. Absolute
paths and relative paths with directory components are explicit paths.
"""
configured_path = Path(prompt_file_suffix)
Comment thread
lewtun marked this conversation as resolved.
if configured_path.is_absolute() or configured_path.parent != Path("."):
prompt_file = configured_path
else:
prompt_file = Path(__file__).parent.parent / "prompts" / prompt_file_suffix

with open(prompt_file, "r") as f:
prompt_data = yaml.safe_load(f)
Expand Down
247 changes: 148 additions & 99 deletions agent/core/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import asyncio
import json
import logging
import random
import time
from dataclasses import dataclass, field
from pathlib import Path
Expand Down Expand Up @@ -403,9 +404,9 @@ async def _record_manual_approved_spend_if_needed(


# -- LLM retry constants --------------------------------------------------
_MAX_LLM_RETRIES = 3
_LLM_RETRY_DELAYS = [5, 15, 30] # seconds between retries
_LLM_RATE_LIMIT_RETRY_DELAYS = [30, 60] # exceed Bedrock's ~60s TPM bucket window
_MAX_LLM_RETRIES = 8
_LLM_RETRY_DELAYS = [15, 30, 60, 120, 300, 600, 600] # seconds between retries
_LLM_RATE_LIMIT_RETRY_DELAYS = [60, 120, 300, 600, 600, 600, 600]


def _is_rate_limit_error(error: Exception) -> bool:
Expand Down Expand Up @@ -455,6 +456,12 @@ def _retry_delay_for(error: Exception, attempt_index: int) -> int | None:
return schedule[attempt_index]


def _retry_delay_with_jitter(delay: int) -> int:
"""Add bounded jitter to avoid synchronized retry bursts."""
jitter = random.randint(0, max(1, min(60, delay // 5)))
return delay + jitter


def _is_transient_error(error: Exception) -> bool:
"""Return True for errors that are likely transient and worth retrying."""
err_str = str(error).lower()
Expand Down Expand Up @@ -852,12 +859,39 @@ async def _call_llm_streaming(
session: Session, messages, tools, llm_params
) -> LLMResult:
"""Call the LLM with streaming, emitting assistant_chunk events."""
response = None
_healed_effort = False # one-shot safety net per call
_healed_thinking_signature = False
messages, tools = with_prompt_caching(messages, tools, llm_params.get("model"))
t_start = time.monotonic()

async def _send_stream_reset_if_needed(
emitted_assistant_chunk: bool,
*,
attempt_index: int,
delay_s: int | None = None,
reason: str,
) -> None:
if not emitted_assistant_chunk:
return
data = {
"attempt": attempt_index + 1,
"next_attempt": attempt_index + 2,
"max_attempts": _MAX_LLM_RETRIES,
"reason": reason,
}
if delay_s is not None:
data["delay_s"] = delay_s
await session.send_event(Event(event_type="assistant_stream_reset", data=data))

for _llm_attempt in range(_MAX_LLM_RETRIES):
full_content = ""
emitted_assistant_chunk = False
tool_calls_acc: dict[int, dict] = {}
token_count = 0
finish_reason = None
final_usage_chunk = None
chunks = []
should_replay_thinking = _should_replay_thinking_state(llm_params.get("model"))
try:
response = await acompletion(
messages=messages,
Expand All @@ -868,7 +902,90 @@ async def _call_llm_streaming(
timeout=600,
**llm_params,
)
break

async for chunk in response:
chunks.append(chunk)
if session.is_cancelled:
tool_calls_acc.clear()
break

choice = chunk.choices[0] if chunk.choices else None
if not choice:
if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens
final_usage_chunk = chunk
continue

delta = choice.delta
if choice.finish_reason:
finish_reason = choice.finish_reason

if delta.content:
full_content += delta.content
emitted_assistant_chunk = True
await session.send_event(
Event(
event_type="assistant_chunk",
data={"content": delta.content},
)
)

if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""},
}
if tc_delta.id:
tool_calls_acc[idx]["id"] = tc_delta.id
if tc_delta.function:
if tc_delta.function.name:
tool_calls_acc[idx]["function"]["name"] += (
tc_delta.function.name
)
if tc_delta.function.arguments:
tool_calls_acc[idx]["function"]["arguments"] += (
tc_delta.function.arguments
)

if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens
final_usage_chunk = chunk

usage = await telemetry.record_llm_call(
session,
model=llm_params.get("model", session.config.model_name),
response=final_usage_chunk,
latency_ms=int((time.monotonic() - t_start) * 1000),
finish_reason=finish_reason,
)
thinking_blocks = None
reasoning_content = None
if chunks and should_replay_thinking:
try:
rebuilt = stream_chunk_builder(chunks, messages=messages)
if rebuilt and getattr(rebuilt, "choices", None):
rebuilt_msg = rebuilt.choices[0].message
thinking_blocks, reasoning_content = _extract_thinking_state(
rebuilt_msg
)
except Exception:
logger.debug(
"Failed to rebuild streaming thinking state", exc_info=True
)

return LLMResult(
content=full_content or None,
tool_calls_acc=tool_calls_acc,
token_count=token_count,
finish_reason=finish_reason,
usage=usage,
thinking_blocks=thinking_blocks,
reasoning_content=reasoning_content,
)
except ContextWindowExceededError:
raise
except Exception as e:
Expand All @@ -879,6 +996,11 @@ async def _call_llm_streaming(
llm_params = await _heal_effort_and_rebuild_params(
session, e, llm_params
)
await _send_stream_reset_if_needed(
emitted_assistant_chunk,
attempt_index=_llm_attempt,
reason="effort_config_retry",
)
await session.send_event(
Event(
event_type="tool_log",
Expand All @@ -896,115 +1018,41 @@ async def _call_llm_streaming(
already_healed=_healed_thinking_signature,
):
_healed_thinking_signature = True
await _send_stream_reset_if_needed(
emitted_assistant_chunk,
attempt_index=_llm_attempt,
reason="thinking_signature_retry",
)
continue
_delay = _retry_delay_for(e, _llm_attempt)
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
_sleep_delay = _retry_delay_with_jitter(_delay)
logger.warning(
"Transient LLM error (attempt %d/%d): %s — retrying in %ds",
"Transient LLM streaming error (attempt %d/%d): %s — retrying in %ds",
_llm_attempt + 1,
_MAX_LLM_RETRIES,
e,
_delay,
_sleep_delay,
)
await _send_stream_reset_if_needed(
emitted_assistant_chunk,
attempt_index=_llm_attempt,
delay_s=_sleep_delay,
reason="transient_error_retry",
)
await session.send_event(
Event(
event_type="tool_log",
data={
"tool": "system",
"log": f"LLM connection error, retrying in {_delay}s...",
"log": f"LLM stream error, retrying in {_sleep_delay}s...",
},
)
)
await asyncio.sleep(_delay)
await asyncio.sleep(_sleep_delay)
continue
raise

full_content = ""
tool_calls_acc: dict[int, dict] = {}
token_count = 0
finish_reason = None
final_usage_chunk = None
chunks = []
should_replay_thinking = _should_replay_thinking_state(llm_params.get("model"))

async for chunk in response:
chunks.append(chunk)
if session.is_cancelled:
tool_calls_acc.clear()
break

choice = chunk.choices[0] if chunk.choices else None
if not choice:
if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens
final_usage_chunk = chunk
continue

delta = choice.delta
if choice.finish_reason:
finish_reason = choice.finish_reason

if delta.content:
full_content += delta.content
await session.send_event(
Event(event_type="assistant_chunk", data={"content": delta.content})
)

if delta.tool_calls:
for tc_delta in delta.tool_calls:
idx = tc_delta.index
if idx not in tool_calls_acc:
tool_calls_acc[idx] = {
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""},
}
if tc_delta.id:
tool_calls_acc[idx]["id"] = tc_delta.id
if tc_delta.function:
if tc_delta.function.name:
tool_calls_acc[idx]["function"]["name"] += (
tc_delta.function.name
)
if tc_delta.function.arguments:
tool_calls_acc[idx]["function"]["arguments"] += (
tc_delta.function.arguments
)

if hasattr(chunk, "usage") and chunk.usage:
token_count = chunk.usage.total_tokens
final_usage_chunk = chunk

usage = await telemetry.record_llm_call(
session,
model=llm_params.get("model", session.config.model_name),
response=final_usage_chunk,
latency_ms=int((time.monotonic() - t_start) * 1000),
finish_reason=finish_reason,
)
thinking_blocks = None
reasoning_content = None
if chunks and should_replay_thinking:
try:
rebuilt = stream_chunk_builder(chunks, messages=messages)
if rebuilt and getattr(rebuilt, "choices", None):
rebuilt_msg = rebuilt.choices[0].message
thinking_blocks, reasoning_content = _extract_thinking_state(
rebuilt_msg
)
except Exception:
logger.debug("Failed to rebuild streaming thinking state", exc_info=True)

return LLMResult(
content=full_content or None,
tool_calls_acc=tool_calls_acc,
token_count=token_count,
finish_reason=finish_reason,
usage=usage,
thinking_blocks=thinking_blocks,
reasoning_content=reasoning_content,
)


async def _call_llm_non_streaming(
session: Session, messages, tools, llm_params
Expand Down Expand Up @@ -1056,23 +1104,24 @@ async def _call_llm_non_streaming(
continue
_delay = _retry_delay_for(e, _llm_attempt)
if _llm_attempt < _MAX_LLM_RETRIES - 1 and _delay is not None:
_sleep_delay = _retry_delay_with_jitter(_delay)
logger.warning(
"Transient LLM error (attempt %d/%d): %s — retrying in %ds",
_llm_attempt + 1,
_MAX_LLM_RETRIES,
e,
_delay,
_sleep_delay,
)
await session.send_event(
Event(
event_type="tool_log",
data={
"tool": "system",
"log": f"LLM connection error, retrying in {_delay}s...",
"log": f"LLM connection error, retrying in {_sleep_delay}s...",
},
)
)
await asyncio.sleep(_delay)
await asyncio.sleep(_sleep_delay)
continue
raise

Expand Down Expand Up @@ -2139,7 +2188,7 @@ async def submission_loop(
# Retry any failed uploads from previous sessions (fire-and-forget).
# Includes the personal trace repo when enabled so a session that failed
# to publish to the user's HF dataset gets a fresh attempt on next run.
if config and config.save_sessions:
if config and config.save_sessions and config.upload_sessions:
Session.retry_failed_uploads_detached(
directory=str(DEFAULT_SESSION_LOG_DIR),
repo_id=config.session_dataset_repo,
Expand Down
Loading
Loading