-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend save_pretrained to offloaded models #27412
Conversation
added hidden subset
debugged hidden subset contrastive search
added contrastive search compression
debugged compressed contrastive search
memory reduction for contrastive search
debugged mem red
added low memory option feature
debugged low mem
added low mem cache
fixed 2047 tensor view
debugged 2042 past key val inputs
reformatted tensors
changed low mem output
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for all for work @blbadger and sorry for not being active ! I tested the PR locally on a big model and it works great. The RAM usage is just as I expected. I left a few comments from testing this PR. I think we can merge soon ! ( We need to do a patch or a release for accelerate before that ) LMK if you want to finish the PR. Otherwise, I can do it.
tests/test_modeling_utils.py
Outdated
"transformer.wte": 0, | ||
"transformer.wpe": 0, | ||
"transformer.h.0": "cpu", | ||
"transformer.h.1": "cpu", | ||
"transformer.h.2": "cpu", | ||
"transformer.h.3": "disk", | ||
"transformer.h.4": "disk", | ||
"transformer.ln_f": 0, | ||
"lm_head": 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you make it device_agnostic just like the tests above ? You need to pull the latest changes !
tests/test_modeling_utils.py
Outdated
model_id = "hf-internal-testing/tiny-random-gpt2" | ||
onloaded_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cpu") | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
input_tokens = tokenizer.encode("Four score and seven years ago", return_tensors="pt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use the same output as other tests: inputs = torch.tensor([[1, 2, 3]]).to(0)
tests/test_modeling_utils.py
Outdated
self.assertTrue( | ||
postsaved_memory - presaved_memory < 7e5 | ||
) # shard size (2e5) plus buffer (~4e5), will fail if shard is too large |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this assert. The tests will be too flaky for our CI. I tested the PR and it works pretty well for big models. Not sure this can capture the fact that we won't use more than shard_size for the ram when loading offloaded modules.
src/transformers/modeling_utils.py
Outdated
# Save the model | ||
if state_dict is None: | ||
# if any model parameters are offloaded to the disk, make module map | ||
if hasattr(self, "hf_device_map") and "disk" in self.hf_device_map.values(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You forgot to include the case where "cpu" is in self.hf_device_map.values() also. You can use the following check to see if we have offloading: if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
. If you replace in the test the "disk" value by "cpu", the test will fail.
src/transformers/modeling_utils.py
Outdated
@@ -117,6 +117,7 @@ | |||
save_offload_index, | |||
set_module_tensor_to_device, | |||
) | |||
from accelerate.utils.modeling import get_state_dict_from_offload |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to protect the import since we don't force the users to download the latest version of accelerate. You can see how we can do here.
src/transformers/modeling_utils.py
Outdated
# remake shard with onloaded parameters if necessary | ||
if module_map: | ||
# init state_dict for this shard | ||
state_dict = {name: "" for name in shard} | ||
for module_name in state_dict.keys(): | ||
module = module_map[module_name] | ||
# update state dict with onloaded parameters | ||
state_dict = get_state_dict_from_offload(module, module_name, state_dict) | ||
|
||
# assign shard to be the completed state dict | ||
shard = state_dict | ||
del state_dict | ||
gc.collect() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a check to see if the users indeed have the latest version of accelerate ? See an similar example here.
@SunMarc thanks very much for taking a look! No worries, I have been very busy too and would not have had much time to work on this before now anyways. I will plan make time to go through your suggestions tomorrow and will let you know if I can't make the finishing touches myself, in which case you would be more than welcome to do so |
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these iterations @blbadger. Just a few nits concerning the accelerate version. I've merged the PR on accelerate side and we should release a new version this week.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! This indeed looks great. cc @amyeroberts for a final look
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for enabling this - it'll be great to have this feature!
postsaved_output = saved_model(inputs)[0] | ||
|
||
self.assertTrue(torch.allclose(cpu_output, presaved_output, atol=1e-4)) | ||
self.assertTrue(torch.allclose(presaved_output, postsaved_output)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice :)
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Thank you again @blbadger for your patience and your work ! I really appreciate your contribution 🔥 Congrats on merging this amazing feature ! |
Happy to contribute! Thanks very much @SunMarc for shepherding this through and @amyeroberts @muellerzr @ArthurZucker for your reviews. |
* added hidden subset * debugged hidden subset contrastive search * added contrastive search compression * debugged compressed contrastive search * memory reduction for contrastive search * debugged mem red * added low memory option feature * debugged mem optmimization output stack * debugged mem optmimization output stack * debugged low mem * added low mem cache * fixed 2047 tensor view * debugged 2042 past key val inputs * reformatted tensors * changed low mem output * final clean * removed subset hidden csearch * fixed hidden device * fixed hidden device * changed compressor dtype * removed hstate compression * integrated csearch in generate * test csearch integration into generation exit() * fixed csearch kwarg integration with generation * final wrap and added doc * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * added debug print * direct hstate cat * direct hstate cat * direct hstate cat debug * direct hstate cat debug * expanded full hidden state stack * expanded full hidden state stack * matched dims for hstates * matched dims for hstates * logits fix * equality test * equality hidden debug * debug * added prints for debug * added prints for debug * equality check * switched squeeze dim * input format debug * tracing top_k_ids * removed trace * added test context * added jitter * added jitter * added jitter * returned state * rebuilt past key value reconstruction * debugged * cleaned traces * added selection for pkv * changed output to dict * cleaned * cleaned * cleaned up contrastive search test * moved low_memory kwarg * debugged * changed low mem test batch size to 1 * removed output * debugged test input shape * reformatted csearch test * added trace * removed unsqueeze on final forward pass * replaced unsqueeze with view * removed traces * cleaned * debugged model kwargs * removed special models from test * ran make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * refactored * refactored * refactored * make fixup * renamed flag sequential * renamed flag sequential * iterative onloading * black style and test utils * added traces for integrated test * debugged * added traces * make style * removed traces, make style * included suggestions and added test * debugged test * added offload module check and make style * is_accelerate_available and make style * added test decorator * changed test model and config spec * added offload condition * added lazy loading for each shard * debugged * modified sharding * debugged * added traces * removed safe serialization * no index overload; * trace on safe save ptrs * added ptr condition * debugged * debugged ptr * moved module map init * remake shard only for offloaded modules * refactored * debugged * refactored * debugged * cleaned and make style * cleaned and make style * added trace * sparse module map * debugged * removed module map conditional * refactored * debug * debugged * added traces * added shard mem trace * added shard mem trace * removed underlying storage check * refactored * memory leak removal and make style * cleaned * swapped test decs and make style * added mem checks and make style * added free mem warning * implemented some suggestions * moved onloading to accelerate * refactored for accelerate integration * cleaned test * make style * debugged offload map name * cleaned and make style * replaced meta device check for sharding * cleaned and make style * implemented some suggestions * more suggestions * update warning Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * more suggestions * make style * new make style * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* added hidden subset * debugged hidden subset contrastive search * added contrastive search compression * debugged compressed contrastive search * memory reduction for contrastive search * debugged mem red * added low memory option feature * debugged mem optmimization output stack * debugged mem optmimization output stack * debugged low mem * added low mem cache * fixed 2047 tensor view * debugged 2042 past key val inputs * reformatted tensors * changed low mem output * final clean * removed subset hidden csearch * fixed hidden device * fixed hidden device * changed compressor dtype * removed hstate compression * integrated csearch in generate * test csearch integration into generation exit() * fixed csearch kwarg integration with generation * final wrap and added doc * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * added debug print * direct hstate cat * direct hstate cat * direct hstate cat debug * direct hstate cat debug * expanded full hidden state stack * expanded full hidden state stack * matched dims for hstates * matched dims for hstates * logits fix * equality test * equality hidden debug * debug * added prints for debug * added prints for debug * equality check * switched squeeze dim * input format debug * tracing top_k_ids * removed trace * added test context * added jitter * added jitter * added jitter * returned state * rebuilt past key value reconstruction * debugged * cleaned traces * added selection for pkv * changed output to dict * cleaned * cleaned * cleaned up contrastive search test * moved low_memory kwarg * debugged * changed low mem test batch size to 1 * removed output * debugged test input shape * reformatted csearch test * added trace * removed unsqueeze on final forward pass * replaced unsqueeze with view * removed traces * cleaned * debugged model kwargs * removed special models from test * ran make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * refactored * refactored * refactored * make fixup * renamed flag sequential * renamed flag sequential * iterative onloading * black style and test utils * added traces for integrated test * debugged * added traces * make style * removed traces, make style * included suggestions and added test * debugged test * added offload module check and make style * is_accelerate_available and make style * added test decorator * changed test model and config spec * added offload condition * added lazy loading for each shard * debugged * modified sharding * debugged * added traces * removed safe serialization * no index overload; * trace on safe save ptrs * added ptr condition * debugged * debugged ptr * moved module map init * remake shard only for offloaded modules * refactored * debugged * refactored * debugged * cleaned and make style * cleaned and make style * added trace * sparse module map * debugged * removed module map conditional * refactored * debug * debugged * added traces * added shard mem trace * added shard mem trace * removed underlying storage check * refactored * memory leak removal and make style * cleaned * swapped test decs and make style * added mem checks and make style * added free mem warning * implemented some suggestions * moved onloading to accelerate * refactored for accelerate integration * cleaned test * make style * debugged offload map name * cleaned and make style * replaced meta device check for sharding * cleaned and make style * implemented some suggestions * more suggestions * update warning Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * more suggestions * make style * new make style * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* added hidden subset * debugged hidden subset contrastive search * added contrastive search compression * debugged compressed contrastive search * memory reduction for contrastive search * debugged mem red * added low memory option feature * debugged mem optmimization output stack * debugged mem optmimization output stack * debugged low mem * added low mem cache * fixed 2047 tensor view * debugged 2042 past key val inputs * reformatted tensors * changed low mem output * final clean * removed subset hidden csearch * fixed hidden device * fixed hidden device * changed compressor dtype * removed hstate compression * integrated csearch in generate * test csearch integration into generation exit() * fixed csearch kwarg integration with generation * final wrap and added doc * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * added debug print * direct hstate cat * direct hstate cat * direct hstate cat debug * direct hstate cat debug * expanded full hidden state stack * expanded full hidden state stack * matched dims for hstates * matched dims for hstates * logits fix * equality test * equality hidden debug * debug * added prints for debug * added prints for debug * equality check * switched squeeze dim * input format debug * tracing top_k_ids * removed trace * added test context * added jitter * added jitter * added jitter * returned state * rebuilt past key value reconstruction * debugged * cleaned traces * added selection for pkv * changed output to dict * cleaned * cleaned * cleaned up contrastive search test * moved low_memory kwarg * debugged * changed low mem test batch size to 1 * removed output * debugged test input shape * reformatted csearch test * added trace * removed unsqueeze on final forward pass * replaced unsqueeze with view * removed traces * cleaned * debugged model kwargs * removed special models from test * ran make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * refactored * refactored * refactored * make fixup * renamed flag sequential * renamed flag sequential * iterative onloading * black style and test utils * added traces for integrated test * debugged * added traces * make style * removed traces, make style * included suggestions and added test * debugged test * added offload module check and make style * is_accelerate_available and make style * added test decorator * changed test model and config spec * added offload condition * added lazy loading for each shard * debugged * modified sharding * debugged * added traces * removed safe serialization * no index overload; * trace on safe save ptrs * added ptr condition * debugged * debugged ptr * moved module map init * remake shard only for offloaded modules * refactored * debugged * refactored * debugged * cleaned and make style * cleaned and make style * added trace * sparse module map * debugged * removed module map conditional * refactored * debug * debugged * added traces * added shard mem trace * added shard mem trace * removed underlying storage check * refactored * memory leak removal and make style * cleaned * swapped test decs and make style * added mem checks and make style * added free mem warning * implemented some suggestions * moved onloading to accelerate * refactored for accelerate integration * cleaned test * make style * debugged offload map name * cleaned and make style * replaced meta device check for sharding * cleaned and make style * implemented some suggestions * more suggestions * update warning Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * more suggestions * make style * new make style * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* added hidden subset * debugged hidden subset contrastive search * added contrastive search compression * debugged compressed contrastive search * memory reduction for contrastive search * debugged mem red * added low memory option feature * debugged mem optmimization output stack * debugged mem optmimization output stack * debugged low mem * added low mem cache * fixed 2047 tensor view * debugged 2042 past key val inputs * reformatted tensors * changed low mem output * final clean * removed subset hidden csearch * fixed hidden device * fixed hidden device * changed compressor dtype * removed hstate compression * integrated csearch in generate * test csearch integration into generation exit() * fixed csearch kwarg integration with generation * final wrap and added doc * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * added debug print * direct hstate cat * direct hstate cat * direct hstate cat debug * direct hstate cat debug * expanded full hidden state stack * expanded full hidden state stack * matched dims for hstates * matched dims for hstates * logits fix * equality test * equality hidden debug * debug * added prints for debug * added prints for debug * equality check * switched squeeze dim * input format debug * tracing top_k_ids * removed trace * added test context * added jitter * added jitter * added jitter * returned state * rebuilt past key value reconstruction * debugged * cleaned traces * added selection for pkv * changed output to dict * cleaned * cleaned * cleaned up contrastive search test * moved low_memory kwarg * debugged * changed low mem test batch size to 1 * removed output * debugged test input shape * reformatted csearch test * added trace * removed unsqueeze on final forward pass * replaced unsqueeze with view * removed traces * cleaned * debugged model kwargs * removed special models from test * ran make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * refactored * refactored * refactored * make fixup * renamed flag sequential * renamed flag sequential * iterative onloading * black style and test utils * added traces for integrated test * debugged * added traces * make style * removed traces, make style * included suggestions and added test * debugged test * added offload module check and make style * is_accelerate_available and make style * added test decorator * changed test model and config spec * added offload condition * added lazy loading for each shard * debugged * modified sharding * debugged * added traces * removed safe serialization * no index overload; * trace on safe save ptrs * added ptr condition * debugged * debugged ptr * moved module map init * remake shard only for offloaded modules * refactored * debugged * refactored * debugged * cleaned and make style * cleaned and make style * added trace * sparse module map * debugged * removed module map conditional * refactored * debug * debugged * added traces * added shard mem trace * added shard mem trace * removed underlying storage check * refactored * memory leak removal and make style * cleaned * swapped test decs and make style * added mem checks and make style * added free mem warning * implemented some suggestions * moved onloading to accelerate * refactored for accelerate integration * cleaned test * make style * debugged offload map name * cleaned and make style * replaced meta device check for sharding * cleaned and make style * implemented some suggestions * more suggestions * update warning Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * more suggestions * make style * new make style * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
* added hidden subset * debugged hidden subset contrastive search * added contrastive search compression * debugged compressed contrastive search * memory reduction for contrastive search * debugged mem red * added low memory option feature * debugged mem optmimization output stack * debugged mem optmimization output stack * debugged low mem * added low mem cache * fixed 2047 tensor view * debugged 2042 past key val inputs * reformatted tensors * changed low mem output * final clean * removed subset hidden csearch * fixed hidden device * fixed hidden device * changed compressor dtype * removed hstate compression * integrated csearch in generate * test csearch integration into generation exit() * fixed csearch kwarg integration with generation * final wrap and added doc * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * added debug print * direct hstate cat * direct hstate cat * direct hstate cat debug * direct hstate cat debug * expanded full hidden state stack * expanded full hidden state stack * matched dims for hstates * matched dims for hstates * logits fix * equality test * equality hidden debug * debug * added prints for debug * added prints for debug * equality check * switched squeeze dim * input format debug * tracing top_k_ids * removed trace * added test context * added jitter * added jitter * added jitter * returned state * rebuilt past key value reconstruction * debugged * cleaned traces * added selection for pkv * changed output to dict * cleaned * cleaned * cleaned up contrastive search test * moved low_memory kwarg * debugged * changed low mem test batch size to 1 * removed output * debugged test input shape * reformatted csearch test * added trace * removed unsqueeze on final forward pass * replaced unsqueeze with view * removed traces * cleaned * debugged model kwargs * removed special models from test * ran make quality * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * refactored * refactored * refactored * make fixup * renamed flag sequential * renamed flag sequential * iterative onloading * black style and test utils * added traces for integrated test * debugged * added traces * make style * removed traces, make style * included suggestions and added test * debugged test * added offload module check and make style * is_accelerate_available and make style * added test decorator * changed test model and config spec * added offload condition * added lazy loading for each shard * debugged * modified sharding * debugged * added traces * removed safe serialization * no index overload; * trace on safe save ptrs * added ptr condition * debugged * debugged ptr * moved module map init * remake shard only for offloaded modules * refactored * debugged * refactored * debugged * cleaned and make style * cleaned and make style * added trace * sparse module map * debugged * removed module map conditional * refactored * debug * debugged * added traces * added shard mem trace * added shard mem trace * removed underlying storage check * refactored * memory leak removal and make style * cleaned * swapped test decs and make style * added mem checks and make style * added free mem warning * implemented some suggestions * moved onloading to accelerate * refactored for accelerate integration * cleaned test * make style * debugged offload map name * cleaned and make style * replaced meta device check for sharding * cleaned and make style * implemented some suggestions * more suggestions * update warning Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * more suggestions * make style * new make style * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Final version looks nice and simple thanks all for your hardwork!
What does this PR do?
Fixes #20072 and addresses the second part of huggingface/peft#868
Models with offloaded weights are currently incompatible with
save_pretrained
. This PR allows large models that are loaded onto the gpu and cpu to be saved, which is particularly useful for big models that have undergone merging and unloading via huggingface/peft#1063.The implementation is to iterate through modules and onload parameters to the execution device (typically gpu) before sending the appropriate elements of the state dict to the cpu in-place, where the final state dictionary is assembled and saved.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Still working on the tests (some small models are not compatible with offloading due to architectural considerations) but am happy to submit a colab version with a large model in the meantime:)
Who can review?
Anyone!
@pacman100