Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ea2973d
initial commit
qgallouedec Aug 7, 2025
0f9aa4c
Merge branch 'main' into native-vlm-support
qgallouedec Aug 7, 2025
f72dd39
proper image token ids
qgallouedec Aug 7, 2025
3b78c35
fix tiny model
qgallouedec Aug 7, 2025
95d7767
consistency
qgallouedec Aug 7, 2025
3c5aab9
fix tiny model
qgallouedec Aug 7, 2025
d4a5f67
fix test
qgallouedec Aug 7, 2025
590f997
this should work
qgallouedec Aug 7, 2025
015fd2b
fix gemma
qgallouedec Aug 7, 2025
c124d73
Merge branch 'main' into native-vlm-support
sergiopaniego Aug 7, 2025
1b76e66
Merge branch 'main' into native-vlm-support
kashif Aug 7, 2025
533ba8c
dtype check and scripts update
sergiopaniego Aug 7, 2025
1b967c4
add vision requirement to test_train_vlm in SFTTrainerTester2
qgallouedec Aug 8, 2025
4f677fa
remove force option from push_to_hub in generate_tiny_models.py
qgallouedec Aug 8, 2025
d4a122b
add test case for tiny-Qwen2VLForConditionalGeneration in SFTTrainerT…
qgallouedec Aug 8, 2025
7207802
generate idefics3
qgallouedec Aug 9, 2025
5c53f48
update test
qgallouedec Aug 9, 2025
9cabca2
a lot better
qgallouedec Aug 9, 2025
aea083d
Merge branch 'main' into native-vlm-support
qgallouedec Aug 9, 2025
fda5c1e
Update trl/trainer/sft_trainer.py
qgallouedec Aug 9, 2025
fb203ff
Update scripts/generate_tiny_models.py
qgallouedec Aug 9, 2025
8d5ff49
clean test
qgallouedec Aug 9, 2025
df05be1
Add llava_instruct_mix.py dataset processing script
qgallouedec Aug 9, 2025
b9a01a2
doc
qgallouedec Aug 9, 2025
230d691
update doc
qgallouedec Aug 9, 2025
0ed5e80
fix llava mix dataset
qgallouedec Aug 9, 2025
329667c
fix doc
qgallouedec Aug 9, 2025
3b99cee
remove training vlm sft
qgallouedec Aug 9, 2025
8aa2381
imageS
qgallouedec Aug 9, 2025
4be3118
Update docs/source/sft_trainer.md
qgallouedec Aug 9, 2025
e06b2fb
Small docs nits
sergiopaniego Aug 11, 2025
fbcc78d
Updated sft_vlm.py example
sergiopaniego Aug 11, 2025
055a86a
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
15a2605
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
30b7a2d
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
dbc4b65
Merge branch 'main' into native-vlm-support
qgallouedec Aug 12, 2025
fa05ed2
style
qgallouedec Aug 12, 2025
3ae6682
ignore failing test
qgallouedec Aug 12, 2025
f9c1fec
Clarify behavior of `skip_prepare_dataset` for VLM models in SFTConfi…
qgallouedec Aug 12, 2025
089b732
Add documentation for DataCollator classes in SFTTrainer
qgallouedec Aug 12, 2025
e9c3b82
new tiny style
qgallouedec Aug 13, 2025
7c070b9
mnior + clean
qgallouedec Aug 13, 2025
e8ef1d3
fix tiny qwen2
qgallouedec Aug 13, 2025
eb22b9c
fix doc and comments
qgallouedec Aug 13, 2025
13ed3ee
final example cleaning
qgallouedec Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@
title: Detoxifying a Language Model
- local: multi_adapter_rl
title: Multi Adapter RLHF
- local: training_vlm_sft
title: Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
title: Examples
- sections:
- sections: # Sorted alphabetically
Expand Down
3 changes: 1 addition & 2 deletions docs/source/dataset_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -1033,7 +1033,7 @@ dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions

## Vision datasets

Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.

A conversational vision dataset differs from a standard conversational dataset in two key ways:

Expand Down Expand Up @@ -1061,4 +1061,3 @@ An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](h
width="100%"
height="560px"
></iframe>

127 changes: 27 additions & 100 deletions docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ This post-training method was contributed by [Younes Belkada](https://huggingfac
This example demonstrates how to train a language model using the [`SFTTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [Capybara dataset](https://huggingface.co/datasets/trl-lib/Capybara), a compact, diverse multi-turn dataset to benchmark reasoning and generalization.

```python
from trl import SFTTrainer, SFTConfig
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

trainer = SFTTrainer(
Expand Down Expand Up @@ -91,7 +91,7 @@ This section breaks down how SFT works in practice, covering the key steps: **pr
### Preprocessing and tokenization

During training, each example is expected to contain a **text field** or a **(prompt, completion)** pair, depending on the dataset format. For more details on the expected formats, see [Dataset formats](dataset_formats).
The `SFTTrainer` tokenizes each input using the model's tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization.
The [`SFTTrainer`] tokenizes each input using the model's tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization.

### Computing the loss

Expand Down Expand Up @@ -241,7 +241,7 @@ Unsloth is an open‑source framework for fine‑tuning and reinforcement learni
This example shows how to transform the [Qwen 3 0.6B Base](https://huggingface.co/Qwen/Qwen3-0.6B-Base) model into an instruction-following model using the [Capybara dataset](https://huggingface.co/datasets/trl-lib/Capybara) and a chat template from [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B). The SFT Trainer automatically handles tokenizer updates and special token configuration.

```python
from trl import SFTTrainer, SFTConfig
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

trainer = SFTTrainer(
Expand Down Expand Up @@ -280,122 +280,41 @@ Alternatively, use the structured conversation format (recommended):

## Tool Calling with SFT

The SFT trainer fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
The [`SFTTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:

* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages)
* The list of available tools in the `tools` column, typically provided as JSON schemas

For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section.

## Extending `SFTTrainer` for Vision Language Models
## Training Vision Language Models

`SFTTrainer` does not yet inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py), which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.

### Preparing the Data

The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images:

```python
images = ["obama.png"]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Who is this?"},
{"type": "image"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Barack Obama"}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "What is he famous for?"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "He is the 44th President of the United States."}
]
}
]
```

To illustrate how this data format will be processed using the LLaVA model, you can use the following code:

```python
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
print(processor.apply_chat_template(messages, tokenize=False))
```

The output will be formatted as follows:

```txt
Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States.
```

<iframe src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train" frameborder="0" width="100%" height="560px"></iframe>

### A custom collator for processing multi-modal data

Unlike the default behavior of [`SFTTrainer`], processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator:

```python
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]

# Tokenize the texts and process the images
batch = processor(images=images, text=texts, return_tensors="pt", padding=True)

# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels

return batch
```

We can verify that the collator works as expected by running the following code:
[`SFTTrainer`] fully supports training Vision-Language Models (VLMs). To train a VLM, you need to provide a dataset with an additional `images` column containing the images to be processed. For more information on the expected dataset structure, see the [Dataset Format — Vision Dataset](dataset_formats#vision-dataset) section.
An example of such a dataset is the [LLaVA Instruct Mix](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).

```python
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example
collated_data = collate_fn(examples)
print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])
trainer = SFTTrainer(
model="Qwen/Qwen2.5-VL-3B-Instruct",
args=SFTConfig(max_length=None),
train_dataset=load_dataset("trl-lib/llava-instruct-mix", split="train"),
)
trainer.train()
```

### Training the vision-language model
<Tip>

Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the [`SFTConfig`], specifically `remove_unused_columns` and `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`.
For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_length=None` in the [`SFTConfig`]. This allows the model to process the full sequence length without truncating image tokens.

```python
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=train_dataset,
processing_class=processor,
)
SFTConfig(max_length=None, ...)
```

A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py).
Only use `max_length` when you've verified that truncation won't remove image tokens for the entire dataset.

* [Experiment tracking](https://wandb.ai/huggingface/trl/runs/2b2c5l7s)
* [Trained model](https://huggingface.co/HuggingFaceH4/sft-llava-1.5-7b-hf)
</Tip>

## SFTTrainer

Expand All @@ -407,3 +326,11 @@ A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vs
## SFTConfig

[[autodoc]] SFTConfig

## DataCollatorForLanguageModeling

[[autodoc]] trainer.sft_trainer.DataCollatorForLanguageModeling

## DataCollatorForVisionLanguageModeling

[[autodoc]] trainer.sft_trainer.DataCollatorForVisionLanguageModeling
Loading
Loading