Skip to content

Commit

Permalink
[0.9.1] Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pwwang committed Oct 13, 2022
1 parent 3d263af commit 6b1d9cc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
16 changes: 9 additions & 7 deletions datar/dplyr/group_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from ..core.backends.pandas import DataFrame
from ..core.backends.pandas.core.groupby import GroupBy

from ..core.contexts import Context

from ..core.tibble import Tibble, TibbleGrouped, TibbleRowwise
from ..core.utils import dict_get


@register_verb(DataFrame)
@register_verb(DataFrame, context=Context.EVAL)
def group_data(_data: DataFrame) -> Tibble:
"""Returns a data frame that defines the grouping structure.
Expand All @@ -34,7 +36,7 @@ def _(_data: Union[TibbleGrouped, GroupBy]) -> Tibble:
return gpdata


@register_verb(DataFrame)
@register_verb(DataFrame, context=Context.EVAL)
def group_keys(_data: DataFrame) -> Tibble:
"""Just grouping data without the `_rows` columns
Expand Down Expand Up @@ -68,7 +70,7 @@ def _(_data: TibbleRowwise) -> Tibble:
return Tibble(_data.loc[:, _data.group_vars])


@register_verb(DataFrame)
@register_verb(DataFrame, context=Context.EVAL)
def group_rows(_data: DataFrame) -> List[List[int]]:
"""The locations of grouping structure, always 0-based."""
rows = list(range(_data.shape[0]))
Expand All @@ -91,7 +93,7 @@ def _(_data: GroupBy) -> List[List[int]]:
]


@register_verb(DataFrame)
@register_verb(DataFrame, context=Context.EVAL)
def group_indices(_data: DataFrame) -> List[int]:
"""Returns an integer vector the same length as `_data`.
Expand All @@ -109,7 +111,7 @@ def _(_data: TibbleGrouped) -> List[int]:
return [ret[key] for key in sorted(ret)]


@register_verb(DataFrame)
@register_verb(DataFrame, context=Context.EVAL)
def group_vars(_data: DataFrame) -> Sequence[str]:
"""Gives names of grouping variables as character vector"""
return getattr(_data, "group_vars", [])
Expand All @@ -120,7 +122,7 @@ def group_vars(_data: DataFrame) -> Sequence[str]:
group_cols = group_vars


@register_verb(DataFrame)
@register_verb(DataFrame, context=Context.EVAL)
def group_size(_data: DataFrame) -> Sequence[int]:
"""Gives the size of each group"""
return [_data.shape[0]]
Expand All @@ -131,7 +133,7 @@ def _(_data: TibbleGrouped) -> Sequence[int]:
return list(map(len, group_rows(_data, __ast_fallback="normal")))


@register_verb(DataFrame)
@register_verb(DataFrame, context=Context.EVAL)
def n_groups(_data: DataFrame) -> int:
"""Gives the total number of groups."""
return 1
Expand Down
2 changes: 1 addition & 1 deletion tests/dplyr/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def test_errors():
# named inputs
with pytest.raises(TypeError):
mtcars >> filter(x=1)
with pytest.raises(TypeError):
with pytest.raises(KeyError):
mtcars >> filter(f.y > 2, z=3)
with pytest.raises(TypeError):
mtcars >> filter(True, x=1)
Expand Down

0 comments on commit 6b1d9cc

Please sign in to comment.