From c621922d276c45628b5bffea97f5259417bb1bce Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Mon, 2 Mar 2026 12:26:21 +0800 Subject: [PATCH 01/11] refactor --- ajet/backbone/main_trinity.py | 2 +- ajet/backbone/main_verl.py | 4 +- ajet/backbone/main_vllm.py | 2 +- ajet/backbone/trainer_verl.py | 4 +- ajet/backbone/warm_up.py | 2 +- ajet/copilot/job.py | 208 +++++------ ajet/default_config/ajet_default.yaml | 2 +- ajet/default_config/ajet_ts_default.yaml | 26 +- ajet/launcher.py | 4 +- ajet/swarm_cli.py | 4 +- ajet/tuner.py | 8 +- .../tuner_lib/experimental/as_swarm_client.py | 4 +- ajet/utils/core_env_vars.py | 2 +- docs/en/support_agentscope.md | 2 +- docs/en/support_http.md | 2 +- docs/en/support_langchain.md | 2 +- docs/en/support_oaisdk.md | 2 +- .../benchmark_appworld.yaml | 4 +- .../benchmark_appworld_2nodes.yaml | 4 +- .../benchmark_appworld_oai_sdk.yaml | 4 +- .../benchmark_countdown.yaml | 4 +- .../benchmark_learn2ask.yaml | 4 +- .../bench/benchmark_math/benchmark_math.yaml | 4 +- .../benchmark_math_oai_sdk.yaml | 4 +- .../benchmark_math_raw_http.yaml | 4 +- tutorial/example_appworld/appworld.yaml | 4 +- .../example_appworld/appworld_oai_sdk.yaml | 4 +- tutorial/example_countdown/countdown.yaml | 4 +- .../example_deep_finance/deep_finance.yaml | 4 +- .../yaml_template/deep_finance_template.yaml | 4 +- .../deep_finance_template_maxlen.yaml | 4 +- .../yaml_template/infer.yaml | 4 +- .../example_feedback_tracing.yaml | 4 +- tutorial/example_learn2ask/learn2ask.yaml | 4 +- tutorial/example_rubrics_judge/r_judge.yaml | 4 +- .../example_werewolves_swarm/agent_roll.py | 108 ++++++ tutorial/example_werewolves_swarm/game.py | 351 ++++++++++++++++++ tutorial/example_werewolves_swarm/prompt.py | 176 +++++++++ tutorial/example_werewolves_swarm/start.py | 155 ++++++++ .../structured_model.py | 88 +++++ tutorial/example_werewolves_swarm/utils.py | 161 ++++++++ .../example_werewolves_swarm/werewolves.md | 9 + .../example_werewolves_swarm/werewolves.yaml | 76 ++++ 43 files changed, 1305 insertions(+), 171 deletions(-) create mode 100644 tutorial/example_werewolves_swarm/agent_roll.py create mode 100644 tutorial/example_werewolves_swarm/game.py create mode 100644 tutorial/example_werewolves_swarm/prompt.py create mode 100644 tutorial/example_werewolves_swarm/start.py create mode 100644 tutorial/example_werewolves_swarm/structured_model.py create mode 100644 tutorial/example_werewolves_swarm/utils.py create mode 100644 tutorial/example_werewolves_swarm/werewolves.md create mode 100644 tutorial/example_werewolves_swarm/werewolves.yaml diff --git a/ajet/backbone/main_trinity.py b/ajet/backbone/main_trinity.py index 7956305f..dc06c21c 100644 --- a/ajet/backbone/main_trinity.py +++ b/ajet/backbone/main_trinity.py @@ -52,7 +52,7 @@ def patched_trainer_get_actor(cls, config: Config): Explorer.get_actor = classmethod(patched_explorer_get_actor) Trainer.get_actor = classmethod(patched_trainer_get_actor) - if ajet_config.ajet.enable_experimental_interchange_server: + if ajet_config.ajet.enable_interchange_server: from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server start_interchange_server(ajet_config) diff --git a/ajet/backbone/main_verl.py b/ajet/backbone/main_verl.py index 1e8c9c01..8eebb95f 100644 --- a/ajet/backbone/main_verl.py +++ b/ajet/backbone/main_verl.py @@ -67,7 +67,7 @@ def run_ppo(config: DictConfig) -> None: def on_shutdown(): if ray.is_initialized(): ray.shutdown() - if config.ajet.enable_experimental_interchange_server: + if config.ajet.enable_interchange_server: if config.ajet.enable_swarm_mode: from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status print("Changing engine status to OFFLINE before shutdown...") @@ -250,7 +250,7 @@ def run(self, config): from ajet.backbone.trainer_verl import AjetRayPPOTrainer - if config.ajet.enable_experimental_interchange_server: + if config.ajet.enable_interchange_server: from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server start_interchange_server(config) diff --git a/ajet/backbone/main_vllm.py b/ajet/backbone/main_vllm.py index edbcb161..4e8b717c 100644 --- a/ajet/backbone/main_vllm.py +++ b/ajet/backbone/main_vllm.py @@ -186,7 +186,7 @@ def main(config): os.environ.update(runtime_env["env_vars"]) # atexit.register(lambda: print("Process exiting, performing cleanup...")) - if config.ajet.enable_experimental_interchange_server: + if config.ajet.enable_interchange_server: from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server start_interchange_server(config) if config.ajet.enable_swarm_mode: diff --git a/ajet/backbone/trainer_verl.py b/ajet/backbone/trainer_verl.py index 5ca7ce83..28f09f95 100644 --- a/ajet/backbone/trainer_verl.py +++ b/ajet/backbone/trainer_verl.py @@ -457,7 +457,7 @@ def init_workers(self): ) def _update_interchange_server_status_flag(self, status: str): - if self.config.ajet.enable_experimental_interchange_server: + if self.config.ajet.enable_interchange_server: if self.config.ajet.enable_swarm_mode: from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status http_change_engine_status(self.config, status, global_step=self.global_steps) @@ -858,7 +858,7 @@ def fit(self): # noqa: C901 self.global_steps += 1 # # when enabled oai request interchange, we need to clear the cache from time to time - # if self.config.ajet.enable_experimental_interchange_server: + # if self.config.ajet.enable_interchange_server: # from ajet.tuner_lib.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear # ensure_dat_interchange_server_cache_clear() diff --git a/ajet/backbone/warm_up.py b/ajet/backbone/warm_up.py index 192f4a0b..6e261b6a 100644 --- a/ajet/backbone/warm_up.py +++ b/ajet/backbone/warm_up.py @@ -54,7 +54,7 @@ def warm_up_task_judge_when_needed(config): def clean_up_tmp_ajet_dir(config): """Clean up old IPC socket files in /tmp/ajet directory.""" import time - if config.ajet.enable_experimental_interchange_server is False: + if config.ajet.enable_interchange_server is False: return tmp_dir = "/tmp/ajet" diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index 665c41cd..f74b44ac 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -8,14 +8,13 @@ from __future__ import annotations import os +import time +import yaml import tempfile -from types import SimpleNamespace -from typing import Any, Callable, Union -import yaml +from types import SimpleNamespace +from typing import Any, Callable, Union, cast from loguru import logger - - from ajet.default_config.ajet_default import Config from ajet.utils.config_utils import ( expand_ajet_hierarchical_config, @@ -30,70 +29,122 @@ setup_environment_vars, ) -DEFAULT_DIR = "saved_experiments" + +def override_current_yaml_value_if_given(override_value, current_value): + if override_value is not None: + return override_value + else: + return current_value + +def _set_nested_attr(obj, attr_path: str, value): + keys = attr_path.split(".") + for key in keys[:-1]: + obj = getattr(obj, key) + setattr(obj, keys[-1], value) + +def _get_nested_attr(obj, attr_path: str): + for key in attr_path.split("."): + obj = getattr(obj, key) + return obj class AgentJetJob: - """Lightweight builder that launches AgentJet training as a subprocess.""" + """ + arg: base_yaml_config + **kwargs (yaml config, then override with kwargs) + arg: base_yaml_config (yaml config) + arg: **kwargs (yaml config, then override with kwargs) + """ def __init__( self, - backbone: str = "verl", - model: str = "Qwen/Qwen2___5-7B-Instruct", - n_gpu: int = 8, - algorithm: str = "grpo", - project_name="ajet-swarm", - experiment_name="test", - n_gpu_for_infer: int | None = None, # only for trinity backbone - num_repeat: int = 8, - batch_size: int = 32, - swarm_mode: bool = True, - sample_collection_method: str = "rollout_until_finish_enough_tasks", - *kwargs, + base_yaml_config: str | None = None, + experiment_dir: str | None = None, + project_name: str | None = None, + experiment_name: str | None = None, + n_gpu: int | None = None, + model: str | None = None, + algorithm: str | None = None, + num_repeat: int | None = None, + batch_size: int | None = None, + swarm_mode: bool | None = None, + swarm_mode_sample_collection_method: str | None = None, + max_env_worker: int | None = None, + backbone: str | None = None, ) -> None: - self.backbone = backbone - self.exp_dir = DEFAULT_DIR - self.project_name = project_name - self.exp_name = experiment_name - self.sample_collection_method = sample_collection_method - if swarm_mode: - default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) + + if base_yaml_config is None: + base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml")) else: - default_yaml = None - self.config_as_dict: dict = self.build_job_from_yaml(default_yaml) + logger.warning(f"Reading config from {base_yaml_config}.") + time.sleep(1) + self.config_as_dict: dict = self.build_job_from_yaml(base_yaml_config) self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict) - self.config.ajet.experiment_name = experiment_name - self.config.ajet.backbone = backbone - self.config.ajet.model.path = model - self.config.ajet.trainer_common.n_gpus_per_node = n_gpu - self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm - self.config.ajet.rollout.num_repeat = num_repeat - self.config.ajet.data.train_batch_size = batch_size - self.config.ajet.enable_swarm_mode = swarm_mode - self.config.ajet.swarm_mode_sample_collection_method = sample_collection_method - if n_gpu_for_infer is None and backbone == "trinity": - raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.") - if (n_gpu_for_infer is not None) and backbone == "verl": - raise ValueError("n_gpu_for_infer is only for trinity backbone, please set it to `None`.") - else: - if backbone == "trinity": - assert isinstance(n_gpu_for_infer, int), f"`n_gpu_for_infer` should be int, got {type(n_gpu_for_infer)}." - assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`." - self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer - self.config.ajet.rollout.tensor_model_parallel_size = 1 + self.base_yaml_config: str = cast(str, base_yaml_config) # currently may be None, but will be set later + self.experiment_dir: str = cast(str, experiment_dir) + self.project_name: str = cast(str, project_name) + self.experiment_name: str = cast(str, experiment_name) + self.n_gpu: int = cast(int, n_gpu) + self.model: str = cast(str, model) + self.algorithm: str = cast(str, algorithm) + self.num_repeat: int = cast(int, num_repeat) + self.batch_size: int = cast(int, batch_size) + self.swarm_mode: bool = cast(bool, swarm_mode) + self.swarm_mode_sample_collection_method: str = cast(str, swarm_mode_sample_collection_method) + self.max_env_worker: int = cast(int, max_env_worker) + self.backbone: str = cast(str, backbone) + + # see `ajet/default_config/ajet_ts_default.yaml` + overrides = { + "ajet.experiment_dir": "experiment_dir", + "ajet.project_name": "project_name", + "ajet.experiment_name": "experiment_name", + "ajet.model.path": "model", + "ajet.trainer_common.n_gpus_per_node": "n_gpu", + "ajet.trainer_common.algorithm.adv_estimator": "algorithm", + "ajet.rollout.num_repeat": "num_repeat", + "ajet.data.train_batch_size": "batch_size", + "ajet.enable_swarm_mode": "swarm_mode", + "ajet.swarm_mode_sample_collection_method": "swarm_mode_sample_collection_method", + "ajet.rollout.max_env_worker": "max_env_worker", + "ajet.backbone": "backbone", + } + + # if any value given in kwargs, override the corresponding value in config + for attr_path, override_val in overrides.items(): + # get value from yaml config + # >> e.g. current_model = self.config.model.path + current_val = _get_nested_attr(self.config, attr_path) + + # if override_val (given in __init__) is not None, use it to override the value from yaml config + # >> e.g. new_model = self.model if (self.model is not None) else current_model + new_val = override_current_yaml_value_if_given(getattr(self, override_val), current_val) + + # write final value to `self.config`` + # >> e.g. self.config.model.path = new_model + _set_nested_attr(self.config, attr_path, new_val) + + # write final value to `self` + # >> e.g. self.model = new_model + setattr(self, override_val, new_val) + + if self.backbone == "trinity": + raise NotImplementedError("Trinity backbone is not yet supported in AgentJetJob.") + def build_job_from_yaml(self, yaml_path: str | None) -> dict: + assert self.experiment_dir is not None, "experiment_dir must be provided either in constructor or in yaml config." self.config_as_dict = read_ajet_hierarchical_config( yaml_path, - exp_name=self.exp_name, + exp_name=self.experiment_name, backbone=self.backbone, write_to=None, - exp_dir=self.exp_dir, + exp_dir=self.experiment_dir, ) self.config_as_dict = expand_ajet_hierarchical_config(self.config_as_dict, write_to=None) logger.info(f"Built AgentJet job config: {yaml_path}") return self.config_as_dict + def dump_job_as_yaml(self, yaml_path: str) -> str: if os.path.dirname(yaml_path): os.makedirs(os.path.dirname(yaml_path), exist_ok=True) @@ -102,6 +153,7 @@ def dump_job_as_yaml(self, yaml_path: str) -> str: logger.info(f"Saved training config to {yaml_path}") return yaml_path + def set_workflow( self, workflow: Union[str, Callable[..., Any]], ensure_reward_in_workflow: bool = False ) -> "AgentJetJob": @@ -110,6 +162,7 @@ def set_workflow( # ensure_reward_in_workflow return self + def set_data( self, type: str, @@ -136,60 +189,3 @@ def set_data( return self - def tune(self, *args, **kwargs) -> "AgentJetJob": - import ray - ast_cfg = self.config.ajet - if not ast_cfg.rollout or not ast_cfg.rollout.user_workflow: - raise ValueError("Workflow must be set via set_workflow before tuning.") - if not ast_cfg.task_reader: - raise ValueError("Data source must be set via set_data before tuning.") - - backbone = self.config.ajet.backbone - exp_dir = self.config.ajet.experiment_dir - - with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".yaml") as temp_yaml: - yaml_path = temp_yaml.name - self.dump_job_as_yaml(yaml_path) - args = SimpleNamespace( - conf=yaml_path, - backbone=backbone, - exp_dir=exp_dir, - with_logview=False, - debug=False, - ) - - if args.backbone != "debug": - # Enforce GPU availability and free memory threshold before proceeding - check_avail_gpu(min_free_ratio=0.95) - - # finalize experiment config - main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config( - yaml_path, exp_dir, backbone - ) - - # setup environment variables for ray - env = setup_environment_vars(args, exp_config, main_yaml_fp) - - # start ray if not already started - if not ray.is_initialized(): - from ajet.utils.launch_utils import start_ray_service - - start_ray_service(args, env) - else: - raise RuntimeError( - "Ray is already initialized. Please shutdown existing Ray instance before starting a new tuning job." - ) - - # start training process - if args.conf and main_yaml_fp and exe_exp_base and exp_config: - execute_training_process( - args, - get_backbone_target(args.backbone), - main_yaml_fp, - exe_exp_base, - main_yaml_fp, - env, - exp_config, - ) - - return self diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 0f164971..531eb321 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -284,7 +284,7 @@ ajet: # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature enable_swarm_mode: False # both swarm / oai share the same interchange server - enable_experimental_interchange_server: False + enable_interchange_server: False # interchange server configuration interchange_server: interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_ts_default.yaml index 1db9bdd5..bde6f48b 100644 --- a/ajet/default_config/ajet_ts_default.yaml +++ b/ajet/default_config/ajet_ts_default.yaml @@ -3,11 +3,11 @@ ajet: project_name: "ajet_default_project" experiment_name: "read_yaml_name" experiment_dir: "auto" # {exp-dir}/{experiment_name} - backbone: debug # `debug` or `trinity` or `verl` + backbone: verl model: # which model should be trained - path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-3B-Instruct + path: "Qwen/Qwen2.5-7B-Instruct" rollout: # the path to the workflow class @@ -21,7 +21,7 @@ ajet: judge_protocol: null # reward must come from remote user agent workflow, so set to null # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature - enable_experimental_interchange_server: True + enable_interchange_server: True # train in cloud, run episode locally enable_swarm_mode: True # both swarm / oai share the same interchange server @@ -44,21 +44,35 @@ ajet: # (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.) swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks" + data: + # max number of tokens for prompt + max_prompt_length: 3000 + # max number of tokens for response + max_response_length: 15000 + # how many tasks per training batch + train_batch_size: 32 + # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) + rollout: # maximum number of parallel environments / simulate workers max_env_worker: 128 + # how many times a task should be repeated + num_repeat: 4 trainer_common: logger: tensorboard + n_gpus_per_node: 8 + algorithm: + adv_estimator: grpo -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - - verl_default # verl inherit 1/1 + - verl_default - ajet_default - _self_ diff --git a/ajet/launcher.py b/ajet/launcher.py index c9d5705b..f3638e74 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -150,8 +150,8 @@ def start_swarm_server(env, config): assert config.ajet.enable_swarm_mode, ( "Please enable_swarm_mode in config to start swarm server." ) - assert config.ajet.enable_experimental_interchange_server, ( - "Please enable_experimental_interchange_server in config to start swarm server." + assert config.ajet.enable_interchange_server, ( + "Please enable_interchange_server in config to start swarm server." ) from ajet.tuner_lib.experimental.as_oai_model_server import ( start_interchange_server, diff --git a/ajet/swarm_cli.py b/ajet/swarm_cli.py index eb5dd866..8ce5a570 100644 --- a/ajet/swarm_cli.py +++ b/ajet/swarm_cli.py @@ -24,8 +24,8 @@ def start_swarm_server(env, config, port): assert config.ajet.enable_swarm_mode, ( "Please enable_swarm_mode in config to start swarm server." ) - assert config.ajet.enable_experimental_interchange_server, ( - "Please enable_experimental_interchange_server in config to start swarm server." + assert config.ajet.enable_interchange_server, ( + "Please enable_interchange_server in config to start swarm server." ) # Set the port in the config diff --git a/ajet/tuner.py b/ajet/tuner.py index b780d44a..45a54425 100644 --- a/ajet/tuner.py +++ b/ajet/tuner.py @@ -23,7 +23,7 @@ def __init__( self.context_tracker = context_tracker self.llm_inference_fn = llm_inference_fn self.target2proxy_registry: dict[str, dict[str,TunerTypeUnion]] = {} - self.enable_interchange_server = config.ajet.enable_experimental_interchange_server + self.enable_interchange_server = config.ajet.enable_interchange_server if self.enable_interchange_server: self.proxy_client_started = False @@ -102,10 +102,10 @@ def as_oai_baseurl_apikey( ``` """ - assert self.enable_interchange_server, "Please enable `ajet.enable_experimental_interchange_server` in yaml config to use `as_oai_baseurl_apikey` feature." + assert self.enable_interchange_server, "Please enable `ajet.enable_interchange_server` in yaml config to use `as_oai_baseurl_apikey` feature." if self.proxy_client_started is False: self.proxy_client_started = True - self._enable_experimental_interchange_server(self.llm_inference_fn) + self._enable_interchange_server(self.llm_inference_fn) baseurl_apikey_model = OpenaiClientBaseUrlTuner( config=self.config, context_tracker=self.context_tracker, @@ -168,7 +168,7 @@ def get_context_tracker(self) -> MultiAgentContextTracker: return self.context_tracker - def _enable_experimental_interchange_server(self, llm_inference_fn): + def _enable_interchange_server(self, llm_inference_fn): # experimental reverse proxy start if self.enable_interchange_server: from ajet.tuner_lib.experimental.as_oai_model_client import InterchangeClient diff --git a/ajet/tuner_lib/experimental/as_swarm_client.py b/ajet/tuner_lib/experimental/as_swarm_client.py index bf7c57f3..58b1e481 100644 --- a/ajet/tuner_lib/experimental/as_swarm_client.py +++ b/ajet/tuner_lib/experimental/as_swarm_client.py @@ -113,8 +113,8 @@ def _check_throttle_policy(self, throttle_policy: SwarmThrottlePolicy, pool_info if self._agent_jet_job: # check and raise early errors when possible - assert self._agent_jet_job.sample_collection_method == "rollout_until_finish_enough_tasks", \ - f"Current sample collection method ({self._agent_jet_job.sample_collection_method}) does not support throttle policy." + assert self._agent_jet_job.swarm_mode_sample_collection_method == "rollout_until_finish_enough_tasks", \ + f"Current sample collection method ({self._agent_jet_job.swarm_mode_sample_collection_method}) does not support throttle policy." # only_this_client_uuid = throttle_policy.throttle_method in ["Task_Ratio_Limit"] only_this_client_uuid = True diff --git a/ajet/utils/core_env_vars.py b/ajet/utils/core_env_vars.py index e48e1dda..9df18216 100644 --- a/ajet/utils/core_env_vars.py +++ b/ajet/utils/core_env_vars.py @@ -15,7 +15,7 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict: if config.ajet.trainer_common.nnodes == 1: master_node_ip = "localhost" else: - if config.ajet.enable_experimental_interchange_server: + if config.ajet.enable_interchange_server: if config.ajet.interchange_server.interchange_method == "ipc": raise ValueError("IPC interchange method is not supported for multi-node setup. Please set `ajet.interchange_server.interchange_method: tcp` ") diff --git a/docs/en/support_agentscope.md b/docs/en/support_agentscope.md index e551e4d9..13d308a5 100644 --- a/docs/en/support_agentscope.md +++ b/docs/en/support_agentscope.md @@ -64,7 +64,7 @@ This article introduce the way to convert different types of ways to convert you ajet: ... - enable_experimental_interchange_server: True + enable_interchange_server: True ... ``` diff --git a/docs/en/support_http.md b/docs/en/support_http.md index 0bf3ab3d..d7659b1b 100644 --- a/docs/en/support_http.md +++ b/docs/en/support_http.md @@ -89,7 +89,7 @@ in this AI era, you can always start from scratch and build your own "high-scrap ajet: ... - enable_experimental_interchange_server: True + enable_interchange_server: True ... ``` diff --git a/docs/en/support_langchain.md b/docs/en/support_langchain.md index d1e12890..344163e2 100644 --- a/docs/en/support_langchain.md +++ b/docs/en/support_langchain.md @@ -80,7 +80,7 @@ This article introduce the way to convert different types of ways to convert you ajet: ... - enable_experimental_interchange_server: True + enable_interchange_server: True ... ``` diff --git a/docs/en/support_oaisdk.md b/docs/en/support_oaisdk.md index b60b03e3..104a1c26 100644 --- a/docs/en/support_oaisdk.md +++ b/docs/en/support_oaisdk.md @@ -84,7 +84,7 @@ This article introduce the way to convert different types of ways to convert you ajet: ... - enable_experimental_interchange_server: True + enable_interchange_server: True ... ``` diff --git a/tests/bench/benchmark_appworld/benchmark_appworld.yaml b/tests/bench/benchmark_appworld/benchmark_appworld.yaml index f83e91f0..3622ed1b 100644 --- a/tests/bench/benchmark_appworld/benchmark_appworld.yaml +++ b/tests/bench/benchmark_appworld/benchmark_appworld.yaml @@ -58,14 +58,14 @@ ajet: execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" # -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml b/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml index 4ae12f17..f53ca63b 100644 --- a/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml +++ b/tests/bench/benchmark_appworld/benchmark_appworld_2nodes.yaml @@ -63,14 +63,14 @@ trinity: sync_offset: 0 sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml b/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml index e3175d19..89d82afb 100644 --- a/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml +++ b/tests/bench/benchmark_appworld/benchmark_appworld_oai_sdk.yaml @@ -56,14 +56,14 @@ ajet: execute_testing_lambda: "tests/bench/benchmark_appworld/benchmark_appworld.py->TestProbe" # -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tests/bench/benchmark_countdown/benchmark_countdown.yaml b/tests/bench/benchmark_countdown/benchmark_countdown.yaml index fcd07f35..53cdd902 100644 --- a/tests/bench/benchmark_countdown/benchmark_countdown.yaml +++ b/tests/bench/benchmark_countdown/benchmark_countdown.yaml @@ -124,14 +124,14 @@ ajet: execute_testing_lambda: "tests/bench/benchmark_countdown/benchmark_countdown.py->TestProbe" # FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml index dd3b6a18..b435a0ac 100644 --- a/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml +++ b/tests/bench/benchmark_learn2ask/benchmark_learn2ask.yaml @@ -57,14 +57,14 @@ trinity: sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tests/bench/benchmark_math/benchmark_math.yaml b/tests/bench/benchmark_math/benchmark_math.yaml index 36648d5f..f0f8d896 100644 --- a/tests/bench/benchmark_math/benchmark_math.yaml +++ b/tests/bench/benchmark_math/benchmark_math.yaml @@ -62,14 +62,14 @@ trinity: sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit - trinity_default # trinity inherit diff --git a/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml b/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml index a3dadd1a..e7bf0aba 100644 --- a/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml +++ b/tests/bench/benchmark_math/benchmark_math_oai_sdk.yaml @@ -59,14 +59,14 @@ trinity: sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit - trinity_default # trinity inherit diff --git a/tests/bench/benchmark_math/benchmark_math_raw_http.yaml b/tests/bench/benchmark_math/benchmark_math_raw_http.yaml index 8a4fc433..88c9aa15 100644 --- a/tests/bench/benchmark_math/benchmark_math_raw_http.yaml +++ b/tests/bench/benchmark_math/benchmark_math_raw_http.yaml @@ -59,14 +59,14 @@ trinity: sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit - trinity_default # trinity inherit diff --git a/tutorial/example_appworld/appworld.yaml b/tutorial/example_appworld/appworld.yaml index 316c605b..3ccb91b7 100644 --- a/tutorial/example_appworld/appworld.yaml +++ b/tutorial/example_appworld/appworld.yaml @@ -54,14 +54,14 @@ ajet: n_gpus_per_node: 8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_appworld/appworld_oai_sdk.yaml b/tutorial/example_appworld/appworld_oai_sdk.yaml index 056aac91..4b159b7f 100644 --- a/tutorial/example_appworld/appworld_oai_sdk.yaml +++ b/tutorial/example_appworld/appworld_oai_sdk.yaml @@ -53,14 +53,14 @@ ajet: n_gpus_per_node: 8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_countdown/countdown.yaml b/tutorial/example_countdown/countdown.yaml index d5b161bf..6dcadf81 100644 --- a/tutorial/example_countdown/countdown.yaml +++ b/tutorial/example_countdown/countdown.yaml @@ -135,14 +135,14 @@ ajet: execute_testing_lambda: "" # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN. -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_deep_finance/deep_finance.yaml b/tutorial/example_deep_finance/deep_finance.yaml index e5de33da..fcb429c4 100644 --- a/tutorial/example_deep_finance/deep_finance.yaml +++ b/tutorial/example_deep_finance/deep_finance.yaml @@ -71,14 +71,14 @@ actor_rollout_ref: rollout: tensor_model_parallel_size: 8 gpu_memory_utilization: 0.8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml index 38aa82ed..d9f559b9 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template.yaml @@ -75,14 +75,14 @@ actor_rollout_ref: rollout: tensor_model_parallel_size: 8 gpu_memory_utilization: 0.8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml b/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml index 0ddd541c..02fa6f73 100644 --- a/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml +++ b/tutorial/example_deep_finance/yaml_template/deep_finance_template_maxlen.yaml @@ -76,14 +76,14 @@ actor_rollout_ref: rollout: tensor_model_parallel_size: 8 gpu_memory_utilization: 0.8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_deep_finance/yaml_template/infer.yaml b/tutorial/example_deep_finance/yaml_template/infer.yaml index 5e9d400e..7dcf60ff 100644 --- a/tutorial/example_deep_finance/yaml_template/infer.yaml +++ b/tutorial/example_deep_finance/yaml_template/infer.yaml @@ -76,14 +76,14 @@ actor_rollout_ref: rollout: tensor_model_parallel_size: 8 gpu_memory_utilization: 0.8 -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_feedback_tracing/example_feedback_tracing.yaml b/tutorial/example_feedback_tracing/example_feedback_tracing.yaml index 1cb01333..894aca7c 100644 --- a/tutorial/example_feedback_tracing/example_feedback_tracing.yaml +++ b/tutorial/example_feedback_tracing/example_feedback_tracing.yaml @@ -62,14 +62,14 @@ trainer: - console -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_learn2ask/learn2ask.yaml b/tutorial/example_learn2ask/learn2ask.yaml index acacbce2..211c1e8a 100644 --- a/tutorial/example_learn2ask/learn2ask.yaml +++ b/tutorial/example_learn2ask/learn2ask.yaml @@ -53,14 +53,14 @@ trinity: sync_method: nccl -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_rubrics_judge/r_judge.yaml b/tutorial/example_rubrics_judge/r_judge.yaml index 5834da5f..94e0e940 100644 --- a/tutorial/example_rubrics_judge/r_judge.yaml +++ b/tutorial/example_rubrics_judge/r_judge.yaml @@ -57,14 +57,14 @@ ajet: -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ hydra: searchpath: - file://ajet/default_config - file://ajet/default_config/verl # verl only - file://ajet/default_config/trinity # trinity only -# ------------------ 不需要修改 ------------------ +# ------------------ do not edit ------------------ defaults: - verl_default # verl inherit 1/1 - trinity_default # trinity inherit 1/1 diff --git a/tutorial/example_werewolves_swarm/agent_roll.py b/tutorial/example_werewolves_swarm/agent_roll.py new file mode 100644 index 00000000..f81346bd --- /dev/null +++ b/tutorial/example_werewolves_swarm/agent_roll.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- + +import os +import re +import requests +from textwrap import dedent +from ajet.schema.task import Task, WorkflowOutput +from ajet.copilot.job import AgentJetJob +from ajet.task_reader import RouterTaskReader +from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient + +GRPO_N = 6 # grpo group size +NUM_EPOCH = 10000 +AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") +REMOTE_BATCH_SIZE = 32 +REMOTE_ALLOCATE_GPU_PER_NODE = 8 + +def main(): + + # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) + dataset = RouterTaskReader( + reader_type = "random_dummy", + reader_config = AjetTaskReader() + ) + + ajet_job = AgentJetJob( + base_yaml_config="tutorial/example_werewolves_swarm/werewolves.yaml", + algorithm="grpo", + experiment_name="werewolves_swarm", + ) + + # Hand shake with remote swarm server + swarm_worker = SwarmClient(AJET_SWARM_URL) + swarm_worker.auto_sync_train_config_and_start_engine( + ajet_job, + force_restart=True, + ) + + GRPO_N = ajet_job.num_repeat + REMOTE_BATCH_SIZE = ajet_job.batch_size + + # temperature: 0.7 + # max_env_worker: 64 + # num_repeat: 6 + + def rollout(task): + try: + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return + except: + pass + + executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True) + for _ in range(NUM_EPOCH): + for _, task in enumerate(dataset.generate_training_tasks()): + for _ in range(GRPO_N): + executor.submit_with_periodic_drain(fn=rollout, task=task) + + return None + + + + +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + # Prepare base_url, api_key + base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) + # Read dataset item + query, reference_answer = (task.main_query, task.metadata["answer"]) + # Prepare messages + messages = [ + { "role": "system", "content": dedent("""You are an agent specialized in solving math problems. Please solve the math problem given to you. + You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}.""") }, + { "role": "user", "content": query } + ] + # Use raw http requests (non-streaming) to get response + # "Connection: close" prevents keep-alive pool reuse, which can cause BadStatusLine + # errors under high concurrency when stale pooled connections return residual bytes. + response = requests.post( + f"{base_url}/chat/completions", + json = { "model": "fill_whatever_model", "messages": messages, "stream": False }, + headers = { "Authorization": f"Bearer {api_key}", "Connection": "close" }, + timeout = 300, + ) + response.raise_for_status() + final_answer = response.json()['choices'][0]['message']['content'] + + reference_answer = reference_answer.split("####")[-1].strip() + pattern = r"\\boxed\{([^}]*)\}" + match = re.search(pattern, final_answer) + if match: is_success = match.group(1) == reference_answer + else: is_success = False + raw_reward = 1.0 if is_success else 0.0 + # Return + return WorkflowOutput(reward=raw_reward, metadata={"final_answer": final_answer}) + + + + +if __name__ == "__main__": + main() diff --git a/tutorial/example_werewolves_swarm/game.py b/tutorial/example_werewolves_swarm/game.py new file mode 100644 index 00000000..10246c32 --- /dev/null +++ b/tutorial/example_werewolves_swarm/game.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +# pylint: disable=too-many-branches, too-many-statements, no-name-in-module +"""A werewolf game implemented by agentscope.""" +from agentscope.agent import ReActAgent +from agentscope.pipeline import MsgHub, fanout_pipeline, sequential_pipeline + +# Uncomment the following line to use Chinese prompts +# from tutorial.example_werewolves.prompt import ChinesePrompts as Prompts +from loguru import logger + +from tutorial.example_werewolves.prompt import EnglishPrompts as Prompts +from tutorial.example_werewolves.structured_model import ( + DiscussionModel, + WitchResurrectModel, + get_hunter_model, + get_poison_model, + get_seer_model, + get_vote_model, +) +from tutorial.example_werewolves.utils import ( + MAX_DISCUSSION_ROUND, + MAX_GAME_ROUND, + EchoAgent, + Players, + majority_vote, + names_to_str, +) + + +class BadGuyException(Exception): + ... + + +moderator = EchoAgent() +# moderator.set_console_output_enabled(False) + + +async def hunter_stage( + hunter_agent: ReActAgent, + players: Players, +) -> str | None: + """Because the hunter's stage may happen in two places: killed at night + or voted during the day, we define a function here to avoid duplication.""" + global moderator + msg_hunter = await hunter_agent( + await moderator(Prompts.to_hunter.format(name=hunter_agent.name)), + structured_model=get_hunter_model(players.current_alive), + ) + if msg_hunter.metadata.get("shoot"): + return msg_hunter.metadata.get("name", None) + return None + + +async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C901 + """The main entry of the werewolf game + + Args: + agents (`list[ReActAgent]`): + A list of 9 agents. + """ + assert len(agents) == 9, "The werewolf game needs exactly 9 players." + + # Init the players' status + players = Players() + + # If the witch has healing and poison potion + healing, poison = True, True + + # If it's the first day, the dead can leave a message + first_day = True + + # Broadcast the game begin message + async with MsgHub(participants=agents) as greeting_hub: + await greeting_hub.broadcast( + await moderator( + Prompts.to_all_new_game.format(names_to_str(agents)), + ), + ) + + # Assign roles to the agents + for agent, role in zip(agents, roles): + # Tell the agent its role + await agent.observe( + await moderator( + f"[{agent.name} ONLY] {agent.name}, your role is {role}.", + ), + ) + players.add_player(agent, role) + + # Printing the roles + players.print_roles() + + # GAME BEGIN! + for _ in range(MAX_GAME_ROUND): + # Create a MsgHub for all players to broadcast messages + async with MsgHub( + participants=players.current_alive, + enable_auto_broadcast=False, # manual broadcast only + name="alive_players", + ) as alive_players_hub: + # Night phase + await alive_players_hub.broadcast( + await moderator(Prompts.to_all_night), + ) + killed_player, poisoned_player, shot_player = None, None, None + + try: + # Werewolves discuss + async with MsgHub( + players.werewolves, + enable_auto_broadcast=True, + announcement=await moderator( + Prompts.to_wolves_discussion.format( + names_to_str(players.werewolves), + names_to_str(players.current_alive), + ), + ), + name="werewolves", + ) as werewolves_hub: + # Discussion + n_werewolves = len(players.werewolves) + for _ in range(1, MAX_DISCUSSION_ROUND * n_werewolves + 1): + res = await players.werewolves[_ % n_werewolves]( + structured_model=DiscussionModel, + ) + if _ % n_werewolves == 0 and res.metadata.get( + "reach_agreement", + ): + break + + # Werewolves vote + # Disable auto broadcast to avoid following other's votes + werewolves_hub.set_auto_broadcast(False) + msgs_vote = await fanout_pipeline( + players.werewolves, + msg=await moderator(content=Prompts.to_wolves_vote), + structured_model=get_vote_model(players.current_alive), + enable_gather=False, + ) + killed_player, votes = majority_vote( + [_.metadata.get("vote") for _ in msgs_vote], + ) + # Postpone the broadcast of voting + await werewolves_hub.broadcast( + [ + *msgs_vote, + await moderator( + Prompts.to_wolves_res.format(votes, killed_player), + ), + ], + ) + except Exception as e: + raise BadGuyException( + f"Werewolves failed to make a decision: {e}", + ) + + # Witch's turn + await alive_players_hub.broadcast( + await moderator(Prompts.to_all_witch_turn), + ) + msg_witch_poison = None + for agent in players.witch: + # Cannot heal witch herself + msg_witch_resurrect = None + if healing and killed_player != agent.name: + msg_witch_resurrect = await agent( + await moderator( + Prompts.to_witch_resurrect.format( + witch_name=agent.name, + dead_name=killed_player, + ), + ), + structured_model=WitchResurrectModel, + ) + if msg_witch_resurrect.metadata.get("resurrect"): + killed_player = None + healing = False + + # Has poison potion and hasn't used the healing potion + if poison and not ( + msg_witch_resurrect and msg_witch_resurrect.metadata["resurrect"] + ): + msg_witch_poison = await agent( + await moderator( + Prompts.to_witch_poison.format( + witch_name=agent.name, + ), + ), + structured_model=get_poison_model( + players.current_alive, + ), + ) + if msg_witch_poison.metadata.get("poison"): + poisoned_player = msg_witch_poison.metadata.get("name") + poison = False + + # Seer's turn + await alive_players_hub.broadcast( + await moderator(Prompts.to_all_seer_turn), + ) + for agent in players.seer: + msg_seer = await agent( + await moderator( + Prompts.to_seer.format( + agent.name, + names_to_str(players.current_alive), + ), + ), + structured_model=get_seer_model(players.current_alive), + ) + if msg_seer.metadata.get("name"): + player = msg_seer.metadata["name"] + await agent.observe( + await moderator( + Prompts.to_seer_result.format( + agent_name=player, + role=players.name_to_role[player], + ), + ), + ) + + # Hunter's turn + for agent in players.hunter: + # If killed and not by witch's poison + if killed_player == agent.name and poisoned_player != agent.name: + shot_player = await hunter_stage(agent, players) + + # Update alive players + dead_tonight = [killed_player, poisoned_player, shot_player] + players.update_players(dead_tonight) + + # Day phase + if len([_ for _ in dead_tonight if _]) > 0: + await alive_players_hub.broadcast( + await moderator( + Prompts.to_all_day.format( + names_to_str([_ for _ in dead_tonight if _]), + ), + ), + ) + + # The killed player leave a last message in first night + if killed_player and first_day: + msg_moderator = await moderator( + Prompts.to_dead_player.format(killed_player), + ) + await alive_players_hub.broadcast(msg_moderator) + # Leave a message + last_msg = await players.name_to_agent[killed_player]() + await alive_players_hub.broadcast(last_msg) + + else: + await alive_players_hub.broadcast( + await moderator(Prompts.to_all_peace), + ) + + # Check winning + res = players.check_winning() + if res: + await moderator(res) + break + + # Discussion + await alive_players_hub.broadcast( + await moderator( + Prompts.to_all_discuss.format( + names=names_to_str(players.current_alive), + ), + ), + ) + # Open the auto broadcast to enable discussion + alive_players_hub.set_auto_broadcast(True) + await sequential_pipeline(players.current_alive) + # Disable auto broadcast to avoid leaking info + alive_players_hub.set_auto_broadcast(False) + + # Voting + msgs_vote = await fanout_pipeline( + players.current_alive, + await moderator( + Prompts.to_all_vote.format( + names_to_str(players.current_alive), + ), + ), + structured_model=get_vote_model(players.current_alive), + enable_gather=False, + ) + voted_player, votes = majority_vote( + [_.metadata.get("vote") for _ in msgs_vote], + ) + # Broadcast the voting messages together to avoid influencing + # each other + voting_msgs = [ + *msgs_vote, + await moderator( + Prompts.to_all_res.format(votes, voted_player), + ), + ] + + # Leave a message if voted + if voted_player: + prompt_msg = await moderator( + Prompts.to_dead_player.format(voted_player), + ) + last_msg = await players.name_to_agent[voted_player]( + prompt_msg, + ) + voting_msgs.extend([prompt_msg, last_msg]) + + await alive_players_hub.broadcast(voting_msgs) + + # If the voted player is the hunter, he can shoot someone + shot_player = None + for agent in players.hunter: + if voted_player == agent.name: + shot_player = await hunter_stage(agent, players) + if shot_player: + await alive_players_hub.broadcast( + await moderator( + Prompts.to_all_hunter_shoot.format( + shot_player, + ), + ), + ) + + # Update alive players + dead_today = [voted_player, shot_player] + players.update_players(dead_today) + + # Check winning + res = players.check_winning() + if res: + async with MsgHub(players.all_players) as all_players_hub: + res_msg = await moderator(res) + await all_players_hub.broadcast(res_msg) + break + + # The day ends + first_day = False + + # # Game over, each player reflects + # await fanout_pipeline( + # agents=agents, + # msg=await moderator(Prompts.to_all_reflect), + # ) + + alive_wolves = players.werewolves + good_guy_win = len(alive_wolves) == 0 + logger.warning("**********************************") + logger.warning(f"Good guy win: {good_guy_win}, alive werewolves: {alive_wolves}") + return good_guy_win diff --git a/tutorial/example_werewolves_swarm/prompt.py b/tutorial/example_werewolves_swarm/prompt.py new file mode 100644 index 00000000..52d0f650 --- /dev/null +++ b/tutorial/example_werewolves_swarm/prompt.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +"""Default prompts""" + + +class EnglishPrompts: + """English prompts used to guide the werewolf game.""" + + to_dead_player = ( + "{}, you're eliminated now. Now you can make a final statement to " + "all alive players before you leave the game." + ) + + to_all_new_game = ( + "A new game is starting, the players are: {}. Now we randomly " + "reassign the roles to each player and inform them of their roles " + "privately." + ) + + to_all_night = ( + "Night has fallen, everyone close your eyes. Werewolves open your " + "eyes and choose a player to eliminate tonight." + ) + + to_wolves_discussion = ( + "[WEREWOLVES ONLY] {}, you should discuss and " + "decide on a player to eliminate tonight. Current alive players " + "are {}. Remember to set `reach_agreement` to True if you reach an " + "agreement during the discussion." + ) + + to_wolves_vote = "[WEREWOLVES ONLY] Which player do you vote to kill?" + + to_wolves_res = ( + "[WEREWOLVES ONLY] The voting result is {}. So you have chosen to " "eliminate {}." + ) + + to_all_witch_turn = "Witch's turn, witch open your eyes and decide your action tonight..." + to_witch_resurrect = ( + "[WITCH ONLY] {witch_name}, you're the witch, and tonight {dead_name} " + "is eliminated. You can resurrect him/her by using your healing " + "potion, " + "and note you can only use it once in the whole game. Do you want to " + "resurrect {dead_name}? Give me your reason and decision." + ) + + to_witch_resurrect_no = "[WITCH ONLY] The witch has chosen not to resurrect the player." + to_witch_resurrect_yes = "[WITCH ONLY] The witch has chosen to resurrect the player." + + to_witch_poison = ( + "[WITCH ONLY] {witch_name}, as a witch, you have a one-time-use " + "poison potion, do you want to use it tonight? Give me your reason " + "and decision." + ) + + to_all_seer_turn = ( + "Seer's turn, seer open your eyes and check one player's identity " "tonight..." + ) + + to_seer = ( + "[SEER ONLY] {}, as the seer you can check one player's identity " + "tonight. Who do you want to check? Give me your reason and decision." + ) + + to_seer_result = "[SEER ONLY] You've checked {agent_name}, and the result is: {role}." + + to_hunter = ( + "[HUNTER ONLY] {name}, as the hunter you're eliminated tonight. You " + "can choose one player to take down with you. Also, you can choose " + "not to use this ability. Give me your reason and decision." + ) + + to_all_hunter_shoot = "The hunter has chosen to shoot {} down with him/herself." + + to_all_day = ( + "The day is coming, all players open your eyes. Last night, " + "the following player(s) has been eliminated: {}." + ) + + to_all_peace = ( + "The day is coming, all the players open your eyes. Last night is " + "peaceful, no player is eliminated." + ) + + to_all_discuss = ( + "Now the alive players are {names}. The game goes on, it's time to " + "discuss and vote a player to be eliminated. Now you each take turns " + "to speak once in the order of {names}." + ) + + to_all_vote = ( + "Now the discussion is over. Everyone, please vote to eliminate one " + "player from the alive players: {}." + ) + + to_all_res = "The voting result is {}. So {} has been voted out." + + to_all_wolf_win = ( + "There are {n_alive} players alive, and {n_werewolves} of them are " + "werewolves. " + "The game is over and werewolves win🐺🎉!" + "In this game, the true roles of all players are: {true_roles}" + ) + + to_all_village_win = ( + "All the werewolves have been eliminated." + "The game is over and villagers win🏘️🎉!" + "In this game, the true roles of all players are: {true_roles}" + ) + + to_all_continue = "The game goes on." + + to_all_reflect = ( + "The game is over. Now each player can reflect on their performance. " + "Note each player only has one chance to speak and the reflection is " + "only visible to themselves." + ) + + +class ChinesePrompts: + """Chinese prompts used to guide the werewolf game.""" + + to_dead_player = "{}, 你已被淘汰。现在你可以向所有存活玩家发表最后的遗言。" + + to_all_new_game = "新的一局游戏开始,参与玩家包括:{}。现在为每位玩家重新随机分配身份,并私下告知各自身份。" + + to_all_night = "天黑了,请所有人闭眼。狼人请睁眼,选择今晚要淘汰的一名玩家..." + + to_wolves_discussion = ( + "[仅狼人可见] {}, 你们可以讨论并决定今晚要淘汰的玩家。当前存活玩家有:{}。" "如果达成一致,请将 `reach_agreement` 设为 True。" + ) + + to_wolves_vote = "[仅狼人可见] 你投票要杀死哪位玩家?" + + to_wolves_res = "[仅狼人可见] 投票结果为 {},你们选择淘汰 {}。" + + to_all_witch_turn = "轮到女巫行动,女巫请睁眼并决定今晚的操作..." + to_witch_resurrect = ( + "[仅女巫可见] {witch_name},你是女巫,今晚{dead_name}被淘汰。" + "你可以用解药救他/她,注意解药全局只能用一次。你要救{dead_name}吗?" + "请给出理由和决定。" + ) + + to_witch_resurrect_no = "[仅女巫可见] 女巫选择不救该玩家。" + to_witch_resurrect_yes = "[仅女巫可见] 女巫选择救活该玩家。" + + to_witch_poison = "[仅女巫可见] {witch_name},你有一瓶一次性毒药,今晚要使用吗?请给出理由和决定。" + + to_all_seer_turn = "轮到预言家行动,预言家请睁眼并查验一名玩家身份..." + + to_seer = "[仅预言家可见] {}, 你是预言家,今晚可以查验一名玩家身份。你要查谁?请给出理由和决定。" + + to_seer_result = "[仅预言家可见] 你查验了{agent_name},结果是:{role}。" + + to_hunter = "[仅猎人可见] {name},你是猎人,今晚被淘汰。你可以选择带走一名玩家,也可以选择不带走。请给出理由和决定。" + + to_all_hunter_shoot = "猎人选择带走 {} 一起出局。" + + to_all_day = "天亮了,请所有玩家睁眼。昨晚被淘汰的玩家有:{}。" + + to_all_peace = "天亮了,请所有玩家睁眼。昨晚平安夜,无人被淘汰。" + + to_all_discuss = "现在存活玩家有:{names}。游戏继续,大家开始讨论并投票淘汰一名玩家。请按顺序({names})依次发言。" + + to_all_vote = "讨论结束。请大家从存活玩家中投票淘汰一人:{}。" + + to_all_res = "投票结果为 {},{} 被淘汰。" + + to_all_wolf_win = ( + "当前存活玩家共{n_alive}人,其中{n_werewolves}人为狼人。" "游戏结束,狼人获胜🐺🎉!" "本局所有玩家真实身份为:{true_roles}" + ) + + to_all_village_win = "所有狼人已被淘汰。游戏结束,村民获胜🏘️🎉!本局所有玩家真实身份为:{true_roles}" + + to_all_continue = "游戏继续。" + + to_all_reflect = "游戏结束。现在每位玩家可以对自己的表现进行反思。注意每位玩家只有一次发言机会,且反思内容仅自己可见。" diff --git a/tutorial/example_werewolves_swarm/start.py b/tutorial/example_werewolves_swarm/start.py new file mode 100644 index 00000000..879b6101 --- /dev/null +++ b/tutorial/example_werewolves_swarm/start.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- +# flake8: noqa: E501 + +"""The main entry point for the werewolf game.""" + +from typing import List +import numpy as np +import dotenv +dotenv.load_dotenv() + +from textwrap import dedent + +from agentscope.agent import ReActAgent +from agentscope.formatter import DashScopeMultiAgentFormatter, OpenAIMultiAgentFormatter +from agentscope.model import OpenAIChatModel +from loguru import logger +from pydantic import Field + +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask +from tutorial.example_werewolves.game import BadGuyException, werewolves_game + + +def get_official_agent_prompt(name) -> str: + system_prompt = dedent( + f""" + You're a werewolf game player named {name}. + + # YOUR TARGET + Your target is to win the game with your teammates as much as possible. + + # GAME RULES + - In werewolf game, players are divided into three werewolves, three villagers, one seer, one hunter and one witch. + - Werewolves: kill one player each night, and must hide identity during the day. + - Villagers: ordinary players without special abilities, try to identify and eliminate werewolves. + - Seer: A special villager who can check one player's identity each night. + - Witch: A special villager with two one-time-use potions: a healing potion to save a player from being killed at night, and a poison to eliminate one player at night. + - Hunter: A special villager who can take one player down with them when they are eliminated. + - The game alternates between night and day phases until one side wins: + - Night Phase + - Werewolves choose one victim + - Seer checks one player's identity + - Witch decides whether to use potions + - Moderator announces who died during the night + - Day Phase + - All players discuss and vote to eliminate one suspected player + + # GAME GUIDANCE + - Try your best to win the game with your teammates, tricks, lies, and deception are all allowed, e.g. pretending to be a different role. + - During discussion, don't be political, be direct and to the point. + - The day phase voting provides important clues. For example, the werewolves may vote together, attack the seer, etc. + ## GAME GUIDANCE FOR WEREWOLF + - Seer is your greatest threat, who can check one player's identity each night. Analyze players' speeches, find out the seer and eliminate him/her will greatly increase your chances of winning. + - In the first night, making random choices is common for werewolves since no information is available. + - Pretending to be other roles (seer, witch or villager) is a common strategy to hide your identity and mislead other villagers in the day phase. + - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, if the dead player is hunter, etc. Use this information to adjust your strategy. + ## GAME GUIDANCE FOR SEER + - Seer is very important to villagers, exposing yourself too early may lead to being targeted by werewolves. + - Your ability to check one player's identity is crucial. + - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, if the dead player is hunter, etc. Use this information to adjust your strategy. + ## GAME GUIDANCE FOR WITCH + - Witch has two powerful potions, use them wisely to protect key villagers or eliminate suspected werewolves. + - The outcome of the night phase provides important clues. For example, if the dead player is hunter, etc. Use this information to adjust your strategy. + ## GAME GUIDANCE FOR HUNTER + - Using your ability in day phase will expose your role (since only hunter can take one player down) + - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, etc. Use this information to adjust your strategy. + ## GAME GUIDANCE FOR VILLAGER + - Protecting special villagers, especially the seer, is crucial for your team's success. + - Werewolves may pretend to be the seer. Be cautious and don't trust anyone easily. + - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, if the dead player is hunter, etc. Use this information to adjust your strategy. + + # NOTE + - [IMPORTANT] DO NOT make up any information that is not provided by the moderator or other players. + - This is a TEXT-based game, so DO NOT use or make up any non-textual information. + - Always critically reflect on whether your evidence exist, and avoid making assumptions. + - Your response should be specific and concise, provide clear reason and avoid unnecessary elaboration. + - Generate your one-line response by using the `generate_response` function. + - Don't repeat the others' speeches.""" + ) + return system_prompt + + +class ExampleWerewolves(Workflow): + trainable_targets: List[str] | None = Field(default=["werewolf"], description="List of agents to be fine-tuned.") + + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + + # ensure trainable targets is legal + assert self.trainable_targets is not None, "trainable_targets cannot be None in ExampleWerewolves (because we want to demonstrate a explicit multi-agent case)." + + # bad guys and good guys cannot be trained simultaneously + # (because mix-cooperation-competition MARL needs too many advanced techniques to be displayed here) + if "werewolf" in self.trainable_targets: + assert len(self.trainable_targets) == 1, "Cannot train hostile roles simultaneously." + else: + assert len(self.trainable_targets) != 0, "No trainable targets specified." + + # make and shuffle roles (fix random seed for reproducibility) + roles = ["werewolf"] * 3 + ["villager"] * 3 + ["seer", "witch", "hunter"] + task_id = workflow_task.task.metadata["random_number"] + np.random.seed(int(task_id)) + np.random.shuffle(roles) + + # initialize agents + players = [] + for i, role in enumerate(roles): + default_model = OpenAIChatModel( + model_name="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", + stream=False, + client_args={"base_url": "http://22.17.52.4:2888/v1"}, + api_key="no_api_key", + generate_kwargs={"temperature": 0.01}, + ) + model_for_this_agent = tuner.as_agentscope_model( + agent_name=f"Player{i + 1}", # the name of this agent + target_tag=role, # `target_tag in self.trainable_targets` means we train this agent, otherwise we do not train this agent. + debug_model=default_model, # the model used when this agent is not in `self.trainable_targets` + ) + agent = ReActAgent( + name=f"Player{i + 1}", + sys_prompt=get_official_agent_prompt(f"Player{i + 1}"), + model=model_for_this_agent, + formatter=DashScopeMultiAgentFormatter() + if role in self.trainable_targets + else OpenAIMultiAgentFormatter(), + max_iters=3 if role in self.trainable_targets else 5, + ) + # agent.set_console_output_enabled(False) + players += [agent] + + # reward condition + try: + good_guy_win = await werewolves_game(players, roles) + raw_reward = 0 + is_success = False + if (good_guy_win and self.trainable_targets[0] != "werewolf") or ( + not good_guy_win and self.trainable_targets[0] == "werewolf" + ): + raw_reward = 1 + is_success = True + logger.warning(f"Raw reward: {raw_reward}") + logger.warning(f"Is success: {is_success}") + except BadGuyException as e: + logger.bind(exception=True).exception( + f"Error during game execution. Game cannot continue, whatever the cause, let's punish trainable agents (Although they maybe innocent)." + ) + raw_reward = -0.1 + is_success = False + except Exception as e: + logger.bind(exception=True).exception( + f"Error during game execution. Game cannot continue, whatever the cause, let's punish trainable agents (Although they maybe innocent)." + ) + raw_reward = -0.1 + is_success = False + + return WorkflowOutput(reward=raw_reward, is_success=is_success) diff --git a/tutorial/example_werewolves_swarm/structured_model.py b/tutorial/example_werewolves_swarm/structured_model.py new file mode 100644 index 00000000..46390589 --- /dev/null +++ b/tutorial/example_werewolves_swarm/structured_model.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +"""The structured output models used in the werewolf game.""" +from typing import Literal + +from agentscope.agent import AgentBase +from pydantic import BaseModel, Field + + +class DiscussionModel(BaseModel): + """The output format for discussion.""" + + reach_agreement: bool = Field( + description="Whether you have reached an agreement or not", + ) + + +def get_vote_model(agents: list[AgentBase]) -> type[BaseModel]: + """Get the vote model by player names.""" + + class VoteModel(BaseModel): + """The vote output format.""" + + vote: Literal[tuple(_.name for _ in agents)] = Field( # type: ignore + description="The name of the player you want to vote for", + ) + + return VoteModel + + +class WitchResurrectModel(BaseModel): + """The output format for witch resurrect action.""" + + resurrect: bool = Field( + description="Whether you want to resurrect the player", + ) + + +def get_poison_model(agents: list[AgentBase]) -> type[BaseModel]: + """Get the poison model by player names.""" + + class WitchPoisonModel(BaseModel): + """The output format for witch poison action.""" + + poison: bool = Field( + description="Do you want to use the poison potion", + ) + name: Literal[tuple(_.name for _ in agents)] | None = ( # type: ignore + Field( + description="The name of the player you want to poison, if you " + "don't want to poison anyone, just leave it empty", + default=None, + ) + ) + + return WitchPoisonModel + + +def get_seer_model(agents: list[AgentBase]) -> type[BaseModel]: + """Get the seer model by player names.""" + + class SeerModel(BaseModel): + """The output format for seer action.""" + + name: Literal[tuple(_.name for _ in agents)] = Field( # type: ignore + description="The name of the player you want to check", + ) + + return SeerModel + + +def get_hunter_model(agents: list[AgentBase]) -> type[BaseModel]: + """Get the hunter model by player agents.""" + + class HunterModel(BaseModel): + """The output format for hunter action.""" + + shoot: bool = Field( + description="Whether you want to use the shooting ability or not", + ) + name: Literal[tuple(_.name for _ in agents)] | None = ( # type: ignore + Field( + description="The name of the player you want to shoot, if you " + "don't want to the ability, just leave it empty", + default=None, + ) + ) + + return HunterModel diff --git a/tutorial/example_werewolves_swarm/utils.py b/tutorial/example_werewolves_swarm/utils.py new file mode 100644 index 00000000..c9dd0039 --- /dev/null +++ b/tutorial/example_werewolves_swarm/utils.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +"""Utility functions for the werewolf game.""" +from collections import defaultdict +from typing import Any + +import numpy as np +from agentscope.agent import AgentBase, ReActAgent +from agentscope.message import Msg + +from tutorial.example_werewolves.prompt import EnglishPrompts as Prompts + +# MAX_GAME_ROUND = 30 +# MAX_DISCUSSION_ROUND = 3 +MAX_GAME_ROUND = 7 +MAX_DISCUSSION_ROUND = 2 + + +def majority_vote(votes: list[str]) -> tuple: + """Return the vote with the most counts.""" + result = max(set(votes), key=votes.count) + names, counts = np.unique(votes, return_counts=True) + conditions = ", ".join( + [f"{name}: {count}" for name, count in zip(names, counts)], + ) + return result, conditions + + +def names_to_str(agents: list[str] | list[ReActAgent]) -> str: + """Return a string of agent names.""" + if not agents: + return "" + + if len(agents) == 1: + if isinstance(agents[0], ReActAgent): + return agents[0].name + return agents[0] + + names = [] + for agent in agents: + if isinstance(agent, ReActAgent): + names.append(agent.name) + else: + names.append(agent) + return ", ".join([*names[:-1], "and " + names[-1]]) + + +class EchoAgent(AgentBase): + """Echo agent that repeats the input message.""" + + def __init__(self) -> None: + super().__init__() + self.name = "Moderator" + + async def reply(self, content: str) -> Msg: + """Repeat the input content with its name and role.""" + msg = Msg( + self.name, + content, + role="assistant", + ) + await self.print(msg) + return msg + + async def handle_interrupt( + self, + *args: Any, + **kwargs: Any, + ) -> Msg: + """Handle interrupt.""" + + async def observe(self, msg: Msg | list[Msg] | None) -> None: + """Observe the user's message.""" + + +class Players: + """Maintain the players' status.""" + + def __init__(self) -> None: + """Initialize the players.""" + # The mapping from player name to role + self.name_to_role = {} + self.role_to_names = defaultdict(list) + self.name_to_agent = {} + self.werewolves = [] + self.villagers = [] + self.seer = [] + self.hunter = [] + self.witch = [] + self.current_alive = [] + self.all_players = [] + + def add_player(self, player: ReActAgent, role: str) -> None: + """Add a player to the game. + + Args: + player (`ReActAgent`): + The player to be added. + role (`str`): + The role of the player. + """ + self.name_to_role[player.name] = role + self.name_to_agent[player.name] = player + self.role_to_names[role].append(player.name) + self.all_players.append(player) + if role == "werewolf": + self.werewolves.append(player) + elif role == "villager": + self.villagers.append(player) + elif role == "seer": + self.seer.append(player) + elif role == "hunter": + self.hunter.append(player) + elif role == "witch": + self.witch.append(player) + else: + raise ValueError(f"Unknown role: {role}") + self.current_alive.append(player) + + def update_players(self, dead_players: list[ReActAgent]) -> None: + """Update the current alive players. + + Args: + dead_players (`list[ReActAgent]`): + A list of dead players to be removed. + """ + self.werewolves = [_ for _ in self.werewolves if _.name not in dead_players] + self.villagers = [_ for _ in self.villagers if _.name not in dead_players] + self.seer = [_ for _ in self.seer if _.name not in dead_players] + self.hunter = [_ for _ in self.hunter if _.name not in dead_players] + self.witch = [_ for _ in self.witch if _.name not in dead_players] + self.current_alive = [_ for _ in self.current_alive if _.name not in dead_players] + + def print_roles(self) -> None: + """Print the roles of all players.""" + print("Roles:") + for name, role in self.name_to_role.items(): + print(f" - {name}: {role}") + + def check_winning(self) -> str | None: + """Check if the game is over and return the winning message.""" + + # Prepare true roles string + true_roles = ( + f'{names_to_str(self.role_to_names["werewolf"])} are werewolves, ' + f'{names_to_str(self.role_to_names["villager"])} are villagers, ' + f'{names_to_str(self.role_to_names["seer"])} is the seer, ' + f'{names_to_str(self.role_to_names["hunter"])} is the hunter, ' + f'and {names_to_str(self.role_to_names["witch"])} is the witch.' + ) + + if len(self.werewolves) * 2 >= len(self.current_alive): + return Prompts.to_all_wolf_win.format( + n_alive=len(self.current_alive), + n_werewolves=len(self.werewolves), + true_roles=true_roles, + ) + if self.current_alive and not self.werewolves: + return Prompts.to_all_village_win.format( + true_roles=true_roles, + ) + return None diff --git a/tutorial/example_werewolves_swarm/werewolves.md b/tutorial/example_werewolves_swarm/werewolves.md new file mode 100644 index 00000000..737a11f4 --- /dev/null +++ b/tutorial/example_werewolves_swarm/werewolves.md @@ -0,0 +1,9 @@ +# Training a basic math agent + + +Please refer to document at [`docs/en/example_werewolves.md`](docs/en/example_werewolves.md) + + +# Translate to yaml + +tutorial/example_werewolves_swarm/werewolves.yaml \ No newline at end of file diff --git a/tutorial/example_werewolves_swarm/werewolves.yaml b/tutorial/example_werewolves_swarm/werewolves.yaml new file mode 100644 index 00000000..7f786985 --- /dev/null +++ b/tutorial/example_werewolves_swarm/werewolves.yaml @@ -0,0 +1,76 @@ +# ------------------ main config ------------------ +ajet: + project_name: example_werewolves_swarm + experiment_dir: "auto" # {exp-dir}/{experiment_name} + task_reader: + type: random_dummy # ✨ + + model: + # ✨ select model to be trained + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-7B-Instruct + + + rollout: + user_workflow: null + temperature: 0.7 + max_env_worker: 64 + num_repeat: 6 + agent_madness_reward: 0.0 + tensor_model_parallel_size: 1 + max_num_seqs: 40 + # monitor LLM's abormal behaviors during rollout + compute_madness_checklist: + - "nonsense" + max_response_length_in_one_turn: 1024 + max_model_len: 22000 + + task_reader: + type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` + + task_judge: + # ✨ select evaluation function + judge_protocol: null + + # the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature + enable_interchange_server: True + # train in cloud, run episode locally + enable_swarm_mode: True + # both swarm / oai share the same interchange server + interchange_server: + interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node) + interchange_server_port: 10086 + num_fastapi_process: 2 # 1, 2 or 4 is fine + max_fastapi_threads: 512 # 64 or 128 is fine + max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker` + already_started: False # do not edit, used by `swarm` + + swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks" + + debug: + debug_max_parallel: 1 + debug_first_n_tasks: 1 + + data: + train_batch_size: 32 + max_prompt_length: 4000 + max_response_length: 18000 + + trainer_common: + save_freq: 5 + test_freq: 99999 + total_epochs: 99999 + total_training_steps: 25 + nnodes: 2 + n_gpus_per_node: 8 + +# ------------------ do not edit ------------------ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl + +# ------------------ do not edit ------------------ +defaults: + - verl_default + - ajet_default + - _self_ From 544bff1483f3bc0b3beecedf47d50f21d0ed09a4 Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Mon, 2 Mar 2026 15:42:34 +0800 Subject: [PATCH 02/11] stage werewolve example --- ajet/copilot/job.py | 4 - .../tuner_lib/experimental/as_swarm_client.py | 15 +- .../tuner_lib/experimental/as_swarm_server.py | 1 + .../experimental/interchange_utils.py | 8 +- ajet/utils/config_utils.py | 13 +- ajet/utils/launch_utils.py | 7 + docs/en/ajet-swarm-docker.md | 2 + math_gsm8k_grpo/yaml_backup.yaml | 559 ++++++++++++++++++ tutorial/example_math_swarm/math.py | 20 +- tutorial/example_werewolves/start.py | 6 +- .../example_werewolves_swarm/agent_roll.py | 46 +- .../example_werewolves_swarm/convert_skill.md | 77 +++ tutorial/example_werewolves_swarm/game.py | 351 ----------- tutorial/example_werewolves_swarm/prompt.py | 176 ------ tutorial/example_werewolves_swarm/start.py | 155 ----- .../structured_model.py | 88 --- tutorial/example_werewolves_swarm/utils.py | 161 ----- .../example_werewolves_swarm/werewolves.md | 9 - .../example_werewolves_swarm/werewolves.yaml | 6 +- 19 files changed, 698 insertions(+), 1006 deletions(-) create mode 100644 math_gsm8k_grpo/yaml_backup.yaml create mode 100644 tutorial/example_werewolves_swarm/convert_skill.md delete mode 100644 tutorial/example_werewolves_swarm/game.py delete mode 100644 tutorial/example_werewolves_swarm/prompt.py delete mode 100644 tutorial/example_werewolves_swarm/start.py delete mode 100644 tutorial/example_werewolves_swarm/structured_model.py delete mode 100644 tutorial/example_werewolves_swarm/utils.py delete mode 100644 tutorial/example_werewolves_swarm/werewolves.md diff --git a/ajet/copilot/job.py b/ajet/copilot/job.py index f74b44ac..96ae2031 100644 --- a/ajet/copilot/job.py +++ b/ajet/copilot/job.py @@ -132,13 +132,9 @@ def __init__( def build_job_from_yaml(self, yaml_path: str | None) -> dict: - assert self.experiment_dir is not None, "experiment_dir must be provided either in constructor or in yaml config." self.config_as_dict = read_ajet_hierarchical_config( yaml_path, - exp_name=self.experiment_name, - backbone=self.backbone, write_to=None, - exp_dir=self.experiment_dir, ) self.config_as_dict = expand_ajet_hierarchical_config(self.config_as_dict, write_to=None) logger.info(f"Built AgentJet job config: {yaml_path}") diff --git a/ajet/tuner_lib/experimental/as_swarm_client.py b/ajet/tuner_lib/experimental/as_swarm_client.py index 58b1e481..f4dfcee5 100644 --- a/ajet/tuner_lib/experimental/as_swarm_client.py +++ b/ajet/tuner_lib/experimental/as_swarm_client.py @@ -52,6 +52,8 @@ def raise_for_status_with_detail(resp): raise RuntimeError(f"SwarmClient error {resp.status_code} with non-JSON response: {response_text}") from e +class SwarmServerOfflineError(Exception): ... + class SwarmClient(object): @@ -437,7 +439,7 @@ def start_engine(self): self._wait_until_status_change_to(desired_status="ENGINE.ROLLING") logger.success("Training engine is now ROLLING and ready.") - def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=True): + def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose=True, timeout=1800): """ Poll engine status until it reaches desired_status. Reports status every 5 seconds while waiting. @@ -446,12 +448,20 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose= self.logger_info(f"Polling engine status until {desired_status}...") last_report_time = time.time() init_poll_time = last_report_time + initial_status, _ = self.get_engine_status() while True: try: current_status, _ = self.get_engine_status() current_time = time.time() + # Check if timeout has been reached + if current_time - init_poll_time >= timeout: + raise TimeoutError(f"Timeout reached while waiting for engine status to change to {desired_status}") + + if (initial_status == "ENGINE.OFFLINE") and (current_status == "ENGINE.OFFLINE"): + raise SwarmServerOfflineError(f"Engine status changed from {initial_status} to OFFLINE while waiting for {desired_status}. This may indicate an error in the engine. Please check the swarm server logs for details.") + # Report status every 5 seconds if current_time - last_report_time >= 30: if verbose: @@ -467,6 +477,9 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose= # Wait a bit before next poll time.sleep(5) + except SwarmServerOfflineError as e: + raise e + except Exception as e: logger.error(f"Error polling engine status: {e}") time.sleep(5) diff --git a/ajet/tuner_lib/experimental/as_swarm_server.py b/ajet/tuner_lib/experimental/as_swarm_server.py index a5d156d9..81a5b5ff 100644 --- a/ajet/tuner_lib/experimental/as_swarm_server.py +++ b/ajet/tuner_lib/experimental/as_swarm_server.py @@ -393,6 +393,7 @@ def override_param_callback(config): main_yaml_fp, env, exp_config, + True, # is_swarm_server ), ) p.daemon = True diff --git a/ajet/tuner_lib/experimental/interchange_utils.py b/ajet/tuner_lib/experimental/interchange_utils.py index c50dd93d..b5a52e43 100644 --- a/ajet/tuner_lib/experimental/interchange_utils.py +++ b/ajet/tuner_lib/experimental/interchange_utils.py @@ -111,8 +111,12 @@ class UpdateEngineStatusRequest(BaseModel): def get_interchange_server_url(config): port = os.getenv("AJET_DAT_INTERCHANGE_PORT") - if config.ajet.interchange_server.interchange_server_port != 'auto': - port = str(int(config.ajet.interchange_server.interchange_server_port)) + if isinstance(config, dict): + interchange_server_port = config.get("ajet", {}).get("interchange_server", {}).get("interchange_server_port", "auto") + else: + interchange_server_port = config.ajet.interchange_server.interchange_server_port + if interchange_server_port != 'auto': + port = str(int(interchange_server_port)) assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set" master_node_ip = os.getenv("MASTER_NODE_IP", "localhost") base_url = f"http://{master_node_ip}:{port}" diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py index b02f1205..8a6a301e 100644 --- a/ajet/utils/config_utils.py +++ b/ajet/utils/config_utils.py @@ -171,7 +171,7 @@ def config_safe_guard(config: dict, backbone: str) -> dict: def read_ajet_hierarchical_config( - yaml_fp, exp_name, backbone, write_to=None, exp_dir=DEFAULT_DIR, override_param_callback=None + yaml_fp, exp_name=None, backbone=None, write_to=None, exp_dir=None, override_param_callback=None ): if yaml_fp is None: config = { @@ -193,9 +193,12 @@ def read_ajet_hierarchical_config( else: with open(yaml_fp, "r", encoding="utf-8") as file: config = yaml.safe_load(file) - config["ajet"]["experiment_name"] = exp_name - config["ajet"]["experiment_dir"] = os.path.join(exp_dir, exp_name) - config["ajet"]["backbone"] = backbone + if exp_name is not None: + config["ajet"]["experiment_name"] = exp_name + if (exp_dir is not None) and (exp_name is not None): + config["ajet"]["experiment_dir"] = os.path.join(exp_dir, exp_name) + if backbone is not None: + config["ajet"]["backbone"] = backbone # remove extra config of verl for trinity if backbone == "debug": @@ -324,7 +327,7 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callb ## 4. edit new yaml config = read_ajet_hierarchical_config( - yaml_backup_dst, exp_name, backbone, write_to=yaml_backup_dst, exp_dir=exp_dir, override_param_callback=override_param_callback + yaml_backup_dst, exp_name=exp_name, backbone=backbone, write_to=yaml_backup_dst, exp_dir=exp_dir, override_param_callback=override_param_callback ) config_final = expand_ajet_hierarchical_config(config, write_to=yaml_backup_dst) diff --git a/ajet/utils/launch_utils.py b/ajet/utils/launch_utils.py index 441fbc93..46200ad9 100644 --- a/ajet/utils/launch_utils.py +++ b/ajet/utils/launch_utils.py @@ -319,6 +319,7 @@ def execute_training_process( exe_yaml_path, env, exp_config, + is_swarm_server=False, ): """ Execute the training process based on the specified backbone and configuration. @@ -403,7 +404,13 @@ def execute_training_process( subprocess.run(cmd, check=True, cwd=os.path.abspath("./"), env=env) except subprocess.CalledProcessError as e: logger.error(f"Error running subprocess: {e}") + if is_swarm_server: + from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status + http_change_engine_status(exp_config, "ENGINE.OFFLINE", global_step=0) sys.exit(1) except Exception as e: logger.error(f"Unexpected error: {e}") + if is_swarm_server: + from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status + http_change_engine_status(exp_config, "ENGINE.OFFLINE", global_step=0) sys.exit(1) diff --git a/docs/en/ajet-swarm-docker.md b/docs/en/ajet-swarm-docker.md index 38a3c653..15c054a7 100644 --- a/docs/en/ajet-swarm-docker.md +++ b/docs/en/ajet-swarm-docker.md @@ -24,6 +24,7 @@ docker run --rm -it \ -v ./swarmlog:/workspace/log \ -v ./swarmexp:/workspace/saved_experiments \ -p 10086:10086 \ + -e SWANLAB_API_KEY=$SWANLAB_API_KEY \ --gpus=all \ --shm-size=32GB \ ghcr.io/modelscope/agentjet:main \ @@ -89,6 +90,7 @@ docker run --rm -it \ -v ./swarmlog:/workspace/log \ -v ./swarmexp:/workspace/saved_experiments \ -p 10086:10086 \ + -e SWANLAB_API_KEY=$SWANLAB_API_KEY \ --gpus=all \ --shm-size=32GB \ ghcr.io/modelscope/agentjet:main \ diff --git a/math_gsm8k_grpo/yaml_backup.yaml b/math_gsm8k_grpo/yaml_backup.yaml new file mode 100644 index 00000000..69a60a3c --- /dev/null +++ b/math_gsm8k_grpo/yaml_backup.yaml @@ -0,0 +1,559 @@ +actor_rollout_ref: + actor: + _target_: verl.workers.config.FSDPActorConfig + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + async_save: false + load_contents: + - model + - optimizer + - extra + save_contents: + - model + - optimizer + - extra + clip_ratio: 0.2 + clip_ratio_c: 3.0 + clip_ratio_high: 0.2 + clip_ratio_low: 0.2 + entropy_checkpointing: false + entropy_coeff: 0 + entropy_from_logits_with_chunking: false + fsdp_config: + optimizer_offload: true + param_offload: true + grad_clip: 1.0 + kl_loss_coef: 0.002 + kl_loss_type: low_var_kl + loss_agg_mode: seq-mean-token-mean + optim: + lr: 1.0e-06 + override_ppo_mini_batch_num: 1 + policy_loss: + _target_: verl.workers.config.PolicyLossConfig + clip_cov_lb: 1.0 + clip_cov_ratio: 0.0002 + clip_cov_ub: 5.0 + kl_cov_ratio: 0.0002 + loss_mode: vanilla + ppo_kl_coef: 0.1 + ppo_epochs: 1 + ppo_max_token_len_per_gpu: 18000 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: 1 + ppo_mini_batch_size: 16 + shuffle: false + strategy: fsdp + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: true + use_fused_kernels: false + use_kl_loss: true + use_remove_padding: true + use_torch_compile: true + hybrid_engine: true + model: + custom_chat_template: null + enable_activation_offload: false + enable_gradient_checkpointing: true + exclude_modules: null + external_lib: null + fused_kernel_options: + impl_backend: torch + lora_alpha: 16 + lora_rank: 0 + override_config: {} + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct + target_modules: all-linear + trust_remote_code: false + use_fused_kernels: false + use_liger: false + use_remove_padding: true + use_shm: false + nccl_timeout: 600 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + all_ranks: false + discrete: false + ranks: [] + ref: + entropy_checkpointing: false + entropy_from_logits_with_chunking: false + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + forward_prefetch: false + param_offload: true + reshard_after_forward: true + wrap_policy: + min_num_params: 0 + log_prob_max_token_len_per_gpu: 18000 + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: true + model: null + strategy: fsdp + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: true + use_torch_compile: true + rollout: + agent: + agent_loop_config_path: null + custom_async_server: + name: null + path: null + num_workers: 8 + calculate_log_probs: false + cudagraph_capture_sizes: null + custom_dataflow_cls: + name: '' + path: '' + disable_log_stats: true + do_sample: true + dtype: bfloat16 + enable_chunked_prefill: true + enforce_eager: true + engine_kwargs: + sglang: + attention_backend: null + vllm: + disable_mm_preprocessor_cache: false + swap_space: null + free_cache_engine: true + gamma: 1.0 + gpu_memory_utilization: 0.9 + ignore_eos: false + layered_summon: false + load_format: dummy_dtensor + log_prob_max_token_len_per_gpu: 18000 + log_prob_micro_batch_size: null + log_prob_micro_batch_size_per_gpu: 4 + log_prob_use_dynamic_bsz: true + max_env_worker: 64 + max_model_len: 18000 + max_num_batched_tokens: 8192 + max_num_seqs: 10 + mode: async + multi_stage_wake_up: false + multi_turn: + expected_steps: 1 + max_sample_per_task: 30 + max_steps: 30 + n: 1 + name: vllm + ppo_micro_batch_size_per_gpu: 1 + prompt_length: 3000 + response_length: 10000 + skip_dump_dir: /tmp/rollout_dump + skip_rollout: false + temperature: 0.9 + tensor_model_parallel_size: 1 + top_k: -1 + top_p: 1.0 + trace: + backend: null + token2text: false + update_weights_bucket_megabytes: 512 + val_kwargs: + do_sample: false + num_repeat: 1 + temperature: 0.0 + top_k: -1 + top_p: 1.0 +ajet: + backbone: verl + context_tracker: + alien_llm_model: qwen3-235b-a22b-instruct-2507 + alien_llm_response_length: 512 + detect_timeline_snap: false + fix_retokenization_drift: true + log_tool_format_check: false + log_tool_format_error_detail: false + timeline_merging_policy: + ignore_tools: true + timeline_compare_level: text + data: + max_prompt_length: 3000 + max_response_length: 15000 + train_batch_size: 32 + debug: + debug_first_n_tasks: 2 + debug_max_parallel: 4 + debug_tensor_parallel_size: 4 + debug_vllm_port: 18000 + debug_vllm_seed: 12345 + enable_interchange_server: true + enable_swarm_mode: true + execute_test: false + execute_testing_lambda: '' + experiment_dir: math_gsm8k_grpo + experiment_name: math_gsm8k_grpo + interchange_server: + already_started: true + interchange_method: ipc + interchange_server_port: 10086 + max_fastapi_threads: 512 + max_inference_tracker_threads: 64 + num_fastapi_process: 2 + model: + path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct + project_name: ajet_default_project + rollout: + agent_madness_reward: -1.0 + agent_madness_termination: true + compute_madness_checklist: + - nonsense + force_disable_toolcalls: false + gamma: 1.0 + max_env_worker: 128 + max_model_len: 18000 + max_num_seqs: 10 + max_response_length_in_one_turn: 4096 + multi_turn: + expected_steps: 1 + max_sample_per_task: 30 + max_steps: 30 + n_vllm_engine: 1 + name: vllm + num_repeat: 4 + step_skip_action: 0 + submit_oversample_multiplier: 1.5 + temperature: 0.9 + tensor_model_parallel_size: 1 + top_p: 1.0 + user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow + val_kwargs: + do_sample: false + num_repeat: 1 + temperature: 0.0 + top_k: -1 + top_p: 1.0 + swarm_mode_sample_collection_max_cached_episodes: 9999 + swarm_mode_sample_collection_method: rollout_until_finish_enough_tasks + task_judge: + alien_llm_model: qwen3-235b-a22b-instruct-2507 + alien_llm_response_length: 512 + judge_protocol: null + judge_type: customized_protocol + rubrics_auto_grader: + answer_field: final_answer + categories_number: 5 + custom_evaluation_prompt: null + enable_categorization: false + grader_mode: pointwise + grader_name: auto_grader + input_data_type: jsonl_dataset_file + jsonl_dataset_file: + training: + file_path: tutorial/example_rm_auto_grader/rubrics_train.jsonl + language: en + max_score: 1 + min_score: 0 + model_name: qwen-max + query_field: main_query + query_specific_generate_number: 1 + reference_field: answer + task_reader: + data_generation: + deduplication_filter: + enabled: true + params: + api_key: null + base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 + db_path: ./.similarity_db + model: text-embedding-v4 + similarity_threshold: 0.8 + document_reader: + cache_enabled: true + chunk_size: 5120 + document_path: + - dataset/document/your-document1.pdf + - dataset/document/your-document2.pdf + languages: + - eng + split_by: sentence + llm_model: qwen-long + llm_response_length: 8192 + num_workers: 32 + query_reader: + jsonl_dataset_file: + training: + file_path: dataset/jsonl/your-queries.jsonl + type: jsonl_dataset_file + sampling_params: + temperature: 0 + task_num: 10 + env_service: + env_action_preference: code + env_type: appworld + env_url: http://127.0.0.1:8080 + training_split: train + validation_split: dev + huggingface_dat_repo: + dataset_name: null + dataset_path: gsm8k + http_proxy_address: '' + training_split: train + validation_split: validation + jsonl_dataset_file: + training: + file_path: /path/to/training/data.jsonl + validation: + file_path: /path/to/validation/data.jsonl + type: random_dummy + task_runner: + llm_infer_submit_method: async + wrapper_multiprocessing_timeout: 3600 + wrapper_type: asyncio-with-gc + trainer_common: + algorithm: + adv_estimator: grpo + use_kl_in_reward: false + checkpoint_base_dir: ./saved_checkpoints + fsdp_config: + optimizer_offload: true + param_offload: true + kl_loss_coef: 0.002 + kl_loss_type: low_var_kl + logger: tensorboard + mini_batch_num: 1 + n_gpus_per_node: 8 + nnodes: 1 + optim: + lr: 1.0e-06 + save_freq: 20 + save_trajectory_as_json_file: false + test_freq: 20 + total_epochs: 50 + ulysses_sequence_parallel_size: 1 + use_kl_loss: true + val_before_train: false + val_pass_n: 4 +algorithm: + _target_: verl.trainer.config.AlgoConfig + adv_estimator: grpo + gamma: 1.0 + kl_ctrl: + _target_: verl.trainer.config.KLControlConfig + horizon: 10000 + kl_coef: 0.001 + target_kl: 0.1 + type: fixed + kl_penalty: kl + lam: 1.0 + norm_adv_by_std_in_grpo: true + pf_ppo: + reweight_method: pow + weight_pow: 2.0 + use_kl_in_reward: false + use_pf_ppo: false +critic: + _target_: verl.workers.config.FSDPCriticConfig + checkpoint: + _target_: verl.trainer.config.CheckpointConfig + async_save: false + load_contents: + - model + - optimizer + - extra + save_contents: + - model + - optimizer + - extra + cliprange_value: 0.5 + enable: false + forward_max_token_len_per_gpu: 32768 + forward_micro_batch_size: null + forward_micro_batch_size_per_gpu: null + grad_clip: 1.0 + loss_agg_mode: seq-mean-token-mean + model: + _target_: verl.workers.config.FSDPCriticModelCfg + enable_activation_offload: false + enable_gradient_checkpointing: true + external_lib: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + forward_prefetch: false + fsdp_size: -1 + offload_policy: false + optimizer_offload: false + param_offload: false + reshard_after_forward: true + wrap_policy: + min_num_params: 0 + lora_alpha: 16 + lora_rank: 0 + override_config: {} + path: ~/models/deepseek-llm-7b-chat + target_modules: all-linear + tokenizer_path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + trust_remote_code: false + use_remove_padding: false + use_shm: false + optim: + _target_: verl.workers.config.FSDPOptimizerConfig + lr: 1.0e-05 + lr_warmup_steps: -1 + lr_warmup_steps_ratio: 0.0 + min_lr_ratio: null + total_training_steps: -1 + warmup_style: constant + weight_decay: 0.01 + ppo_epochs: 1 + ppo_max_token_len_per_gpu: 32768 + ppo_micro_batch_size: null + ppo_micro_batch_size_per_gpu: null + ppo_mini_batch_size: 16 + profiler: + _target_: verl.utils.profiler.ProfilerConfig + all_ranks: false + discrete: false + ranks: [] + rollout_n: 1 + shuffle: false + strategy: fsdp + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: true +custom_reward_function: + name: compute_score + path: null +data: + custom_cls: + name: null + path: null + datagen: + name: null + path: null + dataloader_num_workers: 8 + fast_eval: true + filter_overlong_prompts: true + filter_overlong_prompts_workers: 1 + image_key: images + max_prompt_length: 3000 + max_response_length: 15000 + prompt_key: prompt + return_full_prompt: false + return_multi_modal_inputs: true + return_raw_chat: true + return_raw_input_ids: false + reward_fn_key: data_source + sampler: + class_name: null + class_path: null + seed: 42 + shuffle: true + tokenizer: null + train_batch_size: 32 + train_files: ~/data/rlhf/gsm8k/train.parquet + truncation: error + trust_remote_code: false + use_shm: false + val_batch_size: 100000000000 + val_files: ~/data/rlhf/gsm8k/test.parquet + validation_shuffle: false + video_key: videos +defaults: +- verl_default +- ajet_default +- _self_ +hydra: + searchpath: + - file://ajet/default_config + - file://ajet/default_config/verl +ray_init: + num_cpus: null + timeline_json_file: null +reward_model: + enable: false + forward_max_token_len_per_gpu: 32768 + launch_reward_fn_async: false + max_length: null + micro_batch_size: null + micro_batch_size_per_gpu: null + model: + external_lib: null + fsdp_config: + _target_: verl.workers.config.FSDPEngineConfig + forward_prefetch: false + fsdp_size: -1 + param_offload: false + reshard_after_forward: true + wrap_policy: + min_num_params: 0 + input_tokenizer: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct + path: ~/models/FsfairX-LLaMA3-RM-v0.1 + trust_remote_code: false + use_fused_kernels: false + use_remove_padding: false + use_shm: false + profiler: + _target_: verl.utils.profiler.ProfilerConfig + all_ranks: false + discrete: false + ranks: [] + reward_manager: naive + sandbox_fusion: + max_concurrent: 64 + memory_limit_mb: 1024 + url: null + strategy: fsdp + ulysses_sequence_parallel_size: 1 + use_dynamic_bsz: true +trainer: + balance_batch: true + checkpoint_base_dir: ./saved_checkpoints + controller_nsight_options: + cuda-graph-trace: graph + cuda-memory-usage: 'true' + trace: cuda,nvtx,cublas,ucx + critic_warmup: 0 + default_hdfs_dir: null + default_local_dir: ${trainer.checkpoint_base_dir}/${trainer.project_name}/${trainer.experiment_name} + del_local_ckpt_after_load: false + device: cuda + esi_redundant_time: 0 + experiment_name: math_gsm8k_grpo + hfmodelpath: '' + log_val_generations: 0 + logger: + - console + - tensorboard + max_actor_ckpt_to_keep: null + max_critic_ckpt_to_keep: null + n_gpus_per_node: 8 + nnodes: 1 + npu_profile: + options: + analysis: true + level: level1 + record_shapes: false + roles: + - all + save_path: ./profiler_data + with_cpu: true + with_memory: false + with_module: false + with_npu: true + with_stack: false + profile_continuous_steps: false + profile_steps: null + project_name: ajet_default_project + ray_wait_register_center_timeout: 300 + resume_from_path: null + resume_mode: auto + rollout_data_dir: null + save_freq: 20 + test_freq: 20 + total_epochs: 50 + total_training_steps: null + use_legacy_worker_impl: auto + val_before_train: false + val_only: false + val_pass_n: 4 + validation_data_dir: null + worker_nsight_options: + capture-range: cudaProfilerApi + capture-range-end: null + cuda-graph-trace: graph + cuda-memory-usage: 'true' + kill: none + trace: cuda,nvtx,cublas,ucx diff --git a/tutorial/example_math_swarm/math.py b/tutorial/example_math_swarm/math.py index 041e544f..d0bf7096 100644 --- a/tutorial/example_math_swarm/math.py +++ b/tutorial/example_math_swarm/math.py @@ -35,17 +35,19 @@ def main(): ) ) - # # Hand shake with remote swarm server + # Hand shake with remote swarm server swarm_worker = SwarmClient(AJET_SWARM_URL) + ajet_job = AgentJetJob( + experiment_name="math_gsm8k_grpo", + algorithm="grpo", + n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + model=REMOTE_MODEL_PATH, + batch_size=REMOTE_BATCH_SIZE, + num_repeat=GRPO_N, + ) + print(ajet_job.config.to_dict()) swarm_worker.auto_sync_train_config_and_start_engine( - AgentJetJob( - experiment_name="math_gsm8k_grpo", - algorithm="grpo", - n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, - model=REMOTE_MODEL_PATH, - batch_size=REMOTE_BATCH_SIZE, - num_repeat=GRPO_N, - ), + ajet_job, force_restart=True, ) diff --git a/tutorial/example_werewolves/start.py b/tutorial/example_werewolves/start.py index 879b6101..fb3d6df7 100644 --- a/tutorial/example_werewolves/start.py +++ b/tutorial/example_werewolves/start.py @@ -81,6 +81,8 @@ def get_official_agent_prompt(name) -> str: class ExampleWerewolves(Workflow): trainable_targets: List[str] | None = Field(default=["werewolf"], description="List of agents to be fine-tuned.") + big_external_opponent_llm_url = "http://22.17.52.4:2888/v1" + big_external_opponent_llm_name = "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/" async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: @@ -104,11 +106,11 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl players = [] for i, role in enumerate(roles): default_model = OpenAIChatModel( - model_name="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", stream=False, - client_args={"base_url": "http://22.17.52.4:2888/v1"}, api_key="no_api_key", generate_kwargs={"temperature": 0.01}, + model_name=self.big_external_opponent_llm_name, + client_args={"base_url": self.big_external_opponent_llm_url}, ) model_for_this_agent = tuner.as_agentscope_model( agent_name=f"Player{i + 1}", # the name of this agent diff --git a/tutorial/example_werewolves_swarm/agent_roll.py b/tutorial/example_werewolves_swarm/agent_roll.py index f81346bd..6b31c6bd 100644 --- a/tutorial/example_werewolves_swarm/agent_roll.py +++ b/tutorial/example_werewolves_swarm/agent_roll.py @@ -12,11 +12,8 @@ from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient -GRPO_N = 6 # grpo group size NUM_EPOCH = 10000 AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") -REMOTE_BATCH_SIZE = 32 -REMOTE_ALLOCATE_GPU_PER_NODE = 8 def main(): @@ -42,10 +39,6 @@ def main(): GRPO_N = ajet_job.num_repeat REMOTE_BATCH_SIZE = ajet_job.batch_size - # temperature: 0.7 - # max_env_worker: 64 - # num_repeat: 6 - def rollout(task): try: # begin episode @@ -67,41 +60,14 @@ def rollout(task): return None - - def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): - # Prepare base_url, api_key - base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key) - # Read dataset item - query, reference_answer = (task.main_query, task.metadata["answer"]) - # Prepare messages - messages = [ - { "role": "system", "content": dedent("""You are an agent specialized in solving math problems. Please solve the math problem given to you. - You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}.""") }, - { "role": "user", "content": query } - ] - # Use raw http requests (non-streaming) to get response - # "Connection: close" prevents keep-alive pool reuse, which can cause BadStatusLine - # errors under high concurrency when stale pooled connections return residual bytes. - response = requests.post( - f"{base_url}/chat/completions", - json = { "model": "fill_whatever_model", "messages": messages, "stream": False }, - headers = { "Authorization": f"Bearer {api_key}", "Connection": "close" }, - timeout = 300, + import asyncio + from tutorial.example_werewolves.start import ExampleWerewolves + game = ExampleWerewolves( + trainable_targets=["werewolf"], ) - response.raise_for_status() - final_answer = response.json()['choices'][0]['message']['content'] - - reference_answer = reference_answer.split("####")[-1].strip() - pattern = r"\\boxed\{([^}]*)\}" - match = re.search(pattern, final_answer) - if match: is_success = match.group(1) == reference_answer - else: is_success = False - raw_reward = 1.0 if is_success else 0.0 - # Return - return WorkflowOutput(reward=raw_reward, metadata={"final_answer": final_answer}) - - + res = asyncio.run(game.execute(task, api_baseurl_key)) + return res if __name__ == "__main__": diff --git a/tutorial/example_werewolves_swarm/convert_skill.md b/tutorial/example_werewolves_swarm/convert_skill.md new file mode 100644 index 00000000..9ad3ff16 --- /dev/null +++ b/tutorial/example_werewolves_swarm/convert_skill.md @@ -0,0 +1,77 @@ +训练复杂智能体的时候,推荐先从yaml配置出发 + +首先,复制一份基础配置 ajet/default_config/ajet_ts_default.yaml + +cp ajet/default_config/ajet_ts_default.yaml tutorial/example_werewolves_swarm/werewolves.yaml + +然后对配置中的参数进行修改: + +---- opencode命令:这里补充一个参数配置说明表格,参考tutorial/example_werewolves_swarm/werewolves.yaml ---- + + +# 编写训练循环 (Swarm Client) + +Swarm Client 流程如下: + +- 连接蜂群 +- 蜂群server初始化 +- 开始EPOCH循环 + + +```python +# -*- coding: utf-8 -*- + +NUM_EPOCH = 10000 +AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") + +def main(): + + ajet_job = AgentJetJob( + base_yaml_config="tutorial/example_werewolves_swarm/werewolves.yaml", + algorithm="grpo", + experiment_name="werewolves_swarm", + ) + + # Hand shake with remote swarm server + swarm_worker = SwarmClient(AJET_SWARM_URL) + swarm_worker.auto_sync_train_config_and_start_engine( ajet_job, force_restart=True ) + + GRPO_N = ajet_job.num_repeat + REMOTE_BATCH_SIZE = ajet_job.batch_size + + def rollout(task): + try: + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return + except: + pass + + # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) + dataset = RouterTaskReader( + reader_type = "random_dummy", + reader_config = AjetTaskReader() + ) + executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True) + for _ in range(NUM_EPOCH): + for _, task in enumerate(dataset.generate_training_tasks()): + for _ in range(GRPO_N): + executor.submit_with_periodic_drain(fn=rollout, task=task) + + return None + + +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + raise NotImplementedError("see below.") + + +if __name__ == "__main__": + main() + +``` + +# 编写Agent (Swarm Client) diff --git a/tutorial/example_werewolves_swarm/game.py b/tutorial/example_werewolves_swarm/game.py deleted file mode 100644 index 10246c32..00000000 --- a/tutorial/example_werewolves_swarm/game.py +++ /dev/null @@ -1,351 +0,0 @@ -# -*- coding: utf-8 -*- -# pylint: disable=too-many-branches, too-many-statements, no-name-in-module -"""A werewolf game implemented by agentscope.""" -from agentscope.agent import ReActAgent -from agentscope.pipeline import MsgHub, fanout_pipeline, sequential_pipeline - -# Uncomment the following line to use Chinese prompts -# from tutorial.example_werewolves.prompt import ChinesePrompts as Prompts -from loguru import logger - -from tutorial.example_werewolves.prompt import EnglishPrompts as Prompts -from tutorial.example_werewolves.structured_model import ( - DiscussionModel, - WitchResurrectModel, - get_hunter_model, - get_poison_model, - get_seer_model, - get_vote_model, -) -from tutorial.example_werewolves.utils import ( - MAX_DISCUSSION_ROUND, - MAX_GAME_ROUND, - EchoAgent, - Players, - majority_vote, - names_to_str, -) - - -class BadGuyException(Exception): - ... - - -moderator = EchoAgent() -# moderator.set_console_output_enabled(False) - - -async def hunter_stage( - hunter_agent: ReActAgent, - players: Players, -) -> str | None: - """Because the hunter's stage may happen in two places: killed at night - or voted during the day, we define a function here to avoid duplication.""" - global moderator - msg_hunter = await hunter_agent( - await moderator(Prompts.to_hunter.format(name=hunter_agent.name)), - structured_model=get_hunter_model(players.current_alive), - ) - if msg_hunter.metadata.get("shoot"): - return msg_hunter.metadata.get("name", None) - return None - - -async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C901 - """The main entry of the werewolf game - - Args: - agents (`list[ReActAgent]`): - A list of 9 agents. - """ - assert len(agents) == 9, "The werewolf game needs exactly 9 players." - - # Init the players' status - players = Players() - - # If the witch has healing and poison potion - healing, poison = True, True - - # If it's the first day, the dead can leave a message - first_day = True - - # Broadcast the game begin message - async with MsgHub(participants=agents) as greeting_hub: - await greeting_hub.broadcast( - await moderator( - Prompts.to_all_new_game.format(names_to_str(agents)), - ), - ) - - # Assign roles to the agents - for agent, role in zip(agents, roles): - # Tell the agent its role - await agent.observe( - await moderator( - f"[{agent.name} ONLY] {agent.name}, your role is {role}.", - ), - ) - players.add_player(agent, role) - - # Printing the roles - players.print_roles() - - # GAME BEGIN! - for _ in range(MAX_GAME_ROUND): - # Create a MsgHub for all players to broadcast messages - async with MsgHub( - participants=players.current_alive, - enable_auto_broadcast=False, # manual broadcast only - name="alive_players", - ) as alive_players_hub: - # Night phase - await alive_players_hub.broadcast( - await moderator(Prompts.to_all_night), - ) - killed_player, poisoned_player, shot_player = None, None, None - - try: - # Werewolves discuss - async with MsgHub( - players.werewolves, - enable_auto_broadcast=True, - announcement=await moderator( - Prompts.to_wolves_discussion.format( - names_to_str(players.werewolves), - names_to_str(players.current_alive), - ), - ), - name="werewolves", - ) as werewolves_hub: - # Discussion - n_werewolves = len(players.werewolves) - for _ in range(1, MAX_DISCUSSION_ROUND * n_werewolves + 1): - res = await players.werewolves[_ % n_werewolves]( - structured_model=DiscussionModel, - ) - if _ % n_werewolves == 0 and res.metadata.get( - "reach_agreement", - ): - break - - # Werewolves vote - # Disable auto broadcast to avoid following other's votes - werewolves_hub.set_auto_broadcast(False) - msgs_vote = await fanout_pipeline( - players.werewolves, - msg=await moderator(content=Prompts.to_wolves_vote), - structured_model=get_vote_model(players.current_alive), - enable_gather=False, - ) - killed_player, votes = majority_vote( - [_.metadata.get("vote") for _ in msgs_vote], - ) - # Postpone the broadcast of voting - await werewolves_hub.broadcast( - [ - *msgs_vote, - await moderator( - Prompts.to_wolves_res.format(votes, killed_player), - ), - ], - ) - except Exception as e: - raise BadGuyException( - f"Werewolves failed to make a decision: {e}", - ) - - # Witch's turn - await alive_players_hub.broadcast( - await moderator(Prompts.to_all_witch_turn), - ) - msg_witch_poison = None - for agent in players.witch: - # Cannot heal witch herself - msg_witch_resurrect = None - if healing and killed_player != agent.name: - msg_witch_resurrect = await agent( - await moderator( - Prompts.to_witch_resurrect.format( - witch_name=agent.name, - dead_name=killed_player, - ), - ), - structured_model=WitchResurrectModel, - ) - if msg_witch_resurrect.metadata.get("resurrect"): - killed_player = None - healing = False - - # Has poison potion and hasn't used the healing potion - if poison and not ( - msg_witch_resurrect and msg_witch_resurrect.metadata["resurrect"] - ): - msg_witch_poison = await agent( - await moderator( - Prompts.to_witch_poison.format( - witch_name=agent.name, - ), - ), - structured_model=get_poison_model( - players.current_alive, - ), - ) - if msg_witch_poison.metadata.get("poison"): - poisoned_player = msg_witch_poison.metadata.get("name") - poison = False - - # Seer's turn - await alive_players_hub.broadcast( - await moderator(Prompts.to_all_seer_turn), - ) - for agent in players.seer: - msg_seer = await agent( - await moderator( - Prompts.to_seer.format( - agent.name, - names_to_str(players.current_alive), - ), - ), - structured_model=get_seer_model(players.current_alive), - ) - if msg_seer.metadata.get("name"): - player = msg_seer.metadata["name"] - await agent.observe( - await moderator( - Prompts.to_seer_result.format( - agent_name=player, - role=players.name_to_role[player], - ), - ), - ) - - # Hunter's turn - for agent in players.hunter: - # If killed and not by witch's poison - if killed_player == agent.name and poisoned_player != agent.name: - shot_player = await hunter_stage(agent, players) - - # Update alive players - dead_tonight = [killed_player, poisoned_player, shot_player] - players.update_players(dead_tonight) - - # Day phase - if len([_ for _ in dead_tonight if _]) > 0: - await alive_players_hub.broadcast( - await moderator( - Prompts.to_all_day.format( - names_to_str([_ for _ in dead_tonight if _]), - ), - ), - ) - - # The killed player leave a last message in first night - if killed_player and first_day: - msg_moderator = await moderator( - Prompts.to_dead_player.format(killed_player), - ) - await alive_players_hub.broadcast(msg_moderator) - # Leave a message - last_msg = await players.name_to_agent[killed_player]() - await alive_players_hub.broadcast(last_msg) - - else: - await alive_players_hub.broadcast( - await moderator(Prompts.to_all_peace), - ) - - # Check winning - res = players.check_winning() - if res: - await moderator(res) - break - - # Discussion - await alive_players_hub.broadcast( - await moderator( - Prompts.to_all_discuss.format( - names=names_to_str(players.current_alive), - ), - ), - ) - # Open the auto broadcast to enable discussion - alive_players_hub.set_auto_broadcast(True) - await sequential_pipeline(players.current_alive) - # Disable auto broadcast to avoid leaking info - alive_players_hub.set_auto_broadcast(False) - - # Voting - msgs_vote = await fanout_pipeline( - players.current_alive, - await moderator( - Prompts.to_all_vote.format( - names_to_str(players.current_alive), - ), - ), - structured_model=get_vote_model(players.current_alive), - enable_gather=False, - ) - voted_player, votes = majority_vote( - [_.metadata.get("vote") for _ in msgs_vote], - ) - # Broadcast the voting messages together to avoid influencing - # each other - voting_msgs = [ - *msgs_vote, - await moderator( - Prompts.to_all_res.format(votes, voted_player), - ), - ] - - # Leave a message if voted - if voted_player: - prompt_msg = await moderator( - Prompts.to_dead_player.format(voted_player), - ) - last_msg = await players.name_to_agent[voted_player]( - prompt_msg, - ) - voting_msgs.extend([prompt_msg, last_msg]) - - await alive_players_hub.broadcast(voting_msgs) - - # If the voted player is the hunter, he can shoot someone - shot_player = None - for agent in players.hunter: - if voted_player == agent.name: - shot_player = await hunter_stage(agent, players) - if shot_player: - await alive_players_hub.broadcast( - await moderator( - Prompts.to_all_hunter_shoot.format( - shot_player, - ), - ), - ) - - # Update alive players - dead_today = [voted_player, shot_player] - players.update_players(dead_today) - - # Check winning - res = players.check_winning() - if res: - async with MsgHub(players.all_players) as all_players_hub: - res_msg = await moderator(res) - await all_players_hub.broadcast(res_msg) - break - - # The day ends - first_day = False - - # # Game over, each player reflects - # await fanout_pipeline( - # agents=agents, - # msg=await moderator(Prompts.to_all_reflect), - # ) - - alive_wolves = players.werewolves - good_guy_win = len(alive_wolves) == 0 - logger.warning("**********************************") - logger.warning(f"Good guy win: {good_guy_win}, alive werewolves: {alive_wolves}") - return good_guy_win diff --git a/tutorial/example_werewolves_swarm/prompt.py b/tutorial/example_werewolves_swarm/prompt.py deleted file mode 100644 index 52d0f650..00000000 --- a/tutorial/example_werewolves_swarm/prompt.py +++ /dev/null @@ -1,176 +0,0 @@ -# -*- coding: utf-8 -*- -"""Default prompts""" - - -class EnglishPrompts: - """English prompts used to guide the werewolf game.""" - - to_dead_player = ( - "{}, you're eliminated now. Now you can make a final statement to " - "all alive players before you leave the game." - ) - - to_all_new_game = ( - "A new game is starting, the players are: {}. Now we randomly " - "reassign the roles to each player and inform them of their roles " - "privately." - ) - - to_all_night = ( - "Night has fallen, everyone close your eyes. Werewolves open your " - "eyes and choose a player to eliminate tonight." - ) - - to_wolves_discussion = ( - "[WEREWOLVES ONLY] {}, you should discuss and " - "decide on a player to eliminate tonight. Current alive players " - "are {}. Remember to set `reach_agreement` to True if you reach an " - "agreement during the discussion." - ) - - to_wolves_vote = "[WEREWOLVES ONLY] Which player do you vote to kill?" - - to_wolves_res = ( - "[WEREWOLVES ONLY] The voting result is {}. So you have chosen to " "eliminate {}." - ) - - to_all_witch_turn = "Witch's turn, witch open your eyes and decide your action tonight..." - to_witch_resurrect = ( - "[WITCH ONLY] {witch_name}, you're the witch, and tonight {dead_name} " - "is eliminated. You can resurrect him/her by using your healing " - "potion, " - "and note you can only use it once in the whole game. Do you want to " - "resurrect {dead_name}? Give me your reason and decision." - ) - - to_witch_resurrect_no = "[WITCH ONLY] The witch has chosen not to resurrect the player." - to_witch_resurrect_yes = "[WITCH ONLY] The witch has chosen to resurrect the player." - - to_witch_poison = ( - "[WITCH ONLY] {witch_name}, as a witch, you have a one-time-use " - "poison potion, do you want to use it tonight? Give me your reason " - "and decision." - ) - - to_all_seer_turn = ( - "Seer's turn, seer open your eyes and check one player's identity " "tonight..." - ) - - to_seer = ( - "[SEER ONLY] {}, as the seer you can check one player's identity " - "tonight. Who do you want to check? Give me your reason and decision." - ) - - to_seer_result = "[SEER ONLY] You've checked {agent_name}, and the result is: {role}." - - to_hunter = ( - "[HUNTER ONLY] {name}, as the hunter you're eliminated tonight. You " - "can choose one player to take down with you. Also, you can choose " - "not to use this ability. Give me your reason and decision." - ) - - to_all_hunter_shoot = "The hunter has chosen to shoot {} down with him/herself." - - to_all_day = ( - "The day is coming, all players open your eyes. Last night, " - "the following player(s) has been eliminated: {}." - ) - - to_all_peace = ( - "The day is coming, all the players open your eyes. Last night is " - "peaceful, no player is eliminated." - ) - - to_all_discuss = ( - "Now the alive players are {names}. The game goes on, it's time to " - "discuss and vote a player to be eliminated. Now you each take turns " - "to speak once in the order of {names}." - ) - - to_all_vote = ( - "Now the discussion is over. Everyone, please vote to eliminate one " - "player from the alive players: {}." - ) - - to_all_res = "The voting result is {}. So {} has been voted out." - - to_all_wolf_win = ( - "There are {n_alive} players alive, and {n_werewolves} of them are " - "werewolves. " - "The game is over and werewolves win🐺🎉!" - "In this game, the true roles of all players are: {true_roles}" - ) - - to_all_village_win = ( - "All the werewolves have been eliminated." - "The game is over and villagers win🏘️🎉!" - "In this game, the true roles of all players are: {true_roles}" - ) - - to_all_continue = "The game goes on." - - to_all_reflect = ( - "The game is over. Now each player can reflect on their performance. " - "Note each player only has one chance to speak and the reflection is " - "only visible to themselves." - ) - - -class ChinesePrompts: - """Chinese prompts used to guide the werewolf game.""" - - to_dead_player = "{}, 你已被淘汰。现在你可以向所有存活玩家发表最后的遗言。" - - to_all_new_game = "新的一局游戏开始,参与玩家包括:{}。现在为每位玩家重新随机分配身份,并私下告知各自身份。" - - to_all_night = "天黑了,请所有人闭眼。狼人请睁眼,选择今晚要淘汰的一名玩家..." - - to_wolves_discussion = ( - "[仅狼人可见] {}, 你们可以讨论并决定今晚要淘汰的玩家。当前存活玩家有:{}。" "如果达成一致,请将 `reach_agreement` 设为 True。" - ) - - to_wolves_vote = "[仅狼人可见] 你投票要杀死哪位玩家?" - - to_wolves_res = "[仅狼人可见] 投票结果为 {},你们选择淘汰 {}。" - - to_all_witch_turn = "轮到女巫行动,女巫请睁眼并决定今晚的操作..." - to_witch_resurrect = ( - "[仅女巫可见] {witch_name},你是女巫,今晚{dead_name}被淘汰。" - "你可以用解药救他/她,注意解药全局只能用一次。你要救{dead_name}吗?" - "请给出理由和决定。" - ) - - to_witch_resurrect_no = "[仅女巫可见] 女巫选择不救该玩家。" - to_witch_resurrect_yes = "[仅女巫可见] 女巫选择救活该玩家。" - - to_witch_poison = "[仅女巫可见] {witch_name},你有一瓶一次性毒药,今晚要使用吗?请给出理由和决定。" - - to_all_seer_turn = "轮到预言家行动,预言家请睁眼并查验一名玩家身份..." - - to_seer = "[仅预言家可见] {}, 你是预言家,今晚可以查验一名玩家身份。你要查谁?请给出理由和决定。" - - to_seer_result = "[仅预言家可见] 你查验了{agent_name},结果是:{role}。" - - to_hunter = "[仅猎人可见] {name},你是猎人,今晚被淘汰。你可以选择带走一名玩家,也可以选择不带走。请给出理由和决定。" - - to_all_hunter_shoot = "猎人选择带走 {} 一起出局。" - - to_all_day = "天亮了,请所有玩家睁眼。昨晚被淘汰的玩家有:{}。" - - to_all_peace = "天亮了,请所有玩家睁眼。昨晚平安夜,无人被淘汰。" - - to_all_discuss = "现在存活玩家有:{names}。游戏继续,大家开始讨论并投票淘汰一名玩家。请按顺序({names})依次发言。" - - to_all_vote = "讨论结束。请大家从存活玩家中投票淘汰一人:{}。" - - to_all_res = "投票结果为 {},{} 被淘汰。" - - to_all_wolf_win = ( - "当前存活玩家共{n_alive}人,其中{n_werewolves}人为狼人。" "游戏结束,狼人获胜🐺🎉!" "本局所有玩家真实身份为:{true_roles}" - ) - - to_all_village_win = "所有狼人已被淘汰。游戏结束,村民获胜🏘️🎉!本局所有玩家真实身份为:{true_roles}" - - to_all_continue = "游戏继续。" - - to_all_reflect = "游戏结束。现在每位玩家可以对自己的表现进行反思。注意每位玩家只有一次发言机会,且反思内容仅自己可见。" diff --git a/tutorial/example_werewolves_swarm/start.py b/tutorial/example_werewolves_swarm/start.py deleted file mode 100644 index 879b6101..00000000 --- a/tutorial/example_werewolves_swarm/start.py +++ /dev/null @@ -1,155 +0,0 @@ -# -*- coding: utf-8 -*- -# flake8: noqa: E501 - -"""The main entry point for the werewolf game.""" - -from typing import List -import numpy as np -import dotenv -dotenv.load_dotenv() - -from textwrap import dedent - -from agentscope.agent import ReActAgent -from agentscope.formatter import DashScopeMultiAgentFormatter, OpenAIMultiAgentFormatter -from agentscope.model import OpenAIChatModel -from loguru import logger -from pydantic import Field - -from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask -from tutorial.example_werewolves.game import BadGuyException, werewolves_game - - -def get_official_agent_prompt(name) -> str: - system_prompt = dedent( - f""" - You're a werewolf game player named {name}. - - # YOUR TARGET - Your target is to win the game with your teammates as much as possible. - - # GAME RULES - - In werewolf game, players are divided into three werewolves, three villagers, one seer, one hunter and one witch. - - Werewolves: kill one player each night, and must hide identity during the day. - - Villagers: ordinary players without special abilities, try to identify and eliminate werewolves. - - Seer: A special villager who can check one player's identity each night. - - Witch: A special villager with two one-time-use potions: a healing potion to save a player from being killed at night, and a poison to eliminate one player at night. - - Hunter: A special villager who can take one player down with them when they are eliminated. - - The game alternates between night and day phases until one side wins: - - Night Phase - - Werewolves choose one victim - - Seer checks one player's identity - - Witch decides whether to use potions - - Moderator announces who died during the night - - Day Phase - - All players discuss and vote to eliminate one suspected player - - # GAME GUIDANCE - - Try your best to win the game with your teammates, tricks, lies, and deception are all allowed, e.g. pretending to be a different role. - - During discussion, don't be political, be direct and to the point. - - The day phase voting provides important clues. For example, the werewolves may vote together, attack the seer, etc. - ## GAME GUIDANCE FOR WEREWOLF - - Seer is your greatest threat, who can check one player's identity each night. Analyze players' speeches, find out the seer and eliminate him/her will greatly increase your chances of winning. - - In the first night, making random choices is common for werewolves since no information is available. - - Pretending to be other roles (seer, witch or villager) is a common strategy to hide your identity and mislead other villagers in the day phase. - - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, if the dead player is hunter, etc. Use this information to adjust your strategy. - ## GAME GUIDANCE FOR SEER - - Seer is very important to villagers, exposing yourself too early may lead to being targeted by werewolves. - - Your ability to check one player's identity is crucial. - - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, if the dead player is hunter, etc. Use this information to adjust your strategy. - ## GAME GUIDANCE FOR WITCH - - Witch has two powerful potions, use them wisely to protect key villagers or eliminate suspected werewolves. - - The outcome of the night phase provides important clues. For example, if the dead player is hunter, etc. Use this information to adjust your strategy. - ## GAME GUIDANCE FOR HUNTER - - Using your ability in day phase will expose your role (since only hunter can take one player down) - - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, etc. Use this information to adjust your strategy. - ## GAME GUIDANCE FOR VILLAGER - - Protecting special villagers, especially the seer, is crucial for your team's success. - - Werewolves may pretend to be the seer. Be cautious and don't trust anyone easily. - - The outcome of the night phase provides important clues. For example, if witch uses the healing or poison potion, if the dead player is hunter, etc. Use this information to adjust your strategy. - - # NOTE - - [IMPORTANT] DO NOT make up any information that is not provided by the moderator or other players. - - This is a TEXT-based game, so DO NOT use or make up any non-textual information. - - Always critically reflect on whether your evidence exist, and avoid making assumptions. - - Your response should be specific and concise, provide clear reason and avoid unnecessary elaboration. - - Generate your one-line response by using the `generate_response` function. - - Don't repeat the others' speeches.""" - ) - return system_prompt - - -class ExampleWerewolves(Workflow): - trainable_targets: List[str] | None = Field(default=["werewolf"], description="List of agents to be fine-tuned.") - - async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: - - # ensure trainable targets is legal - assert self.trainable_targets is not None, "trainable_targets cannot be None in ExampleWerewolves (because we want to demonstrate a explicit multi-agent case)." - - # bad guys and good guys cannot be trained simultaneously - # (because mix-cooperation-competition MARL needs too many advanced techniques to be displayed here) - if "werewolf" in self.trainable_targets: - assert len(self.trainable_targets) == 1, "Cannot train hostile roles simultaneously." - else: - assert len(self.trainable_targets) != 0, "No trainable targets specified." - - # make and shuffle roles (fix random seed for reproducibility) - roles = ["werewolf"] * 3 + ["villager"] * 3 + ["seer", "witch", "hunter"] - task_id = workflow_task.task.metadata["random_number"] - np.random.seed(int(task_id)) - np.random.shuffle(roles) - - # initialize agents - players = [] - for i, role in enumerate(roles): - default_model = OpenAIChatModel( - model_name="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", - stream=False, - client_args={"base_url": "http://22.17.52.4:2888/v1"}, - api_key="no_api_key", - generate_kwargs={"temperature": 0.01}, - ) - model_for_this_agent = tuner.as_agentscope_model( - agent_name=f"Player{i + 1}", # the name of this agent - target_tag=role, # `target_tag in self.trainable_targets` means we train this agent, otherwise we do not train this agent. - debug_model=default_model, # the model used when this agent is not in `self.trainable_targets` - ) - agent = ReActAgent( - name=f"Player{i + 1}", - sys_prompt=get_official_agent_prompt(f"Player{i + 1}"), - model=model_for_this_agent, - formatter=DashScopeMultiAgentFormatter() - if role in self.trainable_targets - else OpenAIMultiAgentFormatter(), - max_iters=3 if role in self.trainable_targets else 5, - ) - # agent.set_console_output_enabled(False) - players += [agent] - - # reward condition - try: - good_guy_win = await werewolves_game(players, roles) - raw_reward = 0 - is_success = False - if (good_guy_win and self.trainable_targets[0] != "werewolf") or ( - not good_guy_win and self.trainable_targets[0] == "werewolf" - ): - raw_reward = 1 - is_success = True - logger.warning(f"Raw reward: {raw_reward}") - logger.warning(f"Is success: {is_success}") - except BadGuyException as e: - logger.bind(exception=True).exception( - f"Error during game execution. Game cannot continue, whatever the cause, let's punish trainable agents (Although they maybe innocent)." - ) - raw_reward = -0.1 - is_success = False - except Exception as e: - logger.bind(exception=True).exception( - f"Error during game execution. Game cannot continue, whatever the cause, let's punish trainable agents (Although they maybe innocent)." - ) - raw_reward = -0.1 - is_success = False - - return WorkflowOutput(reward=raw_reward, is_success=is_success) diff --git a/tutorial/example_werewolves_swarm/structured_model.py b/tutorial/example_werewolves_swarm/structured_model.py deleted file mode 100644 index 46390589..00000000 --- a/tutorial/example_werewolves_swarm/structured_model.py +++ /dev/null @@ -1,88 +0,0 @@ -# -*- coding: utf-8 -*- -"""The structured output models used in the werewolf game.""" -from typing import Literal - -from agentscope.agent import AgentBase -from pydantic import BaseModel, Field - - -class DiscussionModel(BaseModel): - """The output format for discussion.""" - - reach_agreement: bool = Field( - description="Whether you have reached an agreement or not", - ) - - -def get_vote_model(agents: list[AgentBase]) -> type[BaseModel]: - """Get the vote model by player names.""" - - class VoteModel(BaseModel): - """The vote output format.""" - - vote: Literal[tuple(_.name for _ in agents)] = Field( # type: ignore - description="The name of the player you want to vote for", - ) - - return VoteModel - - -class WitchResurrectModel(BaseModel): - """The output format for witch resurrect action.""" - - resurrect: bool = Field( - description="Whether you want to resurrect the player", - ) - - -def get_poison_model(agents: list[AgentBase]) -> type[BaseModel]: - """Get the poison model by player names.""" - - class WitchPoisonModel(BaseModel): - """The output format for witch poison action.""" - - poison: bool = Field( - description="Do you want to use the poison potion", - ) - name: Literal[tuple(_.name for _ in agents)] | None = ( # type: ignore - Field( - description="The name of the player you want to poison, if you " - "don't want to poison anyone, just leave it empty", - default=None, - ) - ) - - return WitchPoisonModel - - -def get_seer_model(agents: list[AgentBase]) -> type[BaseModel]: - """Get the seer model by player names.""" - - class SeerModel(BaseModel): - """The output format for seer action.""" - - name: Literal[tuple(_.name for _ in agents)] = Field( # type: ignore - description="The name of the player you want to check", - ) - - return SeerModel - - -def get_hunter_model(agents: list[AgentBase]) -> type[BaseModel]: - """Get the hunter model by player agents.""" - - class HunterModel(BaseModel): - """The output format for hunter action.""" - - shoot: bool = Field( - description="Whether you want to use the shooting ability or not", - ) - name: Literal[tuple(_.name for _ in agents)] | None = ( # type: ignore - Field( - description="The name of the player you want to shoot, if you " - "don't want to the ability, just leave it empty", - default=None, - ) - ) - - return HunterModel diff --git a/tutorial/example_werewolves_swarm/utils.py b/tutorial/example_werewolves_swarm/utils.py deleted file mode 100644 index c9dd0039..00000000 --- a/tutorial/example_werewolves_swarm/utils.py +++ /dev/null @@ -1,161 +0,0 @@ -# -*- coding: utf-8 -*- -"""Utility functions for the werewolf game.""" -from collections import defaultdict -from typing import Any - -import numpy as np -from agentscope.agent import AgentBase, ReActAgent -from agentscope.message import Msg - -from tutorial.example_werewolves.prompt import EnglishPrompts as Prompts - -# MAX_GAME_ROUND = 30 -# MAX_DISCUSSION_ROUND = 3 -MAX_GAME_ROUND = 7 -MAX_DISCUSSION_ROUND = 2 - - -def majority_vote(votes: list[str]) -> tuple: - """Return the vote with the most counts.""" - result = max(set(votes), key=votes.count) - names, counts = np.unique(votes, return_counts=True) - conditions = ", ".join( - [f"{name}: {count}" for name, count in zip(names, counts)], - ) - return result, conditions - - -def names_to_str(agents: list[str] | list[ReActAgent]) -> str: - """Return a string of agent names.""" - if not agents: - return "" - - if len(agents) == 1: - if isinstance(agents[0], ReActAgent): - return agents[0].name - return agents[0] - - names = [] - for agent in agents: - if isinstance(agent, ReActAgent): - names.append(agent.name) - else: - names.append(agent) - return ", ".join([*names[:-1], "and " + names[-1]]) - - -class EchoAgent(AgentBase): - """Echo agent that repeats the input message.""" - - def __init__(self) -> None: - super().__init__() - self.name = "Moderator" - - async def reply(self, content: str) -> Msg: - """Repeat the input content with its name and role.""" - msg = Msg( - self.name, - content, - role="assistant", - ) - await self.print(msg) - return msg - - async def handle_interrupt( - self, - *args: Any, - **kwargs: Any, - ) -> Msg: - """Handle interrupt.""" - - async def observe(self, msg: Msg | list[Msg] | None) -> None: - """Observe the user's message.""" - - -class Players: - """Maintain the players' status.""" - - def __init__(self) -> None: - """Initialize the players.""" - # The mapping from player name to role - self.name_to_role = {} - self.role_to_names = defaultdict(list) - self.name_to_agent = {} - self.werewolves = [] - self.villagers = [] - self.seer = [] - self.hunter = [] - self.witch = [] - self.current_alive = [] - self.all_players = [] - - def add_player(self, player: ReActAgent, role: str) -> None: - """Add a player to the game. - - Args: - player (`ReActAgent`): - The player to be added. - role (`str`): - The role of the player. - """ - self.name_to_role[player.name] = role - self.name_to_agent[player.name] = player - self.role_to_names[role].append(player.name) - self.all_players.append(player) - if role == "werewolf": - self.werewolves.append(player) - elif role == "villager": - self.villagers.append(player) - elif role == "seer": - self.seer.append(player) - elif role == "hunter": - self.hunter.append(player) - elif role == "witch": - self.witch.append(player) - else: - raise ValueError(f"Unknown role: {role}") - self.current_alive.append(player) - - def update_players(self, dead_players: list[ReActAgent]) -> None: - """Update the current alive players. - - Args: - dead_players (`list[ReActAgent]`): - A list of dead players to be removed. - """ - self.werewolves = [_ for _ in self.werewolves if _.name not in dead_players] - self.villagers = [_ for _ in self.villagers if _.name not in dead_players] - self.seer = [_ for _ in self.seer if _.name not in dead_players] - self.hunter = [_ for _ in self.hunter if _.name not in dead_players] - self.witch = [_ for _ in self.witch if _.name not in dead_players] - self.current_alive = [_ for _ in self.current_alive if _.name not in dead_players] - - def print_roles(self) -> None: - """Print the roles of all players.""" - print("Roles:") - for name, role in self.name_to_role.items(): - print(f" - {name}: {role}") - - def check_winning(self) -> str | None: - """Check if the game is over and return the winning message.""" - - # Prepare true roles string - true_roles = ( - f'{names_to_str(self.role_to_names["werewolf"])} are werewolves, ' - f'{names_to_str(self.role_to_names["villager"])} are villagers, ' - f'{names_to_str(self.role_to_names["seer"])} is the seer, ' - f'{names_to_str(self.role_to_names["hunter"])} is the hunter, ' - f'and {names_to_str(self.role_to_names["witch"])} is the witch.' - ) - - if len(self.werewolves) * 2 >= len(self.current_alive): - return Prompts.to_all_wolf_win.format( - n_alive=len(self.current_alive), - n_werewolves=len(self.werewolves), - true_roles=true_roles, - ) - if self.current_alive and not self.werewolves: - return Prompts.to_all_village_win.format( - true_roles=true_roles, - ) - return None diff --git a/tutorial/example_werewolves_swarm/werewolves.md b/tutorial/example_werewolves_swarm/werewolves.md deleted file mode 100644 index 737a11f4..00000000 --- a/tutorial/example_werewolves_swarm/werewolves.md +++ /dev/null @@ -1,9 +0,0 @@ -# Training a basic math agent - - -Please refer to document at [`docs/en/example_werewolves.md`](docs/en/example_werewolves.md) - - -# Translate to yaml - -tutorial/example_werewolves_swarm/werewolves.yaml \ No newline at end of file diff --git a/tutorial/example_werewolves_swarm/werewolves.yaml b/tutorial/example_werewolves_swarm/werewolves.yaml index 7f786985..fae6ac50 100644 --- a/tutorial/example_werewolves_swarm/werewolves.yaml +++ b/tutorial/example_werewolves_swarm/werewolves.yaml @@ -57,10 +57,10 @@ ajet: trainer_common: save_freq: 5 - test_freq: 99999 - total_epochs: 99999 + test_freq: 9999999 + total_epochs: 9999999 total_training_steps: 25 - nnodes: 2 + nnodes: 1 n_gpus_per_node: 8 # ------------------ do not edit ------------------ From e9bc0e11fa15e59478545988da5d818faae28147 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Mon, 2 Mar 2026 17:05:31 +0800 Subject: [PATCH 03/11] refactor: update configuration and improve swarm client functionality --- .gitignore | 1 + ajet/default_config/ajet_default.yaml | 2 +- ajet/tuner_lib/as_oai_baseurl_apikey.py | 7 +++++-- ajet/tuner_lib/experimental/as_swarm_client.py | 4 ++-- ajet/tuner_lib/experimental/as_swarm_server.py | 2 +- tutorial/example_math_swarm/math.py | 17 +++++++---------- tutorial/example_werewolves/start.py | 10 ++++------ tutorial/example_werewolves_swarm/agent_roll.py | 11 +++++------ 8 files changed, 26 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index bdb3f5f4..1b88f0c1 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,4 @@ modelscope_cache prompts swarmexp swarmlog +werewolves_swarm diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 531eb321..32bcd822 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -3,7 +3,7 @@ ajet: project_name: "ajet_default_project" experiment_name: "read_yaml_name" experiment_dir: "auto" # {exp-dir}/{experiment_name} - backbone: debug # `debug` or `trinity` or `verl` + backbone: verl # `debug` or `trinity` or `verl` model: diff --git a/ajet/tuner_lib/as_oai_baseurl_apikey.py b/ajet/tuner_lib/as_oai_baseurl_apikey.py index cc40c61b..dcc5e549 100644 --- a/ajet/tuner_lib/as_oai_baseurl_apikey.py +++ b/ajet/tuner_lib/as_oai_baseurl_apikey.py @@ -29,8 +29,11 @@ class OpenaiBaseUrlAndApiKey(BaseModel): episode_uuid: str = Field(default="episode_id", description="reserved field.") def as_agentscope_model(self, *args, **kwargs): - from agentscope.model import DashScopeChatModel - return DashScopeChatModel(model_name="AgentJet-Model", api_key=self.api_key, base_http_api_url=self.base_url) + from agentscope.model import OpenAIChatModel + return OpenAIChatModel( + model_name="AgentJet-Model", api_key=self.api_key, + client_args={"base_url": self.base_url} + ) def as_raw_openai_sdk_client(self, *args, **kwargs): from openai import AsyncOpenAI diff --git a/ajet/tuner_lib/experimental/as_swarm_client.py b/ajet/tuner_lib/experimental/as_swarm_client.py index f4dfcee5..61d22941 100644 --- a/ajet/tuner_lib/experimental/as_swarm_client.py +++ b/ajet/tuner_lib/experimental/as_swarm_client.py @@ -195,7 +195,7 @@ def _should_throttle(self, throttle_policy: SwarmThrottlePolicy, pool_info: Curr self._remember_seen_task(throttle_policy.current_task_id, throttle_policy.expected_batch_size, throttle_policy.expected_num_repeat) return should_throttle - def begin_episode(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: + def begin_episode(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: """ Block until an episode is claimed. Argument: @@ -210,7 +210,7 @@ def begin_episode(self, discard_episode_timeout=600, episode_type="train", throt """ return self._begin_episode_auto_retry(discard_episode_timeout, episode_type, throttle_policy) - def _begin_episode_auto_retry(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: + def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: # max_episode_time: when an episode has **lasted** for more than X seconds, it will be terminated **locally** by client (call `end_episode` will be re-route to `abort_episode`) max_episode_time = 2*discard_episode_timeout diff --git a/ajet/tuner_lib/experimental/as_swarm_server.py b/ajet/tuner_lib/experimental/as_swarm_server.py index 81a5b5ff..eb0a9503 100644 --- a/ajet/tuner_lib/experimental/as_swarm_server.py +++ b/ajet/tuner_lib/experimental/as_swarm_server.py @@ -708,7 +708,7 @@ async def get_episode_buffer(): @app.post("/update_current_batch_rollout_pool_information", response_model=BoolResponse) async def update_current_batch_rollout_pool_information(req: CurrentBatchRolloutPoolInformation): """Update the current batch rollout pool information.""" - if VERBOSE: + if DEBUG: logger.info(f"Running /update_current_batch_rollout_pool_information") try: with shared_mem_dict_lock: diff --git a/tutorial/example_math_swarm/math.py b/tutorial/example_math_swarm/math.py index d0bf7096..3c9ba92f 100644 --- a/tutorial/example_math_swarm/math.py +++ b/tutorial/example_math_swarm/math.py @@ -52,16 +52,13 @@ def main(): ) def rollout(task): - try: - # begin episode - episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60) - # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) - workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` - # report output back to swarm remote - swarm_worker.end_episode(task, episode_uuid, workflow_output) - return - except: - pass + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True) for _ in range(NUM_EPOCH): diff --git a/tutorial/example_werewolves/start.py b/tutorial/example_werewolves/start.py index fb3d6df7..e06241ec 100644 --- a/tutorial/example_werewolves/start.py +++ b/tutorial/example_werewolves/start.py @@ -12,7 +12,7 @@ from agentscope.agent import ReActAgent from agentscope.formatter import DashScopeMultiAgentFormatter, OpenAIMultiAgentFormatter -from agentscope.model import OpenAIChatModel +from agentscope.model import OpenAIChatModel, DashScopeChatModel from loguru import logger from pydantic import Field @@ -81,8 +81,8 @@ def get_official_agent_prompt(name) -> str: class ExampleWerewolves(Workflow): trainable_targets: List[str] | None = Field(default=["werewolf"], description="List of agents to be fine-tuned.") - big_external_opponent_llm_url = "http://22.17.52.4:2888/v1" - big_external_opponent_llm_name = "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/" + big_external_opponent_llm_url: str = Field(default="http://22.17.52.4:2888/v1", description="The URL of the big external opponent LLM. You can replace it with any OpenAI-compatible LLM API URL.") + big_external_opponent_llm_name: str = Field(default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", description="The model name of the big external opponent LLM. You can replace it with any OpenAI-compatible LLM name.") async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: @@ -121,9 +121,7 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl name=f"Player{i + 1}", sys_prompt=get_official_agent_prompt(f"Player{i + 1}"), model=model_for_this_agent, - formatter=DashScopeMultiAgentFormatter() - if role in self.trainable_targets - else OpenAIMultiAgentFormatter(), + formatter=DashScopeMultiAgentFormatter() if isinstance(model_for_this_agent, DashScopeChatModel) else OpenAIMultiAgentFormatter(), max_iters=3 if role in self.trainable_targets else 5, ) # agent.set_console_output_enabled(False) diff --git a/tutorial/example_werewolves_swarm/agent_roll.py b/tutorial/example_werewolves_swarm/agent_roll.py index 6b31c6bd..e7b920bd 100644 --- a/tutorial/example_werewolves_swarm/agent_roll.py +++ b/tutorial/example_werewolves_swarm/agent_roll.py @@ -1,15 +1,12 @@ # -*- coding: utf-8 -*- import os -import re -import requests -from textwrap import dedent -from ajet.schema.task import Task, WorkflowOutput +from ajet.schema.task import Task from ajet.copilot.job import AgentJetJob from ajet.task_reader import RouterTaskReader from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey -from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.default_config.ajet_default import AjetTaskReader from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient NUM_EPOCH = 10000 @@ -33,7 +30,7 @@ def main(): swarm_worker = SwarmClient(AJET_SWARM_URL) swarm_worker.auto_sync_train_config_and_start_engine( ajet_job, - force_restart=True, + force_restart=False, ) GRPO_N = ajet_job.num_repeat @@ -65,6 +62,8 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): from tutorial.example_werewolves.start import ExampleWerewolves game = ExampleWerewolves( trainable_targets=["werewolf"], + big_external_opponent_llm_name="Qwen/Qwen3-235B-A22B-Instruct-2507", + big_external_opponent_llm_url="http://22.16.90.187/v1", ) res = asyncio.run(game.execute(task, api_baseurl_key)) return res From a191f7186ce7fbba79d4e99b042bf260cf03f034 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 3 Mar 2026 16:36:39 +0800 Subject: [PATCH 04/11] improve communication protocol --- ajet/context_tracker/multiagent_tracking.py | 13 - ajet/default_config/ajet_default.yaml | 5 +- ajet/default_config/ajet_ts_default.yaml | 12 +- ajet/launcher.py | 9 +- ajet/swarm_cli.py | 10 +- ajet/task_rollout/async_llm_bridge.py | 270 ++------- ajet/task_rollout/native_parallel_worker.py | 13 +- ajet/task_rollout/single_worker.py | 11 +- ajet/tuner_lib/as_oai_baseurl_apikey.py | 3 + .../experimental/as_oai_model_client.py | 53 +- .../experimental/as_oai_model_server.py | 120 +++- .../tuner_lib/experimental/as_swarm_client.py | 8 +- .../tuner_lib/experimental/as_swarm_server.py | 17 +- ajet/utils/config_utils.py | 28 +- ajet/utils/swarm_overwatch.py | 8 +- ajet/utils/thread_executors.py | 21 +- math_gsm8k_grpo/yaml_backup.yaml | 559 ------------------ tutorial/example_math_swarm/math.py | 3 +- tutorial/example_werewolves/game.py | 10 +- .../example_werewolves_swarm/agent_roll.py | 2 +- .../example_werewolves_swarm/werewolves.yaml | 4 +- 21 files changed, 264 insertions(+), 915 deletions(-) delete mode 100644 math_gsm8k_grpo/yaml_backup.yaml diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index b332a11d..30be665c 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -84,19 +84,6 @@ def extract_text_content_from_content_dict(self, msg): # }, # ], # } - # or tool_result format?? not observed yet: - # msg = { - # "role": "tool", - # "content": [ - # { - # "type": "tool_result", - # "id": "call_xxx", - # "output": "tool output content", - # "name": "tool_name" - # }, - # ], - # } - str_content = "" for item in msg["content"]: diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 32bcd822..7a62ee75 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -2,7 +2,7 @@ ajet: project_name: "ajet_default_project" experiment_name: "read_yaml_name" - experiment_dir: "auto" # {exp-dir}/{experiment_name} + experiment_dir: "{exp-dir}/{experiment_name}" # {exp-dir}/{experiment_name} backbone: verl # `debug` or `trinity` or `verl` @@ -85,6 +85,7 @@ ajet: num_repeat: 1 + task_reader: # how to read dataset / environment type: huggingface_dat_repo # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` @@ -306,8 +307,6 @@ ajet: swarm_mode_sample_collection_max_cached_episodes: 9999 task_runner: - # submit llm infer submit method - llm_infer_submit_method: "async" # options: "sync", "async" # how to wrap the user-defined workflow wrapper_type: "asyncio-with-gc" diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_ts_default.yaml index bde6f48b..ab584d61 100644 --- a/ajet/default_config/ajet_ts_default.yaml +++ b/ajet/default_config/ajet_ts_default.yaml @@ -2,7 +2,7 @@ ajet: project_name: "ajet_default_project" experiment_name: "read_yaml_name" - experiment_dir: "auto" # {exp-dir}/{experiment_name} + experiment_dir: "{exp-dir}/{experiment_name}" # {exp-dir}/{experiment_name} backbone: verl model: @@ -12,6 +12,10 @@ ajet: rollout: # the path to the workflow class user_workflow: null + # maximum number of parallel environments / simulate workers + max_env_worker: 128 + # how many times a task should be repeated + num_repeat: 4 task_reader: type: random_dummy # `env_service` or `jsonl_dataset_file` or `huggingface_dat_repo` or `data_generation` or `random_dummy` @@ -53,12 +57,6 @@ ajet: train_batch_size: 32 # [Hint]: The final number of samples per update will be: N_{sample} = (data.train_batch_size * rollout.num_repeat * rollout.multi_turn.expected_steps) - rollout: - # maximum number of parallel environments / simulate workers - max_env_worker: 128 - # how many times a task should be repeated - num_repeat: 4 - trainer_common: logger: tensorboard n_gpus_per_node: 8 diff --git a/ajet/launcher.py b/ajet/launcher.py index f3638e74..3dbda6e1 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -168,7 +168,7 @@ def main(): from ajet.utils.swarm_overwatch import start_overwatch logger.info(f"Starting Swarm Overwatch for server: {args.swarm_overwatch}") - start_overwatch(args.swarm_overwatch, refresh_interval=1.0) + start_overwatch(args.swarm_overwatch, refresh_interval=2.0) return # Enforce GPU availability and free memory threshold before proceeding @@ -204,7 +204,6 @@ def main(): # read configuration from yaml exp_config = None - exp_dir = args.exp_dir or DEFAULT_DIR if args.swarm_server and (not args.conf): args.conf = os.path.abspath( os.path.join( @@ -215,6 +214,7 @@ def main(): "Please provide a valid config file for swarm server mode." ) if args.conf: + exp_dir = args.exp_dir or DEFAULT_DIR yaml_path = args.conf ( main_yaml_fp, @@ -222,7 +222,10 @@ def main(): exp_name, exp_config, ) = prepare_experiment_config( - yaml_path, exp_dir, args.backbone, storage=(not args.swarm_server) + yaml_path=yaml_path, + exp_base_dir=exp_dir, + backbone=args.backbone, + storage=(not args.swarm_server) ) # setup environment variables diff --git a/ajet/swarm_cli.py b/ajet/swarm_cli.py index 8ce5a570..3434a5b6 100644 --- a/ajet/swarm_cli.py +++ b/ajet/swarm_cli.py @@ -61,7 +61,10 @@ def cmd_start(args): exp_name, exp_config, ) = prepare_experiment_config( - yaml_path, exp_dir, "verl", storage=False + yaml_path=yaml_path, + exp_base_dir=exp_dir, + backbone="verl", + storage=False ) # Setup environment variables @@ -73,7 +76,6 @@ def __init__(self, conf, backbone, exp_dir): self.swarm_server = True self.swarm_overwatch = "" self.debug = "" - swarm_args = SwarmArgs(args.conf, "verl", args.exp_dir) env, exp_config = setup_environment_vars(swarm_args, exp_config, main_yaml_fp) @@ -131,9 +133,9 @@ def main(): parser_overwatch.add_argument( "--refresh-interval", type=float, - default=1.0, + default=2.0, required=False, - help="Refresh interval in seconds (default: 1.0)", + help="Refresh interval in seconds (default: 2.0)", ) parser_overwatch.set_defaults(func=cmd_overwatch) diff --git a/ajet/task_rollout/async_llm_bridge.py b/ajet/task_rollout/async_llm_bridge.py index 3015f631..ced9cf16 100644 --- a/ajet/task_rollout/async_llm_bridge.py +++ b/ajet/task_rollout/async_llm_bridge.py @@ -1,9 +1,9 @@ -import asyncio import copy import json import time import uuid -from typing import Any, Callable, Dict, List, Literal, Union +from typing import Any, Callable, Dict, List, Literal, Union, Awaitable +from typing import TYPE_CHECKING from loguru import logger from omegaconf import DictConfig @@ -15,12 +15,13 @@ from ajet.schema.logprob import TokenAndProb from ajet.utils.tokenizer import ajet_apply_chat_template -from ajet.utils.async_utils import run_async_coroutine_with_timeout -from ajet.utils.testing_utils import _mock_if_test_mode, _test_if_test_mode from ajet.schema.convertion import convert_llm_proxy_response_to_oai_response from ajet.schema.convertion import convert_llm_proxy_response_to_agentscope_response from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker +if TYPE_CHECKING: + from vllm.entrypoints.openai.protocol import ChatCompletionRequest + ChatResponse = Union[OpenAIChatCompletion, AgentScopeChatResponse] @@ -58,207 +59,6 @@ def __init__( self.max_llm_retries = max_llm_retries self.tool_parser = Hermes2ProToolParser(self.tokenizer) - def get_llm_inference_fn_sync(self, sampling_params: dict = {}) -> Callable: # noqa: C901 - - def llm_chat_verl( - messages: List[Dict[str, str]], - custom_sampling_params: dict = {}, - tools=[], - request_id: str = "", - ) -> dict: - request_id = uuid.uuid4().hex - - updated_sampling_params = {} - if sampling_params: - updated_sampling_params.update(sampling_params) - if custom_sampling_params: - updated_sampling_params.update(custom_sampling_params) - - input_messages = copy.deepcopy(messages) - prompt_text = ajet_apply_chat_template( - tokenizer=self.tokenizer, - conversation=input_messages, - tools=tools, - add_generation_prompt=True, - tokenize=False, - ) - prompt_ids = self.tokenizer(prompt_text)["input_ids"] - - if self.config.ajet.execute_test: - _test_if_test_mode("prompt_text", prompt_text, self.config) - - final_res = run_async_coroutine_with_timeout( - self.async_rollout_manager.generate( - request_id=request_id, - prompt_ids=prompt_ids, - sampling_params=updated_sampling_params, - ), - timeout=1800, - ) - - if self.config.ajet.rollout.name == "vllm": - final_res: VerlVllmRequestOutput - token_array = final_res.outputs[0].token_ids - logprob_array = final_res.outputs[0].logprobs - elif self.config.ajet.rollout.name == "sglang": - token_array = final_res - - decoded_text = self.tokenizer.decode(token_array) # type: ignore - if self.config.ajet.execute_test: - decoded_text = _mock_if_test_mode("mock_decoded_text", decoded_text, self.config) - - if decoded_text.endswith("<|im_end|>"): - decoded_text = decoded_text[: -len("<|im_end|>")] - - # if tool call - tool_calls = None - if ( - ("" in decoded_text) - and ("" in decoded_text) - and (not self.config.ajet.rollout.force_disable_toolcalls) - ): - parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore - parsed_tool_calls = parsed_tool_calls.model_dump() - if self.config.ajet.execute_test: - _test_if_test_mode( - "parsed_tool_calls", parsed_tool_calls["tool_calls"], self.config - ) - model_called = parsed_tool_calls["tools_called"] - if model_called: - tool_calls = parsed_tool_calls["tool_calls"] - is_bad_toolcall = False - for i in range(len(tool_calls)): - if "function" in tool_calls[i] and "arguments" in tool_calls[i]["function"]: - expect_dict = json.loads(tool_calls[i]["function"]["arguments"]) - if not isinstance(expect_dict, dict): - is_bad_toolcall = True - if is_bad_toolcall: - tool_calls = None - decoded_text = decoded_text - else: - decoded_text = parsed_tool_calls["content"] - if decoded_text is None: - decoded_text = "" - - return { - "role": "assistant", - "request_id": request_id, - "content": decoded_text, - "tool_calls": tool_calls, - "tokens": [ - TokenAndProb( - token_id=token_id, - logprob=logprob[token_id].logprob, # Warning: vllm logprob does not participant training (not reliable enough), for log only. - decoded_string=logprob[token_id].decoded_token, - ) - for token_id, logprob in zip(token_array, logprob_array) # type: ignore - ], - } - - - def llm_chat_remote( - messages: List[Dict[str, str]], - custom_sampling_params: dict = {}, - tools=[], - request_id: str = "", - ) -> dict: - updated_sampling_params = {} - if sampling_params: - updated_sampling_params.update(sampling_params) - if custom_sampling_params: - updated_sampling_params.update(custom_sampling_params) - updated_sampling_params.update({"logprobs": 1, "return_tokens_as_token_ids": True}) - input_messages = copy.deepcopy(messages) - for i in range(self.max_llm_retries): - try: - # this function is defined in `ajet/backbone/main_vllm.py` - output_message = self.async_rollout_manager.submit_chat_completions( - messages=input_messages, - sampling_params=updated_sampling_params, - tools=tools, - request_id=request_id, - ) - break - except Exception as e: - logger.bind(exception=True).exception(f"rollout_server.{i} error: {e.args}") - time.sleep(i + 1) - return output_message[-1] # type: ignore - - - def llm_chat_trinity( - messages: List[Dict[str, str]], - custom_sampling_params: dict = {}, - tools=[], - request_id: str = "", - ) -> dict: - async def main(): - updated_sampling_params = {} - if sampling_params: - updated_sampling_params.update(sampling_params) - if custom_sampling_params: - updated_sampling_params.update(custom_sampling_params) - updated_sampling_params.pop("min_tokens") - - if tools: - response = await self.async_rollout_manager.chat.completions.create( - model=self.async_rollout_manager.model_path, - messages=messages, - logprobs=True, - tools=tools, - top_logprobs=0, - **updated_sampling_params, - ) - else: - response = await self.async_rollout_manager.chat.completions.create( - model=self.async_rollout_manager.model_path, - messages=messages, - logprobs=True, - top_logprobs=0, - **updated_sampling_params, - ) - return response - - response = run_async_coroutine_with_timeout(main(), timeout=1800) # type: ignore - prompt_text = self.tokenizer.decode(response.model_extra["prompt_token_ids"]) - prompt_token_ids = response.model_extra["prompt_token_ids"] - content = response.choices[0].message.content - message = response.choices[0].message.model_dump(exclude_unset=True, exclude_none=True) - - if content is None: - content = "" - - if ("" in content) and (not message.get("tool_calls", None)): - # logger.bind(exception=True).exception(f"Bad toolcall discovered \n\nprompt_text:\n{prompt_text}\n\nrepsonse:\n{content}") - logger.warning(f"Bad toolcall discovered: {content}") - - return { - "role": "assistant", - "request_id": response.id, - "content": content, - "prompt_text": prompt_text, - "prompt_token_ids": prompt_token_ids, - "tool_calls": message.get("tool_calls", []), - "tokens": [ - TokenAndProb( - token_id=token, - logprob=tokenlogprob.logprob, # Warning: vllm logprob does not participant training, for log only. - decoded_string=tokenlogprob.token, - ) - for tokenlogprob, token in zip( - response.choices[0].logprobs.content, - response.choices[0].token_ids, - ) - ], - } - - if self.llm_mode == "remote": - return llm_chat_remote - if self.llm_mode == "trinity": - return llm_chat_trinity - else: - return llm_chat_verl - - def get_llm_inference_fn_async(self, sampling_params: dict = {}) -> Callable: # noqa: C901 @@ -286,9 +86,6 @@ async def llm_chat_verl( ) prompt_ids = self.tokenizer(prompt_text)["input_ids"] - if self.config.ajet.execute_test: - _test_if_test_mode("prompt_text", prompt_text, self.config) - final_res = await self.async_rollout_manager.generate( request_id=request_id, prompt_ids=prompt_ids, @@ -303,13 +100,11 @@ async def llm_chat_verl( token_array = final_res decoded_text = self.tokenizer.decode(token_array) # type: ignore - if self.config.ajet.execute_test: - decoded_text = _mock_if_test_mode("mock_decoded_text", decoded_text, self.config) if decoded_text.endswith("<|im_end|>"): decoded_text = decoded_text[: -len("<|im_end|>")] - # if tool call + # if tool call, use vLLM tool parser to extract tool calls and validate them tool_calls = None if ( ("" in decoded_text) @@ -319,10 +114,7 @@ async def llm_chat_verl( parsed_tool_calls = self.tool_parser.extract_tool_calls(decoded_text, None) # type: ignore parsed_tool_calls = parsed_tool_calls.model_dump() - if self.config.ajet.execute_test: - _test_if_test_mode( - "parsed_tool_calls", parsed_tool_calls["tool_calls"], self.config - ) + model_called = parsed_tool_calls["tools_called"] if model_called: tool_calls = parsed_tool_calls["tool_calls"] @@ -474,7 +266,7 @@ class OpenaiLlmProxyWithTracker(object): def __init__( self, - llm_inference_fn: Callable, # Callable[AjetStandardLlmBridgeRequest, AjetStandardLlmBridgeResponse] + llm_inference_fn: Callable[..., Awaitable[Dict]], # Callable[AjetStandardLlmBridgeRequest, AjetStandardLlmBridgeResponse] context_tracker: MultiAgentContextTracker, config, ) -> None: @@ -483,15 +275,39 @@ def __init__( self.config = config + async def chat_completion_request( + self, + req: "ChatCompletionRequest", + timeline_uuid: str, + agent_name: str, + target_tag: str, + episode_uuid: str, + ): + from openai.types.chat.chat_completion import ChatCompletion + req_as_dict = req.model_dump() + + # infer + process with context tracker + llm_output = await self.run_infer( + messages=req_as_dict["messages"], + tools=req_as_dict["tools"], + tool_choice="auto", + ) + # convert to OpenAI ChatCompletion format + response: ChatCompletion = convert_llm_proxy_response_to_oai_response(llm_output) + # this is an important id assignment + response.id = timeline_uuid + assert isinstance(response, ChatCompletion) + return response + + async def __call__( self, messages: List[dict], tools: List = [], tool_choice: str = "auto", - structured_model=None, **kwargs, ) -> ChatResponse: - llm_output = await self.run_infer(messages, tools, tool_choice, structured_model, **kwargs) + llm_output = await self.run_infer(messages, tools, tool_choice, **kwargs) return convert_llm_proxy_response_to_oai_response(llm_output) @@ -500,9 +316,8 @@ async def run_infer( messages: List[dict], tools: List = [], tool_choice: str = "auto", # always auto - structured_model=None, # this is for AgentScope only **kwargs, - ): + ) -> Dict: # generate timeline uuid timeline_uuid = uuid.uuid4().hex @@ -527,16 +342,10 @@ async def run_infer( # else: # otherwise, for abnormal output, can still proceed, but we do not track output anymore - # run llm inference ✨ - if self.config.ajet.task_runner.llm_infer_submit_method == "sync": - llm_output = await asyncio.to_thread( - self.llm_inference_fn, converted_message, custom_sampling_params, tools - ) - else: - llm_output = await self.llm_inference_fn(converted_message, custom_sampling_params, tools) - + # run llm inference ✨ (llm_chat_verl) + llm_output = await self.llm_inference_fn(converted_message, custom_sampling_params, tools) - # begin context tracking + # context tracking self.context_tracker.step_track(llm_output, context_safe, converted_message, tools, timeline_uuid=timeline_uuid) return llm_output @@ -554,7 +363,6 @@ def construct_overflow_response(self): - # ---------------------------------------------------------------------------------------------- # ------------------------ call async llm with context tracker (AgentScope) -------------------- # ---------------------------------------------------------------------------------------------- @@ -570,6 +378,6 @@ async def __call__( **kwargs, ) -> AgentScopeChatResponse: - llm_output = await self.run_infer(messages, tools, tool_choice, structured_model) + llm_output = await self.run_infer(messages, tools, tool_choice) response = convert_llm_proxy_response_to_agentscope_response(llm_output, structured_model=structured_model) return response diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index 79f0bdcd..4afad524 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -9,6 +9,7 @@ import numpy as np import torch import threading +from math import ceil from loguru import logger from tensordict import TensorDict from torch.nn.utils.rnn import pad_sequence @@ -172,11 +173,11 @@ def rollout_swarm( # noqa: C901 Build a pool of threads to run context trackers in parallel, each thread re-spawn after complete, until reaching conditions to stop. """ - + # from ajet import bp; bp("SWARM") tracker_array: List[SingleAgentContextTracker] = [] rollout_n = self.rollout_n n_batch_task = len(tasks) - n_task = min(len(tasks), self.max_parallel // rollout_n) + n_task = min(len(tasks), ceil(self.max_parallel / rollout_n)) assert n_task > 0, f"n_task is not valid, n_task = min(len(tasks), self.max_parallel // rollout_n) = {n_task}" self.current_token_count_time = time.time() @@ -370,20 +371,20 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma self._write_swarm_rollout_dynamic_log(observation_window) meet_stop_condition_after_new_results = stop_condition(completed_task_id_map_ct) if meet_stop_condition_after_new_results: - print("Sending soft stop signal to all threads...") + logger.info("Sending soft stop signal to all threads...") stop_all_threads_soft() break # wait for all threads to complete - print('Finalizing all threads...') + logger.info('Finalizing all threads...') executor.shutdown(wait=True) # stop all threads hard - print("Sending hard stop signal to all threads...") + logger.info("Sending hard stop signal to all threads...") stop_all_threads_hard() # build tracker_array - print('Collecting results...') + logger.info('Collecting results...') for ct_list in completed_task_id_map_ct.values(): tracker_array.extend(ct_list) diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py index 4791a47f..f8b51a66 100644 --- a/ajet/task_rollout/single_worker.py +++ b/ajet/task_rollout/single_worker.py @@ -90,14 +90,9 @@ def rollout_env_worker( """ sampling_params = get_sample_params(mode, self.config) - if self.config.ajet.task_runner.llm_infer_submit_method == "sync": - llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_sync( - sampling_params=sampling_params - ) - else: - llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_async( - sampling_params=sampling_params - ) + llm_inference_fn = self.async_llm_bridge.get_llm_inference_fn_async( + sampling_params=sampling_params + ) episode_uuid = uuid.uuid4().hex workflow_task = WorkflowTask( diff --git a/ajet/tuner_lib/as_oai_baseurl_apikey.py b/ajet/tuner_lib/as_oai_baseurl_apikey.py index dcc5e549..0931e980 100644 --- a/ajet/tuner_lib/as_oai_baseurl_apikey.py +++ b/ajet/tuner_lib/as_oai_baseurl_apikey.py @@ -12,11 +12,13 @@ class MockAsyncCompletions(AsyncCompletions): async def create(self, *args, **kwargs) -> Any: # type: ignore return await self._client.create(*args, **kwargs) # type: ignore + class MockAsyncChat(AsyncChat): @property def completions(self) -> MockAsyncCompletions: # type: ignore return MockAsyncCompletions(self._client) + class OpenaiBaseUrlAndApiKey(BaseModel): """ At this layer, we will determine which model to use: - training model @@ -39,6 +41,7 @@ def as_raw_openai_sdk_client(self, *args, **kwargs): from openai import AsyncOpenAI return AsyncOpenAI(api_key=self.api_key, base_url=self.base_url) + class OpenaiClientBaseUrlTuner(BaseModel): """ At this layer, we will determine which model to use: - training model diff --git a/ajet/tuner_lib/experimental/as_oai_model_client.py b/ajet/tuner_lib/experimental/as_oai_model_client.py index 8288b609..616ab19a 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/experimental/as_oai_model_client.py @@ -5,19 +5,17 @@ import os import time import zmq -import base64 import json from loguru import logger from typing import TYPE_CHECKING -from openai.types.chat.chat_completion import ChatCompletion from ajet.tuner_lib.experimental.as_oai_model_server import InterchangeCompletionRequest from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor from ajet.tuner_lib.experimental.interchange_utils import get_zmq_socket -from ajet.tuner_lib.experimental.interchange_utils import DEBUG, API_KEY_PREFIX +from ajet.tuner_lib.experimental.interchange_utils import DEBUG if TYPE_CHECKING: - from vllm.entrypoints.openai.protocol import ChatCompletionRequest + pass context = zmq.Context() atexit.register(context.term) @@ -31,6 +29,7 @@ class InterchangeClient: """ def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker", llm_inference_fn, config): + from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker self.episode_uuid = episode_uuid self.context_tracker = context_tracker self.llm_inference_fn = llm_inference_fn @@ -40,37 +39,12 @@ def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker self.ipc_path = ipc_path self.interchange_method = config.ajet.interchange_server.interchange_method self.max_inference_tracker_threads = config.ajet.interchange_server.max_inference_tracker_threads - - async def llm_infer( - self, - req: "ChatCompletionRequest", - timeline_uuid: str, - agent_name: str, - target_tag: str, - episode_uuid: str, - ) -> ChatCompletion: - from ajet.task_rollout.async_llm_bridge import OpenaiLlmProxyWithTracker - - req_as_dict = req.model_dump() self.llm_proxy_with_tracker = OpenaiLlmProxyWithTracker( context_tracker=self.context_tracker, config=self.config, llm_inference_fn=self.llm_inference_fn, ) - # infer + process with context tracker - response = await self.llm_proxy_with_tracker( - messages=req_as_dict["messages"], - tools=req_as_dict["tools"], - tool_choice="auto", - ) - - # this is an important id assignment - response.id = timeline_uuid - assert isinstance(response, ChatCompletion) - return response - - @property def should_soft_terminate(self) -> bool: if self._should_terminate: @@ -98,15 +72,15 @@ def begin_service(self): if DEBUG: logger.info(f"[client] {self.episode_uuid} | Starting InterchangeClient service loop...") self.socket = context.socket(zmq.REP) self.socket.bind(f"{self.episode_contect_address}") - self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 3 second timeout for REP + self.socket.setsockopt(zmq.RCVTIMEO, 1*1000) # 1 second timeout for REP self.executor = SharedInterchangeThreadExecutor(self.max_inference_tracker_threads).get_shared_executor() if DEBUG: logger.info(f"[client] {self.episode_uuid} | Submitting _begin_service_threading to executor...") future = self.executor.submit(self._begin_service_threading) # wait till service begin running - time.sleep(0.5) wait_time = 1 + time.sleep(wait_time) while future._state == 'PENDING': if self.should_soft_terminate or self.should_hard_terminate: future.cancel() @@ -130,12 +104,15 @@ def _begin_service_threading(self): try: while not self.should_hard_terminate: - # listen for next request from remote try: - # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() has begun (should_terminate {self.should_terminate})") + + # : + # : ajet/tuner_lib/experimental/as_oai_model_server.py + # : socket.send_string(int_req.model_dump_json()) + # : InterchangeCompletionRequest object in JSON string format message = self.socket.recv_string() + ever_receive_anything = True - # if DEBUG: logger.info(f"[client] {self.episode_uuid} | socket.recv_string() is done") except zmq.Again as e: if self.should_hard_terminate: # abort_episode() @@ -162,7 +139,7 @@ def _begin_service_threading(self): future = loop.run_in_executor( context_tracker_executor, asyncio.run, - self.llm_infer( + self.llm_proxy_with_tracker.chat_completion_request( req=parsed_msg.completion_request, timeline_uuid=parsed_msg.timeline_uuid, agent_name=parsed_msg.agent_name, @@ -172,9 +149,13 @@ def _begin_service_threading(self): ) result = loop.run_until_complete(future).model_dump_json() # type: ignore - # great, let's send back the result if DEBUG: logger.info(f"[client] {self.episode_uuid} | before send_string (send llm call result)") + + # + # : ajet/tuner_lib/experimental/as_oai_model_server.py + # : result_str = socket.recv_string() self.socket.send_string(result) + if DEBUG: logger.info(f"[client] {self.episode_uuid} | after send_string (send llm call result)") except: logger.exception(f"[client] {self.episode_uuid} | Exception occurred in service loop.") diff --git a/ajet/tuner_lib/experimental/as_oai_model_server.py b/ajet/tuner_lib/experimental/as_oai_model_server.py index 301483ff..51bf9bb1 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/experimental/as_oai_model_server.py @@ -25,6 +25,7 @@ from loguru import logger from pydantic import BaseModel from fastapi import FastAPI, Header, HTTPException, Request +from fastapi.responses import StreamingResponse from contextlib import asynccontextmanager from multiprocessing import Manager, Process from concurrent.futures import ThreadPoolExecutor @@ -32,6 +33,9 @@ from vllm.entrypoints.openai.protocol import ChatCompletionRequest from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction from ajet.utils.networking import get_host_ip from ajet.tuner_lib.experimental.interchange_utils import EpisodeStatus @@ -44,6 +48,7 @@ class InterchangeCompletionRequest(BaseModel): target_tag: str episode_uuid: str timeline_uuid: str + preserve_sampling_params: bool = False class HealthCheckRequest(BaseModel): agent_name: str @@ -85,14 +90,25 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio socket.setsockopt(zmq.RCVTIMEO, 6*1000) # 6 second recv timeout socket.connect(f"{episode_address}") if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | connect done") + + # + # : ajet/tuner_lib/experimental/as_oai_model_client.py + # : message = self.socket.recv_string() socket.send_string(int_req.model_dump_json()) + if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | send_string") result_str = "" for _ in range(50): # max 5 minutes wait try: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") + + # : + # : ajet/tuner_lib/experimental/as_oai_model_client.py + # : self.socket.send_string(result) + # : ChatCompletion object in JSON string format result_str = socket.recv_string() + break except zmq.Again as e: # check whether server is still in rolling status @@ -112,6 +128,89 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio return result_object + async def mock_as_stream_response(result: ChatCompletion): + """ + Convert a non-streaming ChatCompletion result to streaming format. + + Args: + result: ChatCompletion object to convert to streaming format + + Yields: + Server-sent events formatted as streaming chat completion chunks + """ + content = result.choices[0].message.content if result.choices else "" + role = result.choices[0].message.role if result.choices else "assistant" + # try: + # thinking = result.choices[0].message.reasoning_content + # except: + # thinking = None + tool_calls = result.choices[0].message.tool_calls if result.choices and result.choices[0].message.tool_calls else None + delta_tool_calls = [] # tool_calls: Optional[List[ChoiceDeltaToolCall]] = None + finish_reason = result.choices[0].finish_reason + if tool_calls: + delta_tool_calls = [ChoiceDeltaToolCall( + index=index, + id=tc.id, + function=ChoiceDeltaToolCallFunction( + name = tc.function.name, + arguments = tc.function.arguments, + ), + type=tc.type + ) for index, tc in enumerate(tool_calls)] + + # First chunk with role + first_chunk = ChatCompletionChunk( + id=result.id, + model=result.model, + created=result.created, + object="chat.completion.chunk", + choices=[ + ChunkChoice( + index=0, + delta=ChoiceDelta(role=role, content=""), + finish_reason=None + ) + ] + ) + dat = f"data: {first_chunk.model_dump_json()}\n\n" + yield dat + + # Content chunk + content_chunk = ChatCompletionChunk( + id=result.id, + model=result.model, + created=result.created, + object="chat.completion.chunk", + choices=[ + ChunkChoice( + index=0, + delta=ChoiceDelta(role=role, content=content, tool_calls=delta_tool_calls), + finish_reason=None + ) + ] + ) + dat = f"data: {content_chunk.model_dump_json()}\n\n" + yield dat + + # Final chunk with finish_reason + final_chunk = ChatCompletionChunk( + id=result.id, + model=result.model, + created=result.created, + object="chat.completion.chunk", + choices=[ + ChunkChoice( + index=0, + delta=ChoiceDelta(), + finish_reason=finish_reason + ) + ] + ) + dat = f"data: {final_chunk.model_dump_json()}\n\n" + yield dat + yield "data: [DONE]\n\n" + + @app.get("/health") async def health(): return {"status": "ok"} @@ -149,12 +248,13 @@ async def chat_completions(request: Request, authorization: str = Header(None)): # Parse request body body = await request.json() new_req = ChatCompletionRequest.model_validate(body) - if new_req.stream: - return HTTPException(status_code=400, detail="Streaming responses not supported in current AgentJet version, please set `stream=false` for now.") # Create timeline UUID timeline_uuid = uuid.uuid4().hex + # if training, ignore all sampling parameters from request + preserve_sampling_params = False + # enable_swarm_mode if enable_swarm_mode: from ajet.tuner_lib.experimental.as_swarm_server import ep_key @@ -174,6 +274,14 @@ async def chat_completions(request: Request, authorization: str = Header(None)): es.latest_activity_timestamp = time.time() es.llm_call_count += 1 shared_mem_dict[ep_key(episode_uuid)] = es + if es.episode_type == "eval": + preserve_sampling_params = True + + # For streaming, we process as non-streaming but return in streaming format + original_stream = new_req.stream + if original_stream: + new_req.stream = False + new_req.stream_options = None # Add to received queue int_req = InterchangeCompletionRequest( @@ -182,10 +290,16 @@ async def chat_completions(request: Request, authorization: str = Header(None)): target_tag = target_tag, episode_uuid = episode_uuid, timeline_uuid = timeline_uuid, + preserve_sampling_params = preserve_sampling_params, ) if DEBUG: logger.info(f"episode_uuid: {episode_uuid} | Received new chat completion request (outside thread)") loop = asyncio.get_running_loop() - return await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid) + result = await loop.run_in_executor(request.app.state.executor, _begin_handle_chat_completion, episode_address, int_req, episode_uuid) + + if original_stream: + return StreamingResponse(mock_as_stream_response(result), media_type="text/event-stream") + + return result if enable_swarm_mode: diff --git a/ajet/tuner_lib/experimental/as_swarm_client.py b/ajet/tuner_lib/experimental/as_swarm_client.py index 61d22941..6b3f9168 100644 --- a/ajet/tuner_lib/experimental/as_swarm_client.py +++ b/ajet/tuner_lib/experimental/as_swarm_client.py @@ -212,7 +212,7 @@ def begin_episode(self, discard_episode_timeout=240, episode_type="train", throt def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]: # max_episode_time: when an episode has **lasted** for more than X seconds, it will be terminated **locally** by client (call `end_episode` will be re-route to `abort_episode`) - max_episode_time = 2*discard_episode_timeout + max_episode_time = 8*discard_episode_timeout status, status_json = self.get_engine_status() # warm up connection and log the status if status not in ["ENGINE.ROLLING"]: @@ -318,13 +318,13 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut if episode_uuid in self.record_episode_expire_time: remain_time = self.record_episode_expire_time.pop(episode_uuid, 0) - time.time() if remain_time < 0: - logger.warning(f"Episode {episode_uuid} has expired (expired {-remain_time} seconds ago). Please use a larger `discard_episode_timeout` and `max_episode_time` when `begin_episode`. Skipping end_episode.") + logger.warning(f"Episode {episode_uuid} has expired (expired {-remain_time} seconds ago). Please use a larger `discard_episode_timeout` when `begin_episode`. Skipping end_episode.") # send abort signal to server to clean up episode self.abort_episode(episode_uuid) return else: # send abort signal to server to clean up episode - logger.warning(f"Episode {episode_uuid} has expired (expired at least {CLEAN_RECORD_TIMEOUT} seconds ago). Please use a larger `discard_episode_timeout` and `max_episode_time` when `begin_episode`. Skipping end_episode.") + logger.warning(f"Episode {episode_uuid} has expired (expired at least {CLEAN_RECORD_TIMEOUT} seconds ago). Please use a larger `discard_episode_timeout` when `begin_episode`. Skipping end_episode.") self.abort_episode(episode_uuid) return @@ -459,7 +459,7 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose= if current_time - init_poll_time >= timeout: raise TimeoutError(f"Timeout reached while waiting for engine status to change to {desired_status}") - if (initial_status == "ENGINE.OFFLINE") and (current_status == "ENGINE.OFFLINE"): + if (initial_status == "ENGINE.OFFLINE") and (current_status == "ENGINE.OFFLINE") and (desired_status!="ENGINE.OFFLINE"): raise SwarmServerOfflineError(f"Engine status changed from {initial_status} to OFFLINE while waiting for {desired_status}. This may indicate an error in the engine. Please check the swarm server logs for details.") # Report status every 5 seconds diff --git a/ajet/tuner_lib/experimental/as_swarm_server.py b/ajet/tuner_lib/experimental/as_swarm_server.py index eb0a9503..e29cfbfa 100644 --- a/ajet/tuner_lib/experimental/as_swarm_server.py +++ b/ajet/tuner_lib/experimental/as_swarm_server.py @@ -334,12 +334,9 @@ async def start_engine(): yaml_str = shared_mem_dict["train_config_yaml"] config_dict = yaml_module.safe_load(yaml_str) backbone = config_dict.get("ajet", {}).get("backbone", "verl") - DEFAULT_DIR = "saved_experiments" - exp_dir_final = config_dict.get("ajet", {}).get("experiment_dir", DEFAULT_DIR) - if exp_dir_final != DEFAULT_DIR: - # remove last dir level if possible - exp_dir_final = os.path.dirname(exp_dir_final) - + exp_base_dir = os.path.dirname( + config_dict.get("ajet", {}).get("experiment_dir", "saved_experiments") + ) # Save YAML to temporary file with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as temp_file: @@ -351,7 +348,6 @@ async def start_engine(): args = SimpleNamespace( conf=main_yaml_fp, backbone=backbone, - exp_dir=exp_dir_final, with_logview=False, debug=False, ) @@ -367,7 +363,12 @@ def override_param_callback(config): return config # Finalize experiment config - main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config(main_yaml_fp, exp_dir_final, backbone, override_param_callback) + main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config( + yaml_path=main_yaml_fp, + exp_base_dir=exp_base_dir, + backbone=backbone, + override_param_callback=override_param_callback, + ) # Setup environment variables env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp) diff --git a/ajet/utils/config_utils.py b/ajet/utils/config_utils.py index 8a6a301e..13c4d693 100644 --- a/ajet/utils/config_utils.py +++ b/ajet/utils/config_utils.py @@ -171,7 +171,7 @@ def config_safe_guard(config: dict, backbone: str) -> dict: def read_ajet_hierarchical_config( - yaml_fp, exp_name=None, backbone=None, write_to=None, exp_dir=None, override_param_callback=None + yaml_fp, experiment_name=None, backbone=None, write_to=None, experiment_dir=None, override_param_callback=None ): if yaml_fp is None: config = { @@ -193,10 +193,10 @@ def read_ajet_hierarchical_config( else: with open(yaml_fp, "r", encoding="utf-8") as file: config = yaml.safe_load(file) - if exp_name is not None: - config["ajet"]["experiment_name"] = exp_name - if (exp_dir is not None) and (exp_name is not None): - config["ajet"]["experiment_dir"] = os.path.join(exp_dir, exp_name) + if experiment_name is not None: + config["ajet"]["experiment_name"] = experiment_name + if (experiment_dir is not None): + config["ajet"]["experiment_dir"] = experiment_dir if backbone is not None: config["ajet"]["backbone"] = backbone @@ -248,14 +248,14 @@ def expand_ajet_hierarchical_config(config, write_to=None): return config_final -def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callback=None, storage=True): +def prepare_experiment_config(yaml_path, exp_base_dir, backbone, override_param_callback=None, storage=True): """ Prepare experiment configuration by reading YAML, setting up backup directories, and copying necessary files for the experiment. Args: yaml_path: Path to the YAML configuration file - exp_dir: Directory where experiment artifacts and backups should be stored + exp_base_dir: Directory where experiment artifacts and backups should be stored backbone: Backbone identifier that controls config munging Returns: @@ -284,8 +284,8 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callb else: exp_name = exp_name.replace("|", "-") - backup_dir = os.path.abspath(os.path.join(exp_dir, exp_name, "backup")) - yaml_backup_dst = os.path.join(exp_dir, exp_name, "yaml_backup.yaml") + backup_dir = os.path.abspath(os.path.join(exp_base_dir, exp_name, "backup")) + yaml_backup_dst = os.path.join(exp_base_dir, exp_name, "yaml_backup.yaml") yaml_backup_dst = os.path.abspath(yaml_backup_dst) exe_exp_base = os.path.dirname(yaml_backup_dst) @@ -326,12 +326,18 @@ def prepare_experiment_config(yaml_path, exp_dir, backbone, override_param_callb shutil.copyfile(yaml_backup_src, yaml_backup_dst) ## 4. edit new yaml + experiment_dir = f"{exp_base_dir}/{exp_name}" config = read_ajet_hierarchical_config( - yaml_backup_dst, exp_name=exp_name, backbone=backbone, write_to=yaml_backup_dst, exp_dir=exp_dir, override_param_callback=override_param_callback + yaml_backup_dst, + experiment_name=exp_name, + backbone=backbone, + write_to=yaml_backup_dst, + experiment_dir=experiment_dir, + override_param_callback=override_param_callback ) config_final = expand_ajet_hierarchical_config(config, write_to=yaml_backup_dst) if not storage: - shutil.rmtree(os.path.join(exp_dir, exp_name)) + shutil.rmtree(os.path.join(exp_base_dir, exp_name)) return yaml_backup_dst, exe_exp_base, exp_name, config_final diff --git a/ajet/utils/swarm_overwatch.py b/ajet/utils/swarm_overwatch.py index da234a8f..fd97fa41 100644 --- a/ajet/utils/swarm_overwatch.py +++ b/ajet/utils/swarm_overwatch.py @@ -23,13 +23,13 @@ class SwarmOverwatch: """Real-time monitoring interface for swarm rollout pool""" - def __init__(self, server_url: str, refresh_interval: float = 1.0): + def __init__(self, server_url: str, refresh_interval: float = 2.0): """ Initialize the overwatch monitor Args: server_url: Base URL of the swarm server (e.g., http://localhost:10086) - refresh_interval: Refresh interval in seconds (default: 1.0) + refresh_interval: Refresh interval in seconds (default: 2.0) """ self.server_url = server_url.rstrip("/") self.refresh_interval = refresh_interval @@ -480,13 +480,13 @@ def run(self): ) -def start_overwatch(server_url: str, refresh_interval: float = 1.0): +def start_overwatch(server_url: str, refresh_interval: float = 2.0): """ Start the swarm overwatch monitoring interface Args: server_url: Base URL of the swarm server - refresh_interval: Refresh interval in seconds (default: 1.0) + refresh_interval: Refresh interval in seconds (default: 2.0) """ overwatch = SwarmOverwatch(server_url, refresh_interval) overwatch.run() diff --git a/ajet/utils/thread_executors.py b/ajet/utils/thread_executors.py index 8702e00c..87ea0744 100644 --- a/ajet/utils/thread_executors.py +++ b/ajet/utils/thread_executors.py @@ -45,12 +45,17 @@ def shutdown(self, wait=True): class PeriodicDrainThreadPoolExecutor: """A ThreadPoolExecutor that bounds the number of pending tasks via a semaphore.""" - def __init__(self, workers=100, auto_retry=True): + def __init__(self, workers=100, max_parallel=None, auto_retry=True, block_first_run=False): self._max_workers = workers - self._executor = ThreadPoolExecutor(max_workers=workers) + if max_parallel is None: + self._max_parallel = workers + else: + self._max_parallel = max_parallel + self._executor = ThreadPoolExecutor(max_workers=self._max_parallel) self._submitted_count = 0 self._auto_retry = auto_retry self.current_futures = [] + self._slow_first_run = block_first_run def submit(self, fn, *args, **kwargs): """Submit a task, blocking if the pending queue is full.""" @@ -63,9 +68,15 @@ def retry_wrapper(fn, *args, **kwargs): logger.exception(f"[run_episodes_until_all_complete] Error executing episode: {e}. Retrying...") if self._auto_retry: - return self._executor.submit(retry_wrapper, fn, *args, **kwargs) + future = self._executor.submit(retry_wrapper, fn, *args, **kwargs) else: - return self._executor.submit(fn, *args, **kwargs) + future = self._executor.submit(fn, *args, **kwargs) + + if self._slow_first_run: + self._slow_first_run = False + future.result() # Wait for the first run to complete before allowing more tasks to be submitted + + return future def submit_with_periodic_drain(self, fn, *args, **kwargs): """Submit a task, draining all in-flight work every `drain_every_n_job` submissions.""" @@ -86,4 +97,4 @@ def submit_with_periodic_drain(self, fn, *args, **kwargs): def shutdown(self, wait=True): """Shut down the underlying executor.""" - self._executor.shutdown(wait=wait) \ No newline at end of file + self._executor.shutdown(wait=wait) diff --git a/math_gsm8k_grpo/yaml_backup.yaml b/math_gsm8k_grpo/yaml_backup.yaml deleted file mode 100644 index 69a60a3c..00000000 --- a/math_gsm8k_grpo/yaml_backup.yaml +++ /dev/null @@ -1,559 +0,0 @@ -actor_rollout_ref: - actor: - _target_: verl.workers.config.FSDPActorConfig - checkpoint: - _target_: verl.trainer.config.CheckpointConfig - async_save: false - load_contents: - - model - - optimizer - - extra - save_contents: - - model - - optimizer - - extra - clip_ratio: 0.2 - clip_ratio_c: 3.0 - clip_ratio_high: 0.2 - clip_ratio_low: 0.2 - entropy_checkpointing: false - entropy_coeff: 0 - entropy_from_logits_with_chunking: false - fsdp_config: - optimizer_offload: true - param_offload: true - grad_clip: 1.0 - kl_loss_coef: 0.002 - kl_loss_type: low_var_kl - loss_agg_mode: seq-mean-token-mean - optim: - lr: 1.0e-06 - override_ppo_mini_batch_num: 1 - policy_loss: - _target_: verl.workers.config.PolicyLossConfig - clip_cov_lb: 1.0 - clip_cov_ratio: 0.0002 - clip_cov_ub: 5.0 - kl_cov_ratio: 0.0002 - loss_mode: vanilla - ppo_kl_coef: 0.1 - ppo_epochs: 1 - ppo_max_token_len_per_gpu: 18000 - ppo_micro_batch_size: null - ppo_micro_batch_size_per_gpu: 1 - ppo_mini_batch_size: 16 - shuffle: false - strategy: fsdp - ulysses_sequence_parallel_size: 1 - use_dynamic_bsz: true - use_fused_kernels: false - use_kl_loss: true - use_remove_padding: true - use_torch_compile: true - hybrid_engine: true - model: - custom_chat_template: null - enable_activation_offload: false - enable_gradient_checkpointing: true - exclude_modules: null - external_lib: null - fused_kernel_options: - impl_backend: torch - lora_alpha: 16 - lora_rank: 0 - override_config: {} - path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct - target_modules: all-linear - trust_remote_code: false - use_fused_kernels: false - use_liger: false - use_remove_padding: true - use_shm: false - nccl_timeout: 600 - profiler: - _target_: verl.utils.profiler.ProfilerConfig - all_ranks: false - discrete: false - ranks: [] - ref: - entropy_checkpointing: false - entropy_from_logits_with_chunking: false - fsdp_config: - _target_: verl.workers.config.FSDPEngineConfig - forward_prefetch: false - param_offload: true - reshard_after_forward: true - wrap_policy: - min_num_params: 0 - log_prob_max_token_len_per_gpu: 18000 - log_prob_micro_batch_size: null - log_prob_micro_batch_size_per_gpu: 4 - log_prob_use_dynamic_bsz: true - model: null - strategy: fsdp - ulysses_sequence_parallel_size: 1 - use_dynamic_bsz: true - use_torch_compile: true - rollout: - agent: - agent_loop_config_path: null - custom_async_server: - name: null - path: null - num_workers: 8 - calculate_log_probs: false - cudagraph_capture_sizes: null - custom_dataflow_cls: - name: '' - path: '' - disable_log_stats: true - do_sample: true - dtype: bfloat16 - enable_chunked_prefill: true - enforce_eager: true - engine_kwargs: - sglang: - attention_backend: null - vllm: - disable_mm_preprocessor_cache: false - swap_space: null - free_cache_engine: true - gamma: 1.0 - gpu_memory_utilization: 0.9 - ignore_eos: false - layered_summon: false - load_format: dummy_dtensor - log_prob_max_token_len_per_gpu: 18000 - log_prob_micro_batch_size: null - log_prob_micro_batch_size_per_gpu: 4 - log_prob_use_dynamic_bsz: true - max_env_worker: 64 - max_model_len: 18000 - max_num_batched_tokens: 8192 - max_num_seqs: 10 - mode: async - multi_stage_wake_up: false - multi_turn: - expected_steps: 1 - max_sample_per_task: 30 - max_steps: 30 - n: 1 - name: vllm - ppo_micro_batch_size_per_gpu: 1 - prompt_length: 3000 - response_length: 10000 - skip_dump_dir: /tmp/rollout_dump - skip_rollout: false - temperature: 0.9 - tensor_model_parallel_size: 1 - top_k: -1 - top_p: 1.0 - trace: - backend: null - token2text: false - update_weights_bucket_megabytes: 512 - val_kwargs: - do_sample: false - num_repeat: 1 - temperature: 0.0 - top_k: -1 - top_p: 1.0 -ajet: - backbone: verl - context_tracker: - alien_llm_model: qwen3-235b-a22b-instruct-2507 - alien_llm_response_length: 512 - detect_timeline_snap: false - fix_retokenization_drift: true - log_tool_format_check: false - log_tool_format_error_detail: false - timeline_merging_policy: - ignore_tools: true - timeline_compare_level: text - data: - max_prompt_length: 3000 - max_response_length: 15000 - train_batch_size: 32 - debug: - debug_first_n_tasks: 2 - debug_max_parallel: 4 - debug_tensor_parallel_size: 4 - debug_vllm_port: 18000 - debug_vllm_seed: 12345 - enable_interchange_server: true - enable_swarm_mode: true - execute_test: false - execute_testing_lambda: '' - experiment_dir: math_gsm8k_grpo - experiment_name: math_gsm8k_grpo - interchange_server: - already_started: true - interchange_method: ipc - interchange_server_port: 10086 - max_fastapi_threads: 512 - max_inference_tracker_threads: 64 - num_fastapi_process: 2 - model: - path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct - project_name: ajet_default_project - rollout: - agent_madness_reward: -1.0 - agent_madness_termination: true - compute_madness_checklist: - - nonsense - force_disable_toolcalls: false - gamma: 1.0 - max_env_worker: 128 - max_model_len: 18000 - max_num_seqs: 10 - max_response_length_in_one_turn: 4096 - multi_turn: - expected_steps: 1 - max_sample_per_task: 30 - max_steps: 30 - n_vllm_engine: 1 - name: vllm - num_repeat: 4 - step_skip_action: 0 - submit_oversample_multiplier: 1.5 - temperature: 0.9 - tensor_model_parallel_size: 1 - top_p: 1.0 - user_workflow: tutorial.example_appworld.appworld->ExampleAgentScopeWorkflow - val_kwargs: - do_sample: false - num_repeat: 1 - temperature: 0.0 - top_k: -1 - top_p: 1.0 - swarm_mode_sample_collection_max_cached_episodes: 9999 - swarm_mode_sample_collection_method: rollout_until_finish_enough_tasks - task_judge: - alien_llm_model: qwen3-235b-a22b-instruct-2507 - alien_llm_response_length: 512 - judge_protocol: null - judge_type: customized_protocol - rubrics_auto_grader: - answer_field: final_answer - categories_number: 5 - custom_evaluation_prompt: null - enable_categorization: false - grader_mode: pointwise - grader_name: auto_grader - input_data_type: jsonl_dataset_file - jsonl_dataset_file: - training: - file_path: tutorial/example_rm_auto_grader/rubrics_train.jsonl - language: en - max_score: 1 - min_score: 0 - model_name: qwen-max - query_field: main_query - query_specific_generate_number: 1 - reference_field: answer - task_reader: - data_generation: - deduplication_filter: - enabled: true - params: - api_key: null - base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 - db_path: ./.similarity_db - model: text-embedding-v4 - similarity_threshold: 0.8 - document_reader: - cache_enabled: true - chunk_size: 5120 - document_path: - - dataset/document/your-document1.pdf - - dataset/document/your-document2.pdf - languages: - - eng - split_by: sentence - llm_model: qwen-long - llm_response_length: 8192 - num_workers: 32 - query_reader: - jsonl_dataset_file: - training: - file_path: dataset/jsonl/your-queries.jsonl - type: jsonl_dataset_file - sampling_params: - temperature: 0 - task_num: 10 - env_service: - env_action_preference: code - env_type: appworld - env_url: http://127.0.0.1:8080 - training_split: train - validation_split: dev - huggingface_dat_repo: - dataset_name: null - dataset_path: gsm8k - http_proxy_address: '' - training_split: train - validation_split: validation - jsonl_dataset_file: - training: - file_path: /path/to/training/data.jsonl - validation: - file_path: /path/to/validation/data.jsonl - type: random_dummy - task_runner: - llm_infer_submit_method: async - wrapper_multiprocessing_timeout: 3600 - wrapper_type: asyncio-with-gc - trainer_common: - algorithm: - adv_estimator: grpo - use_kl_in_reward: false - checkpoint_base_dir: ./saved_checkpoints - fsdp_config: - optimizer_offload: true - param_offload: true - kl_loss_coef: 0.002 - kl_loss_type: low_var_kl - logger: tensorboard - mini_batch_num: 1 - n_gpus_per_node: 8 - nnodes: 1 - optim: - lr: 1.0e-06 - save_freq: 20 - save_trajectory_as_json_file: false - test_freq: 20 - total_epochs: 50 - ulysses_sequence_parallel_size: 1 - use_kl_loss: true - val_before_train: false - val_pass_n: 4 -algorithm: - _target_: verl.trainer.config.AlgoConfig - adv_estimator: grpo - gamma: 1.0 - kl_ctrl: - _target_: verl.trainer.config.KLControlConfig - horizon: 10000 - kl_coef: 0.001 - target_kl: 0.1 - type: fixed - kl_penalty: kl - lam: 1.0 - norm_adv_by_std_in_grpo: true - pf_ppo: - reweight_method: pow - weight_pow: 2.0 - use_kl_in_reward: false - use_pf_ppo: false -critic: - _target_: verl.workers.config.FSDPCriticConfig - checkpoint: - _target_: verl.trainer.config.CheckpointConfig - async_save: false - load_contents: - - model - - optimizer - - extra - save_contents: - - model - - optimizer - - extra - cliprange_value: 0.5 - enable: false - forward_max_token_len_per_gpu: 32768 - forward_micro_batch_size: null - forward_micro_batch_size_per_gpu: null - grad_clip: 1.0 - loss_agg_mode: seq-mean-token-mean - model: - _target_: verl.workers.config.FSDPCriticModelCfg - enable_activation_offload: false - enable_gradient_checkpointing: true - external_lib: null - fsdp_config: - _target_: verl.workers.config.FSDPEngineConfig - forward_prefetch: false - fsdp_size: -1 - offload_policy: false - optimizer_offload: false - param_offload: false - reshard_after_forward: true - wrap_policy: - min_num_params: 0 - lora_alpha: 16 - lora_rank: 0 - override_config: {} - path: ~/models/deepseek-llm-7b-chat - target_modules: all-linear - tokenizer_path: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct - trust_remote_code: false - use_remove_padding: false - use_shm: false - optim: - _target_: verl.workers.config.FSDPOptimizerConfig - lr: 1.0e-05 - lr_warmup_steps: -1 - lr_warmup_steps_ratio: 0.0 - min_lr_ratio: null - total_training_steps: -1 - warmup_style: constant - weight_decay: 0.01 - ppo_epochs: 1 - ppo_max_token_len_per_gpu: 32768 - ppo_micro_batch_size: null - ppo_micro_batch_size_per_gpu: null - ppo_mini_batch_size: 16 - profiler: - _target_: verl.utils.profiler.ProfilerConfig - all_ranks: false - discrete: false - ranks: [] - rollout_n: 1 - shuffle: false - strategy: fsdp - ulysses_sequence_parallel_size: 1 - use_dynamic_bsz: true -custom_reward_function: - name: compute_score - path: null -data: - custom_cls: - name: null - path: null - datagen: - name: null - path: null - dataloader_num_workers: 8 - fast_eval: true - filter_overlong_prompts: true - filter_overlong_prompts_workers: 1 - image_key: images - max_prompt_length: 3000 - max_response_length: 15000 - prompt_key: prompt - return_full_prompt: false - return_multi_modal_inputs: true - return_raw_chat: true - return_raw_input_ids: false - reward_fn_key: data_source - sampler: - class_name: null - class_path: null - seed: 42 - shuffle: true - tokenizer: null - train_batch_size: 32 - train_files: ~/data/rlhf/gsm8k/train.parquet - truncation: error - trust_remote_code: false - use_shm: false - val_batch_size: 100000000000 - val_files: ~/data/rlhf/gsm8k/test.parquet - validation_shuffle: false - video_key: videos -defaults: -- verl_default -- ajet_default -- _self_ -hydra: - searchpath: - - file://ajet/default_config - - file://ajet/default_config/verl -ray_init: - num_cpus: null - timeline_json_file: null -reward_model: - enable: false - forward_max_token_len_per_gpu: 32768 - launch_reward_fn_async: false - max_length: null - micro_batch_size: null - micro_batch_size_per_gpu: null - model: - external_lib: null - fsdp_config: - _target_: verl.workers.config.FSDPEngineConfig - forward_prefetch: false - fsdp_size: -1 - param_offload: false - reshard_after_forward: true - wrap_policy: - min_num_params: 0 - input_tokenizer: /mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2___5-14B-Instruct - path: ~/models/FsfairX-LLaMA3-RM-v0.1 - trust_remote_code: false - use_fused_kernels: false - use_remove_padding: false - use_shm: false - profiler: - _target_: verl.utils.profiler.ProfilerConfig - all_ranks: false - discrete: false - ranks: [] - reward_manager: naive - sandbox_fusion: - max_concurrent: 64 - memory_limit_mb: 1024 - url: null - strategy: fsdp - ulysses_sequence_parallel_size: 1 - use_dynamic_bsz: true -trainer: - balance_batch: true - checkpoint_base_dir: ./saved_checkpoints - controller_nsight_options: - cuda-graph-trace: graph - cuda-memory-usage: 'true' - trace: cuda,nvtx,cublas,ucx - critic_warmup: 0 - default_hdfs_dir: null - default_local_dir: ${trainer.checkpoint_base_dir}/${trainer.project_name}/${trainer.experiment_name} - del_local_ckpt_after_load: false - device: cuda - esi_redundant_time: 0 - experiment_name: math_gsm8k_grpo - hfmodelpath: '' - log_val_generations: 0 - logger: - - console - - tensorboard - max_actor_ckpt_to_keep: null - max_critic_ckpt_to_keep: null - n_gpus_per_node: 8 - nnodes: 1 - npu_profile: - options: - analysis: true - level: level1 - record_shapes: false - roles: - - all - save_path: ./profiler_data - with_cpu: true - with_memory: false - with_module: false - with_npu: true - with_stack: false - profile_continuous_steps: false - profile_steps: null - project_name: ajet_default_project - ray_wait_register_center_timeout: 300 - resume_from_path: null - resume_mode: auto - rollout_data_dir: null - save_freq: 20 - test_freq: 20 - total_epochs: 50 - total_training_steps: null - use_legacy_worker_impl: auto - val_before_train: false - val_only: false - val_pass_n: 4 - validation_data_dir: null - worker_nsight_options: - capture-range: cudaProfilerApi - capture-range-end: null - cuda-graph-trace: graph - cuda-memory-usage: 'true' - kill: none - trace: cuda,nvtx,cublas,ucx diff --git a/tutorial/example_math_swarm/math.py b/tutorial/example_math_swarm/math.py index 3c9ba92f..c1351a9d 100644 --- a/tutorial/example_math_swarm/math.py +++ b/tutorial/example_math_swarm/math.py @@ -28,7 +28,8 @@ def main(): reader_type = "huggingface_dat_repo", reader_config = AjetTaskReader( huggingface_dat_repo = HuggingfaceDatRepo( - dataset_path = "/root/agentjet/benchmark_datasets/dataset/gsm8k/socratic", + dataset_path = '/mnt/data_cpfs/model_cache/modelscope/dataset/openai/gsm8k/main', + # dataset_path = "/root/agentjet/benchmark_datasets/dataset/gsm8k/socratic", # dataset_path = "openai/gsm8k", # dataset_name = "main", ) diff --git a/tutorial/example_werewolves/game.py b/tutorial/example_werewolves/game.py index 10246c32..8eca099e 100644 --- a/tutorial/example_werewolves/game.py +++ b/tutorial/example_werewolves/game.py @@ -44,7 +44,7 @@ async def hunter_stage( global moderator msg_hunter = await hunter_agent( await moderator(Prompts.to_hunter.format(name=hunter_agent.name)), - structured_model=get_hunter_model(players.current_alive), + structured_model=get_hunter_model(players.all_players), ) if msg_hunter.metadata.get("shoot"): return msg_hunter.metadata.get("name", None) @@ -134,7 +134,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90 msgs_vote = await fanout_pipeline( players.werewolves, msg=await moderator(content=Prompts.to_wolves_vote), - structured_model=get_vote_model(players.current_alive), + structured_model=get_vote_model(players.all_players), enable_gather=False, ) killed_player, votes = majority_vote( @@ -187,7 +187,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90 ), ), structured_model=get_poison_model( - players.current_alive, + players.all_players, ), ) if msg_witch_poison.metadata.get("poison"): @@ -206,7 +206,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90 names_to_str(players.current_alive), ), ), - structured_model=get_seer_model(players.current_alive), + structured_model=get_seer_model(players.all_players), ) if msg_seer.metadata.get("name"): player = msg_seer.metadata["name"] @@ -282,7 +282,7 @@ async def werewolves_game(agents: list[ReActAgent], roles) -> bool: # noqa: C90 names_to_str(players.current_alive), ), ), - structured_model=get_vote_model(players.current_alive), + structured_model=get_vote_model(players.all_players), enable_gather=False, ) voted_player, votes = majority_vote( diff --git a/tutorial/example_werewolves_swarm/agent_roll.py b/tutorial/example_werewolves_swarm/agent_roll.py index e7b920bd..2e5fa672 100644 --- a/tutorial/example_werewolves_swarm/agent_roll.py +++ b/tutorial/example_werewolves_swarm/agent_roll.py @@ -63,7 +63,7 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): game = ExampleWerewolves( trainable_targets=["werewolf"], big_external_opponent_llm_name="Qwen/Qwen3-235B-A22B-Instruct-2507", - big_external_opponent_llm_url="http://22.16.90.187/v1", + big_external_opponent_llm_url="http://22.14.116.243/v1", ) res = asyncio.run(game.execute(task, api_baseurl_key)) return res diff --git a/tutorial/example_werewolves_swarm/werewolves.yaml b/tutorial/example_werewolves_swarm/werewolves.yaml index fae6ac50..e096e8f6 100644 --- a/tutorial/example_werewolves_swarm/werewolves.yaml +++ b/tutorial/example_werewolves_swarm/werewolves.yaml @@ -2,8 +2,6 @@ ajet: project_name: example_werewolves_swarm experiment_dir: "auto" # {exp-dir}/{experiment_name} - task_reader: - type: random_dummy # ✨ model: # ✨ select model to be trained @@ -17,7 +15,7 @@ ajet: num_repeat: 6 agent_madness_reward: 0.0 tensor_model_parallel_size: 1 - max_num_seqs: 40 + # max_num_seqs: 40 # monitor LLM's abormal behaviors during rollout compute_madness_checklist: - "nonsense" From 621d235b9bcf54b018dce7d5e3d8c5754bca8e52 Mon Sep 17 00:00:00 2001 From: Qingxu Fu Date: Tue, 3 Mar 2026 17:15:10 +0800 Subject: [PATCH 05/11] bug patch --- tutorial/example_werewolves/start.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tutorial/example_werewolves/start.py b/tutorial/example_werewolves/start.py index e06241ec..14ed14ba 100644 --- a/tutorial/example_werewolves/start.py +++ b/tutorial/example_werewolves/start.py @@ -105,18 +105,19 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl # initialize agents players = [] for i, role in enumerate(roles): - default_model = OpenAIChatModel( - stream=False, - api_key="no_api_key", - generate_kwargs={"temperature": 0.01}, - model_name=self.big_external_opponent_llm_name, - client_args={"base_url": self.big_external_opponent_llm_url}, - ) - model_for_this_agent = tuner.as_agentscope_model( - agent_name=f"Player{i + 1}", # the name of this agent - target_tag=role, # `target_tag in self.trainable_targets` means we train this agent, otherwise we do not train this agent. - debug_model=default_model, # the model used when this agent is not in `self.trainable_targets` - ) + if role not in self.trainable_targets: + model_for_this_agent = OpenAIChatModel( + stream=False, + api_key="no_api_key", + generate_kwargs={"temperature": 0.01}, + model_name=self.big_external_opponent_llm_name, + client_args={"base_url": self.big_external_opponent_llm_url}, + ) + else: + model_for_this_agent = tuner.as_agentscope_model( + agent_name=f"Player{i + 1}", + target_tag=role, + ) agent = ReActAgent( name=f"Player{i + 1}", sys_prompt=get_official_agent_prompt(f"Player{i + 1}"), @@ -124,7 +125,7 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl formatter=DashScopeMultiAgentFormatter() if isinstance(model_for_this_agent, DashScopeChatModel) else OpenAIMultiAgentFormatter(), max_iters=3 if role in self.trainable_targets else 5, ) - # agent.set_console_output_enabled(False) + agent.set_console_output_enabled(False) players += [agent] # reward condition From b150420b5178b9f41092b655ce80222858b0d830 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 3 Mar 2026 18:04:14 +0800 Subject: [PATCH 06/11] patch save dir bug --- ajet/context_tracker/multiagent_tracking.py | 1 + ajet/launcher.py | 4 +-- ajet/swarm_cli.py | 4 +-- ajet/task_rollout/native_parallel_worker.py | 2 +- .../tuner_lib/experimental/as_swarm_server.py | 8 +++-- .../example_werewolves_swarm/agent_roll.py | 31 +++++++++---------- 6 files changed, 26 insertions(+), 24 deletions(-) diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index 30be665c..8faf80a2 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -319,6 +319,7 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline): ) ): logger.bind(exception=True).info(f"General Warning: merge failure discovered.\n") + # from ajet import bp; bp("SWARM") return diff --git a/ajet/launcher.py b/ajet/launcher.py index 3dbda6e1..d3a5b4cd 100644 --- a/ajet/launcher.py +++ b/ajet/launcher.py @@ -214,7 +214,7 @@ def main(): "Please provide a valid config file for swarm server mode." ) if args.conf: - exp_dir = args.exp_dir or DEFAULT_DIR + exp_base_dir = args.exp_dir or DEFAULT_DIR yaml_path = args.conf ( main_yaml_fp, @@ -223,7 +223,7 @@ def main(): exp_config, ) = prepare_experiment_config( yaml_path=yaml_path, - exp_base_dir=exp_dir, + exp_base_dir=exp_base_dir, backbone=args.backbone, storage=(not args.swarm_server) ) diff --git a/ajet/swarm_cli.py b/ajet/swarm_cli.py index 3434a5b6..723d9b5a 100644 --- a/ajet/swarm_cli.py +++ b/ajet/swarm_cli.py @@ -42,7 +42,7 @@ def start_swarm_server(env, config, port): def cmd_start(args): """Handle the 'start' subcommand.""" # Use default config if not provided - exp_dir = args.exp_dir or DEFAULT_DIR + exp_base_dir = args.exp_dir or DEFAULT_DIR if not args.conf: args.conf = os.path.abspath( os.path.join( @@ -62,7 +62,7 @@ def cmd_start(args): exp_config, ) = prepare_experiment_config( yaml_path=yaml_path, - exp_base_dir=exp_dir, + exp_base_dir=exp_base_dir, backbone="verl", storage=False ) diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index 4afad524..e8699fda 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -173,7 +173,7 @@ def rollout_swarm( # noqa: C901 Build a pool of threads to run context trackers in parallel, each thread re-spawn after complete, until reaching conditions to stop. """ - # from ajet import bp; bp("SWARM") + tracker_array: List[SingleAgentContextTracker] = [] rollout_n = self.rollout_n n_batch_task = len(tasks) diff --git a/ajet/tuner_lib/experimental/as_swarm_server.py b/ajet/tuner_lib/experimental/as_swarm_server.py index e29cfbfa..27a1b5b5 100644 --- a/ajet/tuner_lib/experimental/as_swarm_server.py +++ b/ajet/tuner_lib/experimental/as_swarm_server.py @@ -334,9 +334,11 @@ async def start_engine(): yaml_str = shared_mem_dict["train_config_yaml"] config_dict = yaml_module.safe_load(yaml_str) backbone = config_dict.get("ajet", {}).get("backbone", "verl") - exp_base_dir = os.path.dirname( - config_dict.get("ajet", {}).get("experiment_dir", "saved_experiments") - ) + DEFAULT_DIR = "saved_experiments" + experiment_dir = config_dict.get("ajet", {}).get("experiment_dir", DEFAULT_DIR) + if experiment_dir == "auto": + exp_base_dir = DEFAULT_DIR + exp_base_dir = os.path.dirname(os.path.abspath(experiment_dir)) # Save YAML to temporary file with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as temp_file: diff --git a/tutorial/example_werewolves_swarm/agent_roll.py b/tutorial/example_werewolves_swarm/agent_roll.py index 2e5fa672..c51bc01a 100644 --- a/tutorial/example_werewolves_swarm/agent_roll.py +++ b/tutorial/example_werewolves_swarm/agent_roll.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import os -from ajet.schema.task import Task +from ajet.schema.task import Task, WorkflowTask from ajet.copilot.job import AgentJetJob from ajet.task_reader import RouterTaskReader from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor @@ -24,37 +24,36 @@ def main(): base_yaml_config="tutorial/example_werewolves_swarm/werewolves.yaml", algorithm="grpo", experiment_name="werewolves_swarm", + max_env_worker=128, ) # Hand shake with remote swarm server swarm_worker = SwarmClient(AJET_SWARM_URL) swarm_worker.auto_sync_train_config_and_start_engine( ajet_job, - force_restart=False, + # force_restart=True, ) GRPO_N = ajet_job.num_repeat REMOTE_BATCH_SIZE = ajet_job.batch_size def rollout(task): - try: - # begin episode - episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60) - # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) - workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` - # report output back to swarm remote - swarm_worker.end_episode(task, episode_uuid, workflow_output) - return - except: - pass + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=240) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return - executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True) + + executor = PeriodicDrainThreadPoolExecutor(workers=1, max_parallel=64, auto_retry=True, block_first_run=True) for _ in range(NUM_EPOCH): for _, task in enumerate(dataset.generate_training_tasks()): for _ in range(GRPO_N): executor.submit_with_periodic_drain(fn=rollout, task=task) - return None + return def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): @@ -63,9 +62,9 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): game = ExampleWerewolves( trainable_targets=["werewolf"], big_external_opponent_llm_name="Qwen/Qwen3-235B-A22B-Instruct-2507", - big_external_opponent_llm_url="http://22.14.116.243/v1", + big_external_opponent_llm_url="http://22.14.116.243:2888/v1", ) - res = asyncio.run(game.execute(task, api_baseurl_key)) + res = asyncio.run(game.execute(WorkflowTask(task=task), api_baseurl_key)) return res From 642eddc19508402af4b7f19f11175c6356d34156 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Tue, 3 Mar 2026 18:40:10 +0800 Subject: [PATCH 07/11] force check agentscope version --- tutorial/example_werewolves/start.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tutorial/example_werewolves/start.py b/tutorial/example_werewolves/start.py index 14ed14ba..0e0ab4b0 100644 --- a/tutorial/example_werewolves/start.py +++ b/tutorial/example_werewolves/start.py @@ -4,6 +4,7 @@ """The main entry point for the werewolf game.""" from typing import List +import agentscope import numpy as np import dotenv dotenv.load_dotenv() @@ -86,6 +87,7 @@ class ExampleWerewolves(Workflow): async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: + assert agentscope.__version__ == "1.0.7", "AgentScope has too many bugs across versions, please use version 1.0.7 for werewolves example." # ensure trainable targets is legal assert self.trainable_targets is not None, "trainable_targets cannot be None in ExampleWerewolves (because we want to demonstrate a explicit multi-agent case)." From f532854f123fd86ab6b7358e7834a073df5ae382 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 4 Mar 2026 14:21:26 +0800 Subject: [PATCH 08/11] sharing httpx client --- .gitignore | 1 + ajet/task_rollout/native_parallel_worker.py | 6 ++-- .../experimental/as_oai_model_server.py | 5 +-- .../tuner_lib/experimental/as_swarm_client.py | 34 +++++++++---------- .../tuner_lib/experimental/as_swarm_server.py | 3 +- .../experimental/interchange_utils.py | 8 +++-- ajet/utils/swarm_overwatch.py | 3 +- ajet/utils/tokenizer.py | 30 ++++++++++++++-- 8 files changed, 60 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 1b88f0c1..7b64dc3e 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,4 @@ prompts swarmexp swarmlog werewolves_swarm +.claude diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index e8699fda..141ac67e 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -233,7 +233,7 @@ def enough_finished_task_stop_condition(completed_task_id_map_ct) -> bool: f"Too many cached episodes [{total_completed_episodes}] has exceeded the max cached episodes [{self.config.ajet.swarm_mode_sample_collection_max_cached_episodes}] " f"Deleting cached episodes to release memory..." ) - completed_task_id_map_ct = {} + completed_task_id_map_ct.clear() return (total_completed_tasks >= n_batch_task) def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: @@ -258,7 +258,7 @@ def enough_non_dummy_task_stop_condition(completed_task_id_map_ct) -> bool: f"Too many cached episodes [{total_completed_episodes}] has exceeded the max cached episodes [{self.config.ajet.swarm_mode_sample_collection_max_cached_episodes}] " f"Deleting cached episodes to release memory..." ) - completed_task_id_map_ct = {} + completed_task_id_map_ct.clear() return (total_completed_non_dummy_tasks >= n_batch_task) # select stop condition function based on config @@ -387,6 +387,7 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma logger.info('Collecting results...') for ct_list in completed_task_id_map_ct.values(): tracker_array.extend(ct_list) + completed_task_id_map_ct.clear() # TODO: support multi-step reward task_success_rate = np.mean( @@ -402,7 +403,6 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma update_rollout_result_array_preview(observation_window, completed_task_id_map_ct) self._write_swarm_rollout_dynamic_log(observation_window) - return tracker_array diff --git a/ajet/tuner_lib/experimental/as_oai_model_server.py b/ajet/tuner_lib/experimental/as_oai_model_server.py index 51bf9bb1..bc8ffb95 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/experimental/as_oai_model_server.py @@ -428,6 +428,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: # polling for server ready start_time = time.time() + _httpx_client = httpx.Client(timeout=0.5) while True: if interchange_server and interchange_server.exitcode is not None: logger.error(f"Interchange server subprocess failed to start. Return code: {interchange_server.exitcode}") @@ -437,7 +438,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: logger.error(msg) raise RuntimeError(msg) try: - if httpx.get(health_url, timeout=0.5).status_code == 200: + if _httpx_client.get(health_url).status_code == 200: break except Exception: # keep waiting @@ -462,7 +463,7 @@ def start_interchange_server(config, blocking=False, env={}) -> int: interchange_server.join() except KeyboardInterrupt: logger.info("Shutting down interchange server...") - try: httpx.post(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code + try: _httpx_client.post(f"http://127.0.0.1:{port}/stop_engine", timeout=8).status_code except Exception: pass if interchange_server: diff --git a/ajet/tuner_lib/experimental/as_swarm_client.py b/ajet/tuner_lib/experimental/as_swarm_client.py index 6b3f9168..4bf9d8e5 100644 --- a/ajet/tuner_lib/experimental/as_swarm_client.py +++ b/ajet/tuner_lib/experimental/as_swarm_client.py @@ -71,6 +71,8 @@ def __init__(self, server_url: str): self._agent_jet_job = None # throttle self._recent_seen_tasks = [] + # reuse httpx client to avoid creating SSL context repeatedly + self._http_client = httpx.Client(timeout=GENERAL_TIMEOUT) def logger_info(self, message): # logger with de-duplication within 1 second to prevent log flooding @@ -252,10 +254,9 @@ def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="t discard_episode_timeout=discard_episode_timeout, throttle_policy=throttle_policy ) - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/claim_episode", - json=req_obj.model_dump(), - timeout=GENERAL_TIMEOUT + json=req_obj.model_dump() ) raise_for_status_with_detail(resp) data = ClaimEpisodeResponse.model_validate(resp.json()) @@ -337,10 +338,9 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut task_id=task_id ) - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/end_episode", - json=req_obj.model_dump(), - timeout=GENERAL_TIMEOUT + json=req_obj.model_dump() ) raise_for_status_with_detail(resp) data = EndEpisodeResponse.model_validate(resp.json()) @@ -366,10 +366,9 @@ def abort_episode(self, episode_uuid: str): task_id="" ) - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/abort_episode", - json=req_obj.model_dump(), - timeout=GENERAL_TIMEOUT + json=req_obj.model_dump() ) raise_for_status_with_detail(resp) data = EndEpisodeResponse.model_validate(resp.json()) @@ -399,10 +398,9 @@ def sync_train_config(self, agent_jet_job: AgentJetJob): req_obj = SyncTrainConfigRequest(yaml_as_string=yaml_str) - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/sync_train_config", - json=req_obj.model_dump(), - timeout=GENERAL_TIMEOUT + json=req_obj.model_dump() ) raise_for_status_with_detail(resp) self.logger_info("Synced train config to Swarm server") @@ -422,7 +420,7 @@ def start_engine(self): raise RuntimeError(f"Cannot start engine when engine is NOT ENGINE.OFFLINE. (current status: {current_status})") # Send start engine request - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/start_engine", json={}, timeout=600 @@ -487,7 +485,7 @@ def _wait_until_status_change_to(self, desired_status="ENGINE.ROLLING", verbose= @cache_with_ttl(ttl=0.5) def get_engine_status(self) -> Tuple[str, dict]: try: - resp = httpx.get( + resp = self._http_client.get( f"{self.server_url}/get_engine_status", timeout=10 ) @@ -512,7 +510,7 @@ def can_continue_episode(self, episode_uuid: str) -> bool: client_uuid=self.client_uuid, episode_uuid=episode_uuid ) - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/can_continue_episode", json=req_obj.model_dump(), timeout=10 @@ -526,7 +524,7 @@ def can_continue_episode(self, episode_uuid: str) -> bool: def get_episode_buffer(self) -> List[EpisodeStatus]: try: - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/get_episode_buffer", json={}, timeout=10 @@ -585,7 +583,7 @@ def stop_engine(self): self.logger_info("Engine is already OFFLINE. No action needed.") return - resp = httpx.post( + resp = self._http_client.post( f"{self.server_url}/stop_engine", json={}, timeout=600 @@ -605,7 +603,7 @@ def get_rollout_stat(self) -> CurrentBatchRolloutPoolInformation: Returns statistics about completed episodes, tasks, and progress. """ try: - resp = httpx.get( + resp = self._http_client.get( f"{self.server_url}/get_current_batch_rollout_pool_information", timeout=10 ) diff --git a/ajet/tuner_lib/experimental/as_swarm_server.py b/ajet/tuner_lib/experimental/as_swarm_server.py index 27a1b5b5..aa82984a 100644 --- a/ajet/tuner_lib/experimental/as_swarm_server.py +++ b/ajet/tuner_lib/experimental/as_swarm_server.py @@ -338,7 +338,8 @@ async def start_engine(): experiment_dir = config_dict.get("ajet", {}).get("experiment_dir", DEFAULT_DIR) if experiment_dir == "auto": exp_base_dir = DEFAULT_DIR - exp_base_dir = os.path.dirname(os.path.abspath(experiment_dir)) + else: + exp_base_dir = os.path.dirname(os.path.abspath(experiment_dir)) # Save YAML to temporary file with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as temp_file: diff --git a/ajet/tuner_lib/experimental/interchange_utils.py b/ajet/tuner_lib/experimental/interchange_utils.py index b5a52e43..880b87b0 100644 --- a/ajet/tuner_lib/experimental/interchange_utils.py +++ b/ajet/tuner_lib/experimental/interchange_utils.py @@ -109,6 +109,8 @@ class UpdateEngineStatusRequest(BaseModel): VERBOSE = True +shared_http_client = httpx.Client(timeout=10.0) + def get_interchange_server_url(config): port = os.getenv("AJET_DAT_INTERCHANGE_PORT") if isinstance(config, dict): @@ -127,7 +129,7 @@ def http_change_engine_status(config, new_status: str, new_status_detail: str|No if new_status not in VALID_STATUSES: raise ValueError(f"Invalid engine status: {new_status}") - resp = httpx.post( + resp = shared_http_client.post( f"{get_interchange_server_url(config)}/update_engine_status", json={"engine_status": new_status, "engine_status_detail": new_status_detail, "global_step": global_step}, timeout=10 @@ -137,7 +139,7 @@ def http_change_engine_status(config, new_status: str, new_status_detail: str|No def is_episode_claimed(config, episode_uuid: str, unregister_if_not_claimed: bool) -> bool: - resp = httpx.post( + resp = shared_http_client.post( f"{get_interchange_server_url(config)}/is_episode_claimed", json={"episode_uuid": episode_uuid, "unregister_if_not_claimed": unregister_if_not_claimed}, timeout=5 @@ -168,7 +170,7 @@ def http_register_episode(config, zmq_listen_result_addr=zmq_listen_result_addr, ) # send http request to swarm server to register episode - response = httpx.post( + response = shared_http_client.post( f"{interchange_http_addr}/register_episode", json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2 timeout=2 diff --git a/ajet/utils/swarm_overwatch.py b/ajet/utils/swarm_overwatch.py index fd97fa41..9f8b7717 100644 --- a/ajet/utils/swarm_overwatch.py +++ b/ajet/utils/swarm_overwatch.py @@ -37,11 +37,12 @@ def __init__(self, server_url: str, refresh_interval: float = 2.0): self.last_update_time = None self.error_count = 0 self.total_requests = 0 + self._httpx_client = httpx.Client(timeout=5.0) def fetch_pool_info(self) -> Optional[CurrentBatchRolloutPoolInformation]: """Fetch current batch rollout pool information from server""" try: - response = httpx.get( + response = self._httpx_client.get( f"{self.server_url}/get_current_batch_rollout_pool_information", timeout=5.0, ) diff --git a/ajet/utils/tokenizer.py b/ajet/utils/tokenizer.py index 94ab8007..30128a30 100644 --- a/ajet/utils/tokenizer.py +++ b/ajet/utils/tokenizer.py @@ -19,6 +19,9 @@ def cleanup_messages(messages: List[Dict]) -> List[Dict]: pass return messages_copied +# Cache storage +_cache = {} + def ajet_apply_chat_template( tokenizer, @@ -28,16 +31,39 @@ def ajet_apply_chat_template( tokenize: bool = True, ): conversation = cleanup_messages(conversation) + + # Create cache key by hashing all inputs + cache_key = ( + id(tokenizer), + hash(json.dumps(conversation, sort_keys=True)), + hash(json.dumps(tools, sort_keys=True)) if tools else 0, + add_generation_prompt, + tokenize, + ) + + # Check cache + if cache_key in _cache: + return _cache[cache_key] + + # Compute result if tools: - return tokenizer.apply_chat_template( + result = tokenizer.apply_chat_template( conversation, tools, add_generation_prompt=add_generation_prompt, tokenize=tokenize, ) else: - return tokenizer.apply_chat_template( + result = tokenizer.apply_chat_template( conversation, tokenize=tokenize, add_generation_prompt=add_generation_prompt, ) + + # Store in cache (implement LRU eviction if cache gets too large) + if len(_cache) >= 1024: + # Remove oldest item (first inserted) + _cache.pop(next(iter(_cache))) + + _cache[cache_key] = result + return result From 79407c43a593c5e909736fc7c333f7b76c79c363 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 4 Mar 2026 15:09:15 +0800 Subject: [PATCH 09/11] fix memory leak --- ajet/default_config/ajet_default.yaml | 2 +- ajet/default_config/ajet_ts_default.yaml | 2 +- .../experimental/as_oai_model_client.py | 39 ++++++++++++------- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/ajet/default_config/ajet_default.yaml b/ajet/default_config/ajet_default.yaml index 7a62ee75..1539d028 100644 --- a/ajet/default_config/ajet_default.yaml +++ b/ajet/default_config/ajet_default.yaml @@ -2,7 +2,7 @@ ajet: project_name: "ajet_default_project" experiment_name: "read_yaml_name" - experiment_dir: "{exp-dir}/{experiment_name}" # {exp-dir}/{experiment_name} + experiment_dir: "auto" # {exp-dir}/{experiment_name} backbone: verl # `debug` or `trinity` or `verl` diff --git a/ajet/default_config/ajet_ts_default.yaml b/ajet/default_config/ajet_ts_default.yaml index ab584d61..90e3f4bd 100644 --- a/ajet/default_config/ajet_ts_default.yaml +++ b/ajet/default_config/ajet_ts_default.yaml @@ -2,7 +2,7 @@ ajet: project_name: "ajet_default_project" experiment_name: "read_yaml_name" - experiment_dir: "{exp-dir}/{experiment_name}" # {exp-dir}/{experiment_name} + experiment_dir: "auto" # {exp-dir}/{experiment_name} backbone: verl model: diff --git a/ajet/tuner_lib/experimental/as_oai_model_client.py b/ajet/tuner_lib/experimental/as_oai_model_client.py index 616ab19a..aaecde5c 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_client.py +++ b/ajet/tuner_lib/experimental/as_oai_model_client.py @@ -131,23 +131,36 @@ def _begin_service_threading(self): # begin to run the llm request, monitored by context tracker # we re-use previously created thread for best performance if DEBUG: logger.info(f"[client] {self.episode_uuid} | before asyncio run self.llm_infer") + + # Check if there's a running event loop try: loop = asyncio.get_running_loop() - except: + created_new_loop = False + except RuntimeError: + # No running loop, create a new one loop = asyncio.new_event_loop() - context_tracker_executor = SharedInferenceTrackerThreadExecutor(self.max_inference_tracker_threads).get_shared_executor() - future = loop.run_in_executor( - context_tracker_executor, - asyncio.run, - self.llm_proxy_with_tracker.chat_completion_request( - req=parsed_msg.completion_request, - timeline_uuid=parsed_msg.timeline_uuid, - agent_name=parsed_msg.agent_name, - target_tag=parsed_msg.target_tag, - episode_uuid=parsed_msg.episode_uuid, + asyncio.set_event_loop(loop) + created_new_loop = True + + try: + context_tracker_executor = SharedInferenceTrackerThreadExecutor(self.max_inference_tracker_threads).get_shared_executor() + future = loop.run_in_executor( + context_tracker_executor, + asyncio.run, + self.llm_proxy_with_tracker.chat_completion_request( + req=parsed_msg.completion_request, + timeline_uuid=parsed_msg.timeline_uuid, + agent_name=parsed_msg.agent_name, + target_tag=parsed_msg.target_tag, + episode_uuid=parsed_msg.episode_uuid, + ) ) - ) - result = loop.run_until_complete(future).model_dump_json() # type: ignore + result = loop.run_until_complete(future).model_dump_json() # type: ignore + finally: + # Clean up the event loop if we created it + if created_new_loop: + loop.close() + asyncio.set_event_loop(None) if DEBUG: logger.info(f"[client] {self.episode_uuid} | before send_string (send llm call result)") From 39ab72e97e30653ce2ed556dd65a52a5e64e0ddc Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 4 Mar 2026 16:06:34 +0800 Subject: [PATCH 10/11] add thread safety to cache operations and implement LRU eviction --- .../experimental/as_oai_model_server.py | 12 +++++++++ ajet/utils/tokenizer.py | 27 ++++++++++++------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/ajet/tuner_lib/experimental/as_oai_model_server.py b/ajet/tuner_lib/experimental/as_oai_model_server.py index bc8ffb95..367c8086 100644 --- a/ajet/tuner_lib/experimental/as_oai_model_server.py +++ b/ajet/tuner_lib/experimental/as_oai_model_server.py @@ -24,6 +24,7 @@ from loguru import logger from pydantic import BaseModel +from functools import lru_cache from fastapi import FastAPI, Header, HTTPException, Request from fastapi.responses import StreamingResponse from contextlib import asynccontextmanager @@ -63,6 +64,9 @@ class HealthCheckRequest(BaseModel): context = zmq.Context() atexit.register(context.term) +@lru_cache(maxsize=128) +def ep_key(episode_uuid: str) -> str: + return f"episodes-{episode_uuid}" def get_app(max_fastapi_threads: int = 512, enable_swarm_mode=False, shared_mem_dict=None, shared_mem_dict_lock=None) -> Tuple[FastAPI, Optional[Coroutine]]: @@ -100,6 +104,14 @@ def _begin_handle_chat_completion(episode_address, int_req: InterchangeCompletio result_str = "" for _ in range(50): # max 5 minutes wait + + if enable_swarm_mode: + assert shared_mem_dict is not None + ep_stat = shared_mem_dict[ep_key(episode_uuid)] + episode_status = ep_stat.episode_status + if episode_status != "claimed": + raise HTTPException(status_code=404, detail="The episode is not claimed, cannot accept new requests.") + try: if DEBUG: logger.info(f"[server] episode_uuid: {episode_uuid} | recv_string begin.") diff --git a/ajet/utils/tokenizer.py b/ajet/utils/tokenizer.py index 30128a30..64587381 100644 --- a/ajet/utils/tokenizer.py +++ b/ajet/utils/tokenizer.py @@ -1,5 +1,6 @@ import copy import json +import threading from typing import Dict, List @@ -21,6 +22,7 @@ def cleanup_messages(messages: List[Dict]) -> List[Dict]: # Cache storage _cache = {} +_cache_lock = threading.Lock() def ajet_apply_chat_template( @@ -41,11 +43,12 @@ def ajet_apply_chat_template( tokenize, ) - # Check cache - if cache_key in _cache: - return _cache[cache_key] + # Check cache with thread safety + with _cache_lock: + if cache_key in _cache: + return _cache[cache_key] - # Compute result + # Compute result (time consuming) - outside lock to avoid blocking other threads if tools: result = tokenizer.apply_chat_template( conversation, @@ -60,10 +63,16 @@ def ajet_apply_chat_template( add_generation_prompt=add_generation_prompt, ) - # Store in cache (implement LRU eviction if cache gets too large) - if len(_cache) >= 1024: - # Remove oldest item (first inserted) - _cache.pop(next(iter(_cache))) + # Store in cache with thread safety (implement LRU eviction if cache gets too large) + with _cache_lock: + if len(_cache) >= 1024: + # Remove oldest item (first inserted) + try: + _cache.pop(next(iter(_cache))) + except KeyError: + # Cache was modified by another thread, which is fine + pass + + _cache[cache_key] = result - _cache[cache_key] = result return result From 76d0e3fba7621795872e9a9d55ffa799db4e6940 Mon Sep 17 00:00:00 2001 From: "qingxu.fu" Date: Wed, 4 Mar 2026 16:59:01 +0800 Subject: [PATCH 11/11] implement skills and skillbench example --- ajet/context_tracker/multiagent_tracking.py | 4 +- ajet/copilot/train-complex-blackbox/SKILL.md | 174 ++++++++++++++++++ ajet/copilot/write-swarm-client/SKILL.md | 87 ++++++--- ajet/schema/extended_msg.py | 2 + ajet/task_rollout/native_parallel_worker.py | 78 ++++++++ ajet/task_rollout/single_worker.py | 4 + .../opencode_build_skillbench_agent.prompt.md | 17 ++ 7 files changed, 343 insertions(+), 23 deletions(-) create mode 100644 ajet/copilot/train-complex-blackbox/SKILL.md create mode 100644 tutorial/opencode_build_skillbench_agent.prompt.md diff --git a/ajet/context_tracker/multiagent_tracking.py b/ajet/context_tracker/multiagent_tracking.py index 8faf80a2..9a6069fc 100644 --- a/ajet/context_tracker/multiagent_tracking.py +++ b/ajet/context_tracker/multiagent_tracking.py @@ -334,7 +334,9 @@ def detect_tool_call_madness(self, llm_output): # llm_output["tool_calls"] is not None, and is not [] tool_calls = llm_output["tool_calls"] if "wrong_toolcall" in self.config.ajet.rollout.compute_madness_checklist: - copy_tool_calls = copy.deepcopy(tool_calls) + # copy_tool_calls = copy.deepcopy(tool_calls) + # Shallow copy is sufficient - we're only reading the data + copy_tool_calls = tool_calls wrong_toolcall = False for i in range(len(copy_tool_calls)): if ("function" in copy_tool_calls[i]) and ( diff --git a/ajet/copilot/train-complex-blackbox/SKILL.md b/ajet/copilot/train-complex-blackbox/SKILL.md new file mode 100644 index 00000000..6bd8e808 --- /dev/null +++ b/ajet/copilot/train-complex-blackbox/SKILL.md @@ -0,0 +1,174 @@ +--- +name: train-complex-blackbox +description: Create a trainable agent loop or agent workflow with AgentJet +license: Complete terms in LICENSE.txt +--- + + +## 0. Ask user for API key + model (or API key + base url + model) for debugging + +This is not 100% necessary, but it can help a lot in debugging in step 1. +If user has not given a API, ask user to give your one. + + +By default, the code you write should be located at ./tutorial/opencode_build_xxxxxx/*.py + +## 1. Initial Programming + +### Writing dataset collector (`get_training_dataset_item_list.py`) +- `get_training_dataset_item_list.py`: Returns a list of training data items. Maybe a list of training tasks, each item is a string identifier of a training task, or a dict containing necessary information for the training task. + +### Episode Runner (`run_episode_once.py`) +- `run_episode_once.py`: + + - Argument Parser: takes (training data item identifier + api-key + base-url) as input, model-name is not required, you can make up a model name because we ignore it. + + - Execute the agent: read the document of the agent user asked you to train, figure out how to execute the agent. In most cases you can use subprocess to start a commandline process to execute the agent, your biggest issue is to figure out how to pass the training data item identifier, api-key and base-url to that commandline process. You can also use python code to execute the agent if you think it's more convenient. + + - Reward: extract / compute the reward/score for the agent's output. Some agents have clear reward sigal, but others don't. + - clear reward signal: take that down as the reward, no need to do extra reward engineering. + - no clear reward signal: you need to design a reward function to compute the reward/score for the agent's output. You can use another LLM to help you design the reward function, or you can design it by yourself if you have domain knowledge. + + +### Test + +Remember to test these two parts before moving to step 2, make sure they work as expected. + + + +## 2. Writing training code + +This part is easy, simply follow this template and change the necessary part such as dataset path, model name, etc. + +`agent_roll.py` + +```python +# -*- coding: utf-8 -*- + +import os +import re +import requests +from textwrap import dedent +from ajet.schema.task import Task, WorkflowOutput +from ajet.copilot.job import AgentJetJob +from ajet.task_reader import RouterTaskReader +from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor +from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey +from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo +from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient + +# python -m tutorial.example_math_swarm.math + +GRPO_N = 4 # grpo group size +NUM_EPOCH = 10000 +AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086") +REMOTE_MODEL_PATH = os.getenv("REMOTE_MODEL_PATH", "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct") +REMOTE_BATCH_SIZE = 32 +REMOTE_ALLOCATE_GPU_PER_NODE = 8 + +def main(): + + # Handshake with swarm remote, then send training param to swarm remote (such as model to be trained, algorithm, etc) + dataset = RouterTaskReader( + reader_type = "huggingface_dat_repo", + reader_config = AjetTaskReader( + huggingface_dat_repo = HuggingfaceDatRepo( + dataset_path = '/mnt/data_cpfs/model_cache/modelscope/dataset/openai/gsm8k/main', + # dataset_path = "/root/agentjet/benchmark_datasets/dataset/gsm8k/socratic", + # dataset_path = "openai/gsm8k", + # dataset_name = "main", + ) + ) + ) + # Load the CountDown dataset + # print(f"Loading dataset from: {LOCAL_DATASET_PATH}") + # dataset = RouterTaskReader( + # reader_type="jsonl_dataset_file", + # reader_config=AjetTaskReader( + # jsonl_dataset_file=JsonlDatasetFile( + # training=JsonlTrainingFp(file_path=LOCAL_DATASET_PATH) + # ) + # ), + # ) + + # Hand shake with remote swarm server + swarm_worker = SwarmClient(AJET_SWARM_URL) + ajet_job = AgentJetJob( + experiment_name="math_gsm8k_grpo", + algorithm="grpo", + n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE, + model=REMOTE_MODEL_PATH, + batch_size=REMOTE_BATCH_SIZE, + num_repeat=GRPO_N, + ) + print(ajet_job.config.to_dict()) + swarm_worker.auto_sync_train_config_and_start_engine( + ajet_job, + force_restart=True, + ) + + def rollout(task): + # begin episode + episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60) + # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key ) + workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output` + # report output back to swarm remote + swarm_worker.end_episode(task, episode_uuid, workflow_output) + return + + executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True) + for _ in range(NUM_EPOCH): + for _, task in enumerate(dataset.generate_training_tasks()): + for _ in range(GRPO_N): + executor.submit_with_periodic_drain(fn=rollout, task=task) + + return None + + +def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey): + .... + raw_reward: float = ... # compute the reward for the agent's output + return WorkflowOutput(reward=raw_reward, metadata={"important_metadata": important_metadata}) + + +if __name__ == "__main__": + main() + + +``` + + +It is very clear now, your job in step 2 is to: + +- use `get_training_dataset_item_list.py` to generate `List[Task]` (`from ajet.schema.task import Task`) +- use `run_episode_once.py` to execute a single episode and place it in `execute_agent` function + + +## 3. Simplify your code and fix bugs + +before moving to step 4, you can simplify your code and fix bugs to make sure it can run smoothly. + + +## 4. Training + +Finally, you can start training. + +Run `ajet-swarm start` to start training server (if the user has already installed agentjet swarm environment), +if the user has docker environment, you can also refer to `docs/en/ajet-swarm-docker.md` to start a AgentSwarm docker container. + +Create a duplication of `agent_roll.py` named `agent_roll_one_episode_debug.py`, and modify it to only run one episode, this can help you debug whether the episode runner and reward function work as expected. + +After the server side is ready, run +```bash +python /path/to/agent_roll_one_episode_debug.py +``` +watch console log to see if the episode can be executed successfully and reward can be computed correctly. + +If anything goes wrong, keep server running, rewrite and fix `agent_roll_one_episode_debug.py`, and run it again until it can run one episode successfully. + +Next, patch `agent_roll.py` if there are any bugs discorvered via the debugging of `agent_roll_one_episode_debug.py`, and then run +```bash +python /path/to/agent_roll.py +``` + +to start the training! diff --git a/ajet/copilot/write-swarm-client/SKILL.md b/ajet/copilot/write-swarm-client/SKILL.md index c62693e5..0b98902d 100644 --- a/ajet/copilot/write-swarm-client/SKILL.md +++ b/ajet/copilot/write-swarm-client/SKILL.md @@ -4,25 +4,24 @@ description: Create a trainable agent loop or agent workflow with AgentJet license: Complete terms in LICENSE.txt --- -## 简介: -你的任务是根据要求,创建一个可训练 Agent (或者Agent Loop,多智能体系统等等),提供给用户做强化学习训练。 -在AgentJet强化学习框架下,这是非常简单的。 +## Introduction: -首先,根据用户的要求,给智能体系统起一个名字,例如 user_math_agent +Your task is to create a trainable Agent (or Agent Loop, multi-agent system, etc.) based on the requirements, and provide it to the user for reinforcement learning training. Under the AgentJet reinforcement learning framework, this is very simple. -其次,创建文件: -tutorial/user_math_agent +First, give the agent system a name based on the user's requirements, for example `user_math_agent`. -接下来,创建Agent源文件: -tutorial/user_math_agent/agent_roll.py (以 tutorial/example_academic_trans_swarm/trans_roll.py 为模板,变化不大,关键是向用户索取必要的参数) -tutorial/user_math_agent/agent_run.py (根据用户的要求,创建运行智能体的函数,或者类,都可以。同步异步都可以。) -tutorial/user_math_agent/readme.md (Agent说明,以及训练、调试方法说明) +Next, create the directory: +`tutorial/user_math_agent` +Then, create the Agent source files: +- `tutorial/user_math_agent/agent_roll.py` (Use `tutorial/example_academic_trans_swarm/trans_roll.py` as a template. There aren't many changes — the key is to ask the user for the necessary parameters.) +- `tutorial/user_math_agent/agent_run.py` (Create the function or class to run the agent based on the user's requirements. Synchronous or asynchronous are both fine.) +- `tutorial/user_math_agent/readme.md` (Agent description, along with training and debugging instructions.) -## 智能体编写方法 +## How to Write the Agent -使用 OpenAI SDK 编写智能体,主要包含以下三个函数(以及必要的子函数和子模块): +Write the agent using the OpenAI SDK. It mainly includes the following three functions (along with any necessary sub-functions and sub-modules): ``` from ajet.schema.task import Task, WorkflowOutput @@ -31,24 +30,68 @@ def _compute_reward(...) def _execute_agent(...) -def run_agent_and_compute_reward(task: Task, base_url:string, api_key:string) -> WorkflowOutput: +def run_agent_and_compute_reward(task: Task, base_url: string, api_key: string) -> WorkflowOutput: ``` -在 agent_roll 中,直接import run_agent_and_compute_reward即可。 +In `agent_roll`, simply import `run_agent_and_compute_reward`. -- 智能体的编写要领:通过一个或几个Agent的协作,高效完成用户给定的任务。 -- 奖励编写的要领:容易验证的,使用规则直接计算。不容易验证的,模仿 `tutorial/example_academic_trans_swarm/train_multi_model/trans_reward.py` 中的方法,使用其他大型模型生成 LLM as Judge 程序。 +- **Key points for writing the agent:** Efficiently complete the user's given task through the collaboration of one or several Agents. +- **Key points for writing the reward:** For things that are easy to verify, calculate directly using rules. For things that are hard to verify, follow the approach in `tutorial/example_academic_trans_swarm/train_multi_model/trans_reward.py` and use other large models to create an LLM-as-Judge program. +## Training and Debugging Instructions -## 训练、调试方法说明 +Overall, the user first runs `ajet-swarm start`, then runs `agent_roll.py`, and training begins. You do not need to and are not allowed to run these bash commands. +- First, help the user write `agent_run.py` and `agent_roll.py`. +- Then, write clear instructions to guide the user through training (`readme.md`). -总体而言,就是用户先运行 `ajet-swarm start`, 然后再运行 `agent_roll.py` 训练就开始了。你不需要也不被允许运行这些bash命令。 -- 首先帮助用户写好 `agent_run.py` 和 `agent_roll.py`, -- 然后写清楚引导用户训练的说明(readme.md), -你的任务就完成了。 +Your task is then complete. -以下是一些参考资料。 +Below are some reference materials. +--- + +## Introduction: + +Your task is to create a trainable Agent (or Agent Loop, multi-agent system, etc.) based on the requirements, and provide it to the user for reinforcement learning training. Under the AgentJet reinforcement learning framework, this is very simple. + +First, give the agent system a name based on the user's requirements, for example `user_math_agent`. + +Next, create the directory: +`tutorial/user_math_agent` + +Then, create the Agent source files: +- `tutorial/user_math_agent/agent_roll.py` (Use `tutorial/example_academic_trans_swarm/trans_roll.py` as a template. There aren't many changes — the key is to ask the user for the necessary parameters.) +- `tutorial/user_math_agent/agent_run.py` (Create the function or class to run the agent based on the user's requirements. Synchronous or asynchronous are both fine.) +- `tutorial/user_math_agent/readme.md` (Agent description, along with training and debugging instructions.) + +## How to Write the Agent + +Write the agent using the OpenAI SDK. It mainly includes the following three functions (along with any necessary sub-functions and sub-modules): + +``` +from ajet.schema.task import Task, WorkflowOutput + +def _compute_reward(...) + +def _execute_agent(...) + +def run_agent_and_compute_reward(task: Task, base_url: string, api_key: string) -> WorkflowOutput: +``` + +In `agent_roll`, simply import `run_agent_and_compute_reward`. + +- **Key points for writing the agent:** Efficiently complete the user's given task through the collaboration of one or several Agents. +- **Key points for writing the reward:** For things that are easy to verify, calculate directly using rules. For things that are hard to verify, follow the approach in `tutorial/example_academic_trans_swarm/train_multi_model/trans_reward.py` and use other large models to create an LLM-as-Judge program. + +## Training and Debugging Instructions + +Overall, the user first runs `ajet-swarm start`, then runs `agent_roll.py`, and training begins. You do not need to and are not allowed to run these bash commands. +- First, help the user write `agent_run.py` and `agent_roll.py`. +- Then, write clear instructions to guide the user through training (`readme.md`). + +Your task is then complete. + +Below are some reference materials. --- # Using AgentJet Swarm to Train Your Agents diff --git a/ajet/schema/extended_msg.py b/ajet/schema/extended_msg.py index dfaa7460..5e789074 100644 --- a/ajet/schema/extended_msg.py +++ b/ajet/schema/extended_msg.py @@ -244,9 +244,11 @@ def get_inc_simple(self, text_frag_from, text_frag_to, tokenizer): tokenizer_output = tokenizer(text_frag_from, return_tensors="pt", padding=False) tokenizer_input_ids = tokenizer_output["input_ids"][0].tolist() token_ids_acc = tokenizer_input_ids + del tokenizer_output # Free memory immediately tokenizer_output = tokenizer(text_frag_to, return_tensors="pt", padding=False) input_ids = tokenizer_output["input_ids"][0].tolist() + del tokenizer_output # Free memory immediately # get the new tokens added in this step input_id_increment = input_ids[len(token_ids_acc) :] FN_DEBUG = False diff --git a/ajet/task_rollout/native_parallel_worker.py b/ajet/task_rollout/native_parallel_worker.py index 141ac67e..8db4058c 100644 --- a/ajet/task_rollout/native_parallel_worker.py +++ b/ajet/task_rollout/native_parallel_worker.py @@ -1,7 +1,9 @@ """Parallel environment rollout orchestration utilities.""" import os +import gc import time +import tracemalloc from concurrent.futures import Future, ThreadPoolExecutor from typing import Dict, List, Literal from urllib.parse import quote @@ -90,6 +92,64 @@ def _write_swarm_rollout_dynamic_log(self, observation_window): f.write(string_buffer) return + def _check_memory_leak(self): + """Check for memory leaks by comparing memory snapshots.""" + if not self._tracemalloc_started: + tracemalloc.start() + self._tracemalloc_started = True + logger.info("Memory tracking started (tracemalloc)") + self._memory_snapshot = tracemalloc.take_snapshot() + return + + # Take a new snapshot + gc.collect() # Force garbage collection before snapshot + current_snapshot = tracemalloc.take_snapshot() + + if self._memory_snapshot is not None: + # Compare snapshots + top_stats = current_snapshot.compare_to(self._memory_snapshot, 'lineno') + + logger.info("=" * 80) + logger.info("Memory Leak Detection: Top 10 differences since last rollout_swarm call") + logger.info("=" * 80) + + total_size_diff = 0 + for stat in top_stats[:10]: + total_size_diff += stat.size_diff + logger.info(f"{stat}") + + # Convert to MB + total_size_diff_mb = total_size_diff / 1024 / 1024 + logger.info(f"\nTotal memory difference: {total_size_diff_mb:.2f} MB") + + # Show top current memory consumers + logger.info("\n" + "=" * 80) + logger.info("Top 10 current memory allocations") + logger.info("=" * 80) + top_current = current_snapshot.statistics('lineno') + for stat in top_current[:10]: + logger.info(f"{stat}") + + logger.info("=" * 80) + + # Enhanced leak detection: show traceback for largest leak + if total_size_diff_mb > 10: # Only if leak is significant (>10MB) + logger.warning(f"SIGNIFICANT MEMORY LEAK DETECTED: {total_size_diff_mb:.2f} MB") + logger.info("\n" + "=" * 80) + logger.info("Detailed traceback for top 3 memory leaks:") + logger.info("=" * 80) + for i, stat in enumerate(top_stats[:3], 1): + if stat.size_diff > 0: + logger.info(f"\n--- Leak #{i}: +{stat.size_diff / 1024 / 1024:.2f} MB, {stat.count_diff} objects ---") + logger.info(f"File: {stat.traceback.format()[0] if stat.traceback else 'Unknown'}") + if stat.traceback and len(stat.traceback) > 1: + logger.info("Full traceback:") + for line in stat.traceback.format(): + logger.info(f" {line}") + logger.info("=" * 80) + + # Update snapshot for next comparison + self._memory_snapshot = current_snapshot def rollout_static( self, @@ -174,6 +234,9 @@ def rollout_swarm( # noqa: C901 each thread re-spawn after complete, until reaching conditions to stop. """ + # # Memory leak detection: compare with previous snapshot + # self._check_memory_leak() + tracker_array: List[SingleAgentContextTracker] = [] rollout_n = self.rollout_n n_batch_task = len(tasks) @@ -403,6 +466,21 @@ def update_rollout_result_array_preview(observation_window, completed_task_id_ma update_rollout_result_array_preview(observation_window, completed_task_id_map_ct) self._write_swarm_rollout_dynamic_log(observation_window) + + # Explicit cleanup to prevent memory leaks + logger.debug("Performing explicit cleanup...") + # Clear futures list + futures.clear() + # Clear observation window + observation_window.clear() + # Delete local function references to break circular refs + del stop_condition_callback + del stop_condition + del update_rollout_result_array_preview + del count_tasks + # Force garbage collection + gc.collect() + return tracker_array diff --git a/ajet/task_rollout/single_worker.py b/ajet/task_rollout/single_worker.py index f8b51a66..baca7e34 100644 --- a/ajet/task_rollout/single_worker.py +++ b/ajet/task_rollout/single_worker.py @@ -72,6 +72,10 @@ def __init__( max_llm_retries=max_llm_retries, ) + # Memory leak tracking + self._memory_snapshot = None + self._tracemalloc_started = False + @retry_with_backoff(max_retry_attr="max_llm_retries") def rollout_env_worker( self, diff --git a/tutorial/opencode_build_skillbench_agent.prompt.md b/tutorial/opencode_build_skillbench_agent.prompt.md new file mode 100644 index 00000000..4a290cf9 --- /dev/null +++ b/tutorial/opencode_build_skillbench_agent.prompt.md @@ -0,0 +1,17 @@ +# Train SkillBench with AgentJet Swarm with Vibe Coding + +result is generated by `claude sonnet 4.5` + +============================= + +你的任务是训练这个仓库中的智能体:https://github.com/benchflow-ai/skillsbench.git +仓库你需要下载到 ../skillsbench_swarm_test +这是在调试过程中你可以使用的模型(openrouter) + "url": "https://openrouter-openrouter-esyubhyrxv.ap-northeast-1.fcapp.run/api/v1", + "key": "sk-or-v1-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + "model": "qwen/qwen3-max" + + + +你的skill(首先读取该SKILL文件,获取必要知识): +- ajet/copilot/train-complex-blackbox/SKILL.md