Skip to content

Commit

Permalink
Merge pull request #341 from lvgig/feature/update_weight_mixin
Browse files Browse the repository at this point in the history
added separate tests for WeightColumnMixin
  • Loading branch information
davidhopkinson26 authored Nov 14, 2024
2 parents 10ced7a + 587f63a commit 1792bbe
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 21 deletions.
14 changes: 0 additions & 14 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,20 +633,6 @@ def test_bad_values_in_weights_error(
with pytest.raises(ValueError, match=expected_message):
transformer.fit(df, df["a"])

def get_df_error_combos():
return [
(
pd.DataFrame({"a": [1, 2], "b": [3, 4]}),
r"weight col \(c\) is not present in columns of data",
"c",
),
(
pd.DataFrame({"a": [1, 2], "b": ["a", "b"]}),
r"weight column must be numeric.",
"b",
),
]

@pytest.mark.parametrize(
"minimal_dataframe_lookup",
["pandas", "polars"],
Expand Down
150 changes: 150 additions & 0 deletions tests/mixins/test_WeightColumnMixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import narwhals as nw
import numpy as np
import pytest

from tests.test_data import create_df_2
from tubular.mixins import WeightColumnMixin


class TestCheckAndSetWeight:
@pytest.mark.parametrize("weights_column", (0, ["a"], {"a": 10}))
def test_weight_arg_errors(
self,
weights_column,
):
"""Test that appropriate errors are throw for bad weight arg."""

obj = WeightColumnMixin()

with pytest.raises(
TypeError,
match="weights_column should be str or None",
):
obj.check_and_set_weight(weights_column)


class TestCheckWeightsColumn:
@pytest.mark.parametrize(
"library",
[
"pandas",
"polars",
],
)
@pytest.mark.parametrize(
"bad_weight_value, expected_message",
[
(None, "weight column must be non-null"),
(np.inf, "weight column must not contain infinite values."),
(-np.inf, "weight column must be positive"),
(-1, "weight column must be positive"),
],
)
def test_bad_values_in_weights_error(
self,
bad_weight_value,
expected_message,
library,
):
"""Test that an exception is raised if there are negative/nan/inf values in sample_weight."""

df = create_df_2(library=library)

obj = WeightColumnMixin()

df = nw.from_native(df)
native_namespace = nw.get_native_namespace(df)

weight_column = "weight_column"

df = df.with_columns(
nw.new_series(
weight_column,
[*[bad_weight_value], *np.arange(2, len(df) + 1)],
native_namespace=native_namespace,
),
)

df = nw.to_native(df)

with pytest.raises(ValueError, match=expected_message):
obj.check_weights_column(df, weight_column)

@pytest.mark.parametrize(
"library",
["pandas", "polars"],
)
def test_weight_col_non_numeric(
self,
library,
):
"""Test an error is raised if weight is not numeric."""

obj = WeightColumnMixin()

df = create_df_2(library=library)
df = nw.from_native(df)

weight_column = "weight_column"
error = r"weight column must be numeric."
df = df.with_columns(nw.lit("a").alias(weight_column))
df = nw.to_native(df)

with pytest.raises(
ValueError,
match=error,
):
# using check_weights_column method to test correct error is raised for transformers that use weights

obj.check_weights_column(df, weight_column)

@pytest.mark.parametrize(
"library",
["pandas", "polars"],
)
def test_weight_not_in_X_error(
self,
library,
):
"""Test an error is raised if weight is not in X"""

obj = WeightColumnMixin()

df = create_df_2(library=library)

weight_column = "weight_column"
error = rf"weight col \({weight_column}\) is not present in columns of data"

with pytest.raises(
ValueError,
match=error,
):
# using check_weights_column method to test correct error is raised for transformers that use weights

obj.check_weights_column(df, weight_column)

@pytest.mark.parametrize(
"library",
["pandas", "polars"],
)
def test_zero_total_weight_error(
self,
library,
):
"""Test that an exception is raised if the total sample weights are 0."""

obj = WeightColumnMixin()

weight_column = "weight_column"

df = create_df_2(library=library)

df = nw.from_native(df)
df = df.with_columns(nw.lit(0).alias(weight_column))
df = nw.to_native(df)

with pytest.raises(
ValueError,
match="total sample weights are not greater than 0",
):
obj.check_weights_column(df, weight_column)
11 changes: 4 additions & 7 deletions tubular/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,11 @@ class WeightColumnMixin:
"""
Mixin class with weights functionality
Attributes
----------
polars_compatible : bool
class attribute, indicates whether transformer has been converted to polars/pandas agnostic narwhals framework
"""

polars_compatible = True
def classname(self) -> str:
"""Method that returns the name of the current class when called."""
return type(self).__name__

@nw.narwhalify
def check_weights_column(self, X: FrameT, weights_column: str) -> None:
Expand Down Expand Up @@ -177,7 +174,7 @@ def check_weights_column(self, X: FrameT, weights_column: str) -> None:
msg = f"{self.classname()}: weight column must be non-null"
raise ValueError(msg)

# check weight not inf, not currently a narwhals efficient way to do this
# check weight not inf, currently no polars-y way to do this in narwhals
if np.isinf(X[weights_column].to_numpy()).any():
msg = f"{self.classname()}: weight column must not contain infinite values."
raise ValueError(msg)
Expand Down

0 comments on commit 1792bbe

Please sign in to comment.