Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Athena read_sql_query cache errors for qmark style parametrized queries #2957

42 changes: 26 additions & 16 deletions awswrangler/athena/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import threading
from heapq import heappop, heappush
from typing import TYPE_CHECKING, Any, Match, NamedTuple
from typing import TYPE_CHECKING, Match, NamedTuple

import boto3

Expand All @@ -23,23 +23,23 @@ class _CacheInfo(NamedTuple):
has_valid_cache: bool
file_format: str | None = None
query_execution_id: str | None = None
query_execution_payload: dict[str, Any] | None = None
query_execution_payload: "QueryExecutionTypeDef" | None = None


class _LocalMetadataCacheManager:
def __init__(self) -> None:
self._lock: threading.Lock = threading.Lock()
self._cache: dict[str, Any] = {}
self._cache: dict[str, "QueryExecutionTypeDef"] = {}
self._pqueue: list[tuple[datetime.datetime, str]] = []
self._max_cache_size = 100

def update_cache(self, items: list[dict[str, Any]]) -> None:
def update_cache(self, items: list["QueryExecutionTypeDef"]) -> None:
"""
Update the local metadata cache with new query metadata.

Parameters
----------
items : List[Dict[str, Any]]
items
List of query execution metadata which is returned by boto3 `batch_get_query_execution()`.
"""
with self._lock:
Expand All @@ -62,18 +62,17 @@ def update_cache(self, items: list[dict[str, Any]]) -> None:
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
self._cache[item["QueryExecutionId"]] = item

def sorted_successful_generator(self) -> list[dict[str, Any]]:
def sorted_successful_generator(self) -> list["QueryExecutionTypeDef"]:
"""
Sorts the entries in the local cache based on query Completion DateTime.

This is useful to guarantee LRU caching rules.

Returns
-------
List[Dict[str, Any]]
Returns successful DDL and DML queries sorted by query completion time.
"""
filtered: list[dict[str, Any]] = []
filtered: list["QueryExecutionTypeDef"] = []
for query in self._cache.values():
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
filtered.append(query)
Expand Down Expand Up @@ -111,13 +110,13 @@ def _parse_select_query_from_possible_ctas(possible_ctas: str) -> str | None:
return None


def _compare_query_string(sql: str, other: str) -> bool:
def _compare_query_string(
sql: str, other: str, sql_params: list[str] | None = None, other_params: list[str] | None = None
) -> bool:
comparison_query = _prepare_query_string_for_comparison(query_string=other)
_logger.debug("sql: %s", sql)
_logger.debug("comparison_query: %s", comparison_query)
if sql == comparison_query:
return True
return False
return sql == comparison_query and sql_params == other_params


def _prepare_query_string_for_comparison(query_string: str) -> str:
Expand All @@ -135,7 +134,7 @@ def _get_last_query_infos(
max_remote_cache_entries: int,
boto3_session: boto3.Session | None = None,
workgroup: str | None = None,
) -> list[dict[str, Any]]:
) -> list["QueryExecutionTypeDef"]:
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
client_athena = _utils.client(service_name="athena", session=boto3_session)
page_size = 50
Expand All @@ -160,14 +159,15 @@ def _get_last_query_infos(
QueryExecutionIds=uncached_ids[i : i + page_size],
).get("QueryExecutions")
)
_cache_manager.update_cache(new_execution_data) # type: ignore[arg-type]
_cache_manager.update_cache(new_execution_data)
return _cache_manager.sorted_successful_generator()


