Skip to content

Commit

Permalink
Validate discrete column (sdv-dev#118)
Browse files Browse the repository at this point in the history
* Bump version: 0.3.1.dev0 → 0.3.1.dev1

* Validates discrete columns

* Fix lint
  • Loading branch information
fealho authored Dec 31, 2020
1 parent cd56f03 commit 5ac5161
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 4 deletions.
2 changes: 1 addition & 1 deletion conda/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% set name = 'ctgan' %}
{% set version = '0.3.1.dev0' %}
{% set version = '0.3.1.dev1' %}

package:
name: "{{ name|lower }}"
Expand Down
2 changes: 1 addition & 1 deletion ctgan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

__author__ = 'MIT Data To AI Lab'
__email__ = 'dailabmit@gmail.com'
__version__ = '0.3.1.dev0'
__version__ = '0.3.1.dev1'

from ctgan.demo import load_demo
from ctgan.synthesizers.ctgan import CTGANSynthesizer
Expand Down
28 changes: 28 additions & 0 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import warnings

import numpy as np
import pandas as pd
import torch
from packaging import version
from torch import optim
Expand Down Expand Up @@ -222,6 +223,31 @@ def _cond_loss(self, data, c, m):

return (loss * m).sum() / data.size()[0]

def _validate_discrete_columns(self, train_data, discrete_columns):
"""Check whether ``discrete_columns`` exists in ``train_data``.
Args:
train_data (numpy.ndarray or pandas.DataFrame):
Training Data. It must be a 2-dimensional numpy array or a pandas.DataFrame.
discrete_columns (list-like):
List of discrete columns to be used to generate the Conditional
Vector. If ``train_data`` is a Numpy array, this list should
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
"""
if isinstance(train_data, pd.DataFrame):
invalid_columns = set(discrete_columns) - set(train_data.columns)
elif isinstance(train_data, np.ndarray):
invalid_columns = []
for column in discrete_columns:
if column < 0 or column >= train_data.shape[1]:
invalid_columns.append(column)
else:
raise TypeError('``train_data`` should be either pd.DataFrame or np.array.')

if invalid_columns:
raise ValueError('Invalid columns found: {}'.format(invalid_columns))

def fit(self, train_data, discrete_columns=tuple(), epochs=None):
"""Fit the CTGAN Synthesizer models to the training data.
Expand All @@ -234,6 +260,8 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=None):
contain the integer indices of the columns. Otherwise, if it is
a ``pandas.DataFrame``, this list should contain the column names.
"""
self._validate_discrete_columns(train_data, discrete_columns)

if epochs is None:
epochs = self._epochs
else:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.3.1.dev0
current_version = 0.3.1.dev1
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<candidate>\d+))?
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,6 @@
test_suite='tests',
tests_require=tests_require,
url='https://github.com/sdv-dev/CTGAN',
version='0.3.1.dev0',
version='0.3.1.dev1',
zip_safe=False,
)
23 changes: 23 additions & 0 deletions tests/integration/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import numpy as np
import pandas as pd
import pytest

from ctgan.synthesizers.ctgan import CTGANSynthesizer

Expand Down Expand Up @@ -145,3 +146,25 @@ def test_save_load():
sampled = ctgan.sample(1000)
assert set(sampled.columns) == {'continuous', 'discrete'}
assert set(sampled['discrete'].unique()) == {'a', 'b', 'c'}


def test_wrong_discrete_columns_dataframe():
data = pd.DataFrame({
'discrete': ['a', 'b']
})
discrete_columns = ['b', 'c']

ctgan = CTGANSynthesizer(epochs=1)
with pytest.raises(ValueError):
ctgan.fit(data, discrete_columns)


def test_wrong_discrete_columns_numpy():
data = pd.DataFrame({
'discrete': ['a', 'b']
})
discrete_columns = [0, 1]

ctgan = CTGANSynthesizer(epochs=1)
with pytest.raises(ValueError):
ctgan.fit(data.to_numpy(), discrete_columns)

0 comments on commit 5ac5161

Please sign in to comment.