Skip to content

Conversation

dstaay-fb
Copy link
Contributor

Summary:
Currently, we have KT.regroup as a functional call. Issue with this two fold:
(1) we don't caching values we effectively know after first batch, leading to marginally higher cpu computation
(2) this values look like unbacked SymInt in PT2 IR and most graph captures. Reality is they are known.

So while a user change, we are adding a new module, to leverage these above insights.

Benchmark (fwd+backward)
[fallback] _regroup_keyed_tenors | B: 512 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 72.0
[prod] KeyedTensor.regroup | B: 512 | F: 80 | device: cuda | Runtime (P90): 2.8 ms | Memory (P90): 72.0
[prod] KTRegroupAsDict | B: 512 | F: 80 | device: cuda | Runtime (P90): 2.3 ms | Memory (P90): 72.0
[fallback] _regroup_keyed_tenors | B: 512 | F: 160 | device: cuda | Runtime (P90): 7.7 ms | Memory (P90): 144.0
[prod] KeyedTensor.regroup | B: 512 | F: 160 | device: cuda | Runtime (P90): 4.6 ms | Memory (P90): 144.0
[prod] KTRegroupAsDict | B: 512 | F: 160 | device: cuda | Runtime (P90): 3.9 ms | Memory (P90): 144.0
[fallback] _regroup_keyed_tenors | B: 512 | F: 320 | device: cuda | Runtime (P90): 10.8 ms | Memory (P90): 288.0
[prod] KeyedTensor.regroup | B: 512 | F: 320 | device: cuda | Runtime (P90): 7.5 ms | Memory (P90): 288.0
[prod] KTRegroupAsDict | B: 512 | F: 320 | device: cuda | Runtime (P90): 9.9 ms | Memory (P90): 288.0
[fallback] _regroup_keyed_tenors | B: 512 | F: 640 | device: cuda | Runtime (P90): 22.7 ms | Memory (P90): 576.0
[prod] KeyedTensor.regroup | B: 512 | F: 640 | device: cuda | Runtime (P90): 13.8 ms | Memory (P90): 576.0
[prod] KTRegroupAsDict | B: 512 | F: 640 | device: cuda | Runtime (P90): 18.6 ms | Memory (P90): 576.0
[fallback] _regroup_keyed_tenors | B: 512 | F: 1280 | device: cuda | Runtime (P90): 58.0 ms | Memory (P90): 1152.0
[prod] KeyedTensor.regroup | B: 512 | F: 1280 | device: cuda | Runtime (P90): 27.9 ms | Memory (P90): 1152.0
[prod] KTRegroupAsDict | B: 512 | F: 1280 | device: cuda | Runtime (P90): 25.7 ms | Memory (P90): 1152.0
[fallback] _regroup_keyed_tenors | B: 1024 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 144.0
[prod] KeyedTensor.regroup | B: 1024 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 144.0
[prod] KTRegroupAsDict | B: 1024 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 144.0
[fallback] _regroup_keyed_tenors | B: 1024 | F: 160 | device: cuda | Runtime (P90): 6.6 ms | Memory (P90): 288.0
[prod] KeyedTensor.regroup | B: 1024 | F: 160 | device: cuda | Runtime (P90): 6.4 ms | Memory (P90): 288.0
[prod] KTRegroupAsDict | B: 1024 | F: 160 | device: cuda | Runtime (P90): 4.1 ms | Memory (P90): 288.0
[fallback] _regroup_keyed_tenors | B: 1024 | F: 320 | device: cuda | Runtime (P90): 15.0 ms | Memory (P90): 576.0
[prod] KeyedTensor.regroup | B: 1024 | F: 320 | device: cuda | Runtime (P90): 8.0 ms | Memory (P90): 576.0
[prod] KTRegroupAsDict | B: 1024 | F: 320 | device: cuda | Runtime (P90): 8.0 ms | Memory (P90): 576.0
[fallback] _regroup_keyed_tenors | B: 1024 | F: 640 | device: cuda | Runtime (P90): 23.6 ms | Memory (P90): 1152.0
[prod] KeyedTensor.regroup | B: 1024 | F: 640 | device: cuda | Runtime (P90): 19.3 ms | Memory (P90): 1152.0
[prod] KTRegroupAsDict | B: 1024 | F: 640 | device: cuda | Runtime (P90): 13.6 ms | Memory (P90): 1152.0
[fallback] _regroup_keyed_tenors | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 55.7 ms | Memory (P90): 2304.0
[prod] KeyedTensor.regroup | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 28.4 ms | Memory (P90): 2304.0
[prod] KTRegroupAsDict | B: 1024 | F: 1280 | device: cuda | Runtime (P90): 26.8 ms | Memory (P90): 2304.0
[fallback] _regroup_keyed_tenors | B: 2048 | F: 80 | device: cuda | Runtime (P90): 3.6 ms | Memory (P90): 288.0
[prod] KeyedTensor.regroup | B: 2048 | F: 80 | device: cuda | Runtime (P90): 3.5 ms | Memory (P90): 288.0
[prod] KTRegroupAsDict | B: 2048 | F: 80 | device: cuda | Runtime (P90): 3.6 ms | Memory (P90): 288.0
[fallback] _regroup_keyed_tenors | B: 2048 | F: 160 | device: cuda | Runtime (P90): 7.0 ms | Memory (P90): 576.0
[prod] KeyedTensor.regroup | B: 2048 | F: 160 | device: cuda | Runtime (P90): 6.4 ms | Memory (P90): 576.0
[prod] KTRegroupAsDict | B: 2048 | F: 160 | device: cuda | Runtime (P90): 4.6 ms | Memory (P90): 576.0
[fallback] _regroup_keyed_tenors | B: 2048 | F: 320 | device: cuda | Runtime (P90): 11.2 ms | Memory (P90): 1152.0
[prod] KeyedTensor.regroup | B: 2048 | F: 320 | device: cuda | Runtime (P90): 8.2 ms | Memory (P90): 1152.0
[prod] KTRegroupAsDict | B: 2048 | F: 320 | device: cuda | Runtime (P90): 8.8 ms | Memory (P90): 1152.0
[fallback] _regroup_keyed_tenors | B: 2048 | F: 640 | device: cuda | Runtime (P90): 23.9 ms | Memory (P90): 2304.0
[prod] KeyedTensor.regroup | B: 2048 | F: 640 | device: cuda | Runtime (P90): 20.6 ms | Memory (P90): 2304.0
[prod] KTRegroupAsDict | B: 2048 | F: 640 | device: cuda | Runtime (P90): 14.6 ms | Memory (P90): 2304.0
[fallback] _regroup_keyed_tenors | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 54.5 ms | Memory (P90): 4608.0
[prod] KeyedTensor.regroup | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 28.3 ms | Memory (P90): 4608.0
[prod] KTRegroupAsDict | B: 2048 | F: 1280 | device: cuda | Runtime (P90): 25.7 ms | Memory (P90): 4608.0
[fallback] _regroup_keyed_tenors | B: 4096 | F: 80 | device: cuda | Runtime (P90): 3.3 ms | Memory (P90): 576.0
[prod] KeyedTensor.regroup | B: 4096 | F: 80 | device: cuda | Runtime (P90): 2.7 ms | Memory (P90): 576.0
[prod] KTRegroupAsDict | B: 4096 | F: 80 | device: cuda | Runtime (P90): 2.3 ms | Memory (P90): 576.0
[fallback] _regroup_keyed_tenors | B: 4096 | F: 160 | device: cuda | Runtime (P90): 5.8 ms | Memory (P90): 1152.0
[prod] KeyedTensor.regroup | B: 4096 | F: 160 | device: cuda | Runtime (P90): 4.4 ms | Memory (P90): 1152.0
[prod] KTRegroupAsDict | B: 4096 | F: 160 | device: cuda | Runtime (P90): 3.9 ms | Memory (P90): 1152.0
[fallback] _regroup_keyed_tenors | B: 4096 | F: 320 | device: cuda | Runtime (P90): 11.1 ms | Memory (P90): 2304.0
[prod] KeyedTensor.regroup | B: 4096 | F: 320 | device: cuda | Runtime (P90): 7.8 ms | Memory (P90): 2304.0
[prod] KTRegroupAsDict | B: 4096 | F: 320 | device: cuda | Runtime (P90): 7.0 ms | Memory (P90): 2304.0
[fallback] _regroup_keyed_tenors | B: 4096 | F: 640 | device: cuda | Runtime (P90): 23.9 ms | Memory (P90): 4608.0
[prod] KeyedTensor.regroup | B: 4096 | F: 640 | device: cuda | Runtime (P90): 14.5 ms | Memory (P90): 4608.0
[prod] KTRegroupAsDict | B: 4096 | F: 640 | device: cuda | Runtime (P90): 13.3 ms | Memory (P90): 4608.0
[fallback] _regroup_keyed_tenors | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 64.0 ms | Memory (P90): 9216.0
[prod] KeyedTensor.regroup | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 26.9 ms | Memory (P90): 9216.0
[prod] KTRegroupAsDict | B: 4096 | F: 1280 | device: cuda | Runtime (P90): 25.1 ms | Memory (P90): 9216.0

