Skip to content

Commit b890a2a

Browse files
KEP-2401: Support mutating dataset preprocessing config in SDK (kubeflow/trainer#2638)
* feat(sdk): Add InstructDataset and dataset_preprocess_config. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(doc): update dataset preprocessing API definition in KEP. Signed-off-by: Electronic-Waste <2690692950@qq.com> * chore(sdk): Add get_args_in_dataset_preprocess_config func. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(sdk): Add the prefix path for dataset. Signed-off-by: Electronic-Waste <2690692950@qq.com> * chore(manifests): Load local datasets. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(sdk): Add TorchTune prefix to dataset class. Signed-off-by: Electronic-Waste <2690692950@qq.com> * chore(initializer): Update HF dataset initializer to support data_dir and data_files. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(sdk): remove data_files and data_dir definition. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(sdk): extract data_files and data_dir from storage_uri. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(sdk): fix errors in UT. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(manifest): Update dataset.data_dir in torchtune CTRs. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(sdk): fix dataset_uri. Signed-off-by: Electronic-Waste <2690692950@qq.com> --------- Signed-off-by: Electronic-Waste <2690692950@qq.com>
1 parent 8e67ea3 commit b890a2a

File tree

5 files changed

+135
-2
lines changed

5 files changed

+135
-2
lines changed

kubeflow/trainer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
from kubeflow.trainer.types.types import (
2828
BuiltinTrainer,
2929
CustomTrainer,
30+
DataFormat,
3031
DataType,
3132
Framework,
3233
HuggingFaceDatasetInitializer,
3334
HuggingFaceModelInitializer,
3435
Initializer,
36+
TorchTuneInstructDataset,
3537
Loss,
3638
Runtime,
3739
Trainer,

kubeflow/trainer/api/trainer_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,9 @@ def train(
193193

194194
# If users choose to use a builtin trainer for post-training.
195195
elif isinstance(trainer, types.BuiltinTrainer):
196-
trainer_crd = utils.get_trainer_crd_from_builtin_trainer(trainer)
196+
trainer_crd = utils.get_trainer_crd_from_builtin_trainer(
197+
trainer, initializer
198+
)
197199

198200
else:
199201
raise ValueError(

kubeflow/trainer/constants/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,6 @@
122122

123123
# The default entrypoint for mpirun.
124124
MPI_ENTRYPOINT = "mpirun"
125+
126+
# The Instruct Datasets class in torchtune
127+
TORCHTUNE_INSTRUCT_DATASET = "torchtune.datasets.instruct_dataset"

kubeflow/trainer/types/types.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,49 @@ class DataType(Enum):
6161
FP32 = "fp32"
6262

6363

64+
# Data file type for the TorchTune LLM Trainer.
65+
class DataFormat(Enum):
66+
"""Data file type for the TorchTune LLM Trainer."""
67+
68+
JSON = "json"
69+
CSV = "csv"
70+
PARQUET = "parquet"
71+
ARROW = "arrow"
72+
TEXT = "text"
73+
XML = "xml"
74+
75+
76+
# Configuration for the TorchTune Instruct dataset.
77+
@dataclass
78+
class TorchTuneInstructDataset:
79+
"""
80+
Configuration for the custom dataset with user instruction prompts and model responses.
81+
REF: https://pytorch.org/torchtune/main/generated/torchtune.datasets.instruct_dataset.html
82+
83+
Args:
84+
source (`Optional[DataFormat]`): Data file type.
85+
split (`Optional[str]`):
86+
The split of the dataset to use. You can use this argument to load a subset of
87+
a given split, e.g. split="train[:10%]". Default is `train`.
88+
train_on_input (`Optional[bool]`):
89+
Whether the model is trained on the user prompt or not. Default is False.
90+
new_system_prompt (`Optional[str]`):
91+
The new system prompt to use. If specified, prepend a system message.
92+
This can serve as instructions to guide the model response. Default is None.
93+
column_map (`Optional[Dict[str, str]]`):
94+
A mapping to change the expected "input" and "output" column names to the actual
95+
column names in the dataset. Keys should be "input" and "output" and values should
96+
be the actual column names. Default is None, keeping the default "input" and
97+
"output" column names.
98+
"""
99+
100+
source: Optional[DataFormat] = None
101+
split: Optional[str] = None
102+
train_on_input: Optional[bool] = None
103+
new_system_prompt: Optional[str] = None
104+
column_map: Optional[Dict[str, str]] = None
105+
106+
64107
# Configuration for the TorchTune LLM Trainer.
65108
@dataclass
66109
class TorchTuneConfig:
@@ -78,6 +121,8 @@ class TorchTuneConfig:
78121
loss (`Optional[Loss]`): The loss algorithm we use to fine-tune the LLM,
79122
e.g. `torchtune.modules.loss.CEWithChunkedOutputLoss`.
80123
num_nodes (`Optional[int]`): The number of nodes to use for training.
124+
dataset_preprocess_config (`Optional[TorchTuneInstructDataset]`):
125+
Configuration for the dataset preprocessing.
81126
resources_per_node (`Optional[Dict]`): The computing resources to allocate per node.
82127
"""
83128

@@ -86,6 +131,7 @@ class TorchTuneConfig:
86131
epochs: Optional[int] = None
87132
loss: Optional[Loss] = None
88133
num_nodes: Optional[int] = None
134+
dataset_preprocess_config: Optional[TorchTuneInstructDataset] = None
89135
resources_per_node: Optional[Dict] = None
90136

91137

kubeflow/trainer/utils/utils.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import inspect
1616
import os
1717
import queue
18+
import re
1819
import textwrap
1920
import threading
2021
from typing import Any, Callable, Dict, List, Optional, Tuple
22+
from urllib.parse import urlparse
2123

2224
import kubeflow.trainer.models as models
2325
from kubeflow.trainer.constants import constants
@@ -327,6 +329,7 @@ def get_entrypoint_using_train_func(
327329

328330
def get_args_using_torchtune_config(
329331
fine_tuning_config: types.TorchTuneConfig,
332+
initializer: Optional[types.Initializer] = None,
330333
) -> Tuple[List[str], List[str]]:
331334
"""
332335
Get the Trainer args from the TorchTuneConfig.
@@ -352,6 +355,32 @@ def get_args_using_torchtune_config(
352355
if fine_tuning_config.loss:
353356
args.append(f"loss={fine_tuning_config.loss}")
354357

358+
# Override the data dir or data files if it is provided.
359+
if isinstance(initializer, types.Initializer) and isinstance(
360+
initializer.dataset, types.HuggingFaceDatasetInitializer
361+
):
362+
storage_uri = (
363+
"hf://" + initializer.dataset.storage_uri
364+
if not initializer.dataset.storage_uri.startswith("hf://")
365+
else initializer.dataset.storage_uri
366+
)
367+
storage_uri_parsed = urlparse(storage_uri)
368+
relative_path = re.sub(r"^/[^/]+", "", storage_uri_parsed.path)
369+
370+
if "." in relative_path:
371+
args.append(
372+
f"dataset.data_files={os.path.join(constants.DATASET_PATH, relative_path)}"
373+
)
374+
else:
375+
args.append(
376+
f"dataset.data_dir={os.path.join(constants.DATASET_PATH, relative_path)}"
377+
)
378+
379+
if fine_tuning_config.dataset_preprocess_config:
380+
args += get_args_in_dataset_preprocess_config(
381+
fine_tuning_config.dataset_preprocess_config
382+
)
383+
355384
return constants.DEFAULT_TORCHTUNE_COMMAND, args
356385

357386

@@ -390,6 +419,7 @@ def get_trainer_crd_from_custom_trainer(
390419

391420
def get_trainer_crd_from_builtin_trainer(
392421
trainer: types.BuiltinTrainer,
422+
initializer: Optional[types.Initializer] = None,
393423
) -> models.TrainerV1alpha1Trainer:
394424
"""
395425
Get the Trainer CRD from the builtin trainer.
@@ -413,7 +443,7 @@ def get_trainer_crd_from_builtin_trainer(
413443
# the torchtune config in the runtime plugin.
414444
# Ref:https://github.com/kubeflow/trainer/tree/master/docs/proposals/2401-llm-trainer-v2
415445
trainer_crd.command, trainer_crd.args = get_args_using_torchtune_config(
416-
trainer.config
446+
trainer.config, initializer
417447
)
418448

419449
return trainer_crd
@@ -507,3 +537,53 @@ def get_log_queue_pool(log_streams: List[Any]) -> List[queue.Queue]:
507537
pool.append(q)
508538
threading.Thread(target=wrap_log_stream, args=(q, log_stream)).start()
509539
return pool
540+
541+
542+
def get_args_in_dataset_preprocess_config(
543+
dataset_preprocess_config: types.TorchTuneInstructDataset,
544+
) -> List[str]:
545+
"""
546+
Get the args from the given dataset preprocess config.
547+
"""
548+
args = []
549+
550+
if not isinstance(dataset_preprocess_config, types.TorchTuneInstructDataset):
551+
raise ValueError(
552+
f"Invalid dataset preprocess config type: {type(dataset_preprocess_config)}."
553+
)
554+
555+
# Override the dataset type field in the torchtune config.
556+
args.append(f"dataset={constants.TORCHTUNE_INSTRUCT_DATASET}")
557+
558+
# Override the dataset source field if it is provided.
559+
if dataset_preprocess_config.source:
560+
if not isinstance(dataset_preprocess_config.source, types.DataFormat):
561+
raise ValueError(
562+
f"Invalid data format: {dataset_preprocess_config.source}."
563+
)
564+
565+
args.append(f"dataset.source={dataset_preprocess_config.source}")
566+
567+
# Override the data dir or data files if it is provided.
568+
569+
# Override the split field if it is provided.
570+
if dataset_preprocess_config.split:
571+
args.append(f"dataset.split={dataset_preprocess_config.split}")
572+
573+
# Override the train_on_input field if it is provided.
574+
if dataset_preprocess_config.train_on_input:
575+
args.append(
576+
f"dataset.train_on_input={dataset_preprocess_config.train_on_input}"
577+
)
578+
579+
# Override the new_system_prompt field if it is provided.
580+
if dataset_preprocess_config.new_system_prompt:
581+
args.append(
582+
f"dataset.new_system_prompt={dataset_preprocess_config.new_system_prompt}"
583+
)
584+
585+
# Override the column_map field if it is provided.
586+
if dataset_preprocess_config.column_map:
587+
args.append(f"dataset.column_map={dataset_preprocess_config.column_map}")
588+
589+
return args

0 commit comments

Comments
 (0)