Skip to content

Commit

Permalink
V0.3.0 (sdv-dev#112)
Browse files Browse the repository at this point in the history
* relative import; warning filter; data sampler file

* merge cond sampler and sampler to one data sampler

* change code to 2 space indent

* define util classes and rename variables.

* refactoring transformer and fix bugs.

* rename file

* rename func

* add hyper parameters to args.

* add doc strings.

* fix line length

* fix bug

* fix indent

* change indent to 4

* we should allow breaking lines before binary operators.

* Bump version: 0.2.2.dev1 → 0.3.0.dev0

* Code refactoring

* Removes attr and load/save, and fix lint

* Fix typo

* Fix conda version

* Add TVAE (sdv-dev#111)

* Adds tvae

* Correctly adds tvae

* Restructure files

* Simplify tvae test

* Fix lint/add verbose to ctgan

* Fix lint

* Move epochs from fit to __init__

* Fix epochs relocation

* General refactoring

* Fix readme

* Adds save/laod to base class

* Fixes save/load, adds test case

* Fix lint

* Fix lint

* Improved testing

* Added FutureWarning

* Updated warning/fix lint

* Fixes tvae bug/fix lint

* Empty commit

* Fix lint

* Update readme

* Updates readme

* Updates readme

Co-authored-by: Lei Xu <leix@mit.edu>
Co-authored-by: Carles Sala <carles@pythiac.com>
  • Loading branch information
3 people authored Dec 18, 2020
1 parent 333aa9d commit 0359bb4
Show file tree
Hide file tree
Showing 20 changed files with 1,173 additions and 969 deletions.
189 changes: 38 additions & 151 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
<p align="left">
<img width=15% src="https://dai.lids.mit.edu/wp-content/uploads/2018/06/Logo_DAI_highres.png" alt=“sdv-dev” />
<i>An open source project from Data to AI Lab at MIT.</i>
<a href="https://dai.lids.mit.edu">
<img width=15% src="https://dai.lids.mit.edu/wp-content/uploads/2018/06/Logo_DAI_highres.png" alt="DAI-Lab" />
</a>
<i>An Open Source Project from the <a href="https://dai.lids.mit.edu">Data to AI Lab, at MIT</a></i>
</p>

[![Development Status](https://img.shields.io/badge/Development%20Status-2%20--%20Pre--Alpha-yellow)](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha)
Expand All @@ -9,29 +11,22 @@
[![Downloads](https://pepy.tech/badge/ctgan)](https://pepy.tech/project/ctgan)
[![Coverage Status](https://codecov.io/gh/sdv-dev/CTGAN/branch/master/graph/badge.svg)](https://codecov.io/gh/sdv-dev/CTGAN)

# CTGAN

Implementation of our NeurIPS paper [Modeling Tabular data using Conditional GAN](https://arxiv.org/abs/1907.00503).

CTGAN is a GAN-based data synthesizer that can generate synthetic tabular data with high fidelity.
<img align="center" width=30% src="docs/images/ctgan.png">

* Website: https://sdv.dev
* Documentation: https://sdv.dev/SDV
* Repository: https://github.com/sdv-dev/CTGAN
* License: [MIT](https://github.com/sdv-dev/CTGAN/blob/master/LICENSE)
* Development Status: [Pre-Alpha](https://pypi.org/search/?c=Development+Status+%3A%3A+2+-+Pre-Alpha)
* Homepage: https://github.com/sdv-dev/CTGAN

## Overview

Based on previous work ([TGAN](https://github.com/sdv-dev/TGAN)) on synthetic data generation,
we develop a new model called CTGAN. Several major differences make CTGAN outperform TGAN.

- **Preprocessing**: CTGAN uses more sophisticated Variational Gaussian Mixture Model to detect
modes of continuous columns.
- **Network structure**: TGAN uses LSTM to generate synthetic data column by column. CTGAN uses
Fully-connected networks which is more efficient.
- **Features to prevent mode collapse**: We design a conditional generator and resample the
training data to prevent model collapse on discrete columns. We use WGANGP and PacGAN to
stabilize the training of GAN.
CTGAN is a collection of Deep Learning based Synthetic Data Generators for single table data, which are able to learn from real data and generate synthetic clones with high fidelity.

Currently, this library implements the **CTGAN** and **TVAE** models proposed in the [Modeling Tabular data using Conditional GAN](https://arxiv.org/abs/1907.00503) paper. For more information about these models, please check out the respective user guides:
* [CTGAN User Guide](https://sdv.dev/SDV/user_guides/single_table/ctgan.html).
* [TVAE User Guide](https://sdv.dev/SDV/user_guides/single_table/tvae.html).

# Install

Expand All @@ -49,9 +44,6 @@ pip install ctgan

This will pull and install the latest stable release from [PyPI](https://pypi.org/).

If you want to install from source or contribute to the project please read the
[Contributing Guide](CONTRIBUTING.rst).

## Install with conda

**CTGAN** can also be installed using [conda](https://docs.conda.io/en/latest/):
Expand All @@ -63,72 +55,25 @@ conda install -c sdv-dev -c pytorch -c conda-forge ctgan
This will pull and install the latest stable release from [Anaconda](https://anaconda.org/).


# Data Format

**CTGAN** expects the input data to be a table given as either a `numpy.ndarray` or a
`pandas.DataFrame` object with two types of columns:

* **Continuous Columns**: Columns that contain numerical values and which can take any value.
* **Discrete columns**: Columns that only contain a finite number of possible values, wether
these are string values or not.

This is an example of a table with 4 columns:

* A continuous column with float values
* A continuous column with integer values
* A discrete column with string values
* A discrete column with integer values

| | A | B | C | D |
|---|------|-----|-----|---|
| 0 | 0.1 | 100 | 'a' | 1 |
| 1 | -1.3 | 28 | 'b' | 2 |
| 2 | 0.3 | 14 | 'a' | 2 |
| 3 | 1.4 | 87 | 'a' | 3 |
| 4 | -0.1 | 69 | 'b' | 2 |
# Usage Example

> :warning: **WARNING**: If you're just getting started with synthetic data, we recommend using the SDV library which provides user-friendly APIs for interacting with CTGAN. To learn more about using CTGAN through SDV, check out the user guide [here](https://sdv.dev/SDV/user_guides/single_table/ctgan.html).
**NOTE**: CTGAN does not distinguish between float and integer columns, which means that it will
sample float values in all cases. If integer values are required, the outputted float values
must be rounded to integers in a later step, outside of CTGAN.
To get started with CTGAN, you should prepare your data as either a `numpy.ndarray` or a `pandas.DataFrame` object with two types of columns:

# Python Quickstart
* **Continuous Columns**: can contain any numerical value.
* **Discrete Columns**: contain a finite number values, whether these are string values or not.

In this short tutorial we will guide you through a series of steps that will help you
getting started with **CTGAN**.
In this example we load the [Adult Census Dataset](https://archive.ics.uci.edu/ml/datasets/adult) which is a built-in demo dataset. We then model it using the **CTGANSynthesizer** and generate a synthetic copy of it.

## 1. Model the data

### Step 1: Prepare your data

Before being able to use CTGAN you will need to prepare your data as specified above.

For this example, we will be loading some data using the `ctgan.load_demo` function.

```python3
from ctgan import CTGANSynthesizer
from ctgan import load_demo

data = load_demo()
```

This will download a copy of the [Adult Census Dataset](https://archive.ics.uci.edu/ml/datasets/adult) as a dataframe:

| age | workclass | fnlwgt | ... | hours-per-week | native-country | income |
|-------|------------------|----------|-----|------------------|------------------|----------|
| 39 | State-gov | 77516 | ... | 40 | United-States | <=50K |
| 50 | Self-emp-not-inc | 83311 | ... | 13 | United-States | <=50K |
| 38 | Private | 215646 | ... | 40 | United-States | <=50K |
| 53 | Private | 234721 | ... | 40 | United-States | <=50K |
| 28 | Private | 338409 | ... | 40 | Cuba | <=50K |
| ... | ... | ... | ... | ... | ... | ... |


Aside from the table itself, you will need to create a list with the names of the discrete
variables.

For this example:

```python3
# Names of the columns that are discrete
discrete_columns = [
'workclass',
'education',
Expand All @@ -140,93 +85,23 @@ discrete_columns = [
'native-country',
'income'
]
```

### Step 2: Fit CTGAN to your data

Once you have the data ready, you need to import and create an instance of the `CTGANSynthesizer`
class.

```python3
from ctgan import CTGANSynthesizer

ctgan = CTGANSynthesizer()
```

And then call its `fit` method passing your data and the list of discrete columns

```python
ctgan = CTGANSynthesizer(epochs=10)
ctgan.fit(data, discrete_columns)
```

**NOTE**: This process is likely to take a long time to run.

If you want to make the process shorter, or longer, you can control the number of training epochs
that the model will be performing by adding it to the `fit` call:

```python3
ctgan.fit(data, discrete_columns, epochs=5)
```

## 2. Generate synthetic data

Once the process has finished, all you need to do is call the `sample` method of your
`CTGANSynthesizer` instance indicating the number of rows that you want to generate.

```python3
# Synthetic copy
samples = ctgan.sample(1000)
```

The output will be a table with the exact same format as the input and filled with the synthetic
data generated by the model.

| age | workclass | fnlwgt | ... | hours-per-week | native-country | income |
|---------|--------------|-----------|-----|------------------|------------------|----------|
| 26.3191 | Private | 124079 | ... | 40.1557 | United-States | <=50K |
| 39.8558 | Private | 133996 | ... | 40.2507 | United-States | <=50K |
| 38.2477 | Self-emp-inc | 135955 | ... | 40.1124 | Ecuador | <=50K |
| 29.6468 | Private | 3331.86 | ... | 27.012 | United-States | <=50K |
| 20.9853 | Private | 120637 | ... | 40.0238 | United-States | <=50K |
| ... | ... | ... | ... | ... | ... | ... |

## 3. Generate synthetic data conditioning on one column

In the CTGAN model, we have a conditional vector. By setting the conditional vector, we increase
the probability of getting one value in one discrete column.

For example, the following code **increase the probability** of workclass = " Private".

```python3
samples = ctgan.sample(1000, 'workclass', ' Private')
```

**Note that this code does not guarante workclass=" Private"**

## 4. Save and load the synthesizer

To save a trained ctgan synthesizer, you can call the `save` method passing a path to the file
in which the model will be saved:

```python3
ctgan.save('ctgan.pkl')
```

Later on, you can restore the saved synthetsizer by passing the path to the `load`
model of the `CTGANSynthetizer` method:
# Join our community

```python3
ctgan = CTGANSynthesizer.load('ctgan.pkl')
```

# Join our community
1. Please have a look at the [Contributing Guide](https://sdv.dev/SDV/developer_guides/contributing.html) to see how you can contribute to the project.
2. If you have any doubts, feature requests or detect an error, please [open an issue on github](https://github.com/sdv-dev/CTGAN/issues) or [join our Slack Workspace](https://sdv-space.slack.com/join/shared_invite/zt-gdsfcb5w-0QQpFMVoyB2Yd6SRiMplcw#/).
3. Also, do not forget to check the [project documentation site](https://sdv.dev/SDV/)!

1. If you would like to try more dataset examples, please have a look at the [examples folder](
https://github.com/sdv-dev/CTGAN/tree/master/examples) of the repository. Please contact us
if you have a usage example that you would want to share with the community.
2. If you want to contribute to the project code, please head to the [Contributing Guide](
CONTRIBUTING.rst) for more details about how to do it.
3. If you have any doubts, feature requests or detect an error, please [open an issue on github](
https://github.com/sdv-dev/CTGAN/issues)

# Citing TGAN

Expand Down Expand Up @@ -260,3 +135,15 @@ A package to easily deploy **CTGAN** onto a remote server. This package is devel

More details can be found in the corresponding repository: https://github.com/oregonpillow/ctgan-server-cli


# The Synthetic Data Vault

<p>
<a href="https://sdv.dev">
<img width=30% src="https://github.com/sdv-dev/SDV/blob/master/docs/images/SDV-Logo-Color-Tagline.png?raw=true">
</a>
<p><i>This repository is part of <a href="https://sdv.dev">The Synthetic Data Vault Project</a></i></p>
</p>

* Website: https://sdv.dev
* Documentation: https://sdv.dev/SDV
2 changes: 1 addition & 1 deletion conda/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% set name = 'ctgan' %}
{% set version = '0.2.3.dev0' %}
{% set version = '0.3.0.dev0' %}

package:
name: "{{ name|lower }}"
Expand Down
6 changes: 4 additions & 2 deletions ctgan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

__author__ = 'MIT Data To AI Lab'
__email__ = 'dailabmit@gmail.com'
__version__ = '0.2.3.dev0'
__version__ = '0.3.0.dev0'

from ctgan.demo import load_demo
from ctgan.synthesizer import CTGANSynthesizer
from ctgan.synthesizers.ctgan import CTGANSynthesizer
from ctgan.synthesizers.tvae import TVAESynthesizer

__all__ = (
'CTGANSynthesizer',
'TVAESynthesizer',
'load_demo'
)
38 changes: 32 additions & 6 deletions ctgan/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse

from ctgan.data import read_csv, read_tsv, write_tsv
from ctgan.synthesizer import CTGANSynthesizer
from ctgan.synthesizers.ctgan import CTGANSynthesizer


def _parse_args():
Expand All @@ -15,11 +15,31 @@ def _parse_args():

parser.add_argument('-m', '--metadata', help='Path to the metadata')
parser.add_argument('-d', '--discrete',
help='Comma separated list of discrete columns, no whitespaces')

help='Comma separated list of discrete columns without whitespaces.')
parser.add_argument('-n', '--num-samples', type=int,
help='Number of rows to sample. Defaults to the training data size')

parser.add_argument('--generator_lr', type=float, default=2e-4,
help='Learning rate for the generator.')
parser.add_argument('--discriminator_lr', type=float, default=2e-4,
help='Learning rate for the discriminator.')

parser.add_argument('--generator_decay', type=float, default=1e-6,
help='Weight decay for the generator.')
parser.add_argument('--discriminator_decay', type=float, default=0,
help='Weight decay for the discriminator.')

parser.add_argument('--embedding_dim', type=int, default=128,
help='Dimension of input z to the generator.')
parser.add_argument('--generator_dim', type=str, default='256,256',
help='Dimension of each generator layer. '
'Comma separated integers with no whitespaces.')
parser.add_argument('--discriminator_dim', type=str, default='256,256',
help='Dimension of each discriminator layer. '
'Comma separated integers with no whitespaces.')

parser.add_argument('--batch_size', type=int, default=500,
help='Batch size. Must be an even number.')
parser.add_argument('--save', default=None, type=str,
help='A filename to save the trained synthesizer.')
parser.add_argument('--load', default=None, type=str,
Expand All @@ -38,7 +58,6 @@ def _parse_args():

def main():
args = _parse_args()

if args.tsv:
data, discrete_columns = read_tsv(args.data, args.metadata)
else:
Expand All @@ -47,8 +66,15 @@ def main():
if args.load:
model = CTGANSynthesizer.load(args.load)
else:
model = CTGANSynthesizer()
model.fit(data, discrete_columns, args.epochs)
generator_dims = [int(x) for x in args.generator_dims.split(',')]
discriminator_dims = [int(x) for x in args.discriminator_dims.split(',')]
model = CTGANSynthesizer(
embedding_dim=args.embedding_dim, generator_dims=generator_dims,
discriminator_dims=discriminator_dims, generator_lr=args.generator_lr,
generator_decay=args.generator_decay, discriminator_lr=args.discriminator_lr,
discriminator_decay=args.discriminator_decay, batch_size=args.batch_size,
epochs=args.epochs)
model.fit(data, discrete_columns)

if args.save is not None:
model.save(args.save)
Expand Down
Loading

0 comments on commit 0359bb4

Please sign in to comment.