Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding orpo training #1210

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ACKNOWLEDGMENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ MLX Examples was developed with contributions from the following individuals:
- Markus Enzweiler: Added the `cvae` examples.
- Prince Canuma: Helped add support for `Starcoder2` models.
- Shiyu Li: Added the `Segment Anything Model`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning`.
- Gökdeniz Gülmez: Added support for `MiniCPM`, `Helium`, `Mamba version 1` and support for `full-fine-tuning` and `Odds Ratio Preference Optimization (ORPO)` training.
53 changes: 53 additions & 0 deletions llms/mlx_lm/LORA.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- Gemma
- OLMo
- MiniCPM
- Mamba
- InternLM2

## Contents

- [Run](#Run)
- [Fine-tune](#Fine-tune)
- [ORPO Training](#ORPO Training)
- [Evaluate](#Evaluate)
- [Generate](#Generate)
- [Fuse](#Fuse)
Expand Down Expand Up @@ -76,6 +78,57 @@ You can specify the output location with `--adapter-path`.
You can resume fine-tuning with an existing adapter with
`--resume-adapter-file <path_to_adapters.safetensors>`.

### ORPO Training

Odds Ratio Preference Optimization (ORPO) training fine-tunes models using human preference data. Usage:

```shell
mlx_lm.lora \
--model <path_to_model> \
--train \
--training-mode orpo \
--data <path_to_data> \
--beta 0.1
```

Parameters:

- `--beta`: Temperature for logistic function (default: 0.1)

Data format (JSONL):

```jsonl
# Basic format with string responses
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}

# With custom preference score
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response", "preference_score": 8.0}

# With system message
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response", "system": "System instruction"}

