Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ Key sections in `configs/magrpo_classeval_config.yaml`:
ClassEval sub-slices or local mirrors.
- `external`: feedback configuration (use `code_feedback` for syntax/test diagnostics).
- `magrpo`: forwarded to `comlrl.trainers.reinforce.MAGRPOTrainer`. Includes collaboration
(`num_agents`, param-count assignment), sampling settings (`num_generations`, `num_turns`,
temperature/top_p), rollout buffering (`rollout_buffer_size`), optimization
(`num_agents`, param-count assignment), rollout settings (`num_generations`, `num_turns`),
rollout buffering (`rollout_buffer_size`), optimization
hyperparameters, and IO controls.
- Sampling knobs (`temperature`, `top_p`, `top_k`) are configured in `agent_model` and passed
to trainer args at runtime.
- `reward_processor`: optional post-processing for rewards (scale, shift).
- `output`: persistence knobs (save final model, output paths, verbose debug prints).

Expand Down
32 changes: 19 additions & 13 deletions configs/iac_classeval_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@ agent_model:
type: qwen
temperature: 0.6
top_p: 0.6
top_k: null
max_length: 2048
torch_dtype: bfloat16

agents: null

critic_model:
name: "Qwen/Qwen3-4B-Instruct-2507"
name: Qwen/Qwen3-4B-Instruct-2507
type: qwen
temperature: 0.6
top_p: 0.6
max_length: 2048
torch_dtype: bfloat16

Expand All @@ -25,19 +24,26 @@ dataset:
eval_split: test[66:82]

output:
base_dir: output
save_final_model: false
save_path: output/final_model
base_dir: output_iac_classeval
verbose: false
save_final_model: false
save_path: output_iac_classeval

external:
mode: code_feedback
original_prompt: true
previous_response: true
external_prompt_passthrough: false

iac:
parallel_training: none
agent_devices:
- cuda:0
critic_devices:
- cuda:0
num_agents: 2
num_turns: 2
use_separate_critic: true
num_train_epochs: 40
agent_learning_rate: 5e-6
critic_learning_rate: 5e-6
Expand All @@ -46,13 +52,9 @@ iac:
rollout_buffer_size: 2
train_batch_size: 2
max_new_tokens: 600
temperature: 0.6
top_p: 0.6
top_k: null
num_generations: 1
use_separate_critic: true
discount: 0.9
early_termination_threshold: -0.2
num_generations: 1
eval_interval: 20
eval_num_samples: 4
eval_batch_size: 1
Expand All @@ -67,5 +69,9 @@ wandb:
project: classeval_dev
entity: null
name: codecompletion_classeval_iac
dir: output
tags: ["iac", "classeval", "code-completion", "turns_2"]
dir: output_iac_classeval
tags:
- iac
- classeval
- code-completion
- turns_2
30 changes: 18 additions & 12 deletions configs/maac_classeval_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@ agent_model:
type: qwen
temperature: 0.6
top_p: 0.6
top_k: null
max_length: 2048
torch_dtype: bfloat16

agents: null

critic_model:
name: "Qwen/Qwen3-4B-Instruct-2507"
name: Qwen/Qwen3-4B-Instruct-2507
type: qwen
temperature: 0.6
top_p: 0.6
max_length: 2048
torch_dtype: bfloat16

Expand All @@ -25,17 +24,23 @@ dataset:
eval_split: test[66:82]

output:
base_dir: output
save_final_model: false
save_path: output/final_model
base_dir: output_maac_classeval
verbose: false
save_final_model: false
save_path: output_maac_classeval

external:
mode: code_feedback
original_prompt: true
previous_response: true
external_prompt_passthrough: false

maac:
parallel_training: none
agent_devices:
- cuda:0
critic_devices:
- cuda:0
num_agents: 2
num_turns: 2
critic_type: v
Expand All @@ -46,12 +51,9 @@ maac:
rollout_buffer_size: 2
train_batch_size: 2
max_new_tokens: 600
temperature: 0.6
top_p: 0.6
top_k: null
num_generations: 1
discount: 0.9
early_termination_threshold: -0.2
num_generations: 1
eval_interval: 20
eval_num_samples: 2
eval_batch_size: 1
Expand All @@ -66,5 +68,9 @@ wandb:
project: classeval_dev
entity: null
name: codecompletion_classeval_maac
dir: output
tags: ["maac", "classeval", "code-completion", "turns_2"]
dir: output_maac_classeval
tags:
- maac
- classeval
- code-completion
- turns_2
22 changes: 14 additions & 8 deletions configs/magrpo_classeval_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ agent_model:
type: qwen
temperature: 0.6
top_p: 0.6
top_k: null
max_length: 2048
torch_dtype: bfloat16

Expand All @@ -19,27 +20,28 @@ dataset:
eval_split: test[66:82]

output:
base_dir: output
save_final_model: false
save_path: output/final_model
base_dir: output_magrpo_classeval
verbose: false
save_final_model: false
save_path: output_magrpo_classeval

external:
mode: code_feedback
original_prompt: true
previous_response: true
external_prompt_passthrough: false

magrpo:
parallel_training: none
agent_devices:
- cuda:0
num_agents: 2
num_turns: 2
num_train_epochs: 13
agent_learning_rate: 1e-5
logging_steps: 1
num_generations: 2
max_new_tokens: 600
temperature: 0.6
top_p: 0.6
top_k: null
discount: 0.9
joint_mode: aligned
early_termination_threshold: -0.2
Expand All @@ -59,5 +61,9 @@ wandb:
project: classeval_dev
entity: null
name: codecompletion_classeval_magrpo
dir: output
tags: ["magrpo", "classeval", "code-completion", "turns_2"]
dir: output_magrpo_classeval
tags:
- magrpo
- classeval
- code-completion
- turns_2
84 changes: 73 additions & 11 deletions train/train_iac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, os.path.dirname(REPO_ROOT))
sys.path.insert(0, REPO_ROOT)
COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL")
if COMLRL_ROOT not in sys.path:
sys.path.insert(0, COMLRL_ROOT)

from datasets import load_dataset # type: ignore
from transformers import AutoTokenizer # type: ignore
Expand Down Expand Up @@ -177,6 +180,55 @@ def _as_bool(x: Any, default: bool) -> bool:
return bool(default)


def _as_device_spec(x: Any) -> Any:
if x is None:
return None
if isinstance(x, str):
s = x.strip()
if s.lower() in ("none", "null", ""):
return None
return s
if isinstance(x, (list, tuple)):
return [str(v) for v in x]
return str(x)


def _read_sampling_config(model_cfg: Dict[str, Any], *, section: str = "agent_model") -> Dict[str, Any]:
if not isinstance(model_cfg, dict):
raise ValueError(f"{section} must be a mapping.")
missing = [key for key in ("temperature", "top_p", "top_k") if key not in model_cfg]
if missing:
raise ValueError(
f"{section} is missing required sampling fields: {', '.join(missing)}."
)

def _require_float(key: str) -> float:
value = model_cfg.get(key)
if value is None or isinstance(value, bool):
raise ValueError(f"{section}.{key} must be provided as a float.")
try:
return float(value)
except Exception as exc:
raise ValueError(f"{section}.{key} must be a float, got {value!r}.") from exc

def _parse_top_k() -> Optional[int]:
value = model_cfg.get("top_k")
if value is None:
return None
if isinstance(value, str) and value.strip().lower() in ("none", "null", ""):
return None
try:
return int(float(value))
except Exception as exc:
raise ValueError(f"{section}.top_k must be an integer or null, got {value!r}.") from exc

return {
"temperature": _require_float("temperature"),
"top_p": _require_float("top_p"),
"top_k": _parse_top_k(),
}


def _map_dtype(x: Any) -> Any:
if isinstance(x, torch.dtype):
return x
Expand Down Expand Up @@ -205,41 +257,50 @@ def _filter_config(candidate: Dict[str, Any], cfg_cls: Any) -> Dict[str, Any]:
return {k: v for k, v in candidate.items() if k in params}


def _build_iac_args(cfg: Dict[str, Any], *, model_name: Optional[str]) -> IACConfig:
def _build_iac_args(cfg: Dict[str, Any], *, sampling_cfg: Dict[str, Any]) -> IACConfig:
tr = cfg.get("iac") or {}
if not isinstance(tr, dict):
tr = {}
ext = cfg.get("external") or {}
if not isinstance(ext, dict):
ext = {}
use_separate_critic = _as_bool(tr.get("use_separate_critic", True), True)

adv_norm = tr.get("advantage_normalization", tr.get("normalize_advantage", True))

candidate = {
"num_turns": _as_int(tr.get("num_turns", 1), 1),
"num_turns": _as_int(tr.get("num_turns", 2), 2),
"num_train_epochs": _as_int(tr.get("num_train_epochs", 40), 40),
"agent_learning_rate": _as_float(tr.get("agent_learning_rate", 5e-6), 5e-6),
"critic_learning_rate": _as_opt_float(
tr.get("critic_learning_rate", 5e-6), 5e-6
),
"rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 8), 8),
"rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 2), 2),
"value_loss_coef": _as_float(tr.get("value_loss_coef", 0.6), 0.6),
"value_clip_range": _as_opt_float(tr.get("value_clip_range", 0.2), 0.2),
"advantage_normalization": _as_bool(adv_norm, True),
"max_new_tokens": _as_int(tr.get("max_new_tokens", 256), 256),
"temperature": _as_float(tr.get("temperature", 0.6), 0.6),
"top_p": _as_float(tr.get("top_p", 0.6), 0.6),
"top_k": _as_opt_int(tr.get("top_k", None), None),
"max_new_tokens": _as_int(tr.get("max_new_tokens", 600), 600),
"temperature": sampling_cfg["temperature"],
"top_p": sampling_cfg["top_p"],
"top_k": sampling_cfg["top_k"],
"num_agents": _as_int(tr.get("num_agents", 2), 2),
"num_generations": _as_int(tr.get("num_generations", 1), 1),
"use_separate_critic": use_separate_critic,
"parallel_training": str(tr.get("parallel_training", "none")).strip().lower(),
"agent_devices": _as_device_spec(tr.get("agent_devices", ["cuda:0"])),
"critic_devices": _as_device_spec(tr.get("critic_devices", ["cuda:0"])),
"critic_value_head_hidden_dim": _as_opt_int(
tr.get("critic_value_head_hidden_dim", None), None
),
"value_head_hidden_dim": _as_opt_int(tr.get("value_head_hidden_dim", None), None),
"discount": _as_float(tr.get("discount", 0.9), 0.9),
"external_prompt_passthrough": _as_bool(
ext.get("external_prompt_passthrough", False), False
),
"early_termination_threshold": _as_opt_float(
tr.get("early_termination_threshold", None), None
tr.get("early_termination_threshold", -0.2), -0.2
),
"eval_interval": _as_int(tr.get("eval_interval", 16), 16),
"eval_interval": _as_int(tr.get("eval_interval", 20), 20),
"eval_num_samples": _as_int(tr.get("eval_num_samples", 4), 4),
"eval_batch_size": _as_int(tr.get("eval_batch_size", 1), 1),
"logging_steps": _as_int(tr.get("logging_steps", 1), 1),
Expand Down Expand Up @@ -333,7 +394,7 @@ def main() -> int:
if tmp_base:
os.environ["CLASSEVAL_TMP_BASE"] = str(tmp_base)

model_name = str(model_cfg.get("name", "Qwen/Qwen2.5-Coder-7B")).strip()
model_name = str(model_cfg.get("name", "Qwen/Qwen3-4B-Instruct-2507")).strip()
agent_names = cfg.get("agents")
model_kwargs: Dict[str, Any] = {}

Expand All @@ -349,7 +410,8 @@ def main() -> int:
if torch_dtype is not None:
model_kwargs["torch_dtype"] = torch_dtype

iac_args = _build_iac_args(cfg, model_name=model_name)
sampling_cfg = _read_sampling_config(model_cfg, section="agent_model")
iac_args = _build_iac_args(cfg, sampling_cfg=sampling_cfg)
num_agents = int(getattr(iac_args, "num_agents", 1))
if agent_names is not None:
if not isinstance(agent_names, (list, tuple)) or not all(
Expand Down
Loading