diff --git a/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py b/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py index 8dc4da0..f967575 100644 --- a/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py +++ b/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py @@ -7,6 +7,14 @@ def disable_gradient_calculation_noncompliant(): # Noncompliant: disables gradient calculation using `torch.no_grad()`. with torch.no_grad(): model.eval() + # some code + + +def disable_gradient_calculation_noncompliant(): + import torch + # Noncompliant: gradient calculation not disabled during evaluation. + model.eval() + # some code # {/fact} @@ -16,4 +24,5 @@ def disable_gradient_calculation_compliant(): # Compliant: disables gradient calculation using `torch.inference_mode()`. with torch.inference_mode(): model.eval() + # some code # {/fact}