Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

expose retain_graph to user within TorchAgent #4720

Merged
merged 12 commits into from
Aug 8, 2022

Conversation

prajjwal1
Copy link
Contributor

@prajjwal1 prajjwal1 commented Aug 6, 2022

Patch description
Although loss.backward() is the norm in usual workflow, there are many instances in which a user might want to pass in retain_graph=True within loss. Considering that parlai doesn't expose the loss api to user, since TorchAgent expects user to call self.backward(), it would be a good addition to let the user pass in retain_graph=True if they want to. User can also pass in create_graph=False, inputs=None params if they want to, to backward().

Testing steps

Other information

@prajjwal1 prajjwal1 requested a review from klshuster August 6, 2022 20:08
Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@prajjwal1 prajjwal1 merged commit ae53361 into facebookresearch:main Aug 8, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants