diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index 3ccb526b75..c285cb6306 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -1,6 +1,6 @@ services: trinity-node-1: - image: trinity-rft-megatron:latest-unittest + image: trinity-rft-unittest:20250918 pull_policy: never command: sh -c "pip install -e .[dev] && ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block" environment: @@ -28,7 +28,7 @@ services: capabilities: [gpu] trinity-node-2: - image: trinity-rft-megatron:latest-unittest + image: trinity-rft-unittest:20250918 pull_policy: never command: sh -c "pip install -e .[dev] && ray start --address=trinity-node-1:6379 --block" environment: diff --git a/benchmark/config/countdown-template.yaml b/benchmark/config/countdown-template.yaml index 9cd84c7cc8..ce12745766 100644 --- a/benchmark/config/countdown-template.yaml +++ b/benchmark/config/countdown-template.yaml @@ -40,7 +40,6 @@ buffer: experience_buffer: name: experience_buffer storage_type: queue - path: '' use_priority_queue: true replay_buffer_kwargs: priority_fn: linear_decay @@ -133,7 +132,6 @@ trainer: default_hdfs_dir: null remove_previous_ckpt_in_save: false del_local_ckpt_after_load: false - val_before_train: false max_actor_ckpt_to_keep: null max_critic_ckpt_to_keep: null critic: diff --git a/benchmark/config/gsm8k-template.yaml b/benchmark/config/gsm8k-template.yaml index 2a9127a237..f27ca38201 100644 --- a/benchmark/config/gsm8k-template.yaml +++ b/benchmark/config/gsm8k-template.yaml @@ -45,7 +45,6 @@ buffer: experience_buffer: name: experience_buffer storage_type: queue - path: '' use_priority_queue: true replay_buffer_kwargs: priority_fn: linear_decay @@ -131,7 +130,6 @@ trainer: default_hdfs_dir: null remove_previous_ckpt_in_save: false del_local_ckpt_after_load: false - val_before_train: false max_actor_ckpt_to_keep: null max_critic_ckpt_to_keep: null monitor: diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 4b324ba9d4..88579f2fcc 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -617,7 +617,6 @@ trainer: default_hdfs_dir: null remove_previous_ckpt_in_save: False del_local_ckpt_after_load: False - val_before_train: False max_actor_ckpt_to_keep: 5 max_critic_ckpt_to_keep: 5 ``` diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 8219ba88c4..b7151020de 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -616,7 +616,6 @@ trainer: default_hdfs_dir: null remove_previous_ckpt_in_save: False del_local_ckpt_after_load: False - val_before_train: False max_actor_ckpt_to_keep: 5 max_critic_ckpt_to_keep: 5 ``` diff --git a/pyproject.toml b/pyproject.toml index dbcd513e35..451be9db4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ agent = [ "agentscope" ] rm_gallery = [ - "rm-gallery>=0.1.1" + "rm-gallery>=0.1.5" ] dev = [ "pre-commit>=2.17.0", @@ -123,6 +123,3 @@ known_third_party = ["wandb"] [project.urls] "Homepage" = "https://github.com/modelscope/Trinity-RFT" "Documentation" = "https://modelscope.github.io/Trinity-RFT/" - -[tool.uv] -override-dependencies=["math_verify>=0.8.0"] # rm-gallery requires math_verify<0.8.0 which is not compatible with trinity-rft diff --git a/scripts/docker/Dockerfile b/scripts/docker/Dockerfile index 153e7f1f9b..dfb6854240 100644 --- a/scripts/docker/Dockerfile +++ b/scripts/docker/Dockerfile @@ -15,7 +15,7 @@ RUN apt update && apt install -y \ python3 python3-pip python3-dev python3-packaging \ libomp-dev infiniband-diags libibverbs-dev librdmacm-dev rdma-core perftest \ && rm -rf /var/lib/apt/lists/* \ - && ln -sf /usr/bin/python3 /usr/bin/python + && ln -sf /usr/bin/python3 /usr/bin/python \ && ln -sf /usr/bin/pip3 /usr/bin/pip diff --git a/scripts/docker_for_megatron/Dockerfile b/scripts/docker_for_megatron/Dockerfile index 80c083a835..ef0c5e2137 100644 --- a/scripts/docker_for_megatron/Dockerfile +++ b/scripts/docker_for_megatron/Dockerfile @@ -18,7 +18,7 @@ RUN apt update && apt install -y \ python3 python3-pip python3-dev python3-packaging \ libomp-dev infiniband-diags libibverbs-dev librdmacm-dev rdma-core perftest \ && rm -rf /var/lib/apt/lists/* \ - && ln -sf /usr/bin/python3 /usr/bin/python + && ln -sf /usr/bin/python3 /usr/bin/python \ && ln -sf /usr/bin/pip3 /usr/bin/pip # For Aliyun users: update pip mirror to aliyun to speed up pip install diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 40aa80d606..fb9cdf670d 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -121,7 +121,7 @@ def setUp(self): pprint(self.config) self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper( - self.engines[0], model_type="vllm_async", enable_history=self.enable_history + self.engines[0], engine_type="vllm", enable_history=self.enable_history ) async def test_generate( @@ -240,7 +240,7 @@ def setUp(self): self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper(self.engines[0], model_type="vllm", enable_history=True) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) def test_model_len(self): messages = [ @@ -277,7 +277,7 @@ def setUp(self): self.config = get_template_config() self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" + self.config.explorer.rollout_model.engine_type = "vllm" self.config.explorer.rollout_model.engine_num = 1 self.config.explorer.rollout_model.tensor_parallel_size = 1 self.config.explorer.rollout_model.use_v1 = True @@ -286,11 +286,9 @@ def setUp(self): self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper( - self.engines[0], model_type="vllm_async", enable_history=True - ) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) self.model_wrapper_no_history = ModelWrapper( - self.engines[0], model_type="vllm_async", enable_history=False + self.engines[0], engine_type="vllm", enable_history=False ) def test_api(self): @@ -348,7 +346,7 @@ def setUp(self): self.config = get_template_config() self.config.mode = "explore" self.config.model.model_path = get_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" + self.config.explorer.rollout_model.engine_type = "vllm" self.config.explorer.rollout_model.engine_num = 1 self.config.explorer.rollout_model.tensor_parallel_size = 1 self.config.explorer.rollout_model.use_v1 = True @@ -357,11 +355,9 @@ def setUp(self): self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper( - self.engines[0], model_type="vllm_async", enable_history=True - ) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) self.model_wrapper_no_history = ModelWrapper( - self.engines[0], model_type="vllm_async", enable_history=False + self.engines[0], engine_type="vllm", enable_history=False ) async def test_api_async(self): @@ -537,7 +533,7 @@ def setUp(self): self.config = get_template_config() self.config.mode = "explore" self.config.model.model_path = get_api_model_path() - self.config.explorer.rollout_model.engine_type = "vllm_async" + self.config.explorer.rollout_model.engine_type = "vllm" self.config.explorer.rollout_model.engine_num = 1 self.config.explorer.rollout_model.tensor_parallel_size = 1 self.config.explorer.rollout_model.use_v1 = True @@ -551,11 +547,9 @@ def setUp(self): self.config.check_and_update() self.engines, self.auxiliary_engines = create_inference_models(self.config) - self.model_wrapper = ModelWrapper( - self.engines[0], model_type="vllm_async", enable_history=True - ) + self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) self.model_wrapper_no_history = ModelWrapper( - self.engines[0], model_type="vllm_async", enable_history=False + self.engines[0], engine_type="vllm", enable_history=False ) def test_api_tool_calls(self): diff --git a/tests/template/verl_config.yaml b/tests/template/verl_config.yaml index 1027ef31d0..d46fd20499 100644 --- a/tests/template/verl_config.yaml +++ b/tests/template/verl_config.yaml @@ -92,6 +92,5 @@ trainer: default_hdfs_dir: null remove_previous_ckpt_in_save: False del_local_ckpt_after_load: False - val_before_train: False max_actor_ckpt_to_keep: 1 max_critic_ckpt_to_keep: 1 diff --git a/trinity/buffer/storage/file.py b/trinity/buffer/storage/file.py index 6f393cc48f..9af8fc5520 100644 --- a/trinity/buffer/storage/file.py +++ b/trinity/buffer/storage/file.py @@ -34,7 +34,7 @@ class FileStorage: """ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - if storage_config.path is None: + if not storage_config.path: storage_config.path = default_storage_path(storage_config, config) ext = os.path.splitext(storage_config.path)[-1] if ext != ".jsonl" and ext != ".json": diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 82f6bf2279..11251e7462 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -219,7 +219,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.queue = QueueBuffer.get_queue(storage_config, config) st_config = deepcopy(storage_config) st_config.wrap_in_ray = False - if st_config.path is not None: + if st_config.path: if is_database_url(st_config.path): from trinity.buffer.writer.sql_writer import SQLWriter diff --git a/trinity/common/config.py b/trinity/common/config.py index 537d371bf6..ac2764506f 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -26,6 +26,11 @@ logger = get_logger(__name__) +def set_if_none(obj, attr, val): + if getattr(obj, attr, None) is None: + setattr(obj, attr, val) + + @dataclass class FormatConfig: """Configuration for data formatting""" @@ -70,6 +75,7 @@ class GenerationConfig: top_p: float = 1.0 top_k: int = -1 logprobs: int = 0 # vLLM return `logprobs + 1` elements + max_tokens: Optional[int] = None # if None, use model.max_response_tokens # repeat each task for `n` times # ! DO NOT SET in `buffer.explorer_input.taskset.rollout_args` n: int = 1 @@ -583,112 +589,84 @@ def _check_interval(self) -> None: def _check_buffer(self) -> None: # noqa: C901 # TODO: split this function into different buffer read/writer # check explorer_input - if self.mode != "train" and not self.buffer.explorer_input.taskset.path: + trainer_input = self.buffer.trainer_input + experience_buffer = trainer_input.experience_buffer + explorer_input = self.buffer.explorer_input + taskset = explorer_input.taskset + + if self.mode != "train" and not taskset.path: raise ValueError( "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset." ) - if not self.buffer.explorer_input.taskset.name: - self.buffer.explorer_input.taskset.name = "taskset" - if ( - self.buffer.explorer_input.taskset.repeat_times is None - or self.buffer.explorer_input.taskset.repeat_times != self.algorithm.repeat_times - ): - self.buffer.explorer_input.taskset.repeat_times = self.algorithm.repeat_times + if not taskset.name: + taskset.name = "taskset" + if taskset.repeat_times is None or taskset.repeat_times != self.algorithm.repeat_times: + taskset.repeat_times = self.algorithm.repeat_times logger.info( "`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`" f" (={self.algorithm.repeat_times})." ) if self.mode == "train": assert ( - self.buffer.trainer_input.experience_buffer is not None + experience_buffer is not None ), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`." - self.buffer.trainer_input.experience_buffer.total_epochs = self.buffer.total_epochs - self.buffer.trainer_input.experience_buffer.total_steps = self.buffer.total_steps + experience_buffer.total_epochs = self.buffer.total_epochs + experience_buffer.total_steps = self.buffer.total_steps else: - self.buffer.explorer_input.taskset.is_eval = False - self.buffer.explorer_input.taskset.total_epochs = self.buffer.total_epochs - self.buffer.explorer_input.taskset.total_steps = self.buffer.total_steps - if self.buffer.explorer_input.taskset.default_workflow_type is None: - self.buffer.explorer_input.taskset.default_workflow_type = ( - self.buffer.explorer_input.default_workflow_type - ) - if self.buffer.explorer_input.taskset.default_eval_workflow_type is None: - self.buffer.explorer_input.taskset.default_eval_workflow_type = ( - self.buffer.explorer_input.default_eval_workflow_type - ) - if self.buffer.explorer_input.taskset.default_reward_fn_type is None: - self.buffer.explorer_input.taskset.default_reward_fn_type = ( - self.buffer.explorer_input.default_reward_fn_type - ) - if self.buffer.explorer_input.taskset.format.system_prompt is None: - self.buffer.explorer_input.taskset.format.system_prompt = ( - self.buffer.explorer_input.system_prompt - ) - if self.buffer.explorer_input.taskset.format.reply_prefix is None: - self.buffer.explorer_input.taskset.format.reply_prefix = ( - self.buffer.explorer_input.reply_prefix - ) - if self.buffer.explorer_input.taskset.ray_namespace is None: - self.buffer.explorer_input.taskset.ray_namespace = self.ray_namespace + taskset.is_eval = False + taskset.total_epochs = self.buffer.total_epochs + taskset.total_steps = self.buffer.total_steps + + set_if_none(taskset, "default_workflow_type", explorer_input.default_workflow_type) + set_if_none( + taskset, "default_eval_workflow_type", explorer_input.default_eval_workflow_type + ) + set_if_none(taskset, "default_reward_fn_type", explorer_input.default_reward_fn_type) + set_if_none(taskset.format, "system_prompt", explorer_input.system_prompt) + set_if_none(taskset.format, "reply_prefix", explorer_input.reply_prefix) + set_if_none(taskset, "ray_namespace", self.ray_namespace) + set_if_none(taskset.rollout_args, "max_tokens", self.model.max_response_tokens) remained_tasksets = [] - for idx, dataset in enumerate(self.buffer.explorer_input.eval_tasksets): + for idx, dataset in enumerate(explorer_input.eval_tasksets): if not dataset.path: logger.warning(f"Eval dataset [{dataset}]'s path is not configured. Skip.") continue dataset.is_eval = True if not dataset.name: dataset.name = f"eval_taskset_{idx}" - if dataset.repeat_times is None: - dataset.repeat_times = 1 - if dataset.default_workflow_type is None: - dataset.default_workflow_type = self.buffer.explorer_input.default_workflow_type - if dataset.default_eval_workflow_type is None: - dataset.default_eval_workflow_type = ( - self.buffer.explorer_input.default_eval_workflow_type - ) - if dataset.default_reward_fn_type is None: - dataset.default_reward_fn_type = self.buffer.explorer_input.default_reward_fn_type - if dataset.format.system_prompt is None: - dataset.format.system_prompt = self.buffer.explorer_input.system_prompt - if dataset.format.reply_prefix is None: - dataset.format.reply_prefix = self.buffer.explorer_input.reply_prefix - if dataset.ray_namespace is None: - dataset.ray_namespace = self.ray_namespace + set_if_none(dataset, "repeat_times", 1) + set_if_none(dataset, "default_workflow_type", explorer_input.default_workflow_type) + set_if_none( + dataset, "default_eval_workflow_type", explorer_input.default_eval_workflow_type + ) + set_if_none(dataset, "default_reward_fn_type", explorer_input.default_reward_fn_type) + set_if_none(dataset.format, "system_prompt", explorer_input.system_prompt) + set_if_none(dataset.format, "reply_prefix", explorer_input.reply_prefix) + set_if_none(dataset, "ray_namespace", self.ray_namespace) + set_if_none(dataset.rollout_args, "max_tokens", self.model.max_response_tokens) remained_tasksets.append(dataset) - self.buffer.explorer_input.eval_tasksets = remained_tasksets + explorer_input.eval_tasksets = remained_tasksets # check trainer_input.experience_buffer - if self.buffer.trainer_input.experience_buffer is None: - self.buffer.trainer_input.experience_buffer = StorageConfig( + if experience_buffer is None: + experience_buffer = trainer_input.experience_buffer = StorageConfig( name="experience_buffer", storage_type=StorageType.QUEUE, ) - logger.info( - f"Auto set `buffer.trainer_input.experience_buffer` to {self.buffer.trainer_input.experience_buffer}" - ) - elif ( - self.buffer.trainer_input.experience_buffer.storage_type is StorageType.FILE - and self.mode == "both" - ): + logger.info(f"Auto set `buffer.trainer_input.experience_buffer` to {experience_buffer}") + elif experience_buffer.storage_type is StorageType.FILE and self.mode == "both": logger.warning( "`FILE` storage is not supported to use as experience_buffer in `both` mode, use `QUEUE` instead." ) - self.buffer.trainer_input.experience_buffer.storage_type = StorageType.QUEUE + experience_buffer.storage_type = StorageType.QUEUE from trinity.algorithm.algorithm import ALGORITHM_TYPE - self.buffer.trainer_input.experience_buffer.schema_type = ALGORITHM_TYPE.get( - self.algorithm.algorithm_type - ).schema - - if self.buffer.trainer_input.experience_buffer.ray_namespace is None: - self.buffer.trainer_input.experience_buffer.ray_namespace = self.ray_namespace + experience_buffer.schema_type = ALGORITHM_TYPE.get(self.algorithm.algorithm_type).schema - if self.buffer.trainer_input.experience_buffer.format.chat_template is None: - self.buffer.trainer_input.experience_buffer.format.chat_template = ( - self.model.custom_chat_template - ) + set_if_none(experience_buffer, "ray_namespace", self.ray_namespace) + set_if_none(experience_buffer.format, "chat_template", self.model.custom_chat_template) # create buffer.cache_dir at ///buffer self.buffer.cache_dir = os.path.abspath(os.path.join(self.checkpoint_job_dir, "buffer")) @@ -701,39 +679,35 @@ def _check_buffer(self) -> None: # noqa: C901 ) # check input/output buffers in pipelines - if self.data_processor.experience_pipeline is not None: - if ( - self.data_processor.experience_pipeline.save_input - and self.data_processor.experience_pipeline.input_save_path is None - ): - self.data_processor.experience_pipeline.input_save_path = os.path.join( + experience_pipeline = self.data_processor.experience_pipeline + if experience_pipeline is not None: + if experience_pipeline.save_input and experience_pipeline.input_save_path is None: + experience_pipeline.input_save_path = os.path.join( self.buffer.cache_dir, "explorer_output.jsonl" ) logger.info( - f"Auto set `data_processor.experience_pipeline.input_save_path` to {self.data_processor.experience_pipeline.input_save_path}" + f"Auto set `data_processor.experience_pipeline.input_save_path` to {experience_pipeline.input_save_path}" ) - if self.data_processor.task_pipeline is not None: - if self.data_processor.task_pipeline.output is None: - if self.buffer.explorer_input.taskset.path is not None: - self.data_processor.task_pipeline.output = self.buffer.explorer_input.taskset + + task_pipeline = self.data_processor.task_pipeline + if task_pipeline is not None: + if task_pipeline.output is None: + if taskset.path is not None: + task_pipeline.output = taskset elif ( - self.buffer.trainer_input.experience_buffer.schema_type in {"dpo", "sft"} - and self.buffer.trainer_input.experience_buffer.path is not None + experience_buffer.schema_type in {"dpo", "sft"} + and experience_buffer.path is not None ): - self.data_processor.task_pipeline.output = ( - self.buffer.trainer_input.experience_buffer - ) + task_pipeline.output = experience_buffer else: raise ValueError( "`data_processor.task_pipeline.output` is required when both " "`buffer.explorer_input.taskset.path` and `buffer.trainer_input.experience_buffer.path` are " "None" ) - if self.data_processor.task_pipeline.output.path and os.path.exists( - self.data_processor.task_pipeline.output.path - ): + if task_pipeline.output.path and os.path.exists(task_pipeline.output.path): raise ValueError( - f"Task pipeline output path {self.data_processor.task_pipeline.output.path} already exists.\n" + f"Task pipeline output path {task_pipeline.output.path} already exists.\n" "Please choose a different output path to avoid overwriting." ) @@ -790,15 +764,13 @@ def _check_algorithm(self) -> None: } default_config.update(algorithm.default_config()) for key, value in default_config.items(): - if getattr(self.algorithm, key, None) is None: - setattr(self.algorithm, key, value) + set_if_none(self.algorithm, key, value) def check_and_set(name, registry, args_attr): fn_cls = registry.get(getattr(self.algorithm, name)) if fn_cls is None: raise ValueError(f"Invalid {name}: {getattr(self.algorithm, name)}") - if getattr(self.algorithm, args_attr) is None: - setattr(self.algorithm, args_attr, fn_cls.default_args()) + set_if_none(self.algorithm, args_attr, fn_cls.default_args()) return fn_cls check_and_set("sample_strategy", SAMPLE_STRATEGY, "sample_strategy_args") @@ -809,44 +781,37 @@ def check_and_set(name, registry, args_attr): check_and_set("entropy_loss_fn", ENTROPY_LOSS_FN, "entropy_loss_fn_args") def _check_model(self) -> None: - if not self.model.critic_model_path: - self.model.critic_model_path = self.model.model_path + model = self.model + if not model.critic_model_path: + model.critic_model_path = model.model_path # check max_model_len, max_prompt_tokens, max_response_tokens # if all three are set, check if they are valid if ( - self.model.max_model_len is not None - and self.model.max_prompt_tokens is not None - and self.model.max_response_tokens is not None + model.max_model_len is not None + and model.max_prompt_tokens is not None + and model.max_response_tokens is not None ): - if ( - self.model.max_prompt_tokens + self.model.max_response_tokens - > self.model.max_model_len - ): + if model.max_prompt_tokens + model.max_response_tokens > model.max_model_len: raise ValueError( - f"`max_prompt_tokens` + `max_response_tokens` ({self.model.max_prompt_tokens} + {self.model.max_response_tokens}) " - f"exceeds `max_model_len` ({self.model.max_model_len}). Please adjust them accordingly." + f"`max_prompt_tokens` + `max_response_tokens` ({model.max_prompt_tokens} + {model.max_response_tokens}) " + f"exceeds `max_model_len` ({model.max_model_len}). Please adjust them accordingly." ) # check max_model_len first - if self.model.max_model_len is None: - if ( - self.model.max_prompt_tokens is not None - and self.model.max_response_tokens is not None - ): - self.model.max_model_len = ( - self.model.max_prompt_tokens + self.model.max_response_tokens - ) + if model.max_model_len is None: + if model.max_prompt_tokens is not None and model.max_response_tokens is not None: + model.max_model_len = model.max_prompt_tokens + model.max_response_tokens logger.warning( - f"`max_model_len` is set to {self.model.max_model_len} from `max_prompt_tokens` and `max_response_tokens`." + f"`max_model_len` is set to {model.max_model_len} from `max_prompt_tokens` and `max_response_tokens`." ) else: from transformers import AutoConfig, AutoTokenizer from transformers.tokenization_utils_base import LARGE_INTEGER - tokenizer = AutoTokenizer.from_pretrained(self.model.model_path) - config = AutoConfig.from_pretrained(self.model.model_path) + tokenizer = AutoTokenizer.from_pretrained(model.model_path) + config = AutoConfig.from_pretrained(model.model_path) max_model_len = min( getattr(tokenizer, "model_max_length", LARGE_INTEGER), getattr(config, "max_position_embeddings", LARGE_INTEGER), @@ -854,42 +819,38 @@ def _check_model(self) -> None: if max_model_len >= LARGE_INTEGER: max_model_len = MAX_MODEL_LEN logger.warning( - f"Failed to get `max_model_len` from model {self.model.model_path}, use {MAX_MODEL_LEN} instead." + f"Failed to get `max_model_len` from model {model.model_path}, use {MAX_MODEL_LEN} instead." ) - self.model.max_model_len = max_model_len + model.max_model_len = max_model_len # both max_prompt_tokens and max_response_tokens are None - if self.model.max_prompt_tokens is None and self.model.max_response_tokens is None: + if model.max_prompt_tokens is None and model.max_response_tokens is None: # default to max_model_len / 2 - self.model.max_prompt_tokens = self.model.max_model_len // 2 - self.model.max_response_tokens = self.model.max_model_len - self.model.max_prompt_tokens + model.max_prompt_tokens = model.max_model_len // 2 + model.max_response_tokens = model.max_model_len - model.max_prompt_tokens logger.warning( - f"`max_prompt_tokens` and `max_response_tokens` are not set, set to {self.model.max_prompt_tokens} and {self.model.max_response_tokens} respectively." + f"`max_prompt_tokens` and `max_response_tokens` are not set, set to {model.max_prompt_tokens} and {model.max_response_tokens} respectively." ) # only max_prompt_tokens is None - if self.model.max_prompt_tokens is None and self.model.max_response_tokens is not None: - self.model.max_response_tokens = min( - self.model.max_response_tokens, self.model.max_model_len - 1 - ) - self.model.max_prompt_tokens = self.model.max_model_len - self.model.max_response_tokens + if model.max_prompt_tokens is None and model.max_response_tokens is not None: + model.max_response_tokens = min(model.max_response_tokens, model.max_model_len - 1) + model.max_prompt_tokens = model.max_model_len - model.max_response_tokens logger.warning( - f"`max_prompt_tokens` is set to {self.model.max_prompt_tokens}, `max_response_tokens` is set to {self.model.max_response_tokens}." + f"`max_prompt_tokens` is set to {model.max_prompt_tokens}, `max_response_tokens` is set to {model.max_response_tokens}." ) # only max_response_tokens is None - if self.model.max_response_tokens is None and self.model.max_prompt_tokens is not None: - self.model.max_prompt_tokens = min( - self.model.max_prompt_tokens, self.model.max_model_len - 1 - ) - self.model.max_response_tokens = self.model.max_model_len - self.model.max_prompt_tokens + if model.max_response_tokens is None and model.max_prompt_tokens is not None: + model.max_prompt_tokens = min(model.max_prompt_tokens, model.max_model_len - 1) + model.max_response_tokens = model.max_model_len - model.max_prompt_tokens logger.warning( - f"`max_response_tokens` is set to {self.model.max_response_tokens}, `max_prompt_tokens` is set to {self.model.max_prompt_tokens}." + f"`max_response_tokens` is set to {model.max_response_tokens}, `max_prompt_tokens` is set to {model.max_prompt_tokens}." ) - if self.model.min_response_tokens >= self.model.max_response_tokens: # type: ignore [operator] - self.model.min_response_tokens = max(self.model.max_response_tokens - 1, 0) # type: ignore [operator] - logger.warning(f"`min_response_tokens` is set to {self.model.min_response_tokens}.") + if model.min_response_tokens >= model.max_response_tokens: # type: ignore [operator] + model.min_response_tokens = max(model.max_response_tokens - 1, 0) # type: ignore [operator] + logger.warning(f"`min_response_tokens` is set to {model.min_response_tokens}.") def __iter__(self): """Iterate over configs with each stage applied in order. @@ -954,14 +915,10 @@ def check_and_update(self) -> Config: # noqa: C901 for aux_model in self.explorer.auxiliary_models: if not aux_model.model_path: raise ValueError("auxiliary model's model_path is required.") - if aux_model.max_model_len is None: - aux_model.max_model_len = self.model.max_model_len - if aux_model.max_prompt_tokens is None: - aux_model.max_prompt_tokens = self.model.max_prompt_tokens - if aux_model.max_response_tokens is None: - aux_model.max_response_tokens = self.model.max_response_tokens - if aux_model.min_response_tokens is None: - aux_model.min_response_tokens = self.model.min_response_tokens + set_if_none(aux_model, "max_model_len", self.model.max_model_len) + set_if_none(aux_model, "max_prompt_tokens", self.model.max_prompt_tokens) + set_if_none(aux_model, "max_response_tokens", self.model.max_response_tokens) + set_if_none(aux_model, "min_response_tokens", self.model.min_response_tokens) # check synchronizer self.synchronizer.ray_namespace = self.ray_namespace @@ -986,8 +943,7 @@ def check_and_update(self) -> Config: # noqa: C901 monitor_cls = MONITOR.get(self.monitor.monitor_type) if monitor_cls is None: raise ValueError(f"Invalid monitor type: {self.monitor.monitor_type}") - if self.monitor.monitor_args is None: - self.monitor.monitor_args = monitor_cls.default_args() + set_if_none(self.monitor, "monitor_args", monitor_cls.default_args()) # create a job dir in ///monitor self.monitor.cache_dir = os.path.join(self.checkpoint_job_dir, "monitor") try: diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index be0a5ed2fe..35c4ec8b2a 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -68,8 +68,8 @@ def sync_wrapper(self, *args, **kwargs): class ModelWrapper: """A wrapper for the InferenceModel Ray Actor""" - def __init__(self, model: Any, model_type: str = "vllm", enable_history: bool = False): - assert model_type.startswith("vllm"), "Only vLLM model is supported for now." + def __init__(self, model: Any, engine_type: str = "vllm", enable_history: bool = False): + assert engine_type.startswith("vllm"), "Only vLLM model is supported for now." self.model = model self.api_address: str = None self.openai_client: openai.OpenAI = None diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 3ccc85c1b3..a40df7aca5 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -580,6 +580,8 @@ def _gen_buffer_config(self): }, }, } + if not experience_buffer_path: + del buffer_config["trainer_input"]["experience_buffer"]["path"] if st.session_state["train_batch_size"] is None: del buffer_config["train_batch_size"] if st.session_state["algorithm_type"] != "dpo":