diff --git a/earthaccess/store.py b/earthaccess/store.py index 68fec997..f279e0f9 100644 --- a/earthaccess/store.py +++ b/earthaccess/store.py @@ -2,12 +2,11 @@ import os import shutil import traceback -from copy import deepcopy from functools import lru_cache from itertools import chain from pathlib import Path from pickle import dumps, loads -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from uuid import uuid4 import earthaccess @@ -97,8 +96,9 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None: """ if auth.authenticated is True: self.auth = auth - self.s3_fs = None - self.initial_ts = datetime.datetime.now() + self._s3_credentials: Dict[ + Tuple, Tuple[datetime.datetime, Dict[str, str]] + ] = {} oauth_profile = "https://urs.earthdata.nasa.gov/profile" # sets the initial URS cookie self._requests_cookies: Dict[str, Any] = {} @@ -182,7 +182,6 @@ def set_requests_session( elif resp.status_code >= 500: resp.raise_for_status() - @lru_cache def get_s3fs_session( self, daac: Optional[str] = None, @@ -200,39 +199,54 @@ def get_s3fs_session( Returns: a s3fs file instance """ - if self.auth is not None: - if not any([concept_id, daac, provider, endpoint]): - raise ValueError( - "At least one of the concept_id, daac, provider or endpoint" - "parameters must be specified. " - ) - if endpoint is not None: - s3_credentials = self.auth.get_s3_credentials(endpoint=endpoint) - elif concept_id is not None: - provider = self._derive_concept_provider(concept_id) - s3_credentials = self.auth.get_s3_credentials(provider=provider) - elif daac is not None: - s3_credentials = self.auth.get_s3_credentials(daac=daac) - elif provider is not None: - s3_credentials = self.auth.get_s3_credentials(provider=provider) - now = datetime.datetime.now() - delta_minutes = now - self.initial_ts - # TODO: test this mocking the time or use https://github.com/dbader/schedule - # if we exceed 1 hour - if ( - self.s3_fs is None or round(delta_minutes.seconds / 60, 2) > 59 - ) and s3_credentials is not None: - self.s3_fs = s3fs.S3FileSystem( - key=s3_credentials["accessKeyId"], - secret=s3_credentials["secretAccessKey"], - token=s3_credentials["sessionToken"], - ) - self.initial_ts = datetime.datetime.now() - return deepcopy(self.s3_fs) - else: + if self.auth is None: raise ValueError( "A valid Earthdata login instance is required to retrieve S3 credentials" ) + if not any([concept_id, daac, provider, endpoint]): + raise ValueError( + "At least one of the concept_id, daac, provider or endpoint" + "parameters must be specified. " + ) + + if concept_id is not None: + provider = self._derive_concept_provider(concept_id) + + # Get existing S3 credentials if we already have them + location = ( + daac, + provider, + endpoint, + ) # Identifier for where to get S3 credentials from + need_new_creds = False + try: + dt_init, creds = self._s3_credentials[location] + except KeyError: + need_new_creds = True + else: + # If cached credentials are expired, invalidate the cache + delta = datetime.datetime.now() - dt_init + if round(delta.seconds / 60, 2) > 55: + need_new_creds = True + self._s3_credentials.pop(location) + + if need_new_creds: + # Don't have existing valid S3 credentials, so get new ones + now = datetime.datetime.now() + if endpoint is not None: + creds = self.auth.get_s3_credentials(endpoint=endpoint) + elif daac is not None: + creds = self.auth.get_s3_credentials(daac=daac) + elif provider is not None: + creds = self.auth.get_s3_credentials(provider=provider) + # Include new credentials in the cache + self._s3_credentials[location] = now, creds + + return s3fs.S3FileSystem( + key=creds["accessKeyId"], + secret=creds["secretAccessKey"], + token=creds["sessionToken"], + ) @lru_cache def get_fsspec_session(self) -> fsspec.AbstractFileSystem: diff --git a/tests/unit/test_store.py b/tests/unit/test_store.py index 388cb5af..2faed5b8 100644 --- a/tests/unit/test_store.py +++ b/tests/unit/test_store.py @@ -5,6 +5,7 @@ import fsspec import pytest import responses +import s3fs from earthaccess import Auth, Store @@ -60,12 +61,22 @@ def test_store_can_create_s3_fsspec_session(self): "https://api.giovanni.earthdata.nasa.gov/s3credentials", "https://data.laadsdaac.earthdatacloud.nasa.gov/s3credentials", ] + mock_creds = { + "accessKeyId": "sure", + "secretAccessKey": "correct", + "sessionToken": "whynot", + } + expected_storage_options = { + "key": mock_creds["accessKeyId"], + "secret": mock_creds["secretAccessKey"], + "token": mock_creds["sessionToken"], + } for endpoint in custom_endpoints: responses.add( responses.GET, endpoint, - json={}, + json=mock_creds, status=200, ) @@ -74,17 +85,13 @@ def test_store_can_create_s3_fsspec_session(self): responses.add( responses.GET, daac["s3-credentials"], - json={ - "accessKeyId": "sure", - "secretAccessKey": "correct", - "sessionToken": "whynot", - }, + json=mock_creds, status=200, ) responses.add( responses.GET, "https://urs.earthdata.nasa.gov/profile", - json={}, + json=mock_creds, status=200, ) @@ -92,22 +99,25 @@ def test_store_can_create_s3_fsspec_session(self): self.assertTrue(isinstance(store.auth, Auth)) for daac in ["NSIDC", "PODAAC", "LPDAAC", "ORNLDAAC", "GES_DISC", "ASF"]: s3_fs = store.get_s3fs_session(daac=daac) - self.assertEqual(type(s3_fs), type(fsspec.filesystem("s3"))) + assert isinstance(s3_fs, s3fs.S3FileSystem) + assert s3_fs.storage_options == expected_storage_options for endpoint in custom_endpoints: s3_fs = store.get_s3fs_session(endpoint=endpoint) - self.assertEqual(type(s3_fs), type(fsspec.filesystem("s3"))) + assert isinstance(s3_fs, s3fs.S3FileSystem) + assert s3_fs.storage_options == expected_storage_options for provider in [ "NSIDC_CPRD", "POCLOUD", "LPCLOUD", - "ORNLCLOUD", + "ORNL_CLOUD", "GES_DISC", "ASF", ]: s3_fs = store.get_s3fs_session(provider=provider) - assert isinstance(s3_fs, fsspec.AbstractFileSystem) + assert isinstance(s3_fs, s3fs.S3FileSystem) + assert s3_fs.storage_options == expected_storage_options # Ensure informative error is raised with pytest.raises(ValueError, match="parameters must be specified"):