Skip to content

Commit 63c3fc5

Browse files
committed
Removes support for huggingface config.json due to the unecessary increased overhead and complexity
Updates the checkpoint.md
1 parent c844a20 commit 63c3fc5

File tree

9 files changed

+65
-95
lines changed

9 files changed

+65
-95
lines changed

docs/checkpoint.md

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# How to use checkpoints in TorchTitan
1+
# How to use checkpointing in `torchtitan`
22

3-
You may want to enable checkpointing in TorchTitan for better fault tolerance during training, or to enable easier importing and exporting of weights between TorchTitan and other libraries. TorchTitan offers varying degrees of support for other checkpoint formats which are listed further below.
3+
You may want to enable checkpointing in `torchtitan` for better fault tolerance during training, or to enable easier importing and exporting of weights between `torchtitan` and other libraries. `torchtitan` offers varying degrees of support for other checkpoint formats which are listed further below.
44

55
## A general guide to use checkpoints during training
66

77
1. ENABLE CHECKPOINTING
8-
In your torchtitan training config, ensure that `enable_checkpoint` is set to True.
8+
In your `torchtitan` training config, ensure that `enable_checkpoint` is set to True.
99
```
1010
[checkpoint]
1111
enable_checkpoint = true
@@ -50,24 +50,38 @@ last_save_model_only = true
5050
export_dtype = "bfloat16"
5151
```
5252

53-
A more exhaustive and up-to-date list of checkpoint config options can be found in torchtitan/config/job_config.py
53+
A more exhaustive and up-to-date list of checkpoint config options can be found in `torchtitan/config/job_config.py`
54+
55+
## Creating a seed checkpoint
56+
Sometimes one needs to create a seed checkpoint to initialize a model from step 0.
57+
E.g. it is hard, if not impossible, for meta initialization on multiple devices to reproduce the initialization on a single device.
58+
A seed checkpoint does initialization of the model on a single CPU, and can be loaded from another job on an arbitrary number of GPUs via DCP resharding.
59+
60+
To create a seed checkpoint, use the same model config as you use for training.
61+
e.g.
62+
```bash
63+
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
64+
```
5465

5566
## Conversion support
5667

57-
### PyTorch Meta Llama
68+
### HuggingFace
69+
`torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu.
5870

59-
If you want to continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager.
60-
An example script for converting the original Llama3 checkpoints into the expected DCP format can be found in `scripts/convert_llama_to_dcp.py`.
71+
1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, simply enable `--checkpoint.initial_load_model_only` and set `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint.
72+
73+
2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt.
6174

62-
The script expects a path to the original checkpoint files, and a path to an output directory:
6375
```bash
64-
python -m scripts.convert_from_llama <input_dir> <output_dir>
76+
python ./scripts/checkpoint_conversion/convert_from_hf.py <input_dir> <output_dir> --model_name <model_name> --model_flavor <model_flavor>
77+
python ./scripts/checkpoint_conversion/convert_to_hf.py <input_dir> <output_dir> --model_name <model_name> --model_flavor <model_flavor>
78+
# e.g.
79+
python ./scripts/convert_from_hf.py ~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ ./outputs/checkpoint/step-0 --model_name llama3 --model_flavor 8B
6580
```
6681

82+
### Torch
6783

68-
### Torchtune
69-
70-
This guide will walk you through the steps required to convert a checkpoint from torchtitan so that it can be loaded into torchtune.
84+
This guide will walk you through the steps required to convert a checkpoint from `torchtitan` so that it can be loaded into pt format.
7185

7286
1. CHECKPOINT CONFIGURATION
7387
```
@@ -83,36 +97,20 @@ export_dtype = "bfloat16"
8397
Once the above have been set, the final checkpoint at the end of the training step will consist of model only with the desired export dtype. However, if the final step has not been reached yet, full checkpoints will still be saved so that training can be resumed.
8498

8599
3. CONVERT SHARDED CHECKPOINTS TO A SINGLE FILE\
86-
Finally, once you have obtained the last checkpoint, you can use the following command to convert the sharded checkpoints to a single .pt file that can be loaded into torchtune:
100+
Finally, once you have obtained the last checkpoint, you can use the following command to convert the sharded checkpoints to a single .pt file.
87101

88-
```
102+
```bash
89103
python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outputs/checkpoint/step-1000 checkpoint.pt
90104
```
91105

92106

93-
That's it. You have now successfully converted a sharded torchtitan checkpoint for use in torchtune.
94-
95-
### HuggingFace
96-
TorchTitan supports two methods now for supporting huggingface, directly saving and loading a hf checkpoint during training, or using an example conversion script to directly reformat the weights.
97-
98-
1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a torchtitan training session from a huggingface safetensors file, simply enable `--checkpoint.initial_load_model_only` and set `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint.
99-
100-
2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is torchtitan-centric, e.g. convert_from_hf means convert hf->tt.
107+
That's it. You have now successfully converted a sharded `torchtitan` checkpoint for use with pytorch formats.
101108

102-
```
103-
python ./scripts/checkpoint_conversion/convert_from_hf.py <input_dir> <output_dir> --model_name <model_name> --model_flavor <model_flavor>
104-
python ./scripts/checkpoint_conversion/convert_to_hf.py <input_dir> <output_dir> --model_name <model_name> --model_flavor <model_flavor>
105-
# e.g.
106-
python ./scripts/convert_from_hf.py ~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/8cde5ca8380496c9a6cc7ef3a8b46a0372a1d920/ ./outputs/checkpoint/step-0 --model_name llama3 --model_flavor 8B
107-
```
109+
### PyTorch Meta Llama
108110

109-
### Seed Checkpoint
110-
Sometimes one needs to create a seed checkpoint to initialize a model from step 0.
111-
E.g. it is hard, if not impossible, for meta initialization on multiple devices to reproduce the initialization on a single device.
112-
A seed checkpoint does initialization of the model on a single CPU, and can be loaded from another job on an arbitrary number of GPUs via DCP resharding.
111+
An example script for converting the original Llama3 checkpoints into DCP format to be used with `torchtitan` can be found in `scripts/convert_from_llama.py`.
113112

114-
To create a seed checkpoint, use the same model config as you use for training.
115-
e.g.
113+
The script expects a path to the original checkpoint files, and a path to an output directory:
116114
```bash
117-
NGPU=1 CONFIG=<path_to_model_config> ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 --parallelism.expert_parallel_degree 1
115+
python -m scripts.convert_from_llama <input_dir> <output_dir>
118116
```

scripts/checkpoint_conversion/convert_from_hf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def convert_from_hf(input_dir, output_dir, model_name, model_flavor):
3131
# get state dict in tt format with allocated memory
3232
state_dict = model._get_state_dict()
3333
# convert empty state dict to hf format so that hf weights can be loaded into it
34-
hf_state_dict, _ = sd_adapter.to_hf(state_dict)
34+
hf_state_dict = sd_adapter.to_hf(state_dict)
3535
dcp.load(
3636
hf_state_dict,
3737
storage_reader=HuggingFaceStorageReader(path=input_dir),
@@ -45,9 +45,9 @@ def convert_from_hf(input_dir, output_dir, model_name, model_flavor):
4545

4646

4747
if __name__ == "__main__":
48-
parser = argparse.ArgumentParser(description="Convert Llama weights to DCP format.")
48+
parser = argparse.ArgumentParser(description="Convert HF checkpoint to DCP format.")
4949
parser.add_argument(
50-
"input_dir", type=Path, help="Input directory with original Llama weights."
50+
"input_dir", type=Path, help="Input directory with HF checkpoint"
5151
)
5252
parser.add_argument("output_dir", type=Path, help="Output directory for DCP.")
5353
parser.add_argument("--model_name", type=str, nargs="?", default="llama3")

scripts/checkpoint_conversion/convert_to_hf.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import argparse
8-
import json
98
from pathlib import Path
109

1110
import torch
@@ -38,7 +37,7 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor):
3837
)
3938

4039
# convert state dict tt->hf
41-
hf_state_dict, config_json = sd_adapter.to_hf(state_dict)
40+
hf_state_dict = sd_adapter.to_hf(state_dict)
4241

4342
fqn_to_index_mapping = {}
4443
num_fqns_per_file = 30
@@ -60,17 +59,15 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor):
6059
storage_writer=storage_writer,
6160
)
6261

63-
config_path = output_dir / "config.json"
64-
with config_path.open("w") as f:
65-
json.dump(config_json, f, indent=4)
66-
6762

6863
if __name__ == "__main__":
69-
parser = argparse.ArgumentParser(description="Convert Llama weights to HF format.")
64+
parser = argparse.ArgumentParser(description="Convert DCP weights to HF format.")
65+
parser.add_argument(
66+
"input_dir", type=Path, help="Input directory with DCP weights."
67+
)
7068
parser.add_argument(
71-
"input_dir", type=Path, help="Input directory with original Llama weights."
69+
"output_dir", type=Path, help="Output directory for HF checkpoint."
7270
)
73-
parser.add_argument("output_dir", type=Path, help="Output directory for DCP.")
7471
parser.add_argument("--model_name", type=str, nargs="?", default="llama3")
7572
parser.add_argument("--model_flavor", type=str, nargs="?", default="8B")
7673
args = parser.parse_args()

torchtitan/components/checkpoint.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,13 @@
66

77
import enum
88
import functools
9-
import json
109
import os
1110
import queue
1211
import re
1312
import shutil
1413
import threading
1514
import time
1615
from concurrent.futures import Future
17-
from pathlib import Path
1816
from typing import Any
1917

2018
import torch
@@ -359,7 +357,7 @@ def dcp_save(
359357
assert (
360358
self.sd_adapter is not None
361359
), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
362-
state_dict, config_json = self.sd_adapter.to_hf(state_dict)
360+
state_dict = self.sd_adapter.to_hf(state_dict)
363361

364362
fqn_to_index_mapping = {}
365363
num_fqns_per_file = 30
@@ -404,11 +402,6 @@ def dcp_save(
404402
checkpoint_id=checkpoint_save_id,
405403
)
406404

407-
if to_hf:
408-
config_path = Path(checkpoint_id) / "config.json"
409-
with config_path.open("w") as f:
410-
json.dump(config_json, f, indent=4)
411-
412405
if enable_garbage_collection:
413406
GarbageCollection.collect("GC collection invoked by checkpointer.")
414407

@@ -432,7 +425,7 @@ def dcp_load(
432425
assert (
433426
self.sd_adapter is not None
434427
), "trying to load checkpoint in HF safetensors format, but sd_adapter is not provided."
435-
hf_state_dict, _ = self.sd_adapter.to_hf(state_dict)
428+
hf_state_dict = self.sd_adapter.to_hf(state_dict)
436429

437430
dcp.load(
438431
hf_state_dict,

torchtitan/experiments/forge/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from typing import Generator
99

1010
import torch
11+
from torch.distributed.elastic.multiprocessing.errors import record
1112

1213
import torchtitan.protocols.train_spec as train_spec_module
13-
from torch.distributed.elastic.multiprocessing.errors import record
1414
from torchtitan.components.checkpoint import CheckpointManager
1515
from torchtitan.components.loss import rescale_accumulated_loss
1616
from torchtitan.distributed import ParallelDims, utils as dist_utils

torchtitan/models/llama3/model/state_dict_adapter.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import re
8-
from typing import Any, Tuple
8+
from typing import Any
99

1010
from torchtitan.protocols.state_dict_adapter import StateDictAdapter
1111

@@ -55,9 +55,7 @@ def _reverse_permute(self, w, n_heads_arg, dim1=None, dim2=None):
5555
.reshape(dim1, dim2)
5656
)
5757

58-
def to_hf(
59-
self, state_dict: dict[str, Any]
60-
) -> Tuple[dict[str, Any], dict[str, Any]]:
58+
def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
6159
to_hf_map = {v: k for k, v in self.from_hf_map.items()}
6260

6361
n_heads = self.model_args.n_heads
@@ -91,27 +89,7 @@ def to_hf(
9189

9290
hf_state_dict[new_key] = value
9391

94-
ffn_hidden_dim = int(self.model_args.dim * 4 * 2 / 3)
95-
if self.model_args.ffn_dim_multiplier:
96-
ffn_hidden_dim = int(ffn_hidden_dim * self.model_args.ffn_dim_multiplier)
97-
multiple_of = self.model_args.multiple_of
98-
ffn_hidden_dim = multiple_of * (
99-
(ffn_hidden_dim + multiple_of - 1) // multiple_of
100-
) # hacky way to get ffn_hidden_dim, follows the calculation in models.TransformerBlock and model.FeedForward
101-
102-
config_json = {
103-
"architectures": ["LlamaForCausalLM"],
104-
"hidde": "silu",
105-
"hidden_size": self.model_args.dim,
106-
"intermediate_size": ffn_hidden_dim,
107-
"model_type": "llama",
108-
"num_attention_heads": self.model_args.n_heads,
109-
"num_hidden_layers": self.model_args.n_layers,
110-
"num_key_value_heads": self.model_args.n_kv_heads,
111-
"vocab_size": self.model_args.vocab_size,
112-
}
113-
114-
return hf_state_dict, config_json
92+
return hf_state_dict
11593

11694
def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
11795
n_heads = self.model_args.n_heads

torchtitan/protocols/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from .model import BaseModelArgs, ModelProtocol
28
from .model_converter import ModelConverter, ModelConvertersContainer
39
from .state_dict_adapter import StateDictAdapter

torchtitan/protocols/state_dict_adapter.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from abc import ABC, abstractmethod
8-
from typing import Any, Tuple
8+
from typing import Any
99

1010
from torchtitan.protocols import BaseModelArgs
1111

@@ -22,9 +22,7 @@ def __init__(self, model_args: BaseModelArgs):
2222
pass
2323

2424
@abstractmethod
25-
def to_hf(
26-
self, state_dict: dict[str, Any]
27-
) -> Tuple[dict[str, Any], dict[str, Any]]:
25+
def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
2826
"""Convert from native model state dict to HuggingFace format.
2927
3028
Args:

torchtitan/train.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,15 @@ def __init__(self, job_config: JobConfig):
142142
)
143143

144144
# build model (using meta init)
145-
self.model_args = self.train_spec.model_args[job_config.model.flavor]
145+
model_args = self.train_spec.model_args[job_config.model.flavor]
146146
# set the model args from training job configs
147-
self.model_args.update_from_config(job_config)
147+
model_args.update_from_config(job_config)
148148

149149
logger.info(
150-
f"Building {self.train_spec.name} {job_config.model.flavor} with {self.model_args}"
150+
f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}"
151151
)
152152
with torch.device("meta"):
153-
model = self.train_spec.model_cls(self.model_args)
153+
model = self.train_spec.model_cls(model_args)
154154

155155
# Build the collection of model converters. No-op if `model.converters` empty
156156
model_converters = build_model_converters(job_config, parallel_dims)
@@ -163,15 +163,15 @@ def __init__(self, job_config: JobConfig):
163163
else self.train_spec.build_metrics_processor_fn
164164
)
165165
self.metrics_processor = build_metrics_processor_fn(
166-
job_config, parallel_dims, self.model_args
166+
job_config, parallel_dims, model_args
167167
)
168168
color = self.metrics_processor.color
169169

170170
# calculate model size and flops per token
171171
(
172172
model_param_count,
173173
self.metrics_processor.num_flops_per_token,
174-
) = self.model_args.get_nparams_and_flops(model, job_config.training.seq_len)
174+
) = model_args.get_nparams_and_flops(model, job_config.training.seq_len)
175175

176176
logger.info(
177177
f"{color.blue}Model {self.train_spec.name} {job_config.model.flavor} "
@@ -234,7 +234,7 @@ def __init__(self, job_config: JobConfig):
234234
parallel_dims,
235235
job_config,
236236
self.device,
237-
self.model_args,
237+
model_args,
238238
self.train_spec.parallelize_fn,
239239
self.loss_fn,
240240
)
@@ -303,7 +303,7 @@ def __init__(self, job_config: JobConfig):
303303
states={"train_state": self},
304304
checkpoint_config=job_config.checkpoint,
305305
sd_adapter=(
306-
self.train_spec.state_dict_adapter(self.model_args)
306+
self.train_spec.state_dict_adapter(model_args)
307307
if self.train_spec.state_dict_adapter
308308
else None
309309
),
@@ -430,7 +430,7 @@ def forward_backward_step(
430430
with self.train_context(optional_context_parallel_ctx):
431431
assert len(model_parts) == 1
432432
with self.maybe_enable_amp:
433-
pred = model_parts[0](inputs)
433+
pred = model_parts[0](inputs, self.tokenizer.eos_id)
434434
loss = self.loss_fn(pred, labels)
435435
# need to free to before bwd to avoid peaking memory
436436
del pred

0 commit comments

Comments
 (0)