Skip to content

Commit

Permalink
Reuse model_size implementation in get_model_size
Browse files Browse the repository at this point in the history
  • Loading branch information
roshikouhai committed Jul 22, 2021
1 parent 9cf9d18 commit 36b82f6
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 27 deletions.
30 changes: 12 additions & 18 deletions pytorch_lightning/utilities/model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
import os
from functools import partial
from typing import Optional, Type, Union
from unittest.mock import Mock

import torch
import torch.nn as nn

import pytorch_lightning as pl
Expand Down Expand Up @@ -75,24 +78,15 @@ def is_overridden(

def get_model_size(model: nn.Module) -> int:
"""
Calculates the size of a nn.Module in bytes by tallying the size of the Tensor
objects in its ``state_dict()``.
NOTE: This will not work with sparse tensors. See
https://github.com/Stonesjtu/pytorch_memlab/blob/master/pytorch_memlab/mem_reporter.py
as a potential implementation.
Calculates the size of a nn.Module in bytes by saving the model to a temprorary file
and reading in the size.
Returns:
Number of bytes in the parameters of the input module
Raises:
NotImplementedError: if the input model has sparse tensors
Number of megabytes in the parameters of the input module
"""
size = 0
for tensor in model.state_dict().values():
if tensor.is_sparse:
raise NotImplementedError(
"Getting the model size of models that include sparse tensors is not implemented."
)
size += tensor.element_size() * tensor.nelement()
return size
# TODO: Implement a method without needing to download the model
tmp_name = f"{uuid.uuid4().hex}.pt"
torch.save(model.state_dict(), tmp_name)
size_mb = os.path.getsize(tmp_name) / 1e6
os.remove(tmp_name)
return size_mb
6 changes: 4 additions & 2 deletions tests/callbacks/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import QuantizationAwareTraining
from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error
from pytorch_lightning.utilities.model_helpers import get_model_size
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.datamodules import RegressDataModule
from tests.helpers.runif import RunIf
Expand All @@ -44,7 +45,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):

trainer = Trainer(**trainer_args)
trainer.fit(model, datamodule=dm)
org_size = model.model_size
org_size = get_model_size(model)
org_score = torch.mean(torch.tensor([mean_relative_error(model(x), y) for x, y in dm.test_dataloader()]))

fusing_layers = [(f'layer_{i}', f'layer_{i}a') for i in range(3)] if fuse else None
Expand All @@ -66,7 +67,8 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
qmodel.eval()
torch.quantization.convert(qmodel, inplace=True)

quant_size = qmodel.model_size
quant_size = get_model_size(qmodel)
quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
# test that the trained model is smaller then initial
size_ratio = quant_size / org_size
assert size_ratio < 0.65
Expand Down
10 changes: 3 additions & 7 deletions tests/utilities/test_model_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,7 @@ def test_get_model_size():
model = BoringModel()

size_bytes = get_model_size(model)

# The BoringModel has a fully connected layer of size 32x2 with a bias resulting in
# 67 weights. Each weight is a float32 -- 4 bytes, therefore we expect a size of
# 264.
assert size_bytes == 264
assert size_bytes == 0.001319


def test_get_sparse_model_size():
Expand All @@ -116,6 +112,6 @@ def __init__(self):
self.layer = nn.Parameter(torch.ones(32).to_sparse())

model = BoringSparseModel()
size_bytes = get_model_size(model)

with pytest.raises(NotImplementedError):
get_model_size(model)
assert size_bytes == 0.001511

0 comments on commit 36b82f6

Please sign in to comment.