Skip to content
64 changes: 57 additions & 7 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,83 @@
"""

import argparse
import os
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Optional

import yaml


REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL")
if COMLRL_ROOT not in sys.path:
sys.path.insert(0, COMLRL_ROOT)


@dataclass(frozen=True)
class ModelConfig:
"""Configuration for model loading and generation."""

name: str
type: str = "qwen"
temperature: float = 0.7
top_p: float = 0.9
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
max_length: int = 2048
special_tokens: Dict[str, str] = field(default_factory=dict)
torch_dtype: Optional[str] = None

@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "ModelConfig":
def from_dict(
cls,
config_dict: Dict[str, Any],
*,
require_sampling: bool = True,
) -> "ModelConfig":
"""Create ModelConfig from dictionary."""
if require_sampling:
missing = [
key
for key in ("temperature", "top_p", "top_k")
if key not in config_dict
]
if missing:
raise ValueError(
f"agent_model is missing required sampling fields: {', '.join(missing)}"
)

def _as_optional_float(value: Any) -> Optional[float]:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"Invalid float value: {value}") from exc

def _as_optional_int(value: Any) -> Optional[int]:
if value is None:
return None
if isinstance(value, str) and value.strip().lower() in ("none", "null", ""):
return None
try:
return int(float(value))
except (TypeError, ValueError) as exc:
raise ValueError(f"Invalid int value: {value}") from exc

temperature = _as_optional_float(config_dict.get("temperature"))
top_p = _as_optional_float(config_dict.get("top_p"))
top_k = _as_optional_int(config_dict.get("top_k"))
if require_sampling and (temperature is None or top_p is None):
raise ValueError("agent_model.temperature and agent_model.top_p must be non-null.")

return cls(
name=config_dict.get("name", ""),
type=config_dict.get("type", "qwen"),
temperature=config_dict.get("temperature", 0.7),
top_p=config_dict.get("top_p", 0.9),
temperature=temperature,
top_p=top_p,
top_k=top_k,
max_length=config_dict.get("max_length", 2048),
special_tokens=config_dict.get("special_tokens", {}),
torch_dtype=(
Expand Down Expand Up @@ -74,7 +124,7 @@ def get_agent_model_config(self) -> ModelConfig:
model_section = self.get_section("agent_model")
if not model_section:
raise ValueError("No 'agent_model' section found in configuration")
return ModelConfig.from_dict(model_section)
return ModelConfig.from_dict(model_section, require_sampling=True)

def get_critic_model_config(self, required: bool = True) -> Optional[ModelConfig]:
"""Get critic model configuration as ModelConfig object."""
Expand All @@ -83,7 +133,7 @@ def get_critic_model_config(self, required: bool = True) -> Optional[ModelConfig
if required:
raise ValueError("No 'critic_model' section found in configuration")
return None
return ModelConfig.from_dict(critic_section)
return ModelConfig.from_dict(critic_section, require_sampling=False)

def update(self, updates: Dict[str, Any]):
"""Update configuration with new values (deep merge)."""
Expand Down
57 changes: 29 additions & 28 deletions configs/ac_che_config.yaml
Original file line number Diff line number Diff line change
@@ -1,61 +1,62 @@
agent_model:
name: "Qwen/Qwen2.5-Coder-3B"
type: "qwen"
temperature: 0.7
top_p: 0.9
name: Qwen/Qwen2.5-Coder-3B
type: qwen
temperature: 0.6
top_p: 0.6
top_k: null
max_length: 2048
torch_dtype: "bfloat16"
torch_dtype: bfloat16

agents: null

critic_model:
name: "Qwen/Qwen2.5-Coder-3B"
type: "qwen"
temperature: 0.7
top_p: 0.9
name: Qwen/Qwen2.5-Coder-3B
type: qwen
max_length: 2048
torch_dtype: "bfloat16"
torch_dtype: bfloat16

critics: null

dataset:
name: "OpenMLRL/CoopHumanEval"
type: "coophumaneval"
train_split: "train[16:]"
eval_split: "train[:16]"
name: OpenMLRL/CoopHumanEval
type: coophumaneval
train_split: train[16:]
eval_split: train[:16]

output:
base_dir: "output"
base_dir: output_ac_che
verbose: false
save_final_model: false
save_path: "output/maac_final"
save_path: output_ac_che

external:
mode: "level_feedback"
mode: level_feedback
sandbox_slice: 1

ac:
parallel_training: none
num_turns: 2
num_train_epochs: 80
agent_learning_rate: 5.0e-6
critic_learning_rate: 5.0e-6
agent_learning_rate: 5.0e-06
critic_learning_rate: 5.0e-06
value_loss_coef: 0.6
advantage_normalization: true
rollout_buffer_size: 4
max_new_tokens: 256
temperature: 0.6
top_p: 0.6
top_k: null
discount: 0.9
early_termination_threshold: -0.2
advantage_normalization: true
eval_interval: 40
eval_num_samples: 4
eval_batch_size: 1
reward_shift: -4

wandb:
project: "comlrl"
entity: "OpenMLRL"
name: "ac_coophumaneval"
dir: "output"
tags: ["ac", "coophumaneval", "single-agent", "turns_2"]
project: comlrl
entity: OpenMLRL
name: ac_coophumaneval
dir: output_ac_che
tags:
- ac
- coophumaneval
- single-agent
- turns_2
57 changes: 29 additions & 28 deletions configs/ac_he_config.yaml
Original file line number Diff line number Diff line change
@@ -1,61 +1,62 @@
agent_model:
name: "Qwen/Qwen2.5-Coder-3B"
type: "qwen"
temperature: 0.7
top_p: 0.9
name: Qwen/Qwen2.5-Coder-3B
type: qwen
temperature: 0.6
top_p: 0.6
top_k: null
max_length: 2048
torch_dtype: "bfloat16"
torch_dtype: bfloat16

agents: null

critic_model:
name: "Qwen/Qwen2.5-Coder-3B"
type: "qwen"
temperature: 0.7
top_p: 0.9
name: Qwen/Qwen2.5-Coder-3B
type: qwen
max_length: 2048
torch_dtype: "bfloat16"
torch_dtype: bfloat16

critics: null

dataset:
name: "openai/openai_humaneval"
type: "humaneval"
train_split: "test[33:163]"
eval_split: "test[:32]"
name: openai/openai_humaneval
type: humaneval
train_split: test[33:163]
eval_split: test[:32]

output:
base_dir: "output"
base_dir: output_ac_he
verbose: false
save_final_model: false
save_path: "output/maac_final"
save_path: output_ac_he

external:
mode: "level_feedback"
mode: level_feedback
sandbox_slice: 1

ac:
parallel_training: none
num_turns: 2
num_train_epochs: 80
agent_learning_rate: 5.0e-6
critic_learning_rate: 5.0e-6
agent_learning_rate: 5.0e-06
critic_learning_rate: 5.0e-06
value_loss_coef: 0.6
advantage_normalization: true
rollout_buffer_size: 4
max_new_tokens: 256
temperature: 0.6
top_p: 0.6
top_k: null
discount: 0.9
early_termination_threshold: -0.2
advantage_normalization: true
eval_interval: 40
eval_num_samples: 4
eval_batch_size: 1
reward_shift: -4

wandb:
project: "comlrl"
entity: "OpenMLRL"
name: "ac_humaneval"
dir: "output"
tags: ["ac", "humaneval", "single-agent", "turns_2"]
project: comlrl
entity: OpenMLRL
name: ac_humaneval
dir: output_ac_he
tags:
- ac
- humaneval
- single-agent
- turns_2
57 changes: 29 additions & 28 deletions configs/ac_mbpp_config.yaml
Original file line number Diff line number Diff line change
@@ -1,61 +1,62 @@
agent_model:
name: "Qwen/Qwen2.5-Coder-3B"
type: "qwen"
temperature: 0.7
top_p: 0.9
name: Qwen/Qwen2.5-Coder-3B
type: qwen
temperature: 0.6
top_p: 0.6
top_k: null
max_length: 2048
torch_dtype: "bfloat16"
torch_dtype: bfloat16

agents: null

critic_model:
name: "Qwen/Qwen2.5-Coder-3B"
type: "qwen"
temperature: 0.7
top_p: 0.9
name: Qwen/Qwen2.5-Coder-3B
type: qwen
max_length: 2048
torch_dtype: "bfloat16"
torch_dtype: bfloat16

critics: null

dataset:
name: "OpenMLRL/MBPP"
type: "mbpp"
train_split: "test[15:65]"
eval_split: "test[:15]"
name: OpenMLRL/MBPP
type: mbpp
train_split: test[15:65]
eval_split: test[:15]

output:
base_dir: "output"
base_dir: output_ac_mbpp
verbose: false
save_final_model: false
save_path: "output/maac_final"
save_path: output_ac_mbpp

external:
mode: "level_feedback"
mode: level_feedback
sandbox_slice: 1

ac:
parallel_training: none
num_turns: 2
num_train_epochs: 80
agent_learning_rate: 5.0e-6
critic_learning_rate: 5.0e-6
agent_learning_rate: 5.0e-06
critic_learning_rate: 5.0e-06
value_loss_coef: 0.6
advantage_normalization: true
rollout_buffer_size: 4
max_new_tokens: 256
temperature: 0.6
top_p: 0.6
top_k: null
discount: 0.9
early_termination_threshold: -0.2
advantage_normalization: true
eval_interval: 40
eval_num_samples: 4
eval_batch_size: 1
reward_shift: -4

wandb:
project: "comlrl"
entity: "OpenMLRL"
name: "ac_mbpp"
dir: "output"
tags: ["ac", "mbpp", "single-agent", "turns_2"]
project: comlrl
entity: OpenMLRL
name: ac_mbpp
dir: output_ac_mbpp
tags:
- ac
- mbpp
- single-agent
- turns_2
Loading