Skip to content

Commit

Permalink
Move data.copy to base class of constraints (#849)
Browse files Browse the repository at this point in the history
* changing logic for handling constraints

* fixing a lot of tests

* removing handling_strategy attribute

* removing handling_strategy from docs and code

* fixing tests, docs and tutorials

* adding unit tests

* only calling reverse_transform if custom constraint

* ignoring sre_parse during isort

* pr comments

* raising all errors except missing column erros

* adding integration test and tracking if reverse transform should use reject sampling

* refactoring how constraint transforms happen

* adding unit tests

* Rebase + move copy to base

* Fix test cases

* fix rebase

* Test resulting data is a copy for transform/reverse_transform

Co-authored-by: Andrew Montanez <amontanez2424@gmail.com>
Co-authored-by: Andrew Montanez <andrewmontanez@Andrews-MBP.hsd1.ma.comcast.net>
Co-authored-by: Andrew Montanez <andrew@datacebo.com>
Co-authored-by: Andrew Montanez <andrew@sdv.dev>
  • Loading branch information
5 people authored Jun 29, 2022
1 parent aa0c0ea commit 04d1549
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
2 changes: 2 additions & 0 deletions sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def transform(self, table_data):
pandas.DataFrame:
Input data unmodified.
"""
table_data = table_data.copy()
missing_columns = [col for col in self.constraint_columns if col not in table_data.columns]
if missing_columns:
raise MissingConstraintColumnError(missing_columns=missing_columns)
Expand Down Expand Up @@ -200,6 +201,7 @@ def reverse_transform(self, table_data):
pandas.DataFrame:
Input data unmodified.
"""
table_data = table_data.copy()
return self._reverse_transform(table_data)

def is_valid(self, table_data):
Expand Down
13 changes: 0 additions & 13 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ def _transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()
combinations = table_data[self._columns].itertuples(index=False, name=None)
uuids = map(self._combinations_to_uuids.get, combinations)
table_data[self._joint_column] = list(uuids)
Expand All @@ -240,7 +239,6 @@ def _reverse_transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()
columns = table_data.pop(self._joint_column).map(self._uuids_to_combinations)

for index, column in enumerate(self._columns):
Expand Down Expand Up @@ -351,7 +349,6 @@ def _transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()
low, high = self._get_data(table_data)
diff_column = high - low
if self._is_datetime:
Expand All @@ -376,7 +373,6 @@ def _reverse_transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()
diff_column = np.exp(table_data[self._diff_column_name]) - 1
if self._dtype != np.dtype('float'):
diff_column = diff_column.round()
Expand Down Expand Up @@ -485,7 +481,6 @@ def _transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()
column = table_data[self._column_name].to_numpy()
diff_column = abs(column - self._value)
if self._is_datetime:
Expand All @@ -509,7 +504,6 @@ def _reverse_transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()
diff_column = np.exp(table_data[self._diff_column_name]) - 1
if self._dtype != np.dtype('float'):
diff_column = diff_column.round()
Expand Down Expand Up @@ -675,7 +669,6 @@ def _transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()
low = table_data[self.low_column_name]
high = table_data[self.high_column_name]

Expand All @@ -700,7 +693,6 @@ def _reverse_transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()
low = table_data[self.low_column_name]
high = table_data[self.high_column_name]
data = table_data[self._transformed_column]
Expand Down Expand Up @@ -820,7 +812,6 @@ def _transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()
data = logit(table_data[self.column_name], self.low_value, self.high_value)
table_data[self._transformed_column] = data
table_data = table_data.drop(self.column_name, axis=1)
Expand All @@ -842,7 +833,6 @@ def _reverse_transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()
data = table_data[self._transformed_column]

data = sigmoid(data, self.low_value, self.high_value)
Expand Down Expand Up @@ -918,7 +908,6 @@ def _transform(self, table_data):
pandas.DataFrame:
Data divided by increment.
"""
table_data = table_data.copy()
table_data[self.column_name] = table_data[self.column_name] / self.increment_value
return table_data

Expand Down Expand Up @@ -988,8 +977,6 @@ def _reverse_transform(self, table_data):
pandas.DataFrame:
Transformed data.
"""
table_data = table_data.copy()

one_hot_data = table_data[self._column_names]
transformed_data = np.zeros_like(one_hot_data.values)
transformed_data[np.arange(len(one_hot_data)), np.argmax(one_hot_data.values, axis=1)] = 1
Expand Down
24 changes: 16 additions & 8 deletions tests/unit/constraints/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,20 +275,24 @@ def test_transform(self):
By default, it behaves like an identity method, to be optionally overwritten by subclasses.
The ``Constraint.transform`` method is expected to:
- Return the input data unmodified.
- Return a copy of the input data.
Input:
- a DataFrame
Output:
- Input
"""
# Run
# Setup
instance = Constraint()
output = instance.transform(pd.DataFrame({'col': ['input']}))
data = pd.DataFrame({'col': ['input']})

# Run
output = instance.transform(data)

# Assert
pd.testing.assert_frame_equal(output, pd.DataFrame({'col': ['input']}))
assert id(output) != id(data)

def test_transform_calls__transform(self):
"""Test that the ``Constraint.transform`` method calls ``_transform``.
Expand All @@ -308,7 +312,7 @@ def test_transform_calls__transform(self):
constraint_mock._transform.return_value = 'the_transformed_data'

# Run
output = Constraint.transform(constraint_mock, 'input')
output = Constraint.transform(constraint_mock, pd.DataFrame())

# Assert
assert output == 'the_transformed_data'
Expand Down Expand Up @@ -411,19 +415,23 @@ def test_reverse_transform(self):
for completion, to be optionally overwritten by subclasses.
The ``Constraint.reverse_transform`` method is expected to:
- Return the input data unmodified.
- Return a copy of the input data.
Input:
- Anything
Output:
- Input
"""
# Run
# Setup
instance = Constraint()
output = instance.reverse_transform('input')
data = pd.DataFrame()

# Run
output = instance.reverse_transform(data)

# Assert
assert output == 'input'
pd.testing.assert_frame_equal(output, pd.DataFrame())
assert id(output) != id(data)

def test_is_valid(self):
"""Test the ``Constraint.is_valid` method. This should be overwritten by all the
Expand Down

0 comments on commit 04d1549

Please sign in to comment.