Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 86: Error merging extensions on modeler.CPA #88

Merged
merged 11 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ clean-test: ## remove test and coverage artifacts
rm -fr .pytest_cache

lint: ## check style with flake8 and isort
flake8 sdv tests
isort -c --recursive sdv tests
flake8 sdv tests examples
isort -c --recursive sdv tests examples

fixlint: ## fix lint issues using autoflake, autopep8, and isort
find sdv -name '*.py' | xargs autoflake --in-place --remove-all-unused-imports --remove-unused-variables
Expand Down
19 changes: 9 additions & 10 deletions examples/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import sys
from timeit import default_timer as timer

from sdv.sdv import SDV
from examples.utils import download_folder
from sdv.sdv import SDV


def get_logger():
Expand Down Expand Up @@ -51,15 +51,14 @@ def run_demo(folder_name):
'demo', folder_name, folder_name.capitalize() + '_manual_meta.json')
sdv = SDV(meta_file)
sdv.fit()
sampled_rows = {}
LOGGER.info('Parent map: %s',
sdv.dn.parent_map)
LOGGER.info('Transformed data: %s',
sdv.dn.transformed_data)
table_list = table_dict[folder_name]
for table in table_list:
sampled_rows[table] = sdv.sample_rows(table, 1)
LOGGER.info('Sampled row from %s: %s', table, sampled_rows[table])
sampled = sdv.sample_all()

LOGGER.info('Parent map: %s', sdv.dn.parent_map)
LOGGER.info('Transformed data: %s', sdv.dn.transformed_data)

for name, table in sampled.items():
LOGGER.info('Sampled row from %s: %s', name, table.head(3).T)

end = timer()
LOGGER.info('Total time: %s seconds', round(end-start))

Expand Down
5 changes: 3 additions & 2 deletions examples/multiparent_example/multiparent_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ def run_example():
# Setup
vault = SDV('data/meta.json')
vault.fit()

# Run
result = vault.sample_all()

for name, table in result.items():
print('Samples generated for table {}:\n{}\n'.format(name, table.head(5)))


if __name__ == '__main__':
run_example()
run_example()
37 changes: 21 additions & 16 deletions examples/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import io
import logging
import os
import boto3
import botocore
import urllib
import zipfile

from botocore import UNSIGNED
from botocore.client import Config

LOGGER = logging.getLogger(__name__)

BUCKET_NAME = 'hdi-demos'
SDV_NAME = 'sdv-demo'
Expand All @@ -14,24 +13,30 @@

def download_folder(folder_name):
"""Downloads and extracts demo folder from S3"""
s3 = boto3.resource('s3', region_name='us-east-1',
config=Config(signature_version=UNSIGNED))
zip_name = folder_name + SUFFIX
zip_destination = os.path.join('demo', zip_name)
key = os.path.join(SDV_NAME, zip_name)
# make sure directory exists

# If the directory doesn't exist , we create it
# If it exists, we check for the folder_name for early exit
if not os.path.exists('demo'):
os.makedirs('demo')

else:
if os.path.exists(os.path.join('demo', folder_name)):
return

# try to download files from s3
try:
s3.Bucket(BUCKET_NAME).download_file(key, zip_destination)
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == "404":
print("The object does not exist.")
else:
raise
url = 'https://{}.s3.amazonaws.com/{}'.format(BUCKET_NAME, key)
response = urllib.request.urlopen(url)
bytes_io = io.BytesIO(response.read())

except urllib.error.HTTPError as error:
if error.code == 404:
LOGGER.error('File %s not found.', key)
raise

# unzip folder
zip_ref = zipfile.ZipFile(zip_destination, 'r')
zip_ref = zipfile.ZipFile(bytes_io, 'r')
zip_ref.extractall('demo')
zip_ref.close()
83 changes: 61 additions & 22 deletions sdv/modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import pandas as pd
from copulas import get_qualified_name
from copulas import EPSILON, get_qualified_name
from copulas.multivariate import GaussianMultivariate, TreeTypes
from copulas.univariate import GaussianUnivariate
from rdt.transformers.positive_number import PositiveNumberTransformer
Expand All @@ -16,9 +16,9 @@
IGNORED_DICT_KEYS = ['fitted', 'distribution', 'type']

MODELLING_ERROR_MESSAGE = (
'There was an error while trying to model the database. If you are using a custom'
'distribution or model, please try again using the default ones. If the problem persist,'
'please report it here: https://github.com/HDI-Project/SDV/issues'
'There was an error while trying to model the database. If you are using a custom '
'distribution or model, please try again using the default ones. If the problem persist, '
'please report it here:\nhttps://github.com/HDI-Project/SDV/issues.\n'
)


Expand Down Expand Up @@ -204,15 +204,26 @@ def impute_table(table):
"""
values = {}

for label in table:
value = table[label].mean()
for column in table.loc[:, table.isnull().any()].columns:
if table[column].dtype in [np.float64, np.int64]:
value = table[column].mean()

if not pd.isnull(value):
values[label] = value
if not pd.isnull(value or np.nan):
values[column] = value
else:
values[label] = 0
values[column] = 0

return table.fillna(values)
table = table.fillna(values)

# There is an issue when using KDEUnivariate modeler in tables with childs
# As the extension columns would have constant values, that make it crash
# This is a temporary fix while https://github.com/DAI-Lab/Copulas/issues/82 is solved.
first_index = table.index[0]
constant_columns = table.loc[:, (table == table.loc[first_index]).all()].columns
for column in constant_columns:
table.loc[first_index, column] = table.loc[first_index, column] + EPSILON

return table

def fit_model(self, data):
"""Returns an instance of self.model fitted with the given data.
Expand All @@ -235,22 +246,26 @@ def _create_extension(self, foreign, transformed_child_table, table_info):
foreign(pandas.DataFrame): Object with Index of elements from children table elements
of a given foreign_key.
transformed_child_table(pandas.DataFrame): Table of data to fil
table_info (tuple(str, str)): foreign_key and child table names.
table_info (tuple[str, str]): foreign_key and child table names.

Returns:
pd.Series : Parameter extension
pd.Series or None : Parameter extension if it can be generated, None elsewhere.
"""

foreign_key, child_name = table_info
try:
conditional_data = transformed_child_table.loc[foreign.index].copy()
conditional_data = conditional_data.drop(foreign_key, axis=1)
if foreign_key in conditional_data:
conditional_data = conditional_data.drop(foreign_key, axis=1)

except KeyError:
return None

clean_df = self.impute_table(conditional_data)
return self.flatten_model(self.fit_model(clean_df), child_name)
if len(conditional_data):
clean_df = self.impute_table(conditional_data)
return self.flatten_model(self.fit_model(clean_df), child_name)

return None

def _get_extensions(self, pk, children):
"""Generate list of extension for child tables.
Expand Down Expand Up @@ -301,7 +316,7 @@ def _get_extensions(self, pk, children):
parameters[foreign_key] = parameter.to_dict()

extension = pd.DataFrame(parameters).T
extension.index.name = fk
extension.index.name = pk

if len(extension):
extensions.append(extension)
Expand All @@ -313,7 +328,15 @@ def CPA(self, table):

Conditional Parameter Aggregation. It will take the table's children and generate
extensions (parameters from modelling the related children for each foreign key)
and merge them to the original `table`
and merge them to the original `table`.

After the extensions are created, `extended_table` is modified in order for the extensions
to be merged. As the extensions are returned with an index consisting of values of the
`primary_key` of the parent table, we need to make sure that same values are present in
`extended_table`. The values couldn't be present in two situations:

- They weren't numeric, and have been transformed.
- They weren't transformed, and therefore are not present on `extended_table`

Args:
table (string): name of table.
Expand All @@ -335,9 +358,24 @@ def CPA(self, table):
extended_table = self.dn.transformed_data[table]
extensions = self._get_extensions(pk, children)

# add extensions
for extension in extensions:
extended_table = extended_table.merge(extension.reset_index(), how='left', on=pk)
if extensions:
original_pk = tables[table].data[pk]
transformed_pk = None

if pk in extended_table:
transformed_pk = extended_table[pk].copy()

if (pk not in extended_table) or (not extended_table[pk].equals(original_pk)):
extended_table[pk] = original_pk

# add extensions
for extension in extensions:
extended_table = extended_table.merge(extension.reset_index(), how='left', on=pk)

if transformed_pk is not None:
extended_table[pk] = transformed_pk
else:
extended_table = extended_table.drop(pk, axis=1)

self.tables[table] = extended_table

Expand Down Expand Up @@ -365,7 +403,8 @@ def model_database(self):
clean_table = self.impute_table(self.tables[table])
self.models[table] = self.fit_model(clean_table)

except (ValueError, np.linalg.linalg.LinAlgError):
ValueError(MODELLING_ERROR_MESSAGE)
except (ValueError, np.linalg.linalg.LinAlgError) as error:
raise ValueError(
MODELLING_ERROR_MESSAGE).with_traceback(error.__traceback__) from None

logger.info('Modeling Complete')
6 changes: 5 additions & 1 deletion sdv/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,15 @@ def transform_synthesized_rows(self, synthesized, table_name, num_rows):

# filter out parameters
labels = list(self.dn.tables[table_name].data)
reverse_columns = [
transformer[1] for transformer in self.dn.ht.transformers
if table_name in transformer
]

text_filled = self._fill_text_columns(synthesized, labels, table_name)

# reverse transform data
reversed_data = self.dn.ht.reverse_transform_table(text_filled, orig_meta)
reversed_data = self.dn.ht.reverse_transform_table(text_filled[reverse_columns], orig_meta)

synthesized.update(reversed_data)
return synthesized[labels]
Expand Down
2 changes: 1 addition & 1 deletion sdv/sdv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Main module."""
import pickle

from sklearn.exceptions import NotFittedError
from copulas import NotFittedError

from sdv.data_navigator import CSVDataLoader
from sdv.modeler import Modeler
Expand Down
3 changes: 0 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@
history = history_file.read()

install_requires = [
'boto3>=1.7.47',
'exrex>=0.10.5',
'numpy>=1.13.1',
'pandas>=0.22.0',
'scipy>=0.19.1',
'scikit-learn>=0.19.1',
'copulas>=0.2.1',
'rdt>=0.1.2'
]
Expand Down
Loading