diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 2b88b6bdad6..5c3f55b7a0a 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -983,7 +983,6 @@ def init_optim( boolean indicating whether the optimizer failed to initialize with optim_states. """ - if hasattr(self, 'resized_embeddings') and self.resized_embeddings: optim_states = None logging.warning('Not loading optimizer due to resize in token embeddings') @@ -2145,6 +2144,7 @@ def act(self): """ # BatchWorld handles calling self_observe, but we're in a Hogwild or Interactive # world, so we need to handle this ourselves. + response = self.batch_act([self.observation])[0] self.self_observe(response) return response @@ -2180,7 +2180,7 @@ def batch_act(self, observations): self.is_training = batch.is_training - # truncation statistics + # truncation statistics if batch._context_original_length is not None: self.record_local_metric( 'clen', AverageMetric.many(batch._context_original_length) @@ -2251,7 +2251,7 @@ def batch_act(self, observations): for k, values in self._local_metrics.items(): if len(values) != len(batch.valid_indices): raise IndexError( - f"Batchsize mismatch on metric {k} got {len(values)}, " + f"Batchsize mismatch on metric {k} (got {len(values)}, " f"expected {len(batch.valid_indices)}" ) for i, value in zip(batch.valid_indices, values): @@ -2296,7 +2296,7 @@ def set_interactive_mode(self, mode, shared): # Only print in the non-shared version. logging.info(f'{self.id}: full interactive mode on.') - def backward(self, loss): + def backward(self, loss, **kwargs): """ Perform a backward pass. @@ -2316,15 +2316,15 @@ def backward(self, loss): # accumulate without syncing with self.model.no_sync(): if self.fp16: - self.optimizer.backward(loss, update_main_grads=False) + self.optimizer.backward(loss, update_main_grads=False, **kwargs) else: - loss.backward() + loss.backward(**kwargs) return if self.fp16: - self.optimizer.backward(loss, update_main_grads=False) + self.optimizer.backward(loss, update_main_grads=False, **kwargs) else: - loss.backward() + loss.backward(**kwargs) def update_params(self): """ diff --git a/parlai/utils/fp16.py b/parlai/utils/fp16.py index 157a57042ef..bc888ed9b55 100644 --- a/parlai/utils/fp16.py +++ b/parlai/utils/fp16.py @@ -182,7 +182,7 @@ def load_state_dict(self, state_dict): self.scaler.loss_scale = state_dict['loss_scaler'] self.optimizer.load_state_dict(state_dict) - def backward(self, loss, update_main_grads=False): + def backward(self, loss, update_main_grads=False, retain_graph=False): """ Computes the sum of gradients of the given tensor w.r.t. graph leaves. @@ -191,7 +191,7 @@ def backward(self, loss, update_main_grads=False): """ if self.scaler is not None: loss = loss * self.scaler.loss_scale - loss.backward() + loss.backward(retain_graph=retain_graph) self._needs_sync = True if update_main_grads: self.update_main_grads()