-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance GreaterThan constraint to allow comparing to scalar values #485
Conversation
cac2289
to
5959803
Compare
sdv/constraints/tabular.py
Outdated
separator = '#' | ||
while not self._valid_separator(table_data, separator, self.constraint_columns): | ||
separator += '#' | ||
self._diff_column = separator + separator.join(self.constraint_columns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@csala The reason that this doesn't use the _valid_separator
method is that the self.column_constraints
tuple might now only have one element. In this case, the _valid_separator
method gets stuck in an infinite loop, because there is nothing to join the only column to, so it finds the column name inside the table. Not sure what you think is a good approach. I didn't use the uuid
either because it will be good to know what the column name ends up being for integration tests. Also not sure how to keep the uuid
from being something in the column names of the table
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is OK to do it here, but I would still go for the {column}####{column}
format instead of the #{column}#{column}####}
one.
wrt to the problem of having one column, you can just validate against set(table_data.columns) - {self.constraint_columns}
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this. If we only validate against the columns not in the self.constraint_columns
, it is possible to end up with a column name that is already present. If the self.constraint_columns
contains ('a',)
, then the diff column should be called 'a#'
. But if we just join the columns it will return 'a'
8b09909
to
0cb7e32
Compare
@@ -161,6 +161,8 @@ def fit(self, table_data): | |||
table_data (pandas.DataFrame): | |||
Table data. | |||
""" | |||
self._fit(table_data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why this call was moved up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In _fit
the _constraint_columns
might be changed. Id the part below runs before this, it will try to access the wrong _constraint_columns
. Before there were no constraints that changed _constraint_columns
in _fit
so this wasn't a problem
@@ -292,22 +325,31 @@ def reverse_transform(self, table_data): | |||
""" | |||
table_data = table_data.copy() | |||
diff = (np.exp(table_data[self._diff_column]).round() - 1).clip(0) | |||
if self._diff_is_datetime(table_data): | |||
if self._is_datetime: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could _is_datetime
ever be None
here? Since it depends on calling fit
first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it could. In a similar way though, self._diff_column
could also be None
. So I am assuming we want users to fit the constraint before reverse transforming. Not sure what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is fine. reverse_transform
should always be called after fit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. If it becomes a usability problem, we could add an explicit check and throw an informative error. Seems fine for now since most constraints probably assume this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if we did that, we should probably just have on in the base class that complains if you try and reverse transform or transform before having fit anything
sdv/constraints/tabular.py
Outdated
separator = '#' | ||
while not self._valid_separator(table_data, separator, self.constraint_columns): | ||
separator += '#' | ||
self._diff_column = separator + separator.join(self.constraint_columns) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is OK to do it here, but I would still go for the {column}####{column}
format instead of the #{column}#{column}####}
one.
wrt to the problem of having one column, you can just validate against set(table_data.columns) - {self.constraint_columns}
.
sdv/constraints/tabular.py
Outdated
if self._low_is_scalar is None: | ||
self._low_is_scalar = self._low not in table_data.columns | ||
|
||
if self._low_is_scalar: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make this an elif
and before this check if both are scalar at the same time and raise an error:
if self._low_is_scalar and self._high_is_scalar:
raise TypeError('`low` and `high` cannot be both scalars at the same time')
elif self._low_is_scalar:
...
elif self._high_is_scalar:
...
else:
self._dtype = ...
@@ -292,22 +325,31 @@ def reverse_transform(self, table_data): | |||
""" | |||
table_data = table_data.copy() | |||
diff = (np.exp(table_data[self._diff_column]).round() - 1).clip(0) | |||
if self._diff_is_datetime(table_data): | |||
if self._is_datetime: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is fine. reverse_transform
should always be called after fit
sdv/constraints/tabular.py
Outdated
invalid = ~self.is_valid(table_data) | ||
new_high_values = low_column.loc[invalid] + diff.loc[invalid] | ||
table_data[self._high].loc[invalid] = new_high_values.astype(self._dtype) | ||
if self._high_is_scalar and not self._low_is_scalar: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this if/else block could be simplified by only assigning new_values
and then doing:
if ...
new_values = ...
column = ...
elif ...
...
else:
...
table_data[column].loc[invalid] = new_values.astype(self._dtype)
column
should also be computed beforehand, in the __init__
, since it will never change.
7ec6981
to
2f25cbd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great @amontanez24 ! I love how thorough the tests are!
I added a couple of minor comments about the code itself, and one about the demo and the start_date
field that was added, that may end up being confusing.
return self._low | ||
elif self._low in table_data.columns: | ||
return table_data[self._low] | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the elif
and this return
can be removed, since when this is called _low_is_scalar
has already been set to the right value, which means that if it evaluates to false the column must exist.
So this becomes:
if self._low_is_scalar:
return self._low
return table_data[self._low]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No because if self._low_is_scalar
is false, the column could still have been dropped. That value can be false, but self._drop
can be set to low
name = self.constraint_columns[0] + token | ||
while name in table_data.columns: | ||
name += '#' | ||
return name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add blank lines above this return
(and also the others that come right after an indentation decrease)
sdv/demo.py
Outdated
@@ -325,6 +325,11 @@ def _load_tabular_dummy(): | |||
faker = Faker() | |||
names = [faker.name() for _ in range(12)] | |||
adresses = [faker.address() for _ in range(12)] | |||
start_date = datetime(1980, 1, 1) | |||
start_dates = [ | |||
start_date + timedelta(days=np.random.randint(0, 14600)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this field choice is delicate, because we will then need to make the years_in_the_company
match the start_date
. Maybe we could use a date which is not related to for how long the employee has been in the company?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good @amontanez24 !
Resolve #368