Skip to content

Commit 80f6b0e

Browse files
feat: Support LoraConfig in TorchTune BuiltinTrainer (#102)
* feat: Add lora types. Signed-off-by: Electronic-Waste <2690692950@qq.com> * chore: propagate lora parameters in command. Signed-off-by: Electronic-Waste <2690692950@qq.com> * feat(lora): Add support for QLoRA. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(lora): remove extra quote symbol in lora attn module. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(lora): replace direct field override with field map. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(lora): remove extra flags. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(lora): fix wrong default list value in LoraConfig. Signed-off-by: Electronic-Waste <2690692950@qq.com> * fix(lora): rmeove outdated code. Signed-off-by: Electronic-Waste <2690692950@qq.com> * test(backend): Add test for lora. Signed-off-by: Electronic-Waste <2690692950@qq.com> --------- Signed-off-by: Electronic-Waste <2690692950@qq.com>
1 parent 38390f7 commit 80f6b0e

File tree

4 files changed

+121
-5
lines changed

4 files changed

+121
-5
lines changed

kubeflow/trainer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
HuggingFaceDatasetInitializer,
3333
HuggingFaceModelInitializer,
3434
Initializer,
35+
LoraConfig,
3536
Loss,
3637
Runtime,
3738
RuntimeTrainer,
@@ -49,6 +50,7 @@
4950
"HuggingFaceDatasetInitializer",
5051
"HuggingFaceModelInitializer",
5152
"Initializer",
53+
"LoraConfig",
5254
"Loss",
5355
"MODEL_PATH",
5456
"Runtime",

kubeflow/trainer/backends/kubernetes/backend_test.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,12 +238,14 @@ def get_custom_trainer(
238238
)
239239

240240

241-
def get_builtin_trainer() -> models.TrainerV1alpha1Trainer:
241+
def get_builtin_trainer(
242+
args: list[str],
243+
) -> models.TrainerV1alpha1Trainer:
242244
"""
243245
Get the builtin trainer for the TrainJob.
244246
"""
245247
return models.TrainerV1alpha1Trainer(
246-
args=["batch_size=2", "epochs=2", "loss=Loss.CEWithChunkedOutputLoss"],
248+
args=args,
247249
command=["tune", "run"],
248250
numNodes=2,
249251
)
@@ -707,7 +709,40 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
707709
expected_output=get_train_job(
708710
runtime_name=TORCH_TUNE_RUNTIME,
709711
train_job_name=TRAIN_JOB_WITH_BUILT_IN_TRAINER,
710-
train_job_trainer=get_builtin_trainer(),
712+
train_job_trainer=get_builtin_trainer(
713+
args=["batch_size=2", "epochs=2", "loss=Loss.CEWithChunkedOutputLoss"],
714+
),
715+
),
716+
),
717+
TestCase(
718+
name="valid flow with built in trainer and lora config",
719+
expected_status=SUCCESS,
720+
config={
721+
"trainer": types.BuiltinTrainer(
722+
config=types.TorchTuneConfig(
723+
num_nodes=2,
724+
peft_config=types.LoraConfig(
725+
apply_lora_to_mlp=True,
726+
lora_rank=8,
727+
lora_alpha=16,
728+
lora_dropout=0.1,
729+
),
730+
),
731+
),
732+
"runtime": TORCH_TUNE_RUNTIME,
733+
},
734+
expected_output=get_train_job(
735+
runtime_name=TORCH_TUNE_RUNTIME,
736+
train_job_name=TRAIN_JOB_WITH_BUILT_IN_TRAINER,
737+
train_job_trainer=get_builtin_trainer(
738+
args=[
739+
"model.apply_lora_to_mlp=True",
740+
"model.lora_rank=8",
741+
"model.lora_alpha=16",
742+
"model.lora_dropout=0.1",
743+
"model.lora_attn_modules=[q_proj,v_proj,output_proj]",
744+
],
745+
),
711746
),
712747
),
713748
TestCase(

kubeflow/trainer/types/types.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,44 @@ class TorchTuneInstructDataset:
110110
column_map: Optional[dict[str, str]] = None
111111

112112

113+
@dataclass
114+
class LoraConfig:
115+
"""Configuration for the LoRA/QLoRA/DoRA.
116+
REF: https://meta-pytorch.org/torchtune/main/tutorials/memory_optimizations.html
117+
118+
Args:
119+
apply_lora_to_mlp (`Optional[bool]`):
120+
Whether to apply LoRA to the MLP in each transformer layer.
121+
apply_lora_to_output (`Optional[bool]`):
122+
Whether to apply LoRA to the model's final output projection.
123+
lora_attn_modules (`list[str]`):
124+
A list of strings specifying which layers of the model to apply LoRA,
125+
default is ["q_proj", "v_proj", "output_proj"]:
126+
1. "q_proj" applies LoRA to the query projection layer.
127+
2. "k_proj" applies LoRA to the key projection layer.
128+
3. "v_proj" applies LoRA to the value projection layer.
129+
4. "output_proj" applies LoRA to the attention output projection layer.
130+
lora_rank (`Optional[int]`): The rank of the low rank decomposition.
131+
lora_alpha (`Optional[int]`):
132+
The scaling factor that adjusts the magnitude of the low-rank matrices' output.
133+
lora_dropout (`Optional[float]`):
134+
The probability of applying Dropout to the low rank updates.
135+
quantize_base (`Optional[bool]`): Whether to enable model quantization.
136+
use_dora (`Optional[bool]`): Whether to enable DoRA.
137+
"""
138+
139+
apply_lora_to_mlp: Optional[bool] = None
140+
apply_lora_to_output: Optional[bool] = None
141+
lora_attn_modules: list[str] = field(
142+
default_factory=lambda: ["q_proj", "v_proj", "output_proj"]
143+
)
144+
lora_rank: Optional[int] = None
145+
lora_alpha: Optional[int] = None
146+
lora_dropout: Optional[float] = None
147+
quantize_base: Optional[bool] = None
148+
use_dora: Optional[bool] = None
149+
150+
113151
# Configuration for the TorchTune LLM Trainer.
114152
@dataclass
115153
class TorchTuneConfig:
@@ -127,6 +165,9 @@ class TorchTuneConfig:
127165
loss (`Optional[Loss]`): The loss algorithm we use to fine-tune the LLM,
128166
e.g. `torchtune.modules.loss.CEWithChunkedOutputLoss`.
129167
num_nodes (`Optional[int]`): The number of nodes to use for training.
168+
peft_config (`Optional[LoraConfig]`):
169+
Configuration for the PEFT(Parameter-Efficient Fine-Tuning),
170+
including LoRA/QLoRA/DoRA, etc.
130171
dataset_preprocess_config (`Optional[TorchTuneInstructDataset]`):
131172
Configuration for the dataset preprocessing.
132173
resources_per_node (`Optional[Dict]`): The computing resources to allocate per node.
@@ -137,6 +178,7 @@ class TorchTuneConfig:
137178
epochs: Optional[int] = None
138179
loss: Optional[Loss] = None
139180
num_nodes: Optional[int] = None
181+
peft_config: Optional[LoraConfig] = None
140182
dataset_preprocess_config: Optional[TorchTuneInstructDataset] = None
141183
resources_per_node: Optional[dict] = None
142184

kubeflow/trainer/utils/utils.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,13 +477,50 @@ def get_args_using_torchtune_config(
477477
else:
478478
args.append(f"dataset.data_dir={os.path.join(constants.DATASET_PATH, relative_path)}")
479479

480+
if fine_tuning_config.peft_config:
481+
args += get_args_from_peft_config(fine_tuning_config.peft_config)
482+
480483
if fine_tuning_config.dataset_preprocess_config:
481-
args += get_args_in_dataset_preprocess_config(fine_tuning_config.dataset_preprocess_config)
484+
args += get_args_from_dataset_preprocess_config(
485+
fine_tuning_config.dataset_preprocess_config
486+
)
487+
488+
return args
489+
490+
491+
def get_args_from_peft_config(peft_config: types.LoraConfig) -> list[str]:
492+
"""
493+
Get the args from the given PEFT config.
494+
"""
495+
args = []
496+
497+
if not isinstance(peft_config, types.LoraConfig):
498+
raise ValueError(f"Invalid PEFT config type: {type(peft_config)}.")
499+
500+
field_map = {
501+
"apply_lora_to_mlp": "model.apply_lora_to_mlp",
502+
"apply_lora_to_output": "model.apply_lora_to_output",
503+
"lora_rank": "model.lora_rank",
504+
"lora_alpha": "model.lora_alpha",
505+
"lora_dropout": "model.lora_dropout",
506+
"quantize_base": "model.quantize_base",
507+
"use_dora": "model.use_dora",
508+
}
509+
510+
# Override the PEFT fields if they are provided.
511+
for field, arg_name in field_map.items():
512+
value = getattr(peft_config, field, None)
513+
if value:
514+
args.append(f"{arg_name}={value}")
515+
516+
# Override the LoRA attention modules if they are provided.
517+
if peft_config.lora_attn_modules:
518+
args.append(f"model.lora_attn_modules=[{','.join(peft_config.lora_attn_modules)}]")
482519

483520
return args
484521

485522

486-
def get_args_in_dataset_preprocess_config(
523+
def get_args_from_dataset_preprocess_config(
487524
dataset_preprocess_config: types.TorchTuneInstructDataset,
488525
) -> list[str]:
489526
"""

0 commit comments

Comments
 (0)