Skip to content

Commit 89626c4

Browse files
committed
[asynctp] Optimize agmm lastdim via addmm_
1 parent 526849d commit 89626c4

File tree

1 file changed

+7
-14
lines changed

1 file changed

+7
-14
lines changed

autoparallel/asynctp_ops.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -625,11 +625,6 @@ def _fused_all_gather_matmul_last_gather_dim_impl(
625625
def unflatten(t: torch.Tensor) -> torch.Tensor:
626626
return t.view(*leading_dims, -1)
627627

628-
A_out_leading_dims = list(A_shard.shape[:-1])
629-
630-
def unflatten_A_out(t: torch.Tensor) -> torch.Tensor:
631-
return t.view(*A_out_leading_dims, -1)
632-
633628
A_flat_out = A_shard_flat.new_empty(
634629
A_shard_flat.shape[0] * group.size(),
635630
A_shard_flat.shape[1],
@@ -645,19 +640,17 @@ def unflatten_A_out(t: torch.Tensor) -> torch.Tensor:
645640
for B, out_dtype in zip(Bs, out_dtypes)
646641
]
647642

648-
# Additional allocation for partials output,
649-
# That will be reduced into output.
650-
output_partials = [torch.empty_like(out) for out in outputs]
651-
652643
first = True
653644

654645
def default_consumer(shard: torch.Tensor, rank: int) -> None:
655646
nonlocal first
656647
for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)):
657-
out = outputs[idx] if first else output_partials[idx]
658-
mm_out_op(shard, B_shards[idx][rank], **kwargs, out=out)
659-
if not first:
660-
outputs[idx] += output_partials[idx]
648+
out = outputs[idx]
649+
if first:
650+
torch.ops.aten.mm.out(shard, B_shards[idx][rank], **kwargs, out=out)
651+
else:
652+
out.addmm_(shard, B_shards[idx][rank])
653+
661654
first = False
662655

663656
_pipelined_all_gather_and_consume_last_dim(
@@ -672,7 +665,7 @@ def default_consumer(shard: torch.Tensor, rank: int) -> None:
672665
# This path is inefficient and will be filtered out at passes stage
673666
# Added only for completness.
674667
A_split_cat_out_flat = torch.cat(A_flat_out.chunk(group_size), dim=-1)
675-
ret_A = unflatten_A_out(A_split_cat_out_flat)
668+
ret_A = unflatten(A_split_cat_out_flat)
676669

677670
return ret_A, [unflatten(output) for output in outputs]
678671

0 commit comments

Comments
 (0)