# With full conversation objects
{
"prompt": "User prompt",
"chosen": {
"messages": [
{"role": "system", "content": "System instruction"},
{"role": "user", "content": "User message"},
{"role": "assistant", "content": "Assistant response"}
]
},
"rejected": {
"messages": [
{"role": "system", "content": "System instruction"},
{"role": "user", "content": "User message"},
{"role": "assistant", "content": "Assistant response"}
]
}
}
```

The trainer assigns binary rewards (1.0 chosen, 0.0 rejected) if no explicit rewards provided via `preference_score`.

### Evaluate

To compute test set perplexity use:
Expand Down
13 changes: 13 additions & 0 deletions llms/mlx_lm/examples/lora_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora

# The training-mode: "normal", or "dpo"
training_mode: normal

# If you set training_mode to "dpo"
# beta: 0.1
# The dpo-lodd-type: "sigmoid", "hinge", "ipo", or "dpop"
# dpo_loss_type: "sigmoid"
# is_reference_free: False
# delta: 50.0
# If reference_model_path is not given it will just use the same model
# reference_model_path: "mlx_model"
# train_bias_only: False

# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"

Expand Down
131 changes: 93 additions & 38 deletions llms/mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.orpo_trainer import ORPOTrainingArgs, evaluate_orpo, train_orpo
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
Expand Down Expand Up @@ -43,6 +44,7 @@
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"training_mode": "normal",
"data": "data/",
"seed": 0,
"num_layers": 16,
Expand All @@ -62,6 +64,12 @@
"grad_checkpoint": False,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"beta": 0.1,
"dpo_loss_type": "sigmoid",
"is_reference_free": False,
"delta": 50.0,
"reference_model_path": None,
"reward_scaling": 1.0,
}


Expand Down Expand Up @@ -94,6 +102,12 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--training-mode",
type=str,
choices=["normal", "dpo", "orpo"],
help="Training mode: normal, DPO or ORPO.",
)
parser.add_argument(
"--num-layers",
type=int,
Expand Down Expand Up @@ -135,7 +149,7 @@ def build_parser():
parser.add_argument(
"--test",
action="store_true",
help="Evaluate on the test set after training",
help="Evaluate on the test set after training.",
default=None,
)
parser.add_argument(
Expand All @@ -152,15 +166,21 @@ def build_parser():
"-c",
"--config",
type=str,
help="A YAML configuration file with the training options",
help="A YAML configuration file with the training options.",
)
parser.add_argument(
"--grad-checkpoint",
action="store_true",
help="Use gradient checkpointing to reduce memory use.",
default=None,
)
parser.add_argument("--seed", type=int, help="The PRNG seed")
parser.add_argument("--beta", type=float)
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpo"])
parser.add_argument("--is-reference-free", action="store_true")
parser.add_argument("--delta", type=float)
parser.add_argument("--reference-model-path", type=str)
parser.add_argument("--reward-scaling", type=float, help="Scaling factor for offline rewards.")
parser.add_argument("--seed", type=int, help="The PRNG seed.")
return parser


Expand Down Expand Up @@ -200,52 +220,87 @@ def train_model(
adapter_file = adapter_path / "adapters.safetensors"
save_config(vars(args), adapter_path / "adapter_config.json")

# init training args
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
)

model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Train model
train(
model=model,
tokenizer=tokenizer,
args=training_args,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
)

# Train model based on training mode
if args.training_mode == "orpo":
training_args = ORPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta
)

train_orpo(
model=model,
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint
)

train(
model=model,
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)


def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set):
model.eval()

test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)

test_ppl = math.exp(test_loss)
if args.training_mode == "orpo":
test_loss, test_rewards = evaluate_orpo(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta
)
print(f"Test loss {test_loss:.8f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else:
test_loss = evaluate(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
)

print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")


def run(args, training_callback: TrainingCallback = None):
Expand All @@ -263,7 +318,7 @@ def run(args, training_callback: TrainingCallback = None):
load_adapters(model, args.adapter_path)

elif args.train:
print("Training")
print(f"Training in {args.training_mode} mode")
train_model(args, model, tokenizer, train_set, valid_set, training_callback)
else:
raise ValueError("Must provide at least one of --train or --test")
Expand Down
66 changes: 64 additions & 2 deletions llms/mlx_lm/tuner/datasets.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,67 @@
import json
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from transformers import PreTrainedTokenizer

class ORPODataset:
def __init__(
self,
data: List[Dict[str, Union[str, Dict]]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
chosen_key: str = "chosen",
rejected_key: str = "rejected",
preference_score_key: str = "preference_score",
system_key: str = None
):
self._chosen_data = []
self._rejected_data = []
self._scores = []

for d in data:
if system_key and system_key in d:
base_messages = [{"role": "system", "content": d[system_key]}]
chosen_messages = base_messages + [{"role": "user", "content": d[prompt_key]}]
if isinstance(d[chosen_key], str):
chosen_messages.append({"role": "assistant", "content": d[chosen_key]})
else:
chosen_messages.extend(d[chosen_key]["messages"])
rejected_messages = base_messages + [{"role": "user", "content": d[prompt_key]}]
if isinstance(d[rejected_key], str):
rejected_messages.append({"role": "assistant", "content": d[rejected_key]})
else:
rejected_messages.extend(d[rejected_key]["messages"])
chosen_text = tokenizer.apply_chat_template(chosen_messages)
rejected_text = tokenizer.apply_chat_template(rejected_messages)
else:
chosen_text = tokenizer.apply_chat_template([
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key] if isinstance(d[chosen_key], str) else d[chosen_key]["messages"][-1]["content"]},
])
rejected_text = tokenizer.apply_chat_template([
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key] if isinstance(d[rejected_key], str) else d[rejected_key]["messages"][-1]["content"]},
])

self._chosen_data.append(chosen_text)
self._rejected_data.append(rejected_text)

if preference_score_key in d:
self._scores.append(float(d[preference_score_key]))
else:
self._scores.append(1.0)

def __len__(self):
return len(self._chosen_data)

def __getitem__(self, idx: int):
return {
"chosen": self._chosen_data[idx],
"rejected": self._rejected_data[idx],
"preference_score": self._scores[idx]
}


class Dataset:
"""
Expand Down Expand Up @@ -90,7 +148,11 @@ def create_dataset(
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
sample = data[0]
if "messages" in sample:

# Add DPO dataset support
if "chosen" in sample and "rejected" in sample:
return ORPODataset(data, tokenizer)
elif "messages" in sample:
return ChatDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
Expand Down
Loading