-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the initialization of the cache when we have multi gpu #33303
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thank you for the quick jump to the solution! In general it looks great :D
I would add three more things:
- make
device
andlayer_device_mapping
mutually exclusive OR make thedevice
argument also accept a device map (whichever is the most consistent across HF libraries) - make sure we throw an exception somewhere: if a user is using multi GPU, then a layer map has to be passed
- tests 🤗
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
The only way to know if we are using multi-gpu is with the |
That's a good point... 👀 My question at the moment is the following: if a user decides to instantiate a cache manually, is using multi-gpu, and doesn't pass the new argument, how can we let the user know that they should have used the new argument? |
Refactored a bit + added an integration test to check the device on the cache ! Note that we do the check after generation as we need to initialize the cache, meaning that we need the model. We already have tests for multi-gpu static cache so i don't think we need to add these. Let me know if there are other tests you would like to see.
When passing the cache to the model in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works for me to address #33178.
There is not atm, and I think we would hit a few issues if we do so 🤔 We would have to either:
What about a simple device check at update time? If cache tensor device != new cache data device then throw informative exception |
Sounds good @gante ! I've updated the PR with the check ! Let me know if this is better ! I'll add quick test to check if the warning is raised correctly ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for iterating 💛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix. 🚀 🚀 🚀
@SunMarc can you include |
Test passed ! |
@gante @LysandreJik can we prioritize to get this fix merged? I will need this one to unblock |
Thanks, for cache related PRs I recommend pinging @ArthurZucker for a review (pinging him now) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but I don't think we should check in the update. Maybe in generate?
src/transformers/cache_utils.py
Outdated
for state_str, state_device, self_state_device in [ | ||
("key_states", key_states.device, self.key_cache[layer_idx].device), | ||
("value_states", value_states.device, self.value_cache[layer_idx].device), | ||
]: | ||
if state_device != self_state_device: | ||
raise ValueError( | ||
f"Computed {state_str} from layer {layer_idx} is on device {state_device} " | ||
f"whereas stored {state_str} is on device {self_state_device}. " | ||
f"If you are manually initializing the cache, make sure to pass the argument `layer_device_map` if you are using multi-gpu. " | ||
" Otherwise, you can just pass `cache_implementation` in `model.generate()` to correctly initialize the cache." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am really unsure this is worth it for us to run this at every forward pass. I know we want to help our users but would need to make sur it does not cost us anything
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a long discussion above, but tl;dr the options are:
- don't warn at all
- check devices in
update
(this implementation)
with torch.compile
, these lines should get ignored anyway when called correctly (at tracing they have the same device). We should benchmark compile to confirm, though. Assuming they have no throughput cost, I think it's a win to have the error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also wrap update
with a try/except, rather than using an if/else
Let's remove it for now and merge |
cc @SunMarc |
I ran a quick benchmark to see what the impact on generate with
But yeah, let's remove that for the release and I can do a follow-up pr with either the same patch or use try except as joao suggested. Feel free to merge the PR for the release @ArthurZucker if you are fine with the modification! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, tests looks great
Summary: bypass-github-export-checks [Done] ~~Require PR [Make StaticCache configurable at model construct time](huggingface/transformers#32830) in order to export, lower and run the 🤗 model OOTB.~~ [Done] ~~Require huggingface/transformers#33303 or huggingface/transformers#33287 to be merged to 🤗 `transformers` to resolve the export issue introduced by huggingface/transformers#32543 ----------- Now we can take the integration point from 🤗 `transformers` to lower compatible models to ExecuTorch OOTB. - This PR creates a simple script with recipe of XNNPACK. - This PR also created a secret `EXECUTORCH_HT_TOKEN` to allow download checkpoints in the CI - This PR connects the 🤗 "Export to ExecuTorch" e2e workflow to ExecuTorch CI ### Instructions to run the demo: 1. Run the export_hf_model.py to lower gemma-2b to ExecuTorch: ``` python -m extension.export_util.export_hf_model -hfm "google/gemma-2b" # The model is exported statical dims with static KV cache ``` 2. Run the tokenizer.py to generate the binary format for ExecuTorch runtime: ``` python -m extension.llm.tokenizer.tokenizer -t <path_to_downloaded_gemma_checkpoint_dir>/tokenizer.model -o tokenizer.bin ``` 3. Build llm runner by following this guide [step 4](https://github.com/pytorch/executorch/tree/main/examples/models/llama2#step-4-run-on-your-computer-to-validate) 4. Run the lowered model ``` cmake-out/examples/models/llama2/llama_main --model_path=gemma.pte --tokenizer_path=tokenizer.bin --prompt="My name is" ``` OOTB output and perf ``` I 00:00:00.003110 executorch:cpuinfo_utils.cpp:62] Reading file /sys/devices/soc0/image_version I 00:00:00.003360 executorch:cpuinfo_utils.cpp:78] Failed to open midr file /sys/devices/soc0/image_version I 00:00:00.003380 executorch:cpuinfo_utils.cpp:158] Number of efficient cores 4 I 00:00:00.003384 executorch:main.cpp:65] Resetting threadpool with num threads = 6 I 00:00:00.014716 executorch:runner.cpp:51] Creating LLaMa runner: model_path=gemma.pte, tokenizer_path=tokenizer_gemma.bin I 00:00:03.065359 executorch:runner.cpp:66] Reading metadata from model I 00:00:03.065391 executorch:metadata_util.h:43] get_n_bos: 1 I 00:00:03.065396 executorch:metadata_util.h:43] get_n_eos: 1 I 00:00:03.065399 executorch:metadata_util.h:43] get_max_seq_len: 123 I 00:00:03.065402 executorch:metadata_util.h:43] use_kv_cache: 1 I 00:00:03.065404 executorch:metadata_util.h:41] The model does not contain use_sdpa_with_kv_cache method, using default value 0 I 00:00:03.065405 executorch:metadata_util.h:43] use_sdpa_with_kv_cache: 0 I 00:00:03.065407 executorch:metadata_util.h:41] The model does not contain append_eos_to_prompt method, using default value 0 I 00:00:03.065409 executorch:metadata_util.h:43] append_eos_to_prompt: 0 I 00:00:03.065411 executorch:metadata_util.h:41] The model does not contain enable_dynamic_shape method, using default value 0 I 00:00:03.065412 executorch:metadata_util.h:43] enable_dynamic_shape: 0 I 00:00:03.130388 executorch:metadata_util.h:43] get_vocab_size: 256000 I 00:00:03.130405 executorch:metadata_util.h:43] get_bos_id: 2 I 00:00:03.130408 executorch:metadata_util.h:43] get_eos_id: 1 My name is Melle. I am a 20 year old girl from Belgium. I am living in the southern part of Belgium. I am 165 cm tall and I weigh 45kg. I like to play sports like swimming, running and playing tennis. I am very interested in music and I like to listen to classical music. I like to sing and I can play the piano. I would like to go to the USA because I like to travel a lot. I am looking for a boy from the USA who is between 18 and 25 years old. I PyTorchObserver {"prompt_tokens":4,"generated_tokens":118,"model_load_start_ms":1723685715497,"model_load_end_ms":1723685718612,"inference_start_ms":1723685718612,"inference_end_ms":1723685732965,"prompt_eval_end_ms":1723685719087,"first_token_ms":1723685719087,"aggregate_sampling_time_ms":182,"SCALING_FACTOR_UNITS_PER_SECOND":1000} I 00:00:17.482472 executorch:stats.h:70] Prompt Tokens: 4 Generated Tokens: 118 I 00:00:17.482475 executorch:stats.h:76] Model Load Time: 3.115000 (seconds) I 00:00:17.482481 executorch:stats.h:86] Total inference time: 14.353000 (seconds) Rate: 8.221278 (tokens/second) I 00:00:17.482483 executorch:stats.h:94] Prompt evaluation: 0.475000 (seconds) Rate: 8.421053 (tokens/second) I 00:00:17.482485 executorch:stats.h:105] Generated 118 tokens: 13.878000 (seconds) Rate: 8.502666 (tokens/second) I 00:00:17.482486 executorch:stats.h:113] Time to first generated token: 0.475000 (seconds) I 00:00:17.482488 executorch:stats.h:120] Sampling time over 122 tokens: 0.182000 (seconds) ``` Pull Request resolved: #4723 Reviewed By: huydhn, kirklandsign Differential Revision: D62543933 Pulled By: guangy10 fbshipit-source-id: 00401a39ba03d7383e4b284d25c8fc62a6695b34
…ce#33303) * init cache multi-gpu * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * switch to execution device map * naming more consistant * fix * mutually exclusive device * added an integration example * remove useless check * suggestion from joao + typing * fix couple of typo and add test * revert check --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
…ce#33303) * init cache multi-gpu * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * switch to execution device map * naming more consistant * fix * mutually exclusive device * added an integration example * remove useless check * suggestion from joao + typing * fix couple of typo and add test * revert check --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
…ce#33303) * init cache multi-gpu * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * switch to execution device map * naming more consistant * fix * mutually exclusive device * added an integration example * remove useless check * suggestion from joao + typing * fix couple of typo and add test * revert check --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
…ce#33303) * init cache multi-gpu * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * switch to execution device map * naming more consistant * fix * mutually exclusive device * added an integration example * remove useless check * suggestion from joao + typing * fix couple of typo and add test * revert check --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
What does this PR do ?
Fixes #33287 (comment)
This PR initializes the cache on the right device when we are in multi-gpu setup. Before this PR, we would move the tensors to the right device during
update()
which created a few issues withexport()
. This is also a cleaner solution in general.cc @gante I would love to have a quick feedback from you
cc @ArthurZucker as you were also interested in
Tested with:
RUN_SLOW=True CUDA_VISIBLE_DEVICES=0,1 pytest tests/generation/test_utils.py -k "test_generate_with_static_cache_multi_gpu"
RUN_SLOW=True CUDA_VISIBLE_DEVICES=0,1 pytest tests/generation/test_utils.py -k "test_init_static_cache_multi_gpu"
RUN_SLOW=1 pytest tests/utils/test_cache_utils.py -k test_static_cache_exportability
and: