diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index 21e0c030..ab08973a 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -1,13 +1,11 @@ -from collections.abc import Collection, Iterator, Sequence +from collections.abc import Callable, Collection, Hashable, Iterator, Sequence from typing import Literal import polars as pl from typing_extensions import Any, overload -from collections.abc import Hashable - from mesa_frames.abstract.mixin import DataFrameMixin -from mesa_frames.types_ import PolarsMask +from mesa_frames.types_ import DataFrame, PolarsMask class PolarsMixin(DataFrameMixin): @@ -21,45 +19,13 @@ def _df_add( axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, ) -> pl.DataFrame: - if isinstance(other, pl.DataFrame): - if axis == "index": - if index_cols is None: - raise ValueError( - "index_cols must be specified when axis is 'index'" - ) - return ( - df.join(other.select(pl.all().suffix("_add")), on=index_cols) - .with_columns( - [ - (pl.col(col) + pl.col(f"{col}_add")).alias(col) - for col in df.columns - if col not in index_cols - ] - ) - .select(df.columns) - ) - else: - return df.select( - [ - (pl.col(col) + pl.col(other.columns[i])).alias(col) - for i, col in enumerate(df.columns) - ] - ) - elif isinstance(other, Sequence): - if axis == "index": - other_series = pl.Series("addend", other) - return df.with_columns( - [(pl.col(col) + other_series).alias(col) for col in df.columns] - ) - else: - return df.with_columns( - [ - (pl.col(col) + other[i]).alias(col) - for i, col in enumerate(df.columns) - ] - ) - else: - raise ValueError("other must be a DataFrame or a Sequence") + return self._df_operation( + df=df, + other=other, + operation=lambda x, y: x + y, + axis=axis, + index_cols=index_cols, + ) def _df_all( self, @@ -67,9 +33,9 @@ def _df_all( name: str = "all", axis: Literal["index", "columns"] = "columns", ) -> pl.Series: - if axis == "columns": - return df.select(pl.col("*").all()).to_series() - return df.with_columns(all=pl.all_horizontal())["all"] + if axis == "index": + return pl.Series(name, df.select(pl.col("*").all()).row(0)) + return df.with_columns(pl.all_horizontal("*").alias(name))[name] def _df_column_names(self, df: pl.DataFrame) -> list[str]: return df.columns @@ -80,24 +46,13 @@ def _df_combine_first( new_df: pl.DataFrame, index_cols: str | list[str], ) -> pl.DataFrame: - new_df = original_df.join(new_df, on=index_cols, how="full", suffix="_right") - # Find columns with the _right suffix and update the corresponding original columns - updated_columns = [] - for col in new_df.columns: - if col.endswith("_right"): - original_col = col.replace("_right", "") - updated_columns.append( - pl.when(pl.col(col).is_not_null()) - .then(pl.col(col)) - .otherwise(pl.col(original_col)) - .alias(original_col) - ) - - # Apply the updates and remove the _right columns - new_df = new_df.with_columns(updated_columns).select( - pl.col(r"^(?!.*_right$).*") - ) - return new_df + common_cols = set(original_df.columns) & set(new_df.columns) + merged_df = original_df.join(new_df, on=index_cols, how="full", suffix="_right") + merged_df = merged_df.with_columns( + pl.coalesce(pl.col(col), pl.col(f"{col}_right")).alias(col) + for col in common_cols + ).select(pl.exclude("^.*_right$")) + return merged_df @overload def _df_concat( @@ -105,7 +60,7 @@ def _df_concat( objs: Collection[pl.DataFrame], how: Literal["horizontal"] | Literal["vertical"] = "vertical", ignore_index: bool = False, - index_cols: str | None = None, + index_cols: str | list[str] | None = None, ) -> pl.DataFrame: ... @overload @@ -133,20 +88,43 @@ def _df_concat( ignore_index: bool = False, index_cols: str | None = None, ) -> pl.Series | pl.DataFrame: - return pl.concat( - objs, how="vertical_relaxed" if how == "vertical" else "horizontal_relaxed" - ) + if isinstance(objs[0], pl.DataFrame) and how == "vertical": + how = "diagonal_relaxed" + if isinstance(objs[0], pl.Series) and how == "horizontal": + obj = pl.DataFrame().hstack(objs, in_place=True) + else: + obj = pl.concat(objs, how=how) + if isinstance(obj, pl.DataFrame) and how == "horizontal" and ignore_index: + obj = obj.rename( + {c: str(i) for c, i in zip(obj.columns, range(len(obj.columns)))} + ) + return obj def _df_constructor( self, - data: Sequence[Sequence] | dict[str | Any] | None = None, + data: dict[str | Any] | Sequence[Sequence] | DataFrame | None = None, columns: list[str] | None = None, index: Sequence[Hashable] | None = None, index_cols: str | list[str] | None = None, dtypes: dict[str, str] | None = None, ) -> pl.DataFrame: - dtypes = {k: self._dtypes_mapping.get(v, v) for k, v in dtypes.items()} - return pl.DataFrame(data=data, schema=dtypes if dtypes else columns) + if dtypes is not None: + dtypes = {k: self._dtypes_mapping.get(v, v) for k, v in dtypes.items()} + df = pl.DataFrame( + data=data, schema=columns, schema_overrides=dtypes, orient="row" + ) + if index is not None: + if index_cols is not None: + if isinstance(index_cols, str): + index_cols = [index_cols] + index_df = pl.DataFrame(index, index_cols) + else: + index_df = pl.DataFrame(index) + if len(df) != len(index_df) and len(df) == 1: + df = index_df.join(df, how="cross") + else: + df = index_df.hstack(df) + return df def _df_contains( self, @@ -154,70 +132,22 @@ def _df_contains( column: str, values: Sequence[Any], ) -> pl.Series: - return pl.Series(values, index=values).is_in(df[column]) + return pl.Series("contains", values).is_in(df[column]) def _df_div( self, df: pl.DataFrame, - other: pl.DataFrame | pl.Series | Sequence[float | int], + other: pl.DataFrame | Sequence[float | int], axis: Literal["index"] | Literal["columns"] = "index", index_cols: str | list[str] | None = None, ) -> pl.DataFrame: - if isinstance(other, pl.DataFrame): - if axis == "index": - if index_cols is None: - raise ValueError( - "index_cols must be specified when axis is 'index'" - ) - return ( - df.join(other.select(pl.all().suffix("_div")), on=index_cols) - .with_columns( - [ - (pl.col(col) / pl.col(f"{col}_div")).alias(col) - for col in df.columns - if col not in index_cols - ] - ) - .select(df.columns) - ) - else: # axis == "columns" - return df.select( - [ - (pl.col(col) / pl.col(other.columns[i])).alias(col) - for i, col in enumerate(df.columns) - ] - ) - elif isinstance(other, pl.Series): - if axis == "index": - return df.with_columns( - [ - (pl.col(col) / other).alias(col) - for col in df.columns - if col != other.name - ] - ) - else: # axis == "columns" - return df.with_columns( - [ - (pl.col(col) / other[i]).alias(col) - for i, col in enumerate(df.columns) - ] - ) - elif isinstance(other, Sequence): - if axis == "index": - other_series = pl.Series("divisor", other) - return df.with_columns( - [(pl.col(col) / other_series).alias(col) for col in df.columns] - ) - else: # axis == "columns" - return df.with_columns( - [ - (pl.col(col) / other[i]).alias(col) - for i, col in enumerate(df.columns) - ] - ) - else: - raise ValueError("other must be a DataFrame, Series, or Sequence") + return self._df_operation( + df=df, + other=other, + operation=lambda x, y: x / y, + axis=axis, + index_cols=index_cols, + ) def _df_drop_columns( self, @@ -235,40 +165,26 @@ def _df_drop_duplicates( # If subset is None, use all columns if subset is None: subset = df.columns - # If subset is a string, convert it to a list - elif isinstance(subset, str): - subset = [subset] - - # Determine the sort order based on 'keep' + original_col_order = df.columns if keep == "first": - sort_expr = [pl.col(col).rank("dense", reverse=True) for col in subset] + return ( + df.group_by(subset, maintain_order=True) + .first() + .select(original_col_order) + ) elif keep == "last": - sort_expr = [pl.col(col).rank("dense") for col in subset] - elif keep is False: - # If keep is False, we don't need to sort, just group and filter - return df.group_by(subset).agg(pl.all().first()).sort(subset) + return ( + df.group_by(subset, maintain_order=True) + .last() + .select(original_col_order) + ) else: - raise ValueError("'keep' must be either 'first', 'last', or False") - - # Add a rank column, sort by it, and keep only the first row of each group - return ( - df.with_columns(pl.struct(sort_expr).alias("__rank")) - .sort("__rank") - .group_by(subset) - .agg(pl.all().first()) - .sort(subset) - .drop("__rank") - ) - - def _df_filter( - self, - df: pl.DataFrame, - condition: pl.Series, - all: bool = True, - ) -> pl.DataFrame: - if all: - return df.filter(pl.all(condition)) - return df.filter(condition) + return ( + df.with_columns(pl.len().over(subset)) + .filter(pl.col("len") < 2) + .drop("len") + .select(original_col_order) + ) def _df_get_bool_mask( self, @@ -284,8 +200,15 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: and len(mask) == len(df) ): return mask + assert isinstance(index_cols, str) return df[index_cols].is_in(mask) + def bool_mask_from_df(mask: pl.DataFrame) -> pl.Series: + mask = mask.with_columns(in_it=True) + return df.join(mask[index_cols + ["in_it"]], on=index_cols, how="left")[ + "in_it" + ].fill_null(False) + if isinstance(mask, pl.Expr): result = mask elif isinstance(mask, pl.Series): @@ -293,11 +216,13 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: elif isinstance(mask, pl.DataFrame): if index_cols in mask.columns: result = bool_mask_from_series(mask[index_cols]) + elif all(col in mask.columns for col in index_cols): + result = bool_mask_from_df(mask[index_cols]) elif len(mask.columns) == 1 and mask.dtypes[0] == pl.Boolean: result = bool_mask_from_series(mask[mask.columns[0]]) else: raise KeyError( - f"DataFrame must have an {index_cols} column or a single boolean column." + f"Mask must have {index_cols} column(s) or a single boolean column." ) elif mask is None or mask == "all": result = pl.Series([True] * len(df)) @@ -336,6 +261,7 @@ def _df_join( self, left: pl.DataFrame, right: pl.DataFrame, + index_cols: str | list[str] | None = None, on: str | list[str] | None = None, left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, @@ -346,14 +272,19 @@ def _df_join( | Literal["cross"] = "left", suffix="_right", ) -> pl.DataFrame: + if how == "outer": + how = "full" + if how == "right": + left, right = right, left + left_on, right_on = right_on, left_on + how = "left" return left.join( right, on=on, left_on=left_on, right_on=right_on, how=how, - lsuffix="", - rsuffix=suffix, + suffix=suffix, ) def _df_mul( @@ -363,45 +294,13 @@ def _df_mul( axis: Literal["index", "columns"] = "index", index_cols: str | list[str] | None = None, ) -> pl.DataFrame: - if isinstance(other, pl.DataFrame): - if axis == "index": - if index_cols is None: - raise ValueError( - "index_cols must be specified when axis is 'index'" - ) - return ( - df.join(other.select(pl.all().suffix("_mul")), on=index_cols) - .with_columns( - [ - (pl.col(col) * pl.col(f"{col}_mul")).alias(col) - for col in df.columns - if col not in index_cols - ] - ) - .select(df.columns) - ) - else: # axis == "columns" - return df.select( - [ - (pl.col(col) * pl.col(other.columns[i])).alias(col) - for i, col in enumerate(df.columns) - ] - ) - elif isinstance(other, Sequence): - if axis == "index": - other_series = pl.Series("multiplier", other) - return df.with_columns( - [(pl.col(col) * other_series).alias(col) for col in df.columns] - ) - else: - return df.with_columns( - [ - (pl.col(col) * other[i]).alias(col) - for i, col in enumerate(df.columns) - ] - ) - else: - raise ValueError("other must be a DataFrame or a Sequence") + return self._df_operation( + df=df, + other=other, + operation=lambda x, y: x * y, + axis=axis, + index_cols=index_cols, + ) @overload def _df_norm( @@ -426,15 +325,58 @@ def _df_norm( include_cols: bool = False, ) -> pl.Series | pl.DataFrame: srs = ( - df.with_columns(pl.col("*").pow(2).alias("*")) - .sum_horizontal() - .sqrt() - .rename(srs_name) + df.with_columns(pl.col("*").pow(2)).sum_horizontal().sqrt().rename(srs_name) ) if include_cols: - return df.with_columns(srs_name=srs) + return df.with_columns(srs) return srs + def _df_operation( + self, + df: pl.DataFrame, + other: pl.DataFrame | Sequence[float | int], + operation: Callable[[pl.Expr, pl.Expr], pl.Expr], + axis: Literal["index", "columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> pl.DataFrame: + if isinstance(other, pl.DataFrame): + if index_cols is not None: + op_df = df.join(other, how="left", on=index_cols, suffix="_op") + else: + assert len(df) == len( + other + ), "DataFrames must have the same length if index_cols is not specified" + index_cols = [] + other = other.rename(lambda col: col + "_op") + op_df = pl.concat([df, other], how="horizontal") + return op_df.with_columns( + operation(pl.col(col), pl.col(f"{col}_op")).alias(col) + for col in df.columns + if col not in index_cols + ).select(df.columns) + elif isinstance( + other, (Sequence, pl.Series) + ): # Currently, pl.Series is not a Sequence + if axis == "index": + assert len(df) == len( + other + ), "Sequence must have the same length as df if axis is 'index'" + other_series = pl.Series("operand", other) + return df.with_columns( + operation(pl.col(col), other_series).alias(col) + for col in df.columns + ) + else: + assert ( + len(df.columns) == len(other) + ), "Sequence must have the same length as df.columns if axis is 'columns'" + return df.with_columns( + operation(pl.col(col), pl.lit(other[i])).alias(col) + for i, col in enumerate(df.columns) + ) + else: + raise ValueError("other must be a DataFrame or a Sequence") + def _df_rename_columns( self, df: pl.DataFrame, old_columns: list[str], new_columns: list[str] ) -> pl.DataFrame: @@ -461,7 +403,11 @@ def _df_sample( seed: int | None = None, ) -> pl.DataFrame: return df.sample( - n=n, frac=frac, replace=with_replacement, shuffle=shuffle, seed=seed + n=n, + fraction=frac, + with_replacement=with_replacement, + shuffle=shuffle, + seed=seed, ) def _df_set_index( @@ -472,6 +418,7 @@ def _df_set_index( ) -> pl.DataFrame: if new_index is None: return df + return df.with_columns(**{index_name: new_index}) def _df_with_columns( self, @@ -505,17 +452,21 @@ def _srs_constructor( self, data: Sequence[Any] | None = None, name: str | None = None, - dtype: Any | None = None, + dtype: str | None = None, index: Sequence[Any] | None = None, ) -> pl.Series: - return pl.Series(name=name, values=data, dtype=self._dtypes_mapping[dtype]) + if dtype is not None: + dtype = self._dtypes_mapping[dtype] + return pl.Series(name=name, values=data, dtype=dtype) def _srs_contains( self, - srs: Sequence[Any], + srs: Collection[Any], values: Any | Sequence[Any], ) -> pl.Series: - return pl.Series(values, index=values).is_in(srs) + if not isinstance(values, Collection): + values = [values] + return pl.Series(values).is_in(srs) def _srs_range( self, diff --git a/tests/polars/test_mixin_polars.py b/tests/polars/test_mixin_polars.py new file mode 100644 index 00000000..2db79759 --- /dev/null +++ b/tests/polars/test_mixin_polars.py @@ -0,0 +1,627 @@ +import numpy as np +import pandas as pd +import polars as pl +import pytest +import typeguard as tg + +from mesa_frames.concrete.polars.mixin import PolarsMixin + + +@tg.typechecked +class TestPolarsMixin: + @pytest.fixture + def mixin(self): + return PolarsMixin() + + @pytest.fixture + def df_0(self): + return pl.DataFrame( + { + "unique_id": ["x", "y", "z"], + "A": [1, 2, 3], + "B": ["a", "b", "c"], + "C": [True, False, True], + "D": [1, 2, 3], + }, + ) + + @pytest.fixture + def df_1(self): + return pl.DataFrame( + { + "unique_id": ["z", "a", "b"], + "A": [4, 5, 6], + "B": ["d", "e", "f"], + "C": [False, True, False], + "E": [1, 2, 3], + }, + ) + + def test_df_add(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + # Test adding a DataFrame and a sequence element-wise along the rows (axis='index') + result = mixin._df_add(df_0[["A", "D"]], df_1["A"], axis="index") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [5, 7, 9] + assert result["D"].to_list() == [5, 7, 9] + + # Test adding a DataFrame and a sequence element-wise along the column (axis='columns') + result = mixin._df_add(df_0[["A", "D"]], [1, 2], axis="columns") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [2, 3, 4] + assert result["D"].to_list() == [3, 4, 5] + + # Test adding DataFrames with index-column alignment + df_1 = df_1.with_columns(D=pl.col("E")) + result = mixin._df_add( + df_0[["unique_id", "A", "D"]], + df_1[["unique_id", "A", "D"]], + axis="index", + index_cols="unique_id", + ) + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [None, None, 7] + assert result["D"].to_list() == [None, None, 4] + + def test_df_all(self, mixin: PolarsMixin): + df = pl.DataFrame( + { + "A": [True, False, True], + "B": [True, True, True], + } + ) + + # Test with axis='columns' + result = mixin._df_all(df["A", "B"], axis="columns") + assert isinstance(result, pl.Series) + assert result.name == "all" + assert result.to_list() == [True, False, True] + + # Test with axis='index' + result = mixin._df_all(df["A", "B"], axis="index") + assert isinstance(result, pl.Series) + assert result.name == "all" + assert result.to_list() == [False, True] + + def test_df_column_names(self, mixin: PolarsMixin, df_0: pl.DataFrame): + cols = mixin._df_column_names(df_0) + assert isinstance(cols, list) + assert all(isinstance(c, str) for c in cols) + assert set(mixin._df_column_names(df_0)) == {"unique_id", "A", "B", "C", "D"} + + def test_df_combine_first( + self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame + ): + # Test with df_0 and df_1 + result = mixin._df_combine_first(df_0, df_1, "unique_id") + result = result.sort("A") + assert isinstance(result, pl.DataFrame) + assert set(result.columns) == {"unique_id", "A", "B", "C", "D", "E"} + assert result["unique_id"].to_list() == ["x", "y", "z", "a", "b"] + assert result["A"].to_list() == [1, 2, 3, 5, 6] + assert result["B"].to_list() == ["a", "b", "c", "e", "f"] + assert result["C"].to_list() == [True, False, True, True, False] + assert result["D"].to_list() == [1, 2, 3, None, None] + assert result["E"].to_list() == [None, None, 1, 2, 3] + + # Test with df_1 and df_0 + result = mixin._df_combine_first(df_1, df_0, "unique_id") + result = result.sort("E", nulls_last=True) + assert isinstance(result, pl.DataFrame) + assert set(result.columns) == {"unique_id", "A", "B", "C", "D", "E"} + assert result["unique_id"].to_list() == ["z", "a", "b", "x", "y"] + assert result["A"].to_list() == [4, 5, 6, 1, 2] + assert result["B"].to_list() == ["d", "e", "f", "a", "b"] + assert result["C"].to_list() == [False, True, False, True, False] + assert result["D"].to_list() == [3, None, None, 1, 2] + assert result["E"].to_list() == [1, 2, 3, None, None] + + def test_df_concat( + self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame + ): + ### Test vertical concatenation + ## With DataFrames + for ignore_index in [False, True]: + vertical = mixin._df_concat( + [df_0, df_1], how="vertical", ignore_index=ignore_index + ) + assert isinstance(vertical, pl.DataFrame) + assert vertical.columns == ["unique_id", "A", "B", "C", "D", "E"] + assert len(vertical) == 6 + assert vertical["unique_id"].to_list() == ["x", "y", "z", "z", "a", "b"] + assert vertical["A"].to_list() == [1, 2, 3, 4, 5, 6] + assert vertical["B"].to_list() == ["a", "b", "c", "d", "e", "f"] + assert vertical["C"].to_list() == [True, False, True, False, True, False] + assert vertical["D"].to_list() == [1, 2, 3, None, None, None] + assert vertical["E"].to_list() == [None, None, None, 1, 2, 3] + + ## With Series + for ignore_index in [True, False]: + vertical = mixin._df_concat( + [df_0["A"], df_1["A"]], how="vertical", ignore_index=ignore_index + ) + assert isinstance(vertical, pl.Series) + assert len(vertical) == 6 + assert vertical.to_list() == [1, 2, 3, 4, 5, 6] + assert vertical.name == "A" + + ## Test horizontal concatenation + ## With DataFrames + # Error With same column names + with pytest.raises(pl.exceptions.DuplicateError): + mixin._df_concat([df_0, df_1], how="horizontal") + # With ignore_index = False + df_1 = df_1.rename(lambda c: f"{c}_1") + horizontal = mixin._df_concat([df_0, df_1], how="horizontal") + assert isinstance(horizontal, pl.DataFrame) + assert horizontal.columns == [ + "unique_id", + "A", + "B", + "C", + "D", + "unique_id_1", + "A_1", + "B_1", + "C_1", + "E_1", + ] + assert len(horizontal) == 3 + assert horizontal["unique_id"].to_list() == ["x", "y", "z"] + assert horizontal["A"].to_list() == [1, 2, 3] + assert horizontal["B"].to_list() == ["a", "b", "c"] + assert horizontal["C"].to_list() == [True, False, True] + assert horizontal["D"].to_list() == [1, 2, 3] + assert horizontal["unique_id_1"].to_list() == ["z", "a", "b"] + assert horizontal["A_1"].to_list() == [4, 5, 6] + assert horizontal["B_1"].to_list() == ["d", "e", "f"] + assert horizontal["C_1"].to_list() == [False, True, False] + assert horizontal["E_1"].to_list() == [1, 2, 3] + + # With ignore_index = True + horizontal_ignore_index = mixin._df_concat( + [df_0, df_1], + how="horizontal", + ignore_index=True, + ) + assert isinstance(horizontal_ignore_index, pl.DataFrame) + assert horizontal_ignore_index.columns == [ + "0", + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + ] + assert len(horizontal_ignore_index) == 3 + assert horizontal_ignore_index["0"].to_list() == ["x", "y", "z"] + assert horizontal_ignore_index["1"].to_list() == [1, 2, 3] + assert horizontal_ignore_index["2"].to_list() == ["a", "b", "c"] + assert horizontal_ignore_index["3"].to_list() == [True, False, True] + assert horizontal_ignore_index["4"].to_list() == [1, 2, 3] + assert horizontal_ignore_index["5"].to_list() == ["z", "a", "b"] + assert horizontal_ignore_index["6"].to_list() == [4, 5, 6] + assert horizontal_ignore_index["7"].to_list() == ["d", "e", "f"] + assert horizontal_ignore_index["8"].to_list() == [False, True, False] + assert horizontal_ignore_index["9"].to_list() == [1, 2, 3] + + ## With Series + # With ignore_index = False + horizontal = mixin._df_concat( + [df_0["A"], df_1["B_1"]], how="horizontal", ignore_index=False + ) + assert isinstance(horizontal, pl.DataFrame) + assert horizontal.columns == ["A", "B_1"] + assert len(horizontal) == 3 + assert horizontal["A"].to_list() == [1, 2, 3] + assert horizontal["B_1"].to_list() == ["d", "e", "f"] + + # With ignore_index = True + horizontal = mixin._df_concat( + [df_0["A"], df_1["B_1"]], how="horizontal", ignore_index=True + ) + assert isinstance(horizontal, pl.DataFrame) + assert horizontal.columns == ["0", "1"] + assert len(horizontal) == 3 + assert horizontal["0"].to_list() == [1, 2, 3] + assert horizontal["1"].to_list() == ["d", "e", "f"] + + def test_df_constructor(self, mixin: PolarsMixin): + # Test with dictionary + data = {"num": [1, 2, 3], "letter": ["a", "b", "c"]} + df = mixin._df_constructor(data) + assert isinstance(df, pl.DataFrame) + assert list(df.columns) == ["num", "letter"] + assert df["num"].to_list() == [1, 2, 3] + assert df["letter"].to_list() == ["a", "b", "c"] + + # Test with list of lists + data = [[1, "a"], [2, "b"], [3, "c"]] + df = mixin._df_constructor( + data, columns=["num", "letter"], dtypes={"num": "int64"} + ) + assert isinstance(df, pl.DataFrame) + assert list(df.columns) == ["num", "letter"] + assert df["num"].dtype == pl.Int64 + assert df["num"].to_list() == [1, 2, 3] + assert df["letter"].to_list() == ["a", "b", "c"] + + # Test with pandas DataFrame + data = pd.DataFrame({"num": [1, 2, 3], "letter": ["a", "b", "c"]}) + df = mixin._df_constructor(data) + assert isinstance(df, pl.DataFrame) + assert list(df.columns) == ["num", "letter"] + assert df["num"].to_list() == [1, 2, 3] + assert df["letter"].to_list() == ["a", "b", "c"] + + # Test with index > 1 and 1 value + data = {"a": 5} + df = mixin._df_constructor( + data, index=pl.int_range(5, eager=True), index_cols="index" + ) + assert isinstance(df, pl.DataFrame) + assert list(df.columns) == ["index", "a"] + assert df["a"].to_list() == [5, 5, 5, 5, 5] + assert df["index"].to_list() == [0, 1, 2, 3, 4] + + def test_df_contains(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with list + result = mixin._df_contains(df_0, "A", [5, 2, 3]) + assert isinstance(result, pl.Series) + assert result.name == "contains" + assert result.to_list() == [False, True, True] + + def test_df_div(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + # Test dividing the DataFrame by a sequence element-wise along the rows (axis='index') + result = mixin._df_div(df_0[["A", "D"]], df_1["A"], axis="index") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [0.25, 0.4, 0.5] + assert result["D"].to_list() == [0.25, 0.4, 0.5] + + # Test dividing the DataFrame by a sequence element-wise along the columns (axis='columns') + result = mixin._df_div(df_0[["A", "D"]], [1, 2], axis="columns") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [1, 2, 3] + assert result["D"].to_list() == [0.5, 1, 1.5] + + # Test dividing DataFrames with index-column alignment + df_1 = df_1.with_columns(D=pl.col("E")) + result = mixin._df_div( + df_0[["unique_id", "A", "D"]], + df_1[["unique_id", "A", "D"]], + axis="index", + index_cols="unique_id", + ) + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [None, None, 0.75] + assert result["D"].to_list() == [None, None, 3] + + def test_df_drop_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with str + dropped = mixin._df_drop_columns(df_0, "A") + assert isinstance(dropped, pl.DataFrame) + assert dropped.columns == ["unique_id", "B", "C", "D"] + # Test with list + dropped = mixin._df_drop_columns(df_0, ["A", "C"]) + assert dropped.columns == ["unique_id", "B", "D"] + + def test_df_drop_duplicates(self, mixin: PolarsMixin, df_0: pl.DataFrame): + new_df = pl.concat([df_0, df_0], how="vertical") + assert len(new_df) == 6 + + # Test with all columns + dropped = mixin._df_drop_duplicates(new_df) + assert isinstance(dropped, pl.DataFrame) + assert len(dropped) == 3 + assert dropped.columns == ["unique_id", "A", "B", "C", "D"] + + # Test with subset (str) + other_df = pl.DataFrame( + { + "unique_id": ["x", "y", "z"], + "A": [1, 2, 3], + "B": ["d", "e", "f"], + "C": [True, True, False], + "D": [1, 2, 3], + }, + ) + new_df = pl.concat([df_0, other_df], how="vertical") + dropped = mixin._df_drop_duplicates(new_df, subset="unique_id") + assert isinstance(dropped, pl.DataFrame) + assert len(dropped) == 3 + + # Test with subset (list) + dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"]) + assert isinstance(dropped, pl.DataFrame) + assert len(dropped) == 5 + assert dropped.columns == ["unique_id", "A", "B", "C", "D"] + assert dropped["B"].to_list() == ["a", "b", "c", "e", "f"] + + # Test with subset (list) and keep='last' + dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"], keep="last") + assert isinstance(dropped, pl.DataFrame) + assert len(dropped) == 5 + assert dropped.columns == ["unique_id", "A", "B", "C", "D"] + assert dropped["B"].to_list() == ["d", "b", "c", "e", "f"] + + # Test with subset (list) and keep=False + dropped = mixin._df_drop_duplicates(new_df, subset=["A", "C"], keep=False) + assert isinstance(dropped, pl.DataFrame) + assert len(dropped) == 4 + assert dropped.columns == ["unique_id", "A", "B", "C", "D"] + assert dropped["B"].to_list() == ["b", "c", "e", "f"] + + def test_df_get_bool_mask(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with pl.Series[bool] + mask = mixin._df_get_bool_mask(df_0, "A", pl.Series([True, False, True])) + assert mask.to_list() == [True, False, True] + + # Test with DataFrame + mask_df = pl.DataFrame({"A": [1, 3]}) + mask = mixin._df_get_bool_mask(df_0, "A", mask_df) + assert mask.to_list() == [True, False, True] + + # Test with single value + mask = mixin._df_get_bool_mask(df_0, "A", 1) + assert mask.to_list() == [True, False, False] + + # Test with list of values + mask = mixin._df_get_bool_mask(df_0, "A", [1, 3]) + assert mask.to_list() == [True, False, True] + + # Test with negate=True + mask = mixin._df_get_bool_mask(df_0, "A", [1, 3], negate=True) + assert mask.to_list() == [False, True, False] + + def test_df_get_masked_df(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with pl.Series[bool] + masked_df = mixin._df_get_masked_df(df_0, "A", pl.Series([True, False, True])) + assert masked_df["A"].to_list() == [1, 3] + assert masked_df["unique_id"].to_list() == ["x", "z"] + + # Test with DataFrame + mask_df = pl.DataFrame({"A": [1, 3]}) + masked_df = mixin._df_get_masked_df(df_0, "A", mask_df) + assert masked_df["A"].to_list() == [1, 3] + assert masked_df["unique_id"].to_list() == ["x", "z"] + + # Test with single value + masked_df = mixin._df_get_masked_df(df_0, "A", 1) + assert masked_df["A"].to_list() == [1] + assert masked_df["unique_id"].to_list() == ["x"] + + # Test with list of values + masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3]) + assert masked_df["A"].to_list() == [1, 3] + assert masked_df["unique_id"].to_list() == ["x", "z"] + + # Test with columns + masked_df = mixin._df_get_masked_df(df_0, "A", [1, 3], columns=["B"]) + assert list(masked_df.columns) == ["B"] + assert masked_df["B"].to_list() == ["a", "c"] + + # Test with negate=True + masked = mixin._df_get_masked_df(df_0, "A", [1, 3], negate=True) + assert len(masked) == 1 + + def test_df_groupby_cumcount(self, df_0: pl.DataFrame, mixin: PolarsMixin): + result = mixin._df_groupby_cumcount(df_0, "C") + assert result.to_list() == [1, 1, 2] + + def test_df_iterator(self, mixin: PolarsMixin, df_0: pl.DataFrame): + iterator = mixin._df_iterator(df_0) + first_item = next(iterator) + assert first_item == {"unique_id": "x", "A": 1, "B": "a", "C": True, "D": 1} + + def test_df_join(self, mixin: PolarsMixin): + left = pl.DataFrame({"A": [1, 2], "B": ["a", "b"]}) + right = pl.DataFrame({"A": [1, 3], "C": ["x", "y"]}) + + # Test with 'on' (left join) + joined = mixin._df_join(left, right, on="A") + assert set(joined.columns) == {"A", "B", "C"} + assert joined["A"].to_list() == [1, 2] + + # Test with 'left_on' and 'right_on' (left join) + right_1 = pl.DataFrame({"D": [1, 2], "C": ["x", "y"]}) + joined = mixin._df_join(left, right_1, left_on="A", right_on="D") + assert set(joined.columns) == {"A", "B", "C"} + assert joined["A"].to_list() == [1, 2] + + # Test with 'right' join + joined = mixin._df_join(left, right, on="A", how="right") + assert set(joined.columns) == {"A", "B", "C"} + assert joined["A"].to_list() == [1, 3] + + # Test with 'inner' join + joined = mixin._df_join(left, right, on="A", how="inner") + assert set(joined.columns) == {"A", "B", "C"} + assert joined["A"].to_list() == [1] + + # Test with 'outer' join + joined = mixin._df_join(left, right, on="A", how="outer") + assert set(joined.columns) == {"A", "B", "A_right", "C"} + assert joined["A"].to_list() == [1, None, 2] + assert joined["A_right"].to_list() == [1, 3, None] + + # Test with 'cross' join + joined = mixin._df_join(left, right, how="cross") + assert set(joined.columns) == {"A", "B", "A_right", "C"} + assert len(joined) == 4 + assert joined.row(0) == (1, "a", 1, "x") + assert joined.row(1) == (1, "a", 3, "y") + assert joined.row(2) == (2, "b", 1, "x") + assert joined.row(3) == (2, "b", 3, "y") + + # Test with different 'suffix' + joined = mixin._df_join(left, right, suffix="_r", how="cross") + assert set(joined.columns) == {"A", "B", "A_r", "C"} + assert len(joined) == 4 + assert joined.row(0) == (1, "a", 1, "x") + assert joined.row(1) == (1, "a", 3, "y") + assert joined.row(2) == (2, "b", 1, "x") + assert joined.row(3) == (2, "b", 3, "y") + + def test_df_mul(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame): + # Test multiplying the DataFrame by a sequence element-wise along the rows (axis='index') + result = mixin._df_mul(df_0[["A", "D"]], df_1["A"], axis="index") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [4, 10, 18] + assert result["D"].to_list() == [4, 10, 18] + + # Test multiplying the DataFrame by a sequence element-wise along the columns (axis='columns') + result = mixin._df_mul(df_0[["A", "D"]], [1, 2], axis="columns") + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [1, 2, 3] + assert result["D"].to_list() == [2, 4, 6] + + # Test multiplying DataFrames with index-column alignment + df_1 = df_1.with_columns(D=pl.col("E")) + result = mixin._df_mul( + df_0[["unique_id", "A", "D"]], + df_1[["unique_id", "A", "D"]], + axis="index", + index_cols="unique_id", + ) + assert isinstance(result, pl.DataFrame) + assert result["A"].to_list() == [None, None, 12] + assert result["D"].to_list() == [None, None, 3] + + def test_df_norm(self, mixin: PolarsMixin): + df = pl.DataFrame({"A": [3, 4], "B": [4, 3]}) + # If include_cols = False + norm = mixin._df_norm(df) + assert isinstance(norm, pl.Series) + assert len(norm) == 2 + assert norm[0] == 5 + assert norm[1] == 5 + + # If include_cols = True + norm = mixin._df_norm(df, include_cols=True) + assert isinstance(norm, pl.DataFrame) + assert len(norm) == 2 + assert norm.columns == ["A", "B", "norm"] + assert norm.row(0, named=True)["norm"] == 5 + assert norm.row(1, named=True)["norm"] == 5 + + def test_df_rename_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): + renamed = mixin._df_rename_columns(df_0, ["A", "B"], ["X", "Y"]) + assert renamed.columns == ["unique_id", "X", "Y", "C", "D"] + + def test_df_reset_index(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # with drop = False + new_df = mixin._df_reset_index(df_0) + assert mixin._df_all(new_df == df_0).all() + + # with drop = True + new_df = mixin._df_reset_index(df_0, index_cols="unique_id", drop=True) + assert new_df.columns == ["A", "B", "C", "D"] + assert len(new_df) == len(df_0) + for col in new_df.columns: + assert (new_df[col] == df_0[col]).all() + + def test_df_remove(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with list + removed = mixin._df_remove(df_0, [1, 3], "A") + assert len(removed) == 1 + assert removed["unique_id"].to_list() == ["y"] + + def test_df_sample(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with n + sampled = mixin._df_sample(df_0, n=2, seed=42) + assert len(sampled) == 2 + + # Test with frac + sampled = mixin._df_sample(df_0, frac=2 / 3, seed=42) + assert len(sampled) == 2 + + # Test with replacement + sampled = mixin._df_sample(df_0, n=4, with_replacement=True, seed=42) + assert len(sampled) == 4 + assert sampled.n_unique() < 4 + + def test_df_set_index(self, mixin: PolarsMixin, df_0: pl.DataFrame): + index = pl.int_range(len(df_0), eager=True) + new_df = mixin._df_set_index(df_0, "index", index) + assert (new_df["index"] == index).all() + + def test_df_with_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame): + # Test with list + new_df = mixin._df_with_columns( + df_0, + data=[[4, "d"], [5, "e"], [6, "f"]], + new_columns=["D", "E"], + ) + assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] + assert new_df["D"].to_list() == [4, 5, 6] + assert new_df["E"].to_list() == ["d", "e", "f"] + + # Test with pl.DataFrame + second_df = pl.DataFrame({"D": [4, 5, 6], "E": ["d", "e", "f"]}) + new_df = mixin._df_with_columns(df_0, second_df) + assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] + assert new_df["D"].to_list() == [4, 5, 6] + assert new_df["E"].to_list() == ["d", "e", "f"] + + # Test with dictionary + new_df = mixin._df_with_columns( + df_0, data={"D": [4, 5, 6], "E": ["d", "e", "f"]} + ) + assert list(new_df.columns) == ["unique_id", "A", "B", "C", "D", "E"] + assert new_df["D"].to_list() == [4, 5, 6] + assert new_df["E"].to_list() == ["d", "e", "f"] + + # Test with numpy array + new_df = mixin._df_with_columns(df_0, data=np.array([4, 5, 6]), new_columns="D") + assert "D" in new_df.columns + assert new_df["D"].to_list() == [4, 5, 6] + + # Test with pl.Series + new_df = mixin._df_with_columns(df_0, pl.Series([4, 5, 6]), new_columns="D") + assert "D" in new_df.columns + assert new_df["D"].to_list() == [4, 5, 6] + + def test_srs_constructor(self, mixin: PolarsMixin): + # Test with list + srs = mixin._srs_constructor([1, 2, 3], name="test", dtype="int64") + assert srs.name == "test" + assert srs.dtype == pl.Int64 + + # Test with numpy array + srs = mixin._srs_constructor(np.array([1, 2, 3]), name="test") + assert srs.name == "test" + assert len(srs) == 3 + + def test_srs_contains(self, mixin: PolarsMixin): + srs = [1, 2, 3, 4, 5] + + # Test with single value + result = mixin._srs_contains(srs, 3) + assert result.to_list() == [True] + + # Test with list + result = mixin._srs_contains(srs, [1, 3, 6]) + assert result.to_list() == [True, True, False] + + # Test with numpy array + result = mixin._srs_contains(srs, np.array([1, 3, 6])) + assert result.to_list() == [True, True, False] + + def test_srs_range(self, mixin: PolarsMixin): + # Test with default step + srs = mixin._srs_range("test", 0, 5) + assert srs.name == "test" + assert srs.to_list() == [0, 1, 2, 3, 4] + + # Test with custom step + srs = mixin._srs_range("test", 0, 10, step=2) + assert srs.to_list() == [0, 2, 4, 6, 8] + + def test_srs_to_df(self, mixin: PolarsMixin): + srs = pl.Series("test", [1, 2, 3]) + df = mixin._srs_to_df(srs) + assert isinstance(df, pl.DataFrame) + assert df["test"].to_list() == [1, 2, 3]