Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restore hidden states added #306

Merged
merged 20 commits into from
Jun 1, 2023
Merged

restore hidden states added #306

merged 20 commits into from
Jun 1, 2023

Conversation

hnekoeiq
Copy link
Collaborator

No description provided.

@hnekoeiq hnekoeiq requested a review from dapatil211 November 10, 2022 20:51
@hnekoeiq hnekoeiq requested a review from kshitijkg November 29, 2022 14:13
@hnekoeiq
Copy link
Collaborator Author

hnekoeiq commented Dec 1, 2022

This PR adds two features to the DRQN agent:
1- Restoring hidden states from the replay buffer
2- Burn-in frames to warm up the rnn module.

Comment on lines 204 to 218
if self._store_hidden == True:
if self._rnn_type == "lstm":
preprocessed_update_info.update(
{
"hidden_state": self._prev_hidden_state,
"cell_state": self._prev_cell_state,
}
)

elif self._rnn_type == "gru":
preprocessed_update_info.update(
{
"hidden_state": self._prev_hidden_state,
}
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works too, but it might be easier to adapt this code to other memory based architecture if we instead store "memory" that can contain any kind of memory (hidden, hidden+cell, trxl memory), etc and have a function that can pack and unpack this memory? So the user has to only modify teh pack unpack function when adding a new memory based architecture?

Copy link
Collaborator Author

@hnekoeiq hnekoeiq Jan 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kshitijkg With the most recent commit, the update and act functions are almost independent of the type of memory (I'm still calling it hidden_state because it is a DRQN agent, but we can replace everything with memory). Could you please review it again?

@hnekoeiq hnekoeiq requested a review from mrsamsami February 21, 2023 14:12
mask[self._burn_frames :] = 1.0
mask = mask.view(1, -1)
interm_loss *= mask
loss = interm_loss.mean()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the correct way of doing it interm_loss.sum() / mask.sum()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true. This should be fixed after merging #339.

if self._store_hidden == True:
hidden_state = (
torch.tensor(
batch["hidden_state"][:, 0].squeeze(1).squeeze(1).unsqueeze(0),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest using view() or reshape() to potentially make the code cleaner.

hnekoeiq and others added 15 commits May 17, 2023 18:07
* first version. testing.

* pylint

* pylint

* pylint

* pylint

* pylint

* test fix

* pylint

* pylint

* resolve discussion

---------

Co-authored-by: artem.zholus <artem.zholus@login-3.server.mila.quebec>
Co-authored-by: Darshan Patil <dapatil211@gmail.com>
* Added option to initialize separate components differently

* Made minor fixes

* Fixed init and registraion

* Fix term trunc (#336) (#341)

* Fixed issues with moving from done to terminated, truncated

* Undo change to logging' scales

* Revert changes to this file, SAC agents don't exist yet

* Clean up test file

Co-authored-by: Darshan Patil <dapatil211@gmail.com>

---------

Co-authored-by: Darshan Patil <dapatil211@gmail.com>
@dapatil211 dapatil211 merged commit 09a21d2 into dev Jun 1, 2023
@dapatil211 dapatil211 deleted the rnn_features branch June 1, 2023 16:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants