Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding _ops and _weight_size metadata checks to tests #6996

Merged
merged 9 commits into from
Dec 1, 2022

Conversation

toni057
Copy link
Contributor

@toni057 toni057 commented Nov 30, 2022

Continuing on PR6936 where number of operations and model sizes were added, in this PR we are adding the logic for calculating the mentioned metadata to test, and verifying that the values added to metadata correspond to the values hardcoded for weights.

Due to the relatively long run times, we are limiting the solution to default weights only.

cc: @datumbox

cc @datumbox @pmeier

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work @toni057. Just a few comments:

Comment on lines 255 to 268
detection_models_input_dims = {
"fasterrcnn_mobilenet_v3_large_320_fpn": (320, 320),
"fasterrcnn_mobilenet_v3_large_fpn": (800, 800),
"fasterrcnn_resnet50_fpn": (800, 800),
"fasterrcnn_resnet50_fpn_v2": (800, 800),
"fcos_resnet50_fpn": (800, 800),
"keypointrcnn_resnet50_fpn": (1333, 1333),
"maskrcnn_resnet50_fpn": (800, 800),
"maskrcnn_resnet50_fpn_v2": (800, 800),
"retinanet_resnet50_fpn": (800, 800),
"retinanet_resnet50_fpn_v2": (800, 800),
"ssd300_vgg16": (300, 300),
"ssdlite320_mobilenet_v3_large": (320, 320),
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think this doesn't belong on the common_extended_utils.py file but rather on the test/test_extended_models.py file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, can move it there.

else:
if w.meta.get("num_params") != sum(p.numel() for p in model_fn(weights=w).parameters()):
incorrect_params.append(w)

calculated_ops = get_ops(module_name, model_name, w)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to review this logic. This way we initialize the models multiple times. Once on the model_fn call above and once within get_ops(). What we can do is initialize the model once and then use it in both cases.

if module_name == "quantization":
# parameters() count doesn't work well with quantization, so we check against the non-quantized
unquantized_w = w.meta.get("unquantized")
if unquantized_w is not None and w.meta.get("num_params") != unquantized_w.meta.get("num_params"):
incorrect_params.append(w)

# the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs instead
calculated_ops = get_ops(model=None, module_name="models", model_name=model_name, weight=unquantized_w)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have to do this estimation. We can follow the same approach as with the num_params. More precisely:
we fetch the unquantized_w.meta.get("_ops") and confirm that the match what we have here. Basically we reproduce the logic on lines 219-220.

return sum(self.flop_counts["Global"].values()) / 1e9


def get_ops(model: torch.nn.Module, module_name: str, model_name: str, weight: Weights, h=512, w=512):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's assume here that model is not None. Then we don't need the model_name parameter. The module_name is also unnecessary as it can be fetched from the model. More specifically:

>>> m = resnet50()
>>> m.__module__
'torchvision.models.resnet'

Comment on lines 278 to 280
if model is None:
kwargs = {"quantize": True} if module_name == "quantization" else {}
model = models.get_model(model_name, weights=weight, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can go away:

Suggested change
if model is None:
kwargs = {"quantize": True} if module_name == "quantization" else {}
model = models.get_model(model_name, weights=weight, **kwargs)

Comment on lines 226 to 227
# loading the model and using it for parameter and ops verification
kwargs = {"quantize": True} if module_name == "quantization" else {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary. We already checked it's not quantization above.

)

# assert that weight flops are correctly pasted to metadata
assert calculated_ops == w.meta["_ops"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't assert like this because it will fail immediately the test without showing us other issues. Instead we should be collecting all issues in one list and showing them to the user. Previously we had incorrect_params which was monitoring issues with the number of parameters. Now that we have more, it's worth switching this into something like incorrect_meta and append to it not only the weight but also the meta name that failed. For example: incorrect_params.append((w, "num_params")).

assert not problematic_weights
assert not incorrect_params
assert not bad_names
assert weight_size_mb == w.meta["_weight_size"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the above. This needs to be asserted properly for all weights. You can use the proposed incorrect_meta to track it as well.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @toni057. Looks great. The comment below is optional.

Let's wait for the tests to see whether there is any randomness, otherwise we should be good.

incorrect_meta.append((w, "num_params"))

# the methodology for quantized ops count doesn't work as well, so we take unquantized FLOPs instead
if unquantized_w is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor Nit: Since this check is needed for both num_params and _ops we can perhaps do it once for both and simplify the code?

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, only one optional Nit below. Your call.

Otherwise we can merge on green CI.

test/test_extended_models.py Outdated Show resolved Hide resolved
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
@datumbox datumbox merged commit 790f1cd into pytorch:main Dec 1, 2022
facebook-github-bot pushed a commit that referenced this pull request Dec 12, 2022
Summary:
* Adding _ops and _weight_size metadata checks to tests

* Fixing wrong ops value

* Changing test_schema_meta_validation to instantiate the model only once

* moving instantiating quantized models inside get_ops

* Small refactor of test_schema_meta_validation logic

* Reverting to previous ops value

* Simplifying unquantized models logic in test_schema_meta_validation

* Update test/test_extended_models.py

Reviewed By: datumbox

Differential Revision: D41836893

fbshipit-source-id: 9174c95ee1843d972898fcd89c3d4e1697e83bca

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Co-authored-by: Toni Blaslov <tblaslov@fb.com>
Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants