Skip to content

Commit

Permalink
Add documentation for and invert default
Browse files Browse the repository at this point in the history
  • Loading branch information
katxiao committed Aug 6, 2021
1 parent 920a88e commit 4bfa5d0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
11 changes: 10 additions & 1 deletion sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib
import inspect
import logging
import warnings

import pandas as pd
from copulas.multivariate.gaussian import GaussianMultivariate
Expand Down Expand Up @@ -111,7 +112,7 @@ class Constraint(metaclass=ConstraintMeta):
def _identity(self, table_data):
return table_data

def __init__(self, handling_strategy, fit_columns_model=True):
def __init__(self, handling_strategy, fit_columns_model=False):
self.fit_columns_model = fit_columns_model
if handling_strategy == 'transform':
self.filter_valid = self._identity
Expand Down Expand Up @@ -226,6 +227,14 @@ def _validate_constraint_columns(self, table_data):
"""
missing_columns = [col for col in self.constraint_columns if col not in table_data.columns]
if missing_columns:
if not self._columns_model:
warning_message = (
'The following constraint columns are missing from the table data '
f'- `{missing_columns}`. When `fit_columns_model` is `False` and one or more '
'constraint columns are missing, reject sampling can become slow.'
)
warnings.warn(warning_message, UserWarning)

all_columns_missing = len(missing_columns) == len(self.constraint_columns)
if self._columns_model is None or all_columns_missing:
raise MissingConstraintColumnError()
Expand Down
30 changes: 25 additions & 5 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,18 @@ class UniqueCombinations(Constraint):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``,
``reject_sampling`` or ``all``. Defaults to ``transform``.
fit_columns_model (bool):
If False, reject sampling will be used to handle conditional sampling.
Otherwise, a model will be trained and used to sample other columns
based on the conditioned column. Defaults to False.
"""

_separator = None
_joint_column = None
_combinations_to_uuids = None
_uuids_to_combinations = None

def __init__(self, columns, handling_strategy='transform', fit_columns_model=True):
def __init__(self, columns, handling_strategy='transform', fit_columns_model=False):
if len(columns) < 2:
raise ValueError('UniqueCombinations requires at least two constraint columns.')

Expand Down Expand Up @@ -212,6 +216,10 @@ class GreaterThan(Constraint):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
fit_columns_model (bool):
If False, reject sampling will be used to handle conditional sampling.
Otherwise, a model will be trained and used to sample other columns
based on the conditioned column. Defaults to False.
drop (str):
Which column to drop during transformation. Can be ``'high'``,
``'low'`` or ``None``.
Expand All @@ -230,7 +238,7 @@ class GreaterThan(Constraint):
_column_to_reconstruct = None

def __init__(self, low, high, strict=False, handling_strategy='transform',
fit_columns_model=True, drop=None, high_is_scalar=None,
fit_columns_model=False, drop=None, high_is_scalar=None,
low_is_scalar=None):
self._low = low
self._high = high
Expand Down Expand Up @@ -437,13 +445,17 @@ class Positive(GreaterThan):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
fit_columns_model (bool):
If False, reject sampling will be used to handle conditional sampling.
Otherwise, a model will be trained and used to sample other columns
based on the conditioned column. Defaults to False.
drop (str):
Which column to drop during transformation. Can be ``'high'``
or ``None``.
"""

def __init__(self, high, strict=False, handling_strategy='transform',
fit_columns_model=True, drop=None):
fit_columns_model=False, drop=None):
super().__init__(handling_strategy=handling_strategy,
fit_columns_model=fit_columns_model,
high=high, low=0, high_is_scalar=False,
Expand All @@ -468,13 +480,17 @@ class Negative(GreaterThan):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
fit_columns_model (bool):
If False, reject sampling will be used to handle conditional sampling.
Otherwise, a model will be trained and used to sample other columns
based on the conditioned column. Defaults to False.
drop (str):
Which column to drop during transformation. Can be ``'low'``
or ``None``.
"""

def __init__(self, low, strict=False, handling_strategy='transform',
fit_columns_model=True, drop=None):
fit_columns_model=False, drop=None):
super().__init__(handling_strategy=handling_strategy,
fit_columns_model=fit_columns_model,
high=0, low=low, high_is_scalar=True,
Expand Down Expand Up @@ -586,6 +602,10 @@ class Between(Constraint):
handling_strategy (str):
How this Constraint should be handled, which can be ``transform``
or ``reject_sampling``. Defaults to ``transform``.
fit_columns_model (bool):
If False, reject sampling will be used to handle conditional sampling.
Otherwise, a model will be trained and used to sample other columns
based on the conditioned column. Defaults to False.
high_is_scalar(bool or None):
Whether or not the value for high is a scalar or a column name.
If ``None``, this will be determined during the ``fit`` method
Expand All @@ -599,7 +619,7 @@ class Between(Constraint):
_transformed_column = None

def __init__(self, column, low, high, strict=False, handling_strategy='transform',
fit_columns_model=True, high_is_scalar=None, low_is_scalar=None):
fit_columns_model=False, high_is_scalar=None, low_is_scalar=None):
self.constraint_column = column
self.constraint_columns = (column,)
self._low = low
Expand Down

0 comments on commit 4bfa5d0

Please sign in to comment.