Skip to content

Commit

Permalink
use gfile to support remote directories
Browse files Browse the repository at this point in the history
Tests all use the `tmpfile` fixture which provides a py.path.local which is
incompatible with the compat.gfile. The contract in many places is type str or
Optional[str] which py.path.local is not.

I hope that folks are not passing in path.local objects, if so this change will
break them. The type annotations say to use str, so this should be ok. The
other option is to just explicitly convert to str as to not break people using
an incorrect type (like the tests were doing)
  • Loading branch information
f4hy committed Jun 13, 2020
1 parent 5fd01b0 commit cdc4cb2
Show file tree
Hide file tree
Showing 12 changed files with 120 additions and 93 deletions.
25 changes: 18 additions & 7 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only
from pytorch_lightning.utilities.io import gfile


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -97,7 +98,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
if(filepath):
filepath = str(filepath) # the tests pass in a py.path.local but we want a str
if save_top_k > 0 and filepath is not None and gfile.isdir(filepath) and len(gfile.listdir(filepath)) > 0:
rank_zero_warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
Expand All @@ -109,12 +112,13 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
if filepath is None: # will be determined by trainer at runtime
self.dirpath, self.filename = None, None
else:
if os.path.isdir(filepath):
if gfile.isdir(filepath):
self.dirpath, self.filename = filepath, '{epoch}'
else:
filepath = os.path.realpath(filepath)
self.dirpath, self.filename = os.path.split(filepath)
os.makedirs(self.dirpath, exist_ok=True)
if not gfile.exists(self.dirpath):
gfile.makedirs(self.dirpath)
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
Expand Down Expand Up @@ -156,12 +160,19 @@ def kth_best_model(self):
return self.kth_best_model_path

def _del_model(self, filepath):
if os.path.isfile(filepath):
os.remove(filepath)
if gfile.exists(filepath):
try:
# in compat mode, remove is not implemented so if running this
# against an actual remove file system and the correct remote
# dependencies exist then this will work fine.
gfile.remove(filepath)
except AttributeError:
os.remove(filepath)

def _save_model(self, filepath):
# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
if not gfile.exists(os.path.dirname(filepath)):
gfile.makedirs(os.path.dirname(filepath))

# delegate the saving to the model
if self.save_function is not None:
Expand Down Expand Up @@ -249,7 +260,7 @@ def on_validation_end(self, trainer, pl_module):

filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
while os.path.isfile(filepath):
while gfile.exists(filepath):
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
Expand Down
21 changes: 13 additions & 8 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.io import load as pl_load

# we want this for tf.iogfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
# not all
from tensorboard.compat import tf

PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, dict, Namespace)
try:
Expand Down Expand Up @@ -269,25 +274,25 @@ def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
True
>>> os.remove(path_csv)
"""
if not os.path.isfile(tags_csv):
if not tf.io.gfile.exists(tags_csv):
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
return {}

with open(tags_csv) as fp:
with tf.io.gfile.GFile(tags_csv, "r") as fp:
csv_reader = csv.reader(fp, delimiter=',')
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}

return tags


def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(tags_csv)):
if not tf.io.gfile.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with open(tags_csv, 'w') as fp:
with tf.io.gfile.GFile(tags_csv, 'w') as fp:
fieldnames = ['key', 'value']
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({'key': 'key', 'value': 'value'})
Expand All @@ -306,24 +311,24 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
True
>>> os.remove(path_yaml)
"""
if not os.path.isfile(config_yaml):
if not tf.io.gfile.exists(config_yaml):
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
return {}

with open(config_yaml) as fp:
with tf.io.gfile.GFile(config_yaml, "r") as fp:
tags = yaml.load(fp, Loader=yaml.SafeLoader)

return tags


def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(config_yaml)):
if not tf.io.gfile.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')

if isinstance(hparams, Namespace):
hparams = vars(hparams)

with open(config_yaml, 'w', newline='') as fp:
with tf.io.gfile.GFile(config_yaml, 'w') as fp:
yaml.dump(hparams, fp)


Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.io import gfile


class TensorBoardLogger(LightningLoggerBase):
Expand Down Expand Up @@ -97,7 +98,8 @@ def experiment(self) -> SummaryWriter:
if self._experiment is not None:
return self._experiment

os.makedirs(self.root_dir, exist_ok=True)
if not gfile.exists(self.root_dir):
gfile.makedirs(self.root_dir)
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

Expand Down Expand Up @@ -145,7 +147,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
def save(self) -> None:
super().save()
dir_path = self.log_dir
if not os.path.isdir(dir_path):
if not gfile.isdir(dir_path):
dir_path = self.save_dir