def _check_for_cached_results(
sql: str,
boto3_session: boto3.Session | None,
workgroup: str | None,
params: list[str] | None = None,
athena_cache_settings: typing.AthenaCacheSettings | None = None,
) -> _CacheInfo:
"""
Expand Down Expand Up @@ -207,15 +207,25 @@ def _check_for_cached_results(
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
parsed_query: str | None = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
if parsed_query is not None:
if _compare_query_string(sql=comparable_sql, other=parsed_query):
if _compare_query_string(
sql=comparable_sql,
other=parsed_query,
sql_params=params,
other_params=query_info.get("ExecutionParameters"),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this effectively disable caching for qmark queries since parameter values are not returned? https://docs.aws.amazon.com/athena/latest/APIReference/API_QueryExecution.html#athena-Type-QueryExecution-ExecutionParameters "The list of parameters is not returned in the response."

):
return _CacheInfo(
has_valid_cache=True,
file_format="parquet",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
if _compare_query_string(
sql=comparable_sql,
other=query_info["Query"],
sql_params=params,
other_params=query_info.get("ExecutionParameters"),
):
return _CacheInfo(
has_valid_cache=True,
file_format="csv",
Expand Down
1 change: 1 addition & 0 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,7 @@ def read_sql_query(
if not client_request_token:
cache_info: _CacheInfo = _check_for_cached_results(
sql=sql,
params=params if paramstyle == "qmark" else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not cache paramstyle="named" parameters as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The named parameters end up built into the SQL query itself client-side. So the query itself contains the values already.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, makes sense.

boto3_session=boto3_session,
workgroup=workgroup,
athena_cache_settings=athena_cache_settings,
Expand Down
22 changes: 13 additions & 9 deletions awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ._cache import _cache_manager, _LocalMetadataCacheManager

if TYPE_CHECKING:
from mypy_boto3_athena.type_defs import QueryExecutionTypeDef
from mypy_boto3_glue.type_defs import ColumnOutputTypeDef

_QUERY_FINAL_STATES: list[str] = ["FAILED", "SUCCEEDED", "CANCELLED"]
Expand All @@ -53,7 +54,7 @@ class _QueryMetadata(NamedTuple):
binaries: list[str]
output_location: str | None
manifest_location: str | None
raw_payload: dict[str, Any]
raw_payload: "QueryExecutionTypeDef"


class _WorkGroupConfig(NamedTuple):
Expand Down Expand Up @@ -214,7 +215,7 @@ def _get_query_metadata(
query_execution_id: str,
boto3_session: boto3.Session | None = None,
categories: list[str] | None = None,
query_execution_payload: dict[str, Any] | None = None,
query_execution_payload: "QueryExecutionTypeDef" | None = None,
metadata_cache_manager: _LocalMetadataCacheManager | None = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
execution_params: list[str] | None = None,
Expand All @@ -225,12 +226,15 @@ def _get_query_metadata(
if query_execution_payload["Status"]["State"] != "SUCCEEDED":
reason: str = query_execution_payload["Status"]["StateChangeReason"]
raise exceptions.QueryFailed(f"Query error: {reason}")
_query_execution_payload: dict[str, Any] = query_execution_payload
_query_execution_payload = query_execution_payload
else:
_query_execution_payload = _executions.wait_query(
query_execution_id=query_execution_id,
boto3_session=boto3_session,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
_query_execution_payload = cast(
"QueryExecutionTypeDef",
_executions.wait_query(
query_execution_id=query_execution_id,
boto3_session=boto3_session,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
),
)
cols_types: dict[str, str] = get_query_columns_types(
query_execution_id=query_execution_id, boto3_session=boto3_session
Expand Down Expand Up @@ -266,8 +270,8 @@ def _get_query_metadata(
if "ResultConfiguration" in _query_execution_payload:
output_location = _query_execution_payload["ResultConfiguration"].get("OutputLocation")

athena_statistics: dict[str, int | str] = _query_execution_payload.get("Statistics", {})
manifest_location: str | None = str(athena_statistics.get("DataManifestLocation"))
athena_statistics = _query_execution_payload.get("Statistics", {})
manifest_location: str | None = athena_statistics.get("DataManifestLocation")

if metadata_cache_manager is not None and query_execution_id not in metadata_cache_manager:
metadata_cache_manager.update_cache(items=[_query_execution_payload])
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,65 @@ def test_athena_paramstyle_qmark_parameters(
assert len(df_out) == 1


@pytest.mark.parametrize(
"ctas_approach,unload_approach",
[
pytest.param(False, False, id="regular"),
pytest.param(True, False, id="ctas"),
pytest.param(False, True, id="unload"),
],
)
def test_athena_paramstyle_qmark_with_caching(
path: str,
path2: str,
glue_database: str,
glue_table: str,
workgroup0: str,
ctas_approach: bool,
unload_approach: bool,
) -> None:
wr.s3.to_parquet(
df=get_df(),
path=path,
index=False,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
partition_cols=["par0", "par1"],
)

df_out = wr.athena.read_sql_query(
sql=f"SELECT * FROM {glue_table} WHERE string = ?",
database=glue_database,
ctas_approach=ctas_approach,
unload_approach=unload_approach,
workgroup=workgroup0,
params=["Washington"],
paramstyle="qmark",
keep_files=False,
s3_output=path2,
athena_cache_settings={"max_cache_seconds": 300},
)

assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Washington"

df_out = wr.athena.read_sql_query(
sql=f"SELECT * FROM {glue_table} WHERE string = ?",
database=glue_database,
ctas_approach=ctas_approach,
unload_approach=unload_approach,
workgroup=workgroup0,
params=["Seattle"],
paramstyle="qmark",
keep_files=False,
s3_output=path2,
athena_cache_settings={"max_cache_seconds": 300},
)

assert len(df_out) == 1 and df_out.iloc[0]["string"] == "Seattle"


def test_read_sql_query_parameter_formatting_respects_prefixes(path, glue_database, glue_table, workgroup0):
wr.s3.to_parquet(
df=get_df(),
Expand Down
Loading