Skip to content

Commit

Permalink
Add no_grad to validation steps (NVIDIA#3071)
Browse files Browse the repository at this point in the history
* add no_grad to validation steps

* small edit
  • Loading branch information
holgerroth authored Nov 25, 2024
1 parent 5534af5 commit 0af7099
Showing 1 changed file with 16 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,11 @@ def _train_step_data_side(self, batch_indices):
def _val_step_data_side(self, batch_indices):
t_start = timer()
self.model.eval()
with torch.no_grad():
inputs = self.valid_dataset.get_batch(batch_indices)
inputs = inputs.to(self.device)

inputs = self.valid_dataset.get_batch(batch_indices)
inputs = inputs.to(self.device)

_val_activations = self.model.forward(inputs) # keep on site-1
_val_activations = self.model.forward(inputs) # keep on site-1

self.compute_stats_pool.record_value(category="_val_step_data_side", value=timer() - t_start)

Expand Down Expand Up @@ -295,23 +295,24 @@ def _train_step_label_side(self, batch_indices, activations, fl_ctx: FLContext):
def _val_step_label_side(self, batch_indices, activations, fl_ctx: FLContext):
t_start = timer()
self.model.eval()
with torch.no_grad():
labels = self.valid_dataset.get_batch(batch_indices)
labels = labels.to(self.device)

labels = self.valid_dataset.get_batch(batch_indices)
labels = labels.to(self.device)
if self.fp16:
activations = activations.type(torch.float32) # return to default pytorch precision

if self.fp16:
activations = activations.type(torch.float32) # return to default pytorch precision
activations = activations.to(self.device)

activations = activations.to(self.device)
pred = self.model.forward(activations)

pred = self.model.forward(activations)
loss = self.criterion(pred, labels)
self.val_loss.append(loss.unsqueeze(0)) # unsqueeze needed for later concatenation
loss = self.criterion(pred, labels)
self.val_loss.append(loss.unsqueeze(0)) # unsqueeze needed for later concatenation

_, pred_labels = torch.max(pred, 1)
_, pred_labels = torch.max(pred, 1)

self.val_pred_labels.extend(pred_labels.unsqueeze(0))
self.val_labels.extend(labels.unsqueeze(0))
self.val_pred_labels.extend(pred_labels.unsqueeze(0))
self.val_labels.extend(labels.unsqueeze(0))

self.compute_stats_pool.record_value(category="_val_step_label_side", value=timer() - t_start)

Expand Down

0 comments on commit 0af7099

Please sign in to comment.