Skip to content

Commit 798d879

Browse files
committed
Test Ruisi's commit
1 parent 017b8e2 commit 798d879

File tree

1 file changed

+63
-5
lines changed

1 file changed

+63
-5
lines changed

examples/example_llama3.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -558,10 +558,14 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
558558

559559
world_size = 256
560560

561-
fake_store = FakeStore()
562-
torch.distributed.init_process_group(
563-
"fake", store=fake_store, rank=0, world_size=world_size
564-
)
561+
backend = "fake"
562+
kwargs = {"rank": 0, "world_size": world_size}
563+
if True:
564+
backend = "nccl"
565+
fake_store = None
566+
kwargs = {}
567+
world_size = 8
568+
torch.distributed.init_process_group(backend, store=fake_store, **kwargs)
565569

566570
use_1d_mesh = False
567571

@@ -588,7 +592,7 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
588592

589593
def model_fn():
590594
model_args = TransformerModelArgs(
591-
n_layers=32,
595+
n_layers=2,
592596
vocab_size=vocab_size,
593597
max_seq_len=seqlen,
594598
multiple_of=1024,
@@ -604,6 +608,28 @@ def input_fn():
604608
return x
605609

606610

611+
from functools import partial
612+
613+
from autoparallel.auto_bucketing import (
614+
simple_fsdp_autobucketing_reordering_pass,
615+
simplefsdp_autobucketing_config,
616+
)
617+
618+
torch._inductor.config.allow_buffer_reuse = False
619+
torch._inductor.config.reorder_for_peak_memory = False
620+
torch._inductor.config.reorder_for_compute_comm_overlap = True
621+
simplefsdp_autobucketing_config.save_estimation_path = (
622+
"/storage/home/fmassa/work/projects/autoparallel/estimation_mast.pkl"
623+
)
624+
simple_fsdp_autobucketing_reordering_pass = partial(
625+
simple_fsdp_autobucketing_reordering_pass,
626+
configs=simplefsdp_autobucketing_config,
627+
)
628+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
629+
simple_fsdp_autobucketing_reordering_pass
630+
]
631+
632+
607633
# parallelize the model
608634
with torch.device("meta"):
609635
model = model_fn()
@@ -659,3 +685,35 @@ def input_fn():
659685
out = parallel_mod(*x)
660686
out.backward(torch.randn_like(out))
661687
print("All good!")
688+
689+
690+
def _dump_trace(prof):
691+
prof.export_chrome_trace(
692+
f"/home/fmassa/work/projects/autoparallel/traces/rank{torch.distributed.get_rank()}.json"
693+
)
694+
695+
696+
prof = torch.profiler.profile(
697+
activities=[
698+
torch.profiler.ProfilerActivity.CUDA,
699+
torch.profiler.ProfilerActivity.CPU,
700+
],
701+
schedule=torch.profiler.schedule(
702+
warmup=2,
703+
active=2,
704+
wait=1,
705+
repeat=1,
706+
skip_first=0,
707+
),
708+
record_shapes=False,
709+
profile_memory=False,
710+
with_stack=True,
711+
with_flops=False,
712+
on_trace_ready=_dump_trace,
713+
)
714+
715+
with prof:
716+
for i in range(10):
717+
out = parallel_mod(*x)
718+
out.backward(torch.randn_like(out))
719+
prof.step()

0 commit comments

Comments
 (0)