Skip to content

Commit

Permalink
Add argument to drop column in ColumnFormula
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Jul 29, 2021
1 parent c58f174 commit 6361425
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,14 @@ class ColumnFormula(Constraint):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
drop_column(str):
Whether or not to drop the constraint column.
"""

def __init__(self, column, formula, handling_strategy='transform'):
def __init__(self, column, formula, handling_strategy='transform', drop_column=True):
self._column = column
self._formula = import_object(formula)
self._drop_column = drop_column
super().__init__(handling_strategy, fit_columns_model=False)

def is_valid(self, table_data):
Expand Down Expand Up @@ -519,7 +522,8 @@ def transform(self, table_data):
Transformed data.
"""
table_data = table_data.copy()
del table_data[self._column]
if self._drop_column:
del table_data[self._column]

return table_data

Expand Down
30 changes: 30 additions & 0 deletions tests/unit/constraints/test_tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,36 @@ def test_transform(self):
})
pd.testing.assert_frame_equal(expected_out, out)

def test_transform_without_dropping_column(self):
"""Test the ``ColumnFormula.transform`` method without dropping the column.
If `drop_column` is false, expect to not drop the constraint column.
Input:
- Table data (pandas.DataFrame)
Output:
- Table data with the indicated column (pandas.DataFrame)
"""
# Setup
column = 'c'
instance = ColumnFormula(column=column, formula=new_column, drop_column=False)

# Run
table_data = pd.DataFrame({
'a': [1, 2, 3],
'b': [4, 5, 6],
'c': [5, 7, 9]
})
out = instance.transform(table_data)

# Assert
expected_out = pd.DataFrame({
'a': [1, 2, 3],
'b': [4, 5, 6],
'c': [5, 7, 9]
})
pd.testing.assert_frame_equal(expected_out, out)

def test_reverse_transform(self):
"""Test the ``ColumnFormula.reverse_transform`` method.
Expand Down

0 comments on commit 6361425

Please sign in to comment.