From 63614255b0e959c6ab6da37fb5124a59be4a346d Mon Sep 17 00:00:00 2001 From: Katharine Xiao <2405771+katxiao@users.noreply.github.com> Date: Wed, 28 Jul 2021 19:08:20 -0400 Subject: [PATCH] Add argument to drop column in ColumnFormula --- sdv/constraints/tabular.py | 8 +++++-- tests/unit/constraints/test_tabular.py | 30 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py index c646fcc6f..03a7070f2 100644 --- a/sdv/constraints/tabular.py +++ b/sdv/constraints/tabular.py @@ -483,11 +483,14 @@ class ColumnFormula(Constraint): handling_strategy (str): How this Constraint should be handled, which can be ``transform`` or ``reject_sampling``. Defaults to ``transform``. + drop_column(str): + Whether or not to drop the constraint column. """ - def __init__(self, column, formula, handling_strategy='transform'): + def __init__(self, column, formula, handling_strategy='transform', drop_column=True): self._column = column self._formula = import_object(formula) + self._drop_column = drop_column super().__init__(handling_strategy, fit_columns_model=False) def is_valid(self, table_data): @@ -519,7 +522,8 @@ def transform(self, table_data): Transformed data. """ table_data = table_data.copy() - del table_data[self._column] + if self._drop_column: + del table_data[self._column] return table_data diff --git a/tests/unit/constraints/test_tabular.py b/tests/unit/constraints/test_tabular.py index 7744b5213..a8c5392a2 100644 --- a/tests/unit/constraints/test_tabular.py +++ b/tests/unit/constraints/test_tabular.py @@ -1802,6 +1802,36 @@ def test_transform(self): }) pd.testing.assert_frame_equal(expected_out, out) + def test_transform_without_dropping_column(self): + """Test the ``ColumnFormula.transform`` method without dropping the column. + + If `drop_column` is false, expect to not drop the constraint column. + + Input: + - Table data (pandas.DataFrame) + Output: + - Table data with the indicated column (pandas.DataFrame) + """ + # Setup + column = 'c' + instance = ColumnFormula(column=column, formula=new_column, drop_column=False) + + # Run + table_data = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': [4, 5, 6], + 'c': [5, 7, 9] + }) + out = instance.transform(table_data) + + # Assert + expected_out = pd.DataFrame({ + 'a': [1, 2, 3], + 'b': [4, 5, 6], + 'c': [5, 7, 9] + }) + pd.testing.assert_frame_equal(expected_out, out) + def test_reverse_transform(self): """Test the ``ColumnFormula.reverse_transform`` method.