Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,83 @@
"""

import argparse
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional

import yaml


REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL")
if COMLRL_ROOT not in sys.path:
sys.path.insert(0, COMLRL_ROOT)


@dataclass(frozen=True)
class ModelConfig:
"""Configuration for model loading and generation."""

name: str
type: str = "qwen"
temperature: float = 0.7
top_p: float = 0.9
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
max_length: int = 2048
special_tokens: Dict[str, str] = field(default_factory=dict)
torch_dtype: Optional[str] = None

@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "ModelConfig":
def from_dict(
cls,
config_dict: Dict[str, Any],
*,
require_sampling: bool = True,
) -> "ModelConfig":
"""Create ModelConfig from dictionary."""
if require_sampling:
missing = [
key
for key in ("temperature", "top_p", "top_k")
if key not in config_dict
]
if missing:
raise ValueError(
f"agent_model is missing required sampling fields: {', '.join(missing)}"
)

def _as_optional_float(value: Any) -> Optional[float]:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"Invalid float value: {value}") from exc

def _as_optional_int(value: Any) -> Optional[int]:
if value is None:
return None
if isinstance(value, str) and value.strip().lower() in ("none", "null", ""):
return None
try:
return int(float(value))
except (TypeError, ValueError) as exc:
raise ValueError(f"Invalid int value: {value}") from exc

temperature = _as_optional_float(config_dict.get("temperature"))
top_p = _as_optional_float(config_dict.get("top_p"))
top_k = _as_optional_int(config_dict.get("top_k"))
if require_sampling and (temperature is None or top_p is None):
raise ValueError("agent_model.temperature and agent_model.top_p must be non-null.")

return cls(
name=config_dict.get("name", ""),
type=config_dict.get("type", "qwen"),
temperature=config_dict.get("temperature", 0.7),
top_p=config_dict.get("top_p", 0.9),
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_length=config_dict.get("max_length", 2048),
special_tokens=config_dict.get("special_tokens", {}),
torch_dtype=(
Expand Down Expand Up @@ -74,7 +124,7 @@ def get_agent_model_config(self) -> ModelConfig:
model_section = self.get_section("agent_model")
if not model_section:
raise ValueError("No 'agent_model' section found in configuration")
return ModelConfig.from_dict(model_section)
return ModelConfig.from_dict(model_section, require_sampling=True)

def get_critic_model_config(self, required: bool = True) -> Optional[ModelConfig]:
"""Get critic model configuration as ModelConfig object."""
Expand All @@ -83,7 +133,7 @@ def get_critic_model_config(self, required: bool = True) -> Optional[ModelConfig
if required:
raise ValueError("No 'critic_model' section found in configuration")
return None
return ModelConfig.from_dict(critic_section)
return ModelConfig.from_dict(critic_section, require_sampling=False)

def update(self, updates: Dict[str, Any]):
"""Update configuration with new values (deep merge)."""
Expand Down
50 changes: 25 additions & 25 deletions configs/ac_arxiv_config.yaml
Original file line number Diff line number Diff line change
@@ -1,50 +1,47 @@
agent_model:
name: "Qwen/Qwen3-1.7B"
type: "qwen"
name: Qwen/Qwen3-1.7B
type: qwen
temperature: 0.7
top_p: 0.9
top_k: null
max_length: 2048
torch_dtype: "auto"
torch_dtype: auto

agents: null

critic_model:
name: "Qwen/Qwen3-1.7B"
type: "qwen"
temperature: 0.7
top_p: 0.9
name: Qwen/Qwen3-1.7B
type: qwen
max_length: 2048
torch_dtype: "auto"
torch_dtype: auto

critics: null

dataset:
name: "OpenMLRL/arXiv_abstract"
type: "arxiv"
train_split: "train[:1000]"
eval_split: "val[:1000]"
name: OpenMLRL/arXiv_abstract
type: arxiv
train_split: train[:1000]
eval_split: val[:1000]

tokenizer:
padding_side: "left"
padding_side: left

output:
base_dir: "./ac_output"
base_dir: output_ac_arxiv
verbose: false
save_final_model: true
save_path: "./ac_output/ac_arxiv"
save_path: output_ac_arxiv

ac:
parallel_training: none
num_turns: 1
num_train_epochs: 4
agent_learning_rate: 5.0e-6
critic_learning_rate: 5.0e-6
agent_learning_rate: 5.0e-06
critic_learning_rate: 5.0e-06
value_loss_coef: 0.6
advantage_normalization: true
rollout_buffer_size: 4
max_new_tokens: 512
temperature: 0.7
top_p: 0.9
top_k: null
advantage_normalization: true
eval_interval: 20
eval_num_samples: 4
eval_batch_size: 1
Expand All @@ -56,7 +53,10 @@ reward_processor:
shift: 0.0

wandb:
project: "comlrl"
entity: "OpenMLRL"
name: "ac_arxiv"
tags: ["ac", "arxiv", "single-agent"]
project: comlrl
entity: OpenMLRL
name: ac_arxiv
tags:
- ac
- arxiv
- single-agent
50 changes: 25 additions & 25 deletions configs/ac_tldr_config.yaml
Original file line number Diff line number Diff line change
@@ -1,50 +1,47 @@
agent_model:
name: "Qwen/Qwen3-1.7B"
type: "qwen"
name: Qwen/Qwen3-1.7B
type: qwen
temperature: 0.7
top_p: 0.9
top_k: null
max_length: 2048
torch_dtype: "auto"
torch_dtype: auto

agents: null

critic_model:
name: "Qwen/Qwen3-1.7B"
type: "qwen"
temperature: 0.7
top_p: 0.9
name: Qwen/Qwen3-1.7B
type: qwen
max_length: 2048
torch_dtype: "auto"
torch_dtype: auto

critics: null

dataset:
name: "trl-lib/tldr"
type: "tldr"
train_split: "train[:1000]"
eval_split: "test[:1000]"
name: trl-lib/tldr
type: tldr
train_split: train[:1000]
eval_split: test[:1000]

tokenizer:
padding_side: "left"
padding_side: left

output:
base_dir: "./ac_output"
base_dir: output_ac_tldr
verbose: false
save_final_model: true
save_path: "./ac_output/ac_tldr"
save_path: output_ac_tldr

ac:
parallel_training: none
num_turns: 1
num_train_epochs: 4
agent_learning_rate: 5.0e-6
critic_learning_rate: 5.0e-6
agent_learning_rate: 5.0e-06
critic_learning_rate: 5.0e-06
value_loss_coef: 0.6
advantage_normalization: true
rollout_buffer_size: 4
max_new_tokens: 256
temperature: 0.7
top_p: 0.9
top_k: null
advantage_normalization: true
eval_interval: 20
eval_num_samples: 4
eval_batch_size: 1
Expand All @@ -56,7 +53,10 @@ reward_processor:
shift: 0.0

wandb:
project: "comlrl"
entity: "OpenMLRL"
name: "ac_tldr"
tags: ["ac", "tldr", "single-agent"]
project: comlrl
entity: OpenMLRL
name: ac_tldr
tags:
- ac
- tldr
- single-agent
42 changes: 23 additions & 19 deletions configs/grpo_arxiv_config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
agent_model:
name: "Qwen/Qwen3-1.7B"
type: "qwen"
temperature: 0.7
top_p: 0.9
name: Qwen/Qwen3-1.7B
type: qwen
temperature: 0.6
top_p: 0.6
top_k: null
max_length: 2048
torch_dtype: "auto"
torch_dtype: auto

agents: null

Expand All @@ -13,29 +14,29 @@ critic_model: null
critics: null

dataset:
name: "OpenMLRL/arXiv_abstract"
type: "arxiv"
train_split: "train[:1000]"
eval_split: "val[:1000]"
name: OpenMLRL/arXiv_abstract
type: arxiv
train_split: train[:1000]
eval_split: val[:1000]

tokenizer:
padding_side: "left"
padding_side: left

output:
base_dir: "./grpo_output"
base_dir: output_grpo_arxiv
verbose: false
save_final_model: true
save_path: "./grpo_output/arxiv_single"
save_path: output_grpo_arxiv

grpo:
parallel_training: none
num_turns: 1
num_train_epochs: 1
agent_learning_rate: 5.0e-6
agent_learning_rate: 5.0e-06
logging_steps: 400
num_generations: 4
joint_mode: aligned
max_new_tokens: 512
top_k: null
joint_mode: aligned
rollout_buffer_size: 2
train_batch_size: 2
advantage_normalization: true
Expand All @@ -49,7 +50,10 @@ reward_processor:
shift: 0.0

wandb:
project: "comlrl"
entity: "OpenMLRL"
name: "grpo_arxiv"
tags: ["grpo", "arxiv", "single-agent"]
project: comlrl
entity: OpenMLRL
name: grpo_arxiv
tags:
- grpo
- arxiv
- single-agent
Loading