Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 13 additions & 20 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Below is a table summarizing the modules and components that developers with dif
| Enhance the RL process from the data perspective. | *Buffer* | `Operator` |

```{note}
Trinity-RFT is still under development, and the following interfaces may change. Please read this section in conjunction with the latest code.
Trinity-RFT is under active development, and the following interfaces may change. Please refer to the latest code when using this guide.
```

---
Expand Down Expand Up @@ -216,12 +216,6 @@ __all__ = [
]
```

For workflows that are not intended to be contributed to Trinity-RFT project, you can just place the above code in `trinity/plugins`. Trinity-RFT will automatically detect and load all custom modules in this folder.

```{tip}
You can specify the directory where your custom modules are located by setting `--plugin-dir` when starting Trinity-RFT. If you don't specify `--plugin-dir`, Trinity-RFT will use `<Trinity_RFT_ROOT_DIR>/trinity/plugins` as the default directory.
```

#### Avoid Re-initialization

For heavy workflows, re-initializing every time can incurs extra computational costs.
Expand Down Expand Up @@ -728,21 +722,20 @@ By meticulously following these steps, you can ensure that new parameters are su

---

## Check Code Style
## Contributing your Code

Before submitting the code, make sure it passes the code style check. Follow these steps:
For modules that are prepared to be contributed to the Trinity-RFT project, please follow the steps below:

```shell
# Install code style checking tools
cd <path_to_trinity_rft>
# bash
pip install -e .[dev]
# zsh
# pip install -e .\[dev\]
1. Implement your code in the appropriate directory, such as `trinity/common/workflows` for workflows, `trinity/algorithm` for algorithms, `trinity/buffer/operators` for operators.

# Run code style checks
pre-commit run --all-files
2. Register your module in the corresponding `__init__.py` file of the directory.

# Commit the code after all checks pass
git commit -am "create example workflow"
3. Add tests for your module in the `tests` directory, following the naming conventions and structure of existing tests.

4. Before submitting the code, make sure it passes the code style check with `pre-commit run --all-files`.

5. Submit a pull request to the Trinity-RFT repository, including a clear description of your changes.

```{tip}
For modules that only used for local testing or not intended for contribution, you can place them in the `trinity/plugins` directory. Trinity-RFT will automatically load all modules in this directory, and you can use those modules without adding them to the `__init__.py` file. You can specify another directory by setting the `--plugin-dir` option when running Trinity-RFT, e.g., `trinity run --config /path/to/your/config --plugin-dir /path/to/your/plugins`.
```
35 changes: 11 additions & 24 deletions examples/grpo_gsm8k_task_pipeline/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,19 @@ algorithm:
algorithm_type: grpo
repeat_times: 8
data_processor:
data_processor_url: 'http://127.0.0.1:5005/data_processor'
# task pipeline related
task_pipeline:
# I/O buffers
input_buffers:
- name: 'raw_input'
path: 'openai/gsm8k'
storage_type: 'file'
raw: true
output_buffer:
name: 'raw_output'
path: './outputs/task_pipeline_output/prioritized_gsm8k.jsonl'
storage_type: 'file'
# format mapping
format:
prompt_key: 'question'
response_key: 'answer'
# data active iterator related
dj_process_desc: 'Please compute difficulty scores for these math questions.'
agent_model_name: 'qwen-max'
clean_strategy: 'iterative'
operators:
- name: "difficulty_score_filter"
args:
api_or_hf_model: "qwen2.5-7b-instruct"
min_score: 0.0
text_key: "question"
inputs:
- /PATH/TO/GSM8K/JSONL
service:
data_juicer:
auto_start: true
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 1024
Expand Down Expand Up @@ -64,11 +56,6 @@ buffer:
name: gsm8k_buffer
storage_type: queue
path: 'sqlite:///gsm8k.db'
# sft_warmup_steps: 0
# sft_warmup_dataset: # Uncomment these to enable sft warmup
# name: warmup_data
# storage_type: file
# path: '/PATH/TO/WARMUP_DATA/'
explorer:
eval_interval: 50
runner_num: 32
Expand Down
82 changes: 78 additions & 4 deletions tests/service/data_juicer_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import shutil
import time
import unittest
from multiprocessing import Process
Expand All @@ -12,16 +14,23 @@
import torch
from jsonargparse import Namespace

