diff --git a/config.py b/config.py index 8a65304..ebfb5a5 100644 --- a/config.py +++ b/config.py @@ -4,6 +4,8 @@ """ import argparse +import os +import sys from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Optional @@ -11,26 +13,74 @@ import yaml +REPO_ROOT = os.path.dirname(os.path.abspath(__file__)) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) + + @dataclass(frozen=True) class ModelConfig: """Configuration for model loading and generation.""" name: str type: str = "qwen" - temperature: float = 0.7 - top_p: float = 0.9 + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None max_length: int = 2048 special_tokens: Dict[str, str] = field(default_factory=dict) torch_dtype: Optional[str] = None @classmethod - def from_dict(cls, config_dict: Dict[str, Any]) -> "ModelConfig": + def from_dict( + cls, + config_dict: Dict[str, Any], + *, + require_sampling: bool = True, + ) -> "ModelConfig": """Create ModelConfig from dictionary.""" + if require_sampling: + missing = [ + key + for key in ("temperature", "top_p", "top_k") + if key not in config_dict + ] + if missing: + raise ValueError( + f"agent_model is missing required sampling fields: {', '.join(missing)}" + ) + + def _as_optional_float(value: Any) -> Optional[float]: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid float value: {value}") from exc + + def _as_optional_int(value: Any) -> Optional[int]: + if value is None: + return None + if isinstance(value, str) and value.strip().lower() in ("none", "null", ""): + return None + try: + return int(float(value)) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid int value: {value}") from exc + + temperature = _as_optional_float(config_dict.get("temperature")) + top_p = _as_optional_float(config_dict.get("top_p")) + top_k = _as_optional_int(config_dict.get("top_k")) + if require_sampling and (temperature is None or top_p is None): + raise ValueError("agent_model.temperature and agent_model.top_p must be non-null.") + return cls( name=config_dict.get("name", ""), type=config_dict.get("type", "qwen"), - temperature=config_dict.get("temperature", 0.7), - top_p=config_dict.get("top_p", 0.9), + temperature=temperature, + top_p=top_p, + top_k=top_k, max_length=config_dict.get("max_length", 2048), special_tokens=config_dict.get("special_tokens", {}), torch_dtype=( @@ -74,7 +124,7 @@ def get_agent_model_config(self) -> ModelConfig: model_section = self.get_section("agent_model") if not model_section: raise ValueError("No 'agent_model' section found in configuration") - return ModelConfig.from_dict(model_section) + return ModelConfig.from_dict(model_section, require_sampling=True) def get_critic_model_config(self, required: bool = True) -> Optional[ModelConfig]: """Get critic model configuration as ModelConfig object.""" @@ -83,7 +133,7 @@ def get_critic_model_config(self, required: bool = True) -> Optional[ModelConfig if required: raise ValueError("No 'critic_model' section found in configuration") return None - return ModelConfig.from_dict(critic_section) + return ModelConfig.from_dict(critic_section, require_sampling=False) def update(self, updates: Dict[str, Any]): """Update configuration with new values (deep merge).""" diff --git a/configs/ac_arxiv_config.yaml b/configs/ac_arxiv_config.yaml index 948242f..e15a060 100644 --- a/configs/ac_arxiv_config.yaml +++ b/configs/ac_arxiv_config.yaml @@ -1,50 +1,47 @@ agent_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" + name: Qwen/Qwen3-1.7B + type: qwen temperature: 0.7 top_p: 0.9 + top_k: null max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto agents: null critic_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen3-1.7B + type: qwen max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto critics: null dataset: - name: "OpenMLRL/arXiv_abstract" - type: "arxiv" - train_split: "train[:1000]" - eval_split: "val[:1000]" + name: OpenMLRL/arXiv_abstract + type: arxiv + train_split: train[:1000] + eval_split: val[:1000] tokenizer: - padding_side: "left" + padding_side: left output: - base_dir: "./ac_output" + base_dir: output_ac_arxiv verbose: false save_final_model: true - save_path: "./ac_output/ac_arxiv" + save_path: output_ac_arxiv ac: + parallel_training: none num_turns: 1 num_train_epochs: 4 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 - advantage_normalization: true rollout_buffer_size: 4 max_new_tokens: 512 - temperature: 0.7 - top_p: 0.9 - top_k: null + advantage_normalization: true eval_interval: 20 eval_num_samples: 4 eval_batch_size: 1 @@ -56,7 +53,10 @@ reward_processor: shift: 0.0 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "ac_arxiv" - tags: ["ac", "arxiv", "single-agent"] + project: comlrl + entity: OpenMLRL + name: ac_arxiv + tags: + - ac + - arxiv + - single-agent diff --git a/configs/ac_tldr_config.yaml b/configs/ac_tldr_config.yaml index 855ce83..8c26b76 100644 --- a/configs/ac_tldr_config.yaml +++ b/configs/ac_tldr_config.yaml @@ -1,50 +1,47 @@ agent_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" + name: Qwen/Qwen3-1.7B + type: qwen temperature: 0.7 top_p: 0.9 + top_k: null max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto agents: null critic_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen3-1.7B + type: qwen max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto critics: null dataset: - name: "trl-lib/tldr" - type: "tldr" - train_split: "train[:1000]" - eval_split: "test[:1000]" + name: trl-lib/tldr + type: tldr + train_split: train[:1000] + eval_split: test[:1000] tokenizer: - padding_side: "left" + padding_side: left output: - base_dir: "./ac_output" + base_dir: output_ac_tldr verbose: false save_final_model: true - save_path: "./ac_output/ac_tldr" + save_path: output_ac_tldr ac: + parallel_training: none num_turns: 1 num_train_epochs: 4 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 - advantage_normalization: true rollout_buffer_size: 4 max_new_tokens: 256 - temperature: 0.7 - top_p: 0.9 - top_k: null + advantage_normalization: true eval_interval: 20 eval_num_samples: 4 eval_batch_size: 1 @@ -56,7 +53,10 @@ reward_processor: shift: 0.0 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "ac_tldr" - tags: ["ac", "tldr", "single-agent"] + project: comlrl + entity: OpenMLRL + name: ac_tldr + tags: + - ac + - tldr + - single-agent diff --git a/configs/grpo_arxiv_config.yaml b/configs/grpo_arxiv_config.yaml index ef37939..e887935 100644 --- a/configs/grpo_arxiv_config.yaml +++ b/configs/grpo_arxiv_config.yaml @@ -1,10 +1,11 @@ agent_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen3-1.7B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto agents: null @@ -13,29 +14,29 @@ critic_model: null critics: null dataset: - name: "OpenMLRL/arXiv_abstract" - type: "arxiv" - train_split: "train[:1000]" - eval_split: "val[:1000]" + name: OpenMLRL/arXiv_abstract + type: arxiv + train_split: train[:1000] + eval_split: val[:1000] tokenizer: - padding_side: "left" + padding_side: left output: - base_dir: "./grpo_output" + base_dir: output_grpo_arxiv verbose: false save_final_model: true - save_path: "./grpo_output/arxiv_single" + save_path: output_grpo_arxiv grpo: + parallel_training: none num_turns: 1 num_train_epochs: 1 - agent_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 logging_steps: 400 num_generations: 4 - joint_mode: aligned max_new_tokens: 512 - top_k: null + joint_mode: aligned rollout_buffer_size: 2 train_batch_size: 2 advantage_normalization: true @@ -49,7 +50,10 @@ reward_processor: shift: 0.0 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "grpo_arxiv" - tags: ["grpo", "arxiv", "single-agent"] + project: comlrl + entity: OpenMLRL + name: grpo_arxiv + tags: + - grpo + - arxiv + - single-agent diff --git a/configs/grpo_tldr_config.yaml b/configs/grpo_tldr_config.yaml index f19e4a1..f1b0be9 100644 --- a/configs/grpo_tldr_config.yaml +++ b/configs/grpo_tldr_config.yaml @@ -1,10 +1,11 @@ agent_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen3-1.7B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto agents: null @@ -13,29 +14,29 @@ critic_model: null critics: null dataset: - name: "trl-lib/tldr" - type: "tldr" - train_split: "train[:1000]" - eval_split: "test[:1000]" + name: trl-lib/tldr + type: tldr + train_split: train[:1000] + eval_split: test[:1000] tokenizer: - padding_side: "left" + padding_side: left output: - base_dir: "./grpo_output" + base_dir: output_grpo_tldr verbose: false save_final_model: true - save_path: "./grpo_output/tldr_single" + save_path: output_grpo_tldr grpo: + parallel_training: none num_turns: 1 num_train_epochs: 1 - agent_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 logging_steps: 400 num_generations: 4 - joint_mode: aligned max_new_tokens: 256 - top_k: null + joint_mode: aligned rollout_buffer_size: 2 train_batch_size: 2 advantage_normalization: true @@ -49,7 +50,10 @@ reward_processor: shift: 0.0 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "grpo_tldr" - tags: ["grpo", "tldr", "single-agent"] + project: comlrl + entity: OpenMLRL + name: grpo_tldr + tags: + - grpo + - tldr + - single-agent diff --git a/configs/iac_arxiv_config.yaml b/configs/iac_arxiv_config.yaml index 5548117..5c006be 100644 --- a/configs/iac_arxiv_config.yaml +++ b/configs/iac_arxiv_config.yaml @@ -1,53 +1,54 @@ agent_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" + name: Qwen/Qwen3-1.7B + type: qwen temperature: 0.7 top_p: 0.9 + top_k: null max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto agents: null critic_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen3-1.7B + type: qwen max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto critics: null dataset: - name: "OpenMLRL/arXiv_abstract" - type: "arxiv" - train_split: "train[:1000]" - eval_split: "val[:1000]" + name: OpenMLRL/arXiv_abstract + type: arxiv + train_split: train[:1000] + eval_split: val[:1000] tokenizer: - padding_side: "left" + padding_side: left output: - base_dir: "./iac_output" + base_dir: output_iac_arxiv verbose: false save_final_model: true - save_path: "./iac_output/iac_arxiv" + save_path: output_iac_arxiv iac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 1 use_separate_critic: true num_train_epochs: 20 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 value_clip_range: 0.2 rollout_buffer_size: 4 train_batch_size: 4 max_new_tokens: 256 - temperature: 0.7 - top_p: 0.9 - top_k: null eval_interval: 20 eval_num_samples: 4 eval_batch_size: 1 @@ -59,7 +60,10 @@ reward_processor: shift: 0.0 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "iac_arxiv" - tags: ["iac", "arxiv", "multi-agent"] + project: comlrl + entity: OpenMLRL + name: iac_arxiv + tags: + - iac + - arxiv + - multi-agent diff --git a/configs/iac_tldr_config.yaml b/configs/iac_tldr_config.yaml index 823bb3c..9cce1f0 100644 --- a/configs/iac_tldr_config.yaml +++ b/configs/iac_tldr_config.yaml @@ -1,53 +1,54 @@ agent_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" + name: Qwen/Qwen3-1.7B + type: qwen temperature: 0.7 top_p: 0.9 + top_k: null max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto agents: null critic_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen3-1.7B + type: qwen max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto critics: null dataset: - name: "trl-lib/tldr" - type: "tldr" - train_split: "train[:1000]" - eval_split: "test[:1000]" + name: trl-lib/tldr + type: tldr + train_split: train[:1000] + eval_split: test[:1000] tokenizer: - padding_side: "left" + padding_side: left output: - base_dir: "./iac_output" + base_dir: output_iac_tldr verbose: false save_final_model: true - save_path: "./iac_output/iac_tldr" + save_path: output_iac_tldr iac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 1 use_separate_critic: true num_train_epochs: 20 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 value_clip_range: 0.2 rollout_buffer_size: 4 train_batch_size: 4 max_new_tokens: 256 - temperature: 0.7 - top_p: 0.9 - top_k: null eval_interval: 20 eval_num_samples: 4 eval_batch_size: 1 @@ -59,7 +60,10 @@ reward_processor: shift: 0.0 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "iac_tldr" - tags: ["iac", "tldr", "multi-agent"] + project: comlrl + entity: OpenMLRL + name: iac_tldr + tags: + - iac + - tldr + - multi-agent diff --git a/configs/maac_arxiv_config.yaml b/configs/maac_arxiv_config.yaml index a6d74ae..48aaab2 100644 --- a/configs/maac_arxiv_config.yaml +++ b/configs/maac_arxiv_config.yaml @@ -1,52 +1,53 @@ agent_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" + name: Qwen/Qwen3-1.7B + type: qwen temperature: 0.7 top_p: 0.9 + top_k: null max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto agents: null critic_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" - temperature: 0.7bao - top_p: 0.9 + name: Qwen/Qwen3-1.7B + type: qwen max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto critics: null dataset: - name: "OpenMLRL/arXiv_abstract" - type: "arxiv" - train_split: "train[:1100]" - eval_split: "val[:1100]" + name: OpenMLRL/arXiv_abstract + type: arxiv + train_split: train[:1100] + eval_split: val[:1100] tokenizer: - padding_side: "left" + padding_side: left output: - base_dir: "./maac_output" + base_dir: output_maac_arxiv verbose: false save_final_model: true - save_path: "./maac_output/maac_arxiv" + save_path: output_maac_arxiv maac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 1 - critic_type: "v" + critic_type: v num_train_epochs: 20 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 rollout_buffer_size: 4 train_batch_size: 4 max_new_tokens: 256 - temperature: 0.7 - top_p: 0.9 - top_k: null eval_interval: 20 eval_num_samples: 4 eval_batch_size: 1 @@ -58,7 +59,10 @@ reward_processor: shift: 0.0 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "maac_arxiv" - tags: ["maac", "arxiv", "multi-agent"] + project: comlrl + entity: OpenMLRL + name: maac_arxiv + tags: + - maac + - arxiv + - multi-agent diff --git a/configs/maac_tldr_config.yaml b/configs/maac_tldr_config.yaml index 39f9927..c2c49d3 100644 --- a/configs/maac_tldr_config.yaml +++ b/configs/maac_tldr_config.yaml @@ -1,52 +1,53 @@ agent_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" + name: Qwen/Qwen3-1.7B + type: qwen temperature: 0.7 top_p: 0.9 + top_k: null max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto agents: null critic_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen3-1.7B + type: qwen max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto critics: null dataset: - name: "trl-lib/tldr" - type: "tldr" - train_split: "train[:1100]" - eval_split: "test[:1100]" + name: trl-lib/tldr + type: tldr + train_split: train[:1100] + eval_split: test[:1100] tokenizer: - padding_side: "left" + padding_side: left output: - base_dir: "./maac_output" + base_dir: output_maac_tldr verbose: false save_final_model: true - save_path: "./maac_output/maac_tldr" + save_path: output_maac_tldr maac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 1 - critic_type: "v" + critic_type: v num_train_epochs: 20 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 rollout_buffer_size: 4 train_batch_size: 4 max_new_tokens: 256 - temperature: 0.7 - top_p: 0.9 - top_k: null eval_interval: 20 eval_num_samples: 4 eval_batch_size: 1 @@ -58,7 +59,10 @@ reward_processor: shift: 0.0 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "maac_tldr" - tags: ["maac", "tldr", "multi-agent"] + project: comlrl + entity: OpenMLRL + name: maac_tldr + tags: + - maac + - tldr + - multi-agent diff --git a/configs/magrpo_arxiv_config.yaml b/configs/magrpo_arxiv_config.yaml index 7648710..2f76c1b 100644 --- a/configs/magrpo_arxiv_config.yaml +++ b/configs/magrpo_arxiv_config.yaml @@ -1,10 +1,11 @@ agent_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" + name: Qwen/Qwen3-1.7B + type: qwen temperature: 0.7 top_p: 0.9 + top_k: null max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto agents: null @@ -13,30 +14,30 @@ critic_model: null critics: null dataset: - name: "OpenMLRL/arXiv_abstract" - type: "arxiv" - train_split: "train[:1100]" - eval_split: "val[:1100]" + name: OpenMLRL/arXiv_abstract + type: arxiv + train_split: train[:1100] + eval_split: val[:1100] tokenizer: - padding_side: "left" + padding_side: left output: - base_dir: "./magrpo_output" + base_dir: output_magrpo_arxiv verbose: false save_final_model: true - save_path: "./magrpo_output/arxiv" + save_path: output_magrpo_arxiv magrpo: + parallel_training: none + agent_devices: + - cuda:0 num_turns: 1 num_train_epochs: 2 - agent_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 logging_steps: 50 num_generations: 4 max_new_tokens: 256 - temperature: 0.7 - top_p: 0.9 - top_k: null joint_mode: aligned rollout_buffer_size: 1 train_batch_size: 1 @@ -51,7 +52,10 @@ reward_processor: shift: 0.0 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "magrpo_arxiv" - tags: ["magrpo", "arxiv", "multi-agent"] + project: comlrl + entity: OpenMLRL + name: magrpo_arxiv + tags: + - magrpo + - arxiv + - multi-agent diff --git a/configs/magrpo_tldr_config.yaml b/configs/magrpo_tldr_config.yaml index 96f81a9..f52672b 100644 --- a/configs/magrpo_tldr_config.yaml +++ b/configs/magrpo_tldr_config.yaml @@ -1,10 +1,11 @@ agent_model: - name: "Qwen/Qwen3-1.7B" - type: "qwen" + name: Qwen/Qwen3-1.7B + type: qwen temperature: 0.7 top_p: 0.9 + top_k: null max_length: 2048 - torch_dtype: "auto" + torch_dtype: auto agents: null @@ -13,30 +14,30 @@ critic_model: null critics: null dataset: - name: "trl-lib/tldr" - type: "tldr" - train_split: "train[:1100]" - eval_split: "test[:1100]" + name: trl-lib/tldr + type: tldr + train_split: train[:1100] + eval_split: test[:1100] tokenizer: - padding_side: "left" + padding_side: left output: - base_dir: "./magrpo_output" + base_dir: output_magrpo_tldr verbose: false save_final_model: true - save_path: "./magrpo_output/tldr" + save_path: output_magrpo_tldr magrpo: + parallel_training: none + agent_devices: + - cuda:0 num_turns: 1 num_train_epochs: 2 - agent_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 logging_steps: 50 num_generations: 4 max_new_tokens: 256 - temperature: 0.7 - top_p: 0.9 - top_k: null joint_mode: aligned rollout_buffer_size: 1 train_batch_size: 1 @@ -51,7 +52,10 @@ reward_processor: shift: 0.0 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "magrpo_tldr" - tags: ["magrpo", "tldr", "multi-agent"] + project: comlrl + entity: OpenMLRL + name: magrpo_tldr + tags: + - magrpo + - tldr + - multi-agent diff --git a/train_ac.py b/train_ac.py index 80139db..6488f8d 100644 --- a/train_ac.py +++ b/train_ac.py @@ -222,9 +222,9 @@ def main() -> None: f"Single-agent AC expects num_agents=1; received num_agents={num_agents}." ) - temperature = ac_cfg.get("temperature", model_config.temperature) - top_p = ac_cfg.get("top_p", model_config.top_p) - top_k = ac_cfg.get("top_k") + temperature = model_config.temperature + top_p = model_config.top_p + top_k = model_config.top_k use_separate_critic = bool(ac_cfg.get("use_separate_critic", True)) model_kwargs: Dict[str, Any] = {} if model_config.torch_dtype is not None: @@ -291,6 +291,9 @@ def main() -> None: num_agents=1, num_generations=ac_cfg.get("num_generations", 1), use_separate_critic=use_separate_critic, + parallel_training=str(ac_cfg.get("parallel_training", "none")).strip().lower(), + agent_devices=ac_cfg.get("agent_devices", None), + critic_devices=ac_cfg.get("critic_devices", None), critic_value_head_hidden_dim=ac_cfg.get("critic_value_head_hidden_dim"), value_head_hidden_dim=ac_cfg.get("value_head_hidden_dim"), discount=ac_cfg.get("discount", 0.9), diff --git a/train_grpo.py b/train_grpo.py index cc6f585..2a3a903 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -210,9 +210,9 @@ def main(): "Please set grpo.num_turns=1 (or remove the field) in the config." ) - temperature = grpo_cfg.get("temperature", model_config.temperature) - top_p = grpo_cfg.get("top_p", model_config.top_p) - top_k = grpo_cfg.get("top_k") + temperature = model_config.temperature + top_p = model_config.top_p + top_k = model_config.top_k grpo_args = MAGRPOConfig( num_turns=1, @@ -225,6 +225,8 @@ def main(): top_p=top_p, top_k=top_k, num_agents=1, + parallel_training=str(grpo_cfg.get("parallel_training", "none")).strip().lower(), + agent_devices=grpo_cfg.get("agent_devices", None), early_termination_threshold=grpo_cfg.get( "early_termination_threshold", -0.2 ), diff --git a/train_iac.py b/train_iac.py index e625e7a..7c3736d 100644 --- a/train_iac.py +++ b/train_iac.py @@ -255,9 +255,9 @@ def main() -> None: tok.add_special_tokens(model_config.special_tokens) tokenizer = tokenizers[0] - temperature = iac_cfg.get("temperature", model_config.temperature) - top_p = iac_cfg.get("top_p", model_config.top_p) - top_k = iac_cfg.get("top_k") + temperature = model_config.temperature + top_p = model_config.top_p + top_k = model_config.top_k use_separate_critic = bool(iac_cfg.get("use_separate_critic", True)) model_kwargs: Dict[str, Any] = {} if model_config.torch_dtype is not None: @@ -305,7 +305,7 @@ def main() -> None: external_transition=None, args=IACConfig( num_turns=1, - num_train_epochs=iac_cfg.get("num_train_epochs", 1), + num_train_epochs=iac_cfg.get("num_train_epochs", 20), agent_learning_rate=iac_cfg.get("agent_learning_rate", 5e-6), critic_learning_rate=iac_cfg.get("critic_learning_rate", 5e-6), value_loss_coef=iac_cfg.get("value_loss_coef", 0.6), @@ -318,13 +318,16 @@ def main() -> None: num_agents=num_agents, num_generations=iac_cfg.get("num_generations", 1), use_separate_critic=use_separate_critic, + parallel_training=str(iac_cfg.get("parallel_training", "none")).strip().lower(), + agent_devices=iac_cfg.get("agent_devices", ["cuda:0"]), + critic_devices=iac_cfg.get("critic_devices", ["cuda:0"]), critic_value_head_hidden_dim=iac_cfg.get("critic_value_head_hidden_dim"), value_head_hidden_dim=iac_cfg.get("value_head_hidden_dim"), discount=iac_cfg.get("discount", 0.9), - eval_interval=iac_cfg.get("eval_interval", 4), + eval_interval=iac_cfg.get("eval_interval", 20), eval_num_samples=iac_cfg.get("eval_num_samples", 4), eval_batch_size=iac_cfg.get("eval_batch_size", 1), - logging_steps=iac_cfg.get("logging_steps", 1), + logging_steps=iac_cfg.get("logging_steps", 50), ), train_dataset=train_dataset, eval_dataset=eval_dataset, diff --git a/train_maac.py b/train_maac.py index 51b798d..167621c 100644 --- a/train_maac.py +++ b/train_maac.py @@ -246,9 +246,9 @@ def main() -> None: tok.add_special_tokens(model_config.special_tokens) tokenizer = tokenizers[0] - temperature = maac_cfg.get("temperature", model_config.temperature) - top_p = maac_cfg.get("top_p", model_config.top_p) - top_k = maac_cfg.get("top_k") + temperature = model_config.temperature + top_p = model_config.top_p + top_k = model_config.top_k model_kwargs: Dict[str, Any] = {} if model_config.torch_dtype is not None: model_kwargs["torch_dtype"] = model_config.torch_dtype @@ -303,7 +303,7 @@ def main() -> None: external_transition=None, args=MAACConfig( num_turns=1, - num_train_epochs=maac_cfg.get("num_train_epochs", 1), + num_train_epochs=maac_cfg.get("num_train_epochs", 20), agent_learning_rate=maac_cfg.get("agent_learning_rate", 5e-6), critic_learning_rate=maac_cfg.get("critic_learning_rate", 5e-6), value_loss_coef=maac_cfg.get("value_loss_coef", 0.6), @@ -314,12 +314,15 @@ def main() -> None: top_k=top_k, num_agents=num_agents, num_generations=maac_cfg.get("num_generations", 1), + parallel_training=str(maac_cfg.get("parallel_training", "none")).strip().lower(), + agent_devices=maac_cfg.get("agent_devices", ["cuda:0"]), + critic_devices=maac_cfg.get("critic_devices", ["cuda:0"]), discount=maac_cfg.get("discount", 0.9), critic_type=maac_cfg.get("critic_type", "v"), - eval_interval=maac_cfg.get("eval_interval", 4), + eval_interval=maac_cfg.get("eval_interval", 20), eval_num_samples=maac_cfg.get("eval_num_samples", 4), eval_batch_size=maac_cfg.get("eval_batch_size", 1), - logging_steps=maac_cfg.get("logging_steps", 1), + logging_steps=maac_cfg.get("logging_steps", 50), ), train_dataset=train_dataset, eval_dataset=eval_dataset, diff --git a/train_magrpo.py b/train_magrpo.py index 47d486d..9367113 100644 --- a/train_magrpo.py +++ b/train_magrpo.py @@ -11,7 +11,7 @@ from config import Config, add_config_args, parse_overrides from datasets import load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer from loggers.arxiv_logger import ( aggregate_arxiv_metrics_for_logging, @@ -263,10 +263,6 @@ def main(): train_dataset = load_dataset(dataset_name, split=train_split) eval_dataset = load_dataset(dataset_name, split=eval_split) - model_kwargs: Dict[str, Any] = {} - if model_config.torch_dtype is not None: - model_kwargs["torch_dtype"] = model_config.torch_dtype - agents_field = config.get("agents") agent_names = None if isinstance(agents_field, (list, tuple)): @@ -302,22 +298,6 @@ def main(): tok.add_special_tokens(model_config.special_tokens) tokenizer = tokenizers[0] - if agent_names: - agents = [ - AutoModelForCausalLM.from_pretrained( - name, - **model_kwargs, - ) - for name in agent_names - ] - else: - agents = [ - AutoModelForCausalLM.from_pretrained( - model_name, - **model_kwargs, - ) - for _ in range(num_agents) - ] magrpo_cfg = config.get_section("magrpo") num_turns_cfg = magrpo_cfg.get("num_turns") if num_turns_cfg is not None and int(num_turns_cfg) != 1: @@ -326,28 +306,30 @@ def main(): "Please set magrpo.num_turns=1 (or remove the field) in the config." ) - temperature = magrpo_cfg.get("temperature", model_config.temperature) - top_p = magrpo_cfg.get("top_p", model_config.top_p) - top_k = magrpo_cfg.get("top_k") + temperature = model_config.temperature + top_p = model_config.top_p + top_k = model_config.top_k magrpo_args = MAGRPOConfig( num_turns=1, - num_train_epochs=magrpo_cfg.get("num_train_epochs", 1), + num_train_epochs=magrpo_cfg.get("num_train_epochs", 2), agent_learning_rate=magrpo_cfg.get("agent_learning_rate", 5e-6), - logging_steps=magrpo_cfg.get("logging_steps", 10), + logging_steps=magrpo_cfg.get("logging_steps", 50), num_generations=magrpo_cfg.get("num_generations", 4), max_new_tokens=magrpo_cfg.get("max_new_tokens", 256), temperature=temperature, top_p=top_p, top_k=top_k, num_agents=num_agents, + parallel_training=str(magrpo_cfg.get("parallel_training", "none")).strip().lower(), + agent_devices=magrpo_cfg.get("agent_devices", ["cuda:0"]), early_termination_threshold=magrpo_cfg.get( "early_termination_threshold", -0.2 ), - rollout_buffer_size=magrpo_cfg.get("rollout_buffer_size", 2), - train_batch_size=magrpo_cfg.get("train_batch_size"), + rollout_buffer_size=magrpo_cfg.get("rollout_buffer_size", 1), + train_batch_size=magrpo_cfg.get("train_batch_size", 1), advantage_normalization=magrpo_cfg.get("advantage_normalization", True), - eval_interval=magrpo_cfg.get("eval_interval", 4), + eval_interval=magrpo_cfg.get("eval_interval", 20), eval_num_samples=magrpo_cfg.get("eval_num_samples", 4), eval_batch_size=magrpo_cfg.get("eval_batch_size", 1), ) @@ -405,8 +387,12 @@ def main(): trainer_kwargs: Dict[str, Any] = { "agent_model": model_name or None, - "agents": agents, + "agents": agent_names, "num_agents": num_agents, + "model_config": { + "torch_dtype": model_config.torch_dtype, + "special_tokens": model_config.special_tokens, + }, "reward_func": reward_func, "formatters": formatters, "args": magrpo_args,