diff --git a/parlai/utils/torch.py b/parlai/utils/torch.py index 5ca1f51636d..b2b0f06f205 100644 --- a/parlai/utils/torch.py +++ b/parlai/utils/torch.py @@ -370,6 +370,8 @@ def _place_modulelist(self, submodule: torch.nn.Module) -> None: if not isinstance(submodule, torch.nn.ModuleList): # not a ModuleList, leave it untouched return + if getattr(submodule, 'model_parallel_exempt', False): + return assert isinstance(submodule, torch.nn.ModuleList) # for typechecker layers = submodule @@ -396,7 +398,7 @@ def _place_modulelist(self, submodule: torch.nn.Module) -> None: # mark a layer as going to the given element layer_assignments[mostfree] += 1 - devices = self.devices[:] + devices = [d for i, d in enumerate(self.devices[:]) if layer_assignments[d] > 0] for layer_no, layer in enumerate(layers): layer_gpu = devices[0] assert layer_assignments[layer_gpu] > 0 @@ -502,7 +504,9 @@ def join(items: List[Chunk], dim=0) -> Chunk: # base case return torch.cat(items, dim=dim) # type: ignore elif isinstance(item0, tuple): - return tuple(PipelineHelper.join(x, dim=dim) for x in zip(*items)) # type: ignore + return tuple( + PipelineHelper.join(x, dim=dim) for x in zip(*items) + ) # type: ignore elif isinstance(item0, dict): keys = item0.keys() return { # type: ignore @@ -522,9 +526,13 @@ def chunk_to(chunk: Chunk, device: str) -> Chunk: if isinstance(chunk, torch.Tensor): return chunk.to(device) # type: ignore elif isinstance(chunk, tuple): - return tuple(PipelineHelper.chunk_to(c, device) for c in chunk) # type: ignore + return tuple( + PipelineHelper.chunk_to(c, device) for c in chunk + ) # type: ignore elif isinstance(chunk, dict): - return {k: PipelineHelper.chunk_to(v, device) for k, v in chunk.items()} # type: ignore + return { + k: PipelineHelper.chunk_to(v, device) for k, v in chunk.items() + } # type: ignore else: raise TypeError('chunk_to only compatible with tensors, tuples or dicts.') diff --git a/tests/test_utils_torch.py b/tests/test_utils_torch.py index f0ed8366292..4e114733ced 100644 --- a/tests/test_utils_torch.py +++ b/tests/test_utils_torch.py @@ -252,3 +252,27 @@ def test_schedule_work_items(self): assert work_items[5].layer_nos == [6, 7] and work_items[5].chunk_idx == 0 assert work_items[6].layer_nos == [4, 5] and work_items[6].chunk_idx == 1 assert work_items[7].layer_nos == [6, 7] and work_items[7].chunk_idx == 1 + + def test_model_parallel_exempt(self): + # Test that we ignore module lists explicitly marked as exempt. + def _get_model(): + model = torch.nn.Module() + model.layers = torch.nn.ModuleList([IdentityLayer() for _ in range(8)]) + return model + + def _exempt_mp(submodule): + submodule.model_parallel_exempt = True + + pipeline = PipelineHelper() + pipeline.num_devices = 8 + pipeline.devices = [f'cuda:{i}' for i in range(8)] + pipeline._PipelineHelper__device_allocations = {d: 0 for d in pipeline.devices} + + model1 = _get_model() + model1 = pipeline.make_parallel(model1) + assert getattr(model1.layers, 'is_model_parallel', False) + + model2 = _get_model() + model2.apply(_exempt_mp) + model2 = pipeline.make_parallel(model2) + assert not getattr(model2.layers, 'is_model_parallel', False)