Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add pad argument to TensorDict.where #539

Merged
merged 9 commits into from
Oct 10, 2023
Merged

[Feature] Add pad argument to TensorDict.where #539

merged 9 commits into from
Oct 10, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Oct 9, 2023

Description

Adds a pad argument in TensorDict.where to allow the method to be called on partially matching tensordicts.
The pad value will be used for any key that is present in one but not the other tensordict.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 9, 2023
@github-actions
Copy link

github-actions bot commented Oct 9, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 105. Improved: $\large\color{#35bf28}4$. Worsened: $\large\color{#d91a1a}17$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 0.2099ms 23.8093μs 42.0004 KOps/s 43.7642 KOps/s $\color{#d91a1a}-4.03\%$
test_plain_set_stack_nested 0.7977ms 0.2309ms 4.3305 KOps/s 4.7125 KOps/s $\textbf{\color{#d91a1a}-8.11\%}$
test_plain_set_nested_inplace 2.5216ms 30.0819μs 33.2426 KOps/s 37.1543 KOps/s $\textbf{\color{#d91a1a}-10.53\%}$
test_plain_set_stack_nested_inplace 8.3801ms 0.2867ms 3.4880 KOps/s 3.9635 KOps/s $\textbf{\color{#d91a1a}-12.00\%}$
test_items 0.6318ms 4.4752μs 223.4519 KOps/s 238.5537 KOps/s $\textbf{\color{#d91a1a}-6.33\%}$
test_items_nested 4.4490ms 0.4802ms 2.0825 KOps/s 2.4291 KOps/s $\textbf{\color{#d91a1a}-14.27\%}$
test_items_nested_locked 3.1876ms 0.4684ms 2.1348 KOps/s 2.4212 KOps/s $\textbf{\color{#d91a1a}-11.83\%}$
test_items_nested_leaf 1.6348ms 0.2746ms 3.6421 KOps/s 3.9120 KOps/s $\textbf{\color{#d91a1a}-6.90\%}$
test_items_stack_nested 5.5966ms 2.5792ms 387.7147 Ops/s 414.9062 Ops/s $\textbf{\color{#d91a1a}-6.55\%}$
test_items_stack_nested_leaf 4.8573ms 2.3239ms 430.3044 Ops/s 477.6039 Ops/s $\textbf{\color{#d91a1a}-9.90\%}$
test_items_stack_nested_locked 2.9872ms 1.2314ms 812.0751 Ops/s 872.8363 Ops/s $\textbf{\color{#d91a1a}-6.96\%}$
test_keys 6.9045ms 6.5560μs 152.5311 KOps/s 166.3764 KOps/s $\textbf{\color{#d91a1a}-8.32\%}$
test_keys_nested 1.9697ms 0.2292ms 4.3634 KOps/s 4.3095 KOps/s $\color{#35bf28}+1.25\%$
test_keys_nested_locked 7.0827ms 0.2402ms 4.1629 KOps/s 4.7031 KOps/s $\textbf{\color{#d91a1a}-11.49\%}$
test_keys_nested_leaf 0.9731ms 0.2041ms 4.9003 KOps/s 4.8886 KOps/s $\color{#35bf28}+0.24\%$
test_keys_stack_nested 2.2217ms 2.1213ms 471.4009 Ops/s 471.0828 Ops/s $\color{#35bf28}+0.07\%$
test_keys_stack_nested_leaf 2.1888ms 2.1191ms 471.9023 Ops/s 470.7128 Ops/s $\color{#35bf28}+0.25\%$
test_keys_stack_nested_locked 1.0697ms 0.9385ms 1.0655 KOps/s 1.0209 KOps/s $\color{#35bf28}+4.37\%$
test_values 32.1020μs 1.9181μs 521.3398 KOps/s 528.4415 KOps/s $\color{#d91a1a}-1.34\%$
test_values_nested 0.1045ms 73.5823μs 13.5902 KOps/s 13.5767 KOps/s $\color{#35bf28}+0.10\%$
test_values_nested_locked 0.1664ms 73.7949μs 13.5511 KOps/s 13.5300 KOps/s $\color{#35bf28}+0.16\%$
test_values_nested_leaf 0.1016ms 66.0246μs 15.1459 KOps/s 15.1399 KOps/s $\color{#35bf28}+0.04\%$
test_values_stack_nested 1.9415ms 1.8473ms 541.3286 Ops/s 534.2261 Ops/s $\color{#35bf28}+1.33\%$
test_values_stack_nested_leaf 2.0213ms 1.8420ms 542.8804 Ops/s 543.3160 Ops/s $\color{#d91a1a}-0.08\%$
test_values_stack_nested_locked 0.8133ms 0.7380ms 1.3550 KOps/s 1.3144 KOps/s $\color{#35bf28}+3.09\%$
test_membership 29.7020μs 2.1781μs 459.1203 KOps/s 477.8933 KOps/s $\color{#d91a1a}-3.93\%$
test_membership_nested 24.7010μs 4.1524μs 240.8224 KOps/s 240.2343 KOps/s $\color{#35bf28}+0.24\%$
test_membership_nested_leaf 32.3020μs 4.1739μs 239.5812 KOps/s 242.0039 KOps/s $\color{#d91a1a}-1.00\%$
test_membership_stacked_nested 45.1040μs 17.1177μs 58.4190 KOps/s 59.2371 KOps/s $\color{#d91a1a}-1.38\%$
test_membership_stacked_nested_leaf 0.1069ms 17.2944μs 57.8223 KOps/s 59.3947 KOps/s $\color{#d91a1a}-2.65\%$
test_membership_nested_last 0.1245ms 8.7900μs 113.7660 KOps/s 113.3575 KOps/s $\color{#35bf28}+0.36\%$
test_membership_nested_leaf_last 39.8030μs 8.8239μs 113.3286 KOps/s 112.2305 KOps/s $\color{#35bf28}+0.98\%$
test_membership_stacked_nested_last 0.3139ms 0.2646ms 3.7800 KOps/s 3.8209 KOps/s $\color{#d91a1a}-1.07\%$
test_membership_stacked_nested_leaf_last 86.0070μs 20.0859μs 49.7862 KOps/s 50.4802 KOps/s $\color{#d91a1a}-1.37\%$
test_nested_getleaf 62.5050μs 18.1287μs 55.1610 KOps/s 56.0397 KOps/s $\color{#d91a1a}-1.57\%$
test_nested_get 45.7030μs 17.1131μs 58.4346 KOps/s 59.1189 KOps/s $\color{#d91a1a}-1.16\%$
test_stacked_getleaf 1.1532ms 1.0128ms 987.3684 Ops/s 990.9482 Ops/s $\color{#d91a1a}-0.36\%$
test_stacked_get 1.0217ms 0.9626ms 1.0389 KOps/s 1.0405 KOps/s $\color{#d91a1a}-0.16\%$
test_nested_getitemleaf 85.0070μs 18.1171μs 55.1963 KOps/s 56.1451 KOps/s $\color{#d91a1a}-1.69\%$
test_nested_getitem 76.0060μs 17.2244μs 58.0573 KOps/s 59.2101 KOps/s $\color{#d91a1a}-1.95\%$
test_stacked_getitemleaf 1.1810ms 1.0146ms 985.6020 Ops/s 988.5928 Ops/s $\color{#d91a1a}-0.30\%$
test_stacked_getitem 1.0164ms 0.9652ms 1.0360 KOps/s 1.0376 KOps/s $\color{#d91a1a}-0.15\%$
test_lock_nested 85.7782ms 1.8388ms 543.8383 Ops/s 585.5171 Ops/s $\textbf{\color{#d91a1a}-7.12\%}$
test_lock_stack_nested 0.1097s 24.1400ms 41.4251 Ops/s 41.7694 Ops/s $\color{#d91a1a}-0.82\%$
test_unlock_nested 81.8578ms 1.8383ms 543.9832 Ops/s 557.8651 Ops/s $\color{#d91a1a}-2.49\%$
test_unlock_stack_nested 0.1086s 24.6308ms 40.5995 Ops/s 40.3764 Ops/s $\color{#35bf28}+0.55\%$
test_flatten_speed 5.0562ms 1.3318ms 750.8797 Ops/s 828.9012 Ops/s $\textbf{\color{#d91a1a}-9.41\%}$
test_unflatten_speed 7.2961ms 2.3150ms 431.9689 Ops/s 470.2679 Ops/s $\textbf{\color{#d91a1a}-8.14\%}$
test_common_ops 5.6305ms 1.2869ms 777.0794 Ops/s 768.9811 Ops/s $\color{#35bf28}+1.05\%$
test_creation 92.8070μs 7.1288μs 140.2760 KOps/s 136.8237 KOps/s $\color{#35bf28}+2.52\%$
test_creation_empty 51.8040μs 15.7867μs 63.3444 KOps/s 61.3301 KOps/s $\color{#35bf28}+3.28\%$
test_creation_nested_1 62.3050μs 27.9821μs 35.7371 KOps/s 34.5929 KOps/s $\color{#35bf28}+3.31\%$
test_creation_nested_2 57.0050μs 30.3968μs 32.8982 KOps/s 32.3650 KOps/s $\color{#35bf28}+1.65\%$
test_clone 0.1541ms 28.2866μs 35.3524 KOps/s 35.7399 KOps/s $\color{#d91a1a}-1.08\%$
test_getitem[int] 70.6060μs 32.4622μs 30.8051 KOps/s 31.1104 KOps/s $\color{#d91a1a}-0.98\%$
test_getitem[slice_int] 0.2641ms 64.4811μs 15.5084 KOps/s 15.7186 KOps/s $\color{#d91a1a}-1.34\%$
test_getitem[range] 0.2141ms 93.9731μs 10.6413 KOps/s 10.4876 KOps/s $\color{#35bf28}+1.47\%$
test_getitem[tuple] 0.1245ms 53.2096μs 18.7936 KOps/s 19.0668 KOps/s $\color{#d91a1a}-1.43\%$
test_getitem[list] 0.1776ms 88.8234μs 11.2583 KOps/s 11.1059 KOps/s $\color{#35bf28}+1.37\%$
test_setitem_dim[int] 0.1438ms 40.3658μs 24.7734 KOps/s 25.4656 KOps/s $\color{#d91a1a}-2.72\%$
test_setitem_dim[slice_int] 0.1079ms 70.4382μs 14.1968 KOps/s 14.3342 KOps/s $\color{#d91a1a}-0.96\%$
test_setitem_dim[range] 0.1280ms 93.1611μs 10.7341 KOps/s 10.3788 KOps/s $\color{#35bf28}+3.42\%$
test_setitem_dim[tuple] 87.1070μs 58.7019μs 17.0352 KOps/s 17.3079 KOps/s $\color{#d91a1a}-1.58\%$
test_setitem 0.1886ms 36.1743μs 27.6439 KOps/s 27.1884 KOps/s $\color{#35bf28}+1.68\%$
test_set 0.2145ms 35.2661μs 28.3558 KOps/s 28.2893 KOps/s $\color{#35bf28}+0.23\%$
test_set_shared 0.4740ms 0.2363ms 4.2324 KOps/s 4.2158 KOps/s $\color{#35bf28}+0.39\%$
test_update 0.2259ms 40.1891μs 24.8823 KOps/s 24.5848 KOps/s $\color{#35bf28}+1.21\%$
test_update_nested 0.2453ms 59.4428μs 16.8229 KOps/s 16.6685 KOps/s $\color{#35bf28}+0.93\%$
test_set_nested 0.2322ms 38.6256μs 25.8895 KOps/s 25.7685 KOps/s $\color{#35bf28}+0.47\%$
test_set_nested_new 0.2220ms 61.8130μs 16.1778 KOps/s 16.3575 KOps/s $\color{#d91a1a}-1.10\%$
test_select 0.1537ms 0.1143ms 8.7471 KOps/s 8.8593 KOps/s $\color{#d91a1a}-1.27\%$
test_unbind_speed 1.1394ms 0.7603ms 1.3152 KOps/s 1.3165 KOps/s $\color{#d91a1a}-0.10\%$
test_unbind_speed_stack0 88.4244ms 10.2953ms 97.1319 Ops/s 94.6757 Ops/s $\color{#35bf28}+2.59\%$
test_unbind_speed_stack1 20.0020μs 1.3083μs 764.3402 KOps/s 761.1382 KOps/s $\color{#35bf28}+0.42\%$
test_creation[device0] 5.3428ms 0.5369ms 1.8626 KOps/s 1.8354 KOps/s $\color{#35bf28}+1.48\%$
test_creation_from_tensor 8.8058ms 0.6182ms 1.6176 KOps/s 1.6666 KOps/s $\color{#d91a1a}-2.94\%$
test_add_one[memmap_tensor0] 2.0944ms 37.4748μs 26.6846 KOps/s 24.9996 KOps/s $\textbf{\color{#35bf28}+6.74\%}$
test_contiguous[memmap_tensor0] 29.4020μs 9.8858μs 101.1557 KOps/s 95.7618 KOps/s $\textbf{\color{#35bf28}+5.63\%}$
test_stack[memmap_tensor0] 0.1083ms 30.5318μs 32.7527 KOps/s 31.6551 KOps/s $\color{#35bf28}+3.47\%$
test_memmaptd_index 0.4191ms 0.3637ms 2.7497 KOps/s 2.7458 KOps/s $\color{#35bf28}+0.14\%$
test_memmaptd_index_astensor 1.5451ms 1.4028ms 712.8552 Ops/s 700.5247 Ops/s $\color{#35bf28}+1.76\%$
test_memmaptd_index_op 3.4050ms 3.0649ms 326.2777 Ops/s 320.1331 Ops/s $\color{#35bf28}+1.92\%$
test_reshape_pytree 0.1093ms 37.9807μs 26.3292 KOps/s 25.3405 KOps/s $\color{#35bf28}+3.90\%$
test_reshape_td 0.1359ms 47.7957μs 20.9224 KOps/s 20.5486 KOps/s $\color{#35bf28}+1.82\%$
test_view_pytree 0.1626ms 37.9548μs 26.3471 KOps/s 25.5798 KOps/s $\color{#35bf28}+3.00\%$
test_view_td 48.4040μs 10.1521μs 98.5021 KOps/s 95.8496 KOps/s $\color{#35bf28}+2.77\%$
test_unbind_pytree 0.1074ms 43.1083μs 23.1974 KOps/s 22.4264 KOps/s $\color{#35bf28}+3.44\%$
test_unbind_td 0.2138ms 0.1115ms 8.9693 KOps/s 8.7332 KOps/s $\color{#35bf28}+2.70\%$
test_split_pytree 98.5080μs 42.3805μs 23.5957 KOps/s 22.7748 KOps/s $\color{#35bf28}+3.60\%$
test_split_td 0.9152ms 0.1232ms 8.1151 KOps/s 8.1039 KOps/s $\color{#35bf28}+0.14\%$
test_add_pytree 0.1097ms 54.3457μs 18.4007 KOps/s 18.4356 KOps/s $\color{#d91a1a}-0.19\%$
test_add_td 0.1376ms 88.5309μs 11.2955 KOps/s 11.1246 KOps/s $\color{#35bf28}+1.54\%$
test_distributed 66.7060μs 10.4668μs 95.5404 KOps/s 88.7686 KOps/s $\textbf{\color{#35bf28}+7.63\%}$
test_tdmodule 3.6691ms 35.7705μs 27.9560 KOps/s 28.8771 KOps/s $\color{#d91a1a}-3.19\%$
test_tdmodule_dispatch 0.3146ms 62.0713μs 16.1105 KOps/s 15.7515 KOps/s $\color{#35bf28}+2.28\%$
test_tdseq 69.6060μs 37.9481μs 26.3518 KOps/s 25.4544 KOps/s $\color{#35bf28}+3.53\%$
test_tdseq_dispatch 0.2240ms 76.5109μs 13.0700 KOps/s 12.7411 KOps/s $\color{#35bf28}+2.58\%$
test_instantiation_functorch 2.0867ms 1.9155ms 522.0482 Ops/s 511.6111 Ops/s $\color{#35bf28}+2.04\%$
test_instantiation_td 4.0991ms 1.6129ms 619.9957 Ops/s 618.6667 Ops/s $\color{#35bf28}+0.21\%$
test_exec_functorch 0.3830ms 0.2350ms 4.2548 KOps/s 4.3681 KOps/s $\color{#d91a1a}-2.59\%$
test_exec_td 0.2628ms 0.2153ms 4.6457 KOps/s 4.6338 KOps/s $\color{#35bf28}+0.26\%$
test_vmap_mlp_speed[True-True] 18.8597ms 1.8099ms 552.5016 Ops/s 695.1540 Ops/s $\textbf{\color{#d91a1a}-20.52\%}$
test_vmap_mlp_speed[True-False] 12.0141ms 0.8569ms 1.1671 KOps/s 1.2691 KOps/s $\textbf{\color{#d91a1a}-8.04\%}$
test_vmap_mlp_speed[False-True] 8.2086ms 1.2347ms 809.9439 Ops/s 660.4465 Ops/s $\textbf{\color{#35bf28}+22.64\%}$
test_vmap_mlp_speed[False-False] 4.4898ms 0.6029ms 1.6587 KOps/s 1.6887 KOps/s $\color{#d91a1a}-1.77\%$

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

what was the treatment of missing keys before?

To me it could make sense that someone calls self.where with an other that has less keys and the where is only applied to the shared keys

try:
tensor = _other.empty()
except NotImplementedError:
# H5 tensordicts do not support select()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is H5?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

H5DF

_other = pad
else:
raise KeyError(
f"Key {key} not found and no pad value provided."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR was an error thrown for these cases? What was the treatment for missing keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was a bug in the sense that td0.where(cond, td1) and td1.where(~cond, td0) where not returning the same thing. Only the keys of the first (self) were scanned. Keys in self but not other would raise an exception, but keys in other but not self where just ignored. This is super bad IMO

@vmoens
Copy link
Contributor Author

vmoens commented Oct 10, 2023

You'll get an exception. If you want to do that, you need to call select first. It's best to avoid having a default behaviour that does not assume anything about user's intentions IMO, and raise an exception if something's wrong

@vmoens vmoens added the enhancement New feature or request label Oct 10, 2023
@vmoens vmoens merged commit 7747505 into main Oct 10, 2023
37 of 41 checks passed
@vmoens vmoens deleted the pad_for_where branch October 10, 2023 09:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants