Commit d909935
committed
Enable 2D sharding (#17)
Summary:
This pull request adds 2D SPMD sharding to the table. It will shard both weights and activations. Here is the sharding strategy.
Let's say we have a 2D mesh (data, model) and data x model == num_devices:
1. input (data,, None, model)
2. embedding (model, data)
3. attn QKV (data, model)
4. attn O (model, data)
5. mlp gate, up (model, data)
6. mlp down (data, model)
7. activation (data,, None, model)
Currently you can specify the model dimension using a new option --spmd_2d_sharding, then the data dimension will be auto-calculated.
TODO: maybe we should have another option to specify whether or not we should shard the activations/inputs or shard them differently.1 parent 674ab35 commit d909935
File tree
3 files changed
+88
-3
lines changed- examples/pytorch/language-modeling
- src/transformers
- models/llama
3 files changed
+88
-3
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
189 | 189 | | |
190 | 190 | | |
191 | 191 | | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
192 | 200 | | |
193 | 201 | | |
194 | 202 | | |
| |||
297 | 305 | | |
298 | 306 | | |
299 | 307 | | |
| 308 | + | |
300 | 309 | | |
301 | 310 | | |
302 | 311 | | |
| |||
469 | 478 | | |
470 | 479 | | |
471 | 480 | | |
| 481 | + | |
| 482 | + | |
472 | 483 | | |
473 | 484 | | |
474 | 485 | | |
| |||
539 | 550 | | |
540 | 551 | | |
541 | 552 | | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
| 584 | + | |
| 585 | + | |
| 586 | + | |
| 587 | + | |
| 588 | + | |
542 | 589 | | |
543 | 590 | | |
544 | 591 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
392 | 392 | | |
393 | 393 | | |
394 | 394 | | |
| 395 | + | |
| 396 | + | |
395 | 397 | | |
396 | 398 | | |
397 | 399 | | |
| |||
540 | 542 | | |
541 | 543 | | |
542 | 544 | | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
543 | 561 | | |
544 | 562 | | |
545 | 563 | | |
| |||
935 | 953 | | |
936 | 954 | | |
937 | 955 | | |
| 956 | + | |
| 957 | + | |
| 958 | + | |
938 | 959 | | |
939 | 960 | | |
940 | 961 | | |
| |||
1015 | 1036 | | |
1016 | 1037 | | |
1017 | 1038 | | |
| 1039 | + | |
1018 | 1040 | | |
| 1041 | + | |
| 1042 | + | |
| 1043 | + | |
| 1044 | + | |
| 1045 | + | |
| 1046 | + | |
| 1047 | + | |
| 1048 | + | |
| 1049 | + | |
| 1050 | + | |
| 1051 | + | |
| 1052 | + | |
| 1053 | + | |
| 1054 | + | |
| 1055 | + | |
1019 | 1056 | | |
1020 | 1057 | | |
1021 | 1058 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1427 | 1427 | | |
1428 | 1428 | | |
1429 | 1429 | | |
1430 | | - | |
1431 | | - | |
| 1430 | + | |
| 1431 | + | |
| 1432 | + | |
1432 | 1433 | | |
1433 | | - | |
| 1434 | + | |
1434 | 1435 | | |
1435 | 1436 | | |
1436 | 1437 | | |
| |||
0 commit comments