Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FLOPs and size to model metadata #6936

Merged
merged 19 commits into from
Nov 11, 2022

Conversation

toni057
Copy link
Contributor

@toni057 toni057 commented Nov 9, 2022

Adds FLOPs and weight size in MB for classification, video, segmentation, optical flow and detection models. FLOPs and sizes were calculated using the script below, where sizes are file sizes obtained from the file system.

Tests were updated to accommodate the new metadata and run with PYTORCH_TEST_WITH_EXTENDED="1" pytest test/test_extended_models.py -vv.
The documentation generation was slightly refactored to enable selecting which columns to show or not in tables as not all models are currently supported (eg quantization).

In order to execute the script, please first install torchvision with the changes as the scrips asserts that the weights were pasted in the torchvision code correctly.

The below is based on @Chillee's work on FLOPs estimation.

import torch
import torchvision
import torchvision.models as models
import os

from torch.utils._pytree import tree_map
from typing import List, Any
from numbers import Number
from collections import defaultdict
from torch.utils._python_dispatch import TorchDispatchMode

aten = torch.ops.aten
quantized = torch.ops.quantized


def get_shape(i):
    if isinstance(i, torch.Tensor):
        return i.shape
    elif hasattr(i, "weight"):
        return i.weight().shape
    else:
        raise ValueError(f"Unknown type {type(i)}")


def prod(x):
    res = 1
    for i in x:
        res *= i
    return res


def matmul_flop(inputs: List[Any], outputs: List[Any]) -> Number:
    """
    Count flops for matmul.
    """
    # Inputs should be a list of length 2.
    # Inputs contains the shapes of two matrices.
    input_shapes = [get_shape(v) for v in inputs]
    assert len(input_shapes) == 2, input_shapes
    assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
    flop = prod(input_shapes[0]) * input_shapes[-1][-1]
    return flop


def addmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
    """
    Count flops for fully connected layers.
    """
    # Count flop for nn.Linear
    # inputs is a list of length 3.
    input_shapes = [get_shape(v) for v in inputs[1:3]]
    # input_shapes[0]: [batch size, input feature dimension]
    # input_shapes[1]: [batch size, output feature dimension]
    assert len(input_shapes[0]) == 2, input_shapes[0]
    assert len(input_shapes[1]) == 2, input_shapes[1]
    batch_size, input_dim = input_shapes[0]
    output_dim = input_shapes[1][1]
    flops = batch_size * input_dim * output_dim
    return flops


def bmm_flop(inputs: List[Any], outputs: List[Any]) -> Number:
    """
    Count flops for the bmm operation.
    """
    # Inputs should be a list of length 2.
    # Inputs contains the shapes of two tensor.
    assert len(inputs) == 2, len(inputs)
    input_shapes = [get_shape(v) for v in inputs]
    n, c, t = input_shapes[0]
    d = input_shapes[-1][-1]
    flop = n * c * t * d
    return flop


def conv_flop_count(
        x_shape: List[int],
        w_shape: List[int],
        out_shape: List[int],
        transposed: bool = False,
) -> Number:
    """
    Count flops for convolution. Note only multiplication is
    counted. Computation for addition and bias is ignored.
    Flops for a transposed convolution are calculated as
    flops = (x_shape[2:] * prod(w_shape) * batch_size).
    Args:
        x_shape (list(int)): The input shape before convolution.
        w_shape (list(int)): The filter shape.
        out_shape (list(int)): The output shape after convolution.
        transposed (bool): is the convolution transposed
    Returns:
        int: the number of flops
    """
    batch_size = x_shape[0]
    conv_shape = (x_shape if transposed else out_shape)[2:]
    flop = batch_size * prod(w_shape) * prod(conv_shape)
    return flop


def conv_flop(inputs: List[Any], outputs: List[Any]):
    """
    Count flops for convolution.
    """
    x, w = inputs[:2]
    x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))
    transposed = inputs[6]

    return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)


def quant_conv_flop(inputs: List[Any], outputs: List[Any]):
    """
    Count flops for quantized convolution.
    """
    x, w = inputs[:2]
    x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0]))

    return conv_flop_count(x_shape, w_shape, out_shape, transposed=False)


def transpose_shape(shape):
    return [shape[1], shape[0]] + list(shape[2:])


