Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding dpo training #1209

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1ff7888
initial commit
Goekdeniz-Guelmez Jan 18, 2025
582f979
fixing reference model loading and freezing
Goekdeniz-Guelmez Jan 18, 2025
1b4e196
update LORA.md
Goekdeniz-Guelmez Jan 18, 2025
06a9f5d
update lora_config.yaml
Goekdeniz-Guelmez Jan 18, 2025
040f7c3
update ACKNOWLEDGMENTS.md
Goekdeniz-Guelmez Jan 18, 2025
51fd621
nits
Goekdeniz-Guelmez Jan 19, 2025
477000e
removing unneeded functions
Goekdeniz-Guelmez Jan 19, 2025
69a8f11
Merge branch 'ml-explore:main' into adding-dpo-training
Goekdeniz-Guelmez Jan 22, 2025
b0ece88
nits
Goekdeniz-Guelmez Jan 22, 2025
e1d549b
nits
Goekdeniz-Guelmez Jan 22, 2025
aefe4ba
nits
Goekdeniz-Guelmez Jan 22, 2025
54fcd8e
update DPODataset and added in system field too
Goekdeniz-Guelmez Jan 24, 2025
531c334
nits
Goekdeniz-Guelmez Jan 24, 2025
86b315f
nits and quality of life improvements
Goekdeniz-Guelmez Jan 24, 2025
0ff1289
updates
Goekdeniz-Guelmez Jan 25, 2025
4d0e52f
more metrics
Goekdeniz-Guelmez Jan 26, 2025
557649d
removing tokenizer and updates
Goekdeniz-Guelmez Jan 26, 2025
9e5482e
Merge branch 'main' into adding-dpo-training
Goekdeniz-Guelmez Jan 26, 2025
b3d6fc3
Merge branch 'ml-explore:main' into adding-dpo-training
Goekdeniz-Guelmez Jan 29, 2025
b31d9cb
removing is-reference-free argument
Goekdeniz-Guelmez Jan 30, 2025
b379359
small fix
Goekdeniz-Guelmez Jan 31, 2025
5998272
cleaning up some namings
Goekdeniz-Guelmez Jan 31, 2025
a03d434
clean up
Goekdeniz-Guelmez Jan 31, 2025
fbb51f6
small fix
Goekdeniz-Guelmez Feb 1, 2025
9b489a6
Merge branch 'ml-explore:main' into adding-dpo-training
Goekdeniz-Guelmez Feb 4, 2025
c2fcb67
fix testing
Goekdeniz-Guelmez Feb 4, 2025
43f2451
nits
Goekdeniz-Guelmez Feb 4, 2025
069431b
adding test_ppl in testing
Goekdeniz-Guelmez Feb 4, 2025
b1c1e13
nice printing the test metrics
Goekdeniz-Guelmez Feb 4, 2025
6710671
Merge branch 'ml-explore:main' into adding-dpo-training
Goekdeniz-Guelmez Feb 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning` and `Direct Preference Optimization (DPO)`.
34 changes: 34 additions & 0 deletions llms/mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- Gemma
- OLMo
- MiniCPM
- Mamba
- InternLM2

## Contents

- [Run](#Run)
- [Fine-tune](#Fine-tune)
- [DPO-Training](#DPOTraining)
- [Evaluate](#Evaluate)
- [Generate](#Generate)
- [Fuse](#Fuse)
Expand Down Expand Up @@ -76,6 +78,38 @@ You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.

### DPO Training

Direct Preference Optimization (DPO) training allows you to fine-tune models using human preference data. To use DPO training, set the training mode to 'dpo':

```shell
mlx_lm.lora \
--model <path_to_model> \
--train \
--training-mode dpo \
--data <path_to_data> \
--beta 0.1
```

The DPO training accepts the following additional parameters:

- `--beta`: Controls the strength of the DPO loss (default: 0.1)
- `--dpo-loss-type`: Choose between "sigmoid" (default), "hinge", "ipo", or "dpop" loss functions
- `--delta`: Margin parameter for hinge loss (default: 50.0)
- `--reference-model-path`: Path to a reference model for DPO training

For DPO training, the data should be in JSONL format with the following structure:

```jsonl
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
```

if the Prompt template accept a system message, you can extend the Dataset with a additional "system" field.

```jsonl
{"system": "You are a helpfull assistant", "prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
```

### Evaluate

To compute test set perplexity use:
Expand Down
13 changes: 13 additions & 0 deletions llms/mlx_lm/examples/lora_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora

# The training-mode: "normal", or "dpo"
training_mode: normal

# If you set training_mode to "dpo"
# beta: 0.1
# The dpo-lodd-type: "sigmoid", "hinge", "ipo", or "dpop"
# dpo_loss_type: "sigmoid"
# is_reference_free: False
# delta: 50.0
# If reference_model_path is not given it will just use the same model
# reference_model_path: "mlx_model"
# train_bias_only: False

# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"

Expand Down
166 changes: 133 additions & 33 deletions llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.dpo_trainer import DPOTrainingArgs, evaluate_dpo, train_dpo
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
Expand Down Expand Up @@ -43,6 +44,7 @@
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"training_mode": "normal",
"data": "data/",
"seed": 0,
"num_layers": 16,
Expand All @@ -62,6 +64,12 @@
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},

# DPO args
"beta": 0.1,
"dpo_loss_type": "sigmoid",
"delta": 50.0,
"reference_model_path": None
}


Expand Down Expand Up @@ -94,6 +102,12 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "dpo"],
help="Training mode: normal or DPO",
)
parser.add_argument(
"--num-layers",
type=int,
Expand Down Expand Up @@ -161,6 +175,34 @@ def build_parser():
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")

# DPO args
parser.add_argument(
"--beta",
type=float,
help="Temperature parameter for DPO training.",
default=0.1
)
parser.add_argument(
"--dpo-loss-type",
type=str,
help="DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'.",
choices=["sigmoid", "hinge", "ipo", "dpop"],
default="sigmoid"
)
parser.add_argument(
"--delta",
type=float,
help="Delta parameter for DPOP loss type.",
default=50.0
)
parser.add_argument(
"--reference-model-path",
type=str,
help="Path to reference model weights. If None, uses the same model.",
default=None
)

return parser


Expand Down Expand Up @@ -200,52 +242,110 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")

# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)

model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)
if args.training_mode == "dpo":
training_args = DPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
loss_type=args.dpo_loss_type,
delta=args.delta,
reference_model_path=args.reference_model_path
)

if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model, _ = load(args.model)

train_dpo(
model=model,
ref_model=reference_model.freeze(),
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint
)

train(
model=model,
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)


def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()

test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)
if args.training_mode == "dpo":
if args.reference_model_path:
reference_model, _ = load(args.reference_model_path)
else:
reference_model = model

test_loss, _, _, test_metrics = evaluate_dpo(
model=model,
ref_model=reference_model,
dataset=test_set,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
delta=args.delta,
loss_type=args.dpo_loss_type,
)

test_ppl = math.exp(test_loss)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}")
print("DPO Test Metrics:")
for metric_name, metric_value in test_metrics.items():
print(f" {metric_name}: {float(metric_value):.3f}")

else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)

test_ppl = math.exp(test_loss)
test_ppl = math.exp(test_loss)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")


def run(args, training_callback: TrainingCallback = None):
Expand All @@ -263,7 +363,7 @@ def run(args, training_callback: TrainingCallback = None):
load_adapters(model, args.adapter_path)

elif args.train:
print("Training")
print(f"Training in {args.training_mode} mode")
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
else:
raise ValueError("Must provide at least one of --train or --test")
Expand Down
Loading