Reviewed By: PaulZhang12

Differential Revision: D57312926

Summary:
Currently, we have KT.regroup as a functional call.  Issue with this two fold:
(1) we don't caching values we effectively know after first batch, leading to marginally higher cpu computation
(2) this values look like unbacked SymInt in PT2 IR and most graph captures.   Reality is they are known.

So while a user change, we are adding a new module, to leverage these above insights.

Benchmark (fwd+backward)
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 80       | device: cuda     | Runtime (P90):   3.3 ms | Memory (P90):  72.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 80       | device: cuda     | Runtime (P90):   2.8 ms | Memory (P90):  72.0
  [prod] KTRegroupAsDict              | B: 512      | F: 80       | device: cuda     | Runtime (P90):   2.3 ms | Memory (P90):  72.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 160      | device: cuda     | Runtime (P90):   7.7 ms | Memory (P90): 144.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 160      | device: cuda     | Runtime (P90):   4.6 ms | Memory (P90): 144.0
  [prod] KTRegroupAsDict              | B: 512      | F: 160      | device: cuda     | Runtime (P90):   3.9 ms | Memory (P90): 144.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 320      | device: cuda     | Runtime (P90):  10.8 ms | Memory (P90): 288.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 320      | device: cuda     | Runtime (P90):   7.5 ms | Memory (P90): 288.0
  [prod] KTRegroupAsDict              | B: 512      | F: 320      | device: cuda     | Runtime (P90):   9.9 ms | Memory (P90): 288.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 640      | device: cuda     | Runtime (P90):  22.7 ms | Memory (P90): 576.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 640      | device: cuda     | Runtime (P90):  13.8 ms | Memory (P90): 576.0
  [prod] KTRegroupAsDict              | B: 512      | F: 640      | device: cuda     | Runtime (P90):  18.6 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 512      | F: 1280     | device: cuda     | Runtime (P90):  58.0 ms | Memory (P90): 1152.0
  [prod] KeyedTensor.regroup          | B: 512      | F: 1280     | device: cuda     | Runtime (P90):  27.9 ms | Memory (P90): 1152.0
  [prod] KTRegroupAsDict              | B: 512      | F: 1280     | device: cuda     | Runtime (P90):  25.7 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 80       | device: cuda     | Runtime (P90):   3.3 ms | Memory (P90): 144.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 80       | device: cuda     | Runtime (P90):   3.3 ms | Memory (P90): 144.0
  [prod] KTRegroupAsDict              | B: 1024     | F: 80       | device: cuda     | Runtime (P90):   3.3 ms | Memory (P90): 144.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 160      | device: cuda     | Runtime (P90):   6.6 ms | Memory (P90): 288.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 160      | device: cuda     | Runtime (P90):   6.4 ms | Memory (P90): 288.0
  [prod] KTRegroupAsDict              | B: 1024     | F: 160      | device: cuda     | Runtime (P90):   4.1 ms | Memory (P90): 288.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 320      | device: cuda     | Runtime (P90):  15.0 ms | Memory (P90): 576.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 320      | device: cuda     | Runtime (P90):   8.0 ms | Memory (P90): 576.0
  [prod] KTRegroupAsDict              | B: 1024     | F: 320      | device: cuda     | Runtime (P90):   8.0 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 640      | device: cuda     | Runtime (P90):  23.6 ms | Memory (P90): 1152.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 640      | device: cuda     | Runtime (P90):  19.3 ms | Memory (P90): 1152.0
  [prod] KTRegroupAsDict              | B: 1024     | F: 640      | device: cuda     | Runtime (P90):  13.6 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 1024     | F: 1280     | device: cuda     | Runtime (P90):  55.7 ms | Memory (P90): 2304.0
  [prod] KeyedTensor.regroup          | B: 1024     | F: 1280     | device: cuda     | Runtime (P90):  28.4 ms | Memory (P90): 2304.0
  [prod] KTRegroupAsDict              | B: 1024     | F: 1280     | device: cuda     | Runtime (P90):  26.8 ms | Memory (P90): 2304.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 80       | device: cuda     | Runtime (P90):   3.6 ms | Memory (P90): 288.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 80       | device: cuda     | Runtime (P90):   3.5 ms | Memory (P90): 288.0
  [prod] KTRegroupAsDict              | B: 2048     | F: 80       | device: cuda     | Runtime (P90):   3.6 ms | Memory (P90): 288.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 160      | device: cuda     | Runtime (P90):   7.0 ms | Memory (P90): 576.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 160      | device: cuda     | Runtime (P90):   6.4 ms | Memory (P90): 576.0
  [prod] KTRegroupAsDict              | B: 2048     | F: 160      | device: cuda     | Runtime (P90):   4.6 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 320      | device: cuda     | Runtime (P90):  11.2 ms | Memory (P90): 1152.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 320      | device: cuda     | Runtime (P90):   8.2 ms | Memory (P90): 1152.0
  [prod] KTRegroupAsDict              | B: 2048     | F: 320      | device: cuda     | Runtime (P90):   8.8 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 640      | device: cuda     | Runtime (P90):  23.9 ms | Memory (P90): 2304.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 640      | device: cuda     | Runtime (P90):  20.6 ms | Memory (P90): 2304.0
  [prod] KTRegroupAsDict              | B: 2048     | F: 640      | device: cuda     | Runtime (P90):  14.6 ms | Memory (P90): 2304.0
  [fallback] _regroup_keyed_tenors    | B: 2048     | F: 1280     | device: cuda     | Runtime (P90):  54.5 ms | Memory (P90): 4608.0
  [prod] KeyedTensor.regroup          | B: 2048     | F: 1280     | device: cuda     | Runtime (P90):  28.3 ms | Memory (P90): 4608.0
  [prod] KTRegroupAsDict              | B: 2048     | F: 1280     | device: cuda     | Runtime (P90):  25.7 ms | Memory (P90): 4608.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 80       | device: cuda     | Runtime (P90):   3.3 ms | Memory (P90): 576.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 80       | device: cuda     | Runtime (P90):   2.7 ms | Memory (P90): 576.0
  [prod] KTRegroupAsDict              | B: 4096     | F: 80       | device: cuda     | Runtime (P90):   2.3 ms | Memory (P90): 576.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 160      | device: cuda     | Runtime (P90):   5.8 ms | Memory (P90): 1152.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 160      | device: cuda     | Runtime (P90):   4.4 ms | Memory (P90): 1152.0
  [prod] KTRegroupAsDict              | B: 4096     | F: 160      | device: cuda     | Runtime (P90):   3.9 ms | Memory (P90): 1152.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 320      | device: cuda     | Runtime (P90):  11.1 ms | Memory (P90): 2304.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 320      | device: cuda     | Runtime (P90):   7.8 ms | Memory (P90): 2304.0
  [prod] KTRegroupAsDict              | B: 4096     | F: 320      | device: cuda     | Runtime (P90):   7.0 ms | Memory (P90): 2304.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 640      | device: cuda     | Runtime (P90):  23.9 ms | Memory (P90): 4608.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 640      | device: cuda     | Runtime (P90):  14.5 ms | Memory (P90): 4608.0
  [prod] KTRegroupAsDict              | B: 4096     | F: 640      | device: cuda     | Runtime (P90):  13.3 ms | Memory (P90): 4608.0
  [fallback] _regroup_keyed_tenors    | B: 4096     | F: 1280     | device: cuda     | Runtime (P90):  64.0 ms | Memory (P90): 9216.0
  [prod] KeyedTensor.regroup          | B: 4096     | F: 1280     | device: cuda     | Runtime (P90):  26.9 ms | Memory (P90): 9216.0
  [prod] KTRegroupAsDict              | B: 4096     | F: 1280     | device: cuda     | Runtime (P90):  25.1 ms | Memory (P90): 9216.0

Reviewed By: PaulZhang12

Differential Revision: D57312926
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 15, 2024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D57312926

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants