-
Notifications
You must be signed in to change notification settings - Fork 5
refactor agentjet job system #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c621922
544bff1
e9bc0e1
a191f71
621d235
b150420
642eddc
f532854
79407c4
39ab72e
76d0e3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -170,3 +170,5 @@ modelscope_cache | |
| prompts | ||
| swarmexp | ||
| swarmlog | ||
| werewolves_swarm | ||
| .claude | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,14 +8,13 @@ | |
| from __future__ import annotations | ||
|
|
||
| import os | ||
| import time | ||
| import yaml | ||
| import tempfile | ||
| from types import SimpleNamespace | ||
| from typing import Any, Callable, Union | ||
|
|
||
| import yaml | ||
| from types import SimpleNamespace | ||
| from typing import Any, Callable, Union, cast | ||
| from loguru import logger | ||
|
|
||
|
|
||
| from ajet.default_config.ajet_default import Config | ||
| from ajet.utils.config_utils import ( | ||
| expand_ajet_hierarchical_config, | ||
|
|
@@ -30,70 +29,118 @@ | |
| setup_environment_vars, | ||
| ) | ||
|
|
||
| DEFAULT_DIR = "saved_experiments" | ||
|
|
||
| def override_current_yaml_value_if_given(override_value, current_value): | ||
| if override_value is not None: | ||
| return override_value | ||
| else: | ||
| return current_value | ||
|
|
||
| def _set_nested_attr(obj, attr_path: str, value): | ||
| keys = attr_path.split(".") | ||
| for key in keys[:-1]: | ||
| obj = getattr(obj, key) | ||
| setattr(obj, keys[-1], value) | ||
|
|
||
| def _get_nested_attr(obj, attr_path: str): | ||
| for key in attr_path.split("."): | ||
| obj = getattr(obj, key) | ||
| return obj | ||
|
|
||
| class AgentJetJob: | ||
| """Lightweight builder that launches AgentJet training as a subprocess.""" | ||
| """ | ||
| arg: base_yaml_config + **kwargs (yaml config, then override with kwargs) | ||
| arg: base_yaml_config (yaml config) | ||
| arg: **kwargs (yaml config, then override with kwargs) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| backbone: str = "verl", | ||
| model: str = "Qwen/Qwen2___5-7B-Instruct", | ||
| n_gpu: int = 8, | ||
| algorithm: str = "grpo", | ||
| project_name="ajet-swarm", | ||
| experiment_name="test", | ||
| n_gpu_for_infer: int | None = None, # only for trinity backbone | ||
| num_repeat: int = 8, | ||
| batch_size: int = 32, | ||
| swarm_mode: bool = True, | ||
| sample_collection_method: str = "rollout_until_finish_enough_tasks", | ||
| *kwargs, | ||
| base_yaml_config: str | None = None, | ||
| experiment_dir: str | None = None, | ||
| project_name: str | None = None, | ||
| experiment_name: str | None = None, | ||
| n_gpu: int | None = None, | ||
| model: str | None = None, | ||
| algorithm: str | None = None, | ||
| num_repeat: int | None = None, | ||
| batch_size: int | None = None, | ||
| swarm_mode: bool | None = None, | ||
| swarm_mode_sample_collection_method: str | None = None, | ||
| max_env_worker: int | None = None, | ||
| backbone: str | None = None, | ||
| ) -> None: | ||
| self.backbone = backbone | ||
| self.exp_dir = DEFAULT_DIR | ||
| self.project_name = project_name | ||
| self.exp_name = experiment_name | ||
| self.sample_collection_method = sample_collection_method | ||
| if swarm_mode: | ||
| default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) | ||
|
|
||
| if base_yaml_config is None: | ||
| base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) | ||
| else: | ||
| default_yaml = None | ||
| self.config_as_dict: dict = self.build_job_from_yaml(default_yaml) | ||
| logger.warning(f"Reading config from {base_yaml_config}.") | ||
| time.sleep(1) | ||
| self.config_as_dict: dict = self.build_job_from_yaml(base_yaml_config) | ||
| self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict) | ||
|
Comment on lines
+79
to
80
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a critical bug in the This will lead to an
This initialization flow needs to be re-architected to resolve this circular dependency. A possible approach is to first determine the values for parameters required for loading the configuration (like |
||
|
|
||
| self.config.ajet.experiment_name = experiment_name | ||
| self.config.ajet.backbone = backbone | ||
| self.config.ajet.model.path = model | ||
| self.config.ajet.trainer_common.n_gpus_per_node = n_gpu | ||
| self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm | ||
| self.config.ajet.rollout.num_repeat = num_repeat | ||
| self.config.ajet.data.train_batch_size = batch_size | ||
| self.config.ajet.enable_swarm_mode = swarm_mode | ||
| self.config.ajet.swarm_mode_sample_collection_method = sample_collection_method | ||
| if n_gpu_for_infer is None and backbone == "trinity": | ||
| raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.") | ||
| if (n_gpu_for_infer is not None) and backbone == "verl": | ||
| raise ValueError("n_gpu_for_infer is only for trinity backbone, please set it to `None`.") | ||
| else: | ||
| if backbone == "trinity": | ||
| assert isinstance(n_gpu_for_infer, int), f"`n_gpu_for_infer` should be int, got {type(n_gpu_for_infer)}." | ||
| assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`." | ||
| self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer | ||
| self.config.ajet.rollout.tensor_model_parallel_size = 1 | ||
| self.base_yaml_config: str = cast(str, base_yaml_config) # currently may be None, but will be set later | ||
| self.experiment_dir: str = cast(str, experiment_dir) | ||
| self.project_name: str = cast(str, project_name) | ||
| self.experiment_name: str = cast(str, experiment_name) | ||
| self.n_gpu: int = cast(int, n_gpu) | ||
| self.model: str = cast(str, model) | ||
| self.algorithm: str = cast(str, algorithm) | ||
| self.num_repeat: int = cast(int, num_repeat) | ||
| self.batch_size: int = cast(int, batch_size) | ||
| self.swarm_mode: bool = cast(bool, swarm_mode) | ||
| self.swarm_mode_sample_collection_method: str = cast(str, swarm_mode_sample_collection_method) | ||
| self.max_env_worker: int = cast(int, max_env_worker) | ||
| self.backbone: str = cast(str, backbone) | ||
|
|
||
| # see `ajet/default_config/ajet_ts_default.yaml` | ||
| overrides = { | ||
| "ajet.experiment_dir": "experiment_dir", | ||
| "ajet.project_name": "project_name", | ||
| "ajet.experiment_name": "experiment_name", | ||
| "ajet.model.path": "model", | ||
| "ajet.trainer_common.n_gpus_per_node": "n_gpu", | ||
| "ajet.trainer_common.algorithm.adv_estimator": "algorithm", | ||
| "ajet.rollout.num_repeat": "num_repeat", | ||
| "ajet.data.train_batch_size": "batch_size", | ||
| "ajet.enable_swarm_mode": "swarm_mode", | ||
| "ajet.swarm_mode_sample_collection_method": "swarm_mode_sample_collection_method", | ||
| "ajet.rollout.max_env_worker": "max_env_worker", | ||
| "ajet.backbone": "backbone", | ||
| } | ||
|
|
||
| # if any value given in kwargs, override the corresponding value in config | ||
| for attr_path, override_val in overrides.items(): | ||
| # get value from yaml config | ||
| # >> e.g. current_model = self.config.model.path | ||
| current_val = _get_nested_attr(self.config, attr_path) | ||
|
|
||
| # if override_val (given in __init__) is not None, use it to override the value from yaml config | ||
| # >> e.g. new_model = self.model if (self.model is not None) else current_model | ||
| new_val = override_current_yaml_value_if_given(getattr(self, override_val), current_val) | ||
|
|
||
| # write final value to `self.config`` | ||
| # >> e.g. self.config.model.path = new_model | ||
| _set_nested_attr(self.config, attr_path, new_val) | ||
|
|
||
| # write final value to `self` | ||
| # >> e.g. self.model = new_model | ||
| setattr(self, override_val, new_val) | ||
|
|
||
| if self.backbone == "trinity": | ||
| raise NotImplementedError("Trinity backbone is not yet supported in AgentJetJob.") | ||
|
|
||
|
|
||
| def build_job_from_yaml(self, yaml_path: str | None) -> dict: | ||
| self.config_as_dict = read_ajet_hierarchical_config( | ||
| yaml_path, | ||
| exp_name=self.exp_name, | ||
| backbone=self.backbone, | ||
| write_to=None, | ||
| exp_dir=self.exp_dir, | ||
| ) | ||
| self.config_as_dict = expand_ajet_hierarchical_config(self.config_as_dict, write_to=None) | ||
| logger.info(f"Built AgentJet job config: {yaml_path}") | ||
| return self.config_as_dict | ||
|
|
||
|
|
||
| def dump_job_as_yaml(self, yaml_path: str) -> str: | ||
| if os.path.dirname(yaml_path): | ||
| os.makedirs(os.path.dirname(yaml_path), exist_ok=True) | ||
|
|
@@ -102,6 +149,7 @@ def dump_job_as_yaml(self, yaml_path: str) -> str: | |
| logger.info(f"Saved training config to {yaml_path}") | ||
| return yaml_path | ||
|
|
||
|
|
||
| def set_workflow( | ||
| self, workflow: Union[str, Callable[..., Any]], ensure_reward_in_workflow: bool = False | ||
| ) -> "AgentJetJob": | ||
|
|
@@ -110,6 +158,7 @@ def set_workflow( | |
| # ensure_reward_in_workflow | ||
| return self | ||
|
|
||
|
|
||
| def set_data( | ||
| self, | ||
| type: str, | ||
|
|
@@ -136,60 +185,3 @@ def set_data( | |
|
|
||
| return self | ||
|
|
||
| def tune(self, *args, **kwargs) -> "AgentJetJob": | ||
| import ray | ||
| ast_cfg = self.config.ajet | ||
| if not ast_cfg.rollout or not ast_cfg.rollout.user_workflow: | ||
| raise ValueError("Workflow must be set via set_workflow before tuning.") | ||
| if not ast_cfg.task_reader: | ||
| raise ValueError("Data source must be set via set_data before tuning.") | ||
|
|
||
| backbone = self.config.ajet.backbone | ||
| exp_dir = self.config.ajet.experiment_dir | ||
|
|
||
| with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".yaml") as temp_yaml: | ||
| yaml_path = temp_yaml.name | ||
| self.dump_job_as_yaml(yaml_path) | ||
| args = SimpleNamespace( | ||
| conf=yaml_path, | ||
| backbone=backbone, | ||
| exp_dir=exp_dir, | ||
| with_logview=False, | ||
| debug=False, | ||
| ) | ||
|
|
||
| if args.backbone != "debug": | ||
| # Enforce GPU availability and free memory threshold before proceeding | ||
| check_avail_gpu(min_free_ratio=0.95) | ||
|
|
||
| # finalize experiment config | ||
| main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config( | ||
| yaml_path, exp_dir, backbone | ||
| ) | ||
|
|
||
| # setup environment variables for ray | ||
| env = setup_environment_vars(args, exp_config, main_yaml_fp) | ||
|
|
||
| # start ray if not already started | ||
| if not ray.is_initialized(): | ||
| from ajet.utils.launch_utils import start_ray_service | ||
|
|
||
| start_ray_service(args, env) | ||
| else: | ||
| raise RuntimeError( | ||
| "Ray is already initialized. Please shutdown existing Ray instance before starting a new tuning job." | ||
| ) | ||
|
|
||
| # start training process | ||
| if args.conf and main_yaml_fp and exe_exp_base and exp_config: | ||
| execute_training_process( | ||
| args, | ||
| get_backbone_target(args.backbone), | ||
| main_yaml_fp, | ||
| exe_exp_base, | ||
| main_yaml_fp, | ||
| env, | ||
| exp_config, | ||
| ) | ||
|
|
||
| return self | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
time.sleep(1)here is a code smell. Sleeps like this are often used to work around race conditions or to ensure a log message is visible before a potential crash. This can hide underlying issues and make the code's behavior dependent on timing. It would be better to identify and fix the root cause rather than using a sleep.