Skip to content
This repository has been archived by the owner on Nov 2, 2023. It is now read-only.

Commit

Permalink
Use QueryConstructor/QueryHandler consistently
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobGM committed Dec 19, 2022
1 parent 54b8e7f commit b0d9836
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions src/patito/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
CACHE_VERSION = 1


class QueryFunc(Protocol[P]):
class QueryConstructor(Protocol[P]):
"""A function taking arbitrary arguments and returning an SQL query string."""

__name__: str
Expand All @@ -56,26 +56,27 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> str:
... # pragma: no cover


class WrappedQueryFunc(Generic[P, DF]):
class QueryHandler(Generic[P, DF]):
"""A class acting as a function that returns a polars.DataFrame when called."""

_cache: Union[bool, Path]

def __init__( # noqa: C901
self,
wrapped_function: QueryFunc[P],
query_constructor: QueryConstructor[P],
cache_directory: Path,
query_handler: Callable[..., pa.Table],
ttl: timedelta,
lazy: bool = False,
cache: Union[str, Path, bool] = False,
model: Union[Type["Model"], None] = None,
query_executor_kwargs: Optional[Dict[Any, Any]] = None,
query_handler_kwargs: Optional[Dict[Any, Any]] = None,
) -> None:
"""Convert SQL string query function to polars.DataFrame function.
"""
Convert SQL string query function to polars.DataFrame function.
Args:
wrapped_function: A function that takes arbitrary arguments and returns
query_constructor: A function that takes arbitrary arguments and returns
an SQL query string.
cache_directory: Path to directory to store parquet cache files in.
query_handler: Function used to execute SQL queries and return arrow
Expand All @@ -98,27 +99,27 @@ def __init__( # noqa: C901
self._cache = cache_directory.joinpath(cache)
else:
self._cache = cache
self._wrapped_function = wrapped_function
self._query_constructor = query_constructor
self.cache_directory = cache_directory

self._query_executor_kwargs = query_executor_kwargs or {}
self._query_handler_kwargs = query_handler_kwargs or {}
# Unless explicitly specified otherwise by the end-user, we retrieve query
# results as arrow tables with column types directly supported by polars.
# Otherwise the resulting parquet files that are written to disk can not be
# lazily read with polars.scan_parquet.
self._query_executor_kwargs.setdefault("cast_to_polars_equivalent_types", True)
self._query_handler_kwargs.setdefault("cast_to_polars_equivalent_types", True)

# We construct the new function with the same parameter signature as
# wrapped_function, but with polars.DataFrame as the return type.
@wraps(wrapped_function)
@wraps(query_constructor)
def cached_func(*args: P.args, **kwargs: P.kwargs) -> DF:
sql_query = wrapped_function(*args, **kwargs)
query = query_constructor(*args, **kwargs)
cache_path = self.cache_path(*args, **kwargs)
if cache_path and cache_path.exists():
metadata: Dict[bytes, bytes] = pq.read_schema(cache_path).metadata or {}

# Check if the cache file was produced by an identical SQL query
is_same_query = metadata.get(b"query") == sql_query.encode("utf-8")
is_same_query = metadata.get(b"query") == query.encode("utf-8")

# Check if the cache is too old to be re-used
cache_created_time = datetime.fromisoformat(
Expand All @@ -145,7 +146,7 @@ def cached_func(*args: P.args, **kwargs: P.kwargs) -> DF:
else:
return pl.read_parquet(cache_path) # type: ignore

arrow_table = query_handler(sql_query, **self._query_executor_kwargs)
arrow_table = query_handler(query, **self._query_handler_kwargs)
if cache_path:
cache_path.parent.mkdir(parents=True, exist_ok=True)
# We write the cache *before* any potential model validation since
Expand All @@ -158,7 +159,7 @@ def cached_func(*args: P.args, **kwargs: P.kwargs) -> DF:
metadata = arrow_table.schema.metadata
metadata[
b"wrapped_function_name"
] = self._wrapped_function.__name__.encode("utf-8")
] = self._query_constructor.__name__.encode("utf-8")
# Store the cache version as an 16-bit unsigned little-endian number
metadata[b"cache_version"] = CACHE_VERSION.to_bytes(
length=16,
Expand Down Expand Up @@ -203,14 +204,14 @@ def cache_path(self, *args: P.args, **kwargs: P.kwargs) -> Optional[Path]:
A deterministic path to a parquet cache. None if caching is disabled.
"""
# We convert args+kwargs to kwargs-only and use it to format the string
function_signature = inspect.signature(self._wrapped_function)
function_signature = inspect.signature(self._query_constructor)
bound_arguments = function_signature.bind(*args, **kwargs)

if isinstance(self._cache, Path):
# Interpret relative paths relative to the main query cache directory
return Path(str(self._cache).format(**bound_arguments.arguments))
elif self._cache is True:
directory: Path = self.cache_directory / self._wrapped_function.__name__
directory: Path = self.cache_directory / self._query_constructor.__name__
directory.mkdir(exist_ok=True, parents=True)
sql_query = self.sql_query(*args, **kwargs)
sql_query_hash = hashlib.sha1( # noqa: S324,S303
Expand All @@ -234,7 +235,7 @@ def sql_query(self, *args: P.args, **kwargs: P.kwargs) -> str:
Returns:
The SQL query string produced for the given input parameters.
"""
return self._wrapped_function(*args, **kwargs)
return self._query_constructor(*args, **kwargs)

def refresh_cache(self, *args: P.args, **kwargs: P.kwargs) -> DF:
"""
Expand All @@ -260,7 +261,7 @@ def clear_caches(self) -> None:

if self._cache is True:
glob_pattern = str(
self.cache_directory / self._wrapped_function.__name__ / "*.parquet"
self.cache_directory / self._query_constructor.__name__ / "*.parquet"
)
else:
# We replace all formatting specifiers of the form '{variable}' with
Expand All @@ -282,7 +283,7 @@ def clear_caches(self) -> None:
)
if metadata.get(
b"wrapped_function_name"
) == self._wrapped_function.__name__.encode("utf-8"):
) == self._query_constructor.__name__.encode("utf-8"):
Path(parquet_path).unlink()
except Exception: # noqa: S112
# If we can't read the parquet metadata for some reason,
Expand Down Expand Up @@ -418,7 +419,7 @@ def query(
ttl: Optional[timedelta] = None,
model: Union[Type["Model"], None] = None,
**kwargs: Any, # noqa: ANN401
) -> Callable[[QueryFunc[P]], WrappedQueryFunc[P, pl.DataFrame]]:
) -> Callable[[QueryConstructor[P]], QueryHandler[P, pl.DataFrame]]:
... # pragma: no cover

# With lazy = True a LazyFrame-producing wrapper is returned
Expand All @@ -431,7 +432,7 @@ def query(
ttl: Optional[timedelta] = None,
model: Union[Type["Model"], None] = None,
**kwargs: Any, # noqa: ANN401
) -> Callable[[QueryFunc[P]], WrappedQueryFunc[P, pl.LazyFrame]]:
) -> Callable[[QueryConstructor[P]], QueryHandler[P, pl.LazyFrame]]:
... # pragma: no cover

def query(
Expand All @@ -443,7 +444,7 @@ def query(
model: Union[Type["Model"], None] = None,
**kwargs: Any, # noqa: ANN401
) -> Callable[
[QueryFunc[P]], WrappedQueryFunc[P, Union[pl.DataFrame, pl.LazyFrame]]
[QueryConstructor[P]], QueryHandler[P, Union[pl.DataFrame, pl.LazyFrame]]
]:
"""
Execute the returned query string and return a polars dataframe.
Expand Down Expand Up @@ -478,16 +479,16 @@ def query(
specified by the original function's return string.
"""

def wrapper(wrapped_function: QueryFunc) -> WrappedQueryFunc:
return WrappedQueryFunc(
wrapped_function=wrapped_function,
def wrapper(query_constructor: QueryConstructor) -> QueryHandler:
return QueryHandler(
query_constructor=query_constructor,
lazy=lazy,
cache=cache,
ttl=ttl if ttl is not None else self.default_ttl,
cache_directory=self.cache_directory,
model=model,
query_handler=_with_query_metadata(self.query_handler),
query_executor_kwargs=kwargs,
query_handler_kwargs=kwargs,
)

return wrapper
Expand Down

0 comments on commit b0d9836

Please sign in to comment.