Skip to content

Commit

Permalink
Merge pull request #843 from materialsproject/aws_store_botocore_fix
Browse files Browse the repository at this point in the history
Aws store botocore fix
  • Loading branch information
munrojm authored Aug 15, 2023
2 parents 9890001 + d10c174 commit d95ba2b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 46 deletions.
72 changes: 29 additions & 43 deletions src/maggma/stores/aws.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
"""
Advanced Stores for connecting to AWS data
"""
"""Advanced Stores for connecting to AWS data."""
import threading
import warnings
import zlib
from concurrent.futures import wait
from concurrent.futures.thread import ThreadPoolExecutor
from hashlib import sha1
from io import BytesIO
from typing import Dict, Iterator, List, Optional, Tuple, Union
from json import dumps
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

Expand All @@ -30,7 +27,7 @@
class S3Store(Store):
"""
GridFS like storage using Amazon S3 and a regular store for indexing
Assumes Amazon AWS key and secret key are set in environment or default config file
Assumes Amazon AWS key and secret key are set in environment or default config file.
"""

def __init__(
Expand All @@ -50,7 +47,7 @@ def __init__(
**kwargs,
):
"""
Initializes an S3 Store
Initializes an S3 Store.
Args:
index: a store to use to index the S3 Bucket
Expand Down Expand Up @@ -103,15 +100,12 @@ def __init__(
def name(self) -> str:
"""
Returns:
a string representing this data source
a string representing this data source.
"""
return f"s3://{self.bucket}"

def connect(self, *args, **kwargs): # lgtm[py/conflicting-attributes]
"""
Connect to the source data
"""

"""Connect to the source data."""
session = self._get_session()
resource = session.resource("s3", endpoint_url=self.endpoint_url, **self.s3_resource_kwargs)

Expand All @@ -126,9 +120,7 @@ def connect(self, *args, **kwargs): # lgtm[py/conflicting-attributes]
self.index.connect(*args, **kwargs)

def close(self):
"""
Closes any connections
"""
"""Closes any connections."""
self.index.close()

self.s3.meta.client.close()
Expand All @@ -139,7 +131,7 @@ def close(self):
def _collection(self):
"""
Returns:
a handle to the pymongo collection object
a handle to the pymongo collection object.
Important:
Not guaranteed to exist in the future
Expand All @@ -149,12 +141,11 @@ def _collection(self):

def count(self, criteria: Optional[Dict] = None) -> int:
"""
Counts the number of documents matching the query criteria
Counts the number of documents matching the query criteria.
Args:
criteria: PyMongo filter for documents to count in
"""

return self.index.count(criteria)

def query(
Expand All @@ -166,7 +157,7 @@ def query(
limit: int = 0,
) -> Iterator[Dict]:
"""
Queries the Store for a set of documents
Queries the Store for a set of documents.
Args:
criteria: PyMongo filter for documents to search in
Expand All @@ -191,12 +182,15 @@ def query(
# TODO: THis is ugly and unsafe, do some real checking before pulling data
data = self.s3_bucket.Object(self.sub_dir + str(doc[self.key])).get()["Body"].read()
except botocore.exceptions.ClientError as e:
# If a client error is thrown, then check that it was a 404 error.
# If it was a 404 error, then the object does not exist.
error_code = int(e.response["Error"]["Code"])
if error_code == 404:
self.logger.error(f"Could not find S3 object {doc[self.key]}")
break
# If a client error is thrown, then check that it was a NoSuchKey or NoSuchBucket error.
# If it was a NoSuchKey error, then the object does not exist.
error_code = e.response["Error"]["Code"]
if error_code in ["NoSuchKey", "NoSuchBucket"]:
error_message = e.response["Error"]["Message"]
self.logger.error(
f"S3 returned '{error_message}' while querying '{self.bucket}' for '{doc[self.key]}'"
)
continue
else:
raise e

Expand All @@ -223,7 +217,7 @@ def _unpack(data: bytes, compressed: bool):

def distinct(self, field: str, criteria: Optional[Dict] = None, all_exist: bool = False) -> List:
"""
Get all distinct values for a field
Get all distinct values for a field.
Args:
field: the field(s) to get distinct values for
Expand Down Expand Up @@ -268,7 +262,7 @@ def groupby(

def ensure_index(self, key: str, unique: bool = False) -> bool:
"""
Tries to create an index and return true if it succeeded
Tries to create an index and return true if it succeeded.
Args:
key: single key to index
Expand All @@ -286,7 +280,7 @@ def update(
additional_metadata: Union[str, List[str], None] = None,
):
"""
Update documents into the Store
Update documents into the Store.
Args:
docs: the document or list of documents to update
Expand Down Expand Up @@ -335,9 +329,7 @@ def _get_session(self):
return None

def _get_bucket(self):
"""
If on the main thread return the bucket created above, else create a new bucket on each thread
"""
"""If on the main thread return the bucket created above, else create a new bucket on each thread."""
if threading.current_thread().name == "MainThread":
return self.s3_bucket
if not hasattr(self._thread_local, "s3_bucket"):
Expand All @@ -348,7 +340,7 @@ def _get_bucket(self):

def write_doc_to_s3(self, doc: Dict, search_keys: List[str]):
"""
Write the data to s3 and return the metadata to be inserted into the index db
Write the data to s3 and return the metadata to be inserted into the index db.
Args:
doc: the document
Expand Down Expand Up @@ -388,9 +380,7 @@ def write_doc_to_s3(self, doc: Dict, search_keys: List[str]):
s3_bucket.upload_fileobj(
Fileobj=BytesIO(data),
Key=self.sub_dir + str(doc[self.key]),
ExtraArgs={
"Metadata": {s3_to_mongo_keys[k]: str(v) for k, v in search_doc.items()}
},
ExtraArgs={"Metadata": {s3_to_mongo_keys[k]: str(v) for k, v in search_doc.items()}},
)

if lu_info is not None:
Expand All @@ -405,10 +395,7 @@ def write_doc_to_s3(self, doc: Dict, search_keys: List[str]):

@staticmethod
def _sanitize_key(key):
"""
Sanitize keys to store in S3/MinIO metadata.
"""

"""Sanitize keys to store in S3/MinIO metadata."""
# Any underscores are encoded as double dashes in metadata, since keys with
# underscores may be result in the corresponding HTTP header being stripped
# by certain server configurations (e.g. default nginx), leading to:
Expand All @@ -422,7 +409,7 @@ def _sanitize_key(key):

def remove_docs(self, criteria: Dict, remove_s3_object: bool = False):
"""
Remove docs matching the query dictionary
Remove docs matching the query dictionary.
Args:
criteria: query dictionary to match
Expand Down Expand Up @@ -467,7 +454,7 @@ def rebuild_index_from_s3_data(self, **kwargs):
"""
Rebuilds the index Store from the data in S3
Relies on the index document being stores as the metadata for the file
This can help recover lost databases
This can help recover lost databases.
"""
bucket = self.s3_bucket
objects = bucket.objects.filter(Prefix=self.sub_dir)
Expand All @@ -485,9 +472,8 @@ def rebuild_metadata_from_index(self, index_query: Optional[dict] = None):
Read data from the index store and populate the metadata of the S3 bucket
Force all of the keys to be lower case to be Minio compatible
Args:
index_query: query on the index store
index_query: query on the index store.
"""

qq = {} if index_query is None else index_query
for index_doc in self.index.query(qq):
key_ = self.sub_dir + index_doc[self.key]
Expand All @@ -508,7 +494,7 @@ def rebuild_metadata_from_index(self, index_query: Optional[dict] = None):
def __eq__(self, other: object) -> bool:
"""
Check equality for S3Store
other: other S3Store to compare with
other: other S3Store to compare with.
"""
if not isinstance(other, S3Store):
return False
Expand Down
6 changes: 3 additions & 3 deletions tests/stores/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def test_bad_import(mocker):


def test_aws_error(s3store):
def raise_exception_404(data):
error_response = {"Error": {"Code": 404}}
def raise_exception_NoSuchKey(data):
error_response = {"Error": {"Code": "NoSuchKey", "Message": "The specified key does not exist."}}
raise ClientError(error_response, "raise_exception")

def raise_exception_other(data):
Expand All @@ -227,7 +227,7 @@ def raise_exception_other(data):
s3store.query_one()

# Should just pass
s3store.s3_bucket.Object = raise_exception_404
s3store.s3_bucket.Object = raise_exception_NoSuchKey
s3store.query_one()


Expand Down

0 comments on commit d95ba2b

Please sign in to comment.