-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add pad
argument to TensorDict.where
#539
Conversation
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 0.2099ms | 23.8093μs | 42.0004 KOps/s | 43.7642 KOps/s | |
test_plain_set_stack_nested | 0.7977ms | 0.2309ms | 4.3305 KOps/s | 4.7125 KOps/s | |
test_plain_set_nested_inplace | 2.5216ms | 30.0819μs | 33.2426 KOps/s | 37.1543 KOps/s | |
test_plain_set_stack_nested_inplace | 8.3801ms | 0.2867ms | 3.4880 KOps/s | 3.9635 KOps/s | |
test_items | 0.6318ms | 4.4752μs | 223.4519 KOps/s | 238.5537 KOps/s | |
test_items_nested | 4.4490ms | 0.4802ms | 2.0825 KOps/s | 2.4291 KOps/s | |
test_items_nested_locked | 3.1876ms | 0.4684ms | 2.1348 KOps/s | 2.4212 KOps/s | |
test_items_nested_leaf | 1.6348ms | 0.2746ms | 3.6421 KOps/s | 3.9120 KOps/s | |
test_items_stack_nested | 5.5966ms | 2.5792ms | 387.7147 Ops/s | 414.9062 Ops/s | |
test_items_stack_nested_leaf | 4.8573ms | 2.3239ms | 430.3044 Ops/s | 477.6039 Ops/s | |
test_items_stack_nested_locked | 2.9872ms | 1.2314ms | 812.0751 Ops/s | 872.8363 Ops/s | |
test_keys | 6.9045ms | 6.5560μs | 152.5311 KOps/s | 166.3764 KOps/s | |
test_keys_nested | 1.9697ms | 0.2292ms | 4.3634 KOps/s | 4.3095 KOps/s | |
test_keys_nested_locked | 7.0827ms | 0.2402ms | 4.1629 KOps/s | 4.7031 KOps/s | |
test_keys_nested_leaf | 0.9731ms | 0.2041ms | 4.9003 KOps/s | 4.8886 KOps/s | |
test_keys_stack_nested | 2.2217ms | 2.1213ms | 471.4009 Ops/s | 471.0828 Ops/s | |
test_keys_stack_nested_leaf | 2.1888ms | 2.1191ms | 471.9023 Ops/s | 470.7128 Ops/s | |
test_keys_stack_nested_locked | 1.0697ms | 0.9385ms | 1.0655 KOps/s | 1.0209 KOps/s | |
test_values | 32.1020μs | 1.9181μs | 521.3398 KOps/s | 528.4415 KOps/s | |
test_values_nested | 0.1045ms | 73.5823μs | 13.5902 KOps/s | 13.5767 KOps/s | |
test_values_nested_locked | 0.1664ms | 73.7949μs | 13.5511 KOps/s | 13.5300 KOps/s | |
test_values_nested_leaf | 0.1016ms | 66.0246μs | 15.1459 KOps/s | 15.1399 KOps/s | |
test_values_stack_nested | 1.9415ms | 1.8473ms | 541.3286 Ops/s | 534.2261 Ops/s | |
test_values_stack_nested_leaf | 2.0213ms | 1.8420ms | 542.8804 Ops/s | 543.3160 Ops/s | |
test_values_stack_nested_locked | 0.8133ms | 0.7380ms | 1.3550 KOps/s | 1.3144 KOps/s | |
test_membership | 29.7020μs | 2.1781μs | 459.1203 KOps/s | 477.8933 KOps/s | |
test_membership_nested | 24.7010μs | 4.1524μs | 240.8224 KOps/s | 240.2343 KOps/s | |
test_membership_nested_leaf | 32.3020μs | 4.1739μs | 239.5812 KOps/s | 242.0039 KOps/s | |
test_membership_stacked_nested | 45.1040μs | 17.1177μs | 58.4190 KOps/s | 59.2371 KOps/s | |
test_membership_stacked_nested_leaf | 0.1069ms | 17.2944μs | 57.8223 KOps/s | 59.3947 KOps/s | |
test_membership_nested_last | 0.1245ms | 8.7900μs | 113.7660 KOps/s | 113.3575 KOps/s | |
test_membership_nested_leaf_last | 39.8030μs | 8.8239μs | 113.3286 KOps/s | 112.2305 KOps/s | |
test_membership_stacked_nested_last | 0.3139ms | 0.2646ms | 3.7800 KOps/s | 3.8209 KOps/s | |
test_membership_stacked_nested_leaf_last | 86.0070μs | 20.0859μs | 49.7862 KOps/s | 50.4802 KOps/s | |
test_nested_getleaf | 62.5050μs | 18.1287μs | 55.1610 KOps/s | 56.0397 KOps/s | |
test_nested_get | 45.7030μs | 17.1131μs | 58.4346 KOps/s | 59.1189 KOps/s | |
test_stacked_getleaf | 1.1532ms | 1.0128ms | 987.3684 Ops/s | 990.9482 Ops/s | |
test_stacked_get | 1.0217ms | 0.9626ms | 1.0389 KOps/s | 1.0405 KOps/s | |
test_nested_getitemleaf | 85.0070μs | 18.1171μs | 55.1963 KOps/s | 56.1451 KOps/s | |
test_nested_getitem | 76.0060μs | 17.2244μs | 58.0573 KOps/s | 59.2101 KOps/s | |
test_stacked_getitemleaf | 1.1810ms | 1.0146ms | 985.6020 Ops/s | 988.5928 Ops/s | |
test_stacked_getitem | 1.0164ms | 0.9652ms | 1.0360 KOps/s | 1.0376 KOps/s | |
test_lock_nested | 85.7782ms | 1.8388ms | 543.8383 Ops/s | 585.5171 Ops/s | |
test_lock_stack_nested | 0.1097s | 24.1400ms | 41.4251 Ops/s | 41.7694 Ops/s | |
test_unlock_nested | 81.8578ms | 1.8383ms | 543.9832 Ops/s | 557.8651 Ops/s | |
test_unlock_stack_nested | 0.1086s | 24.6308ms | 40.5995 Ops/s | 40.3764 Ops/s | |
test_flatten_speed | 5.0562ms | 1.3318ms | 750.8797 Ops/s | 828.9012 Ops/s | |
test_unflatten_speed | 7.2961ms | 2.3150ms | 431.9689 Ops/s | 470.2679 Ops/s | |
test_common_ops | 5.6305ms | 1.2869ms | 777.0794 Ops/s | 768.9811 Ops/s | |
test_creation | 92.8070μs | 7.1288μs | 140.2760 KOps/s | 136.8237 KOps/s | |
test_creation_empty | 51.8040μs | 15.7867μs | 63.3444 KOps/s | 61.3301 KOps/s | |
test_creation_nested_1 | 62.3050μs | 27.9821μs | 35.7371 KOps/s | 34.5929 KOps/s | |
test_creation_nested_2 | 57.0050μs | 30.3968μs | 32.8982 KOps/s | 32.3650 KOps/s | |
test_clone | 0.1541ms | 28.2866μs | 35.3524 KOps/s | 35.7399 KOps/s | |
test_getitem[int] | 70.6060μs | 32.4622μs | 30.8051 KOps/s | 31.1104 KOps/s | |
test_getitem[slice_int] | 0.2641ms | 64.4811μs | 15.5084 KOps/s | 15.7186 KOps/s | |
test_getitem[range] | 0.2141ms | 93.9731μs | 10.6413 KOps/s | 10.4876 KOps/s | |
test_getitem[tuple] | 0.1245ms | 53.2096μs | 18.7936 KOps/s | 19.0668 KOps/s | |
test_getitem[list] | 0.1776ms | 88.8234μs | 11.2583 KOps/s | 11.1059 KOps/s | |
test_setitem_dim[int] | 0.1438ms | 40.3658μs | 24.7734 KOps/s | 25.4656 KOps/s | |
test_setitem_dim[slice_int] | 0.1079ms | 70.4382μs | 14.1968 KOps/s | 14.3342 KOps/s | |
test_setitem_dim[range] | 0.1280ms | 93.1611μs | 10.7341 KOps/s | 10.3788 KOps/s | |
test_setitem_dim[tuple] | 87.1070μs | 58.7019μs | 17.0352 KOps/s | 17.3079 KOps/s | |
test_setitem | 0.1886ms | 36.1743μs | 27.6439 KOps/s | 27.1884 KOps/s | |
test_set | 0.2145ms | 35.2661μs | 28.3558 KOps/s | 28.2893 KOps/s | |
test_set_shared | 0.4740ms | 0.2363ms | 4.2324 KOps/s | 4.2158 KOps/s | |
test_update | 0.2259ms | 40.1891μs | 24.8823 KOps/s | 24.5848 KOps/s | |
test_update_nested | 0.2453ms | 59.4428μs | 16.8229 KOps/s | 16.6685 KOps/s | |
test_set_nested | 0.2322ms | 38.6256μs | 25.8895 KOps/s | 25.7685 KOps/s | |
test_set_nested_new | 0.2220ms | 61.8130μs | 16.1778 KOps/s | 16.3575 KOps/s | |
test_select | 0.1537ms | 0.1143ms | 8.7471 KOps/s | 8.8593 KOps/s | |
test_unbind_speed | 1.1394ms | 0.7603ms | 1.3152 KOps/s | 1.3165 KOps/s | |
test_unbind_speed_stack0 | 88.4244ms | 10.2953ms | 97.1319 Ops/s | 94.6757 Ops/s | |
test_unbind_speed_stack1 | 20.0020μs | 1.3083μs | 764.3402 KOps/s | 761.1382 KOps/s | |
test_creation[device0] | 5.3428ms | 0.5369ms | 1.8626 KOps/s | 1.8354 KOps/s | |
test_creation_from_tensor | 8.8058ms | 0.6182ms | 1.6176 KOps/s | 1.6666 KOps/s | |
test_add_one[memmap_tensor0] | 2.0944ms | 37.4748μs | 26.6846 KOps/s | 24.9996 KOps/s | |
test_contiguous[memmap_tensor0] | 29.4020μs | 9.8858μs | 101.1557 KOps/s | 95.7618 KOps/s | |
test_stack[memmap_tensor0] | 0.1083ms | 30.5318μs | 32.7527 KOps/s | 31.6551 KOps/s | |
test_memmaptd_index | 0.4191ms | 0.3637ms | 2.7497 KOps/s | 2.7458 KOps/s | |
test_memmaptd_index_astensor | 1.5451ms | 1.4028ms | 712.8552 Ops/s | 700.5247 Ops/s | |
test_memmaptd_index_op | 3.4050ms | 3.0649ms | 326.2777 Ops/s | 320.1331 Ops/s | |
test_reshape_pytree | 0.1093ms | 37.9807μs | 26.3292 KOps/s | 25.3405 KOps/s | |
test_reshape_td | 0.1359ms | 47.7957μs | 20.9224 KOps/s | 20.5486 KOps/s | |
test_view_pytree | 0.1626ms | 37.9548μs | 26.3471 KOps/s | 25.5798 KOps/s | |
test_view_td | 48.4040μs | 10.1521μs | 98.5021 KOps/s | 95.8496 KOps/s | |
test_unbind_pytree | 0.1074ms | 43.1083μs | 23.1974 KOps/s | 22.4264 KOps/s | |
test_unbind_td | 0.2138ms | 0.1115ms | 8.9693 KOps/s | 8.7332 KOps/s | |
test_split_pytree | 98.5080μs | 42.3805μs | 23.5957 KOps/s | 22.7748 KOps/s | |
test_split_td | 0.9152ms | 0.1232ms | 8.1151 KOps/s | 8.1039 KOps/s | |
test_add_pytree | 0.1097ms | 54.3457μs | 18.4007 KOps/s | 18.4356 KOps/s | |
test_add_td | 0.1376ms | 88.5309μs | 11.2955 KOps/s | 11.1246 KOps/s | |
test_distributed | 66.7060μs | 10.4668μs | 95.5404 KOps/s | 88.7686 KOps/s | |
test_tdmodule | 3.6691ms | 35.7705μs | 27.9560 KOps/s | 28.8771 KOps/s | |
test_tdmodule_dispatch | 0.3146ms | 62.0713μs | 16.1105 KOps/s | 15.7515 KOps/s | |
test_tdseq | 69.6060μs | 37.9481μs | 26.3518 KOps/s | 25.4544 KOps/s | |
test_tdseq_dispatch | 0.2240ms | 76.5109μs | 13.0700 KOps/s | 12.7411 KOps/s | |
test_instantiation_functorch | 2.0867ms | 1.9155ms | 522.0482 Ops/s | 511.6111 Ops/s | |
test_instantiation_td | 4.0991ms | 1.6129ms | 619.9957 Ops/s | 618.6667 Ops/s | |
test_exec_functorch | 0.3830ms | 0.2350ms | 4.2548 KOps/s | 4.3681 KOps/s | |
test_exec_td | 0.2628ms | 0.2153ms | 4.6457 KOps/s | 4.6338 KOps/s | |
test_vmap_mlp_speed[True-True] | 18.8597ms | 1.8099ms | 552.5016 Ops/s | 695.1540 Ops/s | |
test_vmap_mlp_speed[True-False] | 12.0141ms | 0.8569ms | 1.1671 KOps/s | 1.2691 KOps/s | |
test_vmap_mlp_speed[False-True] | 8.2086ms | 1.2347ms | 809.9439 Ops/s | 660.4465 Ops/s | |
test_vmap_mlp_speed[False-False] | 4.4898ms | 0.6029ms | 1.6587 KOps/s | 1.6887 KOps/s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
what was the treatment of missing keys before?
To me it could make sense that someone calls self.where
with an other
that has less keys and the where
is only applied to the shared keys
try: | ||
tensor = _other.empty() | ||
except NotImplementedError: | ||
# H5 tensordicts do not support select() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is H5?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
H5DF
_other = pad | ||
else: | ||
raise KeyError( | ||
f"Key {key} not found and no pad value provided." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before this PR was an error thrown for these cases? What was the treatment for missing keys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was a bug in the sense that td0.where(cond, td1)
and td1.where(~cond, td0)
where not returning the same thing. Only the keys of the first (self) were scanned. Keys in self
but not other
would raise an exception, but keys in other
but not self
where just ignored. This is super bad IMO
You'll get an exception. If you want to do that, you need to call select first. It's best to avoid having a default behaviour that does not assume anything about user's intentions IMO, and raise an exception if something's wrong |
Description
Adds a
pad
argument inTensorDict.where
to allow the method to be called on partially matching tensordicts.The
pad
value will be used for any key that is present in one but not the other tensordict.