Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Improve TorchTune Extensibility and Build Interop with Ecosystem #442

Merged
merged 20 commits into from
Mar 11, 2024

Conversation

kartikayk
Copy link
Contributor

@kartikayk kartikayk commented Mar 4, 2024

Updates since last review:

Since the last review I've updated our checkpointing stack. The changes include:

  • Add FullModelTorchTuneCheckpointer, FullModelMetaCheckpointer and FullModelHFCheckpointer which handle all of the logic associated with Meta, HF and TorchTune checkpoints. This drastically simplifies the recipe UX (look at full_finetune.py
  • Fix the weirdness in our checkpointing story where we create state_dicts with model as a key even for final checkpoints. This is unintuitive and different from how any other lib and framework does this.
  • Since the change is only applied to full finetuning, I modigy convert-checkpoint so that lora recipe doesnt break. This is a temp hack and will go away once I make the change to LoRA.

PR Context

Note: A couple of notes as you read through this RFC:

  • Please read through the context before looking at code. Yes, I know its a long read.
  • The code is WIP prototype (a.k.a kinda sucks), but has been thoroughly tested locally. I do need to do some refactoring, add detailed unit tests and doc strings, update docs and tutorials. Before that effort, I'd like some initial thoughts.

Context

Building interoperability with the surrounding ecosystem is critical for TorchTune. This "off-ramp" i.e. the ease with which users can use the fine-tuned checkpoints with their favorite tools, is as important as the fine-tuning capabilities provided by the library. It's not an exaggeration to say that without having a strong interop story, it'll be hard for TorchTune to gain traction within the community.

 

Understanding the Landscape

Before we go deeper into building interoperability with the ecosystem, let's take a quick look at the current ecosystem.

HF Model Hub is the de-facto source for most (if not all) popular LLMs. Each model is associated with a checkpoint format. The checkpoint format is different from inference-time formats like GGUF. These refer to the model state dict and how it's presented to model users. At a high-level, checkpoint formats can be divided into two popular buckets:

  • Original. This refers to the original checkpoints from the LLM authors. For example, Meta provides a consolidated .pth file through the meta-llama repository on HF Model Hub. Various tools like llama.cpp directly build on top of this checkpoint by assuming the keys have a certain format.
  • HF Format. Popular models like Llama 7b also are available in the HF format through the HF Hub. Similar to above, tools like llama.cpp make assumptions on the state dict format for exporting models for inference. These checkpoints are usually available in multiple .bin files (or stored as safetensors) with an associated index.json file which provides information for building these state dicts back up.

Popular code bases like gpt-fast [script], GPTQ-for-Llama [script], llama.cpp [script] etc all depend on the above formats or provide the option to write custom convertors.

 

Given the above state, my claim is that we should build TorchTune to be "state-dict invariant" i.e.

convert checkpoints from popular formats into TorchTune's format -> train -> convert back to the original format. The rest of this RFC goes over this idea.

 

But a few FAQs before that:

[Question] Why do we need to convert into a "TorchTune format"? Can't we just directly use the model classes from these popular frameworks like Transformers?

The TorchTune modeling components and classes are built with modularity and flexibility in mind. Using Transformers negates this design principle and takes away "extensibility" as one of our core value proposition to users. It also negates our goal of being "native PyTorch" since these frameworks and libraries have strict structure which needs to be followed. gpt-fast has a similar structure where the code base first [converts the cpt].

 

[Question] What about models like Mistral and Gemma which are neither from Meta nor from HF.

The above applies to both Mistral [HF Repo] and Gemma [HF Repo].

 

What does "be state-dict invariant" mean for TorchTune?

This has a sizable impact on our current user experience, but the ROI is high since not only do we get a "built-in off-ramp", but adding new models becomes easier.

Our current flow looks something like this:

  • Download model from HF Hub (we artificially constrain this to the meta-llama/Llama-2-7b repo)
  • Convert this into TorchTune format using our convert script
  • Train, and save intermediate checkpoints using a (very) custom format
  • Save final checkpoint in the same format

The above flow means that for inference, we need to first convert the final checkpoint into a standard format (eg: GGUF or an Executorch-friendly-format) by writing a custom convertor, which can be substantial work [example]. Alternatively, we need to adopt a standard implementation of popular models which is also a no-go as mentioned above. As a result, adding new model support will be slow since we will need to build a new off-ramp for each model implementation.

The flow proposed by this RFC looks something like this:

  • Download model from HF Hub without major constraints. Eg: Any llama2 7B model which is in the original format or HF's format (eg: those from technium) should be usable OOTB
  • Convert the checkpoint into TorchTune's expected format
  • Train, and save intermediate checkpoints using our custom format
  • Convert final checkpoint BACK to the original format before writing out to file

To minimize cognitive load on the users, TorchTune recipes will handle conversions "to" and "from" the above formats instead of delegating this to the cli tool.

Concretely, the user would run the following:

# Download any model which is in a format supported by TorchTune; this will be a sizable list
tune download ... 

# Copy the config and recipe to custom folder; Modify config and/or recipe
tune  cp ...

# Run fine-tuning without the need for any conversion
tune run ...

# Take model checkpoint and directly use with tool of choice
python3  llama.cpp/convert.py ...

With the above flow, I'm directly able to convert the model into GGUF to run inference and quantization using llama.cpp as well as use gpt-fast for running generation and quantization.

 

What changes in code?

The user no longer needs to know about TorchTune's checkpoint format unless they're resuming training from a previously failed run. Specifically, we:

  • Add support for META_FORMAT for llama-2-7b which includes authoring the convert_llama2_from_meta_format and convert_llama2_to_meta_format functions for translating state dicts.
  • Expose CheckpointFormat through the recipe and config and explicitly ask the user to specify the checkpoint format (detailed documentation will make this clear)
  • Update utils.load_checkpoint to extract the state_dict from the original checkpoint rather than the translated checkpoint. This means adding support for both checkpoint files and checkpoint directories (checkpoints can be split across multiple files) which is done through the _fetch_meta_format_state_dict function which inturn is hooked up to utils.load_checkpoint through the load_external_checkpoint function
  • Update load_checkpoint in the recipe to translate state dict from original format to TorchTune's format. Behavior for resuming training remains the same (see below)
  • Keep the same functionality for when resume_from_checkpoint is True. We need the additional information in the state dict
  • Update save_checkpoint in the recipe to translate final state dict back to the original format.

 

How does this make things better?

We get two big advantages.

 

Adding support for a new format is straight forward

For adding the HF_FORMAT and opening up the repo to a large number of llama-7b models trained using transformers, we need to only:

  • Author convert_llama2_from_hf_format and convert_llama2_to_hf_format functions for translating state dicts.
  • Author _fetch_hf_format_state_dict which contains logic for extracting state dicts from multiple .bin files
  • And that's it!

With the HF_FORMAT added, support for llama2 13B and 30B models (at least on an 80GB A100) should be straight forward as well. Adding Mistral-7B and Gemma-7B should not be as much work as before either.

 

Running your fav inference tools is straight forward

The following conversion to GGUF using llama.cpp works OOTB

python3 convert.py ~/cpts/finetuned_llama/ --ctx 4096

where finetuned_llama contains the output checkpoints from a full-finetuning run in TorchTune.

 

Other FAQs

How does this work for PEFT?

For the MVP, we plan to provide "merged weights" (inference-time weight merging and downstream support is out-of-scope). This should be a straight-forward swap since the state dict keys remain the same.

 

Why not just keep the existing flow while adding a checkpoint convert at the end

Great question! A couple of reasons for this:

  • Asking the user to worry about checkpoint conversions and keeping track of different formats at different stages of the flow adds unnecessary cognitive load at best and is extremely annoying at worst. The user should specify this once in config and not have to worry about this anymore. This would still be worth considering if not for the second point below.
  • Our current flow is really hard to generalize to new models. Just adding support for a new format creates a 500-line (for now) PR.

 

The current code kinda sucks - I sort of understand what's going on, but there aren't any tests or doc strings. How do I know this actually works?

For now, take my word that it words! I'm working on updating the tests, adding detailed doc strings and even a tutorial on the overall flow. But before I put in all of that work, I'd like some initial feedback.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 4, 2024
Copy link

netlify bot commented Mar 4, 2024

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit f0cb0f6
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65ef2db63588380008184144
😎 Deploy Preview https://deploy-preview-442--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@rohan-varma rohan-varma self-requested a review March 4, 2024 05:08
@rohan-varma
Copy link
Member

High level comment - totally agree with the proposal to support HF formats for popular models as well. Though wanted to get into detail about this point -

With the HF_FORMAT added, support for llama2 13B and 30B models (at least on an 80GB A100) should be straight forward as well.

Curious how the HF format specifically enables easier support for llama2 13B and 30B? AFAIK, meta format checkpoints are published for these models as well, so we should be able to use those formats just as easily as well. Just wondering whether the HF format provides something additional for the larger scale models specifically.

@kartikayk
Copy link
Contributor Author

@rohan-varma Sorry this was poorly worded. Adding the HF_FORMAT for 7B meant adding support for reading the state dict from multiple checkpoint files. Llama 13B is split across multiple .pth files (each ~13GB in size) which would have needed custom handling. The overall point was that with this change the repo becomes less specific to llama-7B which makes adding new models and weights easier.

@rohan-varma
Copy link
Member

@rohan-varma Sorry this was poorly worded. Adding the HF_FORMAT for 7B meant adding support for reading the state dict from multiple checkpoint files. Llama 13B is split across multiple .pth files (each ~13GB in size) which would have needed custom handling. The overall point was that with this change the repo becomes less specific to llama-7B which makes adding new models and weights easier.

@kartikayk That makes sense. It seems like we should discuss how we should read and write multiple files? IIUC, the files are checkpoint shards, so what I've been doing is using distributed to stitch them together so we can load into torchtune models, i.e. in https://github.com/pytorch-labs/torchtune/pull/404/files. Curious about the approach that you took to load 13b?

@ebsmothers
Copy link
Contributor

Thanks for creating this RFC! A few questions I have on the proposal:

  1. This makes perfect sense for Llama2-style models (and there are enough of them that this is a big lift), but what about other models that do not match this format (e.g. Mixtral)? Do we consider this a unique new format that we would support that then gets all the same treatment as Llama2 format in this proposal?

  2. Sounds like we are making an explicit contract with users here.. "give us a checkpoint in format X and we will output a fine-tuned checkpoint in format X". I know you mention PEFT, I am curious how this contract will hold up in adapters that are not linear like LoRA (e.g. bottleneck adapters come to mind). In such cases, do we break the contract, opt not to support them, or something else entirely?

  3. One thing I am still trying to reason about is our interaction with Hugging Face here. Currently we have an implicit dependency via our conversion scripts, but in a sense this is integrating this implicit dependency more tightly into our core offering (e.g. if something about their checkpoint format changes we will be completely broken until we can integrate it into our lib). I think it's a reasonable claim that the UX lift is enough to justify this, just want to call it out.

  4. Kind of related to both (1) and (3).. how do we envision maintenance of this flow down the line? Right now it is easy enough to track one set of state dict mappings, but want to make sure we do not just become a state dict remapping library. E.g. lit-gpt has their own format for Llama2, if we want to support checkpoints fine-tuned on there we need a separate utility. Prior to this change I think it's easier for us to tell users to just figure it out, but now I can imagine more of an expectation that we will support such things out of the box. And (I know I don't need to tell you) maintaining a ton of different state dict mappings can get quite messy.

  5. Just for my own understanding, what is our exact definition of "TorchTune format"? Is it state dict key equivalence (plus I guess the assumption that our model is TransformerDecoder class)? Are there any other constraints on parameter shapes? (Obv 7B, 13B, etc. all have their expected shapes, I am thinking more from the inference side of things here.)

  6. Similar question to @rohan-varma: how do we define our output format in the case where our final checkpoint has to be saved in a distributed fashion?

Sorry this wound up being more than a few questions .. 😅 .. let me know if any of these comments are unclear.

@kartikayk
Copy link
Contributor Author

@rohan-varma why do I need distributed for this? Here's a simple code pointer from gpt-fast on how they handle the bin files from HF (seems pretty standard) and I should be able to handle the pth files similarly. Or do I misunderstand?
https://github.com/pytorch-labs/gpt-fast/blob/main/scripts/convert_hf_checkpoint.py#L67

@kartikayk
Copy link
Contributor Author

@ebsmothers great questions! A few thoughts:

but what about other models that do not match this format

If you take a close rlook, the format isn't model specific. It's "training framework" specific. Adding support for a model like Mistral isn't a new format. It's a new convert function which likely has a similar/same key mapping dict.

Sounds like we are making an explicit contract with users here.. "give us a checkpoint in format X and we will output a fine-tuned checkpoint in format X". I know you mention PEFT, I am curious how this contract will hold up in adapters that are not linear like LoRA

We'll need to figure this out once we have a concrete method that we can take a look at. But the answer is that we should align with the inference tooling on this. And that's the contract I propose here.

One thing I am still trying to reason about is our interaction with Hugging Face here

I view this as an interaction with the ecosystem which happens to depend on HF formats for everything. You're right that if things break that we'll have to go in and fix stuff, but the claim is that this is true generally. If tomorrow some other format becomes popular, we should just align with that. Does this make sense?

Kind of related to both (1) and (3).. how do we envision maintenance of this flow down the line? Right now it is easy enough to track one set of state dict mappings, but want to make sure we do not just become a state dict remapping library. E.g. lit-gpt has their own format for Llama2, if we want to support checkpoints fine-tuned on there we need a separate utility. Prior to this change I think it's easier for us to tell users to just figure it out, but now I can imagine more of an expectation that we will support such things out of the box. And (I know I don't need to tell you) maintaining a ton of different state dict mappings can get quite messy.

A few thoughts on this:

  • I don't think we need to explicitly support lit-gpt format or any other custom format. I think it's a pretty fair ask for users to take care of these conversions themselves. We have enough examples for how this can be done.
  • The maintenance story is going to be rough, but we'll need to figure this out even if we write custom adapters for every popular inference, eval and quantization tool out there.

Just for my own understanding, what is our exact definition of "TorchTune format"

Simply put - format of state dict that you can load into our model class i.e. model.load_state_dict() works.

how do we define our output format in the case where our final checkpoint has to be saved in a distributed fashion?

Responded a bit, I don't think I understand what "distributed fashion means here". See the code pointer above, the files in the HF repo for 13, 70B models. Maybe I'm missing something, so need to understand this more.

@rohan-varma
Copy link
Member

rohan-varma commented Mar 5, 2024

Responded a bit, I don't think I understand what "distributed fashion means here". See the code pointer above, the files in the HF repo for 13, 70B models

So IIUC, in those repos, what @ebsmothers means by the files are in a distributed fashion is that each model related file is a sharded checkpoint of the entire model.

For example, llama 2 70b checkpoint has 8 checkpoint files: https://huggingface.co/meta-llama/Llama-2-70b/tree/main. Each file contains the rank's owning shard from row / column parallel sharding. When converting to / from torchtune format, we'll need to appropriately unshard - which to me, requires being aware if the particular tensor was either column or row parallel sharded.

@kartikayk
Copy link
Contributor Author

Each file contains the rank's owning shard from row / column parallel sharding

I don't think this is necessarily true for the HF checkpoint, but I also need to do a lot more HW for 70B. Let's punt the discussion on 70B though. I think the complexity that 70B brings is orthogonal to this RFC. I'm also in the process of redesigning the entire checkpointing stack since this is sub-optimally written in its current form. This will allow us to decouple the checkpointing across model sizes (7b vs 70b) and across finetuning methods (full vs Lora), which makes sense to me since a user fine-tuning a 7B or smaller model doesnt need to care about any of the complexity associated with the 70B model.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FullModelCheckpointer generally makes sense to me, but I also think this is the easy case 😃. The main thing I'd push on here is more details around common PEFT partial save/load flows and making sure we properly support distributed checkpointing APIs (on that second point I would take a look at #443 if you haven't already)


# different formats for the same model have different dtypes. Track this
# to convert weights to the right format before writing to file
self._checkpoint_dtype = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any use case for different weights in different dtypes?

Comment on lines 193 to 207
if intermediate_checkpoint:
# Add the relevant checkpoint format information to the state dict
checkpoint_dict["checkpoint_dtype"] = self._checkpoint_dtype
checkpoint_dict["weight_map"] = self._weight_map
checkpoint_dict["checkpoint_format"] = self._checkpoint_format.name

# We write to a single ".pt" file irrespective of the extension provided by the recipe
output_path = Path.joinpath(self._output_dir, intermediate_checkpoint_name).with_suffix(".pt")
torch.save(checkpoint_dict, output_path)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(output_path) / 1000**3:.2f} GB "
f"saved to {output_path}"
)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more of a nit, but if the whole method is an if/else maybe it should be split into e.g. save_intermediate_checkpoint and save_final_checkpoint methods?

Comment on lines 38 to 47
* Mid-training Chekpointing. In this case the state-dict contains more information
than just the model weights. It also contains the optimizer state dict and the training
state of the recipe needed to correctly restart training. The construction of the
state-dict, including figuring out what information is needed to correctly resume
training, is handled by the recipe. The checkpointer doesn't know or care about how the
state-dict is constructed. In this scenario, the checkpointer simply adds additional
information about the original checkpoint (eg: format, weight map etc) to ensure the
final checkpoint is constructured correctly in case the current training run fails and
needs to be resumed. Intermediate checkpoints don't require any conversion since these
are directly saved in the ``TORCHTUNE_FORMAT``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the claim is that we are using this Checkpointer class to define and enforce a unified TorchTune format, the mid-training checkpointing contract here feels a little strange. It also introduces a lot of coupling between the recipe and the checkpointing class (for instance, the recipe is responsible for defining the logic to save an intermediate checkpoint, but the checkpointer then needs to know about this when loading the intermediate checkpoint). Lmk if I'm misunderstanding this though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers actually the intent here is opposite of what you mentioned. Right now the checkpointer and the recipe are really intertwined with the checkpointing under utils having to know about recipe specific keys and how the checkpoint is formatted. With this change, the separation of concerns is a lot cleaner i.e:

  • The recipe handles all of the logic for preparing intermediate checkpoints (it knows what state it needs to resume correctly)
  • The checkpointer handling the external checkpoint loading and final checkpoint saving.

Generally there will be some coupling between the two because the checkpointer is meant to be used in the recipe. But generally checkpointers should be recipe agnostic. Let me know if this makes sense.

torchtune/utils/_checkpointing/_checkpointer.py Outdated Show resolved Hide resolved
Comment on lines 359 to 361
for key in state_dict.keys():
self._weight_map[key] = ckpt_file.name
merged_state_dict.update(state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if params are sharded across files? I.e. each individual file contains a subset of the weights for a single state dict key

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a new convertor. I don't have a case for this right now till we add support for 13B or 70B. But I expect that to go in a different checkpointer class (or maybe we can update this one if it makes sense).

@kartikayk
Copy link
Contributor Author

@ebsmothers thank you for the review

he main thing I'd push on here is more details around common PEFT partial save/load flows and making sure we properly support distributed checkpointing APIs

So this PR won't address this complexity. My main point is that having a separate training component will make things easier to separate i.e. not have FFT users worry about PEFT checkpointing. Anything from the design specifically that stops us from doing this? In the worst case I can map utils.save_checkpoint to the save_checkpoint in this class and things will work just fine?

@kimishpatel
Copy link

One question I have about
# Take model checkpoint and directly use with tool of choice python3 llama.cpp/convert.py ...

Given "Convert final checkpoint BACK to the original format before writing out to file", the existence of llama.cpp/convert.py is just a convenience util, right? I can take the "checkpoint that is converted back to original format" to llama.cpp repo as well, right? To me the latter feels more composable compare to writing thin wrapper. Thoughts?

@kartikayk
Copy link
Contributor Author

@kimishpatel Sorry I should clarify this point better.

convert.py is a file in the llama.cpp repo which is used to convert a given checkpoint to GGUF format for using with their inference stack. The point I make is that since we convert the checkpoint back, you can directly use the standard convert script in their repo instead of having to write a new one. Does this make sense?

README.md Outdated
```

The argument passed to `--nproc_per_node` can be varied depending on how many GPUs you have. A full finetune can be memory-intensive, so make sure you are running on enough devices. See [this table](https://github.com/pytorch-labs/torchtune/blob/main/README.md#finetuning-resource-requirements) for resource requirements on common hardware setups.

Similarly, you can finetune with LoRA on the Alpaca dataset on two devices via
Similarly, you can finetune with LoRA on the Alpaca dataset on two devices via the following. Remember to convert your
model with ```train_type``` set to ```lora'```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model with ```train_type``` set to ```lora'```
model with `train_type` set to `lora`

Comment on lines +99 to +107
ckpts = (
["llama2.llama2_7b"]
if large_scale
else [
"small_test_ckpt_tune",
"small_test_ckpt_hf",
"small_test_ckpt_meta",
]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be a @pytest.mark.parametrize, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that and I don't think that'll work when we pass in --large-scale? I actually don't know how that works, so went with a manual for loop. Let me know if you think that will work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm good point. I don't immediately see how to do this. Ideally we should do this more cleanly, but no need to block this PR on it

@@ -164,15 +228,21 @@ def test_gradient_accumulation(
# We use a tiny model to reduce the error accumulation in the test
# It's impossible to make a large model produce the same loss values
# in the same way as the full batch size.
model_ckpt = "llama2_tiny_test_ckpt"
model_ckpt = "small_test_ckpt_tune"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this works? I thought we needed the tiny checkpoint for the accumulation of errors to be small enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it works for me :)

cmd = f"""
tune full_finetune \
--config {_CONFIG_PATH} \
--override \
model._component_=torchtune.models.{model_ckpt} \
model_checkpoint={fetch_ckpt_model_path(model_ckpt)} \
checkpointer._component_=torchtune.utils.FullModelTorchTuneCheckpointer \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could just write a single base command, then append micro_batch_size cfg for the first invocation and gradient_accumulation_steps cfg for the second invocation, just to make it clearer that's all that's changing. But obv it was in this state before you got here so not a huge deal if you don't do this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll follow up with changs beyond this PR in a separate PR. It's complex as it is

Comment on lines +314 to +315
assert len(output_state_dict_1.keys()) + 1 == len(orig_state_dict_1.keys())
assert len(output_state_dict_2.keys()) + 1 == len(orig_state_dict_2.keys())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inv freq accounts for the +1 here as well?

Copy link
Contributor Author

@kartikayk kartikayk Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have inv_freq in each layer. Since each dict corresponds to a single layer, both of them are impacted by 1.


class FullModelTorchTuneCheckpointer(_CheckpointerInterface):
"""
Checkpointer which reads and writes "full-model" checkpoints in a format compatible with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think the definition of "full-model" is not sufficiently clear from the context here

self._checkpoint_path = Path.joinpath(self._checkpoint_dir, checkpoint_files[0])
if (
not self._checkpoint_path.is_file()
or not self._checkpoint_path.suffix == ".pt"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: is the .pt extension a hard requirement? Clearly we've had non-pt extensions floating around for a while without really breaking anything

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its very unintuitive, I'd just make this a convention for checkpoints we write

# if resume_from_checkpoint is True, recipe_state.pt should contain the recipe state
if self._resume_from_checkpoint:
self._recipe_state_file = get_recipe_checkpoint_path(
self._checkpoint_dir, filename="recipe_state.pt"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would consider defining recipe_state.pt as a constant somewhere (maybe class-level) for increased visibility

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I know you mentioned saving as a JSON previously, curious if there's any particular reason for switching to .pt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point about the constants. Let me generalize this a bit when I do the LoRA PR.

Optim state can be quite heavy (as large as model checkpoints) and so saving as json was a bad idea.

Comment on lines +175 to +181
Model:
{
"key_1": weight
...
}

Recipe State:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would maybe explicitly add a comment about the filename each of these is saved to (I know you already mention recipe_state.pt above). At first glance I kinda thought this was all a single state dict, I think adding filenames explicitly may make it harder to make that mistake.

self._resume_from_checkpoint = resume_from_checkpoint

# weight_map contains the state_dict key -> checkpoint file mapping so we can correctly
# parition the state dict into output checkpoint files. This is updated during checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
# parition the state dict into output checkpoint files. This is updated during checkpoint
# partition the state dict into output checkpoint files. This is updated during checkpoint

f"Found {type(value)} instead."
)
# idx is written in the 4 digit format (eg: 0001, 0002, etc.)
self._weight_map[key] = f"{cpt_idx+1:04}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a dumb q: why don't we just use the filename directly as the key here rather than going through all the sorting and indexing business?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the file names here are of the form pytorch_model-00001-of-00003.bin, pytorch_model-00002-of-00003.bin and so on. The ID (00001) is important since the weights are written in this order. But the entire filename doesnt serve any purpose (for now). So I simplified all of the filename logic and just added the ID by making sure the incoming names are lexicographically sorted.

Comment on lines +344 to +346
num_heads=self._config["num_attention_heads"],
num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these keys always exist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For llama models they always do. I checked a few others and they do as well. But might be a good idea to add a try-catch block here

Comment on lines +413 to +416
# If the recipe state needs to be output, first remove the model state dict
if intermediate_checkpoint:
_ = state_dict.pop("model")
torch.save(state_dict, Path.joinpath(self._output_dir, "recipe_state.pt"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check my understanding: when saving intermediate checkpoints, we still save in the input format (as opposed to TorchTune format, which we were doing previously)? And then we just supplement with the additional recipe_state.pt file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh thats right. This way if I want to run inference or eval on intermediate checkpoints, I dont need to do some conversion. Also it simplifies save significantly which is nice

Comment on lines +421 to +422
Checkpointer which reads and writes "full-model" checkpoints in Meta's format. Example includes
the Llama-2-7b model from the meta-llama repo (https://huggingface.co/meta-llama/Llama-2-7b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same link as the HF checkpointer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's different?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed those last two letters..


def load_checkpoint(self) -> Dict[str, Any]:
"""
Load TorchTune checkpoint from file. Currently only loading from a single file is supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment here, is this line copy-paste from the equivalent in FullModelTorchTuneCheckpointer? Just generally, make sure to update docstrings for all of these methods

}
"""
state_dict: Dict[str:Any] = {}
state_dict["model"] = safe_torch_load(self._checkpoint_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we talked about getting rid of the "model" key? I might be missing something/misremembering though, lmk if so

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get rid of "model" for checkpoint save. For the state dicts being send to the recipe we still need this to have a single dict with both model and recipe state

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a bunch more comments, but I think all the major design concerns from my side are addressed in this latest version. So modulo my open comments, this looks good to me.

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a legendary PR, as usual! Couple of points I want to make sure we discuss prior to landing (also mostly okay with discussing these post land if absolutely needed, save for the question around security and extending to sharded checkpoints):

  1. User having to specify the train-type when converting their ckpts into our training format seems a bit unintuitive to me. Could you clarify we we have to do this for the moment, what's blocking us from not having to specify this, and how we can get there?
  2. Seems like we would need to write a LoRA specific checkpointer for this to work with LoRA. Is the reason the same as (1), and can we have a discussion here?
  3. Curious about the del state_dict and gc collect - were you seeing memory issues here?
  4. Not directly related to this PR, but I'm realizing having tune download + operating on the checkpoints could introduce a security risk if we allow users to specify arbitrary checkpoint paths - we're downloading somewhat arbitrary checkpoint files from a third-party and running python code on them. This could also be on sensitive HW such as company devices. See https://www.darkreading.com/application-security/hugging-face-ai-platform-100-malicious-code-execution-models for potential attacks. Any thoughts on this security risk and how we can mitigate? One super easy win seems to just be to have an allowlist of checkpoints we've vetted, and crash tune download on any other checkpoints by default?
  5. Extensiblity to sharded checkpoints. Details on this point are in the CR comments.

@@ -56,6 +56,9 @@ jobs:
mkdir -p /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/llama2-7b/tokenizer.model /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/llama2-7b-01242024 /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/small-ckpt-hf-03082024.pt /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/small-ckpt-tune-03082024.pt /tmp/test-artifacts
aws s3 cp s3://pytorch-multimodal/small-ckpt-meta-03082024.pt /tmp/test-artifacts
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to get pretty annoying to hardcode this tempfile directory, as users on the same box can overwrite each other's stuff / not have access to this directory. We should at least add some sort of unique id to this.

parser.add_argument(
"--train-type",
type=str,
help="Type of finetuning. Currently Full-Finetuning and LoRA have slightly different formats. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I'd like to discuss and understand more deeply around this. User having to specify the train type into checkpoint conversion is quite unintuitive, and ideally checkpoint conversion shouldn't have to know about whether the model is going to be used for a full or LoRA finetune at all - it should just produce a checkpoint format that any torchtune training recipe (at least the ones that we write and endorse) can consume.

It also presents additional overhead to the user - I have to run separate conversion scripts to use different finetune techniques - I know currently it's a one time overhead, but worth pointing out + overhead will scale as we introduce more models.

I'm not able to tell from the PR description the blocker in just enabling a consistent format for both finetune techniques - mind elaborating? thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a temp hack till I fix LoRA. This PR just makes the change for full finetune. I'll follow up shortly with a PR that fixes this for LoRA and will remove this. I'll add this to the PR description

@@ -33,6 +34,7 @@ def convert_checkpoint(
checkpoint_path (Path): Path to the checkpoint path.
model (str): Model name
output_path (Optional[Path]): Path to the output checkpoint.
train_type (str): Type of finetuning
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add "must be full or lora"?

checkpointer:
_component_: torchtune.utils.FullModelTorchTuneCheckpointer
checkpoint_dir: /tmp/llama2/
checkpoint_files: [llama2_native.pt]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the file we will write out? what's the plan for scaling this when we need to produce checkpoint shards (at least for intermediate checkpoints)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the file we read from. The output file will be written to output_dir and matches the input format. See the checkpointer doc string

_component_: torchtune.utils.FullModelTorchTuneCheckpointer
checkpoint_dir: /tmp/llama2/
checkpoint_files: [llama2_native.pt]
model_type: LLAMA2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where can I, as a user, find the supported "model_type"'s?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add more info on this after the LoRA change. For now its a copy paste for the config


def load_checkpoint(self) -> Dict[str, Any]:
"""
Load TorchTune checkpoint from file. Currently only loading from a single file is supported.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we mean load a meta checkpoint?

intermediate_checkpoint: bool = False,
) -> None:
"""
Save TorchTune checkpoint to file. If ``intermediate_checkpoint`` is True, an additional
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Save torchtune checkpoint to file, in meta format"?

torch.save(state_dict, Path.joinpath(self._output_dir, "recipe_state.pt"))


class FullModelMetaCheckpointer(_CheckpointerInterface):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great but I'm curous about the utility of the save portion of this file. Are meta checkpoints really used for any off ramps at the moment? They might be and I'm just unaware - do you have a list of off ramps that directly consume the meta style checkpoints?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Llama.cpp for example was built with the Meta format in mind and then extended to handle HF checkpoints. generally its a bad idea to not support the original model format. For future llama versions we'll jsut launch with this format

state_dict.update(recipe_state)
return state_dict

def save_checkpoint(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another comment about saving in meta format. Do we anticipate use cases where user wants to load meta checkpoint --> train in torchtune --> output HF checkpoint? This seems decently natural if HF checkpoint format supports a lot of off ramps, should we support this, or is there a way user can work around to implement this with this work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I definitely don't want to include the cross conversion complexity here. We can think about providing some conversion scripts for this scenario

TORCHTUNE_RESTART = "torchtune_restart"


class ModelType(Enum):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully follow why LoRA should have a different checkpoint, any details we can get here?

@kartikayk
Copy link
Contributor Author

Thanks for the detailed reviews @ebsmothers and @rohan-varma.

User having to specify the train-type when converting their ckpts into our training format seems a bit unintuitive to me.

This is just a hack since I dont have LoRA addressed in this PR. I'll work on that later today and remove this

Seems like we would need to write a LoRA specific checkpointer for this to work with LoRA. Is the reason the same as (1), and can we have a discussion here?

Not for the MVP. We'll assume we always merge the weights and write a full checkpoint. Once we add adapter checkpointing we'll need to alter save and load and I'd rather have this as a separate class than to just add complexity to the full checkpointer

Curious about the del state_dict and gc collect - were you seeing memory issues here?

This needs a bit more testing before I can remove this. Seems like no harm to leave it there for now?

Not directly related to this PR, but I'm realizing having tune download + operating on the checkpoints could introduce a security risk if we allow users to specify arbitrary checkpoint paths - we're downloading somewhat arbitrary checkpoint files from a third-party and running python code on them. This could also be on sensitive HW such as company devices. See https://www.darkreading.com/application-security/hugging-face-ai-platform-100-malicious-code-execution-models for potential attacks. Any thoughts on this security risk and how we can mitigate? One super easy win seems to just be to have an allowlist of checkpoints we've vetted, and crash tune download on any other checkpoints by default?

It's unclear. We can restrict it, but its an arbitrary restriction and is annoying. Currently we restrict and I just go uncomment this. But not sure what the right way to handle this is other than just letting Hub handle this. We can also support safe tensors, but I dont want to take the transformers depedency.

Extensiblity to sharded checkpoints. Details on this point are in the CR comments.

Responded

@kartikayk kartikayk merged commit 4c1fa52 into main Mar 11, 2024
17 checks passed
@kartikayk kartikayk deleted the b_interop branch March 11, 2024 20:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants