Skip to content

Commit

Permalink
refactor: use pandas to store row data (#240)
Browse files Browse the repository at this point in the history
### Summary of Changes

In #214 we changes the implementation of `Row` so its data was stored in
a `polars.DataFrame`. As explained
[here](#196 (comment)),
`pandas` works better for us for now. We might undo this change in the
future if the type inference of `polars` gets improved (or we decide to
implement this ourselves).

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
  • Loading branch information
lars-reimann and megalinter-bot authored Apr 22, 2023
1 parent 2b58c82 commit f2769f5
Show file tree
Hide file tree
Showing 11 changed files with 55 additions and 298 deletions.
102 changes: 2 additions & 100 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ pandas = "^2.0.0"
pillow = "^9.5.0"
scikit-learn = "^1.2.0"
seaborn = "^0.12.2"
polars = {extras = ["pandas", "pyarrow", "xlsx2csv"], version = "^0.17.5"}

[tool.poetry.group.dev.dependencies]
pytest = "^7.2.1"
Expand Down
30 changes: 17 additions & 13 deletions src/safeds/data/tabular/containers/_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any

import polars as pl
import pandas as pd

from safeds.data.tabular.exceptions import UnknownColumnNameError
from safeds.data.tabular.typing import ColumnType, Schema
Expand Down Expand Up @@ -54,13 +54,13 @@ def from_dict(data: dict[str, Any]) -> Row:
return Row(data)

@staticmethod
def _from_polars_dataframe(data: pl.DataFrame, schema: Schema | None = None) -> Row:
def _from_pandas_dataframe(data: pd.DataFrame, schema: Schema | None = None) -> Row:
"""
Create a row from a `polars.DataFrame`.
Create a row from a `pandas.DataFrame`.
Parameters
----------
data : polars.DataFrame
data : pd.DataFrame
The data.
schema : Schema | None
The schema. If None, the schema is inferred from the data.
Expand All @@ -72,16 +72,18 @@ def _from_polars_dataframe(data: pl.DataFrame, schema: Schema | None = None) ->
Examples
--------
>>> import polars as pl
>>> import pandas as pd
>>> from safeds.data.tabular.containers import Row
>>> row = Row._from_polars_dataframe(pl.DataFrame({"a": [1], "b": [2]}))
>>> row = Row._from_pandas_dataframe(pd.DataFrame({"a": [1], "b": [2]}))
"""
data = data.reset_index(drop=True)

result = object.__new__(Row)
result._data = data

if schema is None:
# noinspection PyProtectedMember
result._schema = Schema._from_polars_dataframe(data)
result._schema = Schema._from_pandas_dataframe(data)
else:
result._schema = schema

Expand All @@ -108,9 +110,11 @@ def __init__(self, data: Mapping[str, Any] | None = None):
if data is None:
data = {}

self._data: pl.DataFrame = pl.DataFrame(data)
data = {key: [value] for key, value in data.items()}

self._data: pd.DataFrame = pd.DataFrame(data)
# noinspection PyProtectedMember
self._schema: Schema = Schema._from_polars_dataframe(self._data)
self._schema: Schema = Schema._from_pandas_dataframe(self._data)

def __contains__(self, obj: Any) -> bool:
"""
Expand Down Expand Up @@ -169,7 +173,7 @@ def __eq__(self, other: Any) -> bool:
return NotImplemented
if self is other:
return True
return self._schema == other._schema and self._data.frame_equal(other._data)
return self._schema == other._schema and self._data.equals(other._data)

def __getitem__(self, column_name: str) -> Any:
"""
Expand Down Expand Up @@ -233,7 +237,7 @@ def __len__(self) -> int:
>>> len(row)
2
"""
return self._data.width
return self._data.shape[1]

def __repr__(self) -> str:
"""
Expand Down Expand Up @@ -319,7 +323,7 @@ def n_columns(self) -> int:
>>> row.n_columns
2
"""
return self._data.width
return self._data.shape[1]

@property
def schema(self) -> Schema:
Expand Down Expand Up @@ -372,7 +376,7 @@ def get_value(self, column_name: str) -> Any:
if not self.has_column(column_name):
raise UnknownColumnNameError([column_name])

return self._data[0, column_name]
return self._data.loc[0, column_name]

def has_column(self, column_name: str) -> bool:
"""
Expand Down
22 changes: 7 additions & 15 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
from IPython.core.display_functions import DisplayHandle, display
from pandas import DataFrame
Expand Down Expand Up @@ -211,7 +210,7 @@ def from_rows(rows: list[Row]) -> Table:
for row in rows:
if schema_compare != row._schema:
raise SchemaMismatchError
row_array.append(row._data.to_pandas())
row_array.append(row._data)

dataframe: DataFrame = pd.concat(row_array, ignore_index=True)
dataframe.columns = schema_compare.column_names
Expand Down Expand Up @@ -251,10 +250,7 @@ def __eq__(self, other: Any) -> bool:
return True
table1 = self.sort_columns()
table2 = other.sort_columns()
return table1._data.equals(table2._data) and table1._schema == table2._schema

def __hash__(self) -> int:
return hash(self._data)
return table1._schema == table2._schema and table1._data.equals(table2._data)

def __repr__(self) -> str:
tmp = self._data.copy(deep=True)
Expand Down Expand Up @@ -416,7 +412,7 @@ def get_row(self, index: int) -> Row:
if len(self._data.index) - 1 < index or index < 0:
raise IndexOutOfBoundsError(index)

return Row._from_polars_dataframe(pl.DataFrame(self._data.iloc[[index]]), self._schema)
return Row._from_pandas_dataframe(self._data.iloc[[index]], self._schema)

# ------------------------------------------------------------------------------------------------------------------
# Information
Expand Down Expand Up @@ -549,9 +545,7 @@ def add_row(self, row: Row) -> Table:
if self._schema != row.schema:
raise SchemaMismatchError

row_frame = row._data.to_pandas()

new_df = pd.concat([self._data, row_frame]).infer_objects()
new_df = pd.concat([self._data, row._data]).infer_objects()
new_df.columns = self.column_names
return Table(new_df)

Expand All @@ -576,9 +570,7 @@ def add_rows(self, rows: list[Row] | Table) -> Table:
if self._schema != row.schema:
raise SchemaMismatchError

row_frames = [row._data.to_pandas() for row in rows]
for row_frame in row_frames:
row_frame.columns = self.column_names
row_frames = (row._data for row in rows)

result = pd.concat([result, *row_frames]).infer_objects()
result.columns = self.column_names
Expand Down Expand Up @@ -1266,8 +1258,8 @@ def to_rows(self) -> list[Row]:
List of rows.
"""
return [
Row._from_polars_dataframe(
pl.DataFrame([list(series_row)], schema=self._schema.column_names),
Row._from_pandas_dataframe(
pd.DataFrame([list(series_row)], columns=self._schema.column_names),
self._schema,
)
for (_, series_row) in self._data.iterrows()
Expand Down
41 changes: 0 additions & 41 deletions src/safeds/data/tabular/typing/_column_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,6 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING

from polars import FLOAT_DTYPES as POLARS_FLOAT_DTYPES
from polars import INTEGER_DTYPES as POLARS_INTEGER_DTYPES
from polars import TEMPORAL_DTYPES as POLARS_TEMPORAL_DTYPES
from polars import Boolean as PolarsBoolean
from polars import Decimal as PolarsDecimal
from polars import Object as PolarsObject
from polars import PolarsDataType
from polars import Utf8 as PolarsUtf8

if TYPE_CHECKING:
import numpy as np

Expand Down Expand Up @@ -52,38 +43,6 @@ def _from_numpy_data_type(data_type: np.dtype) -> ColumnType:
message = f"Unsupported numpy data type '{data_type}'."
raise NotImplementedError(message)

@staticmethod
def _from_polars_data_type(data_type: PolarsDataType) -> ColumnType:
"""
Return the column type for a given `polars` data type.
Parameters
----------
data_type : PolarsDataType
The `polars` data type.
Returns
-------
column_type : ColumnType
The ColumnType.
Raises
------
NotImplementedError
If the given data type is not supported.
"""
if data_type in POLARS_INTEGER_DTYPES:
return Integer()
if data_type is PolarsBoolean:
return Boolean()
if data_type in POLARS_FLOAT_DTYPES or data_type is PolarsDecimal:
return RealNumber()
if data_type is PolarsUtf8 or data_type is PolarsObject or data_type in POLARS_TEMPORAL_DTYPES:
return String()

message = f"Unsupported polars data type '{data_type}'."
raise NotImplementedError(message)

@abstractmethod
def is_nullable(self) -> bool:
"""
Expand Down
Loading

0 comments on commit f2769f5

Please sign in to comment.