-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve to_torch/to_numpy converters #147
Conversation
Maybe a more clever mechanism could be implemented for If all elements have the same type and shape (except for the first dim), than it can be converted as a whole into torch.tensor/np.ndarray, otherwise each element is handled separately. What do you guys think ? @youkaichao @Trinkle23897 |
I'm ok with this. If they finally get into a Batch, they will be converted to a whole array anyway. If not, it makes sense to handle them separately. |
Agree with Kaichao |
I think the point is, previous code returns one merged tensor for |
Exactly, it is a step forward supporting action space |
A bug: In [3]: d=np.zeros([3,6,6])
In [4]: to_torch(d[[]])
~/github/tianshou-new/tianshou/data/utils.py in to_torch(x, dtype, device)
41 x = to_torch(np.asanyarray(x), dtype, device)
42 elif isinstance(x, np.ndarray) and \
---> 43 isinstance(x.item(0), (np.number, np.bool_, Number)):
44 x = torch.from_numpy(x).to(device)
45 if dtype is not None:
IndexError: index 0 is out of bounds for size 0 However, In [6]: to_torch(Batch(d=d)[[]])
Out[6]:
Batch(
d: tensor([], size=(0, 6, 6), dtype=torch.float64),
) So, do not use |
Yes I noticed this. Thank you ! |
So could you please add this into the testcase? |
Sure ! on it |
Ready for review. |
Personally, I don't think this is the best way to support gym.Space.Tuple in either the observation space or the action space. A nice workaround is to wrap the original environment to return dict state and accept dict action, which is natively supported by tianshou.data.Batch. Take observation space as an example, if an environment returns a tuple observation with two items, one is image observation with shape [224, 224, 3], the other is a vector observation with shape [10]. By writing an environment wrapper, the observation is for i in range(100):
buffer.add(obs={'img':np.zeros((224, 224, 3)), 'vec':np.zeros((10,))})
buffer.obs.img # tensor of shape [100, 224, 224, 3], ready to feed in neural network policy
buffer.obs.vec # tensor of shape [100, 10], ready to feed in neural network policy But if we support tuple observation as this pr, the observation in the buffer is a 2-d object array. The policy has to unpack the observation array and construct separate tensors for image observation and vector observations. for i in range(100):
buffer.add(obs=[np.zeros((224, 224, 3)), np.zeros((10,))])
buffer.obs # tensor of shape [100, 2], with data type np.object
img = np.stack(buffer.obs[:, 0]) # user has to do this explicitly!
vec = np.stack(buffer.obs[:, 1]) # user has to do this explicitly! In addition, a = np.zeros((3, 4))
a
Out[104]:
array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
b = a[:, 0] # sliced objects share memory
b[0] = 1
In[107]:
a
Out[107]:
array([[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]])
d = np.stack([a, a]) # stacking objects require additional memory cost
d
Out[109]:
array([[[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
d[:] = 0
d
Out[111]:
array([[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]])
a
Out[112]:
array([[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]) |
@youkaichao yes I know that. That's why I'm not planning to implement anything else regarding tuple support. Batch is obviously not design for that at this point. As I said, it is just a step forward. Currently, it only consists in increasing the versatility of generic helpers, that are not indended to be used only internally in conjunction with Batch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work👍
This case can bypass the check: In [10]: d=Batch(a=[1, np.zeros([3,3]), np.zeros([3,3]), torch.zeros(3,3)])
In [11]: d
Out[11]:
Batch(
a: array([1, array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]),
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]),
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])], dtype=object),
)
In [12]: Batch.cat([d,d])
Out[12]:
Batch(
a: array([1, array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]),
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]),
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]),
1, array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]),
array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]),
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])], dtype=object),
) I'm not sure this case is well enough. I think it is too "corner-case". |
Indeed, I though it was forbidden. I will fix it and add a test case. |
It is fine now. |
I think |
Indeed. Fixed. Added unit test. |
* Enable converting list/tuple back and forth from/to numpy/torch. * Add fallbacks. * Fix PEP8 * Update unit tests. * Type annotation. Robust dtype check. * List of object are converted individually, as a single tensor otherwise. * Improve robustness of _to_array_with_correct_type * Add unit tests. * Do not catch exception at _to_array_with_correct_type level. * Use _parse_value * Fix PEP8 * Fix _parse_value list output type fallback. * Catch torch exception. * Do not convert torch tensor during fallback. * Improve unit tests. * Add unit tests. * FIx missing import * Remove support of numpy arrays of tensors for Batch value parser. * Forbid numpy arrays of tensors. * Fix PEP8. * Fix comment. * Reduce _parse_value branch number. * Fix None value. * Forward error message for debugging purpose. * Fix _is_scalar. * More specific try/catch blocks. * Fix exception chaining. * Fix PEP8. * Fix _is_scalar. * Fix missing corner case. * Fix PEP8. * Allow Batch empty key. * Fix multi-dim array datatype check. Co-authored-by: Alexis Duburcq <alexis.duburcq@wandercraft.eu>
list
andtuple
.to_numpy