Skip to content

Commit

Permalink
feat: Set column_names in fit methods of table transformers to be…
Browse files Browse the repository at this point in the history
… required (#225)

Closes #179.

### Summary of Changes

Every `fit` method of `TableTransformer`s now requires `column_names` to
be explicitly set.
  • Loading branch information
alex-senger authored Apr 21, 2023
1 parent 9509d3d commit 2856296
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/safeds/data/tabular/transformation/_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self, strategy: ImputerStrategy):
self._column_names: list[str] | None = None

# noinspection PyProtectedMember
def fit(self, table: Table, column_names: list[str] | None = None) -> Imputer:
def fit(self, table: Table, column_names: list[str] | None) -> Imputer:
"""
Learn a transformation for a set of columns in a table.
Expand Down
2 changes: 1 addition & 1 deletion src/safeds/data/tabular/transformation/_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self) -> None:
self._wrapped_transformer: sk_OrdinalEncoder | None = None
self._column_names: list[str] | None = None

def fit(self, table: Table, column_names: list[str] | None = None) -> LabelEncoder:
def fit(self, table: Table, column_names: list[str] | None) -> LabelEncoder:
"""
Learn a transformation for a set of columns in a table.
Expand Down
2 changes: 1 addition & 1 deletion src/safeds/data/tabular/transformation/_one_hot_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self) -> None:
self._column_names: dict[str, list[str]] | None = None

# noinspection PyProtectedMember
def fit(self, table: Table, column_names: list[str] | None = None) -> OneHotEncoder:
def fit(self, table: Table, column_names: list[str] | None) -> OneHotEncoder:
"""
Learn a transformation for a set of columns in a table.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class TableTransformer(ABC):
"""Learn a transformation for a set of columns in a `Table` and transform another `Table` with the same columns."""

@abstractmethod
def fit(self, table: Table, column_names: list[str] | None = None) -> TableTransformer:
def fit(self, table: Table, column_names: list[str] | None) -> TableTransformer:
"""
Learn a transformation for a set of columns in a table.
Expand Down
6 changes: 3 additions & 3 deletions tests/safeds/data/tabular/transformation/test_imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_should_not_change_original_transformer(self) -> None:
)

transformer = Imputer(Imputer.Strategy.Constant(0))
transformer.fit(table)
transformer.fit(table, None)

assert transformer._wrapped_transformer is None
assert transformer._column_names is None
Expand All @@ -38,7 +38,7 @@ def test_should_raise_if_column_not_found(self) -> None:
},
)

transformer = Imputer(Imputer.Strategy.Constant(0)).fit(table_to_fit)
transformer = Imputer(Imputer.Strategy.Constant(0)).fit(table_to_fit, None)

table_to_transform = Table.from_dict(
{
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_should_return_true_after_fitting(self) -> None:
)

transformer = Imputer(Imputer.Strategy.Mean())
fitted_transformer = transformer.fit(table)
fitted_transformer = transformer.fit(table, None)
assert fitted_transformer.is_fitted()


Expand Down
10 changes: 5 additions & 5 deletions tests/safeds/data/tabular/transformation/test_label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_should_not_change_original_transformer(self) -> None:
)

transformer = LabelEncoder()
transformer.fit(table)
transformer.fit(table, None)

assert transformer._wrapped_transformer is None
assert transformer._column_names is None
Expand All @@ -37,7 +37,7 @@ def test_should_raise_if_column_not_found(self) -> None:
},
)

transformer = LabelEncoder().fit(table_to_fit)
transformer = LabelEncoder().fit(table_to_fit, None)

table_to_transform = Table.from_dict(
{
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_should_return_true_after_fitting(self) -> None:
)

transformer = LabelEncoder()
fitted_transformer = transformer.fit(table)
fitted_transformer = transformer.fit(table, None)
assert fitted_transformer.is_fitted()


Expand Down Expand Up @@ -150,7 +150,7 @@ class TestInverseTransform:
],
)
def test_should_return_original_table(self, table: Table) -> None:
transformer = LabelEncoder().fit(table)
transformer = LabelEncoder().fit(table, None)

assert transformer.inverse_transform(transformer.transform(table)) == table

Expand All @@ -161,7 +161,7 @@ def test_should_not_change_transformed_table(self) -> None:
},
)

transformer = LabelEncoder().fit(table)
transformer = LabelEncoder().fit(table, None)
transformed_table = transformer.transform(table)
transformer.inverse_transform(transformed_table)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_should_not_change_original_transformer(self) -> None:
)

transformer = OneHotEncoder()
transformer.fit(table)
transformer.fit(table, None)

assert transformer._wrapped_transformer is None
assert transformer._column_names is None
Expand All @@ -37,7 +37,7 @@ def test_should_raise_if_column_not_found(self) -> None:
},
)

transformer = OneHotEncoder().fit(table_to_fit)
transformer = OneHotEncoder().fit(table_to_fit, None)

table_to_transform = Table.from_dict(
{
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_should_return_true_after_fitting(self) -> None:
)

transformer = OneHotEncoder()
fitted_transformer = transformer.fit(table)
fitted_transformer = transformer.fit(table, None)
assert fitted_transformer.is_fitted()


Expand Down Expand Up @@ -247,7 +247,7 @@ def test_should_not_change_transformed_table(self) -> None:
},
)

transformer = OneHotEncoder().fit(table)
transformer = OneHotEncoder().fit(table, None)
transformed_table = transformer.transform(table)
transformer.inverse_transform(transformed_table)

Expand Down

0 comments on commit 2856296

Please sign in to comment.