Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions trinity/service/data_juicer/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
group_scores,
parse_config,
)
from trinity.utils.log import get_logger


def extract_metrics(dataset: Dataset) -> Dict:
Expand Down Expand Up @@ -39,6 +40,12 @@ def __init__(self, config: DJConfig):
"usage_frequency": -0.5,
"quality": 1.0,
}
self.order_method = self.config.order_method
self.order_args = self.config.order_args or {
"folding_layers": 3,
}

self.logger = get_logger(__name__)

def process_experience(self, ds: Dataset) -> Tuple[Dataset, Dict]:
"""Process a batch of experiences.
Expand Down Expand Up @@ -74,11 +81,45 @@ def process_task(self) -> Dict:
)
ds = ds.map(compute_priority_scores_func)
# sort the output dataset in priority
if "priority" in ds.features:
top_k = self.config.top_k
if top_k == -1:
top_k = ds.num_rows
ds = ds.sort("priority", reverse=True).take(top_k)
ds = self.order_task(ds)
# export to the target directory
ds.to_json(os.path.join(self.config.output_dir, "output.jsonl")) # type: ignore [arg-type]
return {"sample_num": ds.num_rows}

def order_task(self, dataset: Dataset) -> Dataset:
"""
Order the dataset with specified method.
"""
# check if priority field exists
if "priority" not in dataset.features and self.order_method in {"sort", "folding"}:
self.logger.warning(
f'"priority" field not found for {self.order_method}. Use "keep" instead.'
)
self.order_method = "keep"

# get top-k
top_k = self.config.top_k
if top_k == -1:
top_k = dataset.num_rows

if self.order_method == "keep":
# keep the original order
return dataset
elif self.order_method == "shuffle":
# shuffle the dataset
return dataset.shuffle()
elif self.order_method == "sort":
# sort the dataset acording to priority
return dataset.sort("priority", reverse=True).take(top_k)
elif self.order_method == "folding":
# folding the dataset to repeat the curriculum learning
# Reference: https://arxiv.org/abs/2506.21545
sorted_dataset = dataset.sort("priority", reverse=True).take(top_k)
folding_layers = self.order_args.get("folding_layers", 3)
folding_indices = []
for j in range(folding_layers):
partition = list(range(j, dataset.num_rows, folding_layers))
folding_indices.extend(partition)
return sorted_dataset.select(folding_indices)
else:
raise ValueError(f"Invalid order method: {self.order_method}")
2 changes: 2 additions & 0 deletions trinity/service/data_juicer/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class DJConfig(BaseModel):
target_fields: List[str] = [] # fields in the output dataset
priority_weights: Dict[str, float] = {} # weights for priority computing
top_k: int = -1 # number of samples to select after task pipeline. -1 means all
order_method: Literal["keep", "shuffle", "sort", "folding"] = "sort"
order_args: Dict = {}

@model_validator(mode="after")
def check_dj_config(self):
Expand Down