Skip to content

Commit

Permalink
S3: select_object_content() now returns the proper Stats (#8436)
Browse files Browse the repository at this point in the history
  • Loading branch information
bblommers authored Dec 26, 2024
1 parent 9a62985 commit 0b2b687
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 25 deletions.
8 changes: 4 additions & 4 deletions moto/s3/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2994,7 +2994,7 @@ def select_object_content(
select_query: str,
input_details: Dict[str, Any],
output_details: Dict[str, Any],
) -> List[bytes]:
) -> Tuple[List[bytes], int]:
"""
Highly experimental. Please raise an issue if you find any inconsistencies/bugs.
Expand Down Expand Up @@ -3023,7 +3023,7 @@ def select_object_content(
"FileHeaderInfo", ""
) == "USE"
query_input = csv_to_json(query_input, use_headers)
query_result = parse_query(query_input, select_query) # type: ignore
query_result, bytes_scanned = parse_query(query_input, select_query) # type: ignore

record_delimiter = "\n"
if "JSON" in output_details:
Expand All @@ -3041,7 +3041,7 @@ def select_object_content(
from py_partiql_parser import json_to_csv

query_result = json_to_csv(query_result, field_delim, record_delimiter)
return [query_result.encode("utf-8")] # type: ignore
return [query_result.encode("utf-8")], bytes_scanned # type: ignore

else:
from py_partiql_parser import SelectEncoder
Expand All @@ -3052,7 +3052,7 @@ def select_object_content(
+ record_delimiter
).encode("utf-8")
for x in query_result
]
], bytes_scanned

def restore_object(
self, bucket_name: str, key_name: str, days: Optional[str], type_: Optional[str]
Expand Down
4 changes: 2 additions & 2 deletions moto/s3/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2582,10 +2582,10 @@ def _key_response_post(
select_query = request["Expression"]
input_details = request["InputSerialization"]
output_details = request["OutputSerialization"]
results = self.backend.select_object_content(
results, bytes_scanned = self.backend.select_object_content(
bucket_name, key_name, select_query, input_details, output_details
)
return 200, {}, serialize_select(results)
return 200, {}, serialize_select(results, bytes_scanned)

else:
raise NotImplementedError(
Expand Down
31 changes: 21 additions & 10 deletions moto/s3/select_object_content.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import binascii
import struct
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple


def parse_query(text_input: str, query: str) -> List[Dict[str, Any]]:
def parse_query(text_input: str, query: str) -> Tuple[List[Dict[str, Any]], int]:
from py_partiql_parser import S3SelectParser

return S3SelectParser(source_data={"s3object": text_input}).parse(query)
parser = S3SelectParser(source_data=text_input)
result = parser.parse(query)
return result, parser.bytes_scanned


def _create_header(key: bytes, value: bytes) -> bytes:
Expand All @@ -33,9 +35,11 @@ def _create_message(
return prelude + prelude_crc + headers + payload + message_crc


def _create_stats_message() -> bytes:
stats = b"""<Stats><BytesScanned>24</BytesScanned><BytesProcessed>24</BytesProcessed><BytesReturned>22</BytesReturned></Stats>"""
return _create_message(content_type=b"text/xml", event_type=b"Stats", payload=stats)
def _create_stats_message(bytes_scanned: int, bytes_returned: int) -> bytes:
stats = f"<Stats><BytesScanned>{bytes_scanned}</BytesScanned><BytesProcessed>{bytes_scanned}</BytesProcessed><BytesReturned>{bytes_returned}</BytesReturned></Stats>"
return _create_message(
content_type=b"text/xml", event_type=b"Stats", payload=stats.encode("utf-8")
)


def _create_data_message(payload: bytes) -> bytes:
Expand All @@ -49,8 +53,15 @@ def _create_end_message() -> bytes:
return _create_message(content_type=None, event_type=b"End", payload=b"")


def serialize_select(data_list: List[bytes]) -> bytes:
response = b""
def serialize_select(data_list: List[bytes], bytes_scanned: int) -> bytes:
bytes_returned = 0
all_data = b""
for data in data_list:
response += _create_data_message(data)
return response + _create_stats_message() + _create_end_message()
bytes_returned += len(data)
all_data += data
response = _create_data_message(all_data)
return (
response
+ _create_stats_message(bytes_scanned, bytes_returned)
+ _create_end_message()
)
18 changes: 9 additions & 9 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ all =
jsonschema
openapi-spec-validator>=0.5.0
pyparsing>=3.0.7
py-partiql-parser==0.5.6
py-partiql-parser==0.6.1
aws-xray-sdk!=0.96,>=0.93
setuptools
multipart
Expand All @@ -70,7 +70,7 @@ proxy =
cfn-lint>=0.40.0
openapi-spec-validator>=0.5.0
pyparsing>=3.0.7
py-partiql-parser==0.5.6
py-partiql-parser==0.6.1
aws-xray-sdk!=0.96,>=0.93
setuptools
multipart
Expand All @@ -84,7 +84,7 @@ server =
cfn-lint>=0.40.0
openapi-spec-validator>=0.5.0
pyparsing>=3.0.7
py-partiql-parser==0.5.6
py-partiql-parser==0.6.1
aws-xray-sdk!=0.96,>=0.93
setuptools
flask!=2.2.0,!=2.2.1
Expand Down Expand Up @@ -118,7 +118,7 @@ cloudformation =
cfn-lint>=0.40.0
openapi-spec-validator>=0.5.0
pyparsing>=3.0.7
py-partiql-parser==0.5.6
py-partiql-parser==0.6.1
aws-xray-sdk!=0.96,>=0.93
setuptools
cloudfront =
Expand All @@ -140,10 +140,10 @@ dms =
ds =
dynamodb =
docker>=3.0.0
py-partiql-parser==0.5.6
py-partiql-parser==0.6.1
dynamodbstreams =
docker>=3.0.0
py-partiql-parser==0.5.6
py-partiql-parser==0.6.1
ebs =
ec2 =
ec2instanceconnect =
Expand Down Expand Up @@ -206,15 +206,15 @@ resourcegroupstaggingapi =
cfn-lint>=0.40.0
openapi-spec-validator>=0.5.0
pyparsing>=3.0.7
py-partiql-parser==0.5.6
py-partiql-parser==0.6.1
route53 =
route53resolver =
s3 =
PyYAML>=5.1
py-partiql-parser==0.5.6
py-partiql-parser==0.6.1
s3crc32c =
PyYAML>=5.1
py-partiql-parser==0.5.6
py-partiql-parser==0.6.1
crc32c
s3control =
sagemaker =
Expand Down
65 changes: 65 additions & 0 deletions tests/test_s3/test_s3_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
"country": "USA",
}
]
LONG_JSON = "".join(
[json.dumps(x) for x in [{"a": {f"b{i}": "s"}, "c": f"d{i}"} for i in range(5)]]
)
SIMPLE_LIST = [SIMPLE_JSON, SIMPLE_JSON2]
SIMPLE_CSV = """a,b,c
e,r,f
Expand Down Expand Up @@ -66,6 +69,9 @@ def create_test_files(bucket_name):
client.put_object(
Bucket=bucket_name, Key="csv.bz2", Body=bz2.compress(SIMPLE_CSV.encode("utf-8"))
)
client.put_object(
Bucket=bucket_name, Key="long.json", Body=LONG_JSON.encode("utf-8")
)


@pytest.mark.aws_verified
Expand Down Expand Up @@ -110,6 +116,11 @@ def test_count_function(bucket_name=None):
)
result = list(content["Payload"])
assert {"Records": {"Payload": b'{"_1":1},'}} in result
assert {
"Stats": {
"Details": {"BytesScanned": 36, "BytesProcessed": 36, "BytesReturned": 9}
}
} in result


@pytest.mark.aws_verified
Expand Down Expand Up @@ -249,6 +260,60 @@ def test_nested_json__select_all(bucket_name=None):
assert json.loads(records[:-1]) == NESTED_JSON


@pytest.mark.aws_verified
@s3_aws_verified
def test_long_json__select_subdocument(bucket_name=None):
client = boto3.client("s3", "us-east-1")
create_test_files(bucket_name)
content = client.select_object_content(
Bucket=bucket_name,
Key="long.json",
Expression="select * from s3object[*].a",
ExpressionType="SQL",
InputSerialization={"JSON": {"Type": "DOCUMENT"}},
OutputSerialization={"JSON": {"RecordDelimiter": ","}},
)
result = list(content["Payload"])
record = result[0]["Records"]
payload = record["Payload"].decode("utf-8")
assert '{"b0":"s"}' in payload
assert '{"b1":"s"}' in payload
assert '{"b2":"s"}' in payload
assert '{"b3":"s"}' in payload
assert '{"b0":"s"}' in payload

assert result[1]["Stats"]["Details"] == {
"BytesScanned": 145,
"BytesProcessed": 145,
"BytesReturned": 55,
}


@pytest.mark.aws_verified
@s3_aws_verified
def test_long_json__where_filter(bucket_name=None):
client = boto3.client("s3", "us-east-1")
create_test_files(bucket_name)
content = client.select_object_content(
Bucket=bucket_name,
Key="long.json",
Expression="select * from s3object[*] s where s.c = 'd2'",
ExpressionType="SQL",
InputSerialization={"JSON": {"Type": "DOCUMENT"}},
OutputSerialization={"JSON": {"RecordDelimiter": ","}},
)
result = list(content["Payload"])
record = result[0]["Records"]
payload = record["Payload"].decode("utf-8")
assert '{"a":{"b2":"s"},"c":"d2"}' in payload

assert result[1]["Stats"]["Details"] == {
"BytesScanned": 145,
"BytesProcessed": 145,
"BytesReturned": 26,
}


@pytest.mark.aws_verified
@s3_aws_verified
def test_gzipped_json(bucket_name=None):
Expand Down

0 comments on commit 0b2b687

Please sign in to comment.