Skip to content

Commit

Permalink
pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 committed Jun 29, 2021
1 parent 13df0cb commit 2f25cbd
Show file tree
Hide file tree
Showing 3 changed files with 274 additions and 74 deletions.
93 changes: 61 additions & 32 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class GreaterThan(Constraint):

_diff_column = None
_is_datetime = None
_column_to_reconstruct = None

def __init__(self, low, high, strict=False, handling_strategy='transform',
fit_columns_model=True, drop=None, high_is_scalar=None,
Expand All @@ -233,6 +234,44 @@ def __init__(self, low, high, strict=False, handling_strategy='transform',
super().__init__(handling_strategy=handling_strategy,
fit_columns_model=fit_columns_model)

def _get_low_value(self, table_data):
if self._low_is_scalar:
return self._low
elif self._low in table_data.columns:
return table_data[self._low]
return None

def _get_high_value(self, table_data):
if self._high_is_scalar:
return self._high
elif self._high in table_data.columns:
return table_data[self._high]
return None

def _get_column_to_reconstruct(self):
if self._drop == 'high':
column = self._high
elif self._drop == 'low':
column = self._low
elif self._high_is_scalar:
column = self._low
else:
column = self._high

return column

def _get_diff_column_name(self, table_data):
token = '#'
if len(self.constraint_columns) == 1:
name = self.constraint_columns[0] + token
while name in table_data.columns:
name += '#'
return name

while token.join(self.constraint_columns) in table_data.columns:
token += '#'
return token.join(self.constraint_columns)

def _fit(self, table_data):
"""Learn the dtype of the high column.
Expand All @@ -245,23 +284,20 @@ def _fit(self, table_data):
if self._low_is_scalar is None:
self._low_is_scalar = self._low not in table_data.columns

if self._low_is_scalar:
if self._high_is_scalar and self._low_is_scalar:
raise TypeError('`low` and `high` cannot be both scalars at the same time')
elif self._low_is_scalar:
self.constraint_columns = (self._high,)
self._dtype = table_data[self._high].dtype

if self._high_is_scalar:
elif self._high_is_scalar:
self.constraint_columns = (self._low,)
self._dtype = table_data[self._low].dtype

if not self._low_is_scalar and not self._high_is_scalar:
else:
self._dtype = table_data[self._high].dtype

separator = '#'
self._diff_column = separator + separator.join(self.constraint_columns)
while self._diff_column in table_data.columns:
self._diff_column += separator

low = self._low if self._low_is_scalar else table_data[self._low]
self._column_to_reconstruct = self._get_column_to_reconstruct()
self._diff_column = self._get_diff_column_name(table_data)
low = self._get_low_value(table_data)
self._is_datetime = (pd.api.types.is_datetime64_ns_dtype(low)
or isinstance(low, pd.Timestamp)
or isinstance(low, datetime))
Expand All @@ -277,8 +313,8 @@ def is_valid(self, table_data):
pandas.Series:
Whether each row is valid.
"""
low = self._low if self._low_is_scalar else table_data[self._low]
high = self._high if self._high_is_scalar else table_data[self._high]
low = self._get_low_value(table_data)
high = self._get_high_value(table_data)
if self._strict:
return high > low

Expand All @@ -302,9 +338,7 @@ def _transform(self, table_data):
Transformed data.
"""
table_data = table_data.copy()
low = self._low if self._low_is_scalar else table_data[self._low]
high = self._high if self._high_is_scalar else table_data[self._high]
diff = high - low
diff = self._get_high_value(table_data) - self._get_low_value(table_data)

if self._is_datetime:
diff = pd.to_numeric(diff)
Expand Down Expand Up @@ -340,28 +374,23 @@ def reverse_transform(self, table_data):
if self._is_datetime:
diff = pd.to_timedelta(diff)

high = self._get_high_value(table_data)
low = self._get_low_value(table_data)

if self._drop == 'high':
low = self._low if self._low_is_scalar else table_data[self._low]
table_data[self._high] = (low + diff).astype(self._dtype)

elif self._drop == 'low':
high = self._high if self._high_is_scalar else table_data[self._high]
table_data[self._low] = (high - diff).astype(self._dtype)

else:
invalid = ~self.is_valid(table_data)
if self._high_is_scalar and not self._low_is_scalar:
new_low_values = self._high - diff.loc[invalid]
table_data[self._low].loc[invalid] = new_low_values.astype(self._dtype)

elif self._low_is_scalar and not self._high_is_scalar:
new_high_values = self._low + diff.loc[invalid]
table_data[self._high].loc[invalid] = new_high_values.astype(self._dtype)

elif not self._high_is_scalar and not self._low_is_scalar:
low_column = table_data[self._low]
new_high_values = low_column.loc[invalid] + diff.loc[invalid]
table_data[self._high].loc[invalid] = new_high_values.astype(self._dtype)
if not self._high_is_scalar and not self._low_is_scalar:
new_values = low.loc[invalid] + diff.loc[invalid]
elif self._high_is_scalar:
new_values = high - diff.loc[invalid]
else:
new_values = low + diff.loc[invalid]

table_data[self._column_to_reconstruct].loc[invalid] = new_values.astype(self._dtype)

table_data = table_data.drop(self._diff_column, axis=1)

Expand Down
20 changes: 10 additions & 10 deletions tests/integration/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,23 +273,23 @@ def test_conditional_sampling_constraint_uses_columns_model_reject_sampling(gm_m
sampled_numeric_data = [
pd.DataFrame({
'age_joined': [26.0],
'#age_joined#age': [np.log(5.0)]
'age_joined#age': [np.log(5.0)]
}),
pd.DataFrame({
'age_joined': [18.0],
'#age_joined#age': [np.log(13.0)]
'age_joined#age': [np.log(13.0)]
}),
pd.DataFrame({
'age_joined': [28.0],
'#age_joined#age': [np.log(3.0)]
'age_joined#age': [np.log(3.0)]
}),
pd.DataFrame({
'age_joined': [27.0],
'#age_joined#age': [np.log(4.0)]
'age_joined#age': [np.log(4.0)]
}),
pd.DataFrame({
'age_joined': [24.0],
'#age_joined#age': [np.log(7.0)]
'age_joined#age': [np.log(7.0)]
})
]

Expand All @@ -308,15 +308,15 @@ def test_conditional_sampling_constraint_uses_columns_model_reject_sampling(gm_m
})
assert len(model._model.sample.mock_calls) == 5
model._model.sample.assert_any_call(1, conditions={'age_joined': 18.0,
'#age_joined#age': np.log(13.0)})
'age_joined#age': np.log(13.0)})
model._model.sample.assert_any_call(1, conditions={'age_joined': 24.0,
'#age_joined#age': np.log(7.0)})
'age_joined#age': np.log(7.0)})
model._model.sample.assert_any_call(1, conditions={'age_joined': 26.0,
'#age_joined#age': np.log(5.0)})
'age_joined#age': np.log(5.0)})
model._model.sample.assert_any_call(1, conditions={'age_joined': 27.0,
'#age_joined#age': np.log(4.0)})
'age_joined#age': np.log(4.0)})
model._model.sample.assert_any_call(1, conditions={'age_joined': 28.0,
'#age_joined#age': np.log(3.0)})
'age_joined#age': np.log(3.0)})
pd.testing.assert_frame_equal(
sampled_data.sort_values(by='age_joined').reset_index(drop=True),
expected_result.sort_values(by='age_joined').reset_index(drop=True)
Expand Down
Loading

0 comments on commit 2f25cbd

Please sign in to comment.