Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] CompositeDistribution #517

Merged
merged 8 commits into from
Sep 1, 2023
Merged

[Feature] CompositeDistribution #517

merged 8 commits into from
Sep 1, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Aug 31, 2023

Description

Introduces CompositeDistribution, which allows to build bags of distributions.

Addresses pytorch/rl#1473

cc @hersh

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 31, 2023
@vmoens vmoens marked this pull request as ready for review August 31, 2023 08:40
@vmoens vmoens added the enhancement New feature or request label Aug 31, 2023
@vmoens
Copy link
Contributor Author

vmoens commented Aug 31, 2023

Note: we pass a structure {"sample_name": {"params": stuff}} to the constructor, which structure is similar to the samples that we'll get. But the value pointed by "sample_name" is a dict (or better, a tensordict) within the constructor, and it will be a tensor during sampling. This means that this will never be possible:

params = make_params(...)
dist = CompositeDistribution(params, ...)
sample = dist.sample()
params.update(sample) # breaks as tensor and tensordicts conflict

@github-actions
Copy link

github-actions bot commented Aug 31, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 109. Improved: $\large\color{#35bf28}4$. Worsened: $\large\color{#d91a1a}7$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 39.8010μs 20.2503μs 49.3820 KOps/s 49.5431 KOps/s $\color{#d91a1a}-0.33\%$
test_plain_set_stack_nested 0.2083ms 0.1840ms 5.4347 KOps/s 5.4026 KOps/s $\color{#35bf28}+0.59\%$
test_plain_set_nested_inplace 49.4000μs 23.4199μs 42.6987 KOps/s 42.4369 KOps/s $\color{#35bf28}+0.62\%$
test_plain_set_stack_nested_inplace 0.2678ms 0.2163ms 4.6242 KOps/s 4.5545 KOps/s $\color{#35bf28}+1.53\%$
test_items 20.1000μs 3.4605μs 288.9763 KOps/s 323.7241 KOps/s $\textbf{\color{#d91a1a}-10.73\%}$
test_items_nested 1.8761ms 0.3773ms 2.6505 KOps/s 2.7196 KOps/s $\color{#d91a1a}-2.54\%$
test_items_nested_locked 0.4472ms 0.3798ms 2.6328 KOps/s 2.7557 KOps/s $\color{#d91a1a}-4.46\%$
test_items_nested_leaf 1.0086ms 0.2220ms 4.5036 KOps/s 4.5058 KOps/s $\color{#d91a1a}-0.05\%$
test_items_stack_nested 2.0158ms 1.9289ms 518.4208 Ops/s 522.3239 Ops/s $\color{#d91a1a}-0.75\%$
test_items_stack_nested_leaf 1.8009ms 1.7486ms 571.8924 Ops/s 571.1794 Ops/s $\color{#35bf28}+0.12\%$
test_items_stack_nested_locked 1.8997ms 0.9604ms 1.0413 KOps/s 1.0410 KOps/s $\color{#35bf28}+0.02\%$
test_keys 27.3000μs 4.5401μs 220.2617 KOps/s 220.6494 KOps/s $\color{#d91a1a}-0.18\%$
test_keys_nested 0.7804ms 0.1727ms 5.7899 KOps/s 5.7643 KOps/s $\color{#35bf28}+0.44\%$
test_keys_nested_locked 0.1948ms 0.1701ms 5.8800 KOps/s 5.8351 KOps/s $\color{#35bf28}+0.77\%$
test_keys_nested_leaf 0.3174ms 0.1679ms 5.9546 KOps/s 5.5358 KOps/s $\textbf{\color{#35bf28}+7.56\%}$
test_keys_stack_nested 1.8273ms 1.7584ms 568.7009 Ops/s 569.9005 Ops/s $\color{#d91a1a}-0.21\%$
test_keys_stack_nested_leaf 1.8196ms 1.7556ms 569.6102 Ops/s 552.0341 Ops/s $\color{#35bf28}+3.18\%$
test_keys_stack_nested_locked 0.8595ms 0.7853ms 1.2733 KOps/s 1.2828 KOps/s $\color{#d91a1a}-0.74\%$
test_values 27.7000μs 1.3494μs 741.0450 KOps/s 856.2964 KOps/s $\textbf{\color{#d91a1a}-13.46\%}$
test_values_nested 87.3000μs 64.6039μs 15.4790 KOps/s 14.7073 KOps/s $\textbf{\color{#35bf28}+5.25\%}$
test_values_nested_locked 89.8010μs 64.5158μs 15.5001 KOps/s 14.8053 KOps/s $\color{#35bf28}+4.69\%$
test_values_nested_leaf 99.3010μs 56.7072μs 17.6345 KOps/s 16.7389 KOps/s $\textbf{\color{#35bf28}+5.35\%}$
test_values_stack_nested 1.5880ms 1.5314ms 653.0030 Ops/s 653.6041 Ops/s $\color{#d91a1a}-0.09\%$
test_values_stack_nested_leaf 1.5992ms 1.5182ms 658.6799 Ops/s 655.0439 Ops/s $\color{#35bf28}+0.56\%$
test_values_stack_nested_locked 0.6689ms 0.6304ms 1.5864 KOps/s 1.6028 KOps/s $\color{#d91a1a}-1.02\%$
test_membership 24.1000μs 1.8716μs 534.3064 KOps/s 525.1446 KOps/s $\color{#35bf28}+1.74\%$
test_membership_nested 29.0000μs 3.6604μs 273.1906 KOps/s 272.6447 KOps/s $\color{#35bf28}+0.20\%$
test_membership_nested_leaf 30.3010μs 3.6849μs 271.3807 KOps/s 272.2172 KOps/s $\color{#d91a1a}-0.31\%$
test_membership_stacked_nested 41.4000μs 14.4860μs 69.0322 KOps/s 68.3295 KOps/s $\color{#35bf28}+1.03\%$
test_membership_stacked_nested_leaf 38.8000μs 14.5396μs 68.7777 KOps/s 68.5738 KOps/s $\color{#35bf28}+0.30\%$
test_membership_nested_last 66.2010μs 7.6301μs 131.0593 KOps/s 132.7090 KOps/s $\color{#d91a1a}-1.24\%$
test_membership_nested_leaf_last 32.6010μs 7.7291μs 129.3816 KOps/s 132.1923 KOps/s $\color{#d91a1a}-2.13\%$
test_membership_stacked_nested_last 0.2463ms 0.2206ms 4.5323 KOps/s 4.4722 KOps/s $\color{#35bf28}+1.34\%$
test_membership_stacked_nested_leaf_last 56.4010μs 17.0579μs 58.6238 KOps/s 57.4849 KOps/s $\color{#35bf28}+1.98\%$
test_nested_getleaf 39.4000μs 15.3839μs 65.0032 KOps/s 64.3176 KOps/s $\color{#35bf28}+1.07\%$
test_nested_get 36.2010μs 14.5558μs 68.7011 KOps/s 68.3581 KOps/s $\color{#35bf28}+0.50\%$
test_stacked_getleaf 0.9273ms 0.8307ms 1.2038 KOps/s 1.2126 KOps/s $\color{#d91a1a}-0.72\%$
test_stacked_get 0.8400ms 0.7915ms 1.2634 KOps/s 1.2600 KOps/s $\color{#35bf28}+0.27\%$
test_nested_getitemleaf 59.1010μs 15.3355μs 65.2082 KOps/s 63.9584 KOps/s $\color{#35bf28}+1.95\%$
test_nested_getitem 59.7010μs 14.5745μs 68.6128 KOps/s 68.1363 KOps/s $\color{#35bf28}+0.70\%$
test_stacked_getitemleaf 0.9363ms 0.8245ms 1.2128 KOps/s 1.2003 KOps/s $\color{#35bf28}+1.04\%$
test_stacked_getitem 0.8249ms 0.7902ms 1.2655 KOps/s 1.2577 KOps/s $\color{#35bf28}+0.61\%$
test_lock_nested 58.6446ms 1.4257ms 701.4280 Ops/s 745.1056 Ops/s $\textbf{\color{#d91a1a}-5.86\%}$
test_lock_stack_nested 77.1524ms 18.2208ms 54.8824 Ops/s 58.6493 Ops/s $\textbf{\color{#d91a1a}-6.42\%}$
test_unlock_nested 60.0001ms 1.4306ms 698.9883 Ops/s 705.4364 Ops/s $\color{#d91a1a}-0.91\%$
test_unlock_stack_nested 78.4839ms 18.7180ms 53.4246 Ops/s 57.2617 Ops/s $\textbf{\color{#d91a1a}-6.70\%}$
test_flatten_speed 1.0395ms 1.0019ms 998.1316 Ops/s 1.0181 KOps/s $\color{#d91a1a}-1.96\%$
test_unflatten_speed 1.7843ms 1.7489ms 571.7991 Ops/s 576.3803 Ops/s $\color{#d91a1a}-0.79\%$
test_common_ops 1.2188ms 1.0183ms 982.0044 Ops/s 991.8204 Ops/s $\color{#d91a1a}-0.99\%$
test_creation 34.5010μs 6.1087μs 163.7017 KOps/s 164.0742 KOps/s $\color{#d91a1a}-0.23\%$
test_creation_empty 28.2000μs 13.4715μs 74.2307 KOps/s 75.0640 KOps/s $\color{#d91a1a}-1.11\%$
test_creation_nested_1 48.4010μs 23.1575μs 43.1825 KOps/s 43.6259 KOps/s $\color{#d91a1a}-1.02\%$
test_creation_nested_2 50.0000μs 26.1107μs 38.2984 KOps/s 38.8446 KOps/s $\color{#d91a1a}-1.41\%$
test_clone 0.1353ms 24.6129μs 40.6292 KOps/s 40.9783 KOps/s $\color{#d91a1a}-0.85\%$
test_getitem[int] 68.4010μs 27.0798μs 36.9278 KOps/s 38.1420 KOps/s $\color{#d91a1a}-3.18\%$
test_getitem[slice_int] 0.1035ms 50.8922μs 19.6494 KOps/s 20.3932 KOps/s $\color{#d91a1a}-3.65\%$
test_getitem[range] 0.1274ms 77.4016μs 12.9196 KOps/s 13.1579 KOps/s $\color{#d91a1a}-1.81\%$
test_getitem[tuple] 69.4000μs 41.0449μs 24.3635 KOps/s 24.9299 KOps/s $\color{#d91a1a}-2.27\%$
test_getitem[list] 0.3286ms 72.5514μs 13.7833 KOps/s 14.0271 KOps/s $\color{#d91a1a}-1.74\%$
test_setitem_dim[int] 55.4010μs 31.3464μs 31.9016 KOps/s 31.5166 KOps/s $\color{#35bf28}+1.22\%$
test_setitem_dim[slice_int] 72.8010μs 55.7467μs 17.9383 KOps/s 17.7620 KOps/s $\color{#35bf28}+0.99\%$
test_setitem_dim[range] 0.1127ms 76.4307μs 13.0837 KOps/s 13.0108 KOps/s $\color{#35bf28}+0.56\%$
test_setitem_dim[tuple] 62.0010μs 46.8501μs 21.3447 KOps/s 21.1462 KOps/s $\color{#35bf28}+0.94\%$
test_setitem 0.1343ms 30.1375μs 33.1812 KOps/s 33.7225 KOps/s $\color{#d91a1a}-1.61\%$
test_set 0.1327ms 29.3817μs 34.0348 KOps/s 34.2969 KOps/s $\color{#d91a1a}-0.76\%$
test_set_shared 3.9862ms 0.1516ms 6.5944 KOps/s 6.5692 KOps/s $\color{#35bf28}+0.38\%$
test_update 0.2062ms 33.4112μs 29.9301 KOps/s 30.4614 KOps/s $\color{#d91a1a}-1.74\%$
test_update_nested 0.1818ms 50.7340μs 19.7106 KOps/s 19.6309 KOps/s $\color{#35bf28}+0.41\%$
test_set_nested 0.1420ms 32.2907μs 30.9686 KOps/s 32.6212 KOps/s $\textbf{\color{#d91a1a}-5.07\%}$
test_set_nested_new 0.1640ms 50.3807μs 19.8489 KOps/s 20.0957 KOps/s $\color{#d91a1a}-1.23\%$
test_select 0.2084ms 95.4921μs 10.4721 KOps/s 10.4275 KOps/s $\color{#35bf28}+0.43\%$
test_unbind_speed 0.6754ms 0.6445ms 1.5516 KOps/s 1.5702 KOps/s $\color{#d91a1a}-1.19\%$
test_unbind_speed_stack0 65.0108ms 8.1771ms 122.2932 Ops/s 121.7479 Ops/s $\color{#35bf28}+0.45\%$
test_unbind_speed_stack1 20.1000μs 1.1359μs 880.3476 KOps/s 851.5773 KOps/s $\color{#35bf28}+3.38\%$
test_creation[device0] 0.5861ms 0.3353ms 2.9824 KOps/s 3.0215 KOps/s $\color{#d91a1a}-1.30\%$
test_creation_from_tensor 2.3506ms 0.3761ms 2.6590 KOps/s 2.6028 KOps/s $\color{#35bf28}+2.16\%$
test_add_one[memmap_tensor0] 1.6077ms 30.6159μs 32.6627 KOps/s 32.9233 KOps/s $\color{#d91a1a}-0.79\%$
test_contiguous[memmap_tensor0] 32.0000μs 8.5419μs 117.0704 KOps/s 124.0116 KOps/s $\textbf{\color{#d91a1a}-5.60\%}$
test_stack[memmap_tensor0] 99.8010μs 25.4036μs 39.3644 KOps/s 40.3244 KOps/s $\color{#d91a1a}-2.38\%$
test_memmaptd_index 0.3306ms 0.2996ms 3.3382 KOps/s 3.3277 KOps/s $\color{#35bf28}+0.31\%$
test_memmaptd_index_astensor 1.3311ms 1.2462ms 802.4354 Ops/s 825.0703 Ops/s $\color{#d91a1a}-2.74\%$
test_memmaptd_index_op 2.6824ms 2.3972ms 417.1594 Ops/s 426.7196 Ops/s $\color{#d91a1a}-2.24\%$
test_reshape_pytree 97.5010μs 37.2768μs 26.8263 KOps/s 26.9943 KOps/s $\color{#d91a1a}-0.62\%$
test_reshape_td 80.5010μs 45.0255μs 22.2096 KOps/s 22.9611 KOps/s $\color{#d91a1a}-3.27\%$
test_view_pytree 81.1010μs 34.3484μs 29.1134 KOps/s 28.9407 KOps/s $\color{#35bf28}+0.60\%$
test_view_td 21.9010μs 8.5527μs 116.9218 KOps/s 113.4358 KOps/s $\color{#35bf28}+3.07\%$
test_unbind_pytree 93.9010μs 38.7783μs 25.7876 KOps/s 26.1889 KOps/s $\color{#d91a1a}-1.53\%$
test_unbind_td 0.1696ms 95.9546μs 10.4216 KOps/s 10.6831 KOps/s $\color{#d91a1a}-2.45\%$
test_split_pytree 75.4000μs 42.8856μs 23.3179 KOps/s 23.0864 KOps/s $\color{#35bf28}+1.00\%$
test_split_td 0.7825ms 0.1136ms 8.8063 KOps/s 8.8744 KOps/s $\color{#d91a1a}-0.77\%$
test_add_pytree 80.3000μs 44.8779μs 22.2827 KOps/s 21.9086 KOps/s $\color{#35bf28}+1.71\%$
test_add_td 0.1202ms 71.8087μs 13.9259 KOps/s 14.5245 KOps/s $\color{#d91a1a}-4.12\%$
test_distributed 31.8000μs 8.1432μs 122.8024 KOps/s 119.1901 KOps/s $\color{#35bf28}+3.03\%$
test_tdmodule 0.1820ms 26.1766μs 38.2021 KOps/s 38.7678 KOps/s $\color{#d91a1a}-1.46\%$
test_tdmodule_dispatch 0.2566ms 50.3726μs 19.8521 KOps/s 19.7219 KOps/s $\color{#35bf28}+0.66\%$
test_tdseq 0.4612ms 26.4047μs 37.8721 KOps/s 36.4417 KOps/s $\color{#35bf28}+3.93\%$
test_tdseq_dispatch 0.1722ms 54.1136μs 18.4796 KOps/s 18.2786 KOps/s $\color{#35bf28}+1.10\%$
test_instantiation_functorch 1.9532ms 1.5363ms 650.9201 Ops/s 652.9604 Ops/s $\color{#d91a1a}-0.31\%$
test_instantiation_td 1.9518ms 1.2582ms 794.7636 Ops/s 738.4640 Ops/s $\textbf{\color{#35bf28}+7.62\%}$
test_exec_functorch 0.3218ms 0.1793ms 5.5761 KOps/s 5.6800 KOps/s $\color{#d91a1a}-1.83\%$
test_exec_td 0.2002ms 0.1684ms 5.9388 KOps/s 5.9438 KOps/s $\color{#d91a1a}-0.09\%$
test_vmap_mlp_speed[True-True] 1.7938ms 1.0540ms 948.7583 Ops/s 947.1470 Ops/s $\color{#35bf28}+0.17\%$
test_vmap_mlp_speed[True-False] 5.3164ms 0.5207ms 1.9203 KOps/s 1.9085 KOps/s $\color{#35bf28}+0.62\%$
test_vmap_mlp_speed[False-True] 6.6471ms 0.9187ms 1.0885 KOps/s 1.1162 KOps/s $\color{#d91a1a}-2.47\%$
test_vmap_mlp_speed[False-False] 4.9565ms 0.4149ms 2.4102 KOps/s 2.4458 KOps/s $\color{#d91a1a}-1.45\%$
test_vmap_transformer_speed[True-True] 15.8794ms 12.4663ms 80.2166 Ops/s 80.0810 Ops/s $\color{#35bf28}+0.17\%$
test_vmap_transformer_speed[True-False] 13.0686ms 7.7094ms 129.7120 Ops/s 130.1117 Ops/s $\color{#d91a1a}-0.31\%$
test_vmap_transformer_speed[False-True] 22.1632ms 12.3564ms 80.9298 Ops/s 81.9336 Ops/s $\color{#d91a1a}-1.23\%$
test_vmap_transformer_speed[False-False] 12.4010ms 7.6057ms 131.4796 Ops/s 132.8867 Ops/s $\color{#d91a1a}-1.06\%$


def log_prob(self, sample: TensorDictBase):
d = {
_add_suffix(name, "_log_prob"): dist.log_prob(sample.get(name))
Copy link
Contributor

@matteobettini matteobettini Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not want these keys to be parametric?

Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens merged commit 11b8a8c into main Sep 1, 2023
26 of 27 checks passed
@vmoens vmoens deleted the composite_dist branch September 1, 2023 08:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants