diff --git a/config.py b/config.py index 8a65304..ebfb5a5 100644 --- a/config.py +++ b/config.py @@ -4,6 +4,8 @@ """ import argparse +import os +import sys from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Optional @@ -11,26 +13,74 @@ import yaml +REPO_ROOT = os.path.dirname(os.path.abspath(__file__)) +COMLRL_ROOT = os.path.join(os.path.dirname(REPO_ROOT), "CoMLRL") +if COMLRL_ROOT not in sys.path: + sys.path.insert(0, COMLRL_ROOT) + + @dataclass(frozen=True) class ModelConfig: """Configuration for model loading and generation.""" name: str type: str = "qwen" - temperature: float = 0.7 - top_p: float = 0.9 + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None max_length: int = 2048 special_tokens: Dict[str, str] = field(default_factory=dict) torch_dtype: Optional[str] = None @classmethod - def from_dict(cls, config_dict: Dict[str, Any]) -> "ModelConfig": + def from_dict( + cls, + config_dict: Dict[str, Any], + *, + require_sampling: bool = True, + ) -> "ModelConfig": """Create ModelConfig from dictionary.""" + if require_sampling: + missing = [ + key + for key in ("temperature", "top_p", "top_k") + if key not in config_dict + ] + if missing: + raise ValueError( + f"agent_model is missing required sampling fields: {', '.join(missing)}" + ) + + def _as_optional_float(value: Any) -> Optional[float]: + if value is None: + return None + try: + return float(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid float value: {value}") from exc + + def _as_optional_int(value: Any) -> Optional[int]: + if value is None: + return None + if isinstance(value, str) and value.strip().lower() in ("none", "null", ""): + return None + try: + return int(float(value)) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid int value: {value}") from exc + + temperature = _as_optional_float(config_dict.get("temperature")) + top_p = _as_optional_float(config_dict.get("top_p")) + top_k = _as_optional_int(config_dict.get("top_k")) + if require_sampling and (temperature is None or top_p is None): + raise ValueError("agent_model.temperature and agent_model.top_p must be non-null.") + return cls( name=config_dict.get("name", ""), type=config_dict.get("type", "qwen"), - temperature=config_dict.get("temperature", 0.7), - top_p=config_dict.get("top_p", 0.9), + temperature=temperature, + top_p=top_p, + top_k=top_k, max_length=config_dict.get("max_length", 2048), special_tokens=config_dict.get("special_tokens", {}), torch_dtype=( @@ -74,7 +124,7 @@ def get_agent_model_config(self) -> ModelConfig: model_section = self.get_section("agent_model") if not model_section: raise ValueError("No 'agent_model' section found in configuration") - return ModelConfig.from_dict(model_section) + return ModelConfig.from_dict(model_section, require_sampling=True) def get_critic_model_config(self, required: bool = True) -> Optional[ModelConfig]: """Get critic model configuration as ModelConfig object.""" @@ -83,7 +133,7 @@ def get_critic_model_config(self, required: bool = True) -> Optional[ModelConfig if required: raise ValueError("No 'critic_model' section found in configuration") return None - return ModelConfig.from_dict(critic_section) + return ModelConfig.from_dict(critic_section, require_sampling=False) def update(self, updates: Dict[str, Any]): """Update configuration with new values (deep merge).""" diff --git a/configs/ac_che_config.yaml b/configs/ac_che_config.yaml index 507f846..fd7cf0d 100644 --- a/configs/ac_che_config.yaml +++ b/configs/ac_che_config.yaml @@ -1,61 +1,62 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 critics: null dataset: - name: "OpenMLRL/CoopHumanEval" - type: "coophumaneval" - train_split: "train[16:]" - eval_split: "train[:16]" + name: OpenMLRL/CoopHumanEval + type: coophumaneval + train_split: train[16:] + eval_split: train[:16] output: - base_dir: "output" + base_dir: output_ac_che verbose: false save_final_model: false - save_path: "output/maac_final" + save_path: output_ac_che external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 ac: + parallel_training: none num_turns: 2 num_train_epochs: 80 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 - advantage_normalization: true rollout_buffer_size: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: -0.2 + advantage_normalization: true eval_interval: 40 eval_num_samples: 4 eval_batch_size: 1 reward_shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "ac_coophumaneval" - dir: "output" - tags: ["ac", "coophumaneval", "single-agent", "turns_2"] + project: comlrl + entity: OpenMLRL + name: ac_coophumaneval + dir: output_ac_che + tags: + - ac + - coophumaneval + - single-agent + - turns_2 diff --git a/configs/ac_he_config.yaml b/configs/ac_he_config.yaml index cfa73a3..b17532a 100644 --- a/configs/ac_he_config.yaml +++ b/configs/ac_he_config.yaml @@ -1,61 +1,62 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 critics: null dataset: - name: "openai/openai_humaneval" - type: "humaneval" - train_split: "test[33:163]" - eval_split: "test[:32]" + name: openai/openai_humaneval + type: humaneval + train_split: test[33:163] + eval_split: test[:32] output: - base_dir: "output" + base_dir: output_ac_he verbose: false save_final_model: false - save_path: "output/maac_final" + save_path: output_ac_he external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 ac: + parallel_training: none num_turns: 2 num_train_epochs: 80 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 - advantage_normalization: true rollout_buffer_size: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: -0.2 + advantage_normalization: true eval_interval: 40 eval_num_samples: 4 eval_batch_size: 1 reward_shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "ac_humaneval" - dir: "output" - tags: ["ac", "humaneval", "single-agent", "turns_2"] + project: comlrl + entity: OpenMLRL + name: ac_humaneval + dir: output_ac_he + tags: + - ac + - humaneval + - single-agent + - turns_2 diff --git a/configs/ac_mbpp_config.yaml b/configs/ac_mbpp_config.yaml index 4a502f4..1b49439 100644 --- a/configs/ac_mbpp_config.yaml +++ b/configs/ac_mbpp_config.yaml @@ -1,61 +1,62 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 critics: null dataset: - name: "OpenMLRL/MBPP" - type: "mbpp" - train_split: "test[15:65]" - eval_split: "test[:15]" + name: OpenMLRL/MBPP + type: mbpp + train_split: test[15:65] + eval_split: test[:15] output: - base_dir: "output" + base_dir: output_ac_mbpp verbose: false save_final_model: false - save_path: "output/maac_final" + save_path: output_ac_mbpp external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 ac: + parallel_training: none num_turns: 2 num_train_epochs: 80 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 - advantage_normalization: true rollout_buffer_size: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: -0.2 + advantage_normalization: true eval_interval: 40 eval_num_samples: 4 eval_batch_size: 1 reward_shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "ac_mbpp" - dir: "output" - tags: ["ac", "mbpp", "single-agent", "turns_2"] + project: comlrl + entity: OpenMLRL + name: ac_mbpp + dir: output_ac_mbpp + tags: + - ac + - mbpp + - single-agent + - turns_2 diff --git a/configs/grpo_che_config.yaml b/configs/grpo_che_config.yaml index 8eae75e..14f1b68 100644 --- a/configs/grpo_che_config.yaml +++ b/configs/grpo_che_config.yaml @@ -1,10 +1,11 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.8 + top_p: 0.95 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null @@ -13,31 +14,30 @@ critic_model: null critics: null dataset: - name: "OpenMLRL/CoopHumanEval" - type: "coophumaneval" - train_split: "train[16:]" - eval_split: "train[:16]" + name: OpenMLRL/CoopHumanEval + type: coophumaneval + train_split: train[16:] + eval_split: train[:16] output: - base_dir: "output" - save_final_model: false + base_dir: output_grpo_che verbose: false + save_final_model: false external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 grpo: + parallel_training: none num_turns: 2 num_train_epochs: 20 - agent_learning_rate: 2.0e-5 + agent_learning_rate: 2.0e-05 logging_steps: 50 num_generations: 4 max_new_tokens: 256 - temperature: 0.8 - top_p: 0.95 - top_k: null discount: 0.9 + joint_mode: aligned early_termination_threshold: -0.1 rollout_buffer_size: 2 train_batch_size: 2 @@ -45,12 +45,13 @@ grpo: eval_interval: 4 eval_num_samples: 4 eval_batch_size: 1 - joint_mode: aligned reward_shift: -2.1 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "grpo_coophumaneval" - dir: "output" - tags: ["grpo", "coophumaneval"] + project: comlrl + entity: OpenMLRL + name: grpo_coophumaneval + dir: output_grpo_che + tags: + - grpo + - coophumaneval diff --git a/configs/grpo_he_config.yaml b/configs/grpo_he_config.yaml index 9d9d131..51d9639 100644 --- a/configs/grpo_he_config.yaml +++ b/configs/grpo_he_config.yaml @@ -1,10 +1,11 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.8 + top_p: 0.95 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null @@ -13,31 +14,30 @@ critic_model: null critics: null dataset: - name: "openai/openai_humaneval" - type: "humaneval" - train_split: "test[33:163]" - eval_split: "test[:32]" + name: openai/openai_humaneval + type: humaneval + train_split: test[33:163] + eval_split: test[:32] output: - base_dir: "output" - save_final_model: false + base_dir: output_grpo_he verbose: false + save_final_model: false external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 grpo: + parallel_training: none num_turns: 2 num_train_epochs: 6 - agent_learning_rate: 2.0e-5 + agent_learning_rate: 2.0e-05 logging_steps: 50 num_generations: 4 max_new_tokens: 256 - temperature: 0.8 - top_p: 0.95 - top_k: null discount: 0.9 + joint_mode: aligned early_termination_threshold: -0.1 rollout_buffer_size: 2 train_batch_size: 2 @@ -45,12 +45,13 @@ grpo: eval_interval: 4 eval_num_samples: 4 eval_batch_size: 1 - joint_mode: aligned reward_shift: -2.1 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "grpo_humaneval" - dir: "output" - tags: ["grpo", "humaneval"] + project: comlrl + entity: OpenMLRL + name: grpo_humaneval + dir: output_grpo_he + tags: + - grpo + - humaneval diff --git a/configs/grpo_mbpp_config.yaml b/configs/grpo_mbpp_config.yaml index c850a52..6d2ddda 100644 --- a/configs/grpo_mbpp_config.yaml +++ b/configs/grpo_mbpp_config.yaml @@ -1,10 +1,11 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.8 + top_p: 0.95 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null @@ -13,31 +14,30 @@ critic_model: null critics: null dataset: - name: "OpenMLRL/MBPP" - type: "mbpp" - train_split: "test[15:65]" - eval_split: "test[:15]" + name: OpenMLRL/MBPP + type: mbpp + train_split: test[15:65] + eval_split: test[:15] output: - base_dir: "output" - save_final_model: false + base_dir: output_grpo_mbpp verbose: false + save_final_model: false external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 grpo: + parallel_training: none num_turns: 2 num_train_epochs: 8 - agent_learning_rate: 3.0e-5 + agent_learning_rate: 3.0e-05 logging_steps: 50 num_generations: 4 max_new_tokens: 256 - temperature: 0.8 - top_p: 0.95 - top_k: null discount: 0.9 + joint_mode: aligned early_termination_threshold: -0.1 rollout_buffer_size: 2 train_batch_size: 2 @@ -45,12 +45,13 @@ grpo: eval_interval: 4 eval_num_samples: 4 eval_batch_size: 1 - joint_mode: aligned reward_shift: -2.1 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "grpo_mbpp" - dir: "output" - tags: ["grpo", "mbpp"] + project: comlrl + entity: OpenMLRL + name: grpo_mbpp + dir: output_grpo_mbpp + tags: + - grpo + - mbpp diff --git a/configs/iac_che_config.yaml b/configs/iac_che_config.yaml index 0f2c3e7..e8944ea 100644 --- a/configs/iac_che_config.yaml +++ b/configs/iac_che_config.yaml @@ -1,54 +1,56 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 critics: null dataset: - name: "OpenMLRL/CoopHumanEval" - type: "coophumaneval" - train_split: "train[16:]" - eval_split: "train[:16]" + name: OpenMLRL/CoopHumanEval + type: coophumaneval + train_split: train[16:] + eval_split: train[:16] output: - base_dir: "output" + base_dir: output_iac_che verbose: false save_final_model: false - save_path: "output/maac_final" + save_path: output_iac_che external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 + external_prompt_passthrough: false iac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 2 use_separate_critic: true num_train_epochs: 80 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 value_clip_range: 0.2 rollout_buffer_size: 4 train_batch_size: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: -0.2 eval_interval: 40 @@ -58,8 +60,12 @@ iac: reward_shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "iac_coophumaneval" - dir: "output" - tags: ["iac", "coophumaneval", "multi-agent", "turns_2"] + project: comlrl + entity: OpenMLRL + name: iac_coophumaneval + dir: output_iac_che + tags: + - iac + - coophumaneval + - multi-agent + - turns_2 diff --git a/configs/iac_he_config.yaml b/configs/iac_he_config.yaml index 59d34f0..4a9c2ac 100644 --- a/configs/iac_he_config.yaml +++ b/configs/iac_he_config.yaml @@ -1,54 +1,56 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 critics: null dataset: - name: "openai/openai_humaneval" - type: "humaneval" - train_split: "test[33:163]" - eval_split: "test[:32]" + name: openai/openai_humaneval + type: humaneval + train_split: test[33:163] + eval_split: test[:32] output: - base_dir: "output" + base_dir: output_iac_he verbose: false save_final_model: false - save_path: "output/maac_final" + save_path: output_iac_he external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 + external_prompt_passthrough: false iac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 2 use_separate_critic: true num_train_epochs: 80 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 value_clip_range: 0.2 rollout_buffer_size: 4 train_batch_size: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: -0.2 eval_interval: 40 @@ -58,8 +60,12 @@ iac: reward_shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "iac_humaneval" - dir: "output" - tags: ["iac", "humaneval", "multi-agent", "turns_2"] + project: comlrl + entity: OpenMLRL + name: iac_humaneval + dir: output_iac_he + tags: + - iac + - humaneval + - multi-agent + - turns_2 diff --git a/configs/iac_mbpp_config.yaml b/configs/iac_mbpp_config.yaml index b4e0400..209fe8b 100644 --- a/configs/iac_mbpp_config.yaml +++ b/configs/iac_mbpp_config.yaml @@ -1,54 +1,56 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 critics: null dataset: - name: "OpenMLRL/MBPP" - type: "mbpp" - train_split: "test[15:65]" - eval_split: "test[:15]" + name: OpenMLRL/MBPP + type: mbpp + train_split: test[15:65] + eval_split: test[:15] output: - base_dir: "output" + base_dir: output_iac_mbpp verbose: false save_final_model: false - save_path: "output/maac_final" + save_path: output_iac_mbpp external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 + external_prompt_passthrough: false iac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 2 use_separate_critic: true num_train_epochs: 80 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 value_clip_range: 0.2 rollout_buffer_size: 4 train_batch_size: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: -0.2 eval_interval: 40 @@ -58,8 +60,12 @@ iac: reward_shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "iac_mbpp" - dir: "output" - tags: ["iac", "mbpp", "multi-agent", "turns_2"] + project: comlrl + entity: OpenMLRL + name: iac_mbpp + dir: output_iac_mbpp + tags: + - iac + - mbpp + - multi-agent + - turns_2 diff --git a/configs/maac_che_config.yaml b/configs/maac_che_config.yaml index ee1e828..3ca2c20 100644 --- a/configs/maac_che_config.yaml +++ b/configs/maac_che_config.yaml @@ -1,53 +1,55 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 critics: null dataset: - name: "OpenMLRL/CoopHumanEval" - type: "coophumaneval" - train_split: "train[16:]" - eval_split: "train[:16]" + name: OpenMLRL/CoopHumanEval + type: coophumaneval + train_split: train[16:] + eval_split: train[:16] output: - base_dir: "output" + base_dir: output_maac_che verbose: false save_final_model: false - save_path: "output/maac_final" + save_path: output_maac_che external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 + external_prompt_passthrough: false maac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 2 - critic_type: "v" + critic_type: v num_train_epochs: 80 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 rollout_buffer_size: 4 train_batch_size: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: -0.2 eval_interval: 40 @@ -57,8 +59,12 @@ maac: reward_shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "maac_coophumaneval" - dir: "output" - tags: ["maac", "coophumaneval", "multi-agent", "turns_2"] + project: comlrl + entity: OpenMLRL + name: maac_coophumaneval + dir: output_maac_che + tags: + - maac + - coophumaneval + - multi-agent + - turns_2 diff --git a/configs/maac_he_config.yaml b/configs/maac_he_config.yaml index 704c03d..7c6c60e 100644 --- a/configs/maac_he_config.yaml +++ b/configs/maac_he_config.yaml @@ -1,53 +1,55 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 critics: null dataset: - name: "openai/openai_humaneval" - type: "humaneval" - train_split: "test[33:163]" - eval_split: "test[:32]" + name: openai/openai_humaneval + type: humaneval + train_split: test[33:163] + eval_split: test[:32] output: - base_dir: "output" + base_dir: output_maac_he verbose: false save_final_model: false - save_path: "output/maac_final" + save_path: output_maac_he external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 + external_prompt_passthrough: false maac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 2 - critic_type: "v" + critic_type: v num_train_epochs: 80 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 rollout_buffer_size: 4 train_batch_size: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: -0.2 eval_interval: 40 @@ -57,8 +59,12 @@ maac: reward_shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "maac_humaneval" - dir: "output" - tags: ["maac", "humaneval", "multi-agent", "turns_2"] + project: comlrl + entity: OpenMLRL + name: maac_humaneval + dir: output_maac_he + tags: + - maac + - humaneval + - multi-agent + - turns_2 diff --git a/configs/maac_mbpp_config.yaml b/configs/maac_mbpp_config.yaml index f038a89..0ef3f41 100644 --- a/configs/maac_mbpp_config.yaml +++ b/configs/maac_mbpp_config.yaml @@ -1,53 +1,55 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 agents: null critic_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen max_length: 2048 - torch_dtype: "bfloat16" + torch_dtype: bfloat16 critics: null dataset: - name: "OpenMLRL/MBPP" - type: "mbpp" - train_split: "test[15:65]" - eval_split: "test[:15]" + name: OpenMLRL/MBPP + type: mbpp + train_split: test[15:65] + eval_split: test[:15] output: - base_dir: "output" + base_dir: output_maac_mbpp verbose: false save_final_model: false - save_path: "output/maac_final" + save_path: output_maac_mbpp external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 + external_prompt_passthrough: false maac: + parallel_training: none + agent_devices: + - cuda:0 + critic_devices: + - cuda:0 num_agents: 2 num_turns: 2 - critic_type: "v" + critic_type: v num_train_epochs: 80 - agent_learning_rate: 5.0e-6 - critic_learning_rate: 5.0e-6 + agent_learning_rate: 5.0e-06 + critic_learning_rate: 5.0e-06 value_loss_coef: 0.6 rollout_buffer_size: 4 train_batch_size: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 early_termination_threshold: -0.2 eval_interval: 40 @@ -57,8 +59,12 @@ maac: reward_shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "maac_mbpp" - dir: "output" - tags: ["maac", "mbpp", "multi-agent", "turns_2"] + project: comlrl + entity: OpenMLRL + name: maac_mbpp + dir: output_maac_mbpp + tags: + - maac + - mbpp + - multi-agent + - turns_2 diff --git a/configs/magrpo_che_config.yaml b/configs/magrpo_che_config.yaml index 5ad628e..3f3e013 100644 --- a/configs/magrpo_che_config.yaml +++ b/configs/magrpo_che_config.yaml @@ -1,11 +1,12 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" special_tokens: {} + torch_dtype: bfloat16 agents: null @@ -14,33 +15,34 @@ critic_model: null critics: null dataset: - name: "OpenMLRL/CoopHumanEval" - type: "coophumaneval" - train_split: "train[16:]" - eval_split: "train[:16]" + name: OpenMLRL/CoopHumanEval + type: coophumaneval + train_split: train[16:] + eval_split: train[:16] seed: 42 output: - base_dir: "output" - save_final_model: false + base_dir: output_magrpo_che verbose: false + save_final_model: false external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 + external_prompt_passthrough: false magrpo: + parallel_training: none + agent_devices: + - cuda:0 num_agents: 2 num_turns: 2 num_train_epochs: 8 - agent_learning_rate: 2.0e-5 + agent_learning_rate: 2.0e-05 logging_steps: 50 num_generations: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 joint_mode: aligned early_termination_threshold: -0.2 @@ -57,8 +59,12 @@ reward_processor: shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "magrpo_coophumaneval" - dir: "output" - tags: ["magrpo", "coophumaneval", "multi-agent", "turns_2"] + project: comlrl + entity: OpenMLRL + name: magrpo_coophumaneval + dir: output_magrpo_che + tags: + - magrpo + - coophumaneval + - multi-agent + - turns_2 diff --git a/configs/magrpo_he_config.yaml b/configs/magrpo_he_config.yaml index 2155bca..0525ab1 100644 --- a/configs/magrpo_he_config.yaml +++ b/configs/magrpo_he_config.yaml @@ -1,11 +1,12 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" special_tokens: {} + torch_dtype: bfloat16 agents: null @@ -14,33 +15,34 @@ critic_model: null critics: null dataset: - name: "openai/openai_humaneval" - type: "humaneval" - train_split: "test[33:163]" - eval_split: "test[:32]" + name: openai/openai_humaneval + type: humaneval + train_split: test[33:163] + eval_split: test[:32] seed: 42 output: - base_dir: "output" - save_final_model: false + base_dir: output_magrpo_he verbose: false + save_final_model: false external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 + external_prompt_passthrough: false magrpo: + parallel_training: none + agent_devices: + - cuda:0 num_agents: 2 num_turns: 2 num_train_epochs: 6 - agent_learning_rate: 2.0e-5 + agent_learning_rate: 2.0e-05 logging_steps: 50 num_generations: 4 max_new_tokens: 256 - temperature: 0.6 - top_p: 0.6 - top_k: null discount: 0.9 joint_mode: aligned early_termination_threshold: -0.2 @@ -57,8 +59,11 @@ reward_processor: shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "magrpo_humaneval" - dir: "output" - tags: ["magrpo", "humaneval", "multi-agent"] + project: comlrl + entity: OpenMLRL + name: magrpo_humaneval + dir: output_magrpo_he + tags: + - magrpo + - humaneval + - multi-agent diff --git a/configs/magrpo_mbpp_config.yaml b/configs/magrpo_mbpp_config.yaml index abb4b12..b7d64f1 100644 --- a/configs/magrpo_mbpp_config.yaml +++ b/configs/magrpo_mbpp_config.yaml @@ -1,11 +1,12 @@ agent_model: - name: "Qwen/Qwen2.5-Coder-3B" - type: "qwen" - temperature: 0.7 - top_p: 0.9 + name: Qwen/Qwen2.5-Coder-3B + type: qwen + temperature: 0.6 + top_p: 0.6 + top_k: null max_length: 2048 - torch_dtype: "bfloat16" special_tokens: {} + torch_dtype: bfloat16 agents: null @@ -14,33 +15,34 @@ critic_model: null critics: null dataset: - name: "OpenMLRL/MBPP" - type: "mbpp" - train_split: "test[15:65]" - eval_split: "test[:15]" + name: OpenMLRL/MBPP + type: mbpp + train_split: test[15:65] + eval_split: test[:15] seed: 42 output: - base_dir: "output" - save_final_model: false + base_dir: output_magrpo_mbpp verbose: false + save_final_model: false external: - mode: "level_feedback" + mode: level_feedback sandbox_slice: 1 + external_prompt_passthrough: false magrpo: + parallel_training: none + agent_devices: + - cuda:0 num_agents: 2 num_turns: 2 num_train_epochs: 8 - agent_learning_rate: 3.0e-5 + agent_learning_rate: 3.0e-05 logging_steps: 50 num_generations: 4 max_new_tokens: 256 - temperature: 0.8 - top_p: 0.95 - top_k: null discount: 0.9 joint_mode: aligned early_termination_threshold: -0.2 @@ -57,8 +59,11 @@ reward_processor: shift: -4 wandb: - project: "comlrl" - entity: "OpenMLRL" - name: "magrpo_mbpp" - dir: "output" - tags: ["magrpo", "mbpp", "multi-agent"] + project: comlrl + entity: OpenMLRL + name: magrpo_mbpp + dir: output_magrpo_mbpp + tags: + - magrpo + - mbpp + - multi-agent diff --git a/external/level_feedback.py b/external/level_feedback.py index a51ad5e..505f581 100644 --- a/external/level_feedback.py +++ b/external/level_feedback.py @@ -1,5 +1,6 @@ import ast import signal +import math from typing import Dict, List, Tuple, Optional from rewards.code_utils import ( @@ -36,7 +37,7 @@ def _run_tests( MAX_TIMEOUTS = 3 # Prepare execution environment - exec_globals: Dict[str, object] = {} + exec_globals: Dict[str, object] = {"math": math} try: exec(combined_code, exec_globals) except Exception as e: diff --git a/loggers/code_logger.py b/loggers/code_logger.py index c32c9d4..08da5cf 100644 --- a/loggers/code_logger.py +++ b/loggers/code_logger.py @@ -1,4 +1,5 @@ import signal +import math @@ -135,7 +136,7 @@ def code_reward_logger( try: # Load code definitions - exec_globals = {} + exec_globals = {"math": math} exec(combined_code, exec_globals) # Run individual test cases diff --git a/rewards/code_rewards.py b/rewards/code_rewards.py index f5021ea..a36f65a 100644 --- a/rewards/code_rewards.py +++ b/rewards/code_rewards.py @@ -1,5 +1,6 @@ import re import signal +import math from typing import List import builtins @@ -188,7 +189,7 @@ def print(*args, **kwargs): # type: ignore try: # Create execution environment (no timeout needed for function definitions) - exec_globals = {} + exec_globals = {"math": math} exec(combined_code, exec_globals) print("✅ Code definitions loaded successfully") diff --git a/train_ac.py b/train_ac.py index 720fd51..f872d20 100644 --- a/train_ac.py +++ b/train_ac.py @@ -311,9 +311,9 @@ def _resolver(prompt: str): reward_processor = RewardProcessors.shift(value=shift_val_f) # AC-specific config - top_k = ac_cfg.get("top_k") - temperature = ac_cfg.get("temperature", 0.6) - top_p = ac_cfg.get("top_p", 0.6) + top_k = model_config.top_k + temperature = model_config.temperature + top_p = model_config.top_p use_separate_critic = bool(ac_cfg.get("use_separate_critic", True)) model_kwargs: Dict[str, Any] = {} if model_config.torch_dtype is not None: @@ -384,6 +384,9 @@ def external_transition_fn( num_agents=1, num_generations=ac_cfg.get("num_generations", 1), use_separate_critic=use_separate_critic, + parallel_training=str(ac_cfg.get("parallel_training", "none")).strip().lower(), + agent_devices=ac_cfg.get("agent_devices", None), + critic_devices=ac_cfg.get("critic_devices", None), critic_value_head_hidden_dim=ac_cfg.get("critic_value_head_hidden_dim"), value_head_hidden_dim=ac_cfg.get("value_head_hidden_dim"), discount=ac_cfg.get("discount", 0.9), @@ -440,6 +443,9 @@ def _build_wandb_config( ): wandb_section = config.get_section("wandb") if hasattr(config, "get_section") else {} ac_section = config.get_section("ac") if hasattr(config, "get_section") else {} + model_section = ( + config.get_section("agent_model") if hasattr(config, "get_section") else {} + ) output_section = ( config.get_section("output") if hasattr(config, "get_section") else {} ) @@ -467,9 +473,9 @@ def _build_wandb_config( "trainer": { "num_turns": ac_section.get("num_turns", 1), "max_new_tokens": ac_section.get("max_new_tokens", 256), - "temperature": ac_section.get("temperature", 0.6), - "top_p": ac_section.get("top_p", 0.6), - "top_k": ac_section.get("top_k"), + "temperature": model_section.get("temperature"), + "top_p": model_section.get("top_p"), + "top_k": model_section.get("top_k"), "discount": ac_section.get("discount", 0.9), "use_separate_critic": ac_section.get("use_separate_critic", True), }, diff --git a/train_grpo.py b/train_grpo.py index c4a8359..e69f24a 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -242,9 +242,9 @@ def main(): ) print("Model loaded successfully!") - temperature = grpo_config.get("temperature", model_config.temperature) - top_p = grpo_config.get("top_p", model_config.top_p) - top_k = grpo_config.get("top_k") + temperature = model_config.temperature + top_p = model_config.top_p + top_k = model_config.top_k external_cfg = config.get_section("external") if hasattr(config, "get_section") else {} # Register external context resolver using dataset items (for external modes) @@ -337,6 +337,8 @@ def _resolver(prompt: str): temperature=temperature, top_p=top_p, top_k=top_k, + parallel_training=str(grpo_config.get("parallel_training", "none")).strip().lower(), + agent_devices=grpo_config.get("agent_devices", None), discount=grpo_config.get("discount", 0.9), joint_mode=grpo_config.get("joint_mode", "aligned"), early_termination_threshold=grpo_config.get( diff --git a/train_iac.py b/train_iac.py index 7393d53..52c6c53 100644 --- a/train_iac.py +++ b/train_iac.py @@ -257,6 +257,17 @@ def main() -> None: config.save(config_save_path) external_cfg = config.get_section("external") if hasattr(config, "get_section") else {} + _ext_passthrough = external_cfg.get("external_prompt_passthrough", False) + if isinstance(_ext_passthrough, str): + external_prompt_passthrough = _ext_passthrough.strip().lower() in { + "1", + "true", + "yes", + "y", + "on", + } + else: + external_prompt_passthrough = bool(_ext_passthrough) def _normalize_prompt(p: str) -> str: return " ".join((p or "").split()).strip() @@ -352,9 +363,9 @@ def _resolver(prompt: str): if shift_val_f is not None: reward_processor = RewardProcessors.shift(value=shift_val_f) - top_k = iac_cfg.get("top_k") - temperature = iac_cfg.get("temperature", 0.6) - top_p = iac_cfg.get("top_p", 0.6) + top_k = model_config.top_k + temperature = model_config.temperature + top_p = model_config.top_p use_separate_critic = bool(iac_cfg.get("use_separate_critic", True)) model_kwargs: Dict[str, Any] = {} if model_config.torch_dtype is not None: @@ -365,9 +376,9 @@ def _resolver(prompt: str): critic_model_kwargs = dict(model_kwargs) if critic_config is not None and critic_config.torch_dtype is not None: critic_model_kwargs["torch_dtype"] = critic_config.torch_dtype - num_turns = iac_cfg.get("num_turns", 1) + num_turns = iac_cfg.get("num_turns", 2) - rollout_buffer_size = iac_cfg.get("rollout_buffer_size", 8) + rollout_buffer_size = iac_cfg.get("rollout_buffer_size", 4) external_transition_fn = None if num_turns > 1: @@ -404,7 +415,7 @@ def external_transition_fn( external_transition=external_transition_fn, args=IACConfig( num_turns=num_turns, - num_train_epochs=iac_cfg.get("num_train_epochs", 40), + num_train_epochs=iac_cfg.get("num_train_epochs", 80), agent_learning_rate=iac_cfg.get("agent_learning_rate", 5e-6), critic_learning_rate=iac_cfg.get("critic_learning_rate", 5e-6), value_loss_coef=iac_cfg.get("value_loss_coef", 0.6), @@ -417,16 +428,20 @@ def external_transition_fn( num_agents=num_agents, num_generations=iac_cfg.get("num_generations", 1), use_separate_critic=use_separate_critic, + parallel_training=str(iac_cfg.get("parallel_training", "none")).strip().lower(), + agent_devices=iac_cfg.get("agent_devices", ["cuda:0"]), + critic_devices=iac_cfg.get("critic_devices", ["cuda:0"]), critic_value_head_hidden_dim=iac_cfg.get("critic_value_head_hidden_dim"), value_head_hidden_dim=iac_cfg.get("value_head_hidden_dim"), discount=iac_cfg.get("discount", 0.9), + external_prompt_passthrough=external_prompt_passthrough, early_termination_threshold=iac_cfg.get( "early_termination_threshold", -0.2 ), - eval_interval=iac_cfg.get("eval_interval", 16), + eval_interval=iac_cfg.get("eval_interval", 40), eval_num_samples=iac_cfg.get("eval_num_samples", 4), eval_batch_size=iac_cfg.get("eval_batch_size", 1), - logging_steps=iac_cfg.get("logging_steps", 1), + logging_steps=iac_cfg.get("logging_steps", 10), ), train_dataset=train_dataset, eval_dataset=eval_dataset, @@ -474,6 +489,9 @@ def _build_wandb_config( ): wandb_section = config.get_section("wandb") if hasattr(config, "get_section") else {} iac_section = config.get_section("iac") if hasattr(config, "get_section") else {} + model_section = ( + config.get_section("agent_model") if hasattr(config, "get_section") else {} + ) output_section = ( config.get_section("output") if hasattr(config, "get_section") else {} ) @@ -501,9 +519,9 @@ def _build_wandb_config( "trainer": { "num_turns": iac_section.get("num_turns", 1), "max_new_tokens": iac_section.get("max_new_tokens", 256), - "temperature": iac_section.get("temperature", 0.6), - "top_p": iac_section.get("top_p", 0.6), - "top_k": iac_section.get("top_k"), + "temperature": model_section.get("temperature"), + "top_p": model_section.get("top_p"), + "top_k": model_section.get("top_k"), "use_separate_critic": iac_section.get( "use_separate_critic", False ), diff --git a/train_maac.py b/train_maac.py index d0a8683..62cded5 100644 --- a/train_maac.py +++ b/train_maac.py @@ -248,6 +248,18 @@ def main() -> None: config.save(config_save_path) external_cfg = config.get_section("external") if hasattr(config, "get_section") else {} + _ext_passthrough = external_cfg.get("external_prompt_passthrough", False) + if isinstance(_ext_passthrough, str): + external_prompt_passthrough = _ext_passthrough.strip().lower() in { + "1", + "true", + "yes", + "y", + "on", + } + else: + external_prompt_passthrough = bool(_ext_passthrough) + def _normalize_prompt(p: str) -> str: return " ".join((p or "").split()).strip() @@ -342,9 +354,9 @@ def _resolver(prompt: str): if shift_val_f is not None: reward_processor = RewardProcessors.shift(value=shift_val_f) - top_k = maac_cfg.get("top_k") - temperature = maac_cfg.get("temperature", 0.6) - top_p = maac_cfg.get("top_p", 0.6) + top_k = model_config.top_k + temperature = model_config.temperature + top_p = model_config.top_p model_kwargs: Dict[str, Any] = {} if model_config.torch_dtype is not None: model_kwargs["torch_dtype"] = model_config.torch_dtype @@ -400,26 +412,30 @@ def external_transition_fn( external_transition=external_transition_fn, args=MAACConfig( num_turns=num_turns, - num_train_epochs=maac_cfg.get("num_train_epochs", 40), + num_train_epochs=maac_cfg.get("num_train_epochs", 80), agent_learning_rate=maac_cfg.get("agent_learning_rate", 5e-6), critic_learning_rate=maac_cfg.get("critic_learning_rate", 5e-6), value_loss_coef=maac_cfg.get("value_loss_coef", 0.6), - rollout_buffer_size=maac_cfg.get("rollout_buffer_size", 8), + rollout_buffer_size=maac_cfg.get("rollout_buffer_size", 4), max_new_tokens=maac_cfg.get("max_new_tokens", 256), temperature=temperature, top_p=top_p, top_k=top_k, num_agents=num_agents, num_generations=maac_cfg.get("num_generations", 1), + parallel_training=str(maac_cfg.get("parallel_training", "none")).strip().lower(), + agent_devices=maac_cfg.get("agent_devices", ["cuda:0"]), + critic_devices=maac_cfg.get("critic_devices", ["cuda:0"]), discount=discount, + external_prompt_passthrough=external_prompt_passthrough, critic_type=maac_cfg.get("critic_type", "v"), early_termination_threshold=maac_cfg.get( "early_termination_threshold", -0.2 ), - eval_interval=maac_cfg.get("eval_interval", 16), + eval_interval=maac_cfg.get("eval_interval", 40), eval_num_samples=maac_cfg.get("eval_num_samples", 4), eval_batch_size=maac_cfg.get("eval_batch_size", 1), - logging_steps=maac_cfg.get("logging_steps", 1), + logging_steps=maac_cfg.get("logging_steps", 10), ), train_dataset=train_dataset, eval_dataset=eval_dataset, @@ -463,6 +479,9 @@ def _build_wandb_config( ): wandb_section = config.get_section("wandb") if hasattr(config, "get_section") else {} maac_section = config.get_section("maac") if hasattr(config, "get_section") else {} + model_section = ( + config.get_section("agent_model") if hasattr(config, "get_section") else {} + ) output_section = ( config.get_section("output") if hasattr(config, "get_section") else {} ) @@ -490,9 +509,9 @@ def _build_wandb_config( "trainer": { "num_turns": maac_section.get("num_turns", 2), "max_new_tokens": maac_section.get("max_new_tokens", 256), - "temperature": maac_section.get("temperature", 0.6), - "top_p": maac_section.get("top_p", 0.6), - "top_k": maac_section.get("top_k"), + "temperature": model_section.get("temperature"), + "top_p": model_section.get("top_p"), + "top_k": model_section.get("top_k"), "discount": maac_section.get("discount", 0.9), "critic_type": maac_section.get("critic_type", "v"), }, diff --git a/train_magrpo.py b/train_magrpo.py index ce77160..415c37f 100644 --- a/train_magrpo.py +++ b/train_magrpo.py @@ -17,7 +17,7 @@ from config import Config, add_config_args, parse_overrides from datasets import load_dataset import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers import AutoTokenizer from loggers.mt_code_logger import ( aggregate_mt_humaneval_metrics_for_logging, @@ -300,9 +300,21 @@ def main(): ) tokenizer = tokenizers[0] - temperature = magrpo_config.get("temperature", 0.6) - top_p = magrpo_config.get("top_p", 0.6) + temperature = model_config.temperature + top_p = model_config.top_p + top_k = model_config.top_k external_cfg = config.get_section("external") if hasattr(config, "get_section") else {} + _ext_passthrough = external_cfg.get("external_prompt_passthrough", False) + if isinstance(_ext_passthrough, str): + external_prompt_passthrough = _ext_passthrough.strip().lower() in { + "1", + "true", + "yes", + "y", + "on", + } + else: + external_prompt_passthrough = bool(_ext_passthrough) # Register external context resolver using dataset items def _normalize_prompt(p: str) -> str: @@ -394,33 +406,36 @@ def _resolver(prompt: str): magrpo_args_kwargs = { "num_turns": num_turns, - "num_train_epochs": magrpo_config.get("num_train_epochs", 20), - "agent_learning_rate": magrpo_config.get("agent_learning_rate", 5e-6), + "num_train_epochs": magrpo_config.get("num_train_epochs", 8), + "agent_learning_rate": magrpo_config.get("agent_learning_rate", 2e-5), "logging_steps": magrpo_config.get("logging_steps", 50), "num_generations": magrpo_config.get("num_generations", 4), "max_new_tokens": magrpo_config.get("max_new_tokens", 256), "temperature": temperature, "top_p": top_p, + "top_k": top_k, } - if "top_k" in magrpo_config: - magrpo_args_kwargs["top_k"] = magrpo_config.get("top_k") magrpo_args_kwargs.update( { "num_agents": num_agents, + "parallel_training": str( + magrpo_config.get("parallel_training", "none") + ).strip().lower(), + "agent_devices": magrpo_config.get("agent_devices", ["cuda:0"]), "discount": magrpo_config.get("discount", 0.9), "joint_mode": magrpo_config.get("joint_mode", "aligned"), "early_termination_threshold": magrpo_config.get( "early_termination_threshold", -0.2 ), - "rollout_buffer_size": magrpo_config.get("rollout_buffer_size", 2), - "train_batch_size": magrpo_config.get("train_batch_size", None), + "rollout_buffer_size": magrpo_config.get("rollout_buffer_size", 4), + "train_batch_size": magrpo_config.get("train_batch_size", 4), "advantage_normalization": magrpo_config.get( "advantage_normalization", True ), - "eval_interval": magrpo_config.get("eval_interval", 16), + "eval_interval": magrpo_config.get("eval_interval", 4), "eval_num_samples": magrpo_config.get("eval_num_samples", 4), "eval_batch_size": magrpo_config.get("eval_batch_size", 1), - "external_prompt_passthrough": True, + "external_prompt_passthrough": external_prompt_passthrough, } ) magrpo_args = MAGRPOConfig(**magrpo_args_kwargs) @@ -475,26 +490,6 @@ def _resolver(prompt: str): code_rewards.VERBOSE = bool(output_verbose) import external as external_mod external_mod.VERBOSE = bool(output_verbose) - model_kwargs: Dict[str, Any] = {} - if model_config.torch_dtype is not None: - model_kwargs["torch_dtype"] = model_config.torch_dtype - if agent_names: - agents = [ - AutoModelForCausalLM.from_pretrained( - name, - **model_kwargs, - ) - for name in agent_names - ] - else: - agents = [ - AutoModelForCausalLM.from_pretrained( - model_name, - **model_kwargs, - ) - for _ in range(num_agents) - ] - reward_processor = None if config.get("reward_processor.enabled", True): scale_factor = config.get("reward_processor.scale_factor", 1.0) @@ -510,11 +505,17 @@ def _resolver(prompt: str): prev = reward_processor reward_processor = (lambda p=prev, s=shift_proc: (lambda x: s(p(x))))() # Build trainer kwargs (grouped: model/data, reward/formatting, logging, args) + model_arg = model_name or None + agents_arg = agent_names trainer_kwargs = { - "agent_model": model_name or None, - "agents": agents, + "agent_model": model_arg, + "agents": agents_arg, "num_agents": num_agents, "tokenizer": tokenizers if agent_names else tokenizer, + "model_config": { + "torch_dtype": model_config.torch_dtype, + "special_tokens": model_config.special_tokens, + }, "train_dataset": train_dataset, "eval_dataset": eval_dataset, "reward_func": reward_func,