Skip to content

Commit a368fbc

Browse files
committed
fix test errors
1 parent 798d879 commit a368fbc

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

examples/example_llama3.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -558,14 +558,18 @@ def forward(self, tokens: torch.Tensor, input_batch: torch.Tensor | None = None)
558558

559559
world_size = 256
560560

561-
backend = "fake"
562-
kwargs = {"rank": 0, "world_size": world_size}
563-
if True:
564-
backend = "nccl"
565-
fake_store = None
566-
kwargs = {}
567-
world_size = 8
568-
torch.distributed.init_process_group(backend, store=fake_store, **kwargs)
561+
fake_store = FakeStore()
562+
torch.distributed.init_process_group(
563+
"fake", store=fake_store, rank=0, world_size=world_size
564+
)
565+
# backend = "fake"
566+
# kwargs = {"rank": 0, "world_size": world_size}
567+
# if True:
568+
# backend = "nccl"
569+
# fake_store = None
570+
# kwargs = {}
571+
# world_size = 8
572+
# torch.distributed.init_process_group(backend, store=fake_store, **kwargs)
569573

570574
use_1d_mesh = False
571575

@@ -618,9 +622,8 @@ def input_fn():
618622
torch._inductor.config.allow_buffer_reuse = False
619623
torch._inductor.config.reorder_for_peak_memory = False
620624
torch._inductor.config.reorder_for_compute_comm_overlap = True
621-
simplefsdp_autobucketing_config.save_estimation_path = (
622-
"/storage/home/fmassa/work/projects/autoparallel/estimation_mast.pkl"
623-
)
625+
simplefsdp_autobucketing_config.calibrate_number = 5
626+
simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl"
624627
simple_fsdp_autobucketing_reordering_pass = partial(
625628
simple_fsdp_autobucketing_reordering_pass,
626629
configs=simplefsdp_autobucketing_config,

0 commit comments

Comments
 (0)