diff --git a/docs/api/patito/Model/index.rst b/docs/api/patito/Model/index.rst index e64f31b..074c6f4 100644 --- a/docs/api/patito/Model/index.rst +++ b/docs/api/patito/Model/index.rst @@ -19,6 +19,7 @@ Class properties sql_types unique_columns valid_dtypes + valid_sql_types Class methods ------------- diff --git a/docs/api/patito/Model/valid_sql_types.rst b/docs/api/patito/Model/valid_sql_types.rst new file mode 100644 index 0000000..0e72469 --- /dev/null +++ b/docs/api/patito/Model/valid_sql_types.rst @@ -0,0 +1,8 @@ +.. _Model.valid_sql_types: + +patito.Model.valid_sql_types +============================ + +.. currentmodule:: patito._docs + +.. autoproperty:: Model.valid_sql_types diff --git a/docs/api/patito/Relation/cast.rst b/docs/api/patito/Relation/cast.rst new file mode 100644 index 0000000..7c2b9d4 --- /dev/null +++ b/docs/api/patito/Relation/cast.rst @@ -0,0 +1,8 @@ +.. _Relation.cast: + +patito.Relation.cast +==================== + +.. currentmodule:: patito + +.. automethod:: Relation.cast diff --git a/docs/api/patito/Relation/index.rst b/docs/api/patito/Relation/index.rst index 338998f..1695c56 100644 --- a/docs/api/patito/Relation/index.rst +++ b/docs/api/patito/Relation/index.rst @@ -27,6 +27,7 @@ Methods aggregate all case + cast coalesce count create_table diff --git a/src/patito/duckdb.py b/src/patito/duckdb.py index 640d6cd..b3f7959 100644 --- a/src/patito/duckdb.py +++ b/src/patito/duckdb.py @@ -23,7 +23,7 @@ import numpy as np import polars as pl -import pyarrow as pa # type: ignore +import pyarrow as pa # type: ignore[import] from pydantic import create_model from typing_extensions import Literal @@ -62,27 +62,30 @@ # The SQL types supported by DuckDB # See: https://duckdb.org/docs/sql/data_types/overview +# fmt: off DuckDBSQLType = Literal[ - "BIGINT", - "BOOLEAN", - "BLOB", + "BIGINT", "INT8", "LONG", + "BLOB", "BYTEA", "BINARY", "VARBINARY", + "BOOLEAN", "BOOL", "LOGICAL", "DATE", - "DOUBLE", - "DECIMAL", + "DOUBLE", "FLOAT8", "NUMERIC", "DECIMAL", "HUGEINT", - "INTEGER", - "REAL", - "SMALLINT", + "INTEGER", "INT4", "INT", "SIGNED", + "INTERVAL", + "REAL", "FLOAT4", "FLOAT", + "SMALLINT", "INT2", "SHORT", "TIME", - "TIMESTAMP", - "TINYINT", + "TIMESTAMP", "DATETIME", + "TIMESTAMP WITH TIMEZONE", "TIMESTAMPTZ", + "TINYINT", "INT1", "UBIGINT", "UINTEGER", "USMALLINT", "UTINYINT", "UUID", - "VARCHAR", + "VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING", ] +# fmt: on # Used for backward-compatible patches POLARS_VERSION: Optional[Tuple[int, int, int]] @@ -566,6 +569,109 @@ def case( new_relation = self._relation.project(f"*, {case_statement}") return self._wrap(relation=new_relation, schema_change=True) + def cast( + self: RelationType, + model: Optional[ModelType] = None, + strict: bool = False, + include: Optional[Collection[str]] = None, + exclude: Optional[Collection[str]] = None, + ) -> RelationType: + """ + Cast the columns of the relation to types compatible with the associated model. + + The associated model must either be set by invoking + :ref:`Relation.set_model() ` or provided with the ``model`` + parameter. + + Any columns of the relation that are not part of the given model schema will be + left as-is. + + Args: + model: If :ref:`Relation.set_model() ` has not been + invoked or is intended to be overwritten. + strict: If set to ``False``, columns which are technically compliant with + the specified field type, will not be casted. For example, a column + annotated with ``int`` is technically compliant with ``SMALLINT``, even + if ``INTEGER`` is the default SQL type associated with ``int``-annotated + fields. If ``strict`` is set to ``True``, the resulting dtypes will + be forced to the default dtype associated with each python type. + include: If provided, only the given columns will be casted. + exclude: If provided, the given columns will `not` be casted. + + Returns: + New relation where the columns have been casted according to the model + schema. + + Examples: + >>> import patito as pt + >>> class Schema(pt.Model): + ... float_column: float + ... + >>> relation = pt.Relation("select 1 as float_column") + >>> relation.types["float_column"] + 'INTEGER' + >>> relation.cast(model=Schema).types["float_column"] + 'DOUBLE' + + >>> relation = pt.Relation("select 1::FLOAT as float_column") + >>> relation.cast(model=Schema).types["float_column"] + 'FLOAT' + >>> relation.cast(model=Schema, strict=True).types["float_column"] + 'DOUBLE' + + >>> class Schema(pt.Model): + ... column_1: float + ... column_2: float + ... + >>> relation = pt.Relation("select 1 as column_1, 2 as column_2").set_model( + ... Schema + ... ) + >>> relation.types + {'column_1': 'INTEGER', 'column_2': 'INTEGER'} + >>> relation.cast(include=["column_1"]).types + {'column_1': 'DOUBLE', 'column_2': 'INTEGER'} + >>> relation.cast(exclude=["column_1"]).types + {'column_1': 'INTEGER', 'column_2': 'DOUBLE'} + """ + if model is not None: + relation = self.set_model(model) + schema = model + elif self.model is not None: + relation = self + schema = cast(ModelType, self.model) + else: + class_name = self.__class__.__name__ + raise TypeError( + f"{class_name}.cast() invoked without " + f"{class_name}.model having been set! " + f"You should invoke {class_name}.set_model() first " + "or explicitly provide a model to .cast()." + ) + + if include is not None and exclude is not None: + raise ValueError( + f"Both include and exclude provided to {self.__class__.__name__}.cast()!" + ) + elif include is not None: + include = set(include) + elif exclude is not None: + include = set(relation.columns) - set(exclude) + else: + include = set(relation.columns) + + new_columns = [] + for column, current_type in relation.types.items(): + if column not in schema.columns: + new_columns.append(column) + elif column in include and ( + strict or current_type not in schema.valid_sql_types[column] + ): + new_type = schema.sql_types[column] + new_columns.append(f"{column}::{new_type} as {column}") + else: + new_columns.append(column) + return cast(RelationType, self.select(*new_columns)) + def coalesce( self: RelationType, **column_expressions: Union[str, int, float], @@ -1871,7 +1977,7 @@ def with_missing_defaultable_columns( ┌────────────────────┬────────────────┬────────────────────────┐ │ non_default_column ┆ default_column ┆ another_default_column │ │ --- ┆ --- ┆ --- │ - │ i32 ┆ i32 ┆ i64 │ + │ i32 ┆ i32 ┆ i32 │ ╞════════════════════╪════════════════╪════════════════════════╡ │ 1 ┆ 2 ┆ 42 │ └────────────────────┴────────────────┴────────────────────────┘ @@ -1959,7 +2065,7 @@ def with_missing_nullable_columns( ┌─────────────────┬─────────────────────────┐ │ nullable_column ┆ another_nullable_column │ │ --- ┆ --- │ - │ i32 ┆ i64 │ + │ i32 ┆ i32 │ ╞═════════════════╪═════════════════════════╡ │ 1 ┆ null │ └─────────────────┴─────────────────────────┘ @@ -2021,9 +2127,9 @@ def __getitem__(self, key: Union[str, Iterable[str]]) -> Relation: """ Return Relation with selected columns. - Uses :ref:`Relation.project()` under-the-hood in order to + Uses :ref:`Relation.select()` under-the-hood in order to perform the selection. Can technically be used to rename columns, - define derived columns, and so on, but prefer the use of Relation.project() for + define derived columns, and so on, but prefer the use of Relation.select() for such use cases. Args: diff --git a/src/patito/pydantic.py b/src/patito/pydantic.py index 6fc6080..ba620e9 100644 --- a/src/patito/pydantic.py +++ b/src/patito/pydantic.py @@ -35,6 +35,7 @@ if TYPE_CHECKING: import patito.polars + from patito.duckdb import DuckDBSQLType # The generic type of a single row in given Relation. # Should be a typed subclass of Model. @@ -182,9 +183,7 @@ def valid_dtypes( # type: ignore # noqa: C901 valid_dtypes[column] = [ props["dtype"], ] - elif "enum" in props: - if props["type"] != "string": # pragma: no cover - raise NotImplementedError + elif "enum" in props and props["type"] == "string": valid_dtypes[column] = [pl.Categorical, pl.Utf8] elif "type" not in props: raise NotImplementedError( @@ -238,6 +237,172 @@ def valid_dtypes( # type: ignore # noqa: C901 return valid_dtypes + @property + def valid_sql_types( # type: ignore # noqa: C901 + cls: Type[ModelType], + ) -> dict[str, List["DuckDBSQLType"]]: + """ + Return a list of DuckDB SQL types which Patito considers valid for each field. + + The first item of each list is the default dtype chosen by Patito. + + Returns: + A dictionary mapping each column string name to a list of DuckDB SQL types + represented as strings. + + Raises: + NotImplementedError: If one or more model fields are annotated with types + not compatible with DuckDB. + + Example: + >>> import patito as pt + >>> from pprint import pprint + + >>> class MyModel(pt.Model): + ... bool_column: bool + ... str_column: str + ... int_column: int + ... float_column: float + ... + >>> pprint(MyModel.valid_sql_types) + {'bool_column': ['BOOLEAN', 'BOOL', 'LOGICAL'], + 'float_column': ['DOUBLE', + 'FLOAT8', + 'NUMERIC', + 'DECIMAL', + 'REAL', + 'FLOAT4', + 'FLOAT'], + 'int_column': ['INTEGER', + 'INT4', + 'INT', + 'SIGNED', + 'BIGINT', + 'INT8', + 'LONG', + 'HUGEINT', + 'SMALLINT', + 'INT2', + 'SHORT', + 'TINYINT', + 'INT1', + 'UBIGINT', + 'UINTEGER', + 'USMALLINT', + 'UTINYINT'], + 'str_column': ['VARCHAR', 'CHAR', 'BPCHAR', 'TEXT', 'STRING']} + """ + valid_dtypes: Dict[str, List["DuckDBSQLType"]] = {} + for column, props in cls._schema_properties().items(): + if "sql_type" in props: + valid_dtypes[column] = [ + props["sql_type"], + ] + elif "enum" in props and props["type"] == "string": + from patito.duckdb import _enum_type_name + + # fmt: off + valid_dtypes[column] = [ + _enum_type_name(field_properties=props), # type: ignore + "VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING", + ] + # fmt: on + elif "type" not in props: + raise NotImplementedError( + f"No valid sql_type mapping found for column '{column}'." + ) + elif props["type"] == "integer": + # fmt: off + valid_dtypes[column] = [ + "INTEGER", "INT4", "INT", "SIGNED", + "BIGINT", "INT8", "LONG", + "HUGEINT", + "SMALLINT", "INT2", "SHORT", + "TINYINT", "INT1", + "UBIGINT", + "UINTEGER", + "USMALLINT", + "UTINYINT", + ] + # fmt: on + elif props["type"] == "number": + if props.get("format") == "time-delta": + valid_dtypes[column] = [ + "INTERVAL", + ] + else: + # fmt: off + valid_dtypes[column] = [ + "DOUBLE", "FLOAT8", "NUMERIC", "DECIMAL", + "REAL", "FLOAT4", "FLOAT", + ] + # fmt: on + elif props["type"] == "boolean": + # fmt: off + valid_dtypes[column] = [ + "BOOLEAN", "BOOL", "LOGICAL", + ] + # fmt: on + elif props["type"] == "string": + string_format = props.get("format") + if string_format is None: + # fmt: off + valid_dtypes[column] = [ + "VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING", + ] + # fmt: on + elif string_format == "date": + valid_dtypes[column] = ["DATE"] + # TODO: Find out why this branch is not being hit + elif string_format == "date-time": # pragma: no cover + # fmt: off + valid_dtypes[column] = [ + "TIMESTAMP", "DATETIME", + "TIMESTAMP WITH TIMEZONE", "TIMESTAMPTZ", + ] + # fmt: on + elif props["type"] == "null": + valid_dtypes[column] = [ + "INTEGER", + ] + else: # pragma: no cover + raise NotImplementedError( + f"No valid sql_type mapping found for column '{column}'" + ) + + return valid_dtypes + + @property + def sql_types( # type: ignore + cls: Type[ModelType], + ) -> dict[str, str]: + """ + Return compatible DuckDB SQL types for all model fields. + + Returns: + Dictionary with column name keys and SQL type identifier strings. + + Example: + >>> from typing import Literal + >>> import patito as pt + + >>> class MyModel(pt.Model): + ... int_column: int + ... str_column: str + ... float_column: float + ... literal_column: Literal["a", "b", "c"] + ... + >>> MyModel.sql_types + {'int_column': 'INTEGER', + 'str_column': 'VARCHAR', + 'float_column': 'DOUBLE', + 'literal_column': 'enum__4a496993dde04060df4e15a340651b45'} + """ + return { + column: valid_types[0] + for column, valid_types in cls.valid_sql_types.items() + } + @property def defaults( # type: ignore cls: Type[ModelType], @@ -338,43 +503,6 @@ def unique_columns( # type: ignore props = cls._schema_properties() return {column for column in cls.columns if props[column].get("unique", False)} - @property - def sql_types( # type: ignore - cls: Type[ModelType], - ) -> dict[str, str]: - """ - Return compatible DuckDB SQL types for all model fields. - - Returns: - Dictionary with column name keys and SQL type identifier strings. - - Example: - >>> import patito as pt - - >>> class MyModel(pt.Model): - ... int_column: int - ... str_column: str - ... float_column: float - ... literal_column: Literal["a", "b", "c"] - ... - >>> MyModel.sql_types - {'int_column': 'BIGINT', - 'str_column': 'VARCHAR', - 'float_column': 'DOUBLE', - 'literal_column': 'enum__4a496993dde04060df4e15a340651b45'} - """ - from patito.duckdb import _enum_type_name - - types = {} - for column, props in cls._schema_properties().items(): - if "enum" in props and all( - isinstance(variant, str) for variant in props["enum"] - ): - types[column] = _enum_type_name(field_properties=props) - else: - types[column] = PYDANTIC_TO_DUCKDB_TYPES[props["type"]] - return types - class Model(BaseModel, metaclass=ModelMetaclass): """Custom pydantic class for representing table schema and constructing rows.""" @@ -397,6 +525,7 @@ class Model(BaseModel, metaclass=ModelMetaclass): dtypes: ClassVar[Dict[str, Type[pl.DataType]]] sql_types: ClassVar[Dict[str, str]] valid_dtypes: ClassVar[Dict[str, List[Type[pl.DataType]]]] + valid_sql_types: ClassVar[Dict[str, List["DuckDBSQLType"]]] defaults: ClassVar[Dict[str, Any]] diff --git a/tests/test_duckdb/test_database.py b/tests/test_duckdb/test_database.py index 171c890..79b8e54 100644 --- a/tests/test_duckdb/test_database.py +++ b/tests/test_duckdb/test_database.py @@ -111,8 +111,8 @@ class Model(BaseModel): "enum_column", ] assert list(table.types.values()) == [ - "BIGINT", - "BIGINT", + "INTEGER", + "INTEGER", "VARCHAR", "VARCHAR", "BOOLEAN", diff --git a/tests/test_duckdb/test_relation.py b/tests/test_duckdb/test_relation.py index 8c4fe22..c9a0b76 100644 --- a/tests/test_duckdb/test_relation.py +++ b/tests/test_duckdb/test_relation.py @@ -1,4 +1,5 @@ import re +from datetime import date, timedelta from pathlib import Path from typing import Optional from unittest.mock import MagicMock @@ -537,7 +538,7 @@ class TypeModel(pt.Model): assert TypeModel.sql_types == { "a": "VARCHAR", - "b": "BIGINT", + "b": "INTEGER", "c": "DOUBLE", "d": "BOOLEAN", } @@ -889,7 +890,7 @@ class EnumModel(pt.Model): ) enum_df = db.table("enum_table").to_df() assert enum_df.frame_equal(pl.DataFrame({"enum_column": [10, 11, 12]})) - assert enum_df.dtypes == [pl.Int64] + assert enum_df.dtypes == [pl.Int32] def test_multiple_filters(): @@ -914,3 +915,126 @@ def test_string_representation_of_relation(): relation = pt.Relation("select 1 as my_column") relation_str = str(relation) assert "my_column" in relation_str + + +def test_cast(): + """It should be able to cast to the correct SQL types based on model.""" + + class Schema(pt.Model): + float_column: float + + relation = pt.Relation("select 1 as float_column, 2 as other_column") + with pytest.raises( + TypeError, + match=( + r"Relation\.cast\(\) invoked without Relation.model having been set\! " + r"You should invoke Relation\.set_model\(\) first or explicitly provide " + r"a model to \.cast\(\)." + ), + ): + relation.cast() + + # Originally the type of both columns are integers + modeled_relation = relation.set_model(Schema) + assert modeled_relation.types["float_column"] == "INTEGER" + assert modeled_relation.types["other_column"] == "INTEGER" + + # The casted variant has converted the float column to double + casted_relation = relation.set_model(Schema).cast() + assert casted_relation.types["float_column"] == "DOUBLE" + # But kept the other as-is + assert casted_relation.types["other_column"] == "INTEGER" + + # You can either set the model with .set_model() or provide it to cast + assert ( + relation.set_model(Schema) + .cast() + .to_df() + .frame_equal(relation.cast(Schema).to_df()) + ) + + # Other types that should be considered compatible should be kept as-is + compatible_relation = pt.Relation("select 1::FLOAT as float_column") + assert compatible_relation.cast(Schema).types["float_column"] == "FLOAT" + + # Unless the strict parameter is specified + assert ( + compatible_relation.cast(Schema, strict=True).types["float_column"] == "DOUBLE" + ) + + # We can also specify a specific SQL type + class SpecificSQLTypeSchema(pt.Model): + float_column: float = pt.Field(sql_type="BIGINT") + + specific_cast_relation = relation.set_model(SpecificSQLTypeSchema).cast() + assert specific_cast_relation.types["float_column"] == "BIGINT" + + # Unknown types raise + class ObjectModel(pt.Model): + object_column: object + + with pytest.raises( + NotImplementedError, + match=r"No valid sql_type mapping found for column 'object_column'\.", + ): + pt.Relation("select 1 as object_column").set_model(ObjectModel).cast() + + # Check for more specific type annotations + class TotalModel(pt.Model): + timedelta_column: timedelta + date_column: date + null_column: None + + df = pt.DataFrame( + { + "timedelta_column": [timedelta(seconds=90)], + "date_column": [date(2022, 9, 4)], + "null_column": [None], + } + ) + casted_relation = pt.Relation(df, model=TotalModel).cast() + assert casted_relation.types == { + "timedelta_column": "INTERVAL", + "date_column": "DATE", + "null_column": "INTEGER", + } + assert casted_relation.to_df().frame_equal(df) + + # It is possible to only cast a subset + class MyModel(pt.Model): + column_1: float + column_2: float + + relation = pt.Relation("select 1 as column_1, 2 as column_2").set_model(MyModel) + assert relation.cast(include=[]).types == { + "column_1": "INTEGER", + "column_2": "INTEGER", + } + assert relation.cast(include=["column_1"]).types == { + "column_1": "DOUBLE", + "column_2": "INTEGER", + } + assert relation.cast(include=["column_1", "column_2"]).types == { + "column_1": "DOUBLE", + "column_2": "DOUBLE", + } + + assert relation.cast(exclude=[]).types == { + "column_1": "DOUBLE", + "column_2": "DOUBLE", + } + assert relation.cast(exclude=["column_1"]).types == { + "column_1": "INTEGER", + "column_2": "DOUBLE", + } + assert relation.cast(exclude=["column_1", "column_2"]).types == { + "column_1": "INTEGER", + "column_2": "INTEGER", + } + + # Providing both include and exclude should raise a value error + with pytest.raises( + ValueError, + match=r"Both include and exclude provided to Relation.cast\(\)\!", + ): + relation.cast(include=["column_1"], exclude=["column_2"])