Skip to content

Commit

Permalink
more review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom Birch committed Nov 6, 2020
1 parent cb44a25 commit 4618a37
Show file tree
Hide file tree
Showing 13 changed files with 312 additions and 464 deletions.
12 changes: 4 additions & 8 deletions benchmarks/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,13 +530,10 @@ def bench_single_process(args):
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]

if args.deepspeed:
p = PipelineModule(layers=model, num_stages=min(num_devices, 4))
else:
balance = generate_balance(min(num_devices, 4), len(model))
p = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
balance = generate_balance(min(num_devices, 4), len(model))
p = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
del model
del blob["model"]

Expand Down Expand Up @@ -665,7 +662,6 @@ def bench_mpi(args):
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
parser.add_argument("--deepspeed", action="store_true", default=False, help="use eepspeed instead of fairscale pipe")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
Expand Down
2 changes: 0 additions & 2 deletions fairscale/nn/model_parallel/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def initialize_model_parallel(
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.
"""
# if torch.distributed.get_rank() == 0:
# print("> initializing model parallel with size {}".format(model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
Expand Down
10 changes: 0 additions & 10 deletions fairscale/nn/model_parallel/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor:
return input_

# All-reduce.
print(f">> doing all_reduce on {torch.distributed.get_rank()}")
torch.distributed.all_reduce(input_, group=group)
print(f"<< doing all_reduce on {torch.distributed.get_rank()}")

return input_

Expand Down Expand Up @@ -94,12 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function):

@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Forward")
return input_

@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Backward")
return _reduce(None, grad_output)


Expand All @@ -108,12 +104,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):

@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Forward")
return _reduce(ctx, input_)

@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Backward")
return grad_output


Expand All @@ -122,12 +116,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):

@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Forward")
return _split(input_)

@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Backward")
return _gather(grad_output)


Expand All @@ -136,12 +128,10 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):

@staticmethod
def forward(ctx, input_): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Forward")
return _gather(input_)

@staticmethod
def backward(ctx, grad_output): # type: ignore
print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Backward")
return _split(grad_output)


Expand Down
Loading

0 comments on commit 4618a37

Please sign in to comment.