Skip to content

Commit fb1ef04

Browse files
committed
This commit updates the converter to work with recent updates to jobconfig and configmanager.
Additionally it changes the state dict adapter from static class to an instance-type class and consumes the model args in its init to eliminate guesswork during state dict conversion. It also adds support for building the config.json when converting to hf since this file is required by hf for important tasks such as inference. It also moves model_args to a separate file from train_spec to solve a circular import with state_dict_adapter.
1 parent 9841bdb commit fb1ef04

File tree

10 files changed

+194
-143
lines changed

10 files changed

+194
-143
lines changed

docs/checkpoint.md

Lines changed: 59 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,9 @@
1-
## How to convert a Llama 3 checkpoint for use in torchtitan
1+
# How to use checkpoints in TorchTitan
22

3-
If you want to continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager.
4-
An example script for converting the original Llama3 checkpoints into the expected DCP format can be found in `scripts/convert_llama_to_dcp.py`.
5-
6-
The script expects a path to the original checkpoint files, and a path to an output directory:
7-
```bash
8-
python -m scripts.convert_llama_to_dcp <input_dir> <output_dir>
9-
```
3+
You may want to enable checkpointing in TorchTitan for better fault tolerance during training, or to enable easier importing and exporting of weights between TorchTitan and other libraries. TorchTitan offers varying degrees of support for other checkpoint formats which are listed further below.
104

5+
## A general guide to use checkpoints during training
116

12-
## How to convert a torchtitan checkpoint for use in torchtune
13-
14-
This guide will walk you through the steps required to convert a checkpoint from torchtitan so that it can be loaded into torchtune.
15-
16-
### Steps
177
1. ENABLE CHECKPOINTING
188
In your torchtitan training config, ensure that `enable_checkpoint` is set to True.
199
```
@@ -22,8 +12,6 @@ enable_checkpoint = true
2212
folder = "checkpoint"
2313
interval = 500
2414
```
25-
26-
2715
2. SAVE MODEL ONLY
2816
By setting `last_save_model_only` to `True`, the checkpoint will only contain the model and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size.
2917
```
@@ -41,7 +29,17 @@ last_save_model_only = true
4129
export_dtype = "bfloat16"
4230
```
4331

44-
4. EXAMPLE CHECKPOINT CONFIGURATION
32+
4. EXCLUDING SPECIFIC KEYS FROM CHECKPOINT LOADING
33+
In some cases, you may want to partially load from a previous-trained checkpoint and modify certain settings, such as the number of GPUs or the current step. To achieve this, you can use the `exclude_from_loading` parameter to specify which keys should be excluded from loading.
34+
This parameter takes a list of string that should be excluded from loading.
35+
```
36+
[checkpoint]
37+
enable_checkpoint = true
38+
exclude_from_loading = ["data_loader", "lr_scheduler"]
39+
```
40+
When used in command line, the parameter should be a comma-separated list of strings. For example: `--checkpoint.exclude_from_loading data_loader,lr_scheduler`.
41+
42+
5. EXAMPLE CHECKPOINT CONFIGURATION
4543
```
4644
[checkpoint]
4745
enable_checkpoint = true
@@ -52,30 +50,63 @@ last_save_model_only = true
5250
export_dtype = "bfloat16"
5351
```
5452

55-
5. SAVE THE FINAL CHECKPOINT\
56-
Once the above have been set, the final checkpoint at the end of the training step will consist of model only with the desired export dtype. However, if the final step has not been reached yet, full checkpoints will still be saved so that training can be resumed.
53+
A more exhaustive and up-to-date list of checkpoint config options can be found in torchtitan/config/job_config.py
5754

58-
6. CONVERT SHARDED CHECKPOINTS TO A SINGLE FILE\
59-
Finally, once you have obtained the last checkpoint, you can use the following command to convert the sharded checkpoints to a single .pt file that can be loaded into torchtune:
55+
## Conversion support
6056

