Commit 2fe0200
Enable 2D sharding (#17)
Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.
Let's say we have a 2D mesh (data, model) and data x model == num_devices:
1. input (data,, None, model)
2. embedding (model, data)
3. attn QKV (data, model)
4. attn O (model, data)
5. mlp gate, up (model, data)
6. mlp down (data, model)
7. activation (data,, None, model)
Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.
TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.1 parent 1b9c7d8 commit 2fe0200
File tree
3 files changed
+69
-3
lines changed- examples/pytorch/language-modeling
- src/transformers
- models/llama
3 files changed
+69
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
189 | 189 | | |
190 | 190 | | |
191 | 191 | | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
192 | 200 | | |
193 | 201 | | |
194 | 202 | | |
| |||
468 | 476 | | |
469 | 477 | | |
470 | 478 | | |
| 479 | + | |
| 480 | + | |
471 | 481 | | |
472 | 482 | | |
473 | 483 | | |
| |||
538 | 548 | | |
539 | 549 | | |
540 | 550 | | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
541 | 587 | | |
542 | 588 | | |
543 | 589 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
401 | 401 | | |
402 | 402 | | |
403 | 403 | | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
404 | 420 | | |
405 | 421 | | |
406 | 422 | | |
| |||
920 | 936 | | |
921 | 937 | | |
922 | 938 | | |
| 939 | + | |
| 940 | + | |
| 941 | + | |
923 | 942 | | |
924 | 943 | | |
925 | 944 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1548 | 1548 | | |
1549 | 1549 | | |
1550 | 1550 | | |
1551 | | - | |
1552 | | - | |
| 1551 | + | |
| 1552 | + | |
| 1553 | + | |
1553 | 1554 | | |
1554 | | - | |
| 1555 | + | |
1555 | 1556 | | |
1556 | 1557 | | |
1557 | 1558 | | |
| |||
0 commit comments