Skip to content

Commit

Permalink
Merge pull request sdv-dev#68 from sdv-dev/feature_lx_load_save_conds…
Browse files Browse the repository at this point in the history
…ampling

Feature lx load save condsampling
  • Loading branch information
leix28 authored Sep 11, 2020
2 parents ea1d392 + ab21e59 commit 68a0661
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 38 deletions.
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,35 @@ data generated by the model.
| 20.9853 | Private | 120637 | ... | 40.0238 | United-States | <=50K |
| ... | ... | ... | ... | ... | ... | ... |

## 3. Generate synthetic data conditioning on one column

In the CTGAN model, we have a conditional vector. By setting the conditional vector, we increase
the probability of getting one value in one discrete column.

For example, the following code **increase the probability** of workclass = "Private".

```python
samples = ctgan.sample(1000, "workclass", "Private")
```

**Note that this code does not guarante workclass="Private"**

## 4. Save and load synthesizer
To save a trained ctgan synthesizer, use

```python
ctgan.save(path_to_a_folder)
```

To restore a saved synthesizer, use

```python
ctgan = CTGANSynthesizer()
ctgan.fit(data, discrete_columns, epochs=0, load_path=path_to_a_folder)
```

**Please make sure the saved model and the loaded model are for the
same dataset.**

# Join our community

Expand Down
31 changes: 29 additions & 2 deletions ctgan/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ def _parse_args():
parser.add_argument('-n', '--num-samples', type=int,
help='Number of rows to sample. Defaults to the training data size')

parser.add_argument('--save', default=None, type=str,
help='A filename to save the trained synthesizer.')
parser.add_argument('--load', default=None, type=str,
help='A filename to load a trained synthesizer.')

parser.add_argument("--sample_condition_column", default=None, type=str,
help="Select a discrete column name.")
parser.add_argument("--sample_condition_column_value", default=None, type=str,
help="Specify the value of the selected discrete column.")

parser.add_argument('data', help='Path to training data')
parser.add_argument('output', help='Path of the output file')

Expand All @@ -34,13 +44,30 @@ def main():
else:
data, discrete_columns = read_csv(args.data, args.metadata, args.header, args.discrete)

model = CTGANSynthesizer()
if args.load:
model = CTGANSynthesizer.load(args.load)
else:
model = CTGANSynthesizer()
model.fit(data, discrete_columns, args.epochs)

if args.save is not None:
model.save(args.save)

num_samples = args.num_samples or len(data)
sampled = model.sample(num_samples)

if args.sample_condition_column is not None:
assert args.sample_condition_column_value is not None

sampled = model.sample(
num_samples,
args.sample_condition_column,
args.sample_condition_column_value)

if args.tsv:
write_tsv(sampled, args.metadata, args.output)
else:
sampled.to_csv(args.output, index=False)


if __name__ == "__main__":
main()
6 changes: 6 additions & 0 deletions ctgan/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,9 @@ def sample_zero(self, batch):
vec[i, pick + self.interval[col, 0]] = 1

return vec

def generate_cond_from_condition_column_info(self, condition_info, batch):
vec = np.zeros((batch, self.n_opt), dtype='float32')
id = self.interval[condition_info["discrete_column_id"]][0] + condition_info["value_id"]
vec[:, id] = 1
return vec
121 changes: 85 additions & 36 deletions ctgan/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(self, embedding_dim=128, gen_dim=(256, 256), dis_dim=(256, 256),
self.l2scale = l2scale
self.batch_size = batch_size
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.trained_epoches = 0

def _apply_activate(self, data):
data_t = []
Expand Down Expand Up @@ -114,42 +115,52 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr
sampling. Defaults to ``True``.
"""

self.transformer = DataTransformer()
self.transformer.fit(train_data, discrete_columns)
if not hasattr(self, "transformer"):
self.transformer = DataTransformer()
self.transformer.fit(train_data, discrete_columns)
train_data = self.transformer.transform(train_data)

data_sampler = Sampler(train_data, self.transformer.output_info)

data_dim = self.transformer.output_dimensions
self.cond_generator = ConditionalGenerator(
train_data,
self.transformer.output_info,
log_frequency
)

self.generator = Generator(
self.embedding_dim + self.cond_generator.n_opt,
self.gen_dim,
data_dim
).to(self.device)

discriminator = Discriminator(
data_dim + self.cond_generator.n_opt,
self.dis_dim
).to(self.device)

optimizerG = optim.Adam(
self.generator.parameters(), lr=2e-4, betas=(0.5, 0.9),
weight_decay=self.l2scale
)
optimizerD = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.9))

if not hasattr(self, "cond_generator"):
self.cond_generator = ConditionalGenerator(
train_data,
self.transformer.output_info,
log_frequency
)

if not hasattr(self, "generator"):
self.generator = Generator(
self.embedding_dim + self.cond_generator.n_opt,
self.gen_dim,
data_dim
).to(self.device)

if not hasattr(self, "discriminator"):
self.discriminator = Discriminator(
data_dim + self.cond_generator.n_opt,
self.dis_dim
).to(self.device)

if not hasattr(self, "optimizerG"):
self.optimizerG = optim.Adam(
self.generator.parameters(), lr=2e-4, betas=(0.5, 0.9),
weight_decay=self.l2scale
)

if not hasattr(self, "optimizerD"):
self.optimizerD = optim.Adam(
self.discriminator.parameters(), lr=2e-4, betas=(0.5, 0.9))

assert self.batch_size % 2 == 0
mean = torch.zeros(self.batch_size, self.embedding_dim, device=self.device)
std = mean + 1

steps_per_epoch = max(len(train_data) // self.batch_size, 1)
for i in range(epochs):
self.trained_epoches += 1
for id_ in range(steps_per_epoch):
fakez = torch.normal(mean=mean, std=std)

Expand Down Expand Up @@ -180,16 +191,17 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr
real_cat = real
fake_cat = fake

y_fake = discriminator(fake_cat)
y_real = discriminator(real_cat)
y_fake = self.discriminator(fake_cat)
y_real = self.discriminator(real_cat)

pen = discriminator.calc_gradient_penalty(real_cat, fake_cat, self.device)
pen = self.discriminator.calc_gradient_penalty(
real_cat, fake_cat, self.device)
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))

optimizerD.zero_grad()
self.optimizerD.zero_grad()
pen.backward(retain_graph=True)
loss_d.backward()
optimizerD.step()
self.optimizerD.step()

fakez = torch.normal(mean=mean, std=std)
condvec = self.cond_generator.sample(self.batch_size)
Expand All @@ -206,9 +218,9 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr
fakeact = self._apply_activate(fake)

if c1 is not None:
y_fake = discriminator(torch.cat([fakeact, c1], dim=1))
y_fake = self.discriminator(torch.cat([fakeact, c1], dim=1))
else:
y_fake = discriminator(fakeact)
y_fake = self.discriminator(fakeact)

if condvec is None:
cross_entropy = 0
Expand All @@ -217,15 +229,15 @@ def fit(self, train_data, discrete_columns=tuple(), epochs=300, log_frequency=Tr

loss_g = -torch.mean(y_fake) + cross_entropy

optimizerG.zero_grad()
self.optimizerG.zero_grad()
loss_g.backward()
optimizerG.step()
self.optimizerG.step()

print("Epoch %d, Loss G: %.4f, Loss D: %.4f" %
(i + 1, loss_g.detach().cpu(), loss_d.detach().cpu()),
(self.trained_epoches, loss_g.detach().cpu(), loss_d.detach().cpu()),
flush=True)

def sample(self, n):
def sample(self, n, condition_column=None, condition_value=None):
"""Sample data similar to the training data.
Args:
Expand All @@ -236,14 +248,26 @@ def sample(self, n):
numpy.ndarray or pandas.DataFrame
"""

if condition_column is not None and condition_value is not None:
condition_info = self.transformer.covert_column_name_value_to_id(
condition_column, condition_value)
global_condition_vec = self.cond_generator.generate_cond_from_condition_column_info(
condition_info, self.batch_size)
else:
global_condition_vec = None

steps = n // self.batch_size + 1
data = []
for i in range(steps):
mean = torch.zeros(self.batch_size, self.embedding_dim)
std = mean + 1
fakez = torch.normal(mean=mean, std=std).to(self.device)

condvec = self.cond_generator.sample_zero(self.batch_size)
if global_condition_vec is not None:
condvec = global_condition_vec.copy()
else:
condvec = self.cond_generator.sample_zero(self.batch_size)

if condvec is None:
pass
else:
Expand All @@ -259,3 +283,28 @@ def sample(self, n):
data = data[:n]

return self.transformer.inverse_transform(data, None)

def save(self, path):
assert hasattr(self, "generator")
assert hasattr(self, "discriminator")
assert hasattr(self, "transformer")

# always save a cpu model.
device_bak = self.device
self.device = torch.device("cpu")
self.generator.to(self.device)
self.discriminator.to(self.device)

torch.save(self, path)

self.device = device_bak
self.generator.to(self.device)
self.discriminator.to(self.device)

@classmethod
def load(cls, path):
model = torch.load(path)
model.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.generator.to(model.device)
model.discriminator.to(model.device)
return model
27 changes: 27 additions & 0 deletions ctgan/transformer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

import numpy as np
import pandas as pd
from sklearn.exceptions import ConvergenceWarning
Expand Down Expand Up @@ -176,3 +178,28 @@ def inverse_transform(self, data, sigmas):
output = output.values

return output

def save(self, path):
with open(path + "/data_transform.pl", "wb") as f:
pickle.dump(self, f)

def covert_column_name_value_to_id(self, column_name, value):
discrete_counter = 0
column_id = 0
for info in self.meta:
if info["name"] == column_name:
break
if len(info["output_info"]) == 1: # is discrete column
discrete_counter += 1
column_id += 1

return {
"discrete_column_id": discrete_counter,
"column_id": column_id,
"value_id": np.argmax(info["encoder"].transform([[value]])[0])
}

@classmethod
def load(cls, path):
with open(path + "/data_transform.pl", "rb") as f:
return pickle.load(f)

0 comments on commit 68a0661

Please sign in to comment.