61-
```
62-
python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outputs/checkpoint/step-1000 checkpoint.pt
57+
### PyTorch Meta Llama
58+
59+
If you want to continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager.
60+
An example script for converting the original Llama3 checkpoints into the expected DCP format can be found in `scripts/convert_llama_to_dcp.py`.
61+
62+
The script expects a path to the original checkpoint files, and a path to an output directory:
63+
```bash
64+
python -m scripts.convert_from_llama <input_dir> <output_dir>
6365
```
6466

65-
7. EXCLUDING SPECIFIC KEYS FROM CHECKPOINT LOADING
66-
In some cases, you may want to partially load from a previous-trained checkpoint and modify certain settings, such as the number of GPUs or the current step. To achieve this, you can use the `exclude_from_loading` parameter to specify which keys should be excluded from loading.
67-
This parameter takes a list of string that should be excluded from loading.
67+
68+
### Torchtune
69+
70+
This guide will walk you through the steps required to convert a checkpoint from torchtitan so that it can be loaded into torchtune.
71+
72+
1. CHECKPOINT CONFIGURATION
6873
```
6974
[checkpoint]
7075
enable_checkpoint = true
71-
exclude_from_loading = ["data_loader", "lr_scheduler"]
76+
folder = "checkpoint"
77+
interval = 10
78+
last_save_model_only = true
79+
export_dtype = "bfloat16"
7280
```
73-
When used in command line, the parameter should be a comma-separated list of strings. For example: `--checkpoint.exclude_from_loading data_loader,lr_scheduler`.
81+
82+
2. SAVE THE FINAL CHECKPOINT\
83+
Once the above have been set, the final checkpoint at the end of the training step will consist of model only with the desired export dtype. However, if the final step has not been reached yet, full checkpoints will still be saved so that training can be resumed.
84+
85+
3. CONVERT SHARDED CHECKPOINTS TO A SINGLE FILE\
86+
Finally, once you have obtained the last checkpoint, you can use the following command to convert the sharded checkpoints to a single .pt file that can be loaded into torchtune:
87+
88+
```
89+
python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outputs/checkpoint/step-1000 checkpoint.pt
90+
```
91+
7492

7593
That's it. You have now successfully converted a sharded torchtitan checkpoint for use in torchtune.
7694

95+
### HuggingFace
96+
TorchTitan supports two methods now for supporting huggingface, directly saving and loading a hf checkpoint during training, or using an example conversion script to directly reformat the weights.
7797

78-
## How to create a seed checkpoint
98+
1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a torchtitan training session from a huggingface safetensors file, simply enable `--checkpoint.initial_load_model_only` and set `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint.
99+
100+
2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is torchtitan-centric, e.g. convert_from_hf means convert hf->tt.
101+
102+
```
103+
python ./scripts/checkpoint_conversion/convert_from_hf.py <input_dir> <output_dir> --model_name <model_name> --model_flavor <model_flavor>
104+
python ./scripts/checkpoint_conversion/convert_to_hf.py <input_dir> <output_dir> --model_name <model_name> --model_flavor <model_flavor>
105+
# e.g.
106+
python ./scripts/convert_from_hf.py ~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ ./outputs/checkpoint/step-0 --model_name llama3 --model_flavor 8B
107+
```
108+
109+
### Seed Checkpoint
79110
Sometimes one needs to create a seed checkpoint to initialize a model from step 0.
80111
E.g. it is hard, if not impossible, for meta initialization on multiple devices to reproduce the initialization on a single device.
81112
A seed checkpoint does initialization of the model on a single CPU, and can be loaded from another job on an arbitrary number of GPUs via DCP resharding.
@@ -85,8 +116,3 @@ e.g.
85116
```bash
86117
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
87118
```
88-
89-
90-
## How to load / save a checkpoint in HF safetensors format
91-
For save, users need to set `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` to save the last checkpoint in HF format (intermediate ones are always in DCP format).
92-
For load, users need to either put the checkpoint in the `step-0` folder if using `--checkpoint.folder`, or specify `--checkpoint.initial_load_path` to load from a different folder. They also need to set `--checkpoint.initial_load_model_only` to load the checkpoint in HF format.

scripts/convert_from_hf.py renamed to scripts/checkpoint_conversion/convert_from_hf.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
import torchtitan.protocols.train_spec as train_spec_module
1313
from torch.distributed.checkpoint import HuggingFaceStorageReader
1414
from torchtitan.components.checkpoint import ModelWrapper
15-
from torchtitan.components.tokenizer import build_hf_tokenizer
16-
from torchtitan.config_manager import ConfigManager
1715

1816

1917
@torch.inference_mode()
@@ -22,41 +20,31 @@ def convert_from_hf(input_dir, output_dir, model_name, model_flavor):
2220
train_spec = train_spec_module.get_train_spec(model_name)
2321
model_args = train_spec.model_args[model_flavor]
2422

25-
config_manager = ConfigManager()
26-
config = config_manager.parse_args(
27-
[
28-
"--model.tokenizer-path",
29-
"./assets/tokenizer/Llama-3.1-8B",
30-
]
31-
)
32-
tokenizer = build_hf_tokenizer(config)
33-
model_args.update_from_config(config, tokenizer)
3423
with torch.device("cpu"):
3524
model = train_spec.model_cls(model_args)
3625
model = ModelWrapper(model)
3726

38-
sd_adapter = train_spec.state_dict_adapter
27+
sd_adapter = train_spec.state_dict_adapter(model_args)
3928
assert (
4029
sd_adapter is not None
4130
), "trying to convert checkpoint from HF to DCP safetensors format, but sd_adapter is not provided."
4231
# get state dict in tt format with allocated memory
4332
state_dict = model._get_state_dict()
4433
# convert empty state dict to hf format so that hf weights can be loaded into it
45-
hf_state_dict = sd_adapter.to_hf(state_dict, model_args)
34+
hf_state_dict, _ = sd_adapter.to_hf(state_dict)
4635
dcp.load(
4736
hf_state_dict,
4837
storage_reader=HuggingFaceStorageReader(path=input_dir),
4938
)
5039
# convert state dict format back hf->tt and save
51-
state_dict = sd_adapter.from_hf(hf_state_dict, model_args)
40+
state_dict = sd_adapter.from_hf(hf_state_dict)
5241
dcp.save(
5342
state_dict,
5443
checkpoint_id=output_dir,
5544
)
5645

5746

5847
if __name__ == "__main__":
59-
init_logger()
6048
parser = argparse.ArgumentParser(description="Convert Llama weights to DCP format.")
6149
parser.add_argument(
6250
"input_dir", type=Path, help="Input directory with original Llama weights."

scripts/convert_llama_to_dcp.py renamed to scripts/checkpoint_conversion/convert_from_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@torch.inference_mode()
17-
def convert_llama_weights(input_dir, output_dir, max_seq_len: int):
17+
def convert_from_llama(input_dir, output_dir, max_seq_len: int):
1818
with open(input_dir / "params.json", "r") as f:
1919
params = json.load(f)
2020
n_layers = params["n_layers"]
@@ -143,4 +143,4 @@ def convert_llama_weights(input_dir, output_dir, max_seq_len: int):
143143
)
144144
args = parser.parse_args()
145145

146-
convert_llama_weights(args.input_dir, args.output_dir, max_seq_len=args.max_seq_len)
146+
convert_from_llama(args.input_dir, args.output_dir, max_seq_len=args.max_seq_len)

scripts/convert_to_hf.py renamed to scripts/checkpoint_conversion/convert_to_hf.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8+
import json
89
from pathlib import Path
910

1011
import torch
1112
import torch.distributed.checkpoint as dcp
1213
import torchtitan.protocols.train_spec as train_spec_module
1314
from torch.distributed.checkpoint import HuggingFaceStorageWriter
1415
from torchtitan.components.checkpoint import ModelWrapper
15-
from torchtitan.components.tokenizer import build_hf_tokenizer
16-
from torchtitan.config_manager import ConfigManager
1716

