Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

expose retain_graph to user within TorchAgent #4720

Merged
merged 12 commits into from
Aug 8, 2022
16 changes: 8 additions & 8 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,7 +983,6 @@ def init_optim(
boolean indicating whether the optimizer failed to initialize with
optim_states.
"""

if hasattr(self, 'resized_embeddings') and self.resized_embeddings:
optim_states = None
logging.warning('Not loading optimizer due to resize in token embeddings')
Expand Down Expand Up @@ -2145,6 +2144,7 @@ def act(self):
"""
# BatchWorld handles calling self_observe, but we're in a Hogwild or Interactive
# world, so we need to handle this ourselves.

response = self.batch_act([self.observation])[0]
self.self_observe(response)
return response
Expand Down Expand Up @@ -2180,7 +2180,7 @@ def batch_act(self, observations):

self.is_training = batch.is_training

# truncation statistics
# truncation statistics
if batch._context_original_length is not None:
self.record_local_metric(
'clen', AverageMetric.many(batch._context_original_length)
Expand Down Expand Up @@ -2251,7 +2251,7 @@ def batch_act(self, observations):
for k, values in self._local_metrics.items():
if len(values) != len(batch.valid_indices):
raise IndexError(
f"Batchsize mismatch on metric {k} got {len(values)}, "
f"Batchsize mismatch on metric {k} (got {len(values)}, "
f"expected {len(batch.valid_indices)}"
)
for i, value in zip(batch.valid_indices, values):
Expand Down Expand Up @@ -2296,7 +2296,7 @@ def set_interactive_mode(self, mode, shared):
# Only print in the non-shared version.
logging.info(f'{self.id}: full interactive mode on.')

def backward(self, loss):
def backward(self, loss, **kwargs):
"""
Perform a backward pass.

Expand All @@ -2316,15 +2316,15 @@ def backward(self, loss):
# accumulate without syncing
with self.model.no_sync():
if self.fp16:
self.optimizer.backward(loss, update_main_grads=False)
self.optimizer.backward(loss, update_main_grads=False, **kwargs)
else:
loss.backward()
loss.backward(**kwargs)
return

if self.fp16:
self.optimizer.backward(loss, update_main_grads=False)
self.optimizer.backward(loss, update_main_grads=False, **kwargs)
else:
loss.backward()
loss.backward(**kwargs)

def update_params(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions parlai/utils/fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def load_state_dict(self, state_dict):
self.scaler.loss_scale = state_dict['loss_scaler']
self.optimizer.load_state_dict(state_dict)

def backward(self, loss, update_main_grads=False):
def backward(self, loss, update_main_grads=False, retain_graph=False):
"""
Computes the sum of gradients of the given tensor w.r.t. graph leaves.

Expand All @@ -191,7 +191,7 @@ def backward(self, loss, update_main_grads=False):
"""
if self.scaler is not None:
loss = loss * self.scaler.loss_scale
loss.backward()
loss.backward(retain_graph=retain_graph)
self._needs_sync = True
if update_main_grads:
self.update_main_grads()
Expand Down