From 9292a34eb72ca739d46fcc8d55173bc0c38b4bf1 Mon Sep 17 00:00:00 2001 From: Laksh Aithani Date: Sat, 13 Jun 2020 16:19:06 +0100 Subject: [PATCH 1/6] Initial commit --- pytorch_lightning/core/saving.py | 29 +++++++++-------- pytorch_lightning/utilities/cloud_io.py | 43 +++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 4ff019a20c6ae..5e385c79c3df3 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -40,19 +40,20 @@ def load_from_metrics(cls, weights_path, tags_csv, map_location=None): """ rank_zero_warn( "`load_from_metrics` method has been unified with `load_from_checkpoint` in v0.7.0." - " The deprecated method will be removed in v0.9.0.", DeprecationWarning + " The deprecated method will be removed in v0.9.0.", + DeprecationWarning, ) return cls.load_from_checkpoint(weights_path, tags_csv=tags_csv, map_location=map_location) @classmethod def load_from_checkpoint( - cls, - checkpoint_path: str, - *args, - map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, - hparams_file: Optional[str] = None, - tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0 - **kwargs + cls, + checkpoint_path: str, + *args, + map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + hparams_file: Optional[str] = None, + tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0 + **kwargs, ): r""" Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint @@ -136,10 +137,12 @@ def load_from_checkpoint( pretrained_model.freeze() y_hat = pretrained_model(x) """ - if map_location is not None: - checkpoint = pl_load(checkpoint_path, map_location=map_location) - else: - checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + if not map_location: + + def map_location(storage, loc): + return storage + + checkpoint = pl_load(checkpoint_path, map_location=map_location) # add the hparams from csv file to checkpoint if tags_csv is not None: @@ -193,7 +196,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs): if args_name in init_args_name: kwargs.update({args_name: model_args}) else: - args = (model_args, ) + args + args = (model_args,) + args # load the state_dict on the model automatically model = cls(*args, **kwargs) diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index de5bd1918e03b..cf9e95c43de71 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -9,4 +9,47 @@ def load(path_or_url: str, map_location=None): if parsed.scheme == '' or Path(path_or_url).is_file(): # no scheme or local file return torch.load(path_or_url, map_location=map_location) + elif parsed.scheme == 's3': + return load_s3_checkpoint(path_or_url, map_location=map_location) return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) + + +def load_s3_checkpoint(checkpoint_path, map_location, **pickle_load_args): + from torch.serialization import _legacy_load + import pickle + + # Attempt s3fs import + try: + import s3fs + except ImportError: + raise ImportError( + f"Tried to import `s3fs` for AWS S3 i/o, but `s3fs` is not installed. " + f"Please `pip install s3fs` and try again." + ) + + if 'encoding' not in pickle_load_args.keys(): + pickle_load_args['encoding'] = 'utf-8' + fs = s3fs.S3FileSystem() + with fs.open(checkpoint_path, "rb") as f: + checkpoint = _legacy_load(f, map_location, pickle, **pickle_load_args) + return checkpoint + + +def save_s3_checkpoint(checkpoint, checkpoint_path): + from torch.serialization import _legacy_save + import pickle + + # Attempt s3fs import + try: + import s3fs + except ImportError: + raise ImportError( + f'Tried to import `s3fs` for AWS S3 i/o, but `s3fs` is not installed. ' + f'Please `pip install s3fs` and try again.' + ) + + DEFAULT_PROTOCOL = 2 # from torch.serialization.py line 19 + fs = s3fs.S3FileSystem() + with fs.open(checkpoint_path, "wb") as f: + checkpoint = _legacy_save(checkpoint, checkpoint_path, pickle, DEFAULT_PROTOCOL) + return checkpoint From 60848c40e2840f84febd4908a53acab4c6780e33 Mon Sep 17 00:00:00 2001 From: Jirka Date: Sat, 13 Jun 2020 19:30:03 +0200 Subject: [PATCH 2/6] flake8 --- pytorch_lightning/utilities/cloud_io.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index cf9e95c43de71..635d8fb978b39 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -1,8 +1,8 @@ -import torch - -from pathlib import Path +import pickle from urllib.parse import urlparse +import torch + def load(path_or_url: str, map_location=None): parsed = urlparse(path_or_url) @@ -14,18 +14,15 @@ def load(path_or_url: str, map_location=None): return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) -def load_s3_checkpoint(checkpoint_path, map_location, **pickle_load_args): +def load_s3_checkpoint(checkpoint_path: str, map_location, **pickle_load_args): from torch.serialization import _legacy_load - import pickle # Attempt s3fs import try: import s3fs except ImportError: - raise ImportError( - f"Tried to import `s3fs` for AWS S3 i/o, but `s3fs` is not installed. " - f"Please `pip install s3fs` and try again." - ) + raise ImportError("Tried to import `s3fs` for AWS S3 i/o, but `s3fs` is not installed." + " Please `pip install s3fs` and try again.") if 'encoding' not in pickle_load_args.keys(): pickle_load_args['encoding'] = 'utf-8' @@ -37,16 +34,13 @@ def load_s3_checkpoint(checkpoint_path, map_location, **pickle_load_args): def save_s3_checkpoint(checkpoint, checkpoint_path): from torch.serialization import _legacy_save - import pickle # Attempt s3fs import try: import s3fs except ImportError: - raise ImportError( - f'Tried to import `s3fs` for AWS S3 i/o, but `s3fs` is not installed. ' - f'Please `pip install s3fs` and try again.' - ) + raise ImportError("Tried to import `s3fs` for AWS S3 i/o, but `s3fs` is not installed." + " Please `pip install s3fs` and try again.") DEFAULT_PROTOCOL = 2 # from torch.serialization.py line 19 fs = s3fs.S3FileSystem() From 166489b222461e7f662ae47e36e925a052c1d36d Mon Sep 17 00:00:00 2001 From: Jirka Date: Sat, 13 Jun 2020 19:38:40 +0200 Subject: [PATCH 3/6] add s3fs --- requirements/devel.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/devel.txt b/requirements/devel.txt index 3d68cdcbc16e0..71d12187bfea2 100644 --- a/requirements/devel.txt +++ b/requirements/devel.txt @@ -7,4 +7,5 @@ # extended list of dependencies dor development and run lint and tests -r ./test.txt -cloudpickle>=1.2 \ No newline at end of file +cloudpickle>=1.2 +s3fs \ No newline at end of file From d3bbdd5686563a0eb024c3db3aa935bbe5deafba Mon Sep 17 00:00:00 2001 From: Laksh Aithani Date: Sat, 13 Jun 2020 23:20:07 +0100 Subject: [PATCH 4/6] Implement boto3 S3 I/O --- .../callbacks/model_checkpoint.py | 13 +- pytorch_lightning/utilities/cloud_io.py | 119 +++++++++++++----- requirements/devel.txt | 3 +- 3 files changed, 102 insertions(+), 33 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6a47e6f58c88a..bf48ce38991a8 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -16,6 +16,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only +import pytorch_lightning.utilities.cloud_io as cloud_io class ModelCheckpoint(Callback): @@ -95,7 +96,7 @@ class ModelCheckpoint(Callback): def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False, save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False, - mode: str = 'auto', period: int = 1, prefix: str = ''): + mode: str = 'auto', period: int = 1, prefix: str = '', remove_non_top_k_s3_files: bool = True): super().__init__() if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0: rank_zero_warn( @@ -109,6 +110,10 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve if filepath is None: # will be determined by trainer at runtime self.dirpath, self.filename = None, None else: + if cloud_io.is_s3_path(filepath): + self.save_to_s3 = True + self.bucket, filepath = cloud_io.parse_s3_path(filepath) + if os.path.isdir(filepath): self.dirpath, self.filename = filepath, '{epoch}' else: @@ -127,6 +132,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve self.best_model_score = 0 self.best_model_path = '' self.save_function = None + self.remove_non_top_k_s3_files = remove_non_top_k_s3_files torch_inf = torch.tensor(np.Inf) mode_dict = { @@ -158,6 +164,9 @@ def kth_best_model(self): def _del_model(self, filepath): if os.path.isfile(filepath): os.remove(filepath) + if self.save_to_s3: + if self.remove_non_top_k_s3_files: + cloud_io.remove_checkpoint_from_s3(self.bucket, filepath) def _save_model(self, filepath): # make paths @@ -168,6 +177,8 @@ def _save_model(self, filepath): self.save_function(filepath, self.save_weights_only) else: raise ValueError(".save_function() not set") + if self.save_to_s3: + cloud_io.save_checkpoint_to_s3(self.bucket, filepath) def check_monitor_top_k(self, current): less_than_k_models = len(self.best_k_models) < self.save_top_k diff --git a/pytorch_lightning/utilities/cloud_io.py b/pytorch_lightning/utilities/cloud_io.py index 635d8fb978b39..87cb36aab8363 100644 --- a/pytorch_lightning/utilities/cloud_io.py +++ b/pytorch_lightning/utilities/cloud_io.py @@ -1,7 +1,23 @@ -import pickle +from typing import Tuple from urllib.parse import urlparse - +import os.path as osp +import os import torch +from torch.hub import _get_torch_home + +import logging + +logger = logging.getLogger(__name__) + +torch_cache_home = _get_torch_home() +default_cache_path = osp.join(torch_cache_home, "pl_checkpoints") + + +def try_import_boto3(): + try: + import boto3 + except ImportError: + raise ImportError(f'Could not import `boto3`. Please `pip install boto3` and try again.') def load(path_or_url: str, map_location=None): @@ -10,40 +26,81 @@ def load(path_or_url: str, map_location=None): # no scheme or local file return torch.load(path_or_url, map_location=map_location) elif parsed.scheme == 's3': - return load_s3_checkpoint(path_or_url, map_location=map_location) + # AWS S3 file + filepath = download_checkpoint_from_s3(path_or_url) + return torch.load(filepath, map_location=map_location) + # URL return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location) -def load_s3_checkpoint(checkpoint_path: str, map_location, **pickle_load_args): - from torch.serialization import _legacy_load +def is_s3_path(path: str): + """Checks if path is a valid S3 path""" + return path.startswith("s3://") - # Attempt s3fs import - try: - import s3fs - except ImportError: - raise ImportError("Tried to import `s3fs` for AWS S3 i/o, but `s3fs` is not installed." - " Please `pip install s3fs` and try again.") - if 'encoding' not in pickle_load_args.keys(): - pickle_load_args['encoding'] = 'utf-8' - fs = s3fs.S3FileSystem() - with fs.open(checkpoint_path, "rb") as f: - checkpoint = _legacy_load(f, map_location, pickle, **pickle_load_args) - return checkpoint +def parse_s3_path(s3_path: str) -> Tuple[str, str]: + """ + Returns bucket and key from an S3 path. + Example: "s3://my-bucket/folder/checkpoint.ckpt" -> ("my-bucket", "folder/checkpoint.ckpt") + """ + s3_path = urlparse(s3_path, allow_fragments=True) + assert s3_path.scheme == 's3', f'{s3_path} is not a valid AWS S3 path. Needs to start with `s3://`' + bucket, key = s3_path.netloc, s3_path.path + if key.startswith('/'): + key = key[1:] + return bucket, key -def save_s3_checkpoint(checkpoint, checkpoint_path): - from torch.serialization import _legacy_save +def save_checkpoint_to_s3(bucket_name, key): + """ + Saves a single checkpoint to an S3 path. + Args: + bucket_name: The name of the bucket we want to save to + key: The rest of the s3 path. + Returns: + None + """ + try_import_boto3() + bucket = boto3.resource("s3").Bucket(bucket_name) + bucket.upload_file(Filename=key, Key=key) - # Attempt s3fs import - try: - import s3fs - except ImportError: - raise ImportError("Tried to import `s3fs` for AWS S3 i/o, but `s3fs` is not installed." - " Please `pip install s3fs` and try again.") - - DEFAULT_PROTOCOL = 2 # from torch.serialization.py line 19 - fs = s3fs.S3FileSystem() - with fs.open(checkpoint_path, "wb") as f: - checkpoint = _legacy_save(checkpoint, checkpoint_path, pickle, DEFAULT_PROTOCOL) - return checkpoint + +def download_checkpoint_from_s3(path_or_url: str, overwrite=False) -> str: + """ + Downloads file from S3 and saves it in default cache path under original S3 key. + Returns filepath where object has been downloaded. + """ + try_import_boto3() + + # Eg "s3://bucket-name/folder/checkpoint.ckpt" --> ("bucket-name", "folder/checkpoint.ckpt") + bucket_name, key = parse_s3_path(path_or_url) + + # ("folder", "checkpoint.ckpt") + directory, filename = osp.split(key) + + # Make directory: '/Users/johnDoe/.cache/torch/pl_checkpoints/folder' + directory_to_make = osp.join(default_cache_path, directory) + os.makedirs(directory_to_make, exist_ok=True) + + # File we will download to: '/Users/johnDoe/.cache/torch/pl_checkpoints/folder/checkpoint.ckpt' + filepath = osp.join(directory_to_make, filename) + + def _download(): + s3 = boto3.resource("s3") + bucket = s3.Bucket(bucket_name) + bucket.download_file(Key=key, Filename=filepath) + + if not osp.exists(filepath): + _download() + else: + if overwrite: + _download() + return filepath + + +def remove_checkpoint_from_s3(bucket, key): + """Simple remove object from S3""" + try_import_boto3() + s3 = boto3.resource("s3") + obj = s3.Object(bucket, key) + obj.delete() diff --git a/requirements/devel.txt b/requirements/devel.txt index 71d12187bfea2..56b9a3de56bf6 100644 --- a/requirements/devel.txt +++ b/requirements/devel.txt @@ -8,4 +8,5 @@ -r ./test.txt cloudpickle>=1.2 -s3fs \ No newline at end of file +boto3 +botocore \ No newline at end of file From d27d7becb5ddd094c466ae1a5dd8c318387eda8f Mon Sep 17 00:00:00 2001 From: Laksh Aithani Date: Sat, 13 Jun 2020 23:36:31 +0100 Subject: [PATCH 5/6] Made start on moto testing --- requirements/test.txt | 1 + tests/conftest.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/requirements/test.txt b/requirements/test.txt index 2945bc5f968d2..4d27d26293eaa 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -9,3 +9,4 @@ check-manifest twine==1.13.0 black==19.10b0 pre-commit>=1.0 +moto diff --git a/tests/conftest.py b/tests/conftest.py index 91312fc848582..4ad39cabba3aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,9 @@ import pytest import torch.multiprocessing as mp +import boto3 +from moto import mock_s3 +import os def pytest_configure(config): @@ -55,3 +58,20 @@ class ThreadingHTTPServer(ThreadingMixIn, HTTPServer): server_thread.start() yield server.server_address server.shutdown() + + +@pytest.fixture(scope='function') +def aws_credentials(): + """Mocked AWS Credentials for moto.""" + os.environ['AWS_ACCESS_KEY_ID'] = 'testing' + os.environ['AWS_SECRET_ACCESS_KEY'] = 'testing' + os.environ['AWS_SECURITY_TOKEN'] = 'testing' + os.environ['AWS_SESSION_TOKEN'] = 'testing' + + +@pytest.fixture(scope='function') +def s3(aws_credentials): + with mock_s3(): + s3 = boto3.client('s3') + s3.create_bucket(Bucket='testing') + yield s3 From 9598c74f664f77f4ea4a6adcc5660db0bb69b8b4 Mon Sep 17 00:00:00 2001 From: Laksh Aithani Date: Sat, 13 Jun 2020 23:37:08 +0100 Subject: [PATCH 6/6] Remove botocore dep --- requirements/devel.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/devel.txt b/requirements/devel.txt index 56b9a3de56bf6..6cff5050ec639 100644 --- a/requirements/devel.txt +++ b/requirements/devel.txt @@ -9,4 +9,3 @@ cloudpickle>=1.2 boto3 -botocore \ No newline at end of file