Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[Torch] Model Parallel Customization (#2839)
Browse files Browse the repository at this point in the history
* two mp updates

* black
  • Loading branch information
klshuster authored Jul 23, 2020
1 parent e1c7894 commit 1b0ef77
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
16 changes: 12 additions & 4 deletions parlai/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ def _place_modulelist(self, submodule: torch.nn.Module) -> None:
if not isinstance(submodule, torch.nn.ModuleList):
# not a ModuleList, leave it untouched
return
if getattr(submodule, 'model_parallel_exempt', False):
return

assert isinstance(submodule, torch.nn.ModuleList) # for typechecker
layers = submodule
Expand All @@ -396,7 +398,7 @@ def _place_modulelist(self, submodule: torch.nn.Module) -> None:
# mark a layer as going to the given element
layer_assignments[mostfree] += 1

devices = self.devices[:]
devices = [d for i, d in enumerate(self.devices[:]) if layer_assignments[d] > 0]
for layer_no, layer in enumerate(layers):
layer_gpu = devices[0]
assert layer_assignments[layer_gpu] > 0
Expand Down Expand Up @@ -502,7 +504,9 @@ def join(items: List[Chunk], dim=0) -> Chunk:
# base case
return torch.cat(items, dim=dim) # type: ignore
elif isinstance(item0, tuple):
return tuple(PipelineHelper.join(x, dim=dim) for x in zip(*items)) # type: ignore
return tuple(
PipelineHelper.join(x, dim=dim) for x in zip(*items)
) # type: ignore
elif isinstance(item0, dict):
keys = item0.keys()
return { # type: ignore
Expand All @@ -522,9 +526,13 @@ def chunk_to(chunk: Chunk, device: str) -> Chunk:
if isinstance(chunk, torch.Tensor):
return chunk.to(device) # type: ignore
elif isinstance(chunk, tuple):
return tuple(PipelineHelper.chunk_to(c, device) for c in chunk) # type: ignore
return tuple(
PipelineHelper.chunk_to(c, device) for c in chunk
) # type: ignore
elif isinstance(chunk, dict):
return {k: PipelineHelper.chunk_to(v, device) for k, v in chunk.items()} # type: ignore
return {
k: PipelineHelper.chunk_to(v, device) for k, v in chunk.items()
} # type: ignore
else:
raise TypeError('chunk_to only compatible with tensors, tuples or dicts.')

Expand Down
24 changes: 24 additions & 0 deletions tests/test_utils_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,27 @@ def test_schedule_work_items(self):
assert work_items[5].layer_nos == [6, 7] and work_items[5].chunk_idx == 0
assert work_items[6].layer_nos == [4, 5] and work_items[6].chunk_idx == 1
assert work_items[7].layer_nos == [6, 7] and work_items[7].chunk_idx == 1

def test_model_parallel_exempt(self):
# Test that we ignore module lists explicitly marked as exempt.
def _get_model():
model = torch.nn.Module()
model.layers = torch.nn.ModuleList([IdentityLayer() for _ in range(8)])
return model

def _exempt_mp(submodule):
submodule.model_parallel_exempt = True

pipeline = PipelineHelper()
pipeline.num_devices = 8
pipeline.devices = [f'cuda:{i}' for i in range(8)]
pipeline._PipelineHelper__device_allocations = {d: 0 for d in pipeline.devices}

model1 = _get_model()
model1 = pipeline.make_parallel(model1)
assert getattr(model1.layers, 'is_model_parallel', False)

model2 = _get_model()
model2.apply(_exempt_mp)
model2 = pipeline.make_parallel(model2)
assert not getattr(model2.layers, 'is_model_parallel', False)

0 comments on commit 1b0ef77

Please sign in to comment.