def conv_backward_flop(inputs: List[Any], outputs: List[Any]):
    grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]]
    output_mask = inputs[-1]
    fwd_transposed = inputs[7]
    flop_count = 0

    if output_mask[0]:
        grad_input_shape = get_shape(outputs[0])
        flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
    if output_mask[1]:
        grad_weight_shape = get_shape(outputs[1])
        flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)

    return flop_count


flop_mapping = {
    aten.mm: matmul_flop,
    aten.matmul: matmul_flop,
    aten.addmm: addmm_flop,
    aten.bmm: bmm_flop,
    aten.convolution: conv_flop,
    aten._convolution: conv_flop,
    aten.convolution_backward: conv_backward_flop,
    quantized.conv2d: quant_conv_flop,
    quantized.conv2d_relu: quant_conv_flop,
}

unmapped_ops = set()


def normalize_tuple(x):
    if not isinstance(x, tuple):
        return (x,)
    return x


class FlopCounterMode(TorchDispatchMode):
    def __init__(self, model=None):
        self.flop_counts = defaultdict(lambda: defaultdict(int))
        self.parents = ['Global']
        # global mod
        if model is not None:
            for name, module in dict(model.named_children()).items():
                module.register_forward_pre_hook(self.enter_module(name))
                module.register_forward_hook(self.exit_module(name))

    def enter_module(self, name):
        def f(module, inputs):
            self.parents.append(name)
            inputs = normalize_tuple(inputs)
            out = self.create_backwards_pop(name)(*inputs)
            return out

        return f

    def exit_module(self, name):
        def f(module, inputs, outputs):
            assert (self.parents[-1] == name)
            self.parents.pop()
            outputs = normalize_tuple(outputs)
            return self.create_backwards_push(name)(*outputs)

        return f

    def create_backwards_push(self, name):
        class PushState(torch.autograd.Function):
            @staticmethod
            def forward(ctx, *args):
                args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
                if len(args) == 1:
                    return args[0]
                return args

            @staticmethod
            def backward(ctx, *grad_outs):
                self.parents.append(name)
                return grad_outs

        return PushState.apply

    def create_backwards_pop(self, name):
        class PopState(torch.autograd.Function):
            @staticmethod
            def forward(ctx, *args):
                args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args)
                if len(args) == 1:
                    return args[0]
                return args

            @staticmethod
            def backward(ctx, *grad_outs):
                assert (self.parents[-1] == name)
                self.parents.pop()
                return grad_outs

        return PopState.apply

    def __enter__(self):
        self.flop_counts.clear()
        super().__enter__()

    def __exit__(self, *args):
        # print(f"Total: {sum(self.flop_counts['Global'].values()) / 1e9} GFLOPS")
        # for mod in self.flop_counts.keys():
        #     print(f"Module: ", mod)
        #     for k, v in self.flop_counts[mod].items():
        #         print(f"{k}: {v / 1e9} GFLOPS")
        #     print()
        super().__exit__(*args)

    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        kwargs = kwargs if kwargs else {}

        out = func(*args, **kwargs)
        func_packet = func._overloadpacket
        if func_packet in flop_mapping:
            flop_count = flop_mapping[func_packet](args, normalize_tuple(out))
            for par in self.parents:
                self.flop_counts[par][func_packet] += flop_count
        else:
            unmapped_ops.add(func_packet)

        return out

    def get_flops(self):
        return sum(self.flop_counts['Global'].values()) / 1e9


detection_models_input_dims = {
    "fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320),
    "fasterrcnn_mobilenet_v3_large_fpn": (800, 800),
    "fasterrcnn_resnet50_fpn": (800, 800),
    "fasterrcnn_resnet50_fpn_v2": (800, 800),
    "fcos_resnet50_fpn": (800, 800),
    "keypointrcnn_resnet50_fpn": (1333, 1333),
    "maskrcnn_resnet50_fpn": (800, 800),
    "maskrcnn_resnet50_fpn_v2": (800, 800),
    "retinanet_resnet50_fpn": (800, 800),
    "retinanet_resnet50_fpn_v2": (800, 800),
    "ssd300_vgg16": (300, 300),
    "ssdlite320_mobilenet_v3_large": (320, 320)
}


