@@ -558,14 +558,18 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
558558
559559world_size = 256
560560
561- backend = "fake"
562- kwargs = {"rank" : 0 , "world_size" : world_size }
563- if True :
564- backend = "nccl"
565- fake_store = None
566- kwargs = {}
567- world_size = 8
568- torch .distributed .init_process_group (backend , store = fake_store , ** kwargs )
561+ fake_store = FakeStore ()
562+ torch .distributed .init_process_group (
563+ "fake" , store = fake_store , rank = 0 , world_size = world_size
564+ )
565+ # backend = "fake"
566+ # kwargs = {"rank": 0, "world_size": world_size}
567+ # if True:
568+ # backend = "nccl"
569+ # fake_store = None
570+ # kwargs = {}
571+ # world_size = 8
572+ # torch.distributed.init_process_group(backend, store=fake_store, **kwargs)
569573
570574use_1d_mesh = False
571575
@@ -618,9 +622,8 @@ def input_fn():
618622torch ._inductor .config .allow_buffer_reuse = False
619623torch ._inductor .config .reorder_for_peak_memory = False
620624torch ._inductor .config .reorder_for_compute_comm_overlap = True
621- simplefsdp_autobucketing_config .save_estimation_path = (
622- "/storage/home/fmassa/work/projects/autoparallel/estimation_mast.pkl"
623- )
625+ simplefsdp_autobucketing_config .calibrate_number = 5
626+ simplefsdp_autobucketing_config .save_estimation_path = "./estimation_mast.pkl"
624627simple_fsdp_autobucketing_reordering_pass = partial (
625628 simple_fsdp_autobucketing_reordering_pass ,
626629 configs = simplefsdp_autobucketing_config ,
0 commit comments