From c92562f674d5a2de5cc6a50fc7dbb54f1038f4cb Mon Sep 17 00:00:00 2001 From: ankit-amazon <125257518+ankit-amazon@users.noreply.github.com> Date: Tue, 9 May 2023 23:16:09 +0530 Subject: [PATCH 1/2] Update pytorch-disable-gradient-calculation.py --- .../pytorch-disable-gradient-calculation.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py b/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py index 8dc4da0..3b6d54e 100644 --- a/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py +++ b/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py @@ -7,6 +7,14 @@ def disable_gradient_calculation_noncompliant(): # Noncompliant: disables gradient calculation using `torch.no_grad()`. with torch.no_grad(): model.eval() + # some code + + +def disable_gradient_calculation_noncompliant(): + import torch + # Noncompliant: gradient calculation not disabled during evaluation. + model.eval() + # some code # {/fact} @@ -16,4 +24,5 @@ def disable_gradient_calculation_compliant(): # Compliant: disables gradient calculation using `torch.inference_mode()`. with torch.inference_mode(): model.eval() + # some code # {/fact} From 95ed1e4fe9f3b82eb4b2eeb3037c4f5345efa626 Mon Sep 17 00:00:00 2001 From: ankit-amazon <125257518+ankit-amazon@users.noreply.github.com> Date: Tue, 9 May 2023 23:21:47 +0530 Subject: [PATCH 2/2] Update pytorch-disable-gradient-calculation.py --- .../pytorch-disable-gradient-calculation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py b/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py index 3b6d54e..f967575 100644 --- a/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py +++ b/src/python/detectors/pytorch-disable-gradient-calculation/pytorch-disable-gradient-calculation.py @@ -8,8 +8,8 @@ def disable_gradient_calculation_noncompliant(): with torch.no_grad(): model.eval() # some code - - + + def disable_gradient_calculation_noncompliant(): import torch # Noncompliant: gradient calculation not disabled during evaluation.