Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ModelCheckpoint race condition in file existence check #5155

Merged
merged 30 commits into from
Jan 27, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
1af3218
fix
awaelchli Dec 16, 2020
0f2a986
remove repro script
awaelchli Dec 16, 2020
6ffb64b
type annotation
awaelchli Dec 16, 2020
9b7512a
added changelog
awaelchli Dec 16, 2020
b545a0f
fix import problem
awaelchli Dec 17, 2020
622f929
Merge branch 'master' into bugfix/ddp-ckpt
awaelchli Dec 17, 2020
5d1b297
fix NoneType error
awaelchli Dec 18, 2020
2c1a07d
Merge branch 'master' into bugfix/ddp-ckpt
awaelchli Dec 18, 2020
e5d12cd
fix test
Dec 20, 2020
c13345d
Merge branch 'master' into bugfix/ddp-ckpt
Dec 20, 2020
f8529a4
debugging
Dec 25, 2020
f7b528a
debug
Dec 25, 2020
324c7c7
debug
Dec 25, 2020
cbb8b81
Merge branch 'master' into bugfix/ddp-ckpt
Dec 25, 2020
b717dcc
debug
awaelchli Dec 29, 2020
988167b
skip horovod apex test to see if all others pass
awaelchli Dec 29, 2020
87357eb
remove debug message
awaelchli Dec 29, 2020
16fbebf
Merge branch 'master' into bugfix/ddp-ckpt
awaelchli Jan 2, 2021
f2cad67
Merge branch 'master' into bugfix/ddp-ckpt
awaelchli Jan 8, 2021
fb7533c
no-op broadcast for single core tpu
awaelchli Jan 8, 2021
34e76b9
spelling
awaelchli Jan 11, 2021
7ced05b
OMG add a barrier
awaelchli Jan 24, 2021
437b5a2
skip
awaelchli Jan 24, 2021
920df05
Merge branch 'master' into bugfix/ddp-ckpt
awaelchli Jan 24, 2021
68b0d34
Update CHANGELOG.md
Borda Jan 26, 2021
3393868
add back skip in test
awaelchli Jan 26, 2021
3398164
Merge branch 'master' into bugfix/ddp-ckpt
awaelchli Jan 26, 2021
7f3dd80
add changelog
awaelchli Jan 26, 2021
81f7c14
Merge branch 'master' into bugfix/ddp-ckpt
Borda Jan 27, 2021
3fccc5b
Merge branch 'master' into bugfix/ddp-ckpt
SkafteNicki Jan 27, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `DDPHPCAccelerator` hangs in DDP construction by calling `init_device` ([#5157](https://github.com/PyTorchLightning/pytorch-lightning/pull/5157))

- Fixed a race condition in `ModelCheckpoint` when checking if a checkpoint file exists ([#5144](https://github.com/PyTorchLightning/pytorch-lightning/pull/5144))


## [1.1.0] - 2020-12-09

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
mp_queue.put(last_path)

def broadcast(self, obj, src=0):
if self.trainer.tpu_id is not None:
# running on a single core
return obj
buffer = io.BytesIO()
torch.save(obj, buffer)
data = bytearray(buffer.getbuffer())
Expand Down
19 changes: 15 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,12 +520,13 @@ def _get_metric_interpolated_filepath_name(
ckpt_name_metrics: Dict[str, Any],
epoch: int,
step: int,
del_filepath: Optional[str] = None
trainer,
del_filepath: Optional[str] = None,
) -> str:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)

version_cnt = 0
while self._fs.exists(filepath) and filepath != del_filepath:
while self.file_exists(filepath, trainer) and filepath != del_filepath:
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
version_cnt += 1

Expand Down Expand Up @@ -555,7 +556,7 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
else:
last_filepath = self._get_metric_interpolated_filepath_name(
ckpt_name_metrics, trainer.current_epoch, trainer.global_step
ckpt_name_metrics, trainer.current_epoch, trainer.global_step, trainer,
)

accelerator_backend = trainer.accelerator_backend
Expand Down Expand Up @@ -618,7 +619,7 @@ def _update_best_and_save(
if torch.isnan(current):
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))

filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, del_filepath)
filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, trainer, del_filepath)

# save the current score
self.current_score = current
Expand Down Expand Up @@ -656,3 +657,13 @@ def to_yaml(self, filepath: Optional[Union[str, Path]] = None):
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
with self._fs.open(filepath, "w") as fp:
yaml.dump(best_k, fp)

def file_exists(self, filepath: Union[str, Path], trainer) -> bool:
"""
Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing
the internal state to diverge between ranks.
"""
exists = self._fs.exists(filepath)
if trainer.accelerator_backend is not None:
exists = trainer.accelerator_backend.broadcast(exists)
Copy link
Contributor Author

@awaelchli awaelchli Dec 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tgaddair, sorry to ping you out of nowhere but I am stuck here with this broadcast causing a horovod test to hang (test_horovod_apex).
Do you see something obviously wrong about this broadcasting I'm trying to do here?

(print statements and systemexit above were just for debugging attempts)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without having taken a close look at the code, I'm guessing that because we're in model_checkpoint.py, that this code is only being executed on rank 0? Is that possible?

There should be some messages printed out by Horovod when such a stall occurs after about 30s or so. Are they being printed? Can you share them here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(lightning) [aw18f408@vnode03 pytorch-lightning]$ py.test -v tests/models/test_horovod.py::test_horovod_apex -s
================================================================================================ test session starts =================================================================================================
platform linux -- Python 3.8.3, pytest-6.0.1, py-1.9.0, pluggy-0.13.1 -- /home/aw18f408/.conda/envs/lightning/bin/python
cachedir: .pytest_cache
rootdir: /home/aw18f408/repositories/pytorch-lightning, configfile: setup.cfg
plugins: hydra-core-1.0.4, flake8-1.0.6, cov-2.10.0
collected 1 item

tests/models/test_horovod.py::test_horovod_apex [0]<stderr>:/home/aw18f408/.conda/envs/lightning/lib/python3.8/site-packages/graphql/type/directives.py:55: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working
[1]<stderr>:/home/aw18f408/.conda/envs/lightning/lib/python3.8/site-packages/graphql/type/directives.py:55: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working
[0]<stderr>:  assert isinstance(locations, collections.Iterable), 'Must provide locations for directive.'
[1]<stderr>:  assert isinstance(locations, collections.Iterable), 'Must provide locations for directive.'
[0]<stderr>:/home/aw18f408/.conda/envs/lightning/lib/python3.8/site-packages/graphql/type/typemap.py:1: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working
[0]<stderr>:  from collections import OrderedDict, Sequence, defaultdict
[1]<stderr>:/home/aw18f408/.conda/envs/lightning/lib/python3.8/site-packages/graphql/type/typemap.py:1: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working
[1]<stderr>:  from collections import OrderedDict, Sequence, defaultdict
[1]<stderr>:/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/utilities/distributed.py:49: DeprecationWarning: mode='auto' is deprecated in v1.1 and will be removed in v1.3. Default value for mode with be 'min' in v1.3.
[0]<stderr>:/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/utilities/distributed.py:49: DeprecationWarning: mode='auto' is deprecated in v1.1 and will be removed in v1.3. Default value for mode with be 'min' in v1.3.
[0]<stderr>:  warnings.warn(*args, **kwargs)
[1]<stderr>:  warnings.warn(*args, **kwargs)
[0]<stderr>:GPU available: True, used: True
[0]<stderr>:TPU available: None, using: 0 TPU cores
[1]<stderr>:GPU available: True, used: True
[1]<stderr>:TPU available: None, using: 0 TPU cores
[0]<stderr>:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
[1]<stderr>:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
[0]<stderr>:Using APEX 16bit precision.
[1]<stderr>:Using APEX 16bit precision.
[0]<stderr>:/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/precision_connector.py:69: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
[1]<stderr>:/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/trainer/connectors/precision_connector.py:69: DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead
[0]<stderr>:  log.warn("LightningOptimizer doesn't support Apex")
[1]<stderr>:  log.warn("LightningOptimizer doesn't support Apex")
[0]<stderr>:LightningOptimizer doesn't support Apex
[1]<stderr>:LightningOptimizer doesn't support Apex
[0]<stdout>:Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
[0]<stdout>:
[0]<stdout>:Defaults for this optimization level are:
[0]<stdout>:enabled                : True
[1]<stdout>:Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
[0]<stdout>:opt_level              : O2
[1]<stdout>:
[0]<stdout>:cast_model_type        : torch.float16
[0]<stdout>:patch_torch_functions  : False
[0]<stdout>:keep_batchnorm_fp32    : True
[1]<stdout>:Defaults for this optimization level are:
[0]<stdout>:master_weights         : True
[1]<stdout>:enabled                : True
[0]<stdout>:loss_scale             : dynamic
[1]<stdout>:opt_level              : O2
[0]<stdout>:Processing user overrides (additional kwargs that are not None)...
[1]<stdout>:cast_model_type        : torch.float16
[0]<stdout>:After processing overrides, optimization options are:
[1]<stdout>:patch_torch_functions  : False
[0]<stdout>:enabled                : True
[1]<stdout>:keep_batchnorm_fp32    : True
[0]<stdout>:opt_level              : O2
[1]<stdout>:master_weights         : True
[0]<stdout>:cast_model_type        : torch.float16
[1]<stdout>:loss_scale             : dynamic
[0]<stdout>:patch_torch_functions  : False
[1]<stdout>:Processing user overrides (additional kwargs that are not None)...
[0]<stdout>:keep_batchnorm_fp32    : True
[1]<stdout>:After processing overrides, optimization options are:
[0]<stdout>:master_weights         : True
[1]<stdout>:enabled                : True
[0]<stdout>:loss_scale             : dynamic
[1]<stdout>:opt_level              : O2
[1]<stdout>:cast_model_type        : torch.float16
[1]<stdout>:patch_torch_functions  : False
[1]<stdout>:keep_batchnorm_fp32    : True
[1]<stdout>:master_weights         : True
[1]<stdout>:loss_scale             : dynamic
[0]<stderr>:
[0]<stderr>:  | Name      | Type        | Params | In sizes  | Out sizes
[0]<stderr>:------------------------------------------------------------------
[0]<stderr>:0 | c_d1      | Linear      | 785 K  | [5, 784]  | [5, 1000]
[0]<stderr>:1 | c_d1_bn   | BatchNorm1d | 2.0 K  | [5, 1000] | [5, 1000]
[0]<stderr>:2 | c_d1_drop | Dropout     | 0      | [5, 1000] | [5, 1000]
[0]<stderr>:3 | c_d2      | Linear      | 10.0 K | [5, 1000] | [5, 10]
[0]<stderr>:------------------------------------------------------------------
[0]<stderr>:797 K     Trainable params
[0]<stderr>:0         Non-trainable params
[0]<stderr>:797 K     Total params
[0]<stderr>:/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
[0]<stderr>:  warnings.warn(*args, **kwargs)
[0]<stderr>:/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/utilities/distributed.py:49: UserWarning: The validation_epoch_end should not return anything as of 9.1. To log, use self.log(...) or self.write(...) directly in the LightningModule
[0]<stderr>:  warnings.warn(*args, **kwargs)
[0]<stderr>:/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/utilities/distributed.py:49: UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
[0]<stderr>:Please use self.log(...) inside the lightningModule instead.
[0]<stderr>:
[0]<stderr>:# log on a step or aggregate epoch metric to the logger and/or progress bar
[0]<stderr>:# (inside LightningModule)
[0]<stderr>:self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
[0]<stderr>:  warnings.warn(*args, **kwargs)
[0]<stderr>:/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/utilities/distributed.py:49: UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
[0]<stderr>:Please use self.log(...) inside the lightningModule instead.
[0]<stderr>:
[0]<stderr>:# log on a step or aggregate epoch metric to the logger and/or progress bar
[0]<stderr>:# (inside LightningModule)
[0]<stderr>:self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
[0]<stderr>:  warnings.warn(*args, **kwargs)
[0]<stderr>:/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
[0]<stderr>:  warnings.warn(*args, **kwargs)
[1]<stdout>:Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32768.0
[0]<stderr>:/home/aw18f408/.conda/envs/lightning/lib/python3.8/site-packages/horovod/torch/optimizer.py:252: UserWarning: optimizer.step() called without optimizer.skip_synchronize() context after optimizer.synchronize(). This can cause training slowdown. You may want to consider using optimizer.skip_synchronize() context if you use optimizer.synchronize() in your code.
[0]<stderr>:  warnings.warn("optimizer.step() called without "
[1]<stderr>:/home/aw18f408/.conda/envs/lightning/lib/python3.8/site-packages/horovod/torch/optimizer.py:252: UserWarning: optimizer.step() called without optimizer.skip_synchronize() context after optimizer.synchronize(). This can cause training slowdown. You may want to consider using optimizer.skip_synchronize() context if you use optimizer.synchronize() in your code.
[1]<stderr>:  warnings.warn("optimizer.step() called without "
[1]<stdout>:1 <pytorch_lightning.accelerators.horovod_accelerator.HorovodAccelerator object at 0x7f9efa5c2490>
[1]<stdout>:1 exists: False /tmp/pytest-of-aw18f408/pytest-0/test_horovod_apex0/epoch=0-step=1.ckpt
[0]<stderr>:[2020-12-29 16:49:57.723209: W /tmp/pip-install-uk4q7bn5/horovod/horovod/common/stall_inspector.cc:105] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0]<stderr>:Missing ranks:
[0]<stderr>:0: [broadcast.bool.sz]
[0]<stderr>:[2020-12-29 16:50:57.724380: W /tmp/pip-install-uk4q7bn5/horovod/horovod/common/stall_inspector.cc:105] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0]<stderr>:Missing ranks:
[0]<stderr>:0: [broadcast.bool.sz]
[0]<stderr>:[2020-12-29 16:51:57.726092: W /tmp/pip-install-uk4q7bn5/horovod/horovod/common/stall_inspector.cc:105] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0]<stderr>:Missing ranks:
[0]<stderr>:0: [broadcast.bool.sz]
[0]<stderr>:1: [allreduce.c_d1.bias, allreduce.c_d1.weight, allreduce.c_d1_bn.bias, allreduce.c_d1_bn.weight, allreduce.c_d2.bias, allreduce.c_d2.weight ...]
[0]<stderr>:[2020-12-29 16:52:57.726326: W /tmp/pip-install-uk4q7bn5/horovod/horovod/common/stall_inspector.cc:105] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0]<stderr>:Missing ranks:
[0]<stderr>:0: [broadcast.bool.sz]
[0]<stderr>:1: [allreduce.c_d1.bias, allreduce.c_d1.weight, allreduce.c_d1_bn.bias, allreduce.c_d1_bn.weight, allreduce.c_d2.bias, allreduce.c_d2.weight ...]
[0]<stderr>:[2020-12-29 16:53:57.726988: W /tmp/pip-install-uk4q7bn5/horovod/horovod/common/stall_inspector.cc:105] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0]<stderr>:Missing ranks:
[0]<stderr>:0: [broadcast.bool.sz]
[0]<stderr>:1: [allreduce.c_d1.bias, allreduce.c_d1.weight, allreduce.c_d1_bn.bias, allreduce.c_d1_bn.weight, allreduce.c_d2.bias, allreduce.c_d2.weight ...]
[0]<stderr>:[2020-12-29 16:54:57.727621: W /tmp/pip-install-uk4q7bn5/horovod/horovod/common/stall_inspector.cc:105] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0]<stderr>:Missing ranks:
[0]<stderr>:0: [broadcast.bool.sz]
[0]<stderr>:1: [allreduce.c_d1.bias, allreduce.c_d1.weight, allreduce.c_d1_bn.bias, allreduce.c_d1_bn.weight, allreduce.c_d2.bias, allreduce.c_d2.weight ...]
[0]<stderr>:[2020-12-29 16:55:57.728667: W /tmp/pip-install-uk4q7bn5/horovod/horovod/common/stall_inspector.cc:105] One or more tensors were submitted to be reduced, gathered or broadcasted by subset of ranks and are waiting for remainder of ranks for more than 60 seconds. This may indicate that different ranks are trying to submit different tensors or that only subset of ranks is submitting tensors, which will cause deadlock.
[0]<stderr>:Missing ranks:
[0]<stderr>:0: [broadcast.bool.sz]
[0]<stderr>:1: [allreduce.c_d1.bias, allreduce.c_d1.weight, allreduce.c_d1_bn.bias, allreduce.c_d1_bn.weight, allreduce.c_d2.bias, allreduce.c_d2.weight ...]

Thanks for the hint, I found out that I can append the -s option in pytest and get the messages. And indeed as you said, there is the message about the stall

[0]<stderr>:Missing ranks:
[0]<stderr>:0: [broadcast.bool.sz]
[0]<stderr>:1: [allreduce.c_d1.bias, allreduce.c_d1.weight, allreduce.c_d1_bn.bias, allreduce.c_d1_bn.weight, allreduce.c_d2.bias, allreduce.c_d2.weight ...]

Does the above message mean that rank 0 missed the broadcast?
The model checkpoint code should be exectued on all ranks, the only difference should be that it is only allowed to write to disk on rank 0.

Copy link
Contributor

@tgaddair tgaddair Dec 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, looks like rank 1 is entering the checkpoint logic while rank 0 is still training the model. So it seems there is some non-deterministic behavior causing rank 1 to write a checkpoint. For example, there could be something like this going on:

if some_local_metric < threshold:
    write_checkpoint()
continue_training_loop()

That's hypothetical, but if the above logic existed, and rank 1 satisfied the condition but rank 0 didn't, it could lead to the situation above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @awaelchli, any update there ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand what @tgaddair explains, but I can't find where in Lightning the source of the problem occurs. There is one test that fails, and the only difference between that test and the other horovod tests is that apex is turned on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tgaddair does horovod.broadcast_object not block? It looks like adding a barrier solves the problem. The failing apex test now passes locally, but the CI drone is still in trouble

return exists
3 changes: 2 additions & 1 deletion tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ def on_train_end(self, trainer, pl_module):
assert self.best_model_score
assert self.on_save_checkpoint_count == self.expected_count
if trainer.is_global_zero:
assert torch.save.call_count == self.expected_count
# twice the calls expected because ddp broadast also uses torch.save
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
assert torch.save.call_count == self.expected_count * 2
else:
assert torch.save.call_count == 0

Expand Down
1 change: 1 addition & 0 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def test_horovod_multi_gpu(tmpdir):
_run_horovod(trainer_options, on_gpu=True)


@pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does it or does it not?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to investigate yet, the only difference between this test and the one above is apex. It's as if apex is affected by my broadcast operations in model checkpoint, or the other way around

@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
Expand Down