Skip to content

Commit

Permalink
Skip EarlyStopping and ModelCheckpoint Callbacks for Torch 1.13+ (hor…
Browse files Browse the repository at this point in the history
…ovod#3778)

* skip EarlyStopping and ModelCheckpoint Callbacks for Torch 1.13+

Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc authored Nov 21, 2022
1 parent 0268506 commit e392eb9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
39 changes: 28 additions & 11 deletions examples/spark/pytorch/pytorch_lightning_spark_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,34 @@ def on_train_end(self, trainer, model):

callbacks = [MyDummyCallback()]

# added EarlyStopping and ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
callbacks.append(ModelCheckpoint(monitor='val_loss', mode="min",
save_top_k=1, verbose=True))

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
callbacks.append(EarlyStopping(monitor='val_loss',
min_delta=0.001,
patience=3,
verbose=True,
mode='min'))
if version.parse(torch.__version__) < version.parse('1.13'):
"""
torch.distributed.ReduceOp is used in ModelCheckpoint and EarlyStopping.
Since torch 1.13, it doesn't support condition check in Lightning code.
Broken line in lightning code (https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/strategies/horovod.py#L179)
Below error will be thrown:
>>> from torch.distributed import ReduceOp
>>> op = None
>>> op in (ReduceOp.SUM, None)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: __eq__(): incompatible function arguments. The following argument types are supported:
1. (self: torch._C._distributed_c10d.ReduceOp, arg0: c10d::ReduceOp::RedOpType) -> bool
2. (self: torch._C._distributed_c10d.ReduceOp, arg0: torch._C._distributed_c10d.ReduceOp) -> bool
Invoked with: <torch.distributed.distributed_c10d.ReduceOp object at 0x7fba78c9e0b0>, None
"""
# ModelCheckpoint
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
callbacks.append(ModelCheckpoint(monitor='val_loss', mode="min",
save_top_k=1, verbose=True))
# EarlyStopping
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
callbacks.append(EarlyStopping(monitor='val_loss',
min_delta=0.001,
patience=3,
verbose=True,
mode='min'))

torch_estimator = hvd.TorchEstimator(backend=backend,
store=store,
Expand Down
3 changes: 3 additions & 0 deletions test/integration/test_spark_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,9 @@ def test_early_stop_callback(self):
self.skipTest('Spark PyTorch Lightning tests conflict with Tensorflow 2.5.x: '
'https://github.com/horovod/horovod/pull/3263')

if version.parse(torch.__version__) >= version.parse('1.13'):
self.skipTest('Torch 1.13+ fails EarlyStopping CB usage with Horovod.')

from pytorch_lightning.callbacks.early_stopping import EarlyStopping

with spark_session('test_fit_model') as spark:
Expand Down

0 comments on commit e392eb9

Please sign in to comment.