def get_flops(model_list, h=512, w=512, t=None, optical_flow=False, quantize=None):
    flops = {}
    for model_name in model_list:
        weights_enum = models.get_model_weights(model_name)
        for weight in weights_enum:

            # detection models have curated input sizes
            if model_name in detection_models_input_dims:
                # we can feed a batch of 1 for detection model instead of a list of 1 image
                h, w = detection_models_input_dims[model_name]
                dims = (3, h, w)
            else:
                # if the input has the time dimension for video
                if t is None:
                    dims = (1, 3, h, w)
                else:
                    dims = (1, t, 3, h, w)

            input_tensor = torch.randn(dims)

            try:
                preprocess = weight.transforms()
                if optical_flow:
                    inp = preprocess(input_tensor, input_tensor)
                else:
                    # hack to enable mod(*inp) for optical_flow models
                    inp = [preprocess(input_tensor)]
                # todo: else: add the variant with (c, h, w) for detection , list[tensor]

                kwargs = {} if quantize is None else {"quantize": quantize}
                mod = models.get_model(model_name, weights=weight, **kwargs)
                mod.eval()

                flop_counter = FlopCounterMode(mod)
                with flop_counter:

                    # detection models expect a list of 3d tensors as inputs
                    if model_name in detection_models_input_dims:
                        mod(inp)
                    else:
                        mod(*inp)

                    flops[weight] = flop_counter.get_flops()

            except Exception as e:
                # print(f"Failed: ", weight, e)
                flops[weight] = float('nan')

            weights_path = os.path.join(os.getenv("HOME"), ".cache/torch/hub/checkpoints",
                                        weight.url.split("/")[-1])
            weights_size_mb = os.path.getsize(weights_path) / 1024 / 1024

            if "_ops" in weight.meta and not math.isnan(flops[weight]) and not quantize:
                assert (round(flops[weight], 3) == weight.meta["_ops"])
            if "_weight_size" in weight.meta:
                assert (round(weights_size_mb, 3) == weight.meta["_weight_size"])

        print("%-60s| %10.3f | %10.3f" % (weight, round(flops[weight], 3), round(weights_size_mb, 3)))

    return flops


if __name__ == "__main__":
    print("%-60s| %10s | %10s" % ("Name", "GFLOPs", "Size (MB)"))
    print("-" * 85)
    get_flops(models.list_models(module=torchvision.models))  # OK
    get_flops(models.list_models(module=torchvision.models.video), t=16)  # OK
    get_flops(models.list_models(module=torchvision.models.segmentation))  # OK
    get_flops(models.list_models(module=torchvision.models.optical_flow), optical_flow=True)  # OK
    get_flops(models.list_models(module=torchvision.models.detection))  # OK
   get_flops(models.list_models(module=torchvision.models.quantization), quantize=True)

    if unmapped_ops:
        print("\nThe following ops were not counted in the estimations:")
        print([str(op) for op in unmapped_ops])

Running the above yields the following output:

Name                                                        |     GFLOPs |  Size (MB)
-------------------------------------------------------------------------------------
AlexNet_Weights.IMAGENET1K_V1                               |      0.714 |    233.087
ConvNeXt_Base_Weights.IMAGENET1K_V1                         |     15.355 |    338.064
ConvNeXt_Large_Weights.IMAGENET1K_V1                        |     34.361 |    754.537
ConvNeXt_Small_Weights.IMAGENET1K_V1                        |      8.684 |    191.703
ConvNeXt_Tiny_Weights.IMAGENET1K_V1                         |      4.456 |    109.119
DenseNet121_Weights.IMAGENET1K_V1                           |      2.834 |     30.845
DenseNet161_Weights.IMAGENET1K_V1                           |      7.728 |    110.369
DenseNet169_Weights.IMAGENET1K_V1                           |      3.360 |     54.708
DenseNet201_Weights.IMAGENET1K_V1                           |      4.291 |     77.373
EfficientNet_B0_Weights.IMAGENET1K_V1                       |      0.386 |     20.451
EfficientNet_B1_Weights.IMAGENET1K_V1                       |      0.687 |     30.134
EfficientNet_B1_Weights.IMAGENET1K_V2                       |      0.687 |     30.136
EfficientNet_B2_Weights.IMAGENET1K_V1                       |      1.088 |     35.174
EfficientNet_B3_Weights.IMAGENET1K_V1                       |      1.827 |     47.184
EfficientNet_B4_Weights.IMAGENET1K_V1                       |      4.394 |     74.489
EfficientNet_B5_Weights.IMAGENET1K_V1                       |     10.266 |    116.864
EfficientNet_B6_Weights.IMAGENET1K_V1                       |     19.068 |    165.362
EfficientNet_B7_Weights.IMAGENET1K_V1                       |     37.746 |    254.675
EfficientNet_V2_L_Weights.IMAGENET1K_V1                     |     56.080 |    454.573
EfficientNet_V2_M_Weights.IMAGENET1K_V1                     |     24.582 |    208.010
EfficientNet_V2_S_Weights.IMAGENET1K_V1                     |      8.366 |     82.704
GoogLeNet_Weights.IMAGENET1K_V1                             |      1.498 |     49.731
Inception_V3_Weights.IMAGENET1K_V1                          |      5.713 |    103.903
MaxVit_T_Weights.IMAGENET1K_V1                              |      5.558 |    118.769
MNASNet0_5_Weights.IMAGENET1K_V1                            |      0.104 |      8.591
MNASNet0_75_Weights.IMAGENET1K_V1                           |      0.215 |     12.303
MNASNet1_0_Weights.IMAGENET1K_V1                            |      0.314 |     16.915
MNASNet1_3_Weights.IMAGENET1K_V1                            |      0.526 |     24.246
MobileNet_V2_Weights.IMAGENET1K_V1                          |      0.301 |     13.555
MobileNet_V2_Weights.IMAGENET1K_V2                          |      0.301 |     13.598
MobileNet_V3_Large_Weights.IMAGENET1K_V1                    |      0.217 |     21.114
MobileNet_V3_Large_Weights.IMAGENET1K_V2                    |      0.217 |     21.107
MobileNet_V3_Small_Weights.IMAGENET1K_V1                    |      0.057 |      9.829
RegNet_X_16GF_Weights.IMAGENET1K_V1                         |     15.941 |    207.627
RegNet_X_16GF_Weights.IMAGENET1K_V2                         |     15.941 |    207.627
RegNet_X_1_6GF_Weights.IMAGENET1K_V1                        |      1.603 |     35.339
RegNet_X_1_6GF_Weights.IMAGENET1K_V2                        |      1.603 |     35.339
RegNet_X_32GF_Weights.IMAGENET1K_V1                         |     31.736 |    412.039
RegNet_X_32GF_Weights.IMAGENET1K_V2                         |     31.736 |    412.039
RegNet_X_3_2GF_Weights.IMAGENET1K_V1                        |      3.177 |     58.756
RegNet_X_3_2GF_Weights.IMAGENET1K_V2                        |      3.177 |     58.756
RegNet_X_400MF_Weights.IMAGENET1K_V1                        |      0.414 |     21.258
RegNet_X_400MF_Weights.IMAGENET1K_V2                        |      0.414 |     21.257
RegNet_X_800MF_Weights.IMAGENET1K_V1                        |      0.800 |     27.945
RegNet_X_800MF_Weights.IMAGENET1K_V2                        |      0.800 |     27.945
RegNet_X_8GF_Weights.IMAGENET1K_V1                          |      7.995 |    151.456
RegNet_X_8GF_Weights.IMAGENET1K_V2                          |      7.995 |    151.456
RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_E2E_V1               |    374.570 |   2461.564
RegNet_Y_128GF_Weights.IMAGENET1K_SWAG_LINEAR_V1            |    127.518 |   2461.564
RegNet_Y_16GF_Weights.IMAGENET1K_V1                         |     15.912 |    319.490
RegNet_Y_16GF_Weights.IMAGENET1K_V2                         |     15.912 |    319.490
RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_E2E_V1                |     46.735 |    319.490
RegNet_Y_16GF_Weights.IMAGENET1K_SWAG_LINEAR_V1             |     15.912 |    319.490
RegNet_Y_1_6GF_Weights.IMAGENET1K_V1                        |      1.612 |     43.152
RegNet_Y_1_6GF_Weights.IMAGENET1K_V2                        |      1.612 |     43.152
RegNet_Y_32GF_Weights.IMAGENET1K_V1                         |     32.280 |    554.076
RegNet_Y_32GF_Weights.IMAGENET1K_V2                         |     32.280 |    554.076
RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_E2E_V1                |     94.826 |    554.076
RegNet_Y_32GF_Weights.IMAGENET1K_SWAG_LINEAR_V1             |     32.280 |    554.076
RegNet_Y_3_2GF_Weights.IMAGENET1K_V1                        |      3.176 |     74.567
RegNet_Y_3_2GF_Weights.IMAGENET1K_V2                        |      3.176 |     74.567
RegNet_Y_400MF_Weights.IMAGENET1K_V1                        |      0.402 |     16.806
RegNet_Y_400MF_Weights.IMAGENET1K_V2                        |      0.402 |     16.806
RegNet_Y_800MF_Weights.IMAGENET1K_V1                        |      0.834 |     24.774
RegNet_Y_800MF_Weights.IMAGENET1K_V2                        |      0.834 |     24.774
RegNet_Y_8GF_Weights.IMAGENET1K_V1                          |      8.473 |    150.701
RegNet_Y_8GF_Weights.IMAGENET1K_V2                          |      8.473 |    150.701
ResNet101_Weights.IMAGENET1K_V1                             |      7.801 |    170.511
ResNet101_Weights.IMAGENET1K_V2                             |      7.801 |    170.530
ResNet152_Weights.IMAGENET1K_V1                             |     11.514 |    230.434
ResNet152_Weights.IMAGENET1K_V2                             |     11.514 |    230.474
ResNet18_Weights.IMAGENET1K_V1                              |      1.814 |     44.661
ResNet34_Weights.IMAGENET1K_V1                              |      3.664 |     83.275
ResNet50_Weights.IMAGENET1K_V1                              |      4.089 |     97.781
ResNet50_Weights.IMAGENET1K_V2                              |      4.089 |     97.790
ResNeXt101_32X8D_Weights.IMAGENET1K_V1                      |     16.414 |    339.586
ResNeXt101_32X8D_Weights.IMAGENET1K_V2                      |     16.414 |    339.673
ResNeXt101_64X4D_Weights.IMAGENET1K_V1                      |     15.460 |    319.318
ResNeXt50_32X4D_Weights.IMAGENET1K_V1                       |      4.230 |     95.789
ResNeXt50_32X4D_Weights.IMAGENET1K_V2                       |      4.230 |     95.833
ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1                    |      0.040 |      5.282
ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1                    |      0.145 |      8.791
ShuffleNet_V2_X1_5_Weights.IMAGENET1K_V1                    |      0.296 |     13.557
ShuffleNet_V2_X2_0_Weights.IMAGENET1K_V1                    |      0.583 |     28.433
SqueezeNet1_0_Weights.IMAGENET1K_V1                         |      0.819 |      4.778
SqueezeNet1_1_Weights.IMAGENET1K_V1                         |      0.349 |      4.729
Swin_B_Weights.IMAGENET1K_V1                                |     15.431 |    335.364
Swin_S_Weights.IMAGENET1K_V1                                |      8.741 |    189.786
Swin_T_Weights.IMAGENET1K_V1                                |      4.491 |    108.190
Swin_V2_B_Weights.IMAGENET1K_V1                             |     20.325 |    336.372
Swin_V2_S_Weights.IMAGENET1K_V1                             |     11.546 |    190.675
Swin_V2_T_Weights.IMAGENET1K_V1                             |      5.940 |    108.626
VGG11_Weights.IMAGENET1K_V1                                 |      7.609 |    506.840
VGG11_BN_Weights.IMAGENET1K_V1                              |      7.609 |    506.881
VGG13_Weights.IMAGENET1K_V1                                 |     11.308 |    507.545
VGG13_BN_Weights.IMAGENET1K_V1                              |     11.308 |    507.590
VGG16_Weights.IMAGENET1K_V1                                 |     15.470 |    527.796
VGG16_Weights.IMAGENET1K_FEATURES                           |     15.470 |    527.802
VGG16_BN_Weights.IMAGENET1K_V1                              |     15.470 |    527.866
VGG19_Weights.IMAGENET1K_V1                                 |     19.632 |    548.051
VGG19_BN_Weights.IMAGENET1K_V1                              |     19.632 |    548.143
ViT_B_16_Weights.IMAGENET1K_V1                              |     17.564 |    330.285
ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1                     |     55.484 |    331.398
ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1                  |     17.564 |    330.285
ViT_B_32_Weights.IMAGENET1K_V1                              |      4.409 |    336.604
ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1                     |   1016.717 |   2416.643
ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1                  |    167.295 |   2411.209
ViT_L_16_Weights.IMAGENET1K_V1                              |     61.555 |   1161.023
ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1                     |    361.986 |   1164.258
ViT_L_16_Weights.IMAGENET1K_SWAG_LINEAR_V1                  |     61.555 |   1161.023
ViT_L_32_Weights.IMAGENET1K_V1                              |     15.378 |   1169.449
Wide_ResNet101_2_Weights.IMAGENET1K_V1                      |     22.753 |    242.896
Wide_ResNet101_2_Weights.IMAGENET1K_V2                      |     22.753 |    484.747
Wide_ResNet50_2_Weights.IMAGENET1K_V1                       |     11.398 |    131.820
Wide_ResNet50_2_Weights.IMAGENET1K_V2                       |     11.398 |    263.124
MC3_18_Weights.KINETICS400_V1                               |     43.343 |     44.672
MViT_V1_B_Weights.KINETICS400_V1                            |     70.599 |    139.764
MViT_V2_S_Weights.KINETICS400_V1                            |     64.224 |    131.884
R2Plus1D_18_Weights.KINETICS400_V1                          |     40.519 |    120.318
R3D_18_Weights.KINETICS400_V1                               |     40.697 |    127.359
S3D_Weights.KINETICS400_V1                                  |     17.979 |     31.972
DeepLabV3_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1|     10.452 |     42.301
DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1         |    258.743 |    233.217
DeepLabV3_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1          |    178.722 |    160.515
FCN_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1               |    232.738 |    207.711
FCN_ResNet50_Weights.COCO_WITH_VOC_LABELS_V1                |    152.717 |    135.009
LRASPP_MobileNet_V3_Large_Weights.COCO_WITH_VOC_LABELS_V1   |      2.086 |     12.490
Raft_Large_Weights.C_T_V1                                   |    211.007 |     20.129
Raft_Large_Weights.C_T_V2                                   |    211.007 |     20.129
Raft_Large_Weights.C_T_SKHT_V1                              |    211.007 |     20.129
Raft_Large_Weights.C_T_SKHT_V2                              |    211.007 |     20.129
Raft_Large_Weights.C_T_SKHT_K_V1                            |    211.007 |     20.129
Raft_Large_Weights.C_T_SKHT_K_V2                            |    211.007 |     20.129
Raft_Small_Weights.C_T_V1                                   |     47.655 |      3.821
Raft_Small_Weights.C_T_V2                                   |     47.655 |      3.821
FasterRCNN_MobileNet_V3_Large_320_FPN_Weights.COCO_V1       |      0.719 |     74.239
FasterRCNN_MobileNet_V3_Large_FPN_Weights.COCO_V1           |      4.494 |     74.239
FasterRCNN_ResNet50_FPN_Weights.COCO_V1                     |    134.380 |    159.743
FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1                  |    280.371 |    167.104
FCOS_ResNet50_FPN_Weights.COCO_V1                           |    128.207 |    123.608
KeypointRCNN_ResNet50_FPN_Weights.COCO_LEGACY               |    133.924 |    226.054
KeypointRCNN_ResNet50_FPN_Weights.COCO_V1                   |    137.420 |    226.054
MaskRCNN_ResNet50_FPN_Weights.COCO_V1                       |    134.380 |    169.840
MaskRCNN_ResNet50_FPN_V2_Weights.COCO_V1                    |    333.577 |    177.219
RetinaNet_ResNet50_FPN_Weights.COCO_V1                      |    151.540 |    130.267
RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1                   |    152.238 |    146.037
SSD300_VGG16_Weights.COCO_V1                                |     34.858 |    135.988
SSDLite320_MobileNet_V3_Large_Weights.COCO_V1               |      0.583 |     13.418
GoogLeNet_QuantizedWeights.IMAGENET1K_FBGEMM_V1             |        nan |     12.618
Inception_V3_QuantizedWeights.IMAGENET1K_FBGEMM_V1          |        nan |     23.146
MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1         |        nan |      3.423
MobileNet_V3_Large_QuantizedWeights.IMAGENET1K_QNNPACK_V1   |        nan |     21.554
ResNet18_QuantizedWeights.IMAGENET1K_FBGEMM_V1              |        nan |     11.238
ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1              |        nan |     24.759
ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2              |        nan |     24.953
ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V1      |        nan |     86.034
ResNeXt101_32X8D_QuantizedWeights.IMAGENET1K_FBGEMM_V2      |        nan |     86.645
ResNeXt101_64X4D_QuantizedWeights.IMAGENET1K_FBGEMM_V1      |        nan |     81.556
ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1    |        nan |      1.501
ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1    |        nan |      2.334
ShuffleNet_V2_X1_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1    |        nan |      3.672
ShuffleNet_V2_X2_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1    |        nan |      7.467

