diff --git a/README.md b/README.md index a3ff149..eb859ec 100644 --- a/README.md +++ b/README.md @@ -45,9 +45,11 @@ Key sections in `configs/magrpo_classeval_config.yaml`: ClassEval sub-slices or local mirrors. - `external`: feedback configuration (use `code_feedback` for syntax/test diagnostics). - `magrpo`: forwarded to `comlrl.trainers.reinforce.MAGRPOTrainer`. Includes collaboration - (`num_agents`, param-count assignment), sampling settings (`num_generations`, `num_turns`, - temperature/top_p), rollout buffering (`rollout_buffer_size`), optimization + (`num_agents`, param-count assignment), rollout settings (`num_generations`, `num_turns`), + rollout buffering (`rollout_buffer_size`), optimization hyperparameters, and IO controls. +- Sampling knobs (`temperature`, `top_p`, `top_k`) are configured in `agent_model` and passed + to trainer args at runtime. - `reward_processor`: optional post-processing for rewards (scale, shift). - `output`: persistence knobs (save final model, output paths, verbose debug prints). diff --git a/configs/iac_classeval_config.yaml b/configs/iac_classeval_config.yaml index 9ed0e7a..01f7881 100644 --- a/configs/iac_classeval_config.yaml +++ b/configs/iac_classeval_config.yaml @@ -3,16 +3,15 @@ agent_model: type: qwen temperature: 0.6 top_p: 0.6 + top_k: null max_length: 2048 torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen3-4B-Instruct-2507" + name: Qwen/Qwen3-4B-Instruct-2507 type: qwen - temperature: 0.6 - top_p: 0.6 max_length: 2048 torch_dtype: bfloat16 @@ -25,19 +24,26 @@ dataset: eval_split: test[66:82] output: - base_dir: output - save_final_model: false - save_path: output/final_model + base_dir: output_iac_classeval verbose: false + save_final_model: false + save_path: output_iac_classeval external: mode: code_feedback original_prompt: true previous_response: true + external_prompt_passthrough: false iac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 2 + use_separate_critic: true num_train_epochs: 40 agent_learning_rate: 5e-6 critic_learning_rate: 5e-6 @@ -46,13 +52,9 @@ iac: rollout_buffer_size: 2 train_batch_size: 2 max_new_tokens: 600 - temperature: 0.6 - top_p: 0.6 - top_k: null - num_generations: 1 - use_separate_critic: true discount: 0.9 early_termination_threshold: -0.2 + num_generations: 1 eval_interval: 20 eval_num_samples: 4 eval_batch_size: 1 @@ -67,5 +69,9 @@ wandb: project: classeval_dev entity: null name: codecompletion_classeval_iac - dir: output - tags: ["iac", "classeval", "code-completion", "turns_2"] + dir: output_iac_classeval + tags: + - iac + - classeval + - code-completion + - turns_2 diff --git a/configs/maac_classeval_config.yaml b/configs/maac_classeval_config.yaml index a57fa4a..d15ddbc 100644 --- a/configs/maac_classeval_config.yaml +++ b/configs/maac_classeval_config.yaml @@ -3,16 +3,15 @@ agent_model: type: qwen temperature: 0.6 top_p: 0.6 + top_k: null max_length: 2048 torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen3-4B-Instruct-2507" + name: Qwen/Qwen3-4B-Instruct-2507 type: qwen - temperature: 0.6 - top_p: 0.6 max_length: 2048 torch_dtype: bfloat16 @@ -25,17 +24,23 @@ dataset: eval_split: test[66:82] output: - base_dir: output - save_final_model: false - save_path: output/final_model + base_dir: output_maac_classeval verbose: false + save_final_model: false + save_path: output_maac_classeval external: mode: code_feedback original_prompt: true previous_response: true + external_prompt_passthrough: false maac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 2 critic_type: v @@ -46,12 +51,9 @@ maac: rollout_buffer_size: 2 train_batch_size: 2 max_new_tokens: 600 - temperature: 0.6 - top_p: 0.6 - top_k: null - num_generations: 1 discount: 0.9 early_termination_threshold: -0.2 + num_generations: 1 eval_interval: 20 eval_num_samples: 2 eval_batch_size: 1 @@ -66,5 +68,9 @@ wandb: project: classeval_dev entity: null name: codecompletion_classeval_maac - dir: output - tags: ["maac", "classeval", "code-completion", "turns_2"] + dir: output_maac_classeval + tags: + - maac + - classeval + - code-completion + - turns_2 diff --git a/configs/magrpo_classeval_config.yaml b/configs/magrpo_classeval_config.yaml index 813d2b8..051d0bd 100644 --- a/configs/magrpo_classeval_config.yaml +++ b/configs/magrpo_classeval_config.yaml @@ -3,6 +3,7 @@ agent_model: type: qwen temperature: 0.6 top_p: 0.6 + top_k: null max_length: 2048 torch_dtype: bfloat16 @@ -19,17 +20,21 @@ dataset: eval_split: test[66:82] output: - base_dir: output - save_final_model: false - save_path: output/final_model + base_dir: output_magrpo_classeval verbose: false + save_final_model: false + save_path: output_magrpo_classeval external: mode: code_feedback original_prompt: true previous_response: true + external_prompt_passthrough: false magrpo: + parallel_training: none + agent_devices: + - cuda:0 num_agents: 2 num_turns: 2 num_train_epochs: 13 @@ -37,9 +42,6 @@ magrpo: logging_steps: 1 num_generations: 2 max_new_tokens: 600 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 joint_mode: aligned early_termination_threshold: -0.2 @@ -59,5 +61,9 @@ wandb: project: classeval_dev entity: null name: codecompletion_classeval_magrpo - dir: output - tags: ["magrpo", "classeval", "code-completion", "turns_2"] + dir: output_magrpo_classeval + tags: + - magrpo + - classeval + - code-completion + - turns_2 diff --git a/train/train_iac.py b/train/train_iac.py index eade3a1..b58ec90 100644 --- a/train/train_iac.py +++ b/train/train_iac.py @@ -21,6 +21,9 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(REPO_ROOT)) sys.path.insert(0, REPO_ROOT) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) from datasets import load_dataset # type: ignore from transformers import AutoTokenizer # type: ignore @@ -177,6 +180,55 @@ def _as_bool(x: Any, default: bool) -> bool: return bool(default) +def _as_device_spec(x: Any) -> Any: + if x is None: + return None + if isinstance(x, str): + s = x.strip() + if s.lower() in ("none", "null", ""): + return None + return s + if isinstance(x, (list, tuple)): + return [str(v) for v in x] + return str(x) + + +def _read_sampling_config(model_cfg: Dict[str, Any], *, section: str = "agent_model") -> Dict[str, Any]: + if not isinstance(model_cfg, dict): + raise ValueError(f"{section} must be a mapping.") + missing = [key for key in ("temperature", "top_p", "top_k") if key not in model_cfg] + if missing: + raise ValueError( + f"{section} is missing required sampling fields: {', '.join(missing)}." + ) + + def _require_float(key: str) -> float: + value = model_cfg.get(key) + if value is None or isinstance(value, bool): + raise ValueError(f"{section}.{key} must be provided as a float.") + try: + return float(value) + except Exception as exc: + raise ValueError(f"{section}.{key} must be a float, got {value!r}.") from exc + + def _parse_top_k() -> Optional[int]: + value = model_cfg.get("top_k") + if value is None: + return None + if isinstance(value, str) and value.strip().lower() in ("none", "null", ""): + return None + try: + return int(float(value)) + except Exception as exc: + raise ValueError(f"{section}.top_k must be an integer or null, got {value!r}.") from exc + + return { + "temperature": _require_float("temperature"), + "top_p": _require_float("top_p"), + "top_k": _parse_top_k(), + } + + def _map_dtype(x: Any) -> Any: if isinstance(x, torch.dtype): return x @@ -205,41 +257,50 @@ def _filter_config(candidate: Dict[str, Any], cfg_cls: Any) -> Dict[str, Any]: return {k: v for k, v in candidate.items() if k in params} -def _build_iac_args(cfg: Dict[str, Any], *, model_name: Optional[str]) -> IACConfig: +def _build_iac_args(cfg: Dict[str, Any], *, sampling_cfg: Dict[str, Any]) -> IACConfig: tr = cfg.get("iac") or {} if not isinstance(tr, dict): tr = {} + ext = cfg.get("external") or {} + if not isinstance(ext, dict): + ext = {} use_separate_critic = _as_bool(tr.get("use_separate_critic", True), True) adv_norm = tr.get("advantage_normalization", tr.get("normalize_advantage", True)) candidate = { - "num_turns": _as_int(tr.get("num_turns", 1), 1), + "num_turns": _as_int(tr.get("num_turns", 2), 2), "num_train_epochs": _as_int(tr.get("num_train_epochs", 40), 40), "agent_learning_rate": _as_float(tr.get("agent_learning_rate", 5e-6), 5e-6), "critic_learning_rate": _as_opt_float( tr.get("critic_learning_rate", 5e-6), 5e-6 ), - "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 8), 8), + "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 2), 2), "value_loss_coef": _as_float(tr.get("value_loss_coef", 0.6), 0.6), "value_clip_range": _as_opt_float(tr.get("value_clip_range", 0.2), 0.2), "advantage_normalization": _as_bool(adv_norm, True), - "max_new_tokens": _as_int(tr.get("max_new_tokens", 256), 256), - "temperature": _as_float(tr.get("temperature", 0.6), 0.6), - "top_p": _as_float(tr.get("top_p", 0.6), 0.6), - "top_k": _as_opt_int(tr.get("top_k", None), None), + "max_new_tokens": _as_int(tr.get("max_new_tokens", 600), 600), + "temperature": sampling_cfg["temperature"], + "top_p": sampling_cfg["top_p"], + "top_k": sampling_cfg["top_k"], "num_agents": _as_int(tr.get("num_agents", 2), 2), "num_generations": _as_int(tr.get("num_generations", 1), 1), "use_separate_critic": use_separate_critic, + "parallel_training": str(tr.get("parallel_training", "none")).strip().lower(), + "agent_devices": _as_device_spec(tr.get("agent_devices", ["cuda:0"])), + "critic_devices": _as_device_spec(tr.get("critic_devices", ["cuda:0"])), "critic_value_head_hidden_dim": _as_opt_int( tr.get("critic_value_head_hidden_dim", None), None ), "value_head_hidden_dim": _as_opt_int(tr.get("value_head_hidden_dim", None), None), "discount": _as_float(tr.get("discount", 0.9), 0.9), + "external_prompt_passthrough": _as_bool( + ext.get("external_prompt_passthrough", False), False + ), "early_termination_threshold": _as_opt_float( - tr.get("early_termination_threshold", None), None + tr.get("early_termination_threshold", -0.2), -0.2 ), - "eval_interval": _as_int(tr.get("eval_interval", 16), 16), + "eval_interval": _as_int(tr.get("eval_interval", 20), 20), "eval_num_samples": _as_int(tr.get("eval_num_samples", 4), 4), "eval_batch_size": _as_int(tr.get("eval_batch_size", 1), 1), "logging_steps": _as_int(tr.get("logging_steps", 1), 1), @@ -333,7 +394,7 @@ def main() -> int: if tmp_base: os.environ["CLASSEVAL_TMP_BASE"] = str(tmp_base) - model_name = str(model_cfg.get("name", "Qwen/Qwen2.5-Coder-7B")).strip() + model_name = str(model_cfg.get("name", "Qwen/Qwen3-4B-Instruct-2507")).strip() agent_names = cfg.get("agents") model_kwargs: Dict[str, Any] = {} @@ -349,7 +410,8 @@ def main() -> int: if torch_dtype is not None: model_kwargs["torch_dtype"] = torch_dtype - iac_args = _build_iac_args(cfg, model_name=model_name) + sampling_cfg = _read_sampling_config(model_cfg, section="agent_model") + iac_args = _build_iac_args(cfg, sampling_cfg=sampling_cfg) num_agents = int(getattr(iac_args, "num_agents", 1)) if agent_names is not None: if not isinstance(agent_names, (list, tuple)) or not all( diff --git a/train/train_maac.py b/train/train_maac.py index 4ca6fa9..5dc429c 100644 --- a/train/train_maac.py +++ b/train/train_maac.py @@ -21,6 +21,9 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(REPO_ROOT)) sys.path.insert(0, REPO_ROOT) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) from datasets import load_dataset # type: ignore from transformers import AutoTokenizer # type: ignore @@ -177,6 +180,55 @@ def _as_bool(x: Any, default: bool) -> bool: return bool(default) +def _as_device_spec(x: Any) -> Any: + if x is None: + return None + if isinstance(x, str): + s = x.strip() + if s.lower() in ("none", "null", ""): + return None + return s + if isinstance(x, (list, tuple)): + return [str(v) for v in x] + return str(x) + + +def _read_sampling_config(model_cfg: Dict[str, Any], *, section: str = "agent_model") -> Dict[str, Any]: + if not isinstance(model_cfg, dict): + raise ValueError(f"{section} must be a mapping.") + missing = [key for key in ("temperature", "top_p", "top_k") if key not in model_cfg] + if missing: + raise ValueError( + f"{section} is missing required sampling fields: {', '.join(missing)}." + ) + + def _require_float(key: str) -> float: + value = model_cfg.get(key) + if value is None or isinstance(value, bool): + raise ValueError(f"{section}.{key} must be provided as a float.") + try: + return float(value) + except Exception as exc: + raise ValueError(f"{section}.{key} must be a float, got {value!r}.") from exc + + def _parse_top_k() -> Optional[int]: + value = model_cfg.get("top_k") + if value is None: + return None + if isinstance(value, str) and value.strip().lower() in ("none", "null", ""): + return None + try: + return int(float(value)) + except Exception as exc: + raise ValueError(f"{section}.top_k must be an integer or null, got {value!r}.") from exc + + return { + "temperature": _require_float("temperature"), + "top_p": _require_float("top_p"), + "top_k": _parse_top_k(), + } + + def _map_dtype(x: Any) -> Any: if isinstance(x, torch.dtype): return x @@ -205,10 +257,13 @@ def _filter_config(candidate: Dict[str, Any], cfg_cls: Any) -> Dict[str, Any]: return {k: v for k, v in candidate.items() if k in params} -def _build_maac_args(cfg: Dict[str, Any], *, model_name: Optional[str]) -> MAACConfig: +def _build_maac_args(cfg: Dict[str, Any], *, sampling_cfg: Dict[str, Any]) -> MAACConfig: tr = cfg.get("maac") or {} if not isinstance(tr, dict): tr = {} + ext = cfg.get("external") or {} + if not isinstance(ext, dict): + ext = {} output_cfg = cfg.get("output", {}) or {} adv_norm = tr.get("advantage_normalization", tr.get("normalize_advantage", True)) @@ -217,28 +272,34 @@ def _build_maac_args(cfg: Dict[str, Any], *, model_name: Optional[str]) -> MAACC critic_type = str(critic_type) candidate = { - "num_turns": _as_int(tr.get("num_turns", 1), 1), + "num_turns": _as_int(tr.get("num_turns", 2), 2), "num_train_epochs": _as_int(tr.get("num_train_epochs", 40), 40), "agent_learning_rate": _as_float(tr.get("agent_learning_rate", 5e-6), 5e-6), "critic_learning_rate": _as_float( tr.get("critic_learning_rate", 5e-6), 5e-6 ), - "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 8), 8), + "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 2), 2), "value_loss_coef": _as_float(tr.get("value_loss_coef", 0.6), 0.6), "advantage_normalization": _as_bool(adv_norm, True), - "max_new_tokens": _as_int(tr.get("max_new_tokens", 256), 256), - "temperature": _as_float(tr.get("temperature", 0.6), 0.6), - "top_p": _as_float(tr.get("top_p", 0.6), 0.6), - "top_k": _as_opt_int(tr.get("top_k", None), None), + "max_new_tokens": _as_int(tr.get("max_new_tokens", 600), 600), + "temperature": sampling_cfg["temperature"], + "top_p": sampling_cfg["top_p"], + "top_k": sampling_cfg["top_k"], "num_agents": _as_int(tr.get("num_agents", 2), 2), "num_generations": _as_int(tr.get("num_generations", 1), 1), "discount": _as_float(tr.get("discount", 0.9), 0.9), + "parallel_training": str(tr.get("parallel_training", "none")).strip().lower(), + "agent_devices": _as_device_spec(tr.get("agent_devices", ["cuda:0"])), + "critic_devices": _as_device_spec(tr.get("critic_devices", ["cuda:0"])), "critic_type": critic_type, + "external_prompt_passthrough": _as_bool( + ext.get("external_prompt_passthrough", False), False + ), "early_termination_threshold": _as_opt_float( - tr.get("early_termination_threshold", None), None + tr.get("early_termination_threshold", -0.2), -0.2 ), - "eval_interval": _as_int(tr.get("eval_interval", 16), 16), - "eval_num_samples": _as_int(tr.get("eval_num_samples", 4), 4), + "eval_interval": _as_int(tr.get("eval_interval", 20), 20), + "eval_num_samples": _as_int(tr.get("eval_num_samples", 2), 2), "eval_batch_size": _as_int(tr.get("eval_batch_size", 1), 1), "logging_steps": _as_int(tr.get("logging_steps", 1), 1), } @@ -333,7 +394,7 @@ def main() -> int: if tmp_base: os.environ["CLASSEVAL_TMP_BASE"] = str(tmp_base) - model_name = str(model_cfg.get("name", "Qwen/Qwen2.5-Coder-7B")).strip() + model_name = str(model_cfg.get("name", "Qwen/Qwen3-4B-Instruct-2507")).strip() agent_names = cfg.get("agents") model_kwargs: Dict[str, Any] = {} @@ -349,7 +410,8 @@ def main() -> int: if torch_dtype is not None: model_kwargs["torch_dtype"] = torch_dtype - maac_args = _build_maac_args(cfg, model_name=model_name) + sampling_cfg = _read_sampling_config(model_cfg, section="agent_model") + maac_args = _build_maac_args(cfg, sampling_cfg=sampling_cfg) num_agents = int(getattr(maac_args, "num_agents", 1)) if agent_names is not None: if not isinstance(agent_names, (list, tuple)) or not all( diff --git a/train/train_magrpo.py b/train/train_magrpo.py index ec497b3..1b8e10b 100644 --- a/train/train_magrpo.py +++ b/train/train_magrpo.py @@ -18,13 +18,19 @@ REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(REPO_ROOT)) sys.path.insert(0, REPO_ROOT) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) from datasets import load_dataset # type: ignore -from transformers import AutoModelForCausalLM, AutoTokenizer # type: ignore +from transformers import AutoTokenizer # type: ignore import torch # type: ignore from comlrl.trainers.reinforce import MAGRPOTrainer # type: ignore -from LLM_Collab_Code_Completion.utils.trainer_args import get_trainer_args +from LLM_Collab_Code_Completion.utils.trainer_args import ( + get_trainer_args, + get_agent_sampling_config, +) from LLM_Collab_Code_Completion.utils.data import ( extract_class_name, @@ -151,7 +157,7 @@ def main(): if isinstance(eval_split, str): eval_split = eval_split.strip() or None - num_agents = int(magrpo_cfg.get("num_agents", 1)) + num_agents = int(magrpo_cfg.get("num_agents", 2)) if not eval_split: print("dataset.eval_split is required.") @@ -188,7 +194,7 @@ def main(): tmp_base = None if tmp_base: os.environ["CLASSEVAL_TMP_BASE"] = str(tmp_base) - model_name = model_cfg.get("name", "Qwen/Qwen2.5-3B") + model_name = model_cfg.get("name", "Qwen/Qwen3-4B-Instruct-2507") agent_names = cfg.get("agents") if agent_names is not None: if not isinstance(agent_names, (list, tuple)) or not all( @@ -196,8 +202,6 @@ def main(): ): raise ValueError("agents must be a list of model names.") agent_names = [str(x) for x in agent_names] - model_kwargs: Dict[str, Any] = {} - dtype_cfg = ( model_cfg.get("dtype") or model_cfg.get("torch_dtype") @@ -227,9 +231,6 @@ def _map_dtype(x): except Exception: torch_dtype = None - if torch_dtype is not None: - model_kwargs["torch_dtype"] = torch_dtype - tokenizer_source = agent_names[0] if agent_names else model_name if not tokenizer_source: raise ValueError("agent_model.name or agents must be provided.") @@ -242,19 +243,10 @@ def _map_dtype(x): tok.pad_token = tok.eos_token tokenizer = tokenizers[0] - agents = [] - if agent_names: - for name in agent_names: - agent = AutoModelForCausalLM.from_pretrained(name, **model_kwargs) - agents.append(agent) - else: - for _ in range(num_agents): - agent = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) - agents.append(agent) - strategy = get_strategy(num_agents=num_agents, seed=seed) - magrpo_args = get_trainer_args(cfg) + sampling_cfg = get_agent_sampling_config(cfg) + magrpo_args = get_trainer_args(cfg, sampling_cfg=sampling_cfg) formatters = build_agent_formatters(strategy) reward_func = get_reward_function(strategy=strategy, num_agents=num_agents) @@ -335,8 +327,12 @@ def _map_dtype(x): trainer_kwargs = { "agent_model": model_name or None, - "agents": agents, + "agents": agent_names, "num_agents": num_agents, + "model_config": { + "torch_dtype": torch_dtype, + "special_tokens": model_cfg.get("special_tokens", {}), + }, "reward_func": reward_func, "formatters": formatters, "args": magrpo_args, diff --git a/utils/trainer_args.py b/utils/trainer_args.py index 51294ce..48ea230 100644 --- a/utils/trainer_args.py +++ b/utils/trainer_args.py @@ -94,44 +94,99 @@ def _as_bool(x: Any, default: bool) -> bool: return bool(x) -def get_trainer_args(cfg: Dict[str, Any]) -> MAGRPOConfig: +def _as_device_spec(x: Any) -> Any: + if x is None: + return None + if isinstance(x, str): + s = x.strip() + if s.lower() in ("none", "null", ""): + return None + return s + if isinstance(x, (list, tuple)): + return [str(v) for v in x] + return str(x) + + +def get_agent_sampling_config(cfg: Dict[str, Any]) -> Dict[str, Any]: + model_cfg = cfg.get("agent_model") + if not isinstance(model_cfg, dict): + raise ValueError("agent_model must be a mapping.") + missing = [key for key in ("temperature", "top_p", "top_k") if key not in model_cfg] + if missing: + raise ValueError( + f"agent_model is missing required sampling fields: {', '.join(missing)}" + ) + + def _require_float(key: str) -> float: + value = model_cfg.get(key) + if value is None or isinstance(value, bool): + raise ValueError(f"agent_model.{key} must be provided as a float.") + try: + return float(value) + except Exception as exc: + raise ValueError(f"agent_model.{key} must be a float, got {value!r}.") from exc + + top_k_raw = model_cfg.get("top_k") + if isinstance(top_k_raw, str) and top_k_raw.strip().lower() in ("none", "null", ""): + top_k_val: Optional[int] = None + elif top_k_raw is None: + top_k_val = None + else: + try: + top_k_val = int(float(top_k_raw)) + except Exception as exc: + raise ValueError( + f"agent_model.top_k must be an integer or null, got {top_k_raw!r}." + ) from exc + + return { + "temperature": _require_float("temperature"), + "top_p": _require_float("top_p"), + "top_k": top_k_val, + } + + +def get_trainer_args(cfg: Dict[str, Any], *, sampling_cfg: Dict[str, Any]) -> MAGRPOConfig: tr = cfg.get("magrpo", {}) - lr_val = tr.get("agent_learning_rate", 3e-5) + ext = cfg.get("external", {}) + lr_val = tr.get("agent_learning_rate", 1e-5) candidate = { - "num_turns": _as_int(tr.get("num_turns", 1), 1), - "num_train_epochs": _as_int(tr.get("num_train_epochs", 3), 3), - "agent_learning_rate": _as_float(lr_val, 3e-5), - "logging_steps": _as_int(tr.get("logging_steps", 50), 50), - "num_generations": _as_int(tr.get("num_generations", 4), 4), - "max_new_tokens": _as_int(tr.get("max_new_tokens", 512), 512), - "temperature": _as_float(tr.get("temperature", 0.2), 0.2), - "top_p": _as_float(tr.get("top_p", 0.95), 0.95), + "num_turns": _as_int(tr.get("num_turns", 2), 2), + "num_train_epochs": _as_int(tr.get("num_train_epochs", 13), 13), + "agent_learning_rate": _as_float(lr_val, 1e-5), + "logging_steps": _as_int(tr.get("logging_steps", 1), 1), + "num_generations": _as_int(tr.get("num_generations", 2), 2), + "max_new_tokens": _as_int(tr.get("max_new_tokens", 600), 600), + "temperature": _as_float(sampling_cfg.get("temperature"), 0.6), + "top_p": _as_float(sampling_cfg.get("top_p"), 0.6), + "top_k": _as_opt_int(sampling_cfg.get("top_k"), None), } - if "top_k" in tr: - candidate["top_k"] = _as_opt_int(tr.get("top_k", None), None) candidate.update( { - "num_agents": _as_int(tr.get("num_agents", 1), 1), + "num_agents": _as_int(tr.get("num_agents", 2), 2), + "parallel_training": str(tr.get("parallel_training", "none")).strip().lower(), + "agent_devices": _as_device_spec(tr.get("agent_devices", ["cuda:0"])), "discount": _as_float(tr.get("discount", 0.9), 0.9), "joint_mode": str(tr.get("joint_mode", "aligned")), + "early_termination_threshold": _as_opt_float( + tr.get("early_termination_threshold", -0.2), -0.2 + ), } ) - if "early_termination_threshold" in tr: - candidate["early_termination_threshold"] = _as_opt_float( - tr.get("early_termination_threshold", None), None - ) candidate.update( { - "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 2), 2), - "train_batch_size": _as_opt_int(tr.get("train_batch_size", None), None), + "rollout_buffer_size": _as_int(tr.get("rollout_buffer_size", 1), 1), + "train_batch_size": _as_opt_int(tr.get("train_batch_size", 1), 1), "advantage_normalization": _as_bool( tr.get("advantage_normalization", True), True ), - "eval_interval": _as_int(tr.get("eval_interval", 16), 16), + "eval_interval": _as_int(tr.get("eval_interval", 10), 10), "eval_num_samples": _as_int(tr.get("eval_num_samples", 4), 4), "eval_batch_size": _as_int(tr.get("eval_batch_size", 1), 1), - "external_prompt_passthrough": True, + "external_prompt_passthrough": _as_bool( + ext.get("external_prompt_passthrough", False), False + ), } )