from tests.tools import RayUnittestBaseAysnc, get_template_config
from tests.tools import RayUnittestBase, RayUnittestBaseAysnc, get_template_config
from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.pipelines import ExperiencePipeline
from trinity.common.config import DataJuicerServiceConfig, OperatorConfig
from trinity.buffer.pipelines import ExperiencePipeline, check_and_run_task_pipeline
from trinity.common.config import (
DataJuicerServiceConfig,
OperatorConfig,
StorageConfig,
TaskPipelineConfig,
)
from trinity.common.experience import Experience
from trinity.service.data_juicer.client import DataJuicerClient
from trinity.service.data_juicer.server.server import main
from trinity.service.data_juicer.server.utils import DJConfig, parse_config
from trinity.utils.distributed import get_available_port

TASKSET_OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "taskset_output")


class TestDataJuicer(unittest.TestCase):
def test_config(self):
Expand Down Expand Up @@ -129,7 +138,7 @@ def start_server(port):
self.assertIsNone(client.server)


class TestDataJuicerOperators(RayUnittestBaseAysnc):
class TestDataJuicerExperiencePipeline(RayUnittestBaseAysnc):
async def test_data_juicer_operators(self):
config = get_template_config()
config.service.data_juicer = DataJuicerServiceConfig(
Expand Down Expand Up @@ -198,3 +207,68 @@ async def test_data_juicer_operators(self):
with self.assertRaises(TimeoutError):
reader.read(batch_size=1)
await pipeline.close.remote()


class TestDataJuicerTaskPipeline(RayUnittestBase):
def setUp(self):
if os.path.exists(TASKSET_OUTPUT_DIR):
shutil.rmtree(TASKSET_OUTPUT_DIR)

def test_data_juicer_task_pipeline(self):
config = get_template_config()
config.service.data_juicer = DataJuicerServiceConfig(
auto_start=True,
)
config.data_processor.task_pipeline = TaskPipelineConfig(
operators=[
OperatorConfig(
name="text_length_filter",
args={
"min_len": 10,
"max_len": 500,
"text_key": "question",
},
),
OperatorConfig(
name="word_repetition_filter",
args={
"rep_len": 3,
"min_ratio": 0.0,
"max_ratio": 0.2,
"text_key": "question",
},
),
],
inputs=[
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"template",
"data",
"gsm8k",
"train.jsonl",
),
os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"template",
"data",
"countdown",
"train.jsonl",
),
],
target_fields=["question", "answer"],
)
config.buffer.explorer_input.taskset = StorageConfig(
name="taskset",
path=TASKSET_OUTPUT_DIR,
)
config.check_and_update()
metrics = check_and_run_task_pipeline(config)
self.assertTrue("sample_num" in metrics)
self.assertEqual(metrics["sample_num"], 16)
from datasets import load_dataset

ds = load_dataset(
TASKSET_OUTPUT_DIR,
split="train",
)
self.assertEqual(ds.num_rows, 16)
11 changes: 9 additions & 2 deletions trinity/buffer/operators/data_juicer_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,18 @@ def __init__(

Note:
- Must include one of the following, and the priority is from high to low:
- `operators` (`List[Dict]`)
- `config_path` (`str`)
- `operators` (`List[Dict]`)
"""
self.client = DataJuicerClient(config=service_config)
self.client.initialize({"operators": operators, "config_path": config_path, "np": np})
self.client.initialize(
{
"operators": operators,
"config_path": config_path,
"np": np,
"pipeline_type": "experience",
}
)

def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
return self.client.process_experience(exps)
Expand Down
6 changes: 6 additions & 0 deletions trinity/buffer/pipelines/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline
from trinity.buffer.pipelines.task_pipeline import (
TaskPipeline,
check_and_run_task_pipeline,
)

__all__ = [
"ExperiencePipeline",
"TaskPipeline",
"check_and_run_task_pipeline",
]
73 changes: 73 additions & 0 deletions trinity/buffer/pipelines/task_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from typing import Any, Dict

from trinity.common.config import Config, OperatorConfig, TaskPipelineConfig
from trinity.utils.log import get_logger


def check_and_run_task_pipeline(config: Config) -> Dict:
if not (config.mode == "explore" or config.mode == "both"):
# task pipeline is only available when using Explorer
return {}
if config.data_processor.task_pipeline is None:
return {}

task_pipeline = TaskPipeline(config)
try:
return task_pipeline.process()
except Exception as e:
raise RuntimeError(f"Task pipeline failed: {e}")
finally:
task_pipeline.close()


class TaskPipeline:
"""
A class to process task datasets through DataJuicer.
"""

def __init__(self, config: Config):
self.logger = get_logger(__name__)
from trinity.service.data_juicer.client import DataJuicerClient

self.client = DataJuicerClient(config.service.data_juicer) # type: ignore [arg-type]
self.pipeline_config = config.data_processor.task_pipeline

def convert_pipeline_config(self, pipeline_config: TaskPipelineConfig) -> Dict[str, Any]:
"""
Convert the TaskPipelineConfig to a format suitable for DataJuicer.
"""

def _convert_operator(operator: OperatorConfig) -> Dict:
return {operator.name: {key: value for key, value in operator.args.items()}}

if pipeline_config.output.path is None:
raise ValueError("When using task pipeline, taskset.path must be set.")

converted_config = {
"pipeline_type": "task",
"operators": [_convert_operator(op) for op in pipeline_config.operators],
"np": pipeline_config.num_process,
"config_path": pipeline_config.config_path,
"inputs": [path for path in pipeline_config.inputs],
"target_fields": pipeline_config.target_fields,
"output_dir": pipeline_config.output.path,
}
return converted_config

def process(self) -> Dict[str, Any]:
"""
Process the task datasets using DataJuicer.

Returns:
Dict[str, Any]: Metrics for logging.
"""
# Convert the pipeline configuration
converted_config = self.convert_pipeline_config(self.pipeline_config) # type: ignore [arg-type]
self.client.initialize(converted_config)
self.logger.info("Starting task processing...")
metrics = self.client.process_task()
self.logger.info("Task processing completed.")
return metrics

def close(self):
self.client.close()
11 changes: 3 additions & 8 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

import ray

from trinity.buffer.pipelines.task_pipeline import check_and_run_task_pipeline
from trinity.common.config import Config, load_config
from trinity.data.utils import check_and_activate_data_processor, stop_data_processor
from trinity.explorer.explorer import Explorer
from trinity.trainer.trainer import Trainer
from trinity.utils.log import get_logger
Expand Down Expand Up @@ -122,9 +122,8 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
config = load_config(config_path)
config.check_and_update()
pprint(config)
# try to activate task pipeline for raw data
data_processor_config = config.data_processor
check_and_activate_data_processor(data_processor_config, config_path)
# try to run task pipeline for raw data
check_and_run_task_pipeline(config)
if dlc:
from trinity.utils.dlc_utils import setup_ray_cluster

Expand Down Expand Up @@ -156,10 +155,6 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):

stop_ray_cluster(namespace=config.ray_namespace)

# stop all pipelines
if data_processor_config.data_processor_url is not None:
stop_data_processor(data_processor_config.data_processor_url)


def studio(port: int = 8501):
from streamlit.web import cli as stcli
Expand Down
Loading