Skip to content

Commit

Permalink
Merge pull request #190 from sdv-dev/tabular-docs
Browse files Browse the repository at this point in the history
Tabular docs
  • Loading branch information
csala authored Sep 7, 2020
2 parents 11d9fd1 + fbfef7a commit daaaa2f
Show file tree
Hide file tree
Showing 18 changed files with 4,553 additions and 1,077 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ test-readme: ## run the readme snippets

.PHONY: test-tutorials
test-tutorials: ## run the tutorial notebooks
find tutorials -maxdepth 1 -name "*.ipynb" -exec \
find tutorials -path "*/.ipynb_checkpoints" -prune -false -o -name "*.ipynb" -exec \
jupyter nbconvert --execute --ExecutePreprocessor.timeout=600 --stdout {} > /dev/null \;

.PHONY: test
Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ Getting Started
:maxdepth: 2

install
../tutorials/01_Quickstart
../tutorials/Quickstart
8 changes: 3 additions & 5 deletions docs/user_guides/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
User Guides
===========

The User Guides section covers different topics about SDV usage.
The User Guides section covers different topics about SDV usage for different types of data:

.. toctree::
:maxdepth: 2

../tutorials/02_Single_Table_Modeling
../tutorials/03_Relational_Data_Modeling
../tutorials/04_Working_with_Metadata
../tutorials/05_Handling_Constraints
single_table
relational
8 changes: 8 additions & 0 deletions docs/user_guides/relational.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Relational Data
===============

.. toctree::
:maxdepth: 2

../tutorials/relational_data/01_Relational_Data_Models
../tutorials/relational_data/02_Working_with_Metadata
10 changes: 10 additions & 0 deletions docs/user_guides/single_table.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Single Table Data
=================

.. toctree::
:maxdepth: 2

../tutorials/single_table_data/01_Tabular_Models
../tutorials/single_table_data/02_GaussianCopula_Model
../tutorials/single_table_data/03_CTGAN_Model
../tutorials/single_table_data/04_Handling_Constraints
4 changes: 0 additions & 4 deletions sdv/metadata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,10 +407,6 @@ def get_dtypes(self, table_name, ids=False):
if name == self.get_foreign_key(table_name, child_table):
break

else:
raise MetadataError(
'id field `{}` is neither a primary or a foreign key'.format(name))

if ids or (field_type != 'id'):
dtypes[name] = dtype

Expand Down
19 changes: 12 additions & 7 deletions sdv/metadata/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def get_dtypes(self, ids=False):
Args:
ids (bool):
Whether or not include the id fields. Defaults to ``False``.
Whether or not to include the id fields. Defaults to ``False``.
Returns:
dict:
Expand Down Expand Up @@ -409,7 +409,7 @@ def set_primary_key(self, field_name):
def _make_anonymization_mappings(self, data):
mappings = {}
for name, field_metadata in self._fields_metadata.items():
if field_metadata.get('pii'):
if field_metadata['type'] != 'id' and field_metadata.get('pii'):
faker = self._get_faker(field_metadata['pii_category'])

uniques = data[name].unique()
Expand Down Expand Up @@ -463,7 +463,8 @@ def transform(self, data):
if not self.fitted:
raise MetadataNotFittedError()

data = self._anonymize(data[self._field_names])
fields = self.get_dtypes(ids=False)
data = self._anonymize(data[fields])

for constraint in self._constraints:
data = constraint.transform(data)
Expand All @@ -490,11 +491,15 @@ def reverse_transform(self, data):

fields = self._fields_metadata
for name, dtype in self.get_dtypes(ids=True).items():
field_type = fields[name]['type']
if field_type == 'id':
field_data = pd.Series(np.arange(len(reversed_data)))
else:
field_metadata = fields[name]
field_type = field_metadata['type']
if field_type != 'id':
field_data = reversed_data[name]
elif field_metadata.get('pii', False):
faker = self._get_faker(field_metadata['pii_category'])
field_data = pd.Series([faker() for _ in range(len(reversed_data))])
else:
field_data = pd.Series(np.arange(len(reversed_data)))

