Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,5 @@ modelscope_cache
prompts
swarmexp
swarmlog
werewolves_swarm
.claude
2 changes: 1 addition & 1 deletion ajet/backbone/main_trinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def patched_trainer_get_actor(cls, config: Config):
Explorer.get_actor = classmethod(patched_explorer_get_actor)
Trainer.get_actor = classmethod(patched_trainer_get_actor)

if ajet_config.ajet.enable_experimental_interchange_server:
if ajet_config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
start_interchange_server(ajet_config)

Expand Down
4 changes: 2 additions & 2 deletions ajet/backbone/main_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run_ppo(config: DictConfig) -> None:
def on_shutdown():
if ray.is_initialized():
ray.shutdown()
if config.ajet.enable_experimental_interchange_server:
if config.ajet.enable_interchange_server:
if config.ajet.enable_swarm_mode:
from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status
print("Changing engine status to OFFLINE before shutdown...")
Expand Down Expand Up @@ -250,7 +250,7 @@ def run(self, config):

from ajet.backbone.trainer_verl import AjetRayPPOTrainer

if config.ajet.enable_experimental_interchange_server:
if config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
start_interchange_server(config)

Expand Down
2 changes: 1 addition & 1 deletion ajet/backbone/main_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def main(config):
os.environ.update(runtime_env["env_vars"])
# atexit.register(lambda: print("Process exiting, performing cleanup..."))

if config.ajet.enable_experimental_interchange_server:
if config.ajet.enable_interchange_server:
from ajet.tuner_lib.experimental.as_oai_model_server import start_interchange_server
start_interchange_server(config)
if config.ajet.enable_swarm_mode:
Expand Down
4 changes: 2 additions & 2 deletions ajet/backbone/trainer_verl.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def init_workers(self):
)

def _update_interchange_server_status_flag(self, status: str):
if self.config.ajet.enable_experimental_interchange_server:
if self.config.ajet.enable_interchange_server:
if self.config.ajet.enable_swarm_mode:
from ajet.tuner_lib.experimental.interchange_utils import http_change_engine_status
http_change_engine_status(self.config, status, global_step=self.global_steps)
Expand Down Expand Up @@ -858,7 +858,7 @@ def fit(self): # noqa: C901
self.global_steps += 1

# # when enabled oai request interchange, we need to clear the cache from time to time
# if self.config.ajet.enable_experimental_interchange_server:
# if self.config.ajet.enable_interchange_server:
# from ajet.tuner_lib.experimental.as_oai_model_server import ensure_dat_interchange_server_cache_clear
# ensure_dat_interchange_server_cache_clear()

Expand Down
2 changes: 1 addition & 1 deletion ajet/backbone/warm_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def warm_up_task_judge_when_needed(config):
def clean_up_tmp_ajet_dir(config):
"""Clean up old IPC socket files in /tmp/ajet directory."""
import time
if config.ajet.enable_experimental_interchange_server is False:
if config.ajet.enable_interchange_server is False:
return

tmp_dir = "/tmp/ajet"
Expand Down
18 changes: 4 additions & 14 deletions ajet/context_tracker/multiagent_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,6 @@ def extract_text_content_from_content_dict(self, msg):
# },
# ],
# }
# or tool_result format?? not observed yet:
# msg = {
# "role": "tool",
# "content": [
# {
# "type": "tool_result",
# "id": "call_xxx",
# "output": "tool output content",
# "name": "tool_name"
# },
# ],
# }


str_content = ""
for item in msg["content"]:
Expand Down Expand Up @@ -332,6 +319,7 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline):
)
):
logger.bind(exception=True).info(f"General Warning: merge failure discovered.\n")
# from ajet import bp; bp("SWARM")
return


