Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallel #153

Merged
merged 32 commits into from
Aug 29, 2023
Merged
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0d4ea37
fix is_first_layer
Aug 22, 2023
3063afb
tensor parallel
Aug 23, 2023
bdc1ed9
Merge branch 'fix_first_layer' into tensor_parallel
Aug 23, 2023
8648f5b
rm unused code
Aug 23, 2023
763b408
refactor nccl group; remove partition_modules in pipe_layer.py
Aug 24, 2023
4c50567
fix by review comment
Aug 24, 2023
825139c
fix topology
Aug 24, 2023
4ff0f41
fix topology
Aug 24, 2023
a5d7ba6
fix
Aug 24, 2023
2951d70
use ParallelEmbedding
Aug 24, 2023
39319e1
overlap parallel linear backward
Aug 24, 2023
df3fd8f
add tp_comm_stream
Aug 24, 2023
99efba3
fix tp
Achazwl Aug 24, 2023
85dd5ab
Merge branch 'tensor_parallel' into tp
Achazwl Aug 24, 2023
76abcb4
Merge pull request #1 from Achazwl/tp
Aug 24, 2023
f1b4fd7
fix load_state_dict
Aug 25, 2023
677a316
test parallel linear
Aug 25, 2023
743253e
mv zero_level to CheckpointBlock
Aug 25, 2023
4e8c462
merge dev
Aug 25, 2023
604ddfe
fix overlap
Aug 25, 2023
0aee817
gather once in atten
Aug 25, 2023
bd0bad0
fix sub grad_input in parallel linear
Aug 25, 2023
50cdcaf
Merge branch 'dev' into tensor_parallel
zkh2016 Aug 26, 2023
15460b6
fix gather_output
Aug 26, 2023
0e0e05c
Merge branch 'tensor_parallel' of https://github.com/zkh2016/BMTrain …
Aug 26, 2023
b44a62e
fix train.py
Aug 26, 2023
100cd55
fused q,k,v
Aug 26, 2023
fa09468
fix row parallel linear
Aug 26, 2023
37bc403
fix cross entropy
Aug 26, 2023
15c2c48
Update setup.py
zkh2016 Aug 28, 2023
42663c8
overlap send communication in pipeline
Aug 28, 2023
207912b
Merge branch 'tensor_parallel' of https://github.com/zkh2016/BMTrain …
Aug 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fused q,k,v
zhangkaihuo committed Aug 26, 2023
commit 100cd55bae8e32d1f46aacc97485c3f00703204f
16 changes: 11 additions & 5 deletions example/layers/attention.py
Original file line number Diff line number Diff line change
@@ -44,12 +44,18 @@ def forward(self,
batch_size, seq_q, dim_model = hidden_q.size()
seq_kv = hidden_kv.size(1)

if config['tp_size'] > 1:
hidden_q = all_gather(hidden_q, comm=config['tp_comm']).flatten(0,1)
assert hidden_q.data_ptr() == hidden_kv.data_ptr()

hidden_q = bmt.nn.OpParallelLinear.apply(
hidden_q,
torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0),
torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0),
True, False,
False, None
)

h_q, h_k, h_v = hidden_q.chunk(3, dim=-1)

h_q : torch.Tensor = self.project_q(hidden_q)
h_k : torch.Tensor = self.project_k(hidden_q)
h_v : torch.Tensor = self.project_v(hidden_q)
if config['tp_size'] > 1:
#batch_size will changed in TensorParallel
batch_size = h_v.shape[0]