From d8cab68dddde79a972ea6a46992a8ff266f6dde1 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Sun, 4 Aug 2024 12:51:22 +0200 Subject: [PATCH 01/10] fix to _df_all and _df_concat --- mesa_frames/concrete/polars/mixin.py | 32 ++++++++++++---------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index bae9b532..05eb89fc 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -1,11 +1,9 @@ -from collections.abc import Collection, Iterator, Sequence +from collections.abc import Collection, Hashable, Iterator, Sequence from typing import Literal import polars as pl from typing_extensions import Any, overload -from collections.abc import Hashable - from mesa_frames.abstract.mixin import DataFrameMixin from mesa_frames.types_ import PolarsMask @@ -64,20 +62,11 @@ def _df_add( def _df_all( self, df: pl.DataFrame, - name: str, - axis: str = "columns", - index_cols: str | None = None, - ) -> pl.DataFrame: - if axis == "index": - return df.group_by(index_cols).agg(pl.all().all().alias(index_cols)) - return df.select(pl.all().all()) - - def _df_with_columns( - self, original_df: pl.DataFrame, new_columns: list[str], data: Any - ) -> pl.DataFrame: - return original_df.with_columns( - **{col: value for col, value in zip(new_columns, data)} - ) + axis: Literal["index", "columns"] = "columns", + ) -> pl.Series: + if axis == "columns": + return df.select(pl.col("*").all()).to_series() + return df.with_columns(all=pl.all_horizontal())["all"] def _df_column_names(self, df: pl.DataFrame) -> list[str]: return df.columns @@ -113,7 +102,7 @@ def _df_concat( objs: Collection[pl.DataFrame], how: Literal["horizontal"] | Literal["vertical"] = "vertical", ignore_index: bool = False, - index_cols: str | None = None, + index_cols: str | list[str] | None = None, ) -> pl.DataFrame: ... @overload @@ -471,6 +460,13 @@ def _df_set_index( return df return df.with_columns(index_name=new_index) + def _df_with_columns( + self, original_df: pl.DataFrame, new_columns: list[str], data: Any + ) -> pl.DataFrame: + return original_df.with_columns( + **{col: value for col, value in zip(new_columns, data)} + ) + def _srs_constructor( self, data: Sequence[Any] | None = None, From 98bfe1128f94b8b8a4cc5f1ccaef06b0c40a01aa Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Sun, 4 Aug 2024 14:03:22 +0200 Subject: [PATCH 02/10] adding right overload to concat --- mesa_frames/abstract/mixin.py | 16 ++++++++++++---- mesa_frames/concrete/pandas/mixin.py | 11 ++++++++++- mesa_frames/concrete/polars/mixin.py | 13 +++++++++++-- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/mesa_frames/abstract/mixin.py b/mesa_frames/abstract/mixin.py index 0f3599fa..ba90834d 100644 --- a/mesa_frames/abstract/mixin.py +++ b/mesa_frames/abstract/mixin.py @@ -1,12 +1,10 @@ from abc import ABC, abstractmethod -from collections.abc import Collection, Iterator, Sequence +from collections.abc import Collection, Hashable, Iterator, Sequence from copy import copy, deepcopy from typing import Literal from typing_extensions import Any, Self, overload -from collections.abc import Hashable - from mesa_frames.types_ import BoolSeries, DataFrame, Index, Mask, Series @@ -185,11 +183,21 @@ def _df_combine_first( def _df_concat( self, objs: Collection[Series], - how: Literal["horizontal"] | Literal["vertical"] = "vertical", + how: Literal["vertical"] = "vertical", ignore_index: bool = False, index_cols: str | None = None, ) -> Series: ... + @overload + @abstractmethod + def _df_concat( + self, + objs: Collection[Series], + how: Literal["horizontal"] = "horizontal", + ignore_index: bool = False, + index_cols: str | None = None, + ) -> DataFrame: ... + @overload @abstractmethod def _df_concat( diff --git a/mesa_frames/concrete/pandas/mixin.py b/mesa_frames/concrete/pandas/mixin.py index a5f99c64..bffc225b 100644 --- a/mesa_frames/concrete/pandas/mixin.py +++ b/mesa_frames/concrete/pandas/mixin.py @@ -63,7 +63,16 @@ def _df_concat( def _df_concat( self, objs: Collection[pd.Series], - how: Literal["horizontal"] | Literal["vertical"] = "vertical", + how: Literal["horizontal"] = "horizontal", + ignore_index: bool = False, + index_cols: str | None = None, + ) -> pd.DataFrame: ... + + @overload + def _df_concat( + self, + objs: Collection[pd.Series], + how: Literal["vertical"] = "vertical", ignore_index: bool = False, index_cols: str | None = None, ) -> pd.Series: ... diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index 05eb89fc..75295187 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -109,11 +109,20 @@ def _df_concat( def _df_concat( self, objs: Collection[pl.Series], - how: Literal["horizontal"] | Literal["vertical"] = "vertical", + how: Literal["vertical"] = "vertical", ignore_index: bool = False, - index_cols: str | None = None, + index_cols: str | list[str] | None = None, ) -> pl.Series: ... + @overload + def _df_concat( + self, + objs: Collection[pl.Series], + how: Literal["horizontal"] = "horizontal", + ignore_index: bool = False, + index_cols: str | list[str] | None = None, + ) -> pl.DataFrame: ... + def _df_concat( self, objs: Collection[pl.DataFrame] | Collection[pl.Series], From 77ad3799e79fd204df78f705081b43870a5569b1 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Sun, 4 Aug 2024 18:02:03 +0200 Subject: [PATCH 03/10] change _df_filter to _df_get_masked_df + _df_all --- mesa_frames/abstract/mixin.py | 17 ++++------------- mesa_frames/abstract/space.py | 10 ++++------ mesa_frames/concrete/pandas/mixin.py | 27 +++++++-------------------- mesa_frames/concrete/polars/mixin.py | 1 + 4 files changed, 16 insertions(+), 39 deletions(-) diff --git a/mesa_frames/abstract/mixin.py b/mesa_frames/abstract/mixin.py index ba90834d..d5c03e02 100644 --- a/mesa_frames/abstract/mixin.py +++ b/mesa_frames/abstract/mixin.py @@ -165,10 +165,9 @@ def _df_add( def _df_all( self, df: DataFrame, - name: str, + name: str = "all", axis: str = "columns", - index_cols: str | list[str] | None = None, - ) -> DataFrame: ... + ) -> Series: ... @abstractmethod def _df_column_names(self, df: DataFrame) -> list[str]: ... @@ -259,19 +258,11 @@ def _df_drop_duplicates( keep: Literal["first", "last", False] = "first", ) -> DataFrame: ... - @abstractmethod - def _df_filter( - self, - df: DataFrame, - condition: BoolSeries, - all: bool = True, - ) -> DataFrame: ... - @abstractmethod def _df_get_bool_mask( self, df: DataFrame, - index_cols: str | list[str], + index_cols: str | list[str] | None = None, mask: Mask | None = None, negate: bool = False, ) -> BoolSeries: ... @@ -280,7 +271,7 @@ def _df_get_bool_mask( def _df_get_masked_df( self, df: DataFrame, - index_cols: str, + index_cols: str | list[str] | None = None, mask: Mask | None = None, columns: str | list[str] | None = None, negate: bool = False, diff --git a/mesa_frames/abstract/space.py b/mesa_frames/abstract/space.py index 6fe6ef09..770916c6 100644 --- a/mesa_frames/abstract/space.py +++ b/mesa_frames/abstract/space.py @@ -1341,8 +1341,8 @@ def get_neighborhood( radius_df, on=self._center_col_names, ) - neighbors_df = self._df_filter( - neighbors_df, neighbors_df["radius"] <= neighbors_df["max_radius"] + neighbors_df = self._df_get_masked_df( + neighbors_df, mask=neighbors_df["radius"] <= neighbors_df["max_radius"] ) neighbors_df = self._df_drop_columns(neighbors_df, "max_radius") @@ -1357,13 +1357,12 @@ def get_neighborhood( neighbors_df = self._df_drop_duplicates(neighbors_df, self._pos_col_names) # Filter out-of-bound neighbors - neighbors_df = self._df_filter( + neighbors_df = self._df_get_masked_df( neighbors_df, - ( + mask=self._df_all( (neighbors_df[self._pos_col_names] < self._dimensions) & (neighbors_df >= 0) ), - all=True, ) if include_center: @@ -1428,7 +1427,6 @@ def out_of_bounds(self, pos: GridCoordinate | GridCoordinates) -> DataFrame: out_of_bounds = self._df_all( (pos_df < 0) | (pos_df >= self._dimensions), name="out_of_bounds", - index_cols=self._pos_col_names, ) return self._df_concat(objs=[pos_df, out_of_bounds], how="horizontal") diff --git a/mesa_frames/concrete/pandas/mixin.py b/mesa_frames/concrete/pandas/mixin.py index bffc225b..2593ab08 100644 --- a/mesa_frames/concrete/pandas/mixin.py +++ b/mesa_frames/concrete/pandas/mixin.py @@ -1,8 +1,6 @@ -from collections.abc import Collection, Iterator, Sequence +from collections.abc import Collection, Hashable, Iterator, Sequence from typing import Literal -from collections.abc import Hashable - import numpy as np import pandas as pd from typing_extensions import Any, overload @@ -24,11 +22,10 @@ def _df_add( def _df_all( self, df: pd.DataFrame, - name: str, + name: str = "all", axis: str = "columns", - index_cols: str | list[str] | None = None, - ) -> pd.DataFrame: - return df.all(axis).to_frame(name) + ) -> pd.Series: + return df.all(axis).rename(name) def _df_column_names(self, df: pd.DataFrame) -> list[str]: return df.columns.tolist() + df.index.names @@ -116,16 +113,6 @@ def _df_contains( return pd.Series(values).isin(df.index) return pd.Series(values).isin(df[column]) - def _df_filter( - self, - df: pd.DataFrame, - condition: pd.DataFrame, - all: bool = True, - ) -> pd.DataFrame: - if all and isinstance(condition, pd.DataFrame): - return df[condition.all(axis=1)] - return df[condition] - def _df_div( self, df: pd.DataFrame, @@ -153,7 +140,7 @@ def _df_drop_duplicates( def _df_get_bool_mask( self, df: pd.DataFrame, - index_cols: str | list[str], + index_cols: str | list[str] | None = None, mask: PandasMask = None, negate: bool = False, ) -> pd.Series: @@ -162,7 +149,7 @@ def _df_get_bool_mask( isinstance(index_cols, list) and df.index.names == index_cols ): srs = df.index - else: + elif index_cols is not None: srs = df.set_index(index_cols).index if isinstance(mask, pd.Series) and mask.dtype == bool and len(mask) == len(df): mask.index = df.index @@ -190,7 +177,7 @@ def _df_get_bool_mask( def _df_get_masked_df( self, df: pd.DataFrame, - index_cols: str, + index_cols: str | list[str] | None = None, mask: PandasMask | None = None, columns: str | list[str] | None = None, negate: bool = False, diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index 75295187..34d78b5e 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -62,6 +62,7 @@ def _df_add( def _df_all( self, df: pl.DataFrame, + name: str = "all", axis: Literal["index", "columns"] = "columns", ) -> pl.Series: if axis == "columns": From 9f757c8f257ed17b1b6514401ad7a207c54e07ef Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Sun, 4 Aug 2024 18:16:21 +0200 Subject: [PATCH 04/10] adding custom name to _df_groupby_cumcount --- mesa_frames/abstract/mixin.py | 4 +--- mesa_frames/abstract/space.py | 2 +- mesa_frames/concrete/pandas/mixin.py | 6 ++++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mesa_frames/abstract/mixin.py b/mesa_frames/abstract/mixin.py index d5c03e02..cc843d08 100644 --- a/mesa_frames/abstract/mixin.py +++ b/mesa_frames/abstract/mixin.py @@ -279,9 +279,7 @@ def _df_get_masked_df( @abstractmethod def _df_groupby_cumcount( - self, - df: DataFrame, - by: str | list[str], + self, df: DataFrame, by: str | list[str], name: str = "cum_count" ) -> Series: ... @abstractmethod diff --git a/mesa_frames/abstract/space.py b/mesa_frames/abstract/space.py index 770916c6..b6942b40 100644 --- a/mesa_frames/abstract/space.py +++ b/mesa_frames/abstract/space.py @@ -1263,7 +1263,7 @@ def get_neighborhood( radius_df = self._srs_to_df(radius_srs) radius_df = self._df_with_columns( radius_df, - self._df_groupby_cumcount(radius_df, "radius") + 1, + self._df_groupby_cumcount(radius_df, "radius", name="offset") + 1, new_columns="offset", ) diff --git a/mesa_frames/concrete/pandas/mixin.py b/mesa_frames/concrete/pandas/mixin.py index 2593ab08..d562fa03 100644 --- a/mesa_frames/concrete/pandas/mixin.py +++ b/mesa_frames/concrete/pandas/mixin.py @@ -187,8 +187,10 @@ def _df_get_masked_df( return df.loc[b_mask, columns] return df.loc[b_mask] - def _df_groupby_cumcount(self, df: pd.DataFrame, by: str | list[str]) -> pd.Series: - return df.groupby(by).cumcount() + def _df_groupby_cumcount( + self, df: pd.DataFrame, by: str | list[str], name: str = "cum_count" + ) -> pd.Series: + return df.groupby(by).cumcount().rename(name) def _df_iterator(self, df: pd.DataFrame) -> Iterator[dict[str, Any]]: for index, row in df.iterrows(): From 3186a126e8fef962c34e96a6edd7d6c02b6bc392 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:40:23 +0200 Subject: [PATCH 05/10] adding tests for PolarsMixin --- tests/polars/test_mixin_polars.py | 617 ++++++++++++++++++++++++++++++ 1 file changed, 617 insertions(+) create mode 100644 tests/polars/test_mixin_polars.py diff --git a/tests/polars/test_mixin_polars.py b/tests/polars/test_mixin_polars.py new file mode 100644 index 00000000..d297fbb0 --- /dev/null +++ b/tests/polars/test_mixin_polars.py @@ -0,0 +1,617 @@ +import numpy as np +import pandas as pd +import polars as pl +import pytest +import typeguard as tg + +from mesa_frames.concrete.polars.mixin import PolarsMixin + + +@tg.typechecked +class TestPolarsMixin: + @pytest.fixture + def mixin(self): + return PolarsMixin() + + @pytest.fixture + def df_0(self): + return pl.DataFrame( + { + "unique_id": ["x", "y", "z"], + "A": [1, 2, 3], + "B": ["a", "b", "c"], + "C": [True, False, True], + "D": [1, 2, 3], + }, + ) + + @pytest.fixture + def df_1(self): + return pl.DataFrame( + { + "unique_id": ["z", "a", "b"], + "A": [4, 5, 6], + "B": ["d", "e", "f"], + "C": [False, True, False], + "E": [1, 2, 3], + }, + ) + + def test_df_add(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + # Test adding a DataFrame and a sequence element-wise along the rows (axis='index') + result = mixin._df_add(df_0[["A", "D"]], df_1["A"], axis="index") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [5, 7, 9] + assert result["D"].to_list() == [5, 7, 9] + + # Test adding a DataFrame and a sequence element-wise along the column (axis='columns') + result = mixin._df_add(df_0[["A", "D"]], [1, 2], axis="columns") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [2, 3, 4] + assert result["D"].to_list() == [3, 4, 5] + + # Test adding DataFrames with index-column alignment + df_1 = df_1.with_columns(D=pl.col("E")) + result = mixin._df_add( + df_0[["unique_id", "A", "D"]], + df_1[["unique_id", "A", "D"]], + axis="index", + index_cols="unique_id", + ) + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [None, None, 7] + assert result["D"].to_list() == [None, None, 4] + + def test_df_all(self, mixin: PolarsMixin): + df = pl.DataFrame( + { + "A": [True, False, True], + "B": [True, True, True], + } + ) + + # Test with axis='index' + result = mixin._df_all(df["A", "B"], axis="index") + assert isinstance(result, pl.Series) + assert result.name == "all" + assert result.to_list() == [True, False, True] + + # Test with axis='columns' + result = mixin._df_all(df["A", "B"], axis="columns") + assert isinstance(result, pl.Series) + assert result.name == "all" + assert result.to_list() == [False, True] + + def test_df_column_names(self, mixin: PolarsMixin, df_0: pl.DataFrame): + cols = mixin._df_column_names(df_0) + assert isinstance(cols, list) + assert all(isinstance(c, str) for c in cols) + assert set(mixin._df_column_names(df_0)) == {"unique_id", "A", "B", "C", "D"} + + def test_df_combine_first( + self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame + ): + # Test with df_0 and df_1 + result = mixin._df_combine_first(df_0, df_1, "unique_id") + result = result.sort("A") + assert isinstance(result, pl.DataFrame) + assert set(result.columns) == {"unique_id", "A", "B", "C", "D", "E"} + assert result["unique_id"].to_list() == ["x", "y", "z", "a", "b"] + assert result["A"].to_list() == [1, 2, 3, 5, 6] + assert result["B"].to_list() == ["a", "b", "c", "e", "f"] + assert result["C"].to_list() == [True, False, True, True, False] + assert result["D"].to_list() == [1, 2, 3, None, None] + assert result["E"].to_list() == [None, None, 1, 2, 3] + + # Test with df_1 and df_0 + result = mixin._df_combine_first(df_1, df_0, "unique_id") + result = result.sort("E", nulls_last=True) + assert isinstance(result, pl.DataFrame) + assert set(result.columns) == {"unique_id", "A", "B", "C", "D", "E"} + assert result["unique_id"].to_list() == ["z", "a", "b", "x", "y"] + assert result["A"].to_list() == [4, 5, 6, 1, 2] + assert result["B"].to_list() == ["d", "e", "f", "a", "b"] + assert result["C"].to_list() == [False, True, False, True, False] + assert result["D"].to_list() == [3, None, None, 1, 2] + assert result["E"].to_list() == [1, 2, 3, None, None] + + def test_df_concat( + self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame + ): + ### Test vertical concatenation + ## With DataFrames + for ignore_index in [False, True]: + vertical = mixin._df_concat( + [df_0, df_1], how="vertical", ignore_index=ignore_index + ) + assert isinstance(vertical, pl.DataFrame) + assert vertical.columns == ["unique_id", "A", "B", "C", "D", "E"] + assert len(vertical) == 6 + assert vertical["unique_id"].to_list() == ["x", "y", "z", "z", "a", "b"] + assert vertical["A"].to_list() == [1, 2, 3, 4, 5, 6] + assert vertical["B"].to_list() == ["a", "b", "c", "d", "e", "f"] + assert vertical["C"].to_list() == [True, False, True, False, True, False] + assert vertical["D"].to_list() == [1, 2, 3, None, None, None] + assert vertical["E"].to_list() == [None, None, None, 1, 2, 3] + + ## With Series + for ignore_index in [True, False]: + vertical = mixin._df_concat( + [df_0["A"], df_1["A"]], how="vertical", ignore_index=ignore_index + ) + assert isinstance(vertical, pl.Series) + assert len(vertical) == 6 + assert vertical.to_list() == [1, 2, 3, 4, 5, 6] + assert vertical.name == "A" + + ## Test horizontal concatenation + ## With DataFrames + # Error With same column names + with pytest.raises(pl.exceptions.DuplicateError): + mixin._df_concat([df_0, df_1], how="horizontal") + # With ignore_index = False + df_1 = df_1.rename(lambda c: f"{c}_1") + horizontal = mixin._df_concat([df_0, df_1], how="horizontal") + assert isinstance(horizontal, pl.DataFrame) + assert horizontal.columns == [ + "unique_id", + "A", + "B", + "C", + "D", + "unique_id_1", + "A_1", + "B_1", + "C_1", + "E_1", + ] + assert len(horizontal) == 3 + assert horizontal["unique_id"].to_list() == ["x", "y", "z"] + assert horizontal["A"].to_list() == [1, 2, 3] + assert horizontal["B"].to_list() == ["a", "b", "c"] + assert horizontal["C"].to_list() == [True, False, True] + assert horizontal["D"].to_list() == [1, 2, 3] + assert horizontal["unique_id_1"].to_list() == ["z", "a", "b"] + assert horizontal["A_1"].to_list() == [4, 5, 6] + assert horizontal["B_1"].to_list() == ["d", "e", "f"] + assert horizontal["C_1"].to_list() == [False, True, False] + assert horizontal["E_1"].to_list() == [1, 2, 3] + + # With ignore_index = True + horizontal_ignore_index = mixin._df_concat( + [df_0, df_1], + how="horizontal", + ignore_index=True, + ) + assert isinstance(horizontal_ignore_index, pl.DataFrame) + assert horizontal_ignore_index.columns == [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ] + assert len(horizontal_ignore_index) == 3 + assert horizontal_ignore_index["0"].to_list() == ["x", "y", "z"] + assert horizontal_ignore_index["1"].to_list() == [1, 2, 3] + assert horizontal_ignore_index["2"].to_list() == ["a", "b", "c"] + assert horizontal_ignore_index["3"].to_list() == [True, False, True] + assert horizontal_ignore_index["4"].to_list() == [1, 2, 3] + assert horizontal_ignore_index["5"].to_list() == ["z", "a", "b"] + assert horizontal_ignore_index["6"].to_list() == [4, 5, 6] + assert horizontal_ignore_index["7"].to_list() == ["d", "e", "f"] + assert horizontal_ignore_index["8"].to_list() == [False, True, False] + assert horizontal_ignore_index["9"].to_list() == [1, 2, 3] + + ## With Series + # With ignore_index = False + horizontal = mixin._df_concat( + [df_0["A"], df_1["B_1"]], how="horizontal", ignore_index=False + ) + assert isinstance(horizontal, pl.DataFrame) + assert horizontal.columns == ["A", "B_1"] + assert len(horizontal) == 3 + assert horizontal["A"].to_list() == [1, 2, 3] + assert horizontal["B_1"].to_list() == ["d", "e", "f"] + + # With ignore_index = True + horizontal = mixin._df_concat( + [df_0["A"], df_1["B_1"]], how="horizontal", ignore_index=True + ) + assert isinstance(horizontal, pl.DataFrame) + assert horizontal.columns == ["0", "1"] + assert len(horizontal) == 3 + assert horizontal["0"].to_list() == [1, 2, 3] + assert horizontal["1"].to_list() == ["d", "e", "f"] + + def test_df_constructor(self, mixin: PolarsMixin): + # Test with dictionary + data = {"num": [1, 2, 3], "letter": ["a", "b", "c"]} + df = mixin._df_constructor(data) + assert isinstance(df, pl.DataFrame) + assert list(df.columns) == ["num", "letter"] + assert df["num"].to_list() == [1, 2, 3] + assert df["letter"].to_list() == ["a", "b", "c"] + + # Test with list of lists + data = [[1, "a"], [2, "b"], [3, "c"]] + df = mixin._df_constructor( + data, columns=["num", "letter"], dtypes={"num": "int64"} + ) + assert isinstance(df, pl.DataFrame) + assert list(df.columns) == ["num", "letter"] + assert df["num"].dtype == pl.Int64 + assert df["num"].to_list() == [1, 2, 3] + assert df["letter"].to_list() == ["a", "b", "c"] + + # Test with pandas DataFrame + data = pd.DataFrame({"num": [1, 2, 3], "letter": ["a", "b", "c"]}) + df = mixin._df_constructor(data) + assert isinstance(df, pl.DataFrame) + assert list(df.columns) == ["num", "letter"] + assert df["num"].to_list() == [1, 2, 3] + assert df["letter"].to_list() == ["a", "b", "c"] + + def test_df_contains(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with list + result = mixin._df_contains(df_0, "A", [5, 2, 3]) + assert isinstance(result, pl.Series) + assert result.name == "contains" + assert result.to_list() == [False, True, True] + + def test_df_div(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + # Test dividing the DataFrame by a sequence element-wise along the rows (axis='index') + result = mixin._df_div(df_0[["A", "D"]], df_1["A"], axis="index") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [0.25, 0.4, 0.5] + assert result["D"].to_list() == [0.25, 0.4, 0.5] + + # Test dividing the DataFrame by a sequence element-wise along the columns (axis='columns') + result = mixin._df_div(df_0[["A", "D"]], [1, 2], axis="columns") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [1, 2, 3] + assert result["D"].to_list() == [0.5, 1, 1.5] + + # Test dividing DataFrames with index-column alignment + df_1 = df_1.with_columns(D=pl.col("E")) + result = mixin._df_div( + df_0[["unique_id", "A", "D"]], + df_1[["unique_id", "A", "D"]], + axis="index", + index_cols="unique_id", + ) + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [None, None, 0.75] + assert result["D"].to_list() == [None, None, 3] + + def test_df_drop_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with str + dropped = mixin._df_drop_columns(df_0, "A") + assert isinstance(dropped, pl.DataFrame) + assert dropped.columns == ["unique_id", "B", "C", "D"] + # Test with list + dropped = mixin._df_drop_columns(df_0, ["A", "C"]) + assert dropped.columns == ["unique_id", "B", "D"] + + def test_df_drop_duplicates(self, mixin: PolarsMixin, df_0: pl.DataFrame): + new_df = pl.concat([df_0, df_0], how="vertical") + assert len(new_df) == 6 + + # Test with all columns + dropped = mixin._df_drop_duplicates(new_df) + assert isinstance(dropped, pl.DataFrame) + assert len(dropped) == 3 + assert dropped.columns == ["unique_id", "A", "B", "C", "D"] + + # Test with subset (str) + other_df = pl.DataFrame( + { + "unique_id": ["x", "y", "z"], + "A": [1, 2, 3], + "B": ["d", "e", "f"], + "C": [True, True, False], + "D": [1, 2, 3], + }, + ) + new_df = pl.concat([df_0, other_df], how="vertical") + dropped = mixin._df_drop_duplicates(new_df, subset="unique_id") + assert isinstance(dropped, pl.DataFrame) + assert len(dropped) == 3 + + # Test with subset (list) + dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"]) + assert isinstance(dropped, pl.DataFrame) + assert len(dropped) == 5 + assert dropped.columns == ["unique_id", "A", "B", "C", "D"] + assert dropped["B"].to_list() == ["a", "b", "c", "e", "f"] + + # Test with subset (list) and keep='last' + dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"], keep="last") + assert isinstance(dropped, pl.DataFrame) + assert len(dropped) == 5 + assert dropped.columns == ["unique_id", "A", "B", "C", "D"] + assert dropped["B"].to_list() == ["d", "b", "c", "e", "f"] + + # Test with subset (list) and keep=False + dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"], keep=False) + assert isinstance(dropped, pl.DataFrame) + assert len(dropped) == 4 + assert dropped.columns == ["unique_id", "A", "B", "C", "D"] + assert dropped["B"].to_list() == ["b", "c", "e", "f"] + + def test_df_get_bool_mask(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with pl.Series[bool] + mask = mixin._df_get_bool_mask(df_0, "A", pl.Series([True, False, True])) + assert mask.to_list() == [True, False, True] + + # Test with DataFrame + mask_df = pl.DataFrame({"A": [1, 3]}) + mask = mixin._df_get_bool_mask(df_0, "A", mask_df) + assert mask.to_list() == [True, False, True] + + # Test with single value + mask = mixin._df_get_bool_mask(df_0, "A", 1) + assert mask.to_list() == [True, False, False] + + # Test with list of values + mask = mixin._df_get_bool_mask(df_0, "A", [1, 3]) + assert mask.to_list() == [True, False, True] + + # Test with negate=True + mask = mixin._df_get_bool_mask(df_0, "A", [1, 3], negate=True) + assert mask.to_list() == [False, True, False] + + def test_df_get_masked_df(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with pl.Series[bool] + masked_df = mixin._df_get_masked_df(df_0, "A", pl.Series([True, False, True])) + assert masked_df["A"].to_list() == [1, 3] + assert masked_df["unique_id"].to_list() == ["x", "z"] + + # Test with DataFrame + mask_df = pl.DataFrame({"A": [1, 3]}) + masked_df = mixin._df_get_masked_df(df_0, "A", mask_df) + assert masked_df["A"].to_list() == [1, 3] + assert masked_df["unique_id"].to_list() == ["x", "z"] + + # Test with single value + masked_df = mixin._df_get_masked_df(df_0, "A", 1) + assert masked_df["A"].to_list() == [1] + assert masked_df["unique_id"].to_list() == ["x"] + + # Test with list of values + masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3]) + assert masked_df["A"].to_list() == [1, 3] + assert masked_df["unique_id"].to_list() == ["x", "z"] + + # Test with columns + masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3], columns=["B"]) + assert list(masked_df.columns) == ["B"] + assert masked_df["B"].to_list() == ["a", "c"] + + # Test with negate=True + masked = mixin._df_get_masked_df(df_0, "A", [1, 3], negate=True) + assert len(masked) == 1 + + def test_df_groupby_cumcount(self, df_0: pl.DataFrame, mixin: PolarsMixin): + result = mixin._df_groupby_cumcount(df_0, "C") + assert result.to_list() == [1, 1, 2] + + def test_df_iterator(self, mixin: PolarsMixin, df_0: pl.DataFrame): + iterator = mixin._df_iterator(df_0) + first_item = next(iterator) + assert first_item == {"unique_id": "x", "A": 1, "B": "a", "C": True, "D": 1} + + def test_df_join(self, mixin: PolarsMixin): + left = pl.DataFrame({"A": [1, 2], "B": ["a", "b"]}) + right = pl.DataFrame({"A": [1, 3], "C": ["x", "y"]}) + + # Test with 'on' (left join) + joined = mixin._df_join(left, right, on="A") + assert set(joined.columns) == {"A", "B", "C"} + assert joined["A"].to_list() == [1, 2] + + # Test with 'left_on' and 'right_on' (left join) + right_1 = pl.DataFrame({"D": [1, 2], "C": ["x", "y"]}) + joined = mixin._df_join(left, right_1, left_on="A", right_on="D") + assert set(joined.columns) == {"A", "B", "C"} + assert joined["A"].to_list() == [1, 2] + + # Test with 'right' join + joined = mixin._df_join(left, right, on="A", how="right") + assert set(joined.columns) == {"A", "B", "C"} + assert joined["A"].to_list() == [1, 3] + + # Test with 'inner' join + joined = mixin._df_join(left, right, on="A", how="inner") + assert set(joined.columns) == {"A", "B", "C"} + assert joined["A"].to_list() == [1] + + # Test with 'outer' join + joined = mixin._df_join(left, right, on="A", how="outer") + assert set(joined.columns) == {"A", "B", "A_right", "C"} + assert joined["A"].to_list() == [1, None, 2] + assert joined["A_right"].to_list() == [1, 3, None] + + # Test with 'cross' join + joined = mixin._df_join(left, right, how="cross") + assert set(joined.columns) == {"A", "B", "A_right", "C"} + assert len(joined) == 4 + assert joined.row(0) == (1, "a", 1, "x") + assert joined.row(1) == (1, "a", 3, "y") + assert joined.row(2) == (2, "b", 1, "x") + assert joined.row(3) == (2, "b", 3, "y") + + # Test with different 'suffix' + joined = mixin._df_join(left, right, suffix="_r", how="cross") + assert set(joined.columns) == {"A", "B", "A_r", "C"} + assert len(joined) == 4 + assert joined.row(0) == (1, "a", 1, "x") + assert joined.row(1) == (1, "a", 3, "y") + assert joined.row(2) == (2, "b", 1, "x") + assert joined.row(3) == (2, "b", 3, "y") + + def test_df_mul(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + # Test multiplying the DataFrame by a sequence element-wise along the rows (axis='index') + result = mixin._df_mul(df_0[["A", "D"]], df_1["A"], axis="index") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [4, 10, 18] + assert result["D"].to_list() == [4, 10, 18] + + # Test multiplying the DataFrame by a sequence element-wise along the columns (axis='columns') + result = mixin._df_mul(df_0[["A", "D"]], [1, 2], axis="columns") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [1, 2, 3] + assert result["D"].to_list() == [2, 4, 6] + + # Test multiplying DataFrames with index-column alignment + df_1 = df_1.with_columns(D=pl.col("E")) + result = mixin._df_mul( + df_0[["unique_id", "A", "D"]], + df_1[["unique_id", "A", "D"]], + axis="index", + index_cols="unique_id", + ) + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [None, None, 12] + assert result["D"].to_list() == [None, None, 3] + + def test_df_norm(self, mixin: PolarsMixin): + df = pl.DataFrame({"A": [3, 4], "B": [4, 3]}) + # If include_cols = False + norm = mixin._df_norm(df) + assert isinstance(norm, pl.Series) + assert len(norm) == 2 + assert norm[0] == 5 + assert norm[1] == 5 + + # If include_cols = True + norm = mixin._df_norm(df, include_cols=True) + assert isinstance(norm, pl.DataFrame) + assert len(norm) == 2 + assert norm.columns == ["A", "B", "norm"] + assert norm.row(0, named=True)["norm"] == 5 + assert norm.row(1, named=True)["norm"] == 5 + + def test_df_rename_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): + renamed = mixin._df_rename_columns(df_0, ["A", "B"], ["X", "Y"]) + assert renamed.columns == ["unique_id", "X", "Y", "C", "D"] + + def test_df_reset_index(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # with drop = False + new_df = mixin._df_reset_index(df_0) + assert mixin._df_all(new_df == df_0).all() + + # with drop = True + new_df = mixin._df_reset_index(df_0, index_cols="unique_id", drop=True) + assert new_df.columns == ["A", "B", "C", "D"] + assert len(new_df) == len(df_0) + for col in new_df.columns: + assert (new_df[col] == df_0[col]).all() + + def test_df_remove(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with list + removed = mixin._df_remove(df_0, [1, 3], "A") + assert len(removed) == 1 + assert removed["unique_id"].to_list() == ["y"] + + def test_df_sample(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with n + sampled = mixin._df_sample(df_0, n=2, seed=42) + assert len(sampled) == 2 + + # Test with frac + sampled = mixin._df_sample(df_0, frac=2 / 3, seed=42) + assert len(sampled) == 2 + + # Test with replacement + sampled = mixin._df_sample(df_0, n=4, with_replacement=True, seed=42) + assert len(sampled) == 4 + assert sampled.n_unique() < 4 + + def test_df_set_index(self, mixin: PolarsMixin, df_0: pl.DataFrame): + index = pl.int_range(len(df_0), eager=True) + new_df = mixin._df_set_index(df_0, "index", index) + assert (new_df["index"] == index).all() + + def test_df_with_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with list + new_df = mixin._df_with_columns( + df_0, + data=[[4, "d"], [5, "e"], [6, "f"]], + new_columns=["D", "E"], + ) + assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] + assert new_df["D"].to_list() == [4, 5, 6] + assert new_df["E"].to_list() == ["d", "e", "f"] + + # Test with pl.DataFrame + second_df = pl.DataFrame({"D": [4, 5, 6], "E": ["d", "e", "f"]}) + new_df = mixin._df_with_columns(df_0, second_df) + assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] + assert new_df["D"].to_list() == [4, 5, 6] + assert new_df["E"].to_list() == ["d", "e", "f"] + + # Test with dictionary + new_df = mixin._df_with_columns( + df_0, data={"D": [4, 5, 6], "E": ["d", "e", "f"]} + ) + assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] + assert new_df["D"].to_list() == [4, 5, 6] + assert new_df["E"].to_list() == ["d", "e", "f"] + + # Test with numpy array + new_df = mixin._df_with_columns(df_0, data=np.array([4, 5, 6]), new_columns="D") + assert "D" in new_df.columns + assert new_df["D"].to_list() == [4, 5, 6] + + # Test with pl.Series + new_df = mixin._df_with_columns(df_0, pl.Series([4, 5, 6]), new_columns="D") + assert "D" in new_df.columns + assert new_df["D"].to_list() == [4, 5, 6] + + def test_srs_constructor(self, mixin: PolarsMixin): + # Test with list + srs = mixin._srs_constructor([1, 2, 3], name="test", dtype="int64") + assert srs.name == "test" + assert srs.dtype == pl.Int64 + + # Test with numpy array + srs = mixin._srs_constructor(np.array([1, 2, 3]), name="test") + assert srs.name == "test" + assert len(srs) == 3 + + def test_srs_contains(self, mixin: PolarsMixin): + srs = [1, 2, 3, 4, 5] + + # Test with single value + result = mixin._srs_contains(srs, 3) + assert result.to_list() == [True] + + # Test with list + result = mixin._srs_contains(srs, [1, 3, 6]) + assert result.to_list() == [True, True, False] + + # Test with numpy array + result = mixin._srs_contains(srs, np.array([1, 3, 6])) + assert result.to_list() == [True, True, False] + + def test_srs_range(self, mixin: PolarsMixin): + # Test with default step + srs = mixin._srs_range("test", 0, 5) + assert srs.name == "test" + assert srs.to_list() == [0, 1, 2, 3, 4] + + # Test with custom step + srs = mixin._srs_range("test", 0, 10, step=2) + assert srs.to_list() == [0, 2, 4, 6, 8] + + def test_srs_to_df(self, mixin: PolarsMixin): + srs = pl.Series("test", [1, 2, 3]) + df = mixin._srs_to_df(srs) + assert isinstance(df, pl.DataFrame) + assert df["test"].to_list() == [1, 2, 3] From 93c7db373fa492c38698c7f3d140df93db711ac9 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:41:01 +0200 Subject: [PATCH 06/10] fixes to PolarsMixins --- mesa_frames/concrete/polars/mixin.py | 385 ++++++++++++--------------- 1 file changed, 173 insertions(+), 212 deletions(-) diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index 34d78b5e..2ceec2cd 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -4,8 +4,10 @@ import polars as pl from typing_extensions import Any, overload +from collections.abc import Callable + from mesa_frames.abstract.mixin import DataFrameMixin -from mesa_frames.types_ import PolarsMask +from mesa_frames.types_ import DataFrame, PolarsMask class PolarsMixin(DataFrameMixin): @@ -19,45 +21,13 @@ def _df_add( axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, ) -> pl.DataFrame: - if isinstance(other, pl.DataFrame): - if axis == "index": - if index_cols is None: - raise ValueError( - "index_cols must be specified when axis is 'index'" - ) - return ( - df.join(other.select(pl.all().suffix("_add")), on=index_cols) - .with_columns( - [ - (pl.col(col) + pl.col(f"{col}_add")).alias(col) - for col in df.columns - if col not in index_cols - ] - ) - .select(df.columns) - ) - else: - return df.select( - [ - (pl.col(col) + pl.col(other.columns[i])).alias(col) - for i, col in enumerate(df.columns) - ] - ) - elif isinstance(other, Sequence): - if axis == "index": - other_series = pl.Series("addend", other) - return df.with_columns( - [(pl.col(col) + other_series).alias(col) for col in df.columns] - ) - else: - return df.with_columns( - [ - (pl.col(col) + other[i]).alias(col) - for i, col in enumerate(df.columns) - ] - ) - else: - raise ValueError("other must be a DataFrame or a Sequence") + return self._df_operation( + df=df, + other=other, + operation=lambda x, y: x + y, + axis=axis, + index_cols=index_cols, + ) def _df_all( self, @@ -66,8 +36,8 @@ def _df_all( axis: Literal["index", "columns"] = "columns", ) -> pl.Series: if axis == "columns": - return df.select(pl.col("*").all()).to_series() - return df.with_columns(all=pl.all_horizontal())["all"] + return pl.Series(name, df.select(pl.col("*").all()).row(0)) + return df.with_columns(all=pl.all_horizontal("*"))["all"] def _df_column_names(self, df: pl.DataFrame) -> list[str]: return df.columns @@ -78,24 +48,15 @@ def _df_combine_first( new_df: pl.DataFrame, index_cols: str | list[str], ) -> pl.DataFrame: - new_df = original_df.join(new_df, on=index_cols, how="full", suffix="_right") - # Find columns with the _right suffix and update the corresponding original columns - updated_columns = [] - for col in new_df.columns: - if col.endswith("_right"): - original_col = col.replace("_right", "") - updated_columns.append( - pl.when(pl.col(col).is_not_null()) - .then(pl.col(col)) - .otherwise(pl.col(original_col)) - .alias(original_col) - ) - - # Apply the updates and remove the _right columns - new_df = new_df.with_columns(updated_columns).select( - pl.col(r"^(?!.*_right$).*") - ) - return new_df + common_cols = set(original_df.columns) & set(new_df.columns) + merged_df = original_df.join(new_df, on=index_cols, how="full", suffix="_right") + merged_df = merged_df.with_columns( + [ + pl.coalesce(pl.col(col), pl.col(f"{col}_right")).alias(col) + for col in common_cols + ] + ).select(pl.exclude("^.*_right$")) + return merged_df @overload def _df_concat( @@ -131,20 +92,29 @@ def _df_concat( ignore_index: bool = False, index_cols: str | None = None, ) -> pl.Series | pl.DataFrame: - return pl.concat( - objs, how="vertical_relaxed" if how == "vertical" else "horizontal_relaxed" - ) + if isinstance(objs[0], pl.DataFrame) and how == "vertical": + how = "diagonal_relaxed" + if isinstance(objs[0], pl.Series) and how == "horizontal": + obj = pl.DataFrame().hstack(objs, in_place=True) + else: + obj = pl.concat(objs, how=how) + if isinstance(obj, pl.DataFrame) and how == "horizontal" and ignore_index: + obj = obj.rename( + {c: str(i) for c, i in zip(obj.columns, range(len(obj.columns)))} + ) + return obj def _df_constructor( self, - data: Sequence[Sequence] | dict[str | Any] | None = None, + data: dict[str | Any] | Sequence[Sequence] | DataFrame | None = None, columns: list[str] | None = None, index: Sequence[Hashable] | None = None, index_cols: str | list[str] | None = None, dtypes: dict[str, str] | None = None, ) -> pl.DataFrame: - dtypes = {k: self._dtypes_mapping.get(v, v) for k, v in dtypes.items()} - return pl.DataFrame(data=data, schema=dtypes if dtypes else columns) + if dtypes is not None: + dtypes = {k: self._dtypes_mapping.get(v, v) for k, v in dtypes.items()} + return pl.DataFrame(data=data, schema=columns, schema_overrides=dtypes) def _df_contains( self, @@ -152,70 +122,22 @@ def _df_contains( column: str, values: Sequence[Any], ) -> pl.Series: - return pl.Series(values, index=values).is_in(df[column]) + return pl.Series("contains", values).is_in(df[column]) def _df_div( self, df: pl.DataFrame, - other: pl.DataFrame | pl.Series | Sequence[float | int], + other: pl.DataFrame | Sequence[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, ) -> pl.DataFrame: - if isinstance(other, pl.DataFrame): - if axis == "index": - if index_cols is None: - raise ValueError( - "index_cols must be specified when axis is 'index'" - ) - return ( - df.join(other.select(pl.all().suffix("_div")), on=index_cols) - .with_columns( - [ - (pl.col(col) / pl.col(f"{col}_div")).alias(col) - for col in df.columns - if col not in index_cols - ] - ) - .select(df.columns) - ) - else: # axis == "columns" - return df.select( - [ - (pl.col(col) / pl.col(other.columns[i])).alias(col) - for i, col in enumerate(df.columns) - ] - ) - elif isinstance(other, pl.Series): - if axis == "index": - return df.with_columns( - [ - (pl.col(col) / other).alias(col) - for col in df.columns - if col != other.name - ] - ) - else: # axis == "columns" - return df.with_columns( - [ - (pl.col(col) / other[i]).alias(col) - for i, col in enumerate(df.columns) - ] - ) - elif isinstance(other, Sequence): - if axis == "index": - other_series = pl.Series("divisor", other) - return df.with_columns( - [(pl.col(col) / other_series).alias(col) for col in df.columns] - ) - else: # axis == "columns" - return df.with_columns( - [ - (pl.col(col) / other[i]).alias(col) - for i, col in enumerate(df.columns) - ] - ) - else: - raise ValueError("other must be a DataFrame, Series, or Sequence") + return self._df_operation( + df=df, + other=other, + operation=lambda x, y: x / y, + axis=axis, + index_cols=index_cols, + ) def _df_drop_columns( self, @@ -233,40 +155,26 @@ def _df_drop_duplicates( # If subset is None, use all columns if subset is None: subset = df.columns - # If subset is a string, convert it to a list - elif isinstance(subset, str): - subset = [subset] - - # Determine the sort order based on 'keep' + original_col_order = df.columns if keep == "first": - sort_expr = [pl.col(col).rank("dense", reverse=True) for col in subset] + return ( + df.group_by(subset, maintain_order=True) + .first() + .select(original_col_order) + ) elif keep == "last": - sort_expr = [pl.col(col).rank("dense") for col in subset] - elif keep is False: - # If keep is False, we don't need to sort, just group and filter - return df.group_by(subset).agg(pl.all().first()).sort(subset) + return ( + df.group_by(subset, maintain_order=True) + .last() + .select(original_col_order) + ) else: - raise ValueError("'keep' must be either 'first', 'last', or False") - - # Add a rank column, sort by it, and keep only the first row of each group - return ( - df.with_columns(pl.struct(sort_expr).alias("__rank")) - .sort("__rank") - .group_by(subset) - .agg(pl.all().first()) - .sort(subset) - .drop("__rank") - ) - - def _df_filter( - self, - df: pl.DataFrame, - condition: pl.Series, - all: bool = True, - ) -> pl.DataFrame: - if all: - return df.filter(pl.all(condition)) - return df.filter(condition) + return ( + df.with_columns(pl.len().over(subset)) + .filter(pl.col("len") < 2) + .drop("len") + .select(original_col_order) + ) def _df_get_bool_mask( self, @@ -322,8 +230,10 @@ def _df_get_masked_df( return df.filter(b_mask)[columns] return df.filter(b_mask) - def _df_groupby_cumcount(self, df: pl.DataFrame, by: str | list[str]) -> pl.Series: - return df.with_columns(pl.col(by).cum_count().alias("cumcount")) + def _df_groupby_cumcount( + self, df: pl.DataFrame, by: str | list[str], name="cum_count" + ) -> pl.Series: + return df.with_columns(pl.cum_count(by).over(by).alias(name))[name] def _df_iterator(self, df: pl.DataFrame) -> Iterator[dict[str, Any]]: return iter(df.iter_rows(named=True)) @@ -342,14 +252,19 @@ def _df_join( | Literal["cross"] = "left", suffix="_right", ) -> pl.DataFrame: + if how == "outer": + how = "full" + if how == "right": + left, right = right, left + left_on, right_on = right_on, left_on + how = "left" return left.join( right, on=on, left_on=left_on, right_on=right_on, how=how, - lsuffix="", - rsuffix=suffix, + suffix=suffix, ) def _df_mul( @@ -359,45 +274,13 @@ def _df_mul( axis: Literal["index", "columns"] = "index", index_cols: str | list[str] | None = None, ) -> pl.DataFrame: - if isinstance(other, pl.DataFrame): - if axis == "index": - if index_cols is None: - raise ValueError( - "index_cols must be specified when axis is 'index'" - ) - return ( - df.join(other.select(pl.all().suffix("_mul")), on=index_cols) - .with_columns( - [ - (pl.col(col) * pl.col(f"{col}_mul")).alias(col) - for col in df.columns - if col not in index_cols - ] - ) - .select(df.columns) - ) - else: # axis == "columns" - return df.select( - [ - (pl.col(col) * pl.col(other.columns[i])).alias(col) - for i, col in enumerate(df.columns) - ] - ) - elif isinstance(other, Sequence): - if axis == "index": - other_series = pl.Series("multiplier", other) - return df.with_columns( - [(pl.col(col) * other_series).alias(col) for col in df.columns] - ) - else: - return df.with_columns( - [ - (pl.col(col) * other[i]).alias(col) - for i, col in enumerate(df.columns) - ] - ) - else: - raise ValueError("other must be a DataFrame or a Sequence") + return self._df_operation( + df=df, + other=other, + operation=lambda x, y: x * y, + axis=axis, + index_cols=index_cols, + ) @overload def _df_norm( @@ -422,15 +305,64 @@ def _df_norm( include_cols: bool = False, ) -> pl.Series | pl.DataFrame: srs = ( - df.with_columns(pl.col("*").pow(2).alias("*")) - .sum_horizontal() - .sqrt() - .rename(srs_name) + df.with_columns(pl.col("*").pow(2)).sum_horizontal().sqrt().rename(srs_name) ) if include_cols: - return df.with_columns(srs_name=srs) + return df.with_columns(srs) return srs + def _df_operation( + self, + df: pl.DataFrame, + other: pl.DataFrame | Sequence[float | int], + operation: Callable[[pl.Expr, pl.Expr], pl.Expr], + axis: Literal["index", "columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> pl.DataFrame: + if isinstance(other, pl.DataFrame): + if index_cols is not None: + op_df = df.join(other, how="left", on=index_cols, suffix="_op") + else: + assert len(df) == len( + other + ), "DataFrames must have the same length if index_cols is not specified" + index_cols = [] + other = other.rename(lambda col: col + "_op") + op_df = pl.concat([df, other], how="horizontal") + return op_df.with_columns( + [ + operation(pl.col(col), pl.col(f"{col}_op")).alias(col) + for col in df.columns + if col not in index_cols + ] + ).select(df.columns) + elif isinstance( + other, (Sequence, pl.Series) + ): # Currently, pl.Series is not a Sequence + if axis == "index": + assert len(df) == len( + other + ), "Sequence must have the same length as df if axis is 'index'" + other_series = pl.Series("operand", other) + return df.with_columns( + [ + operation(pl.col(col), other_series).alias(col) + for col in df.columns + ] + ) + else: + assert ( + len(df.columns) == len(other) + ), "Sequence must have the same length as df.columns if axis is 'columns'" + return df.with_columns( + [ + operation(pl.col(col), pl.lit(other[i])).alias(col) + for i, col in enumerate(df.columns) + ] + ) + else: + raise ValueError("other must be a DataFrame or a Sequence") + def _df_rename_columns( self, df: pl.DataFrame, old_columns: list[str], new_columns: list[str] ) -> pl.DataFrame: @@ -457,7 +389,11 @@ def _df_sample( seed: int | None = None, ) -> pl.DataFrame: return df.sample( - n=n, frac=frac, replace=with_replacement, shuffle=shuffle, seed=seed + n=n, + fraction=frac, + with_replacement=with_replacement, + shuffle=shuffle, + seed=seed, ) def _df_set_index( @@ -468,30 +404,55 @@ def _df_set_index( ) -> pl.DataFrame: if new_index is None: return df - return df.with_columns(index_name=new_index) + return df.with_columns(**{index_name: new_index}) def _df_with_columns( - self, original_df: pl.DataFrame, new_columns: list[str], data: Any + self, + original_df: pl.DataFrame, + data: Sequence | pl.DataFrame | Sequence[Sequence] | dict[str | Any] | Any, + new_columns: str | list[str] | None = None, ) -> pl.DataFrame: - return original_df.with_columns( - **{col: value for col, value in zip(new_columns, data)} - ) + if ( + (isinstance(data, Sequence) and isinstance(data[0], Sequence)) + or isinstance( + data, pl.DataFrame + ) # Currently, pl.DataFrame is not a Sequence + or ( + isinstance(data, dict) + and isinstance(data[list(data.keys())[0]], Sequence) + ) + ): + # This means that data is a Sequence of Sequences (rows) + data = pl.DataFrame(data, new_columns) + original_df = original_df.select(pl.exclude(data.columns)) + return original_df.hstack(data) + if not isinstance(data, dict): + assert new_columns is not None, "new_columns must be specified" + if isinstance(new_columns, list): + data = {col: value for col, value in zip(new_columns, data)} + else: + data = {new_columns: data} + return original_df.with_columns(**data) def _srs_constructor( self, data: Sequence[Any] | None = None, name: str | None = None, - dtype: Any | None = None, + dtype: str | None = None, index: Sequence[Any] | None = None, ) -> pl.Series: - return pl.Series(name=name, values=data, dtype=self._dtypes_mapping[dtype]) + if dtype is not None: + dtype = self._dtypes_mapping[dtype] + return pl.Series(name=name, values=data, dtype=dtype) def _srs_contains( self, - srs: Sequence[Any], + srs: Collection[Any], values: Any | Sequence[Any], ) -> pl.Series: - return pl.Series(values, index=values).is_in(srs) + if not isinstance(values, Collection): + values = [values] + return pl.Series(values).is_in(srs) def _srs_range( self, From e3e6d91e392e060a60291e4e7bf1daf716b7cd60 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 12 Aug 2024 13:28:03 +0000 Subject: [PATCH 07/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mesa_frames/concrete/polars/mixin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index 2af72e74..2ceec2cd 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -405,6 +405,7 @@ def _df_set_index( if new_index is None: return df return df.with_columns(**{index_name: new_index}) + def _df_with_columns( self, original_df: pl.DataFrame, From 703b56173714d2f0b5c4ef18ca4fa9e2961ff5cb Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:04:18 +0200 Subject: [PATCH 08/10] adding index_cols to _df_join per abstract DataFrameMixin --- mesa_frames/concrete/polars/mixin.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index 2ceec2cd..1cc4fdb3 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -1,11 +1,9 @@ -from collections.abc import Collection, Hashable, Iterator, Sequence +from collections.abc import Callable, Collection, Hashable, Iterator, Sequence from typing import Literal import polars as pl from typing_extensions import Any, overload -from collections.abc import Callable - from mesa_frames.abstract.mixin import DataFrameMixin from mesa_frames.types_ import DataFrame, PolarsMask @@ -242,6 +240,7 @@ def _df_join( self, left: pl.DataFrame, right: pl.DataFrame, + index_cols: str | list[str] | None = None, on: str | list[str] | None = None, left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, From 9ea50e01a0443a0b9f0d5e24968db4cb0f1a4454 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Tue, 13 Aug 2024 12:02:37 +0200 Subject: [PATCH 09/10] fixes to method logic --- mesa_frames/concrete/polars/mixin.py | 55 +++++++++++++++++----------- tests/polars/test_mixin_polars.py | 8 ++-- 2 files changed, 38 insertions(+), 25 deletions(-) diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index 1cc4fdb3..4d17fa09 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -33,9 +33,9 @@ def _df_all( name: str = "all", axis: Literal["index", "columns"] = "columns", ) -> pl.Series: - if axis == "columns": + if axis == "index": return pl.Series(name, df.select(pl.col("*").all()).row(0)) - return df.with_columns(all=pl.all_horizontal("*"))["all"] + return df.with_columns(pl.all_horizontal("*").alias(name))[name] def _df_column_names(self, df: pl.DataFrame) -> list[str]: return df.columns @@ -49,10 +49,8 @@ def _df_combine_first( common_cols = set(original_df.columns) & set(new_df.columns) merged_df = original_df.join(new_df, on=index_cols, how="full", suffix="_right") merged_df = merged_df.with_columns( - [ - pl.coalesce(pl.col(col), pl.col(f"{col}_right")).alias(col) - for col in common_cols - ] + pl.coalesce(pl.col(col), pl.col(f"{col}_right")).alias(col) + for col in common_cols ).select(pl.exclude("^.*_right$")) return merged_df @@ -112,7 +110,19 @@ def _df_constructor( ) -> pl.DataFrame: if dtypes is not None: dtypes = {k: self._dtypes_mapping.get(v, v) for k, v in dtypes.items()} - return pl.DataFrame(data=data, schema=columns, schema_overrides=dtypes) + df = pl.DataFrame( + data=data, schema=columns, schema_overrides=dtypes, orient="row" + ) + if index is not None: + if index_cols is not None: + index_df = pl.DataFrame(index, index_cols) + else: + index_df = pl.DataFrame(index) + if len(df) != len(index_df) and len(df) == 1: + df = index_df.join(df, how="cross") + else: + df = index_df.hstack(df) + return df def _df_contains( self, @@ -188,8 +198,15 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: and len(mask) == len(df) ): return mask + assert isinstance(index_cols, str) return df[index_cols].is_in(mask) + def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series: + mask = mask.with_columns(in_it=True) + return df.join(mask[index_cols + ["in_it"]], on=index_cols, how="left")[ + "in_it" + ].fill_null(False) + if isinstance(mask, pl.Expr): result = mask elif isinstance(mask, pl.Series): @@ -197,11 +214,13 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: elif isinstance(mask, pl.DataFrame): if index_cols in mask.columns: result = bool_mask_from_series(mask[index_cols]) + elif all(col in mask.columns for col in index_cols): + result = bool_mask_from_df(mask[index_cols]) elif len(mask.columns) == 1 and mask.dtypes[0] == pl.Boolean: result = bool_mask_from_series(mask[mask.columns[0]]) else: raise KeyError( - f"DataFrame must have an {index_cols} column or a single boolean column." + f"Mask must have {index_cols} column(s) or a single boolean column." ) elif mask is None or mask == "all": result = pl.Series([True] * len(df)) @@ -329,11 +348,9 @@ def _df_operation( other = other.rename(lambda col: col + "_op") op_df = pl.concat([df, other], how="horizontal") return op_df.with_columns( - [ - operation(pl.col(col), pl.col(f"{col}_op")).alias(col) - for col in df.columns - if col not in index_cols - ] + operation(pl.col(col), pl.col(f"{col}_op")).alias(col) + for col in df.columns + if col not in index_cols ).select(df.columns) elif isinstance( other, (Sequence, pl.Series) @@ -344,20 +361,16 @@ def _df_operation( ), "Sequence must have the same length as df if axis is 'index'" other_series = pl.Series("operand", other) return df.with_columns( - [ - operation(pl.col(col), other_series).alias(col) - for col in df.columns - ] + operation(pl.col(col), other_series).alias(col) + for col in df.columns ) else: assert ( len(df.columns) == len(other) ), "Sequence must have the same length as df.columns if axis is 'columns'" return df.with_columns( - [ - operation(pl.col(col), pl.lit(other[i])).alias(col) - for i, col in enumerate(df.columns) - ] + operation(pl.col(col), pl.lit(other[i])).alias(col) + for i, col in enumerate(df.columns) ) else: raise ValueError("other must be a DataFrame or a Sequence") diff --git a/tests/polars/test_mixin_polars.py b/tests/polars/test_mixin_polars.py index d297fbb0..310117d5 100644 --- a/tests/polars/test_mixin_polars.py +++ b/tests/polars/test_mixin_polars.py @@ -70,14 +70,14 @@ def test_df_all(self, mixin: PolarsMixin): } ) - # Test with axis='index' - result = mixin._df_all(df["A", "B"], axis="index") + # Test with axis='columns' + result = mixin._df_all(df["A", "B"], axis="columns") assert isinstance(result, pl.Series) assert result.name == "all" assert result.to_list() == [True, False, True] - # Test with axis='columns' - result = mixin._df_all(df["A", "B"], axis="columns") + # Test with axis='index' + result = mixin._df_all(df["A", "B"], axis="index") assert isinstance(result, pl.Series) assert result.name == "all" assert result.to_list() == [False, True] From c23a6f438b915ad2e544e65734ddad5bf75215d4 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:53:15 +0200 Subject: [PATCH 10/10] testing with index > 1 and 1 value --- mesa_frames/concrete/polars/mixin.py | 2 ++ tests/polars/test_mixin_polars.py | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index 4d17fa09..ab08973a 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -115,6 +115,8 @@ def _df_constructor( ) if index is not None: if index_cols is not None: + if isinstance(index_cols, str): + index_cols = [index_cols] index_df = pl.DataFrame(index, index_cols) else: index_df = pl.DataFrame(index) diff --git a/tests/polars/test_mixin_polars.py b/tests/polars/test_mixin_polars.py index 310117d5..2db79759 100644 --- a/tests/polars/test_mixin_polars.py +++ b/tests/polars/test_mixin_polars.py @@ -257,6 +257,16 @@ def test_df_constructor(self, mixin: PolarsMixin): assert df["num"].to_list() == [1, 2, 3] assert df["letter"].to_list() == ["a", "b", "c"] + # Test with index > 1 and 1 value + data = {"a": 5} + df = mixin._df_constructor( + data, index=pl.int_range(5, eager=True), index_cols="index" + ) + assert isinstance(df, pl.DataFrame) + assert list(df.columns) == ["index", "a"] + assert df["a"].to_list() == [5, 5, 5, 5, 5] + assert df["index"].to_list() == [0, 1, 2, 3, 4] + def test_df_contains(self, mixin: PolarsMixin, df_0: pl.DataFrame): # Test with list result = mixin._df_contains(df_0, "A", [5, 2, 3])