Commit eb6f4ab
enable customized emb lookup kernel for TorchRec
Summary:
# context
* NVIDIA dynamicemb package depends on an old TorchRec release (r0.7) plus a PR ([meta-pytorch#2533](meta-pytorch#2533))
* The goal is to refactor the PR ([meta-pytorch#2533](meta-pytorch#2533)) on trunk so that torchrec can accept customized kernel.
# design rationales
* Given the fact that the [`EmbeddingComputeKernel`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_types.py#L64-L72) is a Enum class which can't be dynamically extended outside of TorchRec codebase, we are adding a placeholder type named `customized_kernel` for all customized compute kernels.
* `compute_kernel` is set in [ParameterSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L694), along with `sharding_type`, `sharding_specs`, etc. User can subclass the `ParameterSharding` dataclass to add more configs and parameters needed by the customized compute kernel, including something like `customized_compute_kernel` to specify the exact one in case there are many.
* In order to propagate some [extra config](https://fburl.com/code/bnwp44sz) to the customized kernel, we add a `get_additional_fused_params` to propagate the params to `fused_params`. (we might consider to move the [`add_params_from_parameter_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359) function to a class function of ParameterSharding, so that the user can modify the function when necessary.
NOTE: `fused_params` is originally used for passing necessary parameters to the fbgemm lookup kernels (e.g., TBE, see below). It now seems to be just a convenient way of [propagating configs to the kernel from `ParametersSharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/utils.py#L359).
```
(Pdb) group_fused_params
{'optimizer': <EmbOptimType.EXACT_ADAGRAD: 'exact_adagrad'>, 'learning_rate': 0.1}
```
* besides the lookup module, very often the customized kernel also needs a customized input_dist and/or a customized output_dist. they all come from [EmbeddingSharding](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding_sharding.py#L964) and its [child classes](https://github.com/pytorch/torchrec/tree/main/torchrec/distributed/sharding) like cw_sharding, tw_sharding, etc.
* we make it public for the main API [`create_embedding_sharding`](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/embedding.py#L150) function that return a subclass of EmbeddingSharding, which further creates the user-defined input_dist, output_dist, lookup modules and so on.
WARNING: somehow the HKV-based customized compute kernel can't handle `_initialize_torch_state` likely due to the table.weight tensor is no long on the GPU, so it can't really be represented with sharded tensor or DTensor. It's the user's responsibility to correctly handle the state_dict by overriding the `_initialize_torch_state` function.
Differential Revision: D707235831 parent 75f1f1c commit eb6f4ab
File tree
7 files changed
+259
-194
lines changed- torchrec/distributed
- planner/tests
7 files changed
+259
-194
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
150 | | - | |
151 | | - | |
152 | | - | |
153 | | - | |
154 | | - | |
155 | | - | |
156 | | - | |
157 | | - | |
158 | | - | |
159 | | - | |
160 | | - | |
161 | | - | |
162 | | - | |
163 | | - | |
164 | | - | |
165 | | - | |
166 | | - | |
167 | | - | |
168 | | - | |
169 | | - | |
170 | | - | |
171 | | - | |
172 | | - | |
173 | | - | |
174 | | - | |
175 | | - | |
176 | | - | |
177 | | - | |
178 | | - | |
179 | | - | |
180 | | - | |
181 | | - | |
182 | | - | |
183 | | - | |
184 | | - | |
185 | | - | |
186 | | - | |
187 | | - | |
188 | | - | |
189 | | - | |
190 | 150 | | |
191 | 151 | | |
192 | 152 | | |
| |||
557 | 517 | | |
558 | 518 | | |
559 | 519 | | |
560 | | - | |
| 520 | + | |
561 | 521 | | |
562 | 522 | | |
563 | 523 | | |
| |||
637 | 597 | | |
638 | 598 | | |
639 | 599 | | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
| 609 | + | |
| 610 | + | |
| 611 | + | |
| 612 | + | |
| 613 | + | |
| 614 | + | |
| 615 | + | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
| 624 | + | |
| 625 | + | |
| 626 | + | |
| 627 | + | |
| 628 | + | |
| 629 | + | |
| 630 | + | |
| 631 | + | |
| 632 | + | |
| 633 | + | |
| 634 | + | |
| 635 | + | |
| 636 | + | |
| 637 | + | |
| 638 | + | |
| 639 | + | |
| 640 | + | |
| 641 | + | |
| 642 | + | |
| 643 | + | |
| 644 | + | |
640 | 645 | | |
641 | 646 | | |
642 | 647 | | |
| |||
757 | 762 | | |
758 | 763 | | |
759 | 764 | | |
| 765 | + | |
| 766 | + | |
| 767 | + | |
| 768 | + | |
| 769 | + | |
| 770 | + | |
| 771 | + | |
| 772 | + | |
| 773 | + | |
| 774 | + | |
| 775 | + | |
| 776 | + | |
760 | 777 | | |
761 | 778 | | |
762 | 779 | | |
763 | 780 | | |
764 | 781 | | |
765 | | - | |
766 | | - | |
767 | | - | |
768 | 782 | | |
769 | 783 | | |
770 | 784 | | |
| |||
783 | 797 | | |
784 | 798 | | |
785 | 799 | | |
| 800 | + | |
| 801 | + | |
| 802 | + | |
| 803 | + | |
| 804 | + | |
786 | 805 | | |
787 | 806 | | |
788 | 807 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
181 | 181 | | |
182 | 182 | | |
183 | 183 | | |
184 | | - | |
185 | | - | |
186 | | - | |
187 | | - | |
188 | | - | |
189 | | - | |
190 | | - | |
191 | | - | |
192 | | - | |
193 | | - | |
194 | | - | |
195 | | - | |
196 | | - | |
197 | | - | |
198 | | - | |
199 | | - | |
200 | | - | |
201 | | - | |
202 | | - | |
203 | | - | |
204 | | - | |
205 | | - | |
206 | | - | |
207 | | - | |
208 | | - | |
209 | | - | |
210 | | - | |
211 | | - | |
212 | | - | |
213 | | - | |
214 | | - | |
215 | | - | |
216 | | - | |
217 | | - | |
218 | | - | |
219 | 184 | | |
220 | 185 | | |
221 | 186 | | |
222 | 187 | | |
223 | | - | |
| 188 | + | |
224 | 189 | | |
225 | 190 | | |
226 | 191 | | |
| |||
239 | 204 | | |
240 | 205 | | |
241 | 206 | | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
242 | 242 | | |
243 | 243 | | |
244 | 244 | | |
| |||
409 | 409 | | |
410 | 410 | | |
411 | 411 | | |
412 | | - | |
413 | | - | |
414 | | - | |
415 | | - | |
416 | | - | |
417 | | - | |
418 | | - | |
419 | | - | |
420 | | - | |
421 | | - | |
422 | | - | |
423 | | - | |
424 | | - | |
425 | | - | |
426 | | - | |
427 | | - | |
428 | | - | |
429 | | - | |
430 | | - | |
431 | | - | |
432 | | - | |
433 | | - | |
434 | | - | |
435 | | - | |
436 | | - | |
437 | | - | |
438 | | - | |
439 | | - | |
440 | | - | |
441 | | - | |
442 | | - | |
443 | | - | |
444 | | - | |
445 | | - | |
446 | 412 | | |
447 | 413 | | |
448 | 414 | | |
449 | | - | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
450 | 418 | | |
451 | 419 | | |
452 | 420 | | |
| |||
473 | 441 | | |
474 | 442 | | |
475 | 443 | | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
| 451 | + | |
| 452 | + | |
| 453 | + | |
| 454 | + | |
| 455 | + | |
| 456 | + | |
| 457 | + | |
| 458 | + | |
| 459 | + | |
| 460 | + | |
| 461 | + | |
| 462 | + | |
| 463 | + | |
| 464 | + | |
| 465 | + | |
| 466 | + | |
| 467 | + | |
| 468 | + | |
| 469 | + | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
476 | 477 | | |
477 | 478 | | |
478 | 479 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
70 | 70 | | |
71 | 71 | | |
72 | 72 | | |
| 73 | + | |
73 | 74 | | |
74 | 75 | | |
75 | 76 | | |
| |||
0 commit comments