diff --git a/src/sparsezoo/utils/authentication.py b/src/sparsezoo/utils/authentication.py index 8765fe44..5694cdae 100644 --- a/src/sparsezoo/utils/authentication.py +++ b/src/sparsezoo/utils/authentication.py @@ -45,103 +45,73 @@ ) -class SparseZooCredentials: - """ - Class wrapping around the sparse zoo credentials file. - """ - - def __init__(self): - if os.path.exists(CREDENTIALS_YAML): - _LOGGER.debug(f"Loading sparse zoo credentials from {CREDENTIALS_YAML}") - with open(CREDENTIALS_YAML) as credentials_file: - credentials_yaml = yaml.safe_load(credentials_file) - if credentials_yaml and CREDENTIALS_YAML_TOKEN_KEY in credentials_yaml: - self._token = credentials_yaml[CREDENTIALS_YAML_TOKEN_KEY]["token"] - self._created = credentials_yaml[CREDENTIALS_YAML_TOKEN_KEY][ - "created" - ] - else: - self._token = None - self._created = None - else: - _LOGGER.debug( - f"No sparse zoo credentials files found at {CREDENTIALS_YAML}" - ) - self._token = None - self._created = None - - def save_token(self, token: str, created: float): - """ - Save the jwt for accessing sparse zoo APIs. Will create the credentials file - if it does not exist already. - - :param token: the jwt for accessing sparse zoo APIs - :param created: the approximate time the token was created - """ - _LOGGER.debug(f"Saving sparse zoo credentials at {CREDENTIALS_YAML}") - if not os.path.exists(CREDENTIALS_YAML): - create_parent_dirs(CREDENTIALS_YAML) - with open(CREDENTIALS_YAML, "w+") as credentials_file: - credentials_yaml = yaml.safe_load(credentials_file) - if credentials_yaml is None: - credentials_yaml = {} - credentials_yaml[CREDENTIALS_YAML_TOKEN_KEY] = { - "token": token, - "created": created, - } - self._token = token - self._created = created - - yaml.safe_dump(credentials_yaml, credentials_file) - - @property - def token(self): - """ - :return: obtain the token if under 1 day old, else return None - """ - _LOGGER.debug(f"Obtaining sparse zoo credentials from {CREDENTIALS_YAML}") - if self._token and self._created is not None: - creation_date = datetime.fromtimestamp(self._created, tz=timezone.utc) - creation_difference = datetime.now(tz=timezone.utc) - creation_date - if creation_difference.days < 30: - return self._token - else: - _LOGGER.debug(f"Expired sparse zoo credentials at {CREDENTIALS_YAML}") - return None - else: - _LOGGER.debug(f"No sparse zoo credentials found at {CREDENTIALS_YAML}") - return None - - def get_auth_header( - authentication_type: str = PUBLIC_AUTH_TYPE, - force_token_refresh: bool = False, + force_token_refresh: bool = False, path: str = CREDENTIALS_YAML ) -> Dict: """ Obtain an authentication header token from either credentials file or from APIs - if token is over 1 day old. Location of credentials file can be changed by setting + if token is over 30 day old. Location of credentials file can be changed by setting the environment variable `NM_SPARSE_ZOO_CREDENTIALS`. Currently only 'public' authentication type is supported. - :param authentication_type: authentication type for generating token :param force_token_refresh: forces a new token to be generated :return: An authentication header with key 'nm-token-header' containing the header token """ - credentials = SparseZooCredentials() - token = credentials.token - if token and not force_token_refresh: - return {NM_TOKEN_HEADER: token} - elif authentication_type.lower() == PUBLIC_AUTH_TYPE: + token = _maybe_load_token(path) + if token is None or force_token_refresh: _LOGGER.info("Obtaining new sparse zoo credentials token") - created = time.time() response = requests.post( url=AUTH_API, data=json.dumps({"authentication_type": PUBLIC_AUTH_TYPE}) ) response.raise_for_status() token = response.json()["token"] - credentials.save_token(token, created) - return {NM_TOKEN_HEADER: token} - else: - raise Exception(f"Authentication type {PUBLIC_AUTH_TYPE} not supported.") + created = time.time() + _save_token(token, created, path) + return {NM_TOKEN_HEADER: token} + + +def _maybe_load_token(path: str): + if not os.path.exists(path): + _LOGGER.debug(f"No sparse zoo credentials files found at {path}") + return None + + _LOGGER.debug(f"Loading sparse zoo credentials from {path}") + + with open(path) as fp: + creds = yaml.safe_load(fp) + + if creds is None or CREDENTIALS_YAML_TOKEN_KEY not in creds: + _LOGGER.debug(f"No sparse zoo credentials found at {path}") + return None + + info = creds[CREDENTIALS_YAML_TOKEN_KEY] + if "token" not in info or "created" not in info: + _LOGGER.debug(f"No sparse zoo credentials found at {path}") + return None + + date_created = datetime.fromtimestamp(info["created"], tz=timezone.utc) + creation_difference = datetime.now(tz=timezone.utc) - date_created + + if creation_difference.days > 30: + _LOGGER.debug(f"Expired sparse zoo credentials at {path}") + return None + + return info["token"] + + +def _save_token(token: str, created: float, path: str): + """ + Save the jwt for accessing sparse zoo APIs. Will create the credentials file + if it does not exist already. + + :param token: the jwt for accessing sparse zoo APIs + :param created: the approximate time the token was created + """ + _LOGGER.debug(f"Saving sparse zoo credentials at {CREDENTIALS_YAML}") + if not os.path.exists(path): + create_parent_dirs(path) + with open(path, "w+") as fp: + auth = {CREDENTIALS_YAML_TOKEN_KEY: dict(token=token, created=created)} + yaml.safe_dump(auth, fp) diff --git a/tests/sparsezoo/utils/test_authentication.py b/tests/sparsezoo/utils/test_authentication.py new file mode 100644 index 00000000..aff526d4 --- /dev/null +++ b/tests/sparsezoo/utils/test_authentication.py @@ -0,0 +1,87 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest +import yaml + +from sparsezoo.utils.authentication import ( + CREDENTIALS_YAML_TOKEN_KEY, + NM_TOKEN_HEADER, + _maybe_load_token, + _save_token, + get_auth_header, +) + + +def test_load_token_no_path(tmp_path): + path = str(tmp_path / "token.yaml") + assert _maybe_load_token(path) is None + + +def test_load_token_yaml_fail(tmp_path): + path = str(tmp_path / "token.yaml") + with open(path, "w") as fp: + fp.write("asdf") + assert _maybe_load_token(path) is None + + +_OLD_DATE = (datetime.now() - timedelta(days=40)).timestamp() + + +@pytest.mark.parametrize( + "content", + [ + {}, + {CREDENTIALS_YAML_TOKEN_KEY: {}}, + {CREDENTIALS_YAML_TOKEN_KEY: {"token": "asdf"}}, + {CREDENTIALS_YAML_TOKEN_KEY: {"created": "asdf"}}, + {CREDENTIALS_YAML_TOKEN_KEY: {"created": _OLD_DATE}}, + ], +) +def test_load_token_failure_cases(tmp_path, content): + path = str(tmp_path / "token.yaml") + with open(path, "w") as fp: + yaml.dump(content, fp) + assert _maybe_load_token(path) is None + + +def test_load_token_valid(tmp_path): + auth = { + CREDENTIALS_YAML_TOKEN_KEY: { + "created": datetime.now().timestamp(), + "token": "asdf", + } + } + path = str(tmp_path / "token.yaml") + with open(path, "w") as fp: + yaml.dump(auth, fp) + assert _maybe_load_token(path) == "asdf" + + +def test_load_saved_token(tmp_path): + path = str(tmp_path / "some" / "dirs" / "token.yaml") + _save_token("asdf", datetime.now().timestamp(), path) + assert _maybe_load_token(path) == "asdf" + + +@patch("requests.post", return_value=MagicMock(json=lambda: {"token": "qwer"})) +def test_get_auth_token(post_mock, tmp_path): + path = tmp_path / "creds.yaml" + assert not path.exists() + assert get_auth_header(path=str(path)) == {NM_TOKEN_HEADER: "qwer"} + assert path.exists() + post_mock.assert_called()