# prepare the file path
Expand All @@ -171,13 +173,13 @@ def version(self) -> int:
def _get_next_version(self):
root_dir = os.path.join(self.save_dir, self.name)

if not os.path.isdir(root_dir):
if not gfile.isdir(root_dir):
log.warning('Missing logger folder: %s', root_dir)
return 0

existing_versions = []
for d in os.listdir(root_dir):
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
for d in gfile.listdir(root_dir):
if gfile.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
existing_versions.append(int(d.split("_")[1]))

if len(existing_versions) == 0:
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytorch_lightning.callbacks import Callback, ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.io import gfile


class TrainerCallbackConfigMixin(ABC):
Expand Down Expand Up @@ -67,7 +68,8 @@ def configure_checkpoint_callback(self):
monitor_key = 'loss' if train_step_only else 'val_loss'

if self.checkpoint_callback is True:
os.makedirs(ckpt_path, exist_ok=True)
if not gfile.exists(ckpt_path):
gfile.makedirs(ckpt_path)
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path,
monitor=monitor_key
Expand All @@ -77,7 +79,9 @@ def configure_checkpoint_callback(self):
and self.checkpoint_callback.dirpath is None:
self.checkpoint_callback.dirpath = ckpt_path
self.checkpoint_callback.filename = '{epoch}'
os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True)
if not gfile.exists(self.checkpoint_callback.dirpath):
gfile.makedirs(self.checkpoint_callback.dirpath)

elif self.checkpoint_callback is False:
self.checkpoint_callback = None

Expand Down
14 changes: 8 additions & 6 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
)
from pytorch_lightning.utilities import rank_zero_warn, parsing
from pytorch_lightning.utilities.io import load as pl_load
from pytorch_lightning.utilities.io import gfile

try:
import torch_xla
Expand Down Expand Up @@ -375,8 +376,8 @@ def restore_hpc_weights_if_needed(self, model: LightningModule):

# look for hpc weights
folderpath = self.weights_save_path
if os.path.exists(folderpath):
files = os.listdir(folderpath)
if gfile.exists(folderpath):
files = gfile.listdir(folderpath)
hpc_weight_paths = [x for x in files if 'hpc_ckpt' in x]

# if hpc weights exist restore model
Expand Down Expand Up @@ -451,15 +452,16 @@ def restore_training_state(self, checkpoint):
# ----------------------------------
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
os.makedirs(folderpath, exist_ok=True)
if not gfile.exists(folderpath):
gfile.makedirs(folderpath)

# save logger to make sure we get all the metrics
logger.save()

ckpt_number = self.max_ckpt_in_folder(folderpath) + 1

if not os.path.exists(folderpath):
os.makedirs(folderpath, exist_ok=True)
if not gfile.exists(folderpath):
gfile.makedirs(folderpath)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

# give model a chance to do something on hpc_save
Expand Down Expand Up @@ -509,7 +511,7 @@ def hpc_load(self, folderpath, on_gpu):
log.info(f'restored hpc model from: {filepath}')

def max_ckpt_in_folder(self, path, name_key='ckpt_'):
files = os.listdir(path)
files = gfile.listdir(path)
files = [x for x in files if name_key in x]
if len(files) == 0:
return 0
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/utilities/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

from urllib.parse import urlparse

# we want this for tf.io.gfile, which if tf is installed gives full tf,
# otherwise gives a pruned down version which works for some file backends but
# not all
from tensorboard.compat import tf

gfile = tf.io.gfile


def load(path_or_url: str, map_location=None):
parsed = urlparse(path_or_url)
Expand Down
6 changes: 3 additions & 3 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def log_metrics(self, metrics, step):
super().log_metrics(metrics, step)
self.history.append((step, metrics))

logger_args = _get_logger_args(logger_class, tmpdir)
logger_args = _get_logger_args(logger_class, str(tmpdir))
logger = StoreHistoryLogger(**logger_args)

trainer = Trainer(
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)

logger_args = _get_logger_args(logger_class, tmpdir)
logger_args = _get_logger_args(logger_class, str(tmpdir))
logger = logger_class(**logger_args)

# test pickling loggers
Expand All @@ -109,7 +109,7 @@ def test_logger_reset_correctly(tmpdir, extra_params):
model = EvalModelTemplate()

trainer = Trainer(
default_root_dir=tmpdir,
default_root_dir=str(tmpdir),
**extra_params
)
logger1 = trainer.logger
Expand Down
Loading

0 comments on commit cdc4cb2

Please sign in to comment.