reversed_data[name] = field_data.dropna().astype(dtype)

Expand Down
78 changes: 54 additions & 24 deletions sdv/tabular/copulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class GaussianCopula(BaseTabularModel):
so using this will make ``get_parameters`` unusable.
* ``truncated_gaussian``: Use a Truncated Gaussian distribution.
default_distribution (copulas.univariate.Univariate or str):
Copulas univariate distribution to use by default. To choose from the list
of possible ``distribution`` values. Defaults to ``parametric``.
categorical_transformer (str):
Type of transformer to use for the categorical variables, to choose
from:
Expand Down Expand Up @@ -153,25 +157,36 @@ class GaussianCopula(BaseTabularModel):
}
_DEFAULT_TRANSFORMER = 'one_hot_encoding'

@classmethod
def _get_distribution(cls, distribution):
if not distribution:
return cls._DISTRIBUTIONS['parametric']

if isinstance(distribution, str):
return cls._DISTRIBUTIONS.get(distribution, distribution)
def _get_distribution(self, table_data):
default = self._DISTRIBUTIONS.get(self._default_distribution, self._default_distribution)
if self._distribution is None:
return {
name: default
for name in table_data.columns
}

if isinstance(distribution, dict):
if not isinstance(self._distribution, dict):
distribution = self._DISTRIBUTIONS.get(self._distribution, self._distribution)
return {
name: cls._get_distribution(distribution)
for name, distribution in distribution.items()
name: distribution
for name in table_data.columns
}

return distribution
distributions = {}
for column in table_data.columns:
distribution = self._distribution.get(column)
if not distribution:
distribution = default
else:
distribution = self._DISTRIBUTIONS.get(distribution, distribution)

distributions[column] = distribution

return distributions

def __init__(self, field_names=None, field_types=None, field_transformers=None,
anonymize_fields=None, primary_key=None, constraints=None,
table_metadata=None, distribution=None, categorical_transformer=None):
anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None,
distribution=None, default_distribution=None, categorical_transformer=None):

if isinstance(table_metadata, dict):
table_metadata = Table.from_dict(table_metadata)
Expand All @@ -185,7 +200,8 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None,
if categorical_transformer is None:
categorical_transformer = model_kwargs['categorical_transformer']

self._distribution = self._get_distribution(distribution)
self._distribution = distribution
self._default_distribution = default_distribution or 'parametric'

categorical_transformer = categorical_transformer or self._DEFAULT_TRANSFORMER
self._categorical_transformer = categorical_transformer
Expand All @@ -200,6 +216,28 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None,
table_metadata=table_metadata
)

def get_distributions(self):
"""Get the arguments needed to reproduce this model.
Additional arguments include:
- Distribution found for each column
- categorical_transformer
Returns:
dict:
Dictionary containing the categorical transformer used
and the distributions used or detected for each column.
"""
parameters = self._model.to_dict()
univariates = parameters['univariates']
columns = parameters['columns']

distributions = {}
for column, univariate in zip(columns, univariates):
distributions[column] = univariate['type']

return distributions

def _update_metadata(self):
"""Add arguments needed to reproduce this model to the Metadata.
Expand All @@ -210,14 +248,7 @@ def _update_metadata(self):
class_name = self.__class__.__name__
model_kwargs = self._metadata.get_model_kwargs(class_name)
if not model_kwargs:
parameters = self._model.to_dict()
univariates = parameters['univariates']
columns = parameters['columns']

distributions = {}
for column, univariate in zip(columns, univariates):
distributions[column] = univariate['type']

distributions = self.get_distributions()
self._metadata.set_model_kwargs(class_name, {
'distribution': distributions,
'categorical_transformer': self._categorical_transformer,
Expand All @@ -230,6 +261,7 @@ def _fit(self, table_data):
table_data (pandas.DataFrame):
Data to be fitted.
"""
self._distribution = self._get_distribution(table_data)
self._model = copulas.multivariate.GaussianMultivariate(distribution=self._distribution)
self._model.fit(table_data)
self._update_metadata()
Expand Down Expand Up @@ -267,8 +299,6 @@ def get_parameters(self):
for index, row in enumerate(params['covariance']):
covariance.append(row[:index + 1])

# self._model.covariance = np.array(values)
# params = self._model.to_dict()
univariates = dict()
for name, univariate in zip(params.pop('columns'), params['univariates']):
univariates[name] = univariate
Expand Down
30 changes: 23 additions & 7 deletions sdv/tabular/ctgan.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
"""Wrapper around CTGAN model."""

import contextlib
import io

from sdv.tabular.base import BaseTabularModel


Expand Down Expand Up @@ -60,6 +63,8 @@ class CTGAN(BaseTabularModel):
Wheight Decay for the Adam Optimizer. Defaults to 1e-6.
batch_size (int):
Number of data samples to process in each step.
verbose (bool):
Whether to print fit progress on stdout. Defaults to ``False``.
"""

_CTGAN_CLASS = None
Expand All @@ -72,7 +77,7 @@ class CTGAN(BaseTabularModel):
def __init__(self, field_names=None, field_types=None, field_transformers=None,
anonymize_fields=None, primary_key=None, constraints=None, table_metadata=None,
epochs=300, log_frequency=True, embedding_dim=128, gen_dim=(256, 256),
dis_dim=(256, 256), l2scale=1e-6, batch_size=500):
dis_dim=(256, 256), l2scale=1e-6, batch_size=500, verbose=False):
super().__init__(
field_names=field_names,
primary_key=primary_key,
Expand All @@ -99,6 +104,7 @@ def __init__(self, field_names=None, field_types=None, field_transformers=None,
self._batch_size = batch_size
self._epochs = epochs
self._log_frequency = log_frequency
self._verbose = verbose

def _fit(self, table_data):
"""Fit the model to the table.
Expand All @@ -119,12 +125,22 @@ def _fit(self, table_data):
for field, meta in self._metadata.get_fields().items()
if meta['type'] == 'categorical'
]
self._model.fit(
table_data,
epochs=self._epochs,
discrete_columns=categoricals,
log_frequency=self._log_frequency,
)

if self._verbose:
self._model.fit(
table_data,
epochs=self._epochs,
discrete_columns=categoricals,
log_frequency=self._log_frequency,
)
else:
with contextlib.redirect_stdout(io.StringIO()):
self._model.fit(
table_data,
epochs=self._epochs,
discrete_columns=categoricals,
log_frequency=self._log_frequency,
)

def _sample(self, num_rows):
"""Sample the indicated number of rows from the model.
Expand Down
17 changes: 0 additions & 17 deletions tests/metadata/test___init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,23 +370,6 @@ def test_get_dtypes_error_invalid_type(self):
with pytest.raises(MetadataError):
Metadata.get_dtypes(metadata, 'test')

def test_get_dtypes_error_id(self):
"""Test get data types with an id that is not a primary or foreign key."""
# Setup
table_meta = {
'fields': {
'item': {'type': 'id'}
}
}
metadata = Mock(spec_set=Metadata)
metadata.get_table_meta.return_value = table_meta
metadata.get_children.return_value = []
metadata._DTYPES = Metadata._DTYPES

# Run
with pytest.raises(MetadataError):
Metadata.get_dtypes(metadata, 'test', ids=True)

def test_get_dtypes_error_subtype_numerical(self):
"""Test get data types with an invalid numerical subtype."""
# Setup
Expand Down
Loading

0 comments on commit daaaa2f

Please sign in to comment.