Expand All @@ -346,7 +334,9 @@ def detect_tool_call_madness(self, llm_output):
# llm_output["tool_calls"] is not None, and is not []
tool_calls = llm_output["tool_calls"]
if "wrong_toolcall" in self.config.ajet.rollout.compute_madness_checklist:
copy_tool_calls = copy.deepcopy(tool_calls)
# copy_tool_calls = copy.deepcopy(tool_calls)
# Shallow copy is sufficient - we're only reading the data
copy_tool_calls = tool_calls
wrong_toolcall = False
for i in range(len(copy_tool_calls)):
if ("function" in copy_tool_calls[i]) and (
Expand Down
206 changes: 99 additions & 107 deletions ajet/copilot/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from __future__ import annotations

import os
import time
import yaml
import tempfile
from types import SimpleNamespace
from typing import Any, Callable, Union

import yaml
from types import SimpleNamespace
from typing import Any, Callable, Union, cast
from loguru import logger


from ajet.default_config.ajet_default import Config
from ajet.utils.config_utils import (
expand_ajet_hierarchical_config,
Expand All @@ -30,70 +29,118 @@
setup_environment_vars,
)

DEFAULT_DIR = "saved_experiments"

def override_current_yaml_value_if_given(override_value, current_value):
if override_value is not None:
return override_value
else:
return current_value

def _set_nested_attr(obj, attr_path: str, value):
keys = attr_path.split(".")
for key in keys[:-1]:
obj = getattr(obj, key)
setattr(obj, keys[-1], value)

def _get_nested_attr(obj, attr_path: str):
for key in attr_path.split("."):
obj = getattr(obj, key)
return obj

class AgentJetJob:
"""Lightweight builder that launches AgentJet training as a subprocess."""
"""
arg: base_yaml_config + **kwargs (yaml config, then override with kwargs)
arg: base_yaml_config (yaml config)
arg: **kwargs (yaml config, then override with kwargs)
"""

def __init__(
self,
backbone: str = "verl",
model: str = "Qwen/Qwen2___5-7B-Instruct",
n_gpu: int = 8,
algorithm: str = "grpo",
project_name="ajet-swarm",
experiment_name="test",
n_gpu_for_infer: int | None = None, # only for trinity backbone
num_repeat: int = 8,
batch_size: int = 32,
swarm_mode: bool = True,
sample_collection_method: str = "rollout_until_finish_enough_tasks",
*kwargs,
base_yaml_config: str | None = None,
experiment_dir: str | None = None,
project_name: str | None = None,
experiment_name: str | None = None,
n_gpu: int | None = None,
model: str | None = None,
algorithm: str | None = None,
num_repeat: int | None = None,
batch_size: int | None = None,
swarm_mode: bool | None = None,
swarm_mode_sample_collection_method: str | None = None,
max_env_worker: int | None = None,
backbone: str | None = None,
) -> None:
self.backbone = backbone
self.exp_dir = DEFAULT_DIR
self.project_name = project_name
self.exp_name = experiment_name
self.sample_collection_method = sample_collection_method
if swarm_mode:
default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml"))

if base_yaml_config is None:
base_yaml_config = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml"))
else:
default_yaml = None
self.config_as_dict: dict = self.build_job_from_yaml(default_yaml)
logger.warning(f"Reading config from {base_yaml_config}.")
time.sleep(1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The time.sleep(1) here is a code smell. Sleeps like this are often used to work around race conditions or to ensure a log message is visible before a potential crash. This can hide underlying issues and make the code's behavior dependent on timing. It would be better to identify and fix the root cause rather than using a sleep.

self.config_as_dict: dict = self.build_job_from_yaml(base_yaml_config)
self.config = Config.update_from_dict_recursive(Config(), self.config_as_dict)
Comment on lines +79 to 80

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a critical bug in the __init__ method's logic. The call to self.build_job_from_yaml(base_yaml_config) on line 79 happens before the instance attributes it depends on (like self.experiment_name, self.backbone, and self.experiment_dir) are initialized. These attributes are only assigned values later in the method.

This will lead to an AttributeError when build_job_from_yaml is called, as it tries to access these uninitialized attributes. The logic has a circular dependency:

  1. build_job_from_yaml is called to create self.config.
  2. build_job_from_yaml needs attributes like self.experiment_dir.
  3. The final value of self.experiment_dir is determined by the override logic (lines 112-128).
  4. The override logic needs self.config to get the base value from the YAML.

This initialization flow needs to be re-architected to resolve this circular dependency. A possible approach is to first determine the values for parameters required for loading the configuration (like experiment_name, backbone, experiment_dir) by checking the __init__ arguments, and then use them to call read_ajet_hierarchical_config. After the config is loaded, the rest of the parameters can be overridden.


self.config.ajet.experiment_name = experiment_name
self.config.ajet.backbone = backbone
self.config.ajet.model.path = model
self.config.ajet.trainer_common.n_gpus_per_node = n_gpu
self.config.ajet.trainer_common.algorithm.adv_estimator = algorithm
self.config.ajet.rollout.num_repeat = num_repeat
self.config.ajet.data.train_batch_size = batch_size
self.config.ajet.enable_swarm_mode = swarm_mode
self.config.ajet.swarm_mode_sample_collection_method = sample_collection_method
if n_gpu_for_infer is None and backbone == "trinity":
raise ValueError("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone.")
if (n_gpu_for_infer is not None) and backbone == "verl":
raise ValueError("n_gpu_for_infer is only for trinity backbone, please set it to `None`.")
else:
if backbone == "trinity":
assert isinstance(n_gpu_for_infer, int), f"`n_gpu_for_infer` should be int, got {type(n_gpu_for_infer)}."
assert n_gpu_for_infer < n_gpu, "`n_gpu_for_infer` should be less than `n_gpu`."
self.config.ajet.rollout.n_vllm_engine = n_gpu_for_infer
self.config.ajet.rollout.tensor_model_parallel_size = 1
self.base_yaml_config: str = cast(str, base_yaml_config) # currently may be None, but will be set later
self.experiment_dir: str = cast(str, experiment_dir)
self.project_name: str = cast(str, project_name)
self.experiment_name: str = cast(str, experiment_name)
self.n_gpu: int = cast(int, n_gpu)
self.model: str = cast(str, model)
self.algorithm: str = cast(str, algorithm)
self.num_repeat: int = cast(int, num_repeat)
self.batch_size: int = cast(int, batch_size)
self.swarm_mode: bool = cast(bool, swarm_mode)
self.swarm_mode_sample_collection_method: str = cast(str, swarm_mode_sample_collection_method)
self.max_env_worker: int = cast(int, max_env_worker)
self.backbone: str = cast(str, backbone)

# see `ajet/default_config/ajet_ts_default.yaml`
overrides = {
"ajet.experiment_dir": "experiment_dir",
"ajet.project_name": "project_name",
"ajet.experiment_name": "experiment_name",
"ajet.model.path": "model",
"ajet.trainer_common.n_gpus_per_node": "n_gpu",
"ajet.trainer_common.algorithm.adv_estimator": "algorithm",
"ajet.rollout.num_repeat": "num_repeat",
"ajet.data.train_batch_size": "batch_size",
"ajet.enable_swarm_mode": "swarm_mode",
"ajet.swarm_mode_sample_collection_method": "swarm_mode_sample_collection_method",
"ajet.rollout.max_env_worker": "max_env_worker",
"ajet.backbone": "backbone",
}

# if any value given in kwargs, override the corresponding value in config
for attr_path, override_val in overrides.items():
# get value from yaml config
# >> e.g. current_model = self.config.model.path
current_val = _get_nested_attr(self.config, attr_path)

# if override_val (given in __init__) is not None, use it to override the value from yaml config
# >> e.g. new_model = self.model if (self.model is not None) else current_model
new_val = override_current_yaml_value_if_given(getattr(self, override_val), current_val)

# write final value to `self.config``
# >> e.g. self.config.model.path = new_model
_set_nested_attr(self.config, attr_path, new_val)

# write final value to `self`
# >> e.g. self.model = new_model
setattr(self, override_val, new_val)

if self.backbone == "trinity":
raise NotImplementedError("Trinity backbone is not yet supported in AgentJetJob.")


def build_job_from_yaml(self, yaml_path: str | None) -> dict:
self.config_as_dict = read_ajet_hierarchical_config(
yaml_path,
exp_name=self.exp_name,
backbone=self.backbone,
write_to=None,
exp_dir=self.exp_dir,
)
self.config_as_dict = expand_ajet_hierarchical_config(self.config_as_dict, write_to=None)
logger.info(f"Built AgentJet job config: {yaml_path}")
return self.config_as_dict


def dump_job_as_yaml(self, yaml_path: str) -> str:
if os.path.dirname(yaml_path):
os.makedirs(os.path.dirname(yaml_path), exist_ok=True)
Expand All @@ -102,6 +149,7 @@ def dump_job_as_yaml(self, yaml_path: str) -> str:
logger.info(f"Saved training config to {yaml_path}")
return yaml_path


def set_workflow(
self, workflow: Union[str, Callable[..., Any]], ensure_reward_in_workflow: bool = False
) -> "AgentJetJob":
Expand All @@ -110,6 +158,7 @@ def set_workflow(
# ensure_reward_in_workflow
return self


def set_data(
self,
type: str,
Expand All @@ -136,60 +185,3 @@ def set_data(

return self

def tune(self, *args, **kwargs) -> "AgentJetJob":
import ray
ast_cfg = self.config.ajet
if not ast_cfg.rollout or not ast_cfg.rollout.user_workflow:
raise ValueError("Workflow must be set via set_workflow before tuning.")
if not ast_cfg.task_reader:
raise ValueError("Data source must be set via set_data before tuning.")

backbone = self.config.ajet.backbone
exp_dir = self.config.ajet.experiment_dir

with tempfile.NamedTemporaryFile(mode="w+", delete=False, suffix=".yaml") as temp_yaml:
yaml_path = temp_yaml.name
self.dump_job_as_yaml(yaml_path)
args = SimpleNamespace(
conf=yaml_path,
backbone=backbone,
exp_dir=exp_dir,
with_logview=False,
debug=False,
)

if args.backbone != "debug":
# Enforce GPU availability and free memory threshold before proceeding
check_avail_gpu(min_free_ratio=0.95)

# finalize experiment config
main_yaml_fp, exe_exp_base, exp_name, exp_config = prepare_experiment_config(
yaml_path, exp_dir, backbone
)

# setup environment variables for ray
env = setup_environment_vars(args, exp_config, main_yaml_fp)

# start ray if not already started
if not ray.is_initialized():
from ajet.utils.launch_utils import start_ray_service

start_ray_service(args, env)
else:
raise RuntimeError(
"Ray is already initialized. Please shutdown existing Ray instance before starting a new tuning job."
)

# start training process
if args.conf and main_yaml_fp and exe_exp_base and exp_config:
execute_training_process(
args,
get_backbone_target(args.backbone),
main_yaml_fp,
exe_exp_base,
main_yaml_fp,
env,
exp_config,
)

return self
Loading
Loading