Skip to content

Commit

Permalink
Spark/Lightning: fix the usage of checkpoint callback (horovod#3186)
Browse files Browse the repository at this point in the history
Signed-off-by: Chongxiao Cao <chongxiaoc@uber.com>
  • Loading branch information
chongxiaoc authored Sep 30, 2021
1 parent d6de12d commit 592d209
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 35 deletions.
4 changes: 3 additions & 1 deletion examples/spark/pytorch/pytorch_lightning_spark_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
help='temporary working directory to write intermediate files (prefix with hdfs:// to use HDFS)')
parser.add_argument('--data-dir', default='/tmp',
help='location of the training dataset in the local filesystem (will be downloaded if needed)')
parser.add_argument('--enable-profiler', action='store_true',
help='Enable profiler')


def train_model(args):
Expand Down Expand Up @@ -195,7 +197,7 @@ def on_train_end(self, trainer, model):
validation=0.1,
verbose=1,
callbacks=callbacks,
profiler="simple")
profiler="simple" if args.enable_profiler else None)

torch_model = torch_estimator.fit(train_df).setOutputCols(['label_prob'])

Expand Down
16 changes: 14 additions & 2 deletions horovod/spark/lightning/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,

profiler = Param(Params._dummy(), 'profiler', 'lightning profiler to use')

checkpoint_callback = Param(Params._dummy(), 'checkpoint_callback',
'model checkpointing callback')

@keyword_only
def __init__(self,
num_proc=None,
Expand Down Expand Up @@ -246,7 +249,8 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None):
profiler=None,
checkpoint_callback=None):

super(TorchEstimator, self).__init__()
self._setDefault(loss_constructors=None,
Expand All @@ -260,7 +264,8 @@ def __init__(self,
data_module=None,
loader_num_epochs=None,
terminate_on_nan=False,
profiler=None)
profiler=None,
checkpoint_callback=None)

kwargs = self._input_kwargs

Expand Down Expand Up @@ -333,6 +338,12 @@ def setTerminateOnNan(self, value):
def getTerminateOnNan(self):
return self.getOrDefault(self.terminate_on_nan)

def setCheckpointCallback(self, value):
return self._set(checkpoint_callback=value)

def getCheckpointCallback(self):
return self.getOrDefault(self.checkpoint_callback)

def getProfiler(self):
return self.getOrDefault(self.profiler)

Expand Down Expand Up @@ -401,6 +412,7 @@ def _fit_on_prepared_data(self, backend, train_rows, val_rows, metadata, avg_row
validation=self.getValidation())

serialized_model = serialize_fn()(model)
# FIXME: checkpoint bytes should be loaded into serialized_model, same as Keras Estimator.
ckpt_bytes = self._read_checkpoint(run_id) if self._has_checkpoint(run_id) else None
trainer = remote.RemoteTrainer(self,
metadata=metadata,
Expand Down
71 changes: 39 additions & 32 deletions horovod/spark/lightning/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def RemoteTrainer(estimator, metadata, ckpt_bytes, run_id, dataset_idx, train_ro
transformation = transformation_fn if transformation_fn else None
inmemory_cache_all = estimator.getInMemoryCacheAll()
callbacks = estimator.getCallbacks() or []
checkpoint_callback = estimator.getCheckpointCallback()
train_steps_per_epoch = estimator.getTrainStepsPerEpoch()
val_steps_per_epoch = estimator.getValidationStepsPerEpoch()
num_gpus = estimator.getNumGPUs()
Expand Down Expand Up @@ -88,16 +89,12 @@ def train(serialized_model):
# Horovod: initialize library.
hvd.init()

with tempfile.TemporaryDirectory() as last_ckpt_dir, remote_store.get_local_output_dir() as run_output_dir:
last_ckpt_file = os.path.join(last_ckpt_dir, 'last.ckpt')
if ckpt_bytes:
with open(last_ckpt_file, 'wb') as f:
f.write(ckpt_bytes)

# TODO: Pass the logger from estimator constructor
with remote_store.get_local_output_dir() as run_output_dir:
logs_path = os.path.join(run_output_dir, remote_store.logs_subdir)
os.makedirs(logs_path, exist_ok=True)
print(f"Made directory {logs_path} for horovod rank {hvd.rank()}")
ckpt_dir = run_output_dir
ckpt_filename = remote_store.checkpoint_filename

# Use default logger if no logger is supplied
train_logger = logger
Expand All @@ -106,22 +103,25 @@ def train(serialized_model):
if train_logger is None:
train_logger = TensorBoardLogger(logs_path)

# TODO: find out a way to use ckpt_path created from remote store, but all other parameters ingest from estimator config
# ckpt_path = os.path.join(run_output_dir, remote_store.checkpoint_filename)
# os.makedirs(ckpt_path, exist_ok=True)
# model_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path)
# callbacks.append(model_checkpoint_callback)

is_model_checkpoint_callback_exist = False
for cb in callbacks:
if isinstance(cb, ModelCheckpoint):
is_model_checkpoint_callback_exist = True
break
# Lightning requires to add checkpoint callbacks for all ranks.
# Otherwise we are seeing hanging in training.
_checkpoint_callback = checkpoint_callback
if _checkpoint_callback:
_checkpoint_callback.dir_path = ckpt_dir
_checkpoint_callback.filename = ckpt_filename
else:
# By default 'monitor'=None which saves a checkpoint only for the last epoch.
_checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir,
filename=ckpt_filename,
verbose=True)
callbacks.append(_checkpoint_callback)

if remote_store.saving_runs and hvd.rank() == 0:
# Horovod: sync checkpoint and logging files only on rank 0 to
# prevent other ranks from corrupting them.
class _SyncCallback(Callback):
def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
remote_store.sync(logs_path)
remote_store.sync(run_output_dir)

callbacks.append(_SyncCallback())

Expand All @@ -133,7 +133,11 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
_val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else \
int(math.floor(float(val_rows) / val_batch_size / hvd.size()))

print(f"Training data of rank[{hvd.local_rank()}]: train_rows:{train_rows}, batch_size:{batch_size}, _train_steps_per_epoch:{_train_steps_per_epoch}.")
if verbose:
print(f"Training data of rank[{hvd.local_rank()}]: Epochs: {epochs}\n"
f"Train rows: {train_rows}, Train batch size: {batch_size}, Train_steps_per_epoch: {_train_steps_per_epoch}\n"
f"Val rows: {val_rows}, Val batch size: {val_batch_size}, Val_steps_per_epoch: {_val_steps_per_epoch}\n"
f"Checkpoint file: {remote_store.checkpoint_path}, Logs dir: {remote_store.logs_path}\n")

cuda_available = torch.cuda.is_available()
# We need to check all ranks have same device type for traning.
Expand All @@ -158,8 +162,6 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
'max_epochs': epochs,
'logger': train_logger,
'log_every_n_steps': log_every_n_steps,
'resume_from_checkpoint': (last_ckpt_file if ckpt_bytes else None),
'checkpoint_callback': is_model_checkpoint_callback_exist,
'num_sanity_val_steps': 0,
'reload_dataloaders_every_epoch': False,
'progress_bar_refresh_rate': _train_steps_per_epoch // 10,
Expand All @@ -172,6 +174,9 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if trainer.profiler:
print(f"Set profiler's logs_path to {logs_path}")
trainer.profiler.dirpath = logs_path
# filename where the profiler results will be saved instead of
# printing to stdout. The .txt extension will be used automatically.
trainer.profiler.filename = "profile"

print(f"pytorch_lightning version={pl.__version__}")

Expand All @@ -191,19 +196,21 @@ def on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
verbose=verbose)
trainer.fit(model, dataset)

serialized_checkpoint = io.BytesIO()
module = model if not is_legacy else model._model
if hvd.rank() == 0:
if remote_store.saving_runs and trainer.profiler:
# One more file sync to push profiler result.
remote_store.sync(logs_path)

# TODO: find a way to pass trainer.logged_metrics out.
output = {'model': module.state_dict()}
# rank 0 overwrites model with best checkpoint and returns.
best_model = model.load_from_checkpoint(_checkpoint_callback.best_model_path)
serialized_checkpoint = io.BytesIO()
module = best_model if not is_legacy else best_model._model

torch.save(output, serialized_checkpoint)

if remote_store.saving_runs and hvd.rank() == 0:
remote_store.sync(logs_path)
# TODO: find a way to pass trainer.logged_metrics out.
output = {'model': module.state_dict()}

serialized_checkpoint.seek(0)
return serialized_checkpoint
torch.save(output, serialized_checkpoint)
return serialized_checkpoint
return train


Expand Down

0 comments on commit 592d209

Please sign in to comment.