-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
restore hidden states added #306
Conversation
This PR adds two features to the DRQN agent: |
hive/agents/drqn.py
Outdated
if self._store_hidden == True: | ||
if self._rnn_type == "lstm": | ||
preprocessed_update_info.update( | ||
{ | ||
"hidden_state": self._prev_hidden_state, | ||
"cell_state": self._prev_cell_state, | ||
} | ||
) | ||
|
||
elif self._rnn_type == "gru": | ||
preprocessed_update_info.update( | ||
{ | ||
"hidden_state": self._prev_hidden_state, | ||
} | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works too, but it might be easier to adapt this code to other memory based architecture if we instead store "memory" that can contain any kind of memory (hidden, hidden+cell, trxl memory), etc and have a function that can pack and unpack this memory? So the user has to only modify teh pack unpack function when adding a new memory based architecture?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kshitijkg With the most recent commit, the update and act functions are almost independent of the type of memory (I'm still calling it hidden_state because it is a DRQN agent, but we can replace everything with memory). Could you please review it again?
mask[self._burn_frames :] = 1.0 | ||
mask = mask.view(1, -1) | ||
interm_loss *= mask | ||
loss = interm_loss.mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the correct way of doing it interm_loss.sum() / mask.sum()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's true. This should be fixed after merging #339.
hive/agents/drqn.py
Outdated
if self._store_hidden == True: | ||
hidden_state = ( | ||
torch.tensor( | ||
batch["hidden_state"][:, 0].squeeze(1).squeeze(1).unsqueeze(0), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest using view()
or reshape()
to potentially make the code cleaner.
* first version. testing. * pylint * pylint * pylint * pylint * pylint * test fix * pylint * pylint * resolve discussion --------- Co-authored-by: artem.zholus <artem.zholus@login-3.server.mila.quebec> Co-authored-by: Darshan Patil <dapatil211@gmail.com>
* Added option to initialize separate components differently * Made minor fixes * Fixed init and registraion * Fix term trunc (#336) (#341) * Fixed issues with moving from done to terminated, truncated * Undo change to logging' scales * Revert changes to this file, SAC agents don't exist yet * Clean up test file Co-authored-by: Darshan Patil <dapatil211@gmail.com> --------- Co-authored-by: Darshan Patil <dapatil211@gmail.com>
No description provided.