From ee95a69aef81e0e2ce10ec50f8d3f22b5457ed05 Mon Sep 17 00:00:00 2001 From: HDCharles <39544797+HDCharles@users.noreply.github.com> Date: Fri, 14 Jun 2024 17:36:57 -0400 Subject: [PATCH] Generalize Model Size Code (#364) * Generalize Model Size Code Summary: previously this worked only on model swap quantized classes and the version in generate.py was specific to a few cases, this new version is significantly more general and now consolidated into a single place. Test Plan: python test_integration.py -k "test_get_model_size" Reviewers: Subscribers: Tasks: Tags: * fix test failures Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 54 ++++++++++++++++++++++++++++ torchao/_models/llama/generate.py | 18 ++-------- torchao/utils.py | 34 +++++++++++++----- 3 files changed, 82 insertions(+), 24 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index b853b0589d..3859d8039b 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1373,6 +1373,60 @@ def forward(self, x): after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) +class TestUtils(unittest.TestCase): + @parameterized.expand(COMMON_DEVICE_DTYPE) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") + def test_get_model_size_autoquant(self, device, dtype): + if device != "cuda" and dtype != torch.bfloat16: + self.skipTest(f"autoquant currently does not support {device}") + if device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"autoquant currently does not support {device}") + if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0): + if dtype == torch.bfloat16: + self.skipTest(f"bfloat16 requires sm80+") + m, k, n = 16, 128, 128 + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(device).to(dtype) + example_input = torch.randn(m, k, device=device, dtype=dtype) + size = torchao.utils.get_model_size_in_bytes(model) + + from torchao.quantization.autoquant import ( + AQWeightOnlyQuantizedLinearWeight2, + ) + qtensor_class_list = ( + AQWeightOnlyQuantizedLinearWeight2, + + ) + + mod = torchao.autoquant(torch.compile(model), qtensor_class_list = qtensor_class_list) + mod(example_input) + size2 = torchao.utils.get_model_size_in_bytes(mod) + self.assertTrue(size2 < size) + + @parameterized.expand( + list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), + ) + def test_get_model_size_aqt(self, api, test_device, test_dtype): + if test_dtype != torch.bfloat16: + self.skipTest(f"{api} in {test_dtype} is not supported yet") + if test_device != "cuda" or not torch.cuda.is_available(): + self.skipTest(f"{api} currently does not support {test_device}") + k, n = 1024, 1024 + model = torch.nn.Sequential( + torch.nn.ReLU(), + torch.nn.Linear(k,n), + torch.nn.ReLU(), + ).to(test_device).to(test_dtype) + size = torchao.utils.get_model_size_in_bytes(model) + api(model) + size2 = torchao.utils.get_model_size_in_bytes(model) + self.assertTrue(size2 < size) + + + if __name__ == "__main__": unittest.main() diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 1f5380a888..ea7200ea6b 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -13,6 +13,7 @@ import torchao import torch._dynamo.config import torch._inductor.config +from torchao.utils import get_model_size_in_bytes def device_sync(device): if "cuda" in device: @@ -143,21 +144,6 @@ def _load_model(checkpoint_path, device, precision): return model.eval() -def _get_model_size(model): - model_size = 0 - for name, child in model.named_children(): - if not isinstance(child, torch.nn.Embedding): - for p in itertools.chain(child.parameters(), child.buffers()): - # handling for tensor subclasses - if isinstance(p, torchao.dtypes.aqt.AffineQuantizedTensor): - layout_tensor = p.layout_tensor - for attr_name in layout_tensor._tensor_flatten__()[0]: - sub_tensor = getattr(layout_tensor, attr_name) - model_size += sub_tensor.numel() * sub_tensor.element_size() - else: - model_size += p.numel() * p.element_size() - return model_size - B_INST, E_INST = "[INST]", "[/INST]" def main( @@ -226,7 +212,7 @@ def main( interactive=False ) - model_size = _get_model_size(model) / 1e9 + model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 if compile: global decode_one_token, prefill diff --git a/torchao/utils.py b/torchao/utils.py index 381a302645..991257a9cb 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -5,6 +5,7 @@ from math import gcd from packaging import version import torch.nn.utils.parametrize as parametrize +import itertools __all__ = [ "benchmark_model", @@ -82,14 +83,31 @@ def find_multiple(n: int, *args: Tuple[int]) -> int: return n return n + k - (n % k) -# https://discuss.pytorch.org/t/finding-model-size/130275 -def get_model_size_in_bytes(model): - s = 0 - for p in model.parameters(): - s += p.nelement() * p.element_size() - for b in model.buffers(): - s += b.nelement() * b.element_size() - return s +def get_model_size_in_bytes(model, ignore_embeddings=False): + """ + Returns the model size in bytes. The option to ignore embeddings + is useful for models with disproportionately large embeddings compared + to other model parameters that get quantized/sparsified. + """ + def flat_size(tensor): + if hasattr(tensor, "__tensor_flatten__"): + size = 0 + # 0th element is a list of attributes that + # hold tensors + for attr_name in tensor.__tensor_flatten__()[0]: + sub_tensor = getattr(tensor, attr_name) + size += flat_size(sub_tensor) + return size + else: + return tensor.numel() * tensor.element_size() + + model_size = 0 + for name, child in model.named_children(): + if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings): + for p in itertools.chain(child.parameters(recurse=False), child.buffers(recurse=False)): + model_size += flat_size(p) + model_size += get_model_size_in_bytes(child, ignore_embeddings) + return model_size class UnwrapTensorSubclass(torch.nn.Module): def forward(self, *tensors):