Skip to content
This repository has been archived by the owner on Nov 2, 2023. It is now read-only.

Commit

Permalink
Implement Relation.cast (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobGM authored Sep 4, 2022
1 parent e500868 commit 8fcccae
Show file tree
Hide file tree
Showing 8 changed files with 437 additions and 60 deletions.
1 change: 1 addition & 0 deletions docs/api/patito/Model/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Class properties
sql_types <sql_types>
unique_columns <unique_columns>
valid_dtypes <valid_dtypes>
valid_sql_types <valid_sql_types>

Class methods
-------------
Expand Down
8 changes: 8 additions & 0 deletions docs/api/patito/Model/valid_sql_types.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. _Model.valid_sql_types:

patito.Model.valid_sql_types
============================

.. currentmodule:: patito._docs

.. autoproperty:: Model.valid_sql_types
8 changes: 8 additions & 0 deletions docs/api/patito/Relation/cast.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.. _Relation.cast:

patito.Relation.cast
====================

.. currentmodule:: patito

.. automethod:: Relation.cast
1 change: 1 addition & 0 deletions docs/api/patito/Relation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Methods
aggregate <aggregate>
all <all>
case <case>
cast <cast>
coalesce <coalesce>
count <count>
create_table <create_table>
Expand Down
138 changes: 122 additions & 16 deletions src/patito/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as np
import polars as pl
import pyarrow as pa # type: ignore
import pyarrow as pa # type: ignore[import]
from pydantic import create_model
from typing_extensions import Literal

Expand Down Expand Up @@ -62,27 +62,30 @@

# The SQL types supported by DuckDB
# See: https://duckdb.org/docs/sql/data_types/overview
# fmt: off
DuckDBSQLType = Literal[
"BIGINT",
"BOOLEAN",
"BLOB",
"BIGINT", "INT8", "LONG",
"BLOB", "BYTEA", "BINARY", "VARBINARY",
"BOOLEAN", "BOOL", "LOGICAL",
"DATE",
"DOUBLE",
"DECIMAL",
"DOUBLE", "FLOAT8", "NUMERIC", "DECIMAL",
"HUGEINT",
"INTEGER",
"REAL",
"SMALLINT",
"INTEGER", "INT4", "INT", "SIGNED",
"INTERVAL",
"REAL", "FLOAT4", "FLOAT",
"SMALLINT", "INT2", "SHORT",
"TIME",
"TIMESTAMP",
"TINYINT",
"TIMESTAMP", "DATETIME",
"TIMESTAMP WITH TIMEZONE", "TIMESTAMPTZ",
"TINYINT", "INT1",
"UBIGINT",
"UINTEGER",
"USMALLINT",
"UTINYINT",
"UUID",
"VARCHAR",
"VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING",
]
# fmt: on

# Used for backward-compatible patches
POLARS_VERSION: Optional[Tuple[int, int, int]]
Expand Down Expand Up @@ -566,6 +569,109 @@ def case(
new_relation = self._relation.project(f"*, {case_statement}")
return self._wrap(relation=new_relation, schema_change=True)

def cast(
self: RelationType,
model: Optional[ModelType] = None,
strict: bool = False,
include: Optional[Collection[str]] = None,
exclude: Optional[Collection[str]] = None,
) -> RelationType:
"""
Cast the columns of the relation to types compatible with the associated model.
The associated model must either be set by invoking
:ref:`Relation.set_model() <Relation.set_model>` or provided with the ``model``
parameter.
Any columns of the relation that are not part of the given model schema will be
left as-is.
Args:
model: If :ref:`Relation.set_model() <Relation.set_model>` has not been
invoked or is intended to be overwritten.
strict: If set to ``False``, columns which are technically compliant with
the specified field type, will not be casted. For example, a column
annotated with ``int`` is technically compliant with ``SMALLINT``, even
if ``INTEGER`` is the default SQL type associated with ``int``-annotated
fields. If ``strict`` is set to ``True``, the resulting dtypes will
be forced to the default dtype associated with each python type.
include: If provided, only the given columns will be casted.
exclude: If provided, the given columns will `not` be casted.
Returns:
New relation where the columns have been casted according to the model
schema.
Examples:
>>> import patito as pt
>>> class Schema(pt.Model):
... float_column: float
...
>>> relation = pt.Relation("select 1 as float_column")
>>> relation.types["float_column"]
'INTEGER'
>>> relation.cast(model=Schema).types["float_column"]
'DOUBLE'
>>> relation = pt.Relation("select 1::FLOAT as float_column")
>>> relation.cast(model=Schema).types["float_column"]
'FLOAT'
>>> relation.cast(model=Schema, strict=True).types["float_column"]
'DOUBLE'
>>> class Schema(pt.Model):
... column_1: float
... column_2: float
...
>>> relation = pt.Relation("select 1 as column_1, 2 as column_2").set_model(
... Schema
... )
>>> relation.types
{'column_1': 'INTEGER', 'column_2': 'INTEGER'}
>>> relation.cast(include=["column_1"]).types
{'column_1': 'DOUBLE', 'column_2': 'INTEGER'}
>>> relation.cast(exclude=["column_1"]).types
{'column_1': 'INTEGER', 'column_2': 'DOUBLE'}
"""
if model is not None:
relation = self.set_model(model)
schema = model
elif self.model is not None:
relation = self
schema = cast(ModelType, self.model)
else:
class_name = self.__class__.__name__
raise TypeError(
f"{class_name}.cast() invoked without "
f"{class_name}.model having been set! "
f"You should invoke {class_name}.set_model() first "
"or explicitly provide a model to .cast()."
)

if include is not None and exclude is not None:
raise ValueError(
f"Both include and exclude provided to {self.__class__.__name__}.cast()!"
)
elif include is not None:
include = set(include)
elif exclude is not None:
include = set(relation.columns) - set(exclude)
else:
include = set(relation.columns)

new_columns = []
for column, current_type in relation.types.items():
if column not in schema.columns:
new_columns.append(column)
elif column in include and (
strict or current_type not in schema.valid_sql_types[column]
):
new_type = schema.sql_types[column]
new_columns.append(f"{column}::{new_type} as {column}")
else:
new_columns.append(column)
return cast(RelationType, self.select(*new_columns))

def coalesce(
self: RelationType,
**column_expressions: Union[str, int, float],
Expand Down Expand Up @@ -1871,7 +1977,7 @@ def with_missing_defaultable_columns(
┌────────────────────┬────────────────┬────────────────────────┐
│ non_default_column ┆ default_column ┆ another_default_column │
│ --- ┆ --- ┆ --- │
│ i32 ┆ i32 ┆ i64
│ i32 ┆ i32 ┆ i32
╞════════════════════╪════════════════╪════════════════════════╡
│ 1 ┆ 2 ┆ 42 │
└────────────────────┴────────────────┴────────────────────────┘
Expand Down Expand Up @@ -1959,7 +2065,7 @@ def with_missing_nullable_columns(
┌─────────────────┬─────────────────────────┐
│ nullable_column ┆ another_nullable_column │
│ --- ┆ --- │
│ i32 ┆ i64
│ i32 ┆ i32
╞═════════════════╪═════════════════════════╡
│ 1 ┆ null │
└─────────────────┴─────────────────────────┘
Expand Down Expand Up @@ -2021,9 +2127,9 @@ def __getitem__(self, key: Union[str, Iterable[str]]) -> Relation:
"""
Return Relation with selected columns.
Uses :ref:`Relation.project()<Relation.project>` under-the-hood in order to
Uses :ref:`Relation.select()<Relation.select>` under-the-hood in order to
perform the selection. Can technically be used to rename columns,
define derived columns, and so on, but prefer the use of Relation.project() for
define derived columns, and so on, but prefer the use of Relation.select() for
such use cases.
Args:
Expand Down
Loading

0 comments on commit 8fcccae

Please sign in to comment.