1817

1918
@torch.inference_mode()
@@ -22,20 +21,11 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor):
2221
train_spec = train_spec_module.get_train_spec(model_name)
2322
model_args = train_spec.model_args[model_flavor]
2423

25-
config_manager = ConfigManager()
26-
config = config_manager.parse_args(
27-
[
28-
"--model.tokenizer-path",
29-
"./assets/tokenizer/Llama-3.1-8B",
30-
]
31-
)
32-
tokenizer = build_hf_tokenizer(config)
33-
model_args.update_from_config(config, tokenizer)
3424
with torch.device("cpu"):
3525
model = train_spec.model_cls(model_args)
3626
model = ModelWrapper(model)
3727

38-
sd_adapter = train_spec.state_dict_adapter
28+
sd_adapter = train_spec.state_dict_adapter(model_args)
3929
assert (
4030
sd_adapter is not None
4131
), "trying to convert checkpoint from DCP to HF safetensors format, but sd_adapter is not provided."
@@ -48,7 +38,7 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor):
4838
)
4939

5040
# convert state dict tt->hf
51-
hf_state_dict = sd_adapter.to_hf(state_dict, model_args)
41+
hf_state_dict, config_json = sd_adapter.to_hf(state_dict)
5242

5343
fqn_to_index_mapping = {}
5444
num_fqns_per_file = 30
@@ -70,6 +60,10 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor):
7060
storage_writer=storage_writer,
7161
)
7262

63+
config_path = output_dir / "config.json"
64+
with config_path.open("w") as f:
65+
json.dump(config_json, f, indent=4)
66+
7367

7468
if __name__ == "__main__":
7569
parser = argparse.ArgumentParser(description="Convert Llama weights to HF format.")

torchtitan/components/checkpoint.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77
import enum
88
import functools
9+
import json
910
import os
1011
import queue
1112
import re
1213
import shutil
1314
import threading
1415
import time
1516
from concurrent.futures import Future
17+
from pathlib import Path
1618
from typing import Any
1719

1820
import torch
@@ -192,7 +194,7 @@ def __init__(
192194
lr_schedulers: LRSchedulersContainer,
193195
states: dict[str, Any],
194196
checkpoint_config: CheckpointConfig,
195-
sd_adapter: type[StateDictAdapter] | None,
197+
sd_adapter: StateDictAdapter | None,
196198
base_folder: str = "",
197199
ft_manager: FTManager | None = None,
198200
) -> None:
@@ -202,7 +204,6 @@ def __init__(
202204
assert (
203205
sd_adapter is not None
204206
), "job_config.checkpoint.last_save_in_hf is True, but sd_adapter is not provided."
205-
self.sd_adapter = sd_adapter
206207

207208
self.ft_manager = (
208209
ft_manager.manager if ft_manager and ft_manager.enabled else None
@@ -358,7 +359,7 @@ def dcp_save(
358359
assert (
359360
self.sd_adapter is not None
360361
), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
361-
state_dict = self.sd_adapter.to_hf(state_dict)
362+
state_dict, config_json = self.sd_adapter.to_hf(state_dict)
362363

363364
fqn_to_index_mapping = {}
364365
num_fqns_per_file = 30
@@ -376,6 +377,9 @@ def dcp_save(
376377
enable_consolidation=True,
377378
thread_count_consolidation=5,
378379
)
380+
config_path = Path(checkpoint_id) / "config.json"
381+
with config_path.open("w") as f:
382+
json.dump(config_json, f, indent=4)
379383
else:
380384
checkpoint_save_id = checkpoint_id
381385

@@ -425,7 +429,7 @@ def dcp_load(
425429
assert (
426430
self.sd_adapter is not None
427431
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
428-
hf_state_dict = self.sd_adapter.to_hf(
432+
hf_state_dict, _ = self.sd_adapter.to_hf(
429433
state_dict, self.states["train_state"].model_args
430434
)
431435

0 commit comments

Comments
 (0)