Skip to content

Commit

Permalink
sdgym-gretel (#1) (#87)
Browse files Browse the repository at this point in the history
* sdgym-gretel: adding gretel synthesizer

* pr comments and changes discussed in OH

* getting rid of error messages

* moving static method out

* Curate dependencies to avoid conflicts

Co-authored-by: Carles Sala <carles@pythiac.com>
  • Loading branch information
amontanez24 and csala authored May 27, 2021
1 parent 0afa9b8 commit f53b6b2
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 4 deletions.
4 changes: 4 additions & 0 deletions conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ build:
requirements:
host:
- ctgan >=0.2.2.dev1,<0.3
- gretel-synthetics >=0.15.4,<0.16
- humanfriendly >=8.2,<9
- numpy >=1.15.4,<2
- pandas >=0.23.4,<2
Expand All @@ -30,12 +31,14 @@ requirements:
- sdv >=0.4.4.dev0,<0.6
- tabulate >=0.8.3,<0.9
- pytorch >=1.1.0,<2
- tensorflow ==2.4.0rc1
- torchvision >=0.3.0
- tqdm >=4,<5
- xlsxwriter >=1.2.8,<1.3
- pytest-runner
run:
- ctgan >=0.2.2.dev1,<0.3
- gretel-synthetics >=0.15.4,<0.16
- humanfriendly >=8.2,<9
- numpy >=1.15.4,<2
- pandas >=0.23.4,<2
Expand All @@ -47,6 +50,7 @@ requirements:
- sdv >=0.4.4.dev0,<0.6
- tabulate >=0.8.3,<0.9
- pytorch >=1.1.0,<2
- tensorflow ==2.4.0rc1
- torchvision >=0.3.0
- tqdm >=4,<5
- xlsxwriter >=1.2.8,<1.3
Expand Down
3 changes: 3 additions & 0 deletions sdgym/synthesizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from sdgym.synthesizers.clbn import CLBN
from sdgym.synthesizers.gretel import Gretel, PreprocessedGretel
from sdgym.synthesizers.identity import Identity
from sdgym.synthesizers.independent import Independent
from sdgym.synthesizers.medgan import MedGAN
Expand All @@ -25,4 +26,6 @@
'GaussianCopulaCategorical',
'GaussianCopulaCategoricalFuzzy',
'GaussianCopulaOneHot',
'Gretel',
'PreprocessedGretel',
)
26 changes: 23 additions & 3 deletions sdgym/synthesizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,43 @@ def fit_sample(self, real_data, metadata):
class SingleTableBaseline(Baseline):
"""Base class for all the SingleTable Baselines.
Sublcasses can choose to implement ``_fit_sample``, which will
Subclasses can choose to implement ``_fit_sample``, which will
always be called with DataFrames and Table metadata dicts, or
to overwrite the ``fit_sample`` method, which may be called with
either DataFrames and Table dicts, or with dicts of tables and
dataset metadata dicts.
"""

MODALITIES = ('single-table', )
CONVERT_TO_NUMERIC = False

def _transform_fit_sample(self, real_data, metadata):
ht = rdt.HyperTransformer()
columns_to_transform = list()
fields_metadata = metadata['fields']
id_fields = list()
for field in fields_metadata:
if fields_metadata.get(field).get('type') != 'id':
columns_to_transform.append(field)
else:
id_fields.append(field)

ht.fit(real_data[columns_to_transform])
transformed_data = ht.transform(real_data)
synthetic_data = self._fit_sample(transformed_data, metadata)
reverse_transformed_synthetic_data = ht.reverse_transform(synthetic_data)
reverse_transformed_synthetic_data[id_fields] = real_data[id_fields]
return reverse_transformed_synthetic_data

def fit_sample(self, real_data, metadata):
_fit_sample = self._transform_fit_sample if self.CONVERT_TO_NUMERIC else self._fit_sample
if isinstance(real_data, dict):
return {
table_name: self._fit_sample(table, metadata.get_table_meta(table_name))
table_name: _fit_sample(table, metadata.get_table_meta(table_name))
for table_name, table in real_data.items()
}

return self._fit_sample(real_data, metadata)
return _fit_sample(real_data, metadata)


class MultiSingleTableBaseline(Baseline):
Expand Down
75 changes: 75 additions & 0 deletions sdgym/synthesizers/gretel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import os

import numpy as np
from gretel_synthetics.batch import DataFrameBatch

from sdgym.synthesizers.base import SingleTableBaseline


class Gretel(SingleTableBaseline):
"""Class to represent Gretel's neural network model."""

DEFAULT_CHECKPOINT_DIR = os.path.join(os.getcwd(), 'checkpoints')

def __init__(self, max_lines=0, max_line_len=2048, epochs=None, vocab_size=20000,
gen_lines=None, dp=False, field_delimiter=",", overwrite=True,
checkpoint_dir=DEFAULT_CHECKPOINT_DIR):
self.max_lines = max_lines
self.max_line_len = max_line_len
self.epochs = epochs
self.vocab_size = vocab_size
self.gen_lines = gen_lines
self.dp = dp
self.field_delimiter = field_delimiter
self.overwrite = overwrite
self.checkpoint_dir = checkpoint_dir

def _fit_sample(self, data, metadata):
config = {
'max_lines': self.max_lines,
'max_line_len': self.max_line_len,
'epochs': self.epochs or data.shape[1] * 3, # value recommended by Gretel
'vocab_size': self.vocab_size,
'gen_lines': self.gen_lines or data.shape[0],
'dp': self.dp,
'field_delimiter': self.field_delimiter,
'overwrite': self.overwrite,
'checkpoint_dir': self.checkpoint_dir
}
batcher = DataFrameBatch(df=data, config=config)
batcher.create_training_data()
batcher.train_all_batches()
batcher.generate_all_batch_lines()
synth_data = batcher.batches_to_df()
return synth_data


class PreprocessedGretel(Gretel):
"""Class that uses RDT to make all columns numeric before using Gretel's model."""

CONVERT_TO_NUMERIC = True

@staticmethod
def make_numeric(val):
if type(val) in [float, int]:
return val

if isinstance(val, str) and val.isnumeric():
return float(val)

return np.nan

def _fix_numeric_columns(self, data, metadata):
fields_metadata = metadata['fields']
for field in data:
if field in fields_metadata and fields_metadata.get(field).get('type') == 'id':
continue

data[field] = data[field].apply(self.make_numeric)
avg = data[field].mean() if not np.isnan(data[field].mean()) else 0
data[field] = data[field].fillna(round(avg))

def _fit_sample(self, data, metadata):
synth_data = super()._fit_sample(data, metadata)
self._fix_numeric_columns(synth_data, metadata)
return synth_data
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
'appdirs>1.1.4,<2',
'boto3>=1.15.0,<2',
'compress-pickle>=1.2.0,<2',
'gretel-synthetics>=0.15.4,<0.16',
'humanfriendly>=8.2,<9',
'numpy>=1.15.4,<2',
'numpy>=1.15.4,<1.20',
'pandas<1.1.5,>=1.1',
'pomegranate>=0.13.0,<0.13.5',
'psutil>=5.7,<6',
Expand All @@ -28,6 +29,8 @@
'rdt>=0.4.1',
'sdmetrics>=0.3.0',
'sdv>=0.9.0',
'tensorflow==2.4.0rc1',
'wheel~=0.35',
]

setup_requires = [
Expand Down

0 comments on commit f53b6b2

Please sign in to comment.