Skip to content
This repository has been archived by the owner on Nov 2, 2023. It is now read-only.

Implement Relation.cast #37

Merged
merged 2 commits into from
Sep 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/patito/Model/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Class properties
sql_types <sql_types>
unique_columns <unique_columns>
valid_dtypes <valid_dtypes>
valid_sql_types <valid_sql_types>

Class methods
-------------
Expand Down
8 changes: 8 additions & 0 deletions docs/api/patito/Model/valid_sql_types.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. _Model.valid_sql_types:

patito.Model.valid_sql_types
============================

.. currentmodule:: patito._docs

.. autoproperty:: Model.valid_sql_types
8 changes: 8 additions & 0 deletions docs/api/patito/Relation/cast.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. _Relation.cast:

patito.Relation.cast
====================

.. currentmodule:: patito

.. automethod:: Relation.cast
1 change: 1 addition & 0 deletions docs/api/patito/Relation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Methods
aggregate <aggregate>
all <all>
case <case>
cast <cast>
coalesce <coalesce>
count <count>
create_table <create_table>
Expand Down
138 changes: 122 additions & 16 deletions src/patito/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as np
import polars as pl
import pyarrow as pa # type: ignore
import pyarrow as pa # type: ignore[import]
from pydantic import create_model
from typing_extensions import Literal

Expand Down Expand Up @@ -62,27 +62,30 @@

# The SQL types supported by DuckDB
# See: https://duckdb.org/docs/sql/data_types/overview
# fmt: off
DuckDBSQLType = Literal[
"BIGINT",
"BOOLEAN",
"BLOB",
"BIGINT", "INT8", "LONG",
"BLOB", "BYTEA", "BINARY", "VARBINARY",
"BOOLEAN", "BOOL", "LOGICAL",
"DATE",
"DOUBLE",
"DECIMAL",
"DOUBLE", "FLOAT8", "NUMERIC", "DECIMAL",
"HUGEINT",
"INTEGER",
"REAL",
"SMALLINT",
"INTEGER", "INT4", "INT", "SIGNED",
"INTERVAL",
"REAL", "FLOAT4", "FLOAT",
"SMALLINT", "INT2", "SHORT",
"TIME",
"TIMESTAMP",
"TINYINT",
"TIMESTAMP", "DATETIME",
"TIMESTAMP WITH TIMEZONE", "TIMESTAMPTZ",
"TINYINT", "INT1",
"UBIGINT",
"UINTEGER",
"USMALLINT",
"UTINYINT",
"UUID",
"VARCHAR",
"VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING",
]
# fmt: on

# Used for backward-compatible patches
POLARS_VERSION: Optional[Tuple[int, int, int]]
Expand Down Expand Up @@ -566,6 +569,109 @@ def case(
new_relation = self._relation.project(f"*, {case_statement}")
return self._wrap(relation=new_relation, schema_change=True)

def cast(
self: RelationType,
model: Optional[ModelType] = None,
strict: bool = False,
include: Optional[Collection[str]] = None,
exclude: Optional[Collection[str]] = None,
) -> RelationType:
"""
Cast the columns of the relation to types compatible with the associated model.

The associated model must either be set by invoking
:ref:`Relation.set_model() <Relation.set_model>` or provided with the ``model``
parameter.

Any columns of the relation that are not part of the given model schema will be
left as-is.

Args:
model: If :ref:`Relation.set_model() <Relation.set_model>` has not been
invoked or is intended to be overwritten.
strict: If set to ``False``, columns which are technically compliant with
the specified field type, will not be casted. For example, a column
annotated with ``int`` is technically compliant with ``SMALLINT``, even
if ``INTEGER`` is the default SQL type associated with ``int``-annotated
fields. If ``strict`` is set to ``True``, the resulting dtypes will
be forced to the default dtype associated with each python type.
include: If provided, only the given columns will be casted.
exclude: If provided, the given columns will `not` be casted.

Returns:
New relation where the columns have been casted according to the model
schema.

Examples:
>>> import patito as pt
>>> class Schema(pt.Model):
... float_column: float
...
>>> relation = pt.Relation("select 1 as float_column")
>>> relation.types["float_column"]
'INTEGER'
>>> relation.cast(model=Schema).types["float_column"]
'DOUBLE'

>>> relation = pt.Relation("select 1::FLOAT as float_column")
>>> relation.cast(model=Schema).types["float_column"]
'FLOAT'
>>> relation.cast(model=Schema, strict=True).types["float_column"]
'DOUBLE'

>>> class Schema(pt.Model):
... column_1: float
... column_2: float
...
>>> relation = pt.Relation("select 1 as column_1, 2 as column_2").set_model(
... Schema
... )
>>> relation.types
{'column_1': 'INTEGER', 'column_2': 'INTEGER'}
>>> relation.cast(include=["column_1"]).types
{'column_1': 'DOUBLE', 'column_2': 'INTEGER'}
>>> relation.cast(exclude=["column_1"]).types
{'column_1': 'INTEGER', 'column_2': 'DOUBLE'}
"""
if model is not None:
relation = self.set_model(model)
schema = model
elif self.model is not None:
relation = self
schema = cast(ModelType, self.model)
else:
class_name = self.__class__.__name__
raise TypeError(
f"{class_name}.cast() invoked without "
f"{class_name}.model having been set! "
f"You should invoke {class_name}.set_model() first "
"or explicitly provide a model to .cast()."
)

if include is not None and exclude is not None:
raise ValueError(
f"Both include and exclude provided to {self.__class__.__name__}.cast()!"
)
elif include is not None:
include = set(include)
elif exclude is not None:
include = set(relation.columns) - set(exclude)
else:
include = set(relation.columns)

new_columns = []
for column, current_type in relation.types.items():
if column not in schema.columns:
new_columns.append(column)
elif column in include and (
strict or current_type not in schema.valid_sql_types[column]
):
new_type = schema.sql_types[column]
new_columns.append(f"{column}::{new_type} as {column}")
else:
new_columns.append(column)
return cast(RelationType, self.select(*new_columns))

def coalesce(
self: RelationType,
**column_expressions: Union[str, int, float],
Expand Down Expand Up @@ -1871,7 +1977,7 @@ def with_missing_defaultable_columns(
┌────────────────────┬────────────────┬────────────────────────┐
│ non_default_column ┆ default_column ┆ another_default_column │
│ --- ┆ --- ┆ --- │
│ i32 ┆ i32 ┆ i64
│ i32 ┆ i32 ┆ i32
╞════════════════════╪════════════════╪════════════════════════╡
│ 1 ┆ 2 ┆ 42 │
└────────────────────┴────────────────┴────────────────────────┘
Expand Down Expand Up @@ -1959,7 +2065,7 @@ def with_missing_nullable_columns(
┌─────────────────┬─────────────────────────┐
│ nullable_column ┆ another_nullable_column │
│ --- ┆ --- │
│ i32 ┆ i64
│ i32 ┆ i32
╞═════════════════╪═════════════════════════╡
│ 1 ┆ null │
└─────────────────┴─────────────────────────┘
Expand Down Expand Up @@ -2021,9 +2127,9 @@ def __getitem__(self, key: Union[str, Iterable[str]]) -> Relation:
"""
Return Relation with selected columns.

Uses :ref:`Relation.project()<Relation.project>` under-the-hood in order to
Uses :ref:`Relation.select()<Relation.select>` under-the-hood in order to
perform the selection. Can technically be used to rename columns,
define derived columns, and so on, but prefer the use of Relation.project() for
define derived columns, and so on, but prefer the use of Relation.select() for
such use cases.

Args:
Expand Down
Loading