diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index ed8f2fd635..58ed9c6035 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -12,7 +12,7 @@ Below is a table summarizing the modules and components that developers with dif | Enhance the RL process from the data perspective. | *Buffer* | `Operator` | ```{note} -Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code. +Trinity-RFT is under active development, and the following interfaces may change. Please refer to the latest code when using this guide. ``` --- @@ -216,12 +216,6 @@ __all__ = [ ] ``` -For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in `trinity/plugins`. Trinity-RFT will automatically detect and load all custom modules in this folder. - -```{tip} -You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `/trinity/plugins` as the default directory. -``` - #### Avoid Re-initialization For heavy workflows, re-initializing every time can incurs extra computational costs. @@ -728,21 +722,20 @@ By meticulously following these steps, you can ensure that new parameters are su --- -## Check Code Style +## Contributing your Code -Before submitting the code, make sure it passes the code style check. Follow these steps: +For modules that are prepared to be contributed to the Trinity-RFT project, please follow the steps below: -```shell -# Install code style checking tools -cd -# bash -pip install -e .[dev] -# zsh -# pip install -e .\[dev\] +1. Implement your code in the appropriate directory, such as `trinity/common/workflows` for workflows, `trinity/algorithm` for algorithms, `trinity/buffer/operators` for operators. -# Run code style checks -pre-commit run --all-files +2. Register your module in the corresponding `__init__.py` file of the directory. -# Commit the code after all checks pass -git commit -am "create example workflow" +3. Add tests for your module in the `tests` directory, following the naming conventions and structure of existing tests. + +4. Before submitting the code, make sure it passes the code style check with `pre-commit run --all-files`. + +5. Submit a pull request to the Trinity-RFT repository, including a clear description of your changes. + +```{tip} +For modules that only used for local testing or not intended for contribution, you can place them in the `trinity/plugins` directory. Trinity-RFT will automatically load all modules in this directory, and you can use those modules without adding them to the `__init__.py` file. You can specify another directory by setting the `--plugin-dir` option when running Trinity-RFT, e.g., `trinity run --config /path/to/your/config --plugin-dir /path/to/your/plugins`. ``` diff --git a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml index f2204c7ac9..f63ea94145 100644 --- a/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml +++ b/examples/grpo_gsm8k_task_pipeline/gsm8k.yaml @@ -5,27 +5,19 @@ algorithm: algorithm_type: grpo repeat_times: 8 data_processor: - data_processor_url: 'http://127.0.0.1:5005/data_processor' # task pipeline related task_pipeline: - # I/O buffers - input_buffers: - - name: 'raw_input' - path: 'openai/gsm8k' - storage_type: 'file' - raw: true - output_buffer: - name: 'raw_output' - path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl' - storage_type: 'file' - # format mapping - format: - prompt_key: 'question' - response_key: 'answer' - # data active iterator related - dj_process_desc: 'Please compute difficulty scores for these math questions.' - agent_model_name: 'qwen-max' - clean_strategy: 'iterative' + operators: + - name: "difficulty_score_filter" + args: + api_or_hf_model: "qwen2.5-7b-instruct" + min_score: 0.0 + text_key: "question" + inputs: + - /PATH/TO/GSM8K/JSONL +service: + data_juicer: + auto_start: true model: model_path: /PATH/TO/MODEL/ max_response_tokens: 1024 @@ -64,11 +56,6 @@ buffer: name: gsm8k_buffer storage_type: queue path: 'sqlite:///gsm8k.db' - # sft_warmup_steps: 0 - # sft_warmup_dataset: # Uncomment these to enable sft warmup - # name: warmup_data - # storage_type: file - # path: '/PATH/TO/WARMUP_DATA/' explorer: eval_interval: 50 runner_num: 32 diff --git a/tests/service/data_juicer_test.py b/tests/service/data_juicer_test.py index 9b0487f2ce..096cc62a36 100644 --- a/tests/service/data_juicer_test.py +++ b/tests/service/data_juicer_test.py @@ -1,3 +1,5 @@ +import os +import shutil import time import unittest from multiprocessing import Process @@ -12,16 +14,23 @@ import torch from jsonargparse import Namespace -from tests.tools import RayUnittestBaseAysnc, get_template_config +from tests.tools import RayUnittestBase, RayUnittestBaseAysnc, get_template_config from trinity.buffer.buffer import get_buffer_reader -from trinity.buffer.pipelines import ExperiencePipeline -from trinity.common.config import DataJuicerServiceConfig, OperatorConfig +from trinity.buffer.pipelines import ExperiencePipeline, check_and_run_task_pipeline +from trinity.common.config import ( + DataJuicerServiceConfig, + OperatorConfig, + StorageConfig, + TaskPipelineConfig, +) from trinity.common.experience import Experience from trinity.service.data_juicer.client import DataJuicerClient from trinity.service.data_juicer.server.server import main from trinity.service.data_juicer.server.utils import DJConfig, parse_config from trinity.utils.distributed import get_available_port +TASKSET_OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "taskset_output") + class TestDataJuicer(unittest.TestCase): def test_config(self): @@ -129,7 +138,7 @@ def start_server(port): self.assertIsNone(client.server) -class TestDataJuicerOperators(RayUnittestBaseAysnc): +class TestDataJuicerExperiencePipeline(RayUnittestBaseAysnc): async def test_data_juicer_operators(self): config = get_template_config() config.service.data_juicer = DataJuicerServiceConfig( @@ -198,3 +207,68 @@ async def test_data_juicer_operators(self): with self.assertRaises(TimeoutError): reader.read(batch_size=1) await pipeline.close.remote() + + +class TestDataJuicerTaskPipeline(RayUnittestBase): + def setUp(self): + if os.path.exists(TASKSET_OUTPUT_DIR): + shutil.rmtree(TASKSET_OUTPUT_DIR) + + def test_data_juicer_task_pipeline(self): + config = get_template_config() + config.service.data_juicer = DataJuicerServiceConfig( + auto_start=True, + ) + config.data_processor.task_pipeline = TaskPipelineConfig( + operators=[ + OperatorConfig( + name="text_length_filter", + args={ + "min_len": 10, + "max_len": 500, + "text_key": "question", + }, + ), + OperatorConfig( + name="word_repetition_filter", + args={ + "rep_len": 3, + "min_ratio": 0.0, + "max_ratio": 0.2, + "text_key": "question", + }, + ), + ], + inputs=[ + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "template", + "data", + "gsm8k", + "train.jsonl", + ), + os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "template", + "data", + "countdown", + "train.jsonl", + ), + ], + target_fields=["question", "answer"], + ) + config.buffer.explorer_input.taskset = StorageConfig( + name="taskset", + path=TASKSET_OUTPUT_DIR, + ) + config.check_and_update() + metrics = check_and_run_task_pipeline(config) + self.assertTrue("sample_num" in metrics) + self.assertEqual(metrics["sample_num"], 16) + from datasets import load_dataset + + ds = load_dataset( + TASKSET_OUTPUT_DIR, + split="train", + ) + self.assertEqual(ds.num_rows, 16) diff --git a/trinity/buffer/operators/data_juicer_operator.py b/trinity/buffer/operators/data_juicer_operator.py index 2e10023082..d651885fea 100644 --- a/trinity/buffer/operators/data_juicer_operator.py +++ b/trinity/buffer/operators/data_juicer_operator.py @@ -29,11 +29,18 @@ def __init__( Note: - Must include one of the following, and the priority is from high to low: - - `operators` (`List[Dict]`) - `config_path` (`str`) + - `operators` (`List[Dict]`) """ self.client = DataJuicerClient(config=service_config) - self.client.initialize({"operators": operators, "config_path": config_path, "np": np}) + self.client.initialize( + { + "operators": operators, + "config_path": config_path, + "np": np, + "pipeline_type": "experience", + } + ) def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]: return self.client.process_experience(exps) diff --git a/trinity/buffer/pipelines/__init__.py b/trinity/buffer/pipelines/__init__.py index a954f83add..c24f9507ba 100644 --- a/trinity/buffer/pipelines/__init__.py +++ b/trinity/buffer/pipelines/__init__.py @@ -1,5 +1,11 @@ from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline +from trinity.buffer.pipelines.task_pipeline import ( + TaskPipeline, + check_and_run_task_pipeline, +) __all__ = [ "ExperiencePipeline", + "TaskPipeline", + "check_and_run_task_pipeline", ] diff --git a/trinity/buffer/pipelines/task_pipeline.py b/trinity/buffer/pipelines/task_pipeline.py new file mode 100644 index 0000000000..8180049a84 --- /dev/null +++ b/trinity/buffer/pipelines/task_pipeline.py @@ -0,0 +1,73 @@ +from typing import Any, Dict + +from trinity.common.config import Config, OperatorConfig, TaskPipelineConfig +from trinity.utils.log import get_logger + + +def check_and_run_task_pipeline(config: Config) -> Dict: + if not (config.mode == "explore" or config.mode == "both"): + # task pipeline is only available when using Explorer + return {} + if config.data_processor.task_pipeline is None: + return {} + + task_pipeline = TaskPipeline(config) + try: + return task_pipeline.process() + except Exception as e: + raise RuntimeError(f"Task pipeline failed: {e}") + finally: + task_pipeline.close() + + +class TaskPipeline: + """ + A class to process task datasets through DataJuicer. + """ + + def __init__(self, config: Config): + self.logger = get_logger(__name__) + from trinity.service.data_juicer.client import DataJuicerClient + + self.client = DataJuicerClient(config.service.data_juicer) # type: ignore [arg-type] + self.pipeline_config = config.data_processor.task_pipeline + + def convert_pipeline_config(self, pipeline_config: TaskPipelineConfig) -> Dict[str, Any]: + """ + Convert the TaskPipelineConfig to a format suitable for DataJuicer. + """ + + def _convert_operator(operator: OperatorConfig) -> Dict: + return {operator.name: {key: value for key, value in operator.args.items()}} + + if pipeline_config.output.path is None: + raise ValueError("When using task pipeline, taskset.path must be set.") + + converted_config = { + "pipeline_type": "task", + "operators": [_convert_operator(op) for op in pipeline_config.operators], + "np": pipeline_config.num_process, + "config_path": pipeline_config.config_path, + "inputs": [path for path in pipeline_config.inputs], + "target_fields": pipeline_config.target_fields, + "output_dir": pipeline_config.output.path, + } + return converted_config + + def process(self) -> Dict[str, Any]: + """ + Process the task datasets using DataJuicer. + + Returns: + Dict[str, Any]: Metrics for logging. + """ + # Convert the pipeline configuration + converted_config = self.convert_pipeline_config(self.pipeline_config) # type: ignore [arg-type] + self.client.initialize(converted_config) + self.logger.info("Starting task processing...") + metrics = self.client.process_task() + self.logger.info("Task processing completed.") + return metrics + + def close(self): + self.client.close() diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 1c0f7a51e5..81fbc462b4 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -8,8 +8,8 @@ import ray +from trinity.buffer.pipelines.task_pipeline import check_and_run_task_pipeline from trinity.common.config import Config, load_config -from trinity.data.utils import check_and_activate_data_processor, stop_data_processor from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger @@ -122,9 +122,8 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): config = load_config(config_path) config.check_and_update() pprint(config) - # try to activate task pipeline for raw data - data_processor_config = config.data_processor - check_and_activate_data_processor(data_processor_config, config_path) + # try to run task pipeline for raw data + check_and_run_task_pipeline(config) if dlc: from trinity.utils.dlc_utils import setup_ray_cluster @@ -156,10 +155,6 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): stop_ray_cluster(namespace=config.ray_namespace) - # stop all pipelines - if data_processor_config.data_processor_url is not None: - stop_data_processor(data_processor_config.data_processor_url) - def studio(port: int = 8501): from streamlit.web import cli as stcli diff --git a/trinity/common/config.py b/trinity/common/config.py index 8e37e03633..6a05c72eb8 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -141,7 +141,10 @@ class OperatorConfig: @Experimental @dataclass class ExperiencePipelineConfig: - """Config for experience pipeline.""" + """Config for experience pipeline. + + Experience Pipeline is used to pre-process rollout experiences for better training. + """ # The list of experience operators to apply, operators will be applied in the order they are defined operators: List[OperatorConfig] = field(default_factory=list) @@ -159,6 +162,32 @@ class ExperiencePipelineConfig: output: Optional[StorageConfig] = None +@Experimental +@dataclass +class TaskPipelineConfig: + """Config for task pipeline. + + Task Pipeline is used to pre-process raw tasks for better exploring. Currently, we only support using + Data-Juicer operators for task pipeline. + """ + + # The list of data-juicer operators to apply, operators will be applied in the order they are defined + operators: List[OperatorConfig] = field(default_factory=list) + # number of process + num_process: int = 4 + # The path to the Data-Juicer config file. If set, operators and num_process will be ignored + config_path: Optional[str] = None + + # Raw input tasksets. Currently, task pipeline only support local file as inputs, + # e.g., /path/to/file.jsonl or /path/to/file.parquet, not a directory or huggingface path + inputs: List[str] = field(default_factory=list) + # Output task buffer, if not set, use `buffer.explorer_input.taskset`. In most cases, users do not need to set this field. + output: Optional[StorageConfig] = None + + # The list of fields extracted from the input tasksets and processed into the output taskset + target_fields: List[str] = field(default_factory=list) + + @dataclass class DataPipelineConfig: """Config for data pipeline.""" @@ -193,7 +222,7 @@ class DataProcessorConfig: # support two types of data pipelines for now # 1. For task. Data preprocessing from raw dataset to the task set - task_pipeline: Optional[DataPipelineConfig] = None + task_pipeline: Optional[TaskPipelineConfig] = None # 2. For experience. Data processing for rollouts experience_pipeline: Optional[ExperiencePipelineConfig] = field( default_factory=ExperiencePipelineConfig @@ -656,7 +685,7 @@ def _check_buffer(self) -> None: # noqa: C901 if self.buffer.trainer_input.sft_warmup_dataset.ray_namespace is None: self.buffer.trainer_input.sft_warmup_dataset.ray_namespace = self.ray_namespace - # check input/output buffers in experience pipelines + # check input/output buffers in pipelines if self.data_processor.experience_pipeline is not None: if ( self.data_processor.experience_pipeline.save_input @@ -668,6 +697,16 @@ def _check_buffer(self) -> None: # noqa: C901 logger.info( f"Auto set `data_processor.experience_pipeline.input_save_path` to {self.data_processor.experience_pipeline.input_save_path}" ) + if self.data_processor.task_pipeline is not None: + if self.data_processor.task_pipeline.output is None: + self.data_processor.task_pipeline.output = self.buffer.explorer_input.taskset + if self.data_processor.task_pipeline.output.path and os.path.exists( + self.data_processor.task_pipeline.output.path + ): + raise ValueError( + f"Task pipeline output path {self.data_processor.task_pipeline.output.path} already exists.\n" + "Please choose a different output path to avoid overwriting." + ) # check train_batch_size if not self.buffer.train_batch_size: diff --git a/trinity/data/utils.py b/trinity/data/utils.py deleted file mode 100644 index 38ed10d6cd..0000000000 --- a/trinity/data/utils.py +++ /dev/null @@ -1,87 +0,0 @@ -from trinity.common.config import DataPipelineConfig, DataProcessorConfig -from trinity.common.constants import DataProcessorPipelineType -from trinity.utils.log import get_logger - -logger = get_logger(__name__) - - -def check_and_activate_data_processor(data_processor_config: DataProcessorConfig, config_path: str): - if ( - data_processor_config.data_processor_url is not None - and data_processor_config.task_pipeline is not None - and validate_data_pipeline( - data_processor_config.task_pipeline, DataProcessorPipelineType.TASK - ) - ): - activate_data_processor( - f"{data_processor_config.data_processor_url}/{DataProcessorPipelineType.TASK.value}", - config_path, - ) - # TODO: check and activate experience pipeline - - -def activate_data_processor(data_processor_url: str, config_path: str): - """Check whether to activate data module and preprocess datasets.""" - from trinity.cli.client import request - - logger.info(f"Activating data module of {data_processor_url}...") - res = request( - url=data_processor_url, - configPath=config_path, - ) - if res["return_code"] != 0: - logger.error(f"Failed to activate data module: {res['return_msg']}.") - return - - -def stop_data_processor(base_data_processor_url: str): - """Stop all pipelines in the data processor""" - from trinity.cli.client import request - - logger.info(f"Stopping all pipelines in {base_data_processor_url}...") - res = request(url=f"{base_data_processor_url}/stop_all") - if res["return_code"] != 0: - logger.error(f"Failed to stop all data pipelines: {res['return_msg']}.") - return - - -def validate_data_pipeline( - data_pipeline_config: DataPipelineConfig, pipeline_type: DataProcessorPipelineType -): - """ - Check if the data pipeline is valid. The config should: - 1. Non-empty input buffer - 2. Different input/output buffers - - :param data_pipeline_config: the input data pipeline to be validated. - :param pipeline_type: the type of pipeline, should be one of DataProcessorPipelineType - """ - input_buffers = data_pipeline_config.input_buffers - output_buffer = data_pipeline_config.output_buffer - # common checks - # check if the input buffer list is empty - if len(input_buffers) == 0: - logger.warning("Empty input buffers in the data pipeline. Won't activate it.") - return False - # check if the input and output buffers are different - input_buffer_names = [buffer.name for buffer in input_buffers] - if output_buffer.name in input_buffer_names: - logger.warning("Output buffer exists in input buffers. Won't activate it.") - return False - if pipeline_type == DataProcessorPipelineType.TASK: - # task pipeline specific - # "raw" field should be True for task pipeline because the data source must be raw data files - for buffer in input_buffers: - if not buffer.raw: - logger.warning( - 'Input buffers should be raw data files for task pipeline ("raw" field should be True). Won\'t activate it.' - ) - return False - elif pipeline_type == DataProcessorPipelineType.EXPERIENCE: - # experience pipeline specific - # No special items need to be checked. - pass - else: - logger.warning(f"Invalid pipeline type: {pipeline_type}..") - return False - return True diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 21346c375e..20b3c87008 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -218,7 +218,7 @@ async def explore_step(self) -> bool: if self.last_sync_successful else RunningStatus.REQUIRE_SYNC, ) - await self.experience_pipeline.close.remote() + await self.shutdown() return False self.scheduler.schedule(tasks, batch_id=self.explore_step_num + 1) self.explore_step_num += 1 @@ -371,8 +371,18 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva self.monitor.log(metric, step) async def shutdown(self) -> None: - await self.scheduler.stop() - self.monitor.close() + if self.scheduler: + await self.scheduler.stop() + self.scheduler = None + if self.experience_pipeline: + await self.experience_pipeline.close.remote() + self.experience_pipeline = None + if self.monitor: + self.monitor.close() + self.monitor = None + self.logger.info( + f"Explorer ({self.config.explorer.name}) shutdown successfully at step {self.explore_step_num}." + ) def is_alive(self) -> bool: """Check if the explorer is alive.""" diff --git a/trinity/service/data_juicer/client.py b/trinity/service/data_juicer/client.py index 8d072f4da6..9109e095a2 100644 --- a/trinity/service/data_juicer/client.py +++ b/trinity/service/data_juicer/client.py @@ -10,13 +10,15 @@ from trinity.common.config import DataJuicerServiceConfig from trinity.common.experience import Experience, from_hf_datasets, to_hf_datasets -from trinity.utils.distributed import get_available_port +from trinity.utils.distributed import get_available_port, is_port_available +from trinity.utils.log import get_logger class DataJuicerClient: """Client for interacting with the DataJuicer server.""" def __init__(self, config: DataJuicerServiceConfig): + self.logger = get_logger(__name__) self.config = config self.url = config.server_url self.session_id = None @@ -35,6 +37,11 @@ def _start_server(self): if not self.config.port: self.config.port = get_available_port() + elif not is_port_available(self.config.port): + self.config.port = get_available_port() + self.logger.info( + f"Starting DataJuicer server at {self.config.server_url} on port {self.config.port}" + ) self.url = f"http://localhost:{self.config.port}" server_process = Process( target=main, kwargs={"host": "localhost", "port": self.config.port, "debug": False} @@ -47,6 +54,7 @@ def _start_server(self): break except ConnectionError: time.sleep(5) + self.logger.info(f"DataJuicer server at {self.url} started successfully.") return server_process def _check_connection(self) -> bool: @@ -85,6 +93,18 @@ def process_experience(self, exps: List[Experience]) -> Tuple[List[Experience], exps = from_hf_datasets(deserialize_arrow_to_dataset(response.content)) return exps, metrics + def process_task(self) -> Dict: + """Process a task using the Data-Juicer service.""" + if not self.session_id: + raise ValueError("DataJuicer session is not initialized.") + json_data = {"session_id": self.session_id} + response = requests.post(f"{self.url}/process_task", json=json_data) + if response.status_code != 200: + raise RuntimeError( + f"Failed to process task: {response.status_code}, {response.json().get('error')}" + ) + return response.json().get("metrics") + def close(self): """Close the DataJuicer client connection.""" if self.session_id: diff --git a/trinity/service/data_juicer/server/server.py b/trinity/service/data_juicer/server/server.py index 9b8b6bd291..9244777d62 100644 --- a/trinity/service/data_juicer/server/server.py +++ b/trinity/service/data_juicer/server/server.py @@ -27,9 +27,7 @@ def create(): """Create a new data juicer session. Args: - config (dict): Configuration parameters for the session. Must include one of the following, and the priority is from high to low: - - + config (dict): Configuration parameters for the session. For example, the config should look like this: .. code-block:: python @@ -85,11 +83,9 @@ def process_experience(): session = sessions[session_id] try: # from_hf_datasets and to_hf_datasets should be imported from trinity.common.experience - processed_ds, metrics = session.process(ds) - print(f"Processed {len(ds)} experiences, got {len(processed_ds)} after processing.") - except Exception as e: - print(f"Error processing experiences: {traceback.format_exc()}") - return jsonify({"error": f"Processing failed: {e}"}), 500 + processed_ds, metrics = session.process_experience(ds) + except Exception: + return jsonify({"error": f"Experience processing failed:\n{traceback.format_exc()}"}), 500 # Serialize processed experiences to parquet in-memory return_bytes = serialize_dataset_to_arrow(processed_ds) @@ -102,6 +98,25 @@ def process_experience(): return response +@app.route("/process_task", methods=["POST"]) +def process_task(): + """Process a task for a given session. + Different from process_experience which process a batch of experiences, + this endpoint process a whole taskset. + """ + data = request.json + session_id = data.get("session_id") + if not session_id or session_id not in sessions: + return jsonify({"error": "Session ID not found."}), 404 + + session = sessions[session_id] + try: + metrics = session.process_task() + return jsonify({"metrics": metrics}) + except Exception: + return jsonify({"error": f"Task processing failed:\n{traceback.format_exc()}"}), 500 + + @app.route("/close", methods=["POST"]) def close(): """Close a data juicer session.""" diff --git a/trinity/service/data_juicer/server/session.py b/trinity/service/data_juicer/server/session.py index 3a993a299e..72682f6315 100644 --- a/trinity/service/data_juicer/server/session.py +++ b/trinity/service/data_juicer/server/session.py @@ -1,3 +1,4 @@ +import os from typing import Dict, Tuple from datasets import Dataset @@ -24,15 +25,38 @@ def __init__(self, config: DJConfig): Args: config (DataJuicerConfigModel): Configuration parameters provided by Trinity. """ - self.config: Namespace = parse_config(config) + self.config = config + self.dj_config: Namespace = parse_config(config) - def process(self, ds: Dataset) -> Tuple[Dataset, Dict]: - # TODO: Implement the processing logic using data juicer executor + def process_experience(self, ds: Dataset) -> Tuple[Dataset, Dict]: + """Process a batch of experiences. + + Args: + ds (Dataset): The input dataset containing a batch of experiences. + + Returns: + Tuple[Dataset, Dict]: The processed dataset and extracted metrics. + """ from data_juicer.core.data import NestedDataset from data_juicer.core.executor.default_executor import DefaultExecutor - dj_executor = DefaultExecutor(cfg=self.config) + dj_executor = DefaultExecutor(cfg=self.dj_config) ds = dj_executor.run(NestedDataset(ds)) metrics = extract_metrics(ds) return ds, metrics + + def process_task(self) -> Dict: + """ + Process task datasets using Data-Juicer + """ + from data_juicer.core.executor.default_executor import DefaultExecutor + + dj_executor = DefaultExecutor(cfg=self.dj_config) + + ds: Dataset = dj_executor.run() + # sort the output dataset in priority + if "priority" in ds.features: + ds.sort_by("priority", reverse=True) + ds.to_json(os.path.join(self.config.output_dir, "output.jsonl")) # type: ignore [arg-type] + return {"sample_num": ds.num_rows} diff --git a/trinity/service/data_juicer/server/utils.py b/trinity/service/data_juicer/server/utils.py index f8346893c8..afd9bbffbe 100644 --- a/trinity/service/data_juicer/server/utils.py +++ b/trinity/service/data_juicer/server/utils.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Optional +import os +from typing import Any, Dict, List, Literal, Optional from data_juicer.config import get_init_configs, prepare_side_configs from jsonargparse import Namespace @@ -6,10 +7,19 @@ class DJConfig(BaseModel): + pipeline_type: Literal["task", "experience"] = "experience" + + # For both `task` and `experience` operators: Optional[List[Dict[str, Dict[str, Any]]]] = None config_path: Optional[str] = None np: int = 4 + # For `task` only + executor_type: Literal["ray", "default"] = "default" + inputs: List[str] = [] # List of input files + output_dir: Optional[str] = None + target_fields: List[str] = [] # fields in the output dataset + @model_validator(mode="after") def check_dj_config(self): if not (self.config_path or self.operators): @@ -21,12 +31,59 @@ def check_dj_config(self): def parse_config(config: DJConfig) -> Namespace: """Convert Trinity config to DJ config""" + if config.config_path is not None: + task_config = prepare_side_configs(config.config_path) + task_config = get_init_configs(task_config) + return task_config + + if config.pipeline_type == "experience": + return _parse_experience_pipeline_config(config) + elif config.pipeline_type == "task": + return _parse_task_pipeline_config(config) + else: + raise ValueError(f"Unknown pipeline type: {config.pipeline_type}") + + +def _parse_experience_pipeline_config(config: DJConfig) -> Namespace: + """Parse the experience pipeline configuration.""" + if config.operators is not None: + exp_config = Namespace(process=[op for op in config.operators], np=config.np) + exp_config = get_init_configs(exp_config) + else: + raise ValueError("At least one of operators or config_path should be provided.") + return exp_config + + +def _parse_task_pipeline_config(config: DJConfig) -> Namespace: + """Parse the task pipeline configuration.""" if config.operators is not None: - dj_config = Namespace(process=[op for op in config.operators], np=config.np) - dj_config = get_init_configs(dj_config) - elif config.config_path is not None: - dj_config = prepare_side_configs(config.config_path) - dj_config = get_init_configs(dj_config) + for input in config.inputs: + if not os.path.exists(input): + raise FileNotFoundError(f"{input} does not exist.") + if not os.path.isfile(input): + raise ValueError( + f"{input} is not a file. Currently, the task pipeline only supports processing files." + ) + if config.output_dir is None: + raise ValueError("`output_dir` must be set for task pipeline.") + os.makedirs(config.output_dir, exist_ok=True) + task_config = Namespace( + process=[op for op in config.operators], + np=config.np, + dataset={ + "configs": [ + { + "type": "local", + "weight": 1.0, + "path": path, + } + for path in config.inputs + ] + }, + text_keys=config.target_fields, + export_shard_size=128 * 1024 * 1024, # 128 MB + ) + task_config = get_init_configs(task_config) else: raise ValueError("At least one of operators or config_path should be provided.") - return dj_config + return task_config diff --git a/trinity/utils/distributed.py b/trinity/utils/distributed.py index cd234eac4a..7b178aba6a 100644 --- a/trinity/utils/distributed.py +++ b/trinity/utils/distributed.py @@ -30,6 +30,15 @@ def get_available_port() -> int: return s.getsockname()[1] +def is_port_available(port: int, host="127.0.0.1") -> bool: + with socket.socket() as s: + try: + s.bind((host, port)) + return True + except OSError: + return False + + def init_process_group( host: str, port: int,