Skip to content

Commit

Permalink
refactor: support launch with python cmd
Browse files Browse the repository at this point in the history
hard-coded launch from localhost:12355 when not provided

Ref #15
  • Loading branch information
flymin committed Dec 20, 2024
1 parent d537ecf commit 8279ff6
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 51 deletions.
44 changes: 17 additions & 27 deletions scripts/inference_magicdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,6 @@
}


class FakeCoordinator:
def block_all(self):
pass

def is_master(self):
return True

def destroy(self):
pass


def set_omegaconf_key_value(cfg, key, value):
p, m = key.rsplit(".", 1)
node = cfg
Expand Down Expand Up @@ -178,25 +167,26 @@ def main():
# colossalai.launch_from_torch({})
dist.init_process_group(backend="nccl", timeout=timedelta(hours=1))
torch.cuda.set_device(dist.get_rank() % torch.cuda.device_count())
coordinator = DistCoordinator()
cfg.sp_size = dist.get_world_size()
if cfg.sp_size > 1:
DP_AXIS, SP_AXIS = 0, 1
dp_size = dist.get_world_size() // cfg.sp_size
pg_mesh = ProcessGroupMesh(dp_size, cfg.sp_size)
dp_group = pg_mesh.get_group_along_axis(DP_AXIS)
sp_group = pg_mesh.get_group_along_axis(SP_AXIS)
set_sequence_parallel_group(sp_group)
print(f"Using sp_size={cfg.sp_size}")
else:
# TODO: sequence_parallel_group unset!
dp_group = dist.group.WORLD
set_data_parallel_group(dp_group)
enable_sequence_parallelism = cfg.sp_size > 1
else:
dist.init_process_group(
backend="nccl", world_size=1, rank=0,
init_method="tcp://localhost:12355")
cfg.sp_size = 1
coordinator = FakeCoordinator()
enable_sequence_parallelism = False
coordinator = DistCoordinator()
if cfg.sp_size > 1:
DP_AXIS, SP_AXIS = 0, 1
dp_size = dist.get_world_size() // cfg.sp_size
pg_mesh = ProcessGroupMesh(dp_size, cfg.sp_size)
dp_group = pg_mesh.get_group_along_axis(DP_AXIS)
sp_group = pg_mesh.get_group_along_axis(SP_AXIS)
set_sequence_parallel_group(sp_group)
print(f"Using sp_size={cfg.sp_size}")
else:
# TODO: sequence_parallel_group unset!
dp_group = dist.group.WORLD
set_data_parallel_group(dp_group)
enable_sequence_parallelism = cfg.sp_size > 1
set_random_seed(seed=cfg.get("seed", 1024))

# == init exp_dir ==
Expand Down
39 changes: 15 additions & 24 deletions scripts/test_magicdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,6 @@ def make_file_dirs(path):
os.makedirs(os.path.dirname(path), exist_ok=True)


class FakeCoordinator:
def block_all(self):
pass

def is_master(self):
return True

def destroy(self):
pass


def set_omegaconf_key_value(cfg, key, value):
p, m = key.rsplit(".", 1)
node = cfg
Expand Down Expand Up @@ -166,20 +155,22 @@ def main():
cfg.sp_size = cfg.get("sp_size", 1)
if is_distributed():
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
if cfg.sp_size > 1:
DP_AXIS, SP_AXIS = 0, 1
dp_size = dist.get_world_size() // cfg.sp_size
pg_mesh = ProcessGroupMesh(dp_size, cfg.sp_size)
dp_group = pg_mesh.get_group_along_axis(DP_AXIS)
sp_group = pg_mesh.get_group_along_axis(SP_AXIS)
set_sequence_parallel_group(sp_group)
else:
# TODO: sequence_parallel_group unset!
dp_group = dist.group.WORLD
set_data_parallel_group(dp_group)
else:
coordinator = FakeCoordinator()
dist.init_process_group(
backend="nccl", world_size=1, rank=0,
init_method="tcp://localhost:12355")
coordinator = DistCoordinator()
if cfg.sp_size > 1:
DP_AXIS, SP_AXIS = 0, 1
dp_size = dist.get_world_size() // cfg.sp_size
pg_mesh = ProcessGroupMesh(dp_size, cfg.sp_size)
dp_group = pg_mesh.get_group_along_axis(DP_AXIS)
sp_group = pg_mesh.get_group_along_axis(SP_AXIS)
set_sequence_parallel_group(sp_group)
else:
# TODO: sequence_parallel_group unset!
dp_group = dist.group.WORLD
set_data_parallel_group(dp_group)
set_random_seed(seed=cfg.get("seed", 1024))

# == init exp_dir ==
Expand Down

0 comments on commit 8279ff6

Please sign in to comment.