Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_model_size_mb (LightningModule.model_size) shouldn't create temporary files in the current directory #10074

Closed
RuRo opened this issue Oct 21, 2021 · 6 comments · Fixed by #10123
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@RuRo
Copy link
Contributor

RuRo commented Oct 21, 2021

Currently, the get_model_size_mb function is implemented like this:

def get_model_size_mb(model: Module) -> float:
    """Calculates the size of a Module in megabytes by saving the model to a temporary file and reading its size.
    The computation includes everything in the :meth:`~torch.nn.Module.state_dict`,
    i.e., by default the parameteters and buffers.
    Returns:
        Number of megabytes in the parameters of the input module.
    """
    # TODO: Implement a method without needing to download the model
    tmp_name = f"{uuid.uuid4().hex}.pt"
    torch.save(model.state_dict(), tmp_name)
    size_mb = os.path.getsize(tmp_name) / 1e6
    os.remove(tmp_name)
    return size_mb

This writes the model to disk and also introduces a race condition where the temporary file may be left behind if the user "Ctrl+C"s the program while the torch.save and os.path.getsize calls are running. I originally found this bug by investigating, what was creating the mysterious cb0e0c0d62ee4c20986c67462c1105d3.pt files in my project directory.

You can achieve the same result without touching the disk by doing

import io

target = io.BytesIO()
torch.save(model.state_dict(), target)
size_mb = target.getbuffer().nbytes / 1e6
@RuRo RuRo added bug Something isn't working help wanted Open to be worked on labels Oct 21, 2021
@awaelchli
Copy link
Contributor

@RuRo Thanks, this seems totally reasonable to me. Would you be interested in contributing this change?

@awaelchli
Copy link
Contributor

awaelchli commented Oct 22, 2021

@calebrob6 came up with a similar idea: #8343 (comment)

He remarks there in the code comment that the disadvantage is the cost of memory.

If this is a concern, I could suggest adding a bool argument to toggle the behavior. Suboptimal, as Lightning would have to choose a default, but maybe worth considering.

@RuRo
Copy link
Contributor Author

RuRo commented Oct 22, 2021

I am "interested in contributing" in the sense that you are free to use the code I provided. Unfortunately, I don't currently have the time to create a proper PR with tests and stuff. Maybe, I'll have some time next month, but given the size of the proposed fix, I don't think that waiting until next month makes sense.


Regarding the memory cost, I'll admit, I haven't thought about that. Here's a slightly more clever version which doesn't store the whole model in memory:

class ByteCounter:
    def __init__(self):
        self.nbytes = 0

    def write(self, data):
        self.nbytes += len(data)

    def flush(self):
        pass

target = ByteCounter()
torch.save(model.state_dict(), target)
size_mb = target.nbytes / 1e6
return size_mb

This still allocates some memory, but the peak memory consumption should be much lower since torch.save writes the model in chunks. For example, torch.saveing a resnet50 model will produce a total of ~97Mb split over 1614 calls to write with a peak usage of ~9Mb.

Also, this version should have about the same memory footprint as the "write the file to disk" version.

@calebrob6
Copy link
Contributor

That's really cool!

@tchaton
Copy link
Contributor

tchaton commented Oct 22, 2021

Dear @RuRo, would you mind contributing ByteCounter, and I believe we could have a version for sharded models too where a summation is performed at the end.

@calebrob6
Copy link
Contributor

@tchaton in case you didn't see, RuRo said, "I am "interested in contributing" in the sense that you are free to use the code I provided. Unfortunately, I don't currently have the time to create a proper PR with tests and stuff."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants