diff --git a/trinity/service/data_juicer/server/session.py b/trinity/service/data_juicer/server/session.py index 5d70e0e5f1..c641f573d8 100644 --- a/trinity/service/data_juicer/server/session.py +++ b/trinity/service/data_juicer/server/session.py @@ -11,6 +11,7 @@ group_scores, parse_config, ) +from trinity.utils.log import get_logger def extract_metrics(dataset: Dataset) -> Dict: @@ -39,6 +40,12 @@ def __init__(self, config: DJConfig): "usage_frequency": -0.5, "quality": 1.0, } + self.order_method = self.config.order_method + self.order_args = self.config.order_args or { + "folding_layers": 3, + } + + self.logger = get_logger(__name__) def process_experience(self, ds: Dataset) -> Tuple[Dataset, Dict]: """Process a batch of experiences. @@ -74,11 +81,45 @@ def process_task(self) -> Dict: ) ds = ds.map(compute_priority_scores_func) # sort the output dataset in priority - if "priority" in ds.features: - top_k = self.config.top_k - if top_k == -1: - top_k = ds.num_rows - ds = ds.sort("priority", reverse=True).take(top_k) + ds = self.order_task(ds) # export to the target directory ds.to_json(os.path.join(self.config.output_dir, "output.jsonl")) # type: ignore [arg-type] return {"sample_num": ds.num_rows} + + def order_task(self, dataset: Dataset) -> Dataset: + """ + Order the dataset with specified method. + """ + # check if priority field exists + if "priority" not in dataset.features and self.order_method in {"sort", "folding"}: + self.logger.warning( + f'"priority" field not found for {self.order_method}. Use "keep" instead.' + ) + self.order_method = "keep" + + # get top-k + top_k = self.config.top_k + if top_k == -1: + top_k = dataset.num_rows + + if self.order_method == "keep": + # keep the original order + return dataset + elif self.order_method == "shuffle": + # shuffle the dataset + return dataset.shuffle() + elif self.order_method == "sort": + # sort the dataset acording to priority + return dataset.sort("priority", reverse=True).take(top_k) + elif self.order_method == "folding": + # folding the dataset to repeat the curriculum learning + # Reference: https://arxiv.org/abs/2506.21545 + sorted_dataset = dataset.sort("priority", reverse=True).take(top_k) + folding_layers = self.order_args.get("folding_layers", 3) + folding_indices = [] + for j in range(folding_layers): + partition = list(range(j, dataset.num_rows, folding_layers)) + folding_indices.extend(partition) + return sorted_dataset.select(folding_indices) + else: + raise ValueError(f"Invalid order method: {self.order_method}") diff --git a/trinity/service/data_juicer/server/utils.py b/trinity/service/data_juicer/server/utils.py index 7ea7961bf4..0a8d6cc784 100644 --- a/trinity/service/data_juicer/server/utils.py +++ b/trinity/service/data_juicer/server/utils.py @@ -24,6 +24,8 @@ class DJConfig(BaseModel): target_fields: List[str] = [] # fields in the output dataset priority_weights: Dict[str, float] = {} # weights for priority computing top_k: int = -1 # number of samples to select after task pipeline. -1 means all + order_method: Literal["keep", "shuffle", "sort", "folding"] = "sort" + order_args: Dict = {} @model_validator(mode="after") def check_dj_config(self):