@@ -558,10 +558,14 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
558558
559559world_size = 256
560560
561- fake_store = FakeStore ()
562- torch .distributed .init_process_group (
563- "fake" , store = fake_store , rank = 0 , world_size = world_size
564- )
561+ backend = "fake"
562+ kwargs = {"rank" : 0 , "world_size" : world_size }
563+ if True :
564+ backend = "nccl"
565+ fake_store = None
566+ kwargs = {}
567+ world_size = 8
568+ torch .distributed .init_process_group (backend , store = fake_store , ** kwargs )
565569
566570use_1d_mesh = False
567571
@@ -588,7 +592,7 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
588592
589593def model_fn ():
590594 model_args = TransformerModelArgs (
591- n_layers = 32 ,
595+ n_layers = 2 ,
592596 vocab_size = vocab_size ,
593597 max_seq_len = seqlen ,
594598 multiple_of = 1024 ,
@@ -604,6 +608,28 @@ def input_fn():
604608 return x
605609
606610
611+ from functools import partial
612+
613+ from autoparallel .auto_bucketing import (
614+ simple_fsdp_autobucketing_reordering_pass ,
615+ simplefsdp_autobucketing_config ,
616+ )
617+
618+ torch ._inductor .config .allow_buffer_reuse = False
619+ torch ._inductor .config .reorder_for_peak_memory = False
620+ torch ._inductor .config .reorder_for_compute_comm_overlap = True
621+ simplefsdp_autobucketing_config .save_estimation_path = (
622+ "/storage/home/fmassa/work/projects/autoparallel/estimation_mast.pkl"
623+ )
624+ simple_fsdp_autobucketing_reordering_pass = partial (
625+ simple_fsdp_autobucketing_reordering_pass ,
626+ configs = simplefsdp_autobucketing_config ,
627+ )
628+ torch ._inductor .config .reorder_for_compute_comm_overlap_passes = [
629+ simple_fsdp_autobucketing_reordering_pass
630+ ]
631+
632+
607633# parallelize the model
608634with torch .device ("meta" ):
609635 model = model_fn ()
@@ -659,3 +685,35 @@ def input_fn():
659685out = parallel_mod (* x )
660686out .backward (torch .randn_like (out ))
661687print ("All good!" )
688+
689+
690+ def _dump_trace (prof ):
691+ prof .export_chrome_trace (
692+ f"/home/fmassa/work/projects/autoparallel/traces/rank{ torch .distributed .get_rank ()} .json"
693+ )
694+
695+
696+ prof = torch .profiler .profile (
697+ activities = [
698+ torch .profiler .ProfilerActivity .CUDA ,
699+ torch .profiler .ProfilerActivity .CPU ,
700+ ],
701+ schedule = torch .profiler .schedule (
702+ warmup = 2 ,
703+ active = 2 ,
704+ wait = 1 ,
705+ repeat = 1 ,
706+ skip_first = 0 ,
707+ ),
708+ record_shapes = False ,
709+ profile_memory = False ,
710+ with_stack = True ,
711+ with_flops = False ,
712+ on_trace_ready = _dump_trace ,
713+ )
714+
715+ with prof :
716+ for i in range (10 ):
717+ out = parallel_mod (* x )
718+ out .backward (torch .randn_like (out ))
719+ prof .step ()
0 commit comments