-
Notifications
You must be signed in to change notification settings - Fork 430
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Improve TorchTune Extensibility and Build Interop with Ecosystem #442
Changes from all commits
cb6ce6d
23cb1dd
9ff927b
97f73f8
3b53f2d
43153f0
b03f415
d60227e
b8820f6
acbb0c6
c076fac
489d1f1
726095c
c254535
59e3183
75de3b3
548309c
9096f1b
76809af
f0cb0f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,14 @@ shuffle: True | |
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama2.llama2_7b | ||
model_checkpoint: /tmp/llama2_native | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelMetaCheckpointer | ||
checkpoint_dir: /tmp/llama2 | ||
checkpoint_files: [consolidated.00.pth] | ||
model_type: LLAMA2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where can I, as a user, find the supported "model_type"'s? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll add more info on this after the LoRA change. For now its a copy paste for the config |
||
|
||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 2 | ||
|
@@ -32,7 +39,7 @@ max_steps_per_epoch: null | |
gradient_accumulation_steps: 1 | ||
log_every_n_steps: null | ||
run_generation: null | ||
resume_from_checkpoint: False | ||
|
||
|
||
# Distributed | ||
device: cuda | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import os | ||
import sys | ||
|
||
from functools import partial | ||
|
@@ -17,6 +16,7 @@ | |
from torch import nn | ||
from torch.cuda.amp import GradScaler | ||
from torch.distributed import init_process_group | ||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader, DistributedSampler | ||
|
||
|
@@ -93,13 +93,20 @@ def __init__(self, cfg: DictConfig) -> None: | |
self.max_steps_per_epoch = cfg.max_steps_per_epoch | ||
self.total_training_steps = 0 | ||
|
||
def load_checkpoint(self, ckpt_path: str): | ||
def load_checkpoint(self, cfg: DictConfig) -> Dict[str, Any]: | ||
""" | ||
Extract the checkpoint state from file and validate. | ||
Extract and load state dict from checkpoint file. | ||
""" | ||
ckpt_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True) | ||
utils.validate_checkpoint(ckpt_dict, self._resume_from_checkpoint) | ||
return ckpt_dict | ||
self._checkpointer = config.instantiate( | ||
cfg, | ||
output_dir=self._output_dir, | ||
resume_from_checkpoint=self._resume_from_checkpoint, | ||
) | ||
checkpoint_dict = self._checkpointer.load_checkpoint() | ||
|
||
if self._resume_from_checkpoint: | ||
self._update_recipe_state(checkpoint_dict) | ||
return checkpoint_dict | ||
|
||
def setup(self, cfg: DictConfig) -> None: | ||
""" | ||
|
@@ -108,13 +115,7 @@ def setup(self, cfg: DictConfig) -> None: | |
""" | ||
self._metric_logger = config.instantiate(cfg.metric_logger) | ||
|
||
ckpt_dict = self.load_checkpoint(ckpt_path=cfg.model_checkpoint) | ||
|
||
# If we're resuming from checkpoint, the recipe's state should be updated before | ||
# initializing the training components. This ensures that the seed is correctly | ||
# propagated to the relevant components | ||
if self._resume_from_checkpoint: | ||
self._update_recipe_state(ckpt_dict) | ||
ckpt_dict = self.load_checkpoint(cfg.checkpointer) | ||
|
||
# ``_setup_model`` handles initialization and loading the state dict. This method | ||
# should be called before ``_setup_optimizer`` since transforming the optimizer | ||
|
@@ -239,6 +240,7 @@ def _setup_optimizer( | |
for FSDP. | ||
""" | ||
optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) | ||
|
||
if opt_state_dict: | ||
opt_state_dict = utils.transform_opt_state_dict( | ||
opt_state_dict, self._model, optimizer | ||
|
@@ -290,35 +292,30 @@ def _setup_data( | |
|
||
def save_checkpoint(self, epoch: int) -> None: | ||
""" | ||
Checkpoint the relevant state of a recipe. | ||
|
||
This makes use of the `save_checkpoint` utility which is responsible for | ||
writing the checkpoint dictionary to file. The contents of the dict are dictated | ||
by whether training is complete or not. | ||
|
||
If training is ongoing, optimizer state, seed and epochs_run are saved along with the | ||
model weights. | ||
Save state dict to file. | ||
""" | ||
os.makedirs(self._output_dir, exist_ok=True) | ||
output_loc = f"{self._output_dir}/model_{epoch}.ckpt" | ||
ckpt_dict = {MODEL_KEY: self._model} | ||
|
||
ckpt_dict = {MODEL_KEY: self._model.state_dict()} | ||
# if training is in-progress, checkpoint the optimizer state as well | ||
if epoch + 1 < self.total_epochs: | ||
optimizer_state_dict = ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Didn't we have a helper function for this? Any reason we prefer it directly in recipe for now/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was just a 2 line function specific to this recipe, so just pulled it in here |
||
FSDP.optim_state_dict(self._model, self._optimizer) | ||
if utils.contains_fsdp(self._model) | ||
else self._optimizer.state_dict() | ||
kartikayk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
ckpt_dict.update( | ||
{ | ||
OPT_KEY: self._optimizer, | ||
OPT_KEY: optimizer_state_dict, | ||
SEED_KEY: self.seed, | ||
EPOCHS_KEY: self.epochs_run, | ||
TOTAL_EPOCHS_KEY: self.total_epochs, | ||
MAX_STEPS_KEY: self.max_steps_per_epoch, | ||
} | ||
) | ||
utils.save_checkpoint(ckpt_dict, output_loc) | ||
|
||
if self._is_rank_zero: | ||
log.info( | ||
f"Model checkpoint of size {os.path.getsize(output_loc) >> 20} MB saved to {output_loc}" | ||
self._checkpointer.save_checkpoint( | ||
ckpt_dict, | ||
epoch=epoch, | ||
intermediate_checkpoint=(epoch + 1 < self.total_epochs), | ||
) | ||
|
||
def _should_update_weights(self, current_iteration: int) -> bool: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,3 @@ | ||
# Install PyTorch | ||
torch>=2.2.0 | ||
|
||
# HuggingFace Integration Reqs | ||
datasets==2.15.0 | ||
huggingface_hub==0.19.4 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,13 @@ | |
# Longer Test using llama 7b checkpoint: ./run_test.sh --large-scale | ||
|
||
LOCAL_DIR="/tmp/test-artifacts" | ||
S3_URLS=("s3://pytorch-multimodal/llama2-7b/tokenizer.model" "s3://pytorch-multimodal/small-ckpt-01242024") | ||
S3_URLS=( | ||
"s3://pytorch-multimodal/llama2-7b/tokenizer.model" | ||
"s3://pytorch-multimodal/small-ckpt-01242024" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this one still around? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is needed for LoRA |
||
"s3://pytorch-multimodal/small-ckpt-tune-03082024.pt" | ||
"s3://pytorch-multimodal/small-ckpt-meta-03082024.pt" | ||
"s3://pytorch-multimodal/small-ckpt-hf-03082024.pt" | ||
) | ||
PYTEST_COMMAND="pytest tests/recipes -s" | ||
|
||
if [[ $# -gt 0 ]]; then | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to get pretty annoying to hardcode this tempfile directory, as users on the same box can overwrite each other's stuff / not have access to this directory. We should at least add some sort of unique id to this.