Skip to content

Commit

Permalink
Add documentation for and invert default
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Aug 6, 2021
1 parent cc6a9e0 commit 234b826
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
11 changes: 10 additions & 1 deletion sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib
import inspect
import logging
import warnings

import pandas as pd
from copulas.multivariate.gaussian import GaussianMultivariate
Expand Down Expand Up @@ -111,7 +112,7 @@ class Constraint(metaclass=ConstraintMeta):
def _identity(self, table_data):
return table_data

def __init__(self, handling_strategy, fit_columns_model=True):
def __init__(self, handling_strategy, fit_columns_model=False):
self.fit_columns_model = fit_columns_model
if handling_strategy == 'transform':
self.filter_valid = self._identity
Expand Down Expand Up @@ -226,6 +227,14 @@ def _validate_constraint_columns(self, table_data):
"""
missing_columns = [col for col in self.constraint_columns if col not in table_data.columns]
if missing_columns:
if not self._columns_model:
warning_message = (
'The following constraint columns are missing from the table data '
f'- `{missing_columns}`. When `fit_columns_model` is `False` and one or more '
'constraint columns are missing, reject sampling can become slow.'
)
warnings.warn(warning_message, UserWarning)

all_columns_missing = len(missing_columns) == len(self.constraint_columns)
if self._columns_model is None or all_columns_missing:
raise MissingConstraintColumnError()
Expand Down
30 changes: 25 additions & 5 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,18 @@ class UniqueCombinations(Constraint):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``,
``reject_sampling`` or ``all``. Defaults to ``transform``.
fit_columns_model (bool):
If False, reject sampling will be used to handle conditional sampling.
Otherwise, a model will be trained and used to sample other columns
based on the conditioned column. Defaults to False.
"""

_separator = None
_joint_column = None
_combinations_to_uuids = None
_uuids_to_combinations = None

def __init__(self, columns, handling_strategy='transform', fit_columns_model=True):
def __init__(self, columns, handling_strategy='transform', fit_columns_model=False):
if len(columns) < 2:
raise ValueError('UniqueCombinations requires at least two constraint columns.')

Expand Down Expand Up @@ -214,6 +218,10 @@ class GreaterThan(Constraint):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
fit_columns_model (bool):
If False, reject sampling will be used to handle conditional sampling.
Otherwise, a model will be trained and used to sample other columns
based on the conditioned column. Defaults to False.
drop (str):
Which column to drop during transformation. Can be ``'high'``,
``'low'`` or ``None``.
Expand Down Expand Up @@ -290,7 +298,7 @@ def _get_columns_to_reconstruct(self):
return column

def __init__(self, low, high, strict=False, handling_strategy='transform',
fit_columns_model=True, drop=None, scalar=None):
fit_columns_model=False, drop=None, scalar=None):
self._strict = strict
self._drop = drop
self._scalar = scalar
Expand Down Expand Up @@ -485,12 +493,16 @@ class Positive(GreaterThan):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
fit_columns_model (bool):
If False, reject sampling will be used to handle conditional sampling.
Otherwise, a model will be trained and used to sample other columns
based on the conditioned column. Defaults to False.
drop (bool):
Whether to drop columns during transformation.
"""

def __init__(self, columns, strict=False, handling_strategy='transform',
fit_columns_model=True, drop=False):
fit_columns_model=False, drop=False):
drop = 'high' if drop else None
super().__init__(handling_strategy=handling_strategy,
fit_columns_model=fit_columns_model,
Expand All @@ -515,12 +527,16 @@ class Negative(GreaterThan):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
fit_columns_model (bool):
If False, reject sampling will be used to handle conditional sampling.
Otherwise, a model will be trained and used to sample other columns
based on the conditioned column. Defaults to False.
drop (bool):
Whether to drop columns during transformation.
"""

def __init__(self, columns, strict=False, handling_strategy='transform',
fit_columns_model=True, drop=False):
fit_columns_model=False, drop=False):
drop = 'low' if drop else None
super().__init__(handling_strategy=handling_strategy,
fit_columns_model=fit_columns_model,
Expand Down Expand Up @@ -633,6 +649,10 @@ class Between(Constraint):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
fit_columns_model (bool):
If False, reject sampling will be used to handle conditional sampling.
Otherwise, a model will be trained and used to sample other columns
based on the conditioned column. Defaults to False.
high_is_scalar(bool or None):
Whether or not the value for high is a scalar or a column name.
If ``None``, this will be determined during the ``fit`` method
Expand All @@ -646,7 +666,7 @@ class Between(Constraint):
_transformed_column = None

def __init__(self, column, low, high, strict=False, handling_strategy='transform',
fit_columns_model=True, high_is_scalar=None, low_is_scalar=None):
fit_columns_model=False, high_is_scalar=None, low_is_scalar=None):
self.constraint_column = column
self.constraint_columns = (column,)
self._low = low
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/tabular/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def test_conditional_sampling_constraint_uses_columns_model(gm_mock):
# Setup
constraint = UniqueCombinations(
columns=['city', 'state'],
handling_strategy='transform'
handling_strategy='transform',
fit_columns_model=True,
)
data = pd.DataFrame({
'city': ['LA', 'SF', 'CHI', 'LA', 'LA'],
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/constraints/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def test_fit_trains_column_model(self, ht_mock, gm_mock):
'a': [1, 2, 3],
'b': [4, 5, 6]
})
instance = Constraint(handling_strategy='transform')
instance = Constraint(handling_strategy='transform', fit_columns_model=True)
instance.constraint_columns = ('a', 'b')

# Run
Expand Down

0 comments on commit 234b826

Please sign in to comment.