-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] torch.func.functional_call compatibility #526
Draft
vmoens
wants to merge
8
commits into
main
Choose a base branch
from
patch_functional_call
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
facebook-github-bot
added
the
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
label
Sep 13, 2023
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 36.6000μs | 20.1319μs | 49.6723 KOps/s | 49.6262 KOps/s | |
test_plain_set_stack_nested | 0.2248ms | 0.1860ms | 5.3758 KOps/s | 5.3589 KOps/s | |
test_plain_set_nested_inplace | 0.2070ms | 23.7512μs | 42.1031 KOps/s | 42.1784 KOps/s | |
test_plain_set_stack_nested_inplace | 0.3177ms | 0.2227ms | 4.4896 KOps/s | 4.5046 KOps/s | |
test_items | 19.6010μs | 3.5576μs | 281.0898 KOps/s | 282.8163 KOps/s | |
test_items_nested | 2.1115ms | 0.3710ms | 2.6956 KOps/s | 2.7500 KOps/s | |
test_items_nested_locked | 0.4602ms | 0.3682ms | 2.7161 KOps/s | 2.7585 KOps/s | |
test_items_nested_leaf | 5.2283ms | 0.2303ms | 4.3431 KOps/s | 4.5226 KOps/s | |
test_items_stack_nested | 2.1070ms | 1.9909ms | 502.2932 Ops/s | 506.9492 Ops/s | |
test_items_stack_nested_leaf | 2.5761ms | 1.8127ms | 551.6657 Ops/s | 556.4046 Ops/s | |
test_items_stack_nested_locked | 2.0614ms | 1.0141ms | 986.1396 Ops/s | 1.0056 KOps/s | |
test_keys | 28.8000μs | 5.0823μs | 196.7602 KOps/s | 196.4588 KOps/s | |
test_keys_nested | 2.2534ms | 0.1818ms | 5.4996 KOps/s | 5.5013 KOps/s | |
test_keys_nested_locked | 0.2952ms | 0.1807ms | 5.5337 KOps/s | 5.5338 KOps/s | |
test_keys_nested_leaf | 0.3196ms | 0.1741ms | 5.7442 KOps/s | 5.3741 KOps/s | |
test_keys_stack_nested | 1.9641ms | 1.8284ms | 546.9374 Ops/s | 549.5832 Ops/s | |
test_keys_stack_nested_leaf | 1.9843ms | 1.8207ms | 549.2492 Ops/s | 548.1867 Ops/s | |
test_keys_stack_nested_locked | 0.9288ms | 0.8019ms | 1.2470 KOps/s | 1.1979 KOps/s | |
test_values | 23.9000μs | 1.5817μs | 632.2247 KOps/s | 636.3023 KOps/s | |
test_values_nested | 0.1408ms | 67.6041μs | 14.7920 KOps/s | 14.7986 KOps/s | |
test_values_nested_locked | 0.1509ms | 67.6046μs | 14.7919 KOps/s | 14.8450 KOps/s | |
test_values_nested_leaf | 0.2538ms | 59.4498μs | 16.8209 KOps/s | 16.9379 KOps/s | |
test_values_stack_nested | 1.7298ms | 1.5916ms | 628.2874 Ops/s | 624.6547 Ops/s | |
test_values_stack_nested_leaf | 1.7003ms | 1.5867ms | 630.2225 Ops/s | 622.3616 Ops/s | |
test_values_stack_nested_locked | 0.7704ms | 0.6353ms | 1.5741 KOps/s | 1.5482 KOps/s | |
test_membership | 23.0000μs | 1.9144μs | 522.3595 KOps/s | 530.7736 KOps/s | |
test_membership_nested | 21.6000μs | 3.6825μs | 271.5529 KOps/s | 266.4596 KOps/s | |
test_membership_nested_leaf | 67.8000μs | 3.6739μs | 272.1914 KOps/s | 268.3146 KOps/s | |
test_membership_stacked_nested | 40.3000μs | 14.3076μs | 69.8928 KOps/s | 69.3910 KOps/s | |
test_membership_stacked_nested_leaf | 69.9010μs | 14.3796μs | 69.5430 KOps/s | 69.1615 KOps/s | |
test_membership_nested_last | 35.4010μs | 7.5629μs | 132.2236 KOps/s | 131.1132 KOps/s | |
test_membership_nested_leaf_last | 85.1000μs | 7.5631μs | 132.2210 KOps/s | 133.0973 KOps/s | |
test_membership_stacked_nested_last | 0.2533ms | 0.2272ms | 4.4022 KOps/s | 4.4106 KOps/s | |
test_membership_stacked_nested_leaf_last | 88.5000μs | 16.7396μs | 59.7386 KOps/s | 59.2961 KOps/s | |
test_nested_getleaf | 84.1010μs | 15.6565μs | 63.8711 KOps/s | 64.0844 KOps/s | |
test_nested_get | 39.4010μs | 15.0614μs | 66.3949 KOps/s | 67.4414 KOps/s | |
test_stacked_getleaf | 1.0147ms | 0.8827ms | 1.1328 KOps/s | 1.1394 KOps/s | |
test_stacked_get | 0.9467ms | 0.8436ms | 1.1854 KOps/s | 1.1876 KOps/s | |
test_nested_getitemleaf | 74.7010μs | 15.6914μs | 63.7290 KOps/s | 64.0191 KOps/s | |
test_nested_getitem | 84.4000μs | 15.1988μs | 65.7948 KOps/s | 66.9395 KOps/s | |
test_stacked_getitemleaf | 1.0049ms | 0.8815ms | 1.1345 KOps/s | 1.1364 KOps/s | |
test_stacked_getitem | 0.9503ms | 0.8402ms | 1.1902 KOps/s | 1.1889 KOps/s | |
test_lock_nested | 70.3423ms | 1.5607ms | 640.7476 Ops/s | 690.8584 Ops/s | |
test_lock_stack_nested | 92.8937ms | 20.6021ms | 48.5388 Ops/s | 52.9614 Ops/s | |
test_unlock_nested | 72.1171ms | 1.5901ms | 628.9043 Ops/s | 643.9713 Ops/s | |
test_unlock_stack_nested | 94.7629ms | 21.1120ms | 47.3664 Ops/s | 51.9070 Ops/s | |
test_flatten_speed | 1.1341ms | 1.0387ms | 962.7431 Ops/s | 992.0428 Ops/s | |
test_unflatten_speed | 2.0052ms | 1.8551ms | 539.0414 Ops/s | 550.2088 Ops/s | |
test_common_ops | 4.9764ms | 1.1073ms | 903.0866 Ops/s | 912.9710 Ops/s | |
test_creation | 43.3000μs | 6.4060μs | 156.1043 KOps/s | 158.4443 KOps/s | |
test_creation_empty | 33.1010μs | 14.3194μs | 69.8351 KOps/s | 72.4853 KOps/s | |
test_creation_nested_1 | 0.1133ms | 25.7018μs | 38.9078 KOps/s | 40.3493 KOps/s | |
test_creation_nested_2 | 56.4010μs | 27.9063μs | 35.8342 KOps/s | 37.0007 KOps/s | |
test_clone | 0.1500ms | 24.4826μs | 40.8454 KOps/s | 40.8449 KOps/s | |
test_getitem[int] | 95.8010μs | 28.4730μs | 35.1210 KOps/s | 35.1962 KOps/s | |
test_getitem[slice_int] | 89.4010μs | 56.5025μs | 17.6983 KOps/s | 18.2454 KOps/s | |
test_getitem[range] | 0.1585ms | 82.7485μs | 12.0848 KOps/s | 12.2953 KOps/s | |
test_getitem[tuple] | 0.1289ms | 46.9888μs | 21.2817 KOps/s | 22.0473 KOps/s | |
test_getitem[list] | 0.3231ms | 78.2674μs | 12.7767 KOps/s | 13.1838 KOps/s | |
test_setitem_dim[int] | 52.1000μs | 32.8513μs | 30.4402 KOps/s | 30.4121 KOps/s | |
test_setitem_dim[slice_int] | 95.4000μs | 59.5256μs | 16.7995 KOps/s | 17.0605 KOps/s | |
test_setitem_dim[range] | 0.1806ms | 81.3022μs | 12.2998 KOps/s | 12.7291 KOps/s | |
test_setitem_dim[tuple] | 68.6010μs | 49.1259μs | 20.3559 KOps/s | 20.6272 KOps/s | |
test_setitem | 0.1872ms | 32.8443μs | 30.4467 KOps/s | 30.7730 KOps/s | |
test_set | 0.1658ms | 31.9336μs | 31.3150 KOps/s | 32.0812 KOps/s | |
test_set_shared | 0.3701ms | 0.1763ms | 5.6717 KOps/s | 5.7306 KOps/s | |
test_update | 0.1895ms | 35.8138μs | 27.9222 KOps/s | 28.0489 KOps/s | |
test_update_nested | 0.2073ms | 53.4312μs | 18.7156 KOps/s | 19.1434 KOps/s | |
test_set_nested | 0.2136ms | 35.5812μs | 28.1047 KOps/s | 28.8723 KOps/s | |
test_set_nested_new | 0.2091ms | 53.5749μs | 18.6654 KOps/s | 18.7271 KOps/s | |
test_select | 0.2634ms | 97.2734μs | 10.2803 KOps/s | 10.2777 KOps/s | |
test_unbind_speed | 0.7404ms | 0.6459ms | 1.5482 KOps/s | 1.5445 KOps/s | |
test_unbind_speed_stack0 | 7.2364ms | 7.0178ms | 142.4945 Ops/s | 112.0550 Ops/s | |
test_unbind_speed_stack1 | 9.4667μs | 0.9427μs | 1.0607 MOps/s | 871.3604 KOps/s | |
test_creation[device0] | 0.5448ms | 0.4428ms | 2.2584 KOps/s | 2.2232 KOps/s | |
test_creation_from_tensor | 3.7764ms | 0.4954ms | 2.0185 KOps/s | 1.9995 KOps/s | |
test_add_one[memmap_tensor0] | 1.7058ms | 32.6880μs | 30.5923 KOps/s | 30.1323 KOps/s | |
test_contiguous[memmap_tensor0] | 38.7000μs | 8.7294μs | 114.5557 KOps/s | 111.2680 KOps/s | |
test_stack[memmap_tensor0] | 74.2010μs | 26.9615μs | 37.0899 KOps/s | 36.9764 KOps/s | |
test_memmaptd_index | 0.4179ms | 0.3209ms | 3.1158 KOps/s | 3.1464 KOps/s | |
test_memmaptd_index_astensor | 1.5184ms | 1.3754ms | 727.0707 Ops/s | 724.9918 Ops/s | |
test_memmaptd_index_op | 2.8106ms | 2.6604ms | 375.8810 Ops/s | 377.7502 Ops/s | |
test_reshape_pytree | 0.1028ms | 38.0584μs | 26.2754 KOps/s | 26.4072 KOps/s | |
test_reshape_td | 0.1390ms | 46.6509μs | 21.4358 KOps/s | 22.0892 KOps/s | |
test_view_pytree | 96.8000μs | 35.4726μs | 28.1908 KOps/s | 28.3203 KOps/s | |
test_view_td | 40.2010μs | 8.8874μs | 112.5184 KOps/s | 112.9799 KOps/s | |
test_unbind_pytree | 92.2010μs | 38.5714μs | 25.9259 KOps/s | 25.2648 KOps/s | |
test_unbind_td | 0.1769ms | 95.4387μs | 10.4779 KOps/s | 10.3660 KOps/s | |
test_split_pytree | 0.1085ms | 39.7945μs | 25.1291 KOps/s | 22.0737 KOps/s | |
test_split_td | 0.8899ms | 0.1089ms | 9.1821 KOps/s | 8.7008 KOps/s | |
test_add_pytree | 96.5000μs | 47.4330μs | 21.0824 KOps/s | 20.7851 KOps/s | |
test_add_td | 0.1951ms | 77.0191μs | 12.9838 KOps/s | 13.2048 KOps/s | |
test_distributed | 34.4010μs | 8.9963μs | 111.1562 KOps/s | 108.8611 KOps/s | |
test_tdmodule | 0.2004ms | 28.7771μs | 34.7498 KOps/s | 34.4102 KOps/s | |
test_tdmodule_dispatch | 0.2746ms | 55.8500μs | 17.9051 KOps/s | 18.1104 KOps/s | |
test_tdseq | 0.5336ms | 32.6006μs | 30.6742 KOps/s | 30.9244 KOps/s | |
test_tdseq_dispatch | 0.2107ms | 67.3290μs | 14.8524 KOps/s | 15.2107 KOps/s | |
test_instantiation_functorch | 1.7436ms | 1.6196ms | 617.4208 Ops/s | 610.1817 Ops/s | |
test_instantiation_td | 2.0428ms | 1.3482ms | 741.7113 Ops/s | 669.4873 Ops/s | |
test_exec_functorch | 0.2320ms | 0.1891ms | 5.2889 KOps/s | 5.3180 KOps/s | |
test_exec_td | 0.3168ms | 0.1814ms | 5.5139 KOps/s | 5.5206 KOps/s | |
test_vmap_mlp_speed[True-True] | 10.4489ms | 1.2150ms | 823.0291 Ops/s | 841.6600 Ops/s | |
test_vmap_mlp_speed[True-False] | 3.5011ms | 0.6210ms | 1.6103 KOps/s | 1.6085 KOps/s | |
test_vmap_mlp_speed[False-True] | 7.6494ms | 1.0180ms | 982.3164 Ops/s | 993.3317 Ops/s | |
test_vmap_mlp_speed[False-False] | 7.2815ms | 0.4669ms | 2.1416 KOps/s | 2.1228 KOps/s | |
test_vmap_transformer_speed[True-True] | 16.7915ms | 13.7518ms | 72.7178 Ops/s | 69.5630 Ops/s | |
test_vmap_transformer_speed[True-False] | 12.7389ms | 8.8963ms | 112.4066 Ops/s | 106.8358 Ops/s | |
test_vmap_transformer_speed[False-True] | 20.2985ms | 14.0073ms | 71.3913 Ops/s | 72.2573 Ops/s | |
test_vmap_transformer_speed[False-False] | 14.1718ms | 8.5581ms | 116.8477 Ops/s | 114.0857 Ops/s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
enhancement
New feature or request
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Depends on #525.
It's a dirty monkey patch (see this comment). It will be cleaner in the tensordict -> core PR.
Unfortunately functional calls don't work with torch.func.dim.Tensor so our custom test doesn't currently work. However, we can make functional calls with batched tensors within the tensordict functional API but it the goal of this PR is to slowly prepare the deprecation of that feature.
Is there a roadmap to have that functionality @zou3519? Maybe accept torch.func.dim.Tensor as input for functional calls?