Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[modeling utils] revamp from_pretrained(..., low_cpu_mem_usage=True) + tests #16657

Merged
merged 17 commits into from
Apr 15, 2022

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Apr 7, 2022

The initial from_pretrained(..., low_cpu_mem_usage=True) implementation was a quick hack to enable loading gptj models on low CPU memory setups. It didn't work with all models.

This PR takes it one step further. It revamps the implementation to support many features it wasn't supporting by revamping the implementation and delegating all the work to the normal from_pretrained code path except the final step of state_dict => model param overwrite.

This PR:

  1. revamps low_cpu_mem_usage=True
  2. adds a functional test that checks from_pretrained(mname, low_cpu_mem_usage=True) works with sharded and non-sharded checkpoint
  3. adds a quality test that measures CPU memory and checks that indeed low_cpu_mem_usage=True uses less memory.
  4. adds various testing utils helper functions to support the new tests

The low cpu memory usage code path is still not 100% complete feature-wise, but it's getting there. Though I'm contemplating a different approach to solving the issue of low cpu memory. That is by introducing several new from_pretrained args that should allow loading the model and/or state_dict directly on GPU for single GPU or DDP. But that's for another PR.

@sgugger, @LysandreJik, @patrickvonplaten

import threading # noqa


class CPUMemoryTracker:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I put CPUMemoryTracker in testing_utils.py before I polish this up?

Copy link
Contributor Author

@stas00 stas00 Apr 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my question is moot.

Unfortunately, I think I have to discard the measuring test (leaving the functional one in place), since measuring cpu memory is super fickle - the test works well on my desktop but fails on CI.

I tried another version with tracemalloc but it doesn't work well either.

If re-run the same tests in a loop I get different numbers due to memory being cached - gc.collect() doesn't seem to help.

I may try some more tomorrow.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. It looked like a great tool though! Too bad CPU usage can't be measured without being so finnicky.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 7, 2022

The documentation is not available anymore as the PR was closed or merged.

@stas00
Copy link
Contributor Author

stas00 commented Apr 8, 2022

Hmm, so trying to write a test that shows the memory saving proved to be a puzzle. Getting inconsistent results between my desktop and the CI. That was using process.memory_info().rss then I also tried tracemalloc, but I think that one is problematic if pytorch uses some kernels that don't go through python memory allocation.

I think I may try /usr/bin/time -f %M via an external process. As it gives me an RSS max peak for the whole independent process.

The results are very peculiar:

  • for a 1GB model there is no saving
  • for a 10GB model, low_mem saves 1/4 of memory,
  • for a 40GB model, low_mem saves 1/2 of memory,

but at least that explains why my memory tracking wasn't showing the saving consistently since I was using a 0.5GB model for the test.

So what I'm doing is:


# 420 MB https://huggingface.co/bert-base-uncased

/usr/bin/time -f %M python -c 'from transformers import AutoModel; AutoModel.from_pretrained("bert-base-uncased", low_cpu_mem_usage=True)'
1139516
/usr/bin/time -f %M python -c 'from transformers import AutoModel; AutoModel.from_pretrained("bert-base-uncased", low_cpu_mem_usage=False)'
1140324

# 1.25GB https://huggingface.co/bert-large-uncased/

/usr/bin/time -f %M python -c 'from transformers import AutoModel; AutoModel.from_pretrained("bert-large-uncased", low_cpu_mem_usage=True)'
2906584
/usr/bin/time -f %M python -c 'from transformers import AutoModel; AutoModel.from_pretrained("bert-large-uncased", low_cpu_mem_usage=False)'
2908236

# 10.6 GB https://huggingface.co/bigscience/T0_3B

/usr/bin/time -f %M python -c 'from transformers import AutoModel; AutoModel.from_pretrained("bigscience/T0_3B", low_cpu_mem_usage=True)'
16122900
/usr/bin/time -f %M python -c 'from transformers import AutoModel; AutoModel.from_pretrained("bigscience/T0_3B", low_cpu_mem_usage=False)'
22299560

# 41.5 GB https://huggingface.co/bigscience/T0

/usr/bin/time -f %M python -c 'from transformers import AutoModel; AutoModel.from_pretrained("bigscience/T0", low_cpu_mem_usage=True)'
43788452
/usr/bin/time -f %M python -c 'from transformers import AutoModel; AutoModel.from_pretrained("bigscience/T0", low_cpu_mem_usage=False)'
86765944

update: The culprit proved be that my original low_cpu_mem code was not able to handle models with a custom prefix in its keys like bert. - this PR fixes it.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for adding those tests!

import threading # noqa


class CPUMemoryTracker:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. It looked like a great tool though! Too bad CPU usage can't be measured without being so finnicky.

return submodule, split_key[0]


def move_model_to_meta(model, loaded_state_dict_keys, start_prefix):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original _load_pretrained_model_low_mem hack got split into 2 functions, one that moves the model to meta and another that replaces the specific keys on meta to loaded state_dict keys.

that way I was able to integrate this functionality into the normal complex code of checking the keys and everything else.

@stas00 stas00 changed the title [modeling utils] add low_cpu_mem_usage tests [modeling utils] revamp from_pretrained(..., low_cpu_mem_usage) + tests Apr 10, 2022
@stas00 stas00 changed the title [modeling utils] revamp from_pretrained(..., low_cpu_mem_usage) + tests [modeling utils] revamp from_pretrained(..., low_cpu_mem_usage=True) + tests Apr 10, 2022
@stas00 stas00 marked this pull request as ready for review April 10, 2022 18:24
@stas00 stas00 requested a review from sgugger April 10, 2022 18:25
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR has moved in a totally different direction as the original intent, touching the code method of Transformers in a way that is hard to read in a git diff. Although the changes are welcome, I'm not sure we can catch any regression in this highly sensible code in the current format.

Could we temporarily revert the refactor and first merge this PR with the test. Then have a PR that refactors the missing key part in a function as you did without code changes, and finally do the code changes in a third PR?

@stas00
Copy link
Contributor Author

stas00 commented Apr 11, 2022

The quality test will not work as the original implementation doesn't work with bert or any other model with its custom bert. prefix. It doesn't put the model to the meta device and thus doesn't save any memory.

I totally hear you about the complexity and that the PR is difficult to review in several places.

So I propose this plan:

  1. a new PR that just refactors _find_mismatched_keys
  2. merge (1) then re-base this PR and revisit?

does that sound OK?

@sgugger
Copy link
Collaborator

sgugger commented Apr 11, 2022

That sounds right, thanks for understanding!

@stas00
Copy link
Contributor Author

stas00 commented Apr 11, 2022

Step 1 is ready: #16706

@stas00
Copy link
Contributor Author

stas00 commented Apr 13, 2022

@sgugger, as planned I rebased on #16706, so the diff should be much easier to read now.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Since it's core, would need to have @LysandreJik and @patrickvonplaten look at this as well, to make sure we don't break anything.

@@ -2091,77 +2210,6 @@ def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=Fal

return retrieved_modules

@staticmethod
def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this method is called in src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py - might be nice to change it the standard one now

Copy link
Contributor Author

@stas00 stas00 Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, Patrick

It's all modular now, so if you agree we can add a convenience wrapper:

    @staticmethod
    def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""):
        """
        This is an experimental function that loads the model using ~1.x model size CPU memory

        Before it gets called we do:

        1. save which state_dict keys we have
        2. drop state_dict before model is created, since the latter takes 1x model size memory

        Here then we continue:

        3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict
        4. load state_dict 2nd time
        5. replace the params/buffers from the state_dict

        Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed.
        """

        _move_model_to_meta(model, loaded_state_dict_keys, start_prefix)
        state_dict = load_state_dict(resolved_archive_file)
        error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix)
        return error_msgs

which restores the original function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if so, how can I test src/transformers/models/regnet/convert_regnet_seer_10b_to_pytorch.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went ahead and added it, so just need to test that conversion script once I know how.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe it's a bit overkill to test the script since the model is huge and it's just a conversion script which are not tested anyways 😅 I'd be fine with just changing the function and "trusting" that it works.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't test conversion scripts. (and the conversion script shouldn'tuse a private method from modeling_utils, missed that in the review...)

Copy link
Contributor Author

@stas00 stas00 Apr 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably indicates a need for a low memory usage model update from state_dict functionality. Perhaps once it's exercised some more we can make it a public util function.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for working on this! Super useful feature.

Left two nits, but I'd also be ok with merging this anyways as there are not too important

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks great to me. Thanks for refactoring this, @stas00, the new method-based approach is very clean.

@stas00 stas00 merged commit 5da33f8 into huggingface:main Apr 15, 2022
@stas00 stas00 deleted the low_cpu_mem_usage-test branch April 15, 2022 01:10
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…` + tests (huggingface#16657)

* add low_cpu_mem_usage tests

* wip: revamping

* wip

* install /usr/bin/time

* wip

* cleanup

* cleanup

* cleanup

* cleanup

* cleanup

* fix assert

* put the wrapper back

* cleanup; switch to bert-base-cased

* Trigger CI

* Trigger CI
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants