Skip to content

Update checkpointing directory -> using vLLM and from_pretrained #2074

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Dec 6, 2024

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Nov 26, 2024

Co-authored-by: vancoyendall vancoykendall@gmail.com#### Context
What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

TLDR using vLLM and Huggingface .from_pretrained

  1. train your model using torchtune main (nightlies). It will produce a folder like this:
    image
  1. Copy the contents of your latest epoch to base_model folder, which contains the checkpoint_dir original's content, without the model files (.pt, .bin, .safetensors):
cp /tmp/llama_3_2_1b/lora_single_device/epoch_2/* /tmp/llama_3_2_1b/lora_single_device/base_model

Making it look like this:
image

  1. Now, you can use it with vLLM and Huggingface. There is one catch here: when using lora, we output MERGED weights AND the adapter. You should NOT use both at the same time. Either use ONLY the merged weights OR the BASE UNTRAINED MODEL + adapter.

3.1. Using huggingface .from_pretrained with BASE UNTRAINED MODEL + adapter

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

# Define the model and adapter paths
original_model_name = "meta-llama/Llama-3.2-1B-Instruct"
trained_model_path = "/tmp/torchtune/llama3_2_1B/lora_single_device/base_model"

model = AutoModelForCausalLM.from_pretrained(original_model_name)

# huggingface will look for adapter_model.safetensors and adapter_config.json
peft_model = PeftModel.from_pretrained(model, trained_model_path)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(original_model_name)

# Function to generate text
def generate_text(model, tokenizer, prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

prompt = "Complete the sentence: 'Once upon a time...'"
print("Base model output:", generate_text(peft_model, tokenizer, prompt))

3.2. Using huggingface with FULLY TRAINED model

from transformers import AutoModelForCausalLM, AutoTokenizer

# Define the model and adapter paths
trained_model_path = "/tmp/torchtune/llama3_2_1B/full_single_device/base_model"

model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path=trained_model_path,
)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(trained_model_path, safetensors=True)


# Function to generate text
def generate_text(model, tokenizer, prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(**inputs, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


prompt = "Complete the sentence: 'Once upon a time...'"
print("Base model output:", generate_text(model, tokenizer, prompt))

3.3. using the MERGED TRAINED MODEL with vLLM

IMPORTANT: this will not work right away. Your output directory has 2 files:

  • ft-model-00001-of-00001.safetensors
  • adapter_model.safetensors

vLLM doesnt know what is a model and what is an adapter. When it tries to load the adapter, it will raise an error. Therefore, delete adapter_model.safetensors from the folder, and it will work

rm /tmp/torchtune/llama3_2_1B/lora_single_device/base_model/adapter_model.safetensors

Now you can run vLLM locally

from vllm import LLM, SamplingParams

def print_outputs(outputs):
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    print("-" * 80)

llm = LLM(
    model="/tmp/torchtune/llama3_2_1B/lora_single_device",
    load_format="safetensors",
    kv_cache_dtype="auto",
)
sampling_params = SamplingParams(max_tokens=16, temperature=0.5)

conversation = [
    {"role": "system", "content": "You are a helpful assistant"},
    {"role": "user", "content": "Hello"},
    {"role": "assistant", "content": "Hello! How can I assist you today?"},
    {
        "role": "user",
        "content": "Write an essay about the importance of higher education.",
    },
]
outputs = llm.chat(conversation, sampling_params=sampling_params, use_tqdm=False)
print_outputs(outputs)

Context

In torchtune's current state, if checkpoint_dir != outputdir, it breaks. Since the files are all mixed, saved as .pt and without the proper configs, its hard for users to readily use it with vllm/huggingface, resulting in issues such as #2048, #2025 and #2118.

This PR is NOT a major refactor. Everything is backwards compatible. The intention here is just to organize the output_dir and allow users to quickly use their models with HF and vLLM.

Changelog

  1. The folder is automatically created/populated like described above

  2. Initially, base_model has all the files from the checkpoint_dir, except those that end in .pt, .safetensors, .bin, etc

  3. We save the original_model info in the adapter.config

  4. naming got standardized

  5. safetensors is the default for HF ckpt

  6. Solved bugs when input_dir != output_dir

Next steps

Update docs

Unresolved issues

Left TODOs in the code, to be addressed in follow up PRs. It makes the code ugly, but we are due a refactor.

Test plan

  • unit tests
  • resumed pretrained
    image
  • ran vllm and huggingface

Felipe Mello and others added 5 commits November 22, 2024 15:27
Copy link

pytorch-bot bot commented Nov 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2074

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 34c12ff with merge base 2b1ee6d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 26, 2024
Comment on lines 211 to 218
# save the repo_id. This is necessary because the download step is a separate command
# from the rest of the CLI. When saving a model adapter, we have to add the repo_id
# to the adapter config.
file_path = os.path.join(output_dir, training.REPO_ID_FNAME).with_suffix(
".json"
)
with open(file_path, "w") as json_file:
json.dump({"repo_id": args.repo_id}, json_file, indent=4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related: #2026

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kaggle donwload

Comment on lines -216 to -229
def save_config(path: Path, config: Dict[str, Any]) -> None:
"""
Save a configuration dictionary to a file.

Args:
path (Path): Path to save the configuration file.
config (Dict[str, Any]): Configuration dictionary to save.
"""
if not path.is_dir():
path.mkdir(exist_ok=True)
file_path = Path.joinpath(path, "config.json")
if not file_path.exists():
with open(file_path, "w") as f:
json.dump(config, f)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced it with "copy_files", so we save every file, and not only config

# TODO: this needs to be updated when we start using HF cache
file_path = os.path.join(true_output_dir, training.REPO_ID_FNAME + ".json")
with open(file_path, "w") as json_file:
json.dump({"repo_id": args.repo_id}, json_file, indent=4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf download

Comment on lines 211 to 218
# save the repo_id. This is necessary because the download step is a separate command
# from the rest of the CLI. When saving a model adapter, we have to add the repo_id
# to the adapter config.
file_path = os.path.join(output_dir, training.REPO_ID_FNAME).with_suffix(
".json"
)
with open(file_path, "w") as json_file:
json.dump({"repo_id": args.repo_id}, json_file, indent=4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kaggle donwload

Comment on lines -350 to -360
if not isinstance(checkpoint_files, List):
formatted_checkpoint_files = FormattedCheckpointFiles.from_dict(
checkpoint_files
)
checkpoint_files = formatted_checkpoint_files.build_checkpoint_filenames()
self._checkpoint_paths = self._validate_hf_checkpoint_files(checkpoint_files)
self._adapter_checkpoint = (
get_path(self._checkpoint_dir, adapter_checkpoint)
if adapter_checkpoint
else None
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved down

Comment on lines 466 to 472
logger.warning(
f"When resuming from ckpt, we could not find all model files in {self._output_dir=}. "
"This is expected if you set `save_adapter_weights_only=True`. In this case, we will load from checkpoint_dir. "
"However, if you set `save_adapter_weights_only=False`, this is unexpected. "
"Perhaps you forgot to add `epoch_{epoch}/` to your filename? "
"Using checkpoint_dir instead..."
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't really like this. The other options are 1) to actually fix the issue, which is knowing save_adapter_only, or 2) silently let it happen, which is dangerous if the adapter + model was trained (e.g. embeddings), and the user forgot to change the file names.

Not sure if there is a 3rd

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I fully follow this, but for (1) couldn't it be done by just e.g. saving save_adapter_weights_only as part of the recipe state and pulling it from there?

Copy link
Contributor Author

@felipemello1 felipemello1 Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you ok with saving it as part of the recipe_state? that would work.

I'm not sure I fully follow this,

  1. User trains a model that has both lora + finetuning (e.g. vision model)
  2. User resumes from ckpt, and forgets to update the ckpt files
  3. Since we are looking at the ckpt_dir, we will get the untrained model, which is a silent bug. Looking at ckpt_dir is only safe IF save_adapter_only=True. Else, we always have to look at output_dir, which will raise "file not found" error if the user forgot to update fnames

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this. when lora=True, it will always get it from ckpt_dir (which is our current state anyway). Need a follow up PR to address it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like there are a few follow-ups on this PR.. can we file an issue so we can track the different todos in a single place?

output_path = Path.joinpath(
self._output_dir, f"hf_model_{cpt_idx}_{epoch}"
).with_suffix(".pt")
output_path = output_path.with_suffix(".bin")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no more .pt. Lets do .bin.

@felipemello1 felipemello1 marked this pull request as ready for review December 2, 2024 22:57
@@ -231,6 +230,7 @@ def _permute(t, n_heads):

def tune_to_peft_adapter_config(
adapter_config: Dict[str, Any],
base_model_name_or_path: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers How was the PEFT model figuring this out before?

Copy link
Contributor

@ebsmothers ebsmothers Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh it wasn't. That's why we loaded PEFT models in two steps:

model = AutoModelForCausalLM.from_pretrained(model_id)
peft_model = PeftModel.from_pretrained(model, checkpoint_dir)

instead of

AutoModelForCausalLM.from_pretrained(checkpoint_dir)

I had a hacky version of this in #2026 but it was pointed out by @pbontrager that this shouldn't be present for models that are gonna get pushed to the hub (in that case we would want the hub model ID here, not a local path). (Edited) Looks like that is addressed here though



# TODO: instead of copying, make it a symlink when we start using HF cache
def copy_files(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make ignore suffixes kwarg only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow. Can you give an example?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def copy_files(
	input,
	output,
	*,
	ignore_suffixes
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call me crazy but do we need a full Python function reinventing recursive copy here? E.g.

os.system(
f"rsync -av --ignore-existing {" ".join([f"--exclude *{}" for ignore_suffix in ignore_suffixes])} {input_dir} {output_dir}"
)

Copy link
Contributor

@SalmanMohammadi SalmanMohammadi Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just use shutil.copy_tree edit: though I may be missing some of the nuances of this function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah wait actually is there a reason why can't use copytree? I'd rather that be maintained by core Python.

Copy link
Contributor Author

@felipemello1 felipemello1 Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

beyond the suffixes, i also have to ignore .cache and .git__

idk, the function gives us some flexibility and its readable. But it seems to be 3x1.

Can we leave it for the ckpt refactoring?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I do like copy_tree and feel like it should be workable. But won't block on it, just include it in the follow-up task (as I mentioned above)

@@ -196,6 +208,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
with pytest.raises(SystemExit, match=""):
runpy.run_path(TUNE_PATH, run_name="__main__")

raise NotImplementedError("")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Tags:
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for getting all of this sorted out! I know this was not an easy one. Once CI is green I think we're good to merge here. One request: can you update the summary to lead with the example usage within vLLM and PEFT? In case anyone is coming to this PR later I think that's what they will want to see. We should also add a section in the readme giving this process explicitly for better visibility.

@felipemello1 felipemello1 merged commit 424ffc3 into pytorch:main Dec 6, 2024
17 checks passed
@felipemello1 felipemello1 deleted the checkpointer branch December 6, 2024 22:02
@felipemello1 felipemello1 changed the title Update checkpointing directory Update checkpointing directory -> using vLLM and from_pretrained Dec 6, 2024
This was referenced Dec 6, 2024
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 8, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <ebs@meta.com>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <felipemello@fb.com>

---------

Co-authored-by: Philip Bontrager <pbontrager@gmail.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: Joe Cummings <jrcummings27@gmail.com>
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 9, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <ebs@meta.com>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <felipemello@fb.com>

---------

Co-authored-by: Philip Bontrager <pbontrager@gmail.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: Joe Cummings <jrcummings27@gmail.com>
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>
rahul-sarvam added a commit to sarvamai/torchtune that referenced this pull request Dec 18, 2024
* Llama 3.3 70B (pytorch#2124)

* Llama 3.3 readme updates (pytorch#2125)

* update configs (pytorch#2107)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* Reduce logging output for distributed KD (pytorch#2120)

* Support Early Exit Loss and/or Layer Dropout (pytorch#1076)

Co-authored-by: ebsmothers <ebs@meta.com>

* Update checkpointing directory (pytorch#2074)

Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>

* pass correct arg (pytorch#2127)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* update configs (pytorch#2128)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* fix qat_lora_test (pytorch#2131)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* guard ckpt imports (pytorch#2133)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* [bug fix] add parents=True (pytorch#2136)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* [bug fix] re-add model (pytorch#2135)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* Update save sizes into GiB (pytorch#2143)

* [bug fix] remove config download when source is kaggle (pytorch#2144)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* [fix] remove "with_suffix" (pytorch#2146)

Co-authored-by: Felipe Mello <felipemello@fb.com>

* DoRA fixes (pytorch#2139)



Co-authored-by: Mircea Mironenco <5738815+mirceamironenco@users.noreply.github.com>

* [Fix] Llama 3.2 Vision decoder_trainable flag fixed (pytorch#2150)

* Small readme, config updates (pytorch#2157)

* Using `FormattedCheckpointFiles` in configs (pytorch#2147)

* Move ``get_world_size_and_rank`` to utils (pytorch#2155)

* Faster intermediate checkpoints with DCP async save in TorchTune (pytorch#2006)

Co-authored-by: Saurabh Mishra <msaurabh@fb.com>

* torchdata integration - multi-dataset and streaming support (pytorch#1929)

* Allow higher version of lm-eval (pytorch#2165)

* Using `FormattedCheckpointFiles` in configs... round 2 (pytorch#2167)

* [EZ] Fix set_torch_num_threads in multi-node. (pytorch#2164)

---------

Co-authored-by: Philip Bontrager <pbontrager@gmail.com>
Co-authored-by: ebsmothers <ebs@meta.com>
Co-authored-by: Felipe Mello <fmellomascarenhas@gmail.com>
Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: Joe Cummings <jrcummings27@gmail.com>
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>
Co-authored-by: Mircea Mironenco <5738815+mirceamironenco@users.noreply.github.com>
Co-authored-by: salman <salman.mohammadi@outlook.com>
Co-authored-by: Saurabh Mishra <msaurabh@meta.com>
Co-authored-by: Saurabh Mishra <msaurabh@fb.com>
Co-authored-by: Andrew Ho <andrew.kenneth.ho@gmail.com>
Co-authored-by: Eugen Hotaj <eugen_hotaj_91@hotmail.com>
rahul-sarvam pushed a commit to sarvamai/torchtune that referenced this pull request Dec 23, 2024
Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>
rahul-sarvam pushed a commit to sarvamai/torchtune that referenced this pull request Dec 23, 2024
Co-authored-by: Felipe Mello <felipemello@fb.com>
Co-authored-by: vancoyendall <vancoykendall@gmail.com>
@RdoubleA RdoubleA mentioned this pull request Jan 21, 2025
@nouranali
Copy link
Contributor

How would this method affect GPU consumption, having to load two models? While using only 1?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants