Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow me to create a Custom Constraint Class and use it in the same file (filepath=None) #1223

Merged
merged 9 commits into from
Feb 3, 2023
53 changes: 35 additions & 18 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,22 @@ def get_sdtypes(self, primary_keys=False):

return sdtypes

def _validate_custom_constraints(self, filepath, class_names):
errors = []
def _validate_custom_constraint_name(self, class_name):
reserved_class_names = list(get_subclasses(Constraint))
module = load_module_from_path(Path(filepath))
if class_name in reserved_class_names:
error_message = (
f"The name '{class_name}' is a reserved constraint name. "
'Please use a different one for the custom constraint.'
)
raise InvalidConstraintsError(error_message)

def _validate_custom_constraints(self, filepath, class_names, module):
errors = []
for class_name in class_names:
if class_name in reserved_class_names:
errors.append((
f"The name '{class_name}' is a reserved constraint name. "
'Please use a different one for the custom constraint.'
))
try:
self._validate_custom_constraint_name(class_name)
except InvalidConstraintsError as err:
errors += err.errors

if not hasattr(module, class_name):
errors.append(f"The constraint '{class_name}' is not defined in '{filepath}'.")
Expand All @@ -171,7 +177,7 @@ def _validate_custom_constraints(self, filepath, class_names):
raise InvalidConstraintsError(errors)

def load_custom_constraint_classes(self, filepath, class_names):
"""Load a custom constraint class for the current model.
"""Load a custom constraint class for the current synthesizer.

Args:
filepath (str):
Expand All @@ -180,9 +186,25 @@ def load_custom_constraint_classes(self, filepath, class_names):
class_names (list):
A list of custom constraint classes to be imported.
"""
self._validate_custom_constraints(filepath, class_names)
path = Path(filepath)
module = load_module_from_path(path)
self._validate_custom_constraints(filepath, class_names, module)
for class_name in class_names:
self._custom_constraint_classes[class_name] = filepath
constraint_class = getattr(module, class_name)
self._custom_constraint_classes[class_name] = constraint_class

def add_custom_constraint_class(self, class_object, class_name):
"""Add a custom constraint class for the synthesizer to use.

Args:
class_object (sdv.constraints.Constraint):
A custom constraint class object.
class_name (str):
The name to assign this custom constraint class. This will be the name to use
when writing a constraint dictionary for ``add_constraints``.
"""
self._validate_custom_constraint_name(class_name)
self._custom_constraint_classes[class_name] = class_object
pvk-developer marked this conversation as resolved.
Show resolved Hide resolved

def _validate_constraint_dict(self, constraint_dict):
"""Validate a constraint against the single table metadata.
Expand All @@ -197,9 +219,7 @@ def _validate_constraint_dict(self, constraint_dict):
constraint_parameters = constraint_dict.get('constraint_parameters', {})
try:
if constraint_class in self._custom_constraint_classes:
path = Path(self._custom_constraint_classes[constraint_class])
module = load_module_from_path(path)
constraint_class = getattr(module, constraint_class)
constraint_class = self._custom_constraint_classes[constraint_class]

else:
constraint_class = Constraint._get_class_from_dict(constraint_class)
Expand Down Expand Up @@ -506,10 +526,7 @@ def _load_constraints(self):
loaded_constraints.append(Constraint.from_dict(constraint))

else:
constraint_class = constraint['constraint_class']
path = Path(self._custom_constraint_classes[constraint_class])
module = load_module_from_path(path)
constraint_class = getattr(module, constraint_class)
constraint_class = self._custom_constraint_classes[constraint['constraint_class']]
loaded_constraints.append(
constraint_class(**constraint.get('constraint_parameters', {}))
)
Expand Down
28 changes: 28 additions & 0 deletions sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,34 @@ def get_constraints(self):

return constraints

def load_custom_constraint_classes(self, table_name, filepath, class_names):
"""Load a custom constraint class for the specified table's synthesizer.

Args:
table_name (str):
Table to add constraint to.
filepath (str):
String representing the absolute or relative path to the python file where
the custom constraints are declared.
class_names (list):
A list of custom constraint classes to be imported.
"""
self._table_synthesizers[table_name].load_custom_constraint_classes(filepath, class_names)

def add_custom_constraint_class(self, table_name, class_object, class_name):
"""Add a custom constraint class for the synthesizer to use.

Args:
table_name (str):
Table to add constraint to.
class_object (sdv.constraints.Constraint):
A custom constraint class object.
class_name (str):
The name to assign this custom constraint class. This will be the name to use
when writing a constraint dictionary for ``add_constraints``.
"""
self._table_synthesizers[table_name].add_custom_constraint_class(class_object, class_name)

def get_info(self):
"""Get dictionary with information regarding the synthesizer.

Expand Down
14 changes: 13 additions & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def get_metadata(self):
return self.metadata

def load_custom_constraint_classes(self, filepath, class_names):
"""Load a custom constraint class for the current model.
"""Load a custom constraint class for the current synthesizer.

Args:
filepath (str):
Expand All @@ -286,6 +286,18 @@ def load_custom_constraint_classes(self, filepath, class_names):
"""
self._data_processor.load_custom_constraint_classes(filepath, class_names)

def add_custom_constraint_class(self, class_object, class_name):
"""Add a custom constraint class for the synthesizer to use.

Args:
class_object (sdv.constraints.Constraint):
A custom constraint class object.
class_name (str):
The name to assign this custom constraint class. This will be the name to use
when writing a constraint dictionary for ``add_constraints``.
"""
self._data_processor.add_custom_constraint_class(class_object, class_name)

def add_constraints(self, constraints):
"""Add constraints to the synthesizer.

Expand Down
95 changes: 95 additions & 0 deletions tests/integration/multi_table/test_hma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime

import numpy as np
import pandas as pd
import pkg_resources
import pytest
Expand All @@ -8,6 +9,7 @@
from sdv.datasets.demo import download_demo
from sdv.metadata.multi_table import MultiTableMetadata
from sdv.multi_table import HMASynthesizer
from tests.integration.single_table.custom_constraints import MyConstraint


def test_hma(tmpdir):
Expand Down Expand Up @@ -140,3 +142,96 @@ def test_hma_set_parameters():
assert hmasynthesizer._table_synthesizers['characters'].default_distribution == 'gamma'
assert hmasynthesizer._table_synthesizers['families'].default_distribution == 'uniform'
assert hmasynthesizer._table_synthesizers['character_families'].default_distribution == 'norm'


def get_custom_constraint_data_and_metadata():
parent_data = pd.DataFrame({
'primary_key': [1000, 1001, 1002],
'numerical_col': [2, 3, 4],
'categorical_col': ['a', 'b', 'a'],
})
child_data = pd.DataFrame({
'user_id': [1000, 1001, 1000],
'id': [1, 2, 3],
'random': ['a', 'b', 'c']
})

metadata = MultiTableMetadata()
metadata.detect_table_from_dataframe('parent', parent_data)
metadata.detect_table_from_dataframe('child', child_data)
metadata.set_primary_key('parent', 'primary_key')
metadata.set_primary_key('child', 'id')
metadata.add_relationship(
parent_primary_key='primary_key',
parent_table_name='parent',
child_foreign_key='user_id',
child_table_name='child'
)

return parent_data, child_data, metadata


def test_hma_custom_constraint():
"""Test an example of using a custom constraint."""
parent_data, child_data, metadata = get_custom_constraint_data_and_metadata()
synthesizer = HMASynthesizer(metadata)
constraint = {
'table_name': 'parent',
'constraint_class': 'MyConstraint',
'constraint_parameters': {
'column_names': ['numerical_col']
}
}
synthesizer.add_custom_constraint_class('parent', MyConstraint, 'MyConstraint')

# Run
synthesizer.add_constraints(constraints=[constraint])
processed_data = synthesizer.preprocess({'parent': parent_data, 'child': child_data})

# Assert Processed Data
np.testing.assert_equal(
processed_data['parent']['numerical_col'].array,
(parent_data['numerical_col'] ** 2.0).array
)

# Run - Fit the model
synthesizer.fit_processed_data(processed_data)

# Run - sample
sampled = synthesizer.sample(10)
assert all(sampled['parent']['numerical_col'] > 1)


def test_hma_custom_constraint_loaded_from_file():
"""Test an example of using a custom constraint loaded from a file."""
parent_data, child_data, metadata = get_custom_constraint_data_and_metadata()
synthesizer = HMASynthesizer(metadata)
constraint = {
'table_name': 'parent',
'constraint_class': 'MyConstraint',
'constraint_parameters': {
'column_names': ['numerical_col']
}
}
synthesizer.load_custom_constraint_classes(
'parent',
'tests/integration/single_table/custom_constraints.py',
['MyConstraint']
)

# Run
synthesizer.add_constraints(constraints=[constraint])
processed_data = synthesizer.preprocess({'parent': parent_data, 'child': child_data})

# Assert Processed Data
np.testing.assert_equal(
processed_data['parent']['numerical_col'].array,
(parent_data['numerical_col'] ** 2.0).array
)

# Run - Fit the model
synthesizer.fit_processed_data(processed_data)

# Run - sample
sampled = synthesizer.sample(10)
assert all(sampled['parent']['numerical_col'] > 1)
6 changes: 4 additions & 2 deletions tests/integration/single_table/custom_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ def is_valid(column_names, data):

def transform(column_names, data):
"""Transform the constraint."""
return data[column_names] ** 2
data[column_names] = data[column_names] ** 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

return data


def reverse_transform(column_names, data):
"""Reverse transform the constraint."""
return data[column_names] // 2
data[column_names] = data[column_names] // 2
return data


MyConstraint = create_custom_constraint_class(
Expand Down
49 changes: 48 additions & 1 deletion tests/integration/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sdv.single_table.copulagan import CopulaGANSynthesizer
from sdv.single_table.copulas import GaussianCopulaSynthesizer
from sdv.single_table.ctgan import CTGANSynthesizer, TVAESynthesizer
from tests.integration.single_table.custom_constraints import MyConstraint

METADATA = SingleTableMetadata._load_from_dict({
'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1',
Expand Down Expand Up @@ -532,7 +533,7 @@ def test_transformers_correctly_auto_assigned():
assert transformers['categorical_col'].add_noise is True


def test_custom_constraints(tmpdir):
def test_custom_constraints_from_file(tmpdir):
"""Ensure the correct loading for a custom constraint class defined in another file."""
data = pd.DataFrame({
'primary_key': ['user-000', 'user-001', 'user-002'],
Expand Down Expand Up @@ -581,6 +582,52 @@ def test_custom_constraints(tmpdir):
assert all(loaded_sampled['numerical_col'] > 1)


def test_custom_constraints_from_object(tmpdir):
"""Ensure the correct loading for a custom constraint class passed as an object."""
data = pd.DataFrame({
'primary_key': ['user-000', 'user-001', 'user-002'],
'pii_col': ['223 Williams Rd', '75 Waltham St', '77 Mass Ave'],
'numerical_col': [2, 3, 4],
'categorical_col': ['a', 'b', 'a'],
})

metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
metadata.update_column(column_name='pii_col', sdtype='address', pii=True)
synthesizer = GaussianCopulaSynthesizer(
metadata,
enforce_min_max_values=False,
enforce_rounding=False
)
synthesizer.add_custom_constraint_class(MyConstraint, 'MyConstraint')
constraint = {
'constraint_class': 'MyConstraint',
'constraint_parameters': {
'column_names': ['numerical_col']
}
}

# Run
synthesizer.add_constraints([constraint])
processed_data = synthesizer.preprocess(data)

# Assert Processed Data
assert all(processed_data['numerical_col'] == data['numerical_col'] ** 2)

# Run - Fit the model
synthesizer.fit_processed_data(processed_data)

# Run - sample
sampled = synthesizer.sample(10)
assert all(sampled['numerical_col'] > 1)

# Run - Save and Sample
synthesizer.save(tmpdir / 'test.pkl')
loaded_instance = synthesizer.load(tmpdir / 'test.pkl')
loaded_sampled = loaded_instance.sample(10)
assert all(loaded_sampled['numerical_col'] > 1)


def test_auto_assign_transformers_and_update_with_pii():
"""Ensure the ability to update a transformer with any given ``pii`` sdtype.

Expand Down
Loading