Skip to content
4 changes: 2 additions & 2 deletions mesa_frames/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from mesa_frames.concrete.agents import AgentsDF
from mesa_frames.concrete.agentset_pandas import AgentSetPandas
from mesa_frames.concrete.agentset_polars import AgentSetPolars
from mesa_frames.concrete.pandas.agentset import AgentSetPandas
from mesa_frames.concrete.polars.agentset import AgentSetPolars
from mesa_frames.concrete.model import ModelDF

__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion mesa_frames/abstract/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing_extensions import Any, Self, overload

from mesa_frames.abstract.mixin import CopyMixin
from mesa_frames.types import BoolSeries, DataFrame, IdsLike, Index, MaskLike, Series
from mesa_frames.types_ import BoolSeries, DataFrame, IdsLike, Index, MaskLike, Series

if TYPE_CHECKING:
from mesa_frames.concrete.agents import AgentSetDF
Expand Down
82 changes: 81 additions & 1 deletion mesa_frames/abstract/mixin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from abc import ABC, abstractmethod
from copy import copy, deepcopy

from typing_extensions import Self
from typing_extensions import Any, Self
from typing import Literal
from collections.abc import Collection, Iterator, Sequence

from mesa_frames.types_ import BoolSeries, DataFrame, MaskLike, Series


class CopyMixin(ABC):
Expand Down Expand Up @@ -142,3 +146,79 @@ def __deepcopy__(self, memo: dict) -> Self:
A deep copy of the AgentContainer.
"""
return self.copy(deep=True, memo=memo)


class DataFrameMixin(ABC):
@abstractmethod
def _df_add_columns(
self, original_df: DataFrame, new_columns: list[str], data: Any
) -> DataFrame: ...

@abstractmethod
def _df_combine_first(
self, original_df: DataFrame, new_df: DataFrame, index_cols: list[str]
) -> DataFrame: ...

@abstractmethod
def _df_concat(
self,
dfs: Collection[DataFrame],
how: Literal["horizontal"] | Literal["vertical"] = "vertical",
ignore_index: bool = False,
) -> DataFrame: ...

@abstractmethod
def _df_constructor(
self,
data: Sequence[Sequence] | dict[str | Any] | None = None,
columns: list[str] | None = None,
index_col: str | list[str] | None = None,
dtypes: dict[str, Any] | None = None,
) -> DataFrame: ...

@abstractmethod
def _df_get_bool_mask(
self,
df: DataFrame,
index_col: str,
mask: MaskLike | None = None,
negate: bool = False,
) -> BoolSeries: ...

@abstractmethod
def _df_get_masked_df(
self,
df: DataFrame,
index_col: str,
mask: MaskLike | None = None,
columns: list[str] | None = None,
negate: bool = False,
) -> DataFrame: ...

@abstractmethod
def _df_iterator(self, df: DataFrame) -> Iterator[dict[str, Any]]: ...

@abstractmethod
def _df_remove(
self, df: DataFrame, ids: Sequence[Any], index_col: str | None = None
) -> DataFrame: ...

@abstractmethod
def _df_sample(
self,
df: DataFrame,
n: int | None = None,
frac: float | None = None,
with_replacement: bool = False,
shuffle: bool = False,
seed: int | None = None,
) -> DataFrame: ...

@abstractmethod
def _srs_constructor(
self,
data: Sequence[Any] | None = None,
name: str | None = None,
dtype: Any | None = None,
index: Sequence[Any] | None = None,
) -> Series: ...
6 changes: 2 additions & 4 deletions mesa_frames/concrete/agents.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from collections import defaultdict
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from typing import Literal, cast
from typing import TYPE_CHECKING, Literal, cast

import polars as pl
from typing_extensions import Any, Self, overload

from typing import TYPE_CHECKING

from mesa_frames.abstract.agents import AgentContainer, AgentSetDF
from mesa_frames.types import (
from mesa_frames.types_ import (
AgnosticMask,
BoolSeries,
DataFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from typing_extensions import Any, Self, overload

from mesa_frames.abstract.agents import AgentSetDF
from mesa_frames.concrete.agentset_polars import AgentSetPolars
from mesa_frames.types import PandasIdsLike, PandasMaskLike
from mesa_frames.concrete.pandas.mixin import PandasMixin
from mesa_frames.concrete.polars.agentset import AgentSetPolars
from mesa_frames.types_ import PandasIdsLike, PandasMaskLike

if TYPE_CHECKING:
from mesa_frames.concrete.model import ModelDF


class AgentSetPandas(AgentSetDF):
class AgentSetPandas(AgentSetDF, PandasMixin):
_agents: pd.DataFrame
_mask: pd.Series
_copy_with_method: dict[str, tuple[str, list[str]]] = {
Expand Down
121 changes: 121 additions & 0 deletions mesa_frames/concrete/pandas/mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pandas as pd
from typing_extensions import Any
from typing import Literal
from collections.abc import Collection, Iterator, Sequence

from mesa_frames.abstract.mixin import DataFrameMixin
from mesa_frames.types_ import PandasMaskLike


class PandasMixin(DataFrameMixin):
def _df_add_columns(
self, original_df: pd.DataFrame, new_columns: list[str], data: Any
) -> pd.DataFrame:
original_df[new_columns] = data
return original_df

def _df_combine_first(
self, original_df: pd.DataFrame, new_df: pd.DataFrame, index_cols: list[str]
) -> pd.DataFrame:
return original_df.combine_first(new_df)

def _df_concat(
self,
dfs: Collection[pd.DataFrame],
how: Literal["horizontal"] | Literal["vertical"] = "vertical",
ignore_index: bool = False,
) -> pd.DataFrame:
return pd.concat(
dfs, axis=0 if how == "vertical" else 1, ignore_index=ignore_index
)

def _df_constructor(
self,
data: Sequence[Sequence] | dict[str | Any] | None = None,
columns: list[str] | None = None,
index_col: str | list[str] | None = None,
dtypes: dict[str, Any] | None = None,
) -> pd.DataFrame:
df = pd.DataFrame(data=data, columns=columns).astype(dtypes)
if index_col:
df.set_index(index_col)
return df

def _df_get_bool_mask(
self,
df: pd.DataFrame,
index_col: str,
mask: PandasMaskLike = None,
negate: bool = False,
) -> pd.Series:
if isinstance(mask, pd.Series) and mask.dtype == bool and len(mask) == len(df):
result = mask
elif isinstance(mask, pd.DataFrame):
if mask.index.name == index_col:
result = pd.Series(df.index.isin(mask.index), index=df.index)
elif index_col in mask.columns:
result = pd.Series(df.index.isin(mask[index_col]), index=df.index)
else:
raise ValueError(
f"A DataFrame mask must have a column/index with name {index_col}"
)
elif mask is None or mask == "all":
result = pd.Series(True, index=df.index)
elif isinstance(mask, Sequence):
result = pd.Series(df.index.isin(mask), index=df.index)
else:
result = pd.Series(df.index.isin([mask]), index=df.index)

if negate:
result = ~result

return result

def _df_get_masked_df(
self,
df: pd.DataFrame,
index_col: str,
mask: PandasMaskLike | None = None,
columns: list[str] | None = None,
negate: bool = False,
) -> pd.DataFrame:
b_mask = self._df_get_bool_mask(df, index_col, mask, negate)
if columns:
return df.loc[b_mask, columns]
return df.loc[b_mask]

def _df_iterator(self, df: pd.DataFrame) -> Iterator[dict[str, Any]]:
for index, row in df.iterrows():
row_dict = row.to_dict()
row_dict["unique_id"] = index
yield row_dict

def _df_remove(
self,
df: pd.DataFrame,
ids: Sequence[Any],
index_col: str | None = None,
) -> pd.DataFrame:
return df[~df.index.isin(ids)]

def _df_sample(
self,
df: pd.DataFrame,
n: int | None = None,
frac: float | None = None,
with_replacement: bool = False,
shuffle: bool = False,
seed: int | None = None,
) -> pd.DataFrame:
return df.sample(
n=n, frac=frac, replace=with_replacement, shuffle=shuffle, random_state=seed
)

def _srs_constructor(
self,
data: Sequence[Sequence] | None = None,
name: str | None = None,
dtype: Any | None = None,
index: Sequence[Any] | None = None,
) -> pd.Series:
return pd.Series(data, name=name, dtype=dtype, index=index)
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from typing_extensions import Any, Self, overload

from mesa_frames.concrete.agents import AgentSetDF
from mesa_frames.types import PolarsIdsLike, PolarsMaskLike
from mesa_frames.concrete.polars.mixin import PolarsMixin
from mesa_frames.types_ import PolarsIdsLike, PolarsMaskLike

if TYPE_CHECKING:
from mesa_frames.concrete.agentset_pandas import AgentSetPandas
from mesa_frames.concrete.model import ModelDF
from mesa_frames.concrete.pandas.agentset import AgentSetPandas


class AgentSetPolars(AgentSetDF):
class AgentSetPolars(AgentSetDF, PolarsMixin):
_agents: pl.DataFrame
_copy_with_method: dict[str, tuple[str, list[str]]] = {
"_agents": ("clone", []),
Expand Down Expand Up @@ -309,7 +310,7 @@ def sort(
return obj

def to_pandas(self) -> "AgentSetPandas":
from mesa_frames.concrete.agentset_pandas import AgentSetPandas
from mesa_frames.concrete.pandas.agentset import AgentSetPandas

new_obj = AgentSetPandas(self._model)
new_obj._agents = self._agents.to_pandas()
Expand Down
Loading