Skip to content

Commit

Permalink
Generalize Model Size Code (pytorch#364)
Browse files Browse the repository at this point in the history
* Generalize Model Size Code

Summary: previously this worked only on model swap quantized classes and
the version in generate.py was specific to a few cases, this new version
is significantly more general and now consolidated into a single place.

Test Plan: python test_integration.py -k "test_get_model_size"

Reviewers:

Subscribers:

Tasks:

Tags:

* fix test failures

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Jun 14, 2024
1 parent 924ebdc commit ee95a69
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 24 deletions.
54 changes: 54 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,60 @@ def forward(self, x):
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))

class TestUtils(unittest.TestCase):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.")
def test_get_model_size_autoquant(self, device, dtype):
if device != "cuda" and dtype != torch.bfloat16:
self.skipTest(f"autoquant currently does not support {device}")
if device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
if dtype == torch.bfloat16:
self.skipTest(f"bfloat16 requires sm80+")
m, k, n = 16, 128, 128
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to(device).to(dtype)
example_input = torch.randn(m, k, device=device, dtype=dtype)
size = torchao.utils.get_model_size_in_bytes(model)

from torchao.quantization.autoquant import (
AQWeightOnlyQuantizedLinearWeight2,
)
qtensor_class_list = (
AQWeightOnlyQuantizedLinearWeight2,

)

mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list)
mod(example_input)
size2 = torchao.utils.get_model_size_in_bytes(mod)
self.assertTrue(size2 < size)

@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
def test_get_model_size_aqt(self, api, test_device, test_dtype):
if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not supported yet")
if test_device != "cuda" or not torch.cuda.is_available():
self.skipTest(f"{api} currently does not support {test_device}")
k, n = 1024, 1024
model = torch.nn.Sequential(
torch.nn.ReLU(),
torch.nn.Linear(k,n),
torch.nn.ReLU(),
).to(test_device).to(test_dtype)
size = torchao.utils.get_model_size_in_bytes(model)
api(model)
size2 = torchao.utils.get_model_size_in_bytes(model)
self.assertTrue(size2 < size)




if __name__ == "__main__":
unittest.main()
18 changes: 2 additions & 16 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torchao
import torch._dynamo.config
import torch._inductor.config
from torchao.utils import get_model_size_in_bytes

def device_sync(device):
if "cuda" in device:
Expand Down Expand Up @@ -143,21 +144,6 @@ def _load_model(checkpoint_path, device, precision):

return model.eval()

def _get_model_size(model):
model_size = 0
for name, child in model.named_children():
if not isinstance(child, torch.nn.Embedding):
for p in itertools.chain(child.parameters(), child.buffers()):
# handling for tensor subclasses
if isinstance(p, torchao.dtypes.aqt.AffineQuantizedTensor):
layout_tensor = p.layout_tensor
for attr_name in layout_tensor._tensor_flatten__()[0]:
sub_tensor = getattr(layout_tensor, attr_name)
model_size += sub_tensor.numel() * sub_tensor.element_size()
else:
model_size += p.numel() * p.element_size()
return model_size

B_INST, E_INST = "[INST]", "[/INST]"

def main(
Expand Down Expand Up @@ -226,7 +212,7 @@ def main(
interactive=False
)

model_size = _get_model_size(model) / 1e9
model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9

if compile:
global decode_one_token, prefill
Expand Down
34 changes: 26 additions & 8 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from math import gcd
from packaging import version
import torch.nn.utils.parametrize as parametrize
import itertools

__all__ = [
"benchmark_model",
Expand Down Expand Up @@ -82,14 +83,31 @@ def find_multiple(n: int, *args: Tuple[int]) -> int:
return n
return n + k - (n % k)

# https://discuss.pytorch.org/t/finding-model-size/130275
def get_model_size_in_bytes(model):
s = 0
for p in model.parameters():
s += p.nelement() * p.element_size()
for b in model.buffers():
s += b.nelement() * b.element_size()
return s
def get_model_size_in_bytes(model, ignore_embeddings=False):
"""
Returns the model size in bytes. The option to ignore embeddings
is useful for models with disproportionately large embeddings compared
to other model parameters that get quantized/sparsified.
"""
def flat_size(tensor):
if hasattr(tensor, "__tensor_flatten__"):
size = 0
# 0th element is a list of attributes that
# hold tensors
for attr_name in tensor.__tensor_flatten__()[0]:
sub_tensor = getattr(tensor, attr_name)
size += flat_size(sub_tensor)
return size
else:
return tensor.numel() * tensor.element_size()

model_size = 0
for name, child in model.named_children():
if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings):
for p in itertools.chain(child.parameters(recurse=False), child.buffers(recurse=False)):
model_size += flat_size(p)
model_size += get_model_size_in_bytes(child, ignore_embeddings)
return model_size

class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors):
Expand Down

0 comments on commit ee95a69

Please sign in to comment.