diff --git a/src/python/detectors/pytorch_miss_call_to_zero_grad/pytorch_miss_call_to_zero_grad.py b/src/python/detectors/pytorch_miss_call_to_zero_grad/pytorch_miss_call_to_zero_grad.py index e96e0a6..e6e19fd 100644 --- a/src/python/detectors/pytorch_miss_call_to_zero_grad/pytorch_miss_call_to_zero_grad.py +++ b/src/python/detectors/pytorch_miss_call_to_zero_grad/pytorch_miss_call_to_zero_grad.py @@ -18,22 +18,6 @@ def pytorch_miss_call_to_zero_grad_noncompliant( # zero before doing a backward pass. loss.backward() optimizer.step() - - avg_loss += loss.item() - # train_error += torch.sum((output > 0) != label) - true_pos += torch.sum((output >= 0).float() * label) - false_pos += torch.sum((output >= 0).float() * (1.0 - label)) - true_neg += torch.sum((output < 0).float() * (1.0 - label)) - false_neg += torch.sum((output < 0).float() * label) - - print(f'\rEpoch {i_epoch},\ - Training {i_batch+1:3d}/{len(dataloader):3d} batch, ' - f'loss {loss.item():0.6f} ', end='') - - avg_loss /= len(dataloader) - tpr = float(true_pos) / float(true_pos + false_neg) - fpr = float(false_pos) / float(false_pos + true_neg) - return avg_loss, tpr, fpr # {/fact} @@ -54,20 +38,4 @@ def pytorch_miss_call_to_zero_grad_compliant( optimizer.zero_grad() loss.backward() optimizer.step() - - avg_loss += loss.item() - # train_error += torch.sum((output > 0) != label) - true_pos += torch.sum((output >= 0).float() * label) - false_pos += torch.sum((output >= 0).float() * (1.0 - label)) - true_neg += torch.sum((output < 0).float() * (1.0 - label)) - false_neg += torch.sum((output < 0).float() * label) - - print(f'\rEpoch {i_epoch},\ - Training {i_batch+1:3d}/{len(dataloader):3d} batch, ' - f'loss {loss.item():0.6f} ', end='') - - avg_loss /= len(dataloader) - tpr = float(true_pos) / float(true_pos + false_neg) - fpr = float(false_pos) / float(false_pos + true_neg) - return avg_loss, tpr, fpr # {/fact}