cc @datumbox

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution @toni057! Overall looks a great addition.

I've added a few comments. Let me know what you think.

docs/source/conf.py Outdated Show resolved Hide resolved
docs/source/conf.py Outdated Show resolved Hide resolved
docs/source/conf.py Outdated Show resolved Hide resolved
docs/source/conf.py Outdated Show resolved Hide resolved
test/test_extended_models.py Outdated Show resolved Hide resolved
torchvision/models/vision_transformer.py Outdated Show resolved Hide resolved
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few mistakes:

torchvision/models/regnet.py Outdated Show resolved Hide resolved
torchvision/models/resnet.py Outdated Show resolved Hide resolved
torchvision/models/resnet.py Outdated Show resolved Hide resolved
@@ -145,6 +147,8 @@ class VGG11_BN_Weights(WeightsEnum):
"acc@5": 89.810,
}
},
"_flops": 7.609090,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was expecting this to be higher for the BN version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to be the same up to a decimal, and for all vgg models.

VGG11_Weights.IMAGENET1K_V1                           |    7609090048 | 506.840077
VGG11_BN_Weights.IMAGENET1K_V1                        |    7609090048 | 506.881400
VGG13_Weights.IMAGENET1K_V1                           |   11308466176 | 507.545068
VGG13_BN_Weights.IMAGENET1K_V1                        |   11308466176 | 507.589627
VGG16_Weights.IMAGENET1K_V1                           |   15470264320 | 527.795678
VGG16_Weights.IMAGENET1K_FEATURES                     |   15470264320 | 527.801824
VGG16_BN_Weights.IMAGENET1K_V1                        |   15470264320 | 527.866207
VGG19_Weights.IMAGENET1K_V1                           |   19632062464 | 548.051225
VGG19_BN_Weights.IMAGENET1K_V1                        |   19632062464 | 548.142819

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, it's because the flops script doesn't estimate all ops. It covers mostly convs, linear layers etc which account for 99% of computations usually. Ops like aten.mean/aten.mul used in BN layers aren't factored in. Maybe on the future we can have a better more precise utility to estimate the flops but for now this will do. We reduced the precision for 3 decimal points as larger precision for the total giga flops/ips doesn't make too much sense.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @toni057. I think we can merge on green CI. Let's follow up with adding some tests to confirm the numbers automatically for new models.

@datumbox
Copy link
Contributor

@toni057 I noticed that the names on the model builder page could be improved. Here is what we show now:
image

Perhaps we can replace the _flops with _flops (in GFLOPs) and _weight_size with _weight_size (Weights file size in MB) or something similar to explain clearly what it is? It should be a straight forward change on the inject_weight_metadata() method:

vision/docs/source/conf.py

Lines 356 to 365 in d72e906

for k, v in meta.items():
if k in {"recipe", "license"}:
v = f"`link <{v}>`__"
elif k == "min_size":
v = f"height={v[0]}, width={v[1]}"
elif k in {"categories", "keypoint_names"} and isinstance(v, list):
max_visible = 3
v_sample = ", ".join(v[:max_visible])
v = f"{v_sample}, ... ({len(v)-max_visible} omitted)" if len(v) > max_visible else v_sample
table.append((str(k), str(v)))

@datumbox
Copy link
Contributor

@toni057 After speaking with @Chillee, I've updated your script to estimate an approximation of the operations in quantized models. I've modified the code on the PR description in place. Let's chat tomorrow to see if we want to add some adjusted number of ops on Quantized models as well.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nits as per our discussion:

.gitignore Outdated Show resolved Hide resolved
docs/source/conf.py Outdated Show resolved Hide resolved
docs/source/conf.py Outdated Show resolved Hide resolved
docs/source/conf.py Outdated Show resolved Hide resolved
@datumbox
Copy link
Contributor

@pmeier FYI flaky test:

=========================== short test summary info ============================
FAILED test/test_prototype_transforms_functional.py::TestKernels::test_against_reference[rotate_image_tensor-38] - AssertionError: The 'mean' of the absolute difference is 108.33549783549783, but only 100.0 is allowed.
Failure happened for the following parameters:

'angle'

docs/source/conf.py Outdated Show resolved Hide resolved
docs/source/conf.py Outdated Show resolved Hide resolved
docs/source/conf.py Outdated Show resolved Hide resolved
@datumbox
Copy link
Contributor

@toni057 Thank you very much for your contribution. I'll merge on green CI. :)

@datumbox datumbox merged commit deba056 into pytorch:main Nov 11, 2022
@oke-aditya
Copy link
Contributor

oke-aditya commented Nov 11, 2022

Aah I just now finished doing the Video Swin Transformer. And I would need to estimate GFLOPs 😛

Can we somehow keep this script handy. (Maybe gist or in wiki pages?)

cc @datumbox @toni057

@datumbox
Copy link
Contributor

@oke-aditya We'll follow up with @toni057 to add on the tests, so we will definitely add the util somewhere.

@oke-aditya
Copy link
Contributor

Is weight_size in MBs? I feel it would be better in GBs considering how big the model have got these days.

@NicolasHug
Copy link
Member

Thanks for this

I noticed some of the doc renders like this:
image

Is the leading underscore intended here?

Also regarding the weight size, does the file size accurately describe the weight size, or is there some compression going on? (IIRC all of our .pth files are actually zip files, but I don't know the compression level)

@oke-aditya
Copy link
Contributor

Yeah good point. .pth is a zip file. Does weight size refer to compressed size or not?
When we load .pth model does it get uncompressed when loading to RAM? Then the model_size is not the same as the size it will take on RAM vs size on disk

@datumbox
Copy link
Contributor

The leading underscore is intended and is similar to what we did with _metrics. I don't mind if you want to send a PR and change the naming but my understanding is the column names should match the source-code. Also note that the _ops is an approximate number and can change on the future (if we have a better way to measure it), hence the reason we didn't make it public yet.

Concerning the size, it's the file size of the weights file. See the script on the PR description for how exactly this is estimated.

@NicolasHug
Copy link
Member

NicolasHug commented Nov 14, 2022

I'm happy to keep the "_ops" key "private" with a leading underscore. However I'm not sure this underscore should be reflected on the doc rendering. I could definitely be wrong, but from my understanding these tables aren't intended to show which keys are available and which aren't, they're just purely informational and decoupled from the underlying implementation. For example there would be no issue to render "number of parameters" in the docs, even though the key is still "num_params"?

Perhaps the original intent was to keep the keys and the doc rendering similar (I don't remember)? If that's the case, then we can still address both our concerns by rendering "_ops" as "FLOPS" and "_weight_size" as "Weights size": these hide the implementation detail (underscore) and make these private keys non-obvious from the docs, so there's little risk of users relying on them.

Concerning the size, it's the file size of the weights file

Got it - my question was: is the file size the same as the size of the weights dict (which is what we ultimately want to document)?

@datumbox
Copy link
Contributor

Got it - my question was: is the file size the same as the size of the weights dict (which is what we ultimately want to document)?

That's what it is. We don't store other information in those published weights (optimizer, EMA weights etc) other than the weights and any crucial meta-data (like version of layers) needed for loading them.

Perhaps the original intent was to keep the keys and the doc rendering similar (I don't remember)?

My understanding the original intent was for these to have the same keys, so that people know how to programmatically fetch them. If you change their names, then it becomes unclear how to access them for those that are public. Given that this is private though, we can do for now whatever we want. If you want to bring a PR, I'm happy to review it.

rendering "_ops" as "FLOPS"

I would use operations per second instead of flops. The later is not accurate for quantized models.

@NicolasHug
Copy link
Member

We don't store other information in those published weights (optimizer, EMA weights etc) other than the weights and any crucial meta-data (like version of layers) needed for loading them.

Are these files compressed in any way? Those .pth weight files are .zip files under the hood, hence my original question.

If you want to bring a PR, I'm happy to review it.

Historically in this repo, review comments that aren't nits are addressed by the PR author / advocate rather than by the reviewer. If you don't mind, I'd prefer not to have to open a PR for this.

facebook-github-bot pushed a commit that referenced this pull request Nov 14, 2022
Summary:
* Adding FLOPs and size to model metadata

* Adding weight size to quantization models

* Small refactor of rich metadata

* Removing unused code

* Fixing wrong entries

* Adding .DS_Store to gitignore

* Renaming _flops to _ops

* Adding number of operations to quantization models

* Reflecting _flops change to _ops

* Renamed ops and weight size in individual model doc pages

* Linter fixes

* Rounding ops to first decimal

* Rounding num ops and sizes to 3 decimals

* Change naming of columns.

* Update tables

Reviewed By: NicolasHug

Differential Revision: D41265180

fbshipit-source-id: e6f8629ba3f2177411716113430b87c1710982c0

Co-authored-by: Toni Blaslov <tblaslov@fb.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
@oke-aditya
Copy link
Contributor

One more thing to point out is that above script runs on CPU. Could we possibly run it on GPU?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants