Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify access to MCR-DL #17

Merged
merged 2 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions benchmarks/all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ def timed_all_gather(input, output, start_event, end_event, args):


def run_all_gather(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

# Prepare benchmark header
print_header(args, 'all_gather')
Expand Down Expand Up @@ -149,5 +146,5 @@ def run_all_gather(local_rank, args):
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)
run_all_gather(local_rank=rank, args=args)
15 changes: 6 additions & 9 deletions benchmarks/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@


def timed_all_reduce(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
import mcr_dl
dist = mcr_dl.get_distributed_engine()

sync_all()
# Warmups, establish connections, etc.
Expand Down Expand Up @@ -62,10 +60,8 @@ def timed_all_reduce(input, start_event, end_event, args):


def run_all_reduce(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
import mcr_dl
dist = mcr_dl.get_distributed_engine()

# Prepare benchmark header
print_header(args, 'all_reduce')
Expand Down Expand Up @@ -125,7 +121,8 @@ def run_all_reduce(local_rank, args):


if __name__ == "__main__":
import mcr_dl
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)
run_all_reduce(local_rank=rank, args=args)
12 changes: 3 additions & 9 deletions benchmarks/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@


def timed_all_to_all(input, output, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

sync_all()
# Warmups, establish connections, etc.
Expand Down Expand Up @@ -62,10 +59,7 @@ def timed_all_to_all(input, output, start_event, end_event, args):


def run_all_to_all(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

world_size = dist.get_world_size()
global_rank = dist.get_rank()
Expand Down Expand Up @@ -147,5 +141,5 @@ def run_all_to_all(local_rank, args):
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)
run_all_to_all(local_rank=rank, args=args)
12 changes: 3 additions & 9 deletions benchmarks/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@


def timed_broadcast(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

sync_all()
# Warmups, establish connections, etc.
Expand Down Expand Up @@ -62,10 +59,7 @@ def timed_broadcast(input, start_event, end_event, args):


def run_broadcast(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

# Prepare benchmark header
print_header(args, 'broadcast')
Expand Down Expand Up @@ -125,5 +119,5 @@ def run_broadcast(local_rank, args):
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)
run_broadcast(local_rank=rank, args=args)
12 changes: 3 additions & 9 deletions benchmarks/pt2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@


def timed_pt2pt(input, start_event, end_event, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

sync_all()
# Warmups, establish connections, etc.
Expand Down Expand Up @@ -81,10 +78,7 @@ def timed_pt2pt(input, start_event, end_event, args):


def run_pt2pt(local_rank, args):
if args.dist == 'torch':
import torch.distributed as dist
elif args.dist == 'mcr_dl':
import mcr_dl as dist
dist = mcr_dl.get_distributed_engine()

# Prepare benchmark header
print_header(args, 'pt2pt')
Expand Down Expand Up @@ -144,5 +138,5 @@ def run_pt2pt(local_rank, args):
if __name__ == "__main__":
args = benchmark_parser().parse_args()
rank = args.local_rank
init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)
run_pt2pt(local_rank=rank, args=args)
2 changes: 1 addition & 1 deletion benchmarks/run_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
# For importing
def main(args, rank):

init_processes(local_rank=rank, args=args)
mcr_dl.init_processes(args.dist, args.backend)

ops_to_run = []
if args.all_reduce:
Expand Down
38 changes: 6 additions & 32 deletions benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
from mcr_dl.cuda_accelerator import get_accelerator
from mcr_dl.comm import mpi_discovery
from mcr_dl.utils import set_mpi_dist_environemnt

global dist

import mcr_dl

def env2int(env_list, default=-1):
for e in env_list:
Expand All @@ -39,41 +37,14 @@ def env2int(env_list, default=-1):
return default


def init_torch_distributed(backend):
global dist
import torch.distributed as dist
if backend == 'nccl':
mpi_discovery()
elif backend == 'mpi':
set_mpi_dist_environemnt()
dist.init_process_group(backend)
local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(local_rank)

def init_mcr_dl_comm(backend):
global dist
import mcr_dl as dist
dist.init_distributed(dist_backend=backend, use_mcr_dl=True)
local_rank = int(os.environ['LOCAL_RANK'])
get_accelerator().set_device(local_rank)


def init_processes(local_rank, args):
if args.dist == 'mcr_dl':
init_mcr_dl_comm(args.backend)
elif args.dist == 'torch':
init_torch_distributed(args.backend)
else:
print_rank_0(f"distributed framework {args.dist} not supported")
exit(0)


def print_rank_0(message):
dist = mcr_dl.get_distributed_engine()
if dist.get_rank() == 0:
print(message)


def print_header(args, comm_op):
dist = mcr_dl.get_distributed_engine()
if comm_op == 'pt2pt':
world_size = 2
else:
Expand All @@ -90,6 +61,7 @@ def print_header(args, comm_op):


def get_bw(comm_op, size, duration, args):
dist = mcr_dl.get_distributed_engine()
n = dist.get_world_size()
tput = 0
busbw = 0
Expand Down Expand Up @@ -133,11 +105,13 @@ def get_metric_strings(args, tput, busbw, duration):


def sync_all():
dist = mcr_dl.get_distributed_engine()
get_accelerator().synchronize()
dist.barrier()


def max_numel(comm_op, dtype, mem_factor, local_rank, args):
dist = mcr_dl.get_distributed_engine()
dtype_size = _element_size(dtype)
max_memory_per_gpu = get_accelerator().total_memory(local_rank) * mem_factor
if comm_op == 'all_reduce' or comm_op == 'pt2pt' or comm_op == 'broadcast':
Expand Down
52 changes: 51 additions & 1 deletion mcr_dl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,54 @@
# limitations under the License.

from .utils import *
from .comm import *
from .comm import *

global __dist_engine
global __dist_backend

__dist_engine = None
__dist_backend = None

def init_torch_distributed(backend):
import torch.distributed as dist
if backend == 'nccl':
mpi_discovery()
elif backend == 'mpi':
set_mpi_dist_environemnt()
dist.init_process_group(backend=backend)
local_rank = int(os.environ['LOCAL_RANK'])
# get_accelerator().set_device(local_rank)
print(f'Rank : {dist.get_rank()} World_Size : {dist.get_world_size()}', flush = True)

def init_mcr_dl_comm(backend):
import mcr_dl
mcr_dl.init_distributed(dist_backend=backend, use_mcr_dl=True)
local_rank = int(os.environ['LOCAL_RANK'])
#get_accelerator().set_device(local_rank)

def init_processes(dist_engine, dist_backend, world_size = -1, rank = -1, timeout = None, init_method = None):
print(f'Comm : {dist_engine} Backend : {dist_backend}')

global __dist_engine
global __dist_backend
__dist_engine = dist_engine
__dist_backend = dist_backend
if dist_engine == 'mcr_dl':
init_mcr_dl_comm(dist_backend)
elif dist_engine == 'torch':
init_torch_distributed(dist_backend)
else:
print(f"distributed framework {dist_engine} not supported")
exit(0)

def get_distributed_engine():
global __dist_engine
if __dist_engine is None:
return None
if __dist_engine == 'torch':
return torch.distributed
elif __dist_engine == 'mcr_dl':
import mcr_dl
return mcr_dl
print(f"Unsupported values for __dist_engine. Expected values 'torch' or 'mcr_dl'")
exit(0)
5 changes: 3 additions & 2 deletions mcr_dl/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ def destroy_process_group(self, group=None):
pass

def new_group(self, ranks):
# TODO: Change this to use comm_op.new_group when the impl. is ready.
# TODO: Change this to use self.mpi_comm_op.new_group(ranks) when the impl. is ready.
if not torch.distributed.is_initialized():
from mcr_dl.torch import TorchBackend
d = TorchBackend(rank=self.rank, size=self.size)
d = TorchBackend(rank=self.rank, world_size=self.size)
logger.info(f"new group called with {ranks}")
return torch.distributed.new_group(ranks)
# return self.mpi_comm_op.new_group(ranks)

def get_rank(self, group=None):
return self.mpi_comm_op.get_rank(0)
Expand Down
5 changes: 3 additions & 2 deletions mcr_dl/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .utils import *
from .backend import *
from .comm import *
from .constants import default_pg_timeout

DS_COMM_ALL_GATHER_OFF = False
DS_COMM_REDUCE_SCATTER_OFF = False
Expand Down Expand Up @@ -119,7 +120,7 @@ class TorchBackend(Backend):
needed.
"""

def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='torch'):
def __init__(self, backend="mpi", init_method = None, timeout = default_pg_timeout, rank=-1, world_size=-1, name='torch'):
super(TorchBackend, self).__init__()
self.has_all_reduce_coalesced = has_all_reduce_coalesced()
self.has_coalescing_manager = has_coalescing_manager()
Expand All @@ -131,7 +132,7 @@ def __init__(self, backend, timeout, init_method, rank=-1, world_size=-1, name='
# The idea is to fake that dist backend is initialized even when
# it is not so we can run on a single GPU without doing any init_process_group
self.single_gpu_mode = True
self.init_process_group(backend, timeout, init_method, rank, world_size)
self.init_process_group(backend=backend, init_method=init_method, timeout= timeout, rank=rank, world_size= world_size)

@classmethod
def get_all_gather_function(self):
Expand Down
Loading
Loading