-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update to barracuda 1.3.3 and changes to the model inputs and outputs for LSTM #5236
Changes from 8 commits
6f12ff7
9e5f864
a541af9
e24d0c9
9e19e80
ca70a0b
f7919e8
bfedf6a
65c19a7
908810c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
import abc | ||
from typing import Tuple | ||
from enum import Enum | ||
from mlagents.trainers.torch.model_serialization import exporting_to_onnx | ||
|
||
|
||
class Swish(torch.nn.Module): | ||
|
@@ -206,7 +207,19 @@ def forward( | |
# We don't use torch.split here since it is not supported by Barracuda | ||
h0 = memories[:, :, : self.hidden_size].contiguous() | ||
c0 = memories[:, :, self.hidden_size :].contiguous() | ||
|
||
if exporting_to_onnx.is_exporting(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment above about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this corresponds to a slice operator, not a split. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually think |
||
# This transpose is needed both at input and output of the LSTM when | ||
# exporting because ONNX will expect (sequence_len, batch, memory_size) | ||
# instead of (batch, sequence_len, memory_size) | ||
h0 = torch.transpose(h0, 0, 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should transpose it before the split into |
||
c0 = torch.transpose(c0, 0, 1) | ||
|
||
hidden = (h0, c0) | ||
lstm_out, hidden_out = self.lstm(input_tensor, hidden) | ||
output_mem = torch.cat(hidden_out, dim=-1) | ||
|
||
if exporting_to_onnx.is_exporting(): | ||
output_mem = torch.transpose(output_mem, 0, 1) | ||
|
||
return lstm_out, output_mem |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -546,7 +546,7 @@ def forward( | |
|
||
|
||
class SimpleActor(nn.Module, Actor): | ||
MODEL_EXPORT_VERSION = 3 | ||
MODEL_EXPORT_VERSION = 3 # Corresponds to ModelApiVersion.MLAgents2_0 | ||
|
||
def __init__( | ||
self, | ||
|
@@ -643,6 +643,7 @@ def forward( | |
At this moment, torch.onnx.export() doesn't accept None as tensor to be exported, | ||
so the size of return tuple varies with action spec. | ||
""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra line? |
||
encoding, memories_out = self.network_body( | ||
inputs, memories=memories, sequence_length=1 | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add something about LSTM as well