Skip to content
This repository has been archived by the owner on Nov 2, 2023. It is now read-only.

Commit

Permalink
Complete test coverage for DuckDB module
Browse files Browse the repository at this point in the history
  • Loading branch information
JakobGM committed Sep 1, 2022
1 parent 21445d5 commit 470e072
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 42 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ source = ["src", "*/site-packages"]
[tool.coverage.run]
branch = true
source = ["patito"]
# TODO: Remove once DuckDB-API becomes public
omit = ["src/patito/duckdb.py"]

[tool.coverage.report]
exclude_lines = [
Expand Down
65 changes: 43 additions & 22 deletions src/patito/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
Tuple[int, int, int],
tuple(map(int, pl.__version__.split("."))),
)
except ValueError:
except ValueError: # pragma: no cover
POLARS_VERSION = None


Expand Down Expand Up @@ -265,10 +265,10 @@ def __init__( # noqa: C901
relation = self.database.connection.from_csv_auto(str(derived_from))
else:
raise ValueError(
f"Unsupported file suffix {derived_from.suffix} for data import!"
f"Unsupported file suffix {derived_from.suffix!r} for data import!"
)
else:
raise TypeError
raise TypeError # pragma: no cover

self._relation = relation
if model is not None:
Expand Down Expand Up @@ -1139,7 +1139,7 @@ def project(
)
try:
relation = self._relation.project(projection)
except RuntimeError as exc:
except RuntimeError as exc: # pragma: no cover
# We might get a RunTime error if the enum type has not
# been created yet. If so, we create all enum types for
# this model.
Expand Down Expand Up @@ -1326,16 +1326,6 @@ def to_df(self) -> DataFrame:
# Here we do a star-select to work around certain weird issues with DuckDB
self._relation = self._relation.project("*")
arrow_table = cast(pa.lib.Table, self._relation.to_arrow_table())
if POLARS_VERSION and POLARS_VERSION <= (0, 13, 38):
# Fix for https://github.com/pola-rs/polars/issues/3500
schema = arrow_table.schema
for index, field in enumerate(schema):
if isinstance(field.type, pa.DictionaryType):
dict_field = field.with_type(
pa.dictionary(index_type=pa.int8(), value_type=pa.utf8())
)
schema = schema.set(index, dict_field)
arrow_table = arrow_table.cast(schema)
try:
return DataFrame._from_arrow(arrow_table)
except pa.ArrowInvalid: # pragma: no cover
Expand Down Expand Up @@ -1537,7 +1527,7 @@ def with_missing_defaultable_columns(
raise TypeError(
f"{class_name}.with_missing_default_columns() invoked without "
f"{class_name}.model having been set! "
"You should invoke {class_name}.set_model() first!"
f"You should invoke {class_name}.set_model() first!"
)
elif include is not None and exclude is not None:
raise TypeError("Both include and exclude provided at the same time!")
Expand All @@ -1559,7 +1549,7 @@ def with_missing_defaultable_columns(

try:
relation = self._relation.project(projection)
except Exception as exc:
except Exception as exc: # pragma: no cover
# We might get a RunTime error if the enum type has not
# been created yet. If so, we create all enum types for
# this model.
Expand Down Expand Up @@ -1625,7 +1615,7 @@ def with_missing_nullable_columns(
raise TypeError(
f"{class_name}.with_missing_nullable_columns() invoked without "
f"{class_name}.model having been set! "
"You should invoke {class_name}.set_model() first!"
f"You should invoke {class_name}.set_model() first!"
)
elif include is not None and exclude is not None:
raise TypeError("Both include and exclude provided at the same time!")
Expand All @@ -1645,7 +1635,7 @@ def with_missing_nullable_columns(

try:
relation = self._relation.project(projection)
except Exception as exc:
except Exception as exc: # pragma: no cover
# We might get a RunTime error if the enum type has not
# been created yet. If so, we create all enum types for
# this model.
Expand Down Expand Up @@ -2132,7 +2122,40 @@ def table(self, name: str) -> Relation:
│ 2 ┆ 4 │
└─────┴─────┘
"""
return Relation(self.connection.table(name))
return Relation(
self.connection.table(name),
database=self.from_connection(self.connection),
)

def view(self, name: str) -> Relation:
"""
Return relation representing all the data in the given view.
Args:
name: The name of the view.
Example:
>>> import patito as pt
>>> df = pt.DataFrame({"a": [1, 2], "b": [3, 4]})
>>> db = pt.Database()
>>> relation = db.to_relation(df)
>>> relation.create_view(name="my_view")
>>> db.view("my_view").to_df()
shape: (2, 2)
┌─────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
│ i64 ┆ i64 │
╞═════╪═════╡
│ 1 ┆ 3 │
├╌╌╌╌╌┼╌╌╌╌╌┤
│ 2 ┆ 4 │
└─────┴─────┘
"""
return Relation(
self.connection.view(name),
database=self.from_connection(self.connection),
)

def create_table(
self,
Expand Down Expand Up @@ -2217,7 +2240,7 @@ def create_enum_types(self, model: Type[ModelType]) -> None:
)
except duckdb.CatalogException as e:
if "already exists" not in str(e):
raise e
raise e # pragma: no cover
self.enum_types.add(enum_type_name)

def create_view(
Expand Down Expand Up @@ -2245,10 +2268,8 @@ def __getattr__(
"from_parquet",
"from_query",
"query",
"table",
"table_function",
"values",
"view",
}
if name in relation_methods:
return lambda *args, **kwargs: Relation(
Expand Down
4 changes: 3 additions & 1 deletion src/patito/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ def sql_types( # type: ignore

types = {}
for column, props in cls._schema_properties().items():
if "enum" in props:
if "enum" in props and all(
isinstance(variant, str) for variant in props["enum"]
):
types[column] = _enum_type_name(field_properties=props)
else:
types[column] = PYDANTIC_TO_DUCKDB_TYPES[props["type"]]
Expand Down
20 changes: 20 additions & 0 deletions tests/test_duckdb/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ class Model(BaseModel):
]


def test_create_view():
"""It should be able to create a view from a relation source."""
db = pt.Database()
df = pt.DataFrame({"a": [1, 2], "b": [3.0, 4.0]})
db.create_view(name="my_view", data=df)
assert db.view("my_view").to_df().frame_equal(df)


def test_validate_non_nullable_enum_columns():
"""Enum columns should be null-validated."""

Expand Down Expand Up @@ -197,6 +205,18 @@ class Model(pt.Model):
assert "test_table" in db


def test_creating_enums_several_tiems():
"""Enums should be able to be defined several times."""

class EnumModel(pt.Model):
enum_column: Literal["a", "b", "c"]

db = pt.Database()
db.create_enum_types(EnumModel)
db.enum_types = set()
db.create_enum_types(EnumModel)


def test_use_of_same_enum_types_from_literal_annotation():
"""Identical literals should get the same DuckDB SQL enum type."""

Expand Down
Loading

0 comments on commit 470e072

Please sign in to comment.