-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] CompositeDistribution #517
Conversation
Note: we pass a structure
|
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 39.8010μs | 20.2503μs | 49.3820 KOps/s | 49.5431 KOps/s | |
test_plain_set_stack_nested | 0.2083ms | 0.1840ms | 5.4347 KOps/s | 5.4026 KOps/s | |
test_plain_set_nested_inplace | 49.4000μs | 23.4199μs | 42.6987 KOps/s | 42.4369 KOps/s | |
test_plain_set_stack_nested_inplace | 0.2678ms | 0.2163ms | 4.6242 KOps/s | 4.5545 KOps/s | |
test_items | 20.1000μs | 3.4605μs | 288.9763 KOps/s | 323.7241 KOps/s | |
test_items_nested | 1.8761ms | 0.3773ms | 2.6505 KOps/s | 2.7196 KOps/s | |
test_items_nested_locked | 0.4472ms | 0.3798ms | 2.6328 KOps/s | 2.7557 KOps/s | |
test_items_nested_leaf | 1.0086ms | 0.2220ms | 4.5036 KOps/s | 4.5058 KOps/s | |
test_items_stack_nested | 2.0158ms | 1.9289ms | 518.4208 Ops/s | 522.3239 Ops/s | |
test_items_stack_nested_leaf | 1.8009ms | 1.7486ms | 571.8924 Ops/s | 571.1794 Ops/s | |
test_items_stack_nested_locked | 1.8997ms | 0.9604ms | 1.0413 KOps/s | 1.0410 KOps/s | |
test_keys | 27.3000μs | 4.5401μs | 220.2617 KOps/s | 220.6494 KOps/s | |
test_keys_nested | 0.7804ms | 0.1727ms | 5.7899 KOps/s | 5.7643 KOps/s | |
test_keys_nested_locked | 0.1948ms | 0.1701ms | 5.8800 KOps/s | 5.8351 KOps/s | |
test_keys_nested_leaf | 0.3174ms | 0.1679ms | 5.9546 KOps/s | 5.5358 KOps/s | |
test_keys_stack_nested | 1.8273ms | 1.7584ms | 568.7009 Ops/s | 569.9005 Ops/s | |
test_keys_stack_nested_leaf | 1.8196ms | 1.7556ms | 569.6102 Ops/s | 552.0341 Ops/s | |
test_keys_stack_nested_locked | 0.8595ms | 0.7853ms | 1.2733 KOps/s | 1.2828 KOps/s | |
test_values | 27.7000μs | 1.3494μs | 741.0450 KOps/s | 856.2964 KOps/s | |
test_values_nested | 87.3000μs | 64.6039μs | 15.4790 KOps/s | 14.7073 KOps/s | |
test_values_nested_locked | 89.8010μs | 64.5158μs | 15.5001 KOps/s | 14.8053 KOps/s | |
test_values_nested_leaf | 99.3010μs | 56.7072μs | 17.6345 KOps/s | 16.7389 KOps/s | |
test_values_stack_nested | 1.5880ms | 1.5314ms | 653.0030 Ops/s | 653.6041 Ops/s | |
test_values_stack_nested_leaf | 1.5992ms | 1.5182ms | 658.6799 Ops/s | 655.0439 Ops/s | |
test_values_stack_nested_locked | 0.6689ms | 0.6304ms | 1.5864 KOps/s | 1.6028 KOps/s | |
test_membership | 24.1000μs | 1.8716μs | 534.3064 KOps/s | 525.1446 KOps/s | |
test_membership_nested | 29.0000μs | 3.6604μs | 273.1906 KOps/s | 272.6447 KOps/s | |
test_membership_nested_leaf | 30.3010μs | 3.6849μs | 271.3807 KOps/s | 272.2172 KOps/s | |
test_membership_stacked_nested | 41.4000μs | 14.4860μs | 69.0322 KOps/s | 68.3295 KOps/s | |
test_membership_stacked_nested_leaf | 38.8000μs | 14.5396μs | 68.7777 KOps/s | 68.5738 KOps/s | |
test_membership_nested_last | 66.2010μs | 7.6301μs | 131.0593 KOps/s | 132.7090 KOps/s | |
test_membership_nested_leaf_last | 32.6010μs | 7.7291μs | 129.3816 KOps/s | 132.1923 KOps/s | |
test_membership_stacked_nested_last | 0.2463ms | 0.2206ms | 4.5323 KOps/s | 4.4722 KOps/s | |
test_membership_stacked_nested_leaf_last | 56.4010μs | 17.0579μs | 58.6238 KOps/s | 57.4849 KOps/s | |
test_nested_getleaf | 39.4000μs | 15.3839μs | 65.0032 KOps/s | 64.3176 KOps/s | |
test_nested_get | 36.2010μs | 14.5558μs | 68.7011 KOps/s | 68.3581 KOps/s | |
test_stacked_getleaf | 0.9273ms | 0.8307ms | 1.2038 KOps/s | 1.2126 KOps/s | |
test_stacked_get | 0.8400ms | 0.7915ms | 1.2634 KOps/s | 1.2600 KOps/s | |
test_nested_getitemleaf | 59.1010μs | 15.3355μs | 65.2082 KOps/s | 63.9584 KOps/s | |
test_nested_getitem | 59.7010μs | 14.5745μs | 68.6128 KOps/s | 68.1363 KOps/s | |
test_stacked_getitemleaf | 0.9363ms | 0.8245ms | 1.2128 KOps/s | 1.2003 KOps/s | |
test_stacked_getitem | 0.8249ms | 0.7902ms | 1.2655 KOps/s | 1.2577 KOps/s | |
test_lock_nested | 58.6446ms | 1.4257ms | 701.4280 Ops/s | 745.1056 Ops/s | |
test_lock_stack_nested | 77.1524ms | 18.2208ms | 54.8824 Ops/s | 58.6493 Ops/s | |
test_unlock_nested | 60.0001ms | 1.4306ms | 698.9883 Ops/s | 705.4364 Ops/s | |
test_unlock_stack_nested | 78.4839ms | 18.7180ms | 53.4246 Ops/s | 57.2617 Ops/s | |
test_flatten_speed | 1.0395ms | 1.0019ms | 998.1316 Ops/s | 1.0181 KOps/s | |
test_unflatten_speed | 1.7843ms | 1.7489ms | 571.7991 Ops/s | 576.3803 Ops/s | |
test_common_ops | 1.2188ms | 1.0183ms | 982.0044 Ops/s | 991.8204 Ops/s | |
test_creation | 34.5010μs | 6.1087μs | 163.7017 KOps/s | 164.0742 KOps/s | |
test_creation_empty | 28.2000μs | 13.4715μs | 74.2307 KOps/s | 75.0640 KOps/s | |
test_creation_nested_1 | 48.4010μs | 23.1575μs | 43.1825 KOps/s | 43.6259 KOps/s | |
test_creation_nested_2 | 50.0000μs | 26.1107μs | 38.2984 KOps/s | 38.8446 KOps/s | |
test_clone | 0.1353ms | 24.6129μs | 40.6292 KOps/s | 40.9783 KOps/s | |
test_getitem[int] | 68.4010μs | 27.0798μs | 36.9278 KOps/s | 38.1420 KOps/s | |
test_getitem[slice_int] | 0.1035ms | 50.8922μs | 19.6494 KOps/s | 20.3932 KOps/s | |
test_getitem[range] | 0.1274ms | 77.4016μs | 12.9196 KOps/s | 13.1579 KOps/s | |
test_getitem[tuple] | 69.4000μs | 41.0449μs | 24.3635 KOps/s | 24.9299 KOps/s | |
test_getitem[list] | 0.3286ms | 72.5514μs | 13.7833 KOps/s | 14.0271 KOps/s | |
test_setitem_dim[int] | 55.4010μs | 31.3464μs | 31.9016 KOps/s | 31.5166 KOps/s | |
test_setitem_dim[slice_int] | 72.8010μs | 55.7467μs | 17.9383 KOps/s | 17.7620 KOps/s | |
test_setitem_dim[range] | 0.1127ms | 76.4307μs | 13.0837 KOps/s | 13.0108 KOps/s | |
test_setitem_dim[tuple] | 62.0010μs | 46.8501μs | 21.3447 KOps/s | 21.1462 KOps/s | |
test_setitem | 0.1343ms | 30.1375μs | 33.1812 KOps/s | 33.7225 KOps/s | |
test_set | 0.1327ms | 29.3817μs | 34.0348 KOps/s | 34.2969 KOps/s | |
test_set_shared | 3.9862ms | 0.1516ms | 6.5944 KOps/s | 6.5692 KOps/s | |
test_update | 0.2062ms | 33.4112μs | 29.9301 KOps/s | 30.4614 KOps/s | |
test_update_nested | 0.1818ms | 50.7340μs | 19.7106 KOps/s | 19.6309 KOps/s | |
test_set_nested | 0.1420ms | 32.2907μs | 30.9686 KOps/s | 32.6212 KOps/s | |
test_set_nested_new | 0.1640ms | 50.3807μs | 19.8489 KOps/s | 20.0957 KOps/s | |
test_select | 0.2084ms | 95.4921μs | 10.4721 KOps/s | 10.4275 KOps/s | |
test_unbind_speed | 0.6754ms | 0.6445ms | 1.5516 KOps/s | 1.5702 KOps/s | |
test_unbind_speed_stack0 | 65.0108ms | 8.1771ms | 122.2932 Ops/s | 121.7479 Ops/s | |
test_unbind_speed_stack1 | 20.1000μs | 1.1359μs | 880.3476 KOps/s | 851.5773 KOps/s | |
test_creation[device0] | 0.5861ms | 0.3353ms | 2.9824 KOps/s | 3.0215 KOps/s | |
test_creation_from_tensor | 2.3506ms | 0.3761ms | 2.6590 KOps/s | 2.6028 KOps/s | |
test_add_one[memmap_tensor0] | 1.6077ms | 30.6159μs | 32.6627 KOps/s | 32.9233 KOps/s | |
test_contiguous[memmap_tensor0] | 32.0000μs | 8.5419μs | 117.0704 KOps/s | 124.0116 KOps/s | |
test_stack[memmap_tensor0] | 99.8010μs | 25.4036μs | 39.3644 KOps/s | 40.3244 KOps/s | |
test_memmaptd_index | 0.3306ms | 0.2996ms | 3.3382 KOps/s | 3.3277 KOps/s | |
test_memmaptd_index_astensor | 1.3311ms | 1.2462ms | 802.4354 Ops/s | 825.0703 Ops/s | |
test_memmaptd_index_op | 2.6824ms | 2.3972ms | 417.1594 Ops/s | 426.7196 Ops/s | |
test_reshape_pytree | 97.5010μs | 37.2768μs | 26.8263 KOps/s | 26.9943 KOps/s | |
test_reshape_td | 80.5010μs | 45.0255μs | 22.2096 KOps/s | 22.9611 KOps/s | |
test_view_pytree | 81.1010μs | 34.3484μs | 29.1134 KOps/s | 28.9407 KOps/s | |
test_view_td | 21.9010μs | 8.5527μs | 116.9218 KOps/s | 113.4358 KOps/s | |
test_unbind_pytree | 93.9010μs | 38.7783μs | 25.7876 KOps/s | 26.1889 KOps/s | |
test_unbind_td | 0.1696ms | 95.9546μs | 10.4216 KOps/s | 10.6831 KOps/s | |
test_split_pytree | 75.4000μs | 42.8856μs | 23.3179 KOps/s | 23.0864 KOps/s | |
test_split_td | 0.7825ms | 0.1136ms | 8.8063 KOps/s | 8.8744 KOps/s | |
test_add_pytree | 80.3000μs | 44.8779μs | 22.2827 KOps/s | 21.9086 KOps/s | |
test_add_td | 0.1202ms | 71.8087μs | 13.9259 KOps/s | 14.5245 KOps/s | |
test_distributed | 31.8000μs | 8.1432μs | 122.8024 KOps/s | 119.1901 KOps/s | |
test_tdmodule | 0.1820ms | 26.1766μs | 38.2021 KOps/s | 38.7678 KOps/s | |
test_tdmodule_dispatch | 0.2566ms | 50.3726μs | 19.8521 KOps/s | 19.7219 KOps/s | |
test_tdseq | 0.4612ms | 26.4047μs | 37.8721 KOps/s | 36.4417 KOps/s | |
test_tdseq_dispatch | 0.1722ms | 54.1136μs | 18.4796 KOps/s | 18.2786 KOps/s | |
test_instantiation_functorch | 1.9532ms | 1.5363ms | 650.9201 Ops/s | 652.9604 Ops/s | |
test_instantiation_td | 1.9518ms | 1.2582ms | 794.7636 Ops/s | 738.4640 Ops/s | |
test_exec_functorch | 0.3218ms | 0.1793ms | 5.5761 KOps/s | 5.6800 KOps/s | |
test_exec_td | 0.2002ms | 0.1684ms | 5.9388 KOps/s | 5.9438 KOps/s | |
test_vmap_mlp_speed[True-True] | 1.7938ms | 1.0540ms | 948.7583 Ops/s | 947.1470 Ops/s | |
test_vmap_mlp_speed[True-False] | 5.3164ms | 0.5207ms | 1.9203 KOps/s | 1.9085 KOps/s | |
test_vmap_mlp_speed[False-True] | 6.6471ms | 0.9187ms | 1.0885 KOps/s | 1.1162 KOps/s | |
test_vmap_mlp_speed[False-False] | 4.9565ms | 0.4149ms | 2.4102 KOps/s | 2.4458 KOps/s | |
test_vmap_transformer_speed[True-True] | 15.8794ms | 12.4663ms | 80.2166 Ops/s | 80.0810 Ops/s | |
test_vmap_transformer_speed[True-False] | 13.0686ms | 7.7094ms | 129.7120 Ops/s | 130.1117 Ops/s | |
test_vmap_transformer_speed[False-True] | 22.1632ms | 12.3564ms | 80.9298 Ops/s | 81.9336 Ops/s | |
test_vmap_transformer_speed[False-False] | 12.4010ms | 7.6057ms | 131.4796 Ops/s | 132.8867 Ops/s |
|
||
def log_prob(self, sample: TensorDictBase): | ||
d = { | ||
_add_suffix(name, "_log_prob"): dist.log_prob(sample.get(name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do not want these keys to be parametric?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
# Conflicts: # test/test_nn.py
Description
Introduces CompositeDistribution, which allows to build bags of distributions.
Addresses pytorch/rl#1473
cc @hersh