From 6e71bbd3ac4a7d82b1804f12463e2ce3603379af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=95=E4=B8=9C?= Date: Wed, 4 Mar 2020 21:10:24 +0800 Subject: [PATCH] 1.0 --- dataset.py | 149 +++++++++++++++++++++++++++++++++++++++++ wgan.py | 103 ++++++++++++++++++++++++++++ wgan_gp_train.py | 171 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 423 insertions(+) create mode 100644 dataset.py create mode 100644 wgan.py create mode 100644 wgan_gp_train.py diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..e686025 --- /dev/null +++ b/dataset.py @@ -0,0 +1,149 @@ +import multiprocessing + +import tensorflow as tf + + +def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1): + @tf.function + def _map_fn(img): + img = tf.image.resize(img, [resize, resize]) + img = tf.clip_by_value(img, 0, 255) + img = img / 127.5 - 1 + return img + + dataset = disk_image_batch_dataset(img_paths, + batch_size, + drop_remainder=drop_remainder, + map_fn=_map_fn, + shuffle=shuffle, + repeat=repeat) + img_shape = (resize, resize, 3) + len_dataset = len(img_paths) // batch_size + + return dataset, img_shape, len_dataset + + +def batch_dataset(dataset, + batch_size, + drop_remainder=True, + n_prefetch_batch=1, + filter_fn=None, + map_fn=None, + n_map_threads=None, + filter_after_map=False, + shuffle=True, + shuffle_buffer_size=None, + repeat=None): + # set defaults + if n_map_threads is None: + n_map_threads = multiprocessing.cpu_count() + if shuffle and shuffle_buffer_size is None: + shuffle_buffer_size = max(batch_size * 128, 2048) # set the minimum buffer size as 2048 + + # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly + if shuffle: + dataset = dataset.shuffle(shuffle_buffer_size) + + if not filter_after_map: + if filter_fn: + dataset = dataset.filter(filter_fn) + + if map_fn: + dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads) + + else: # [*] this is slower + if map_fn: + dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads) + + if filter_fn: + dataset = dataset.filter(filter_fn) + + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + + dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch) + + return dataset + + +def memory_data_batch_dataset(memory_data, + batch_size, + drop_remainder=True, + n_prefetch_batch=1, + filter_fn=None, + map_fn=None, + n_map_threads=None, + filter_after_map=False, + shuffle=True, + shuffle_buffer_size=None, + repeat=None): + """Batch dataset of memory data. + + Parameters + ---------- + memory_data : nested structure of tensors/ndarrays/lists + + """ + dataset = tf.data.Dataset.from_tensor_slices(memory_data) + dataset = batch_dataset(dataset, + batch_size, + drop_remainder=drop_remainder, + n_prefetch_batch=n_prefetch_batch, + filter_fn=filter_fn, + map_fn=map_fn, + n_map_threads=n_map_threads, + filter_after_map=filter_after_map, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + repeat=repeat) + return dataset + + +def disk_image_batch_dataset(img_paths, + batch_size, + labels=None, + drop_remainder=True, + n_prefetch_batch=1, + filter_fn=None, + map_fn=None, + n_map_threads=None, + filter_after_map=False, + shuffle=True, + shuffle_buffer_size=None, + repeat=None): + """Batch dataset of disk image for PNG and JPEG. + + Parameters + ---------- + img_paths : 1d-tensor/ndarray/list of str + labels : nested structure of tensors/ndarrays/lists + + """ + if labels is None: + memory_data = img_paths + else: + memory_data = (img_paths, labels) + + def parse_fn(path, *label): + img = tf.io.read_file(path) + img = tf.image.decode_png(img, 3) # fix channels to 3 + return (img,) + label + + if map_fn: # fuse `map_fn` and `parse_fn` + def map_fn_(*args): + return map_fn(*parse_fn(*args)) + else: + map_fn_ = parse_fn + + dataset = memory_data_batch_dataset(memory_data, + batch_size, + drop_remainder=drop_remainder, + n_prefetch_batch=n_prefetch_batch, + filter_fn=filter_fn, + map_fn=map_fn_, + n_map_threads=n_map_threads, + filter_after_map=filter_after_map, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + repeat=repeat) + + return dataset diff --git a/wgan.py b/wgan.py new file mode 100644 index 0000000..0268847 --- /dev/null +++ b/wgan.py @@ -0,0 +1,103 @@ +import os + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers + +gpus = tf.config.experimental.list_physical_devices(device_type='GPU') +for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + + +class Generator(keras.Model): + def __init__(self, is_training=True): + super(Generator, self).__init__() + # z: [b, 100] => [b, 64, 64, 3] + # channel decrease, image size increase + self.fc = layers.Dense(3 * 3 * 512) + + self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid') + self.bn1 = layers.BatchNormalization() + + self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid') + self.bn2 = layers.BatchNormalization() + + self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid') + + def call(self, inputs, training=None, mask=None): + x = self.fc(inputs) + x = tf.reshape(x, [-1, 3, 3, 512]) + x = tf.nn.leaky_relu(x) + # + x = self.bn1(self.conv1(x), training=training) + x = tf.nn.leaky_relu(x) + + x = self.bn2(self.conv2(x), training=training) + x = tf.nn.leaky_relu(x) + + x = self.conv3(x) + x = tf.tanh(x) + + return x + + +class Discriminator(keras.Model): + def __init__(self): + super(Discriminator, self).__init__() + # [b, 64, 64, 3] => [b, 1] + self.conv1 = layers.Conv2D(64, 5, 4, 'valid') + + self.conv2 = layers.Conv2D(128, 5, 4, 'valid') + self.bn2 = layers.BatchNormalization() + + self.conv3 = layers.Conv2D(256, 5, 4, 'valid') + self.bn3 = layers.BatchNormalization() + + self.conv4 = layers.Conv2D(512, 5, 4, 'valid') + self.bn4 = layers.BatchNormalization() + + self.flatten = layers.Flatten() + + self.fc = layers.Dense(1) + + def call(self, inputs, training=None, mask=None): + x = tf.nn.leaky_relu(self.conv1(inputs)) + x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training)) + x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training)) + + # [b, h, w, c] => [b, -1] + x = self.flatten(x) + print(x.shape) + logits = self.fc(x) + return logits + + +def main(): + img_dim = 96 + z_dim = 100 + num_layers = int(np.log2(img_dim)) - 3 + max_num_channels = img_dim * 8 + f_size = img_dim // 2 ** (num_layers + 1) + batch_size = 256 + + print(num_layers) + print(max_num_channels) + print(f_size) + + d = Discriminator() + g = Generator() + + x = tf.random.normal([2, 96, 96, 3]) + z = tf.random.normal([2, 100]) + + prob = d(x) + print(prob.shape) + + # pic = g(z) + # print(pic.shape) + + +if __name__ == '__main__': + main() diff --git a/wgan_gp_train.py b/wgan_gp_train.py new file mode 100644 index 0000000..c61f2fa --- /dev/null +++ b/wgan_gp_train.py @@ -0,0 +1,171 @@ +import os + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +import numpy as np +import tensorflow as tf +from PIL import Image +import glob +from gan import Generator, Discriminator +from tensorflow import keras +from dataset import make_anime_dataset + + +def save_result(val_out, val_block_size, image_path, color_mode): + def preprocess(img): + img = ((img + 1.0) * 127.5).astype(np.uint8) + # img = img.astype(np.uint8) + return img + + preprocesed = preprocess(val_out) + final_image = np.array([]) + single_row = np.array([]) + for b in range(val_out.shape[0]): + # concat image into a row + if single_row.size == 0: + single_row = preprocesed[b, :, :, :] + else: + single_row = np.concatenate( + (single_row, preprocesed[b, :, :, :]), axis=1) + + # concat image row to final_image + if (b + 1) % val_block_size == 0: + if final_image.size == 0: + final_image = single_row + else: + final_image = np.concatenate((final_image, single_row), axis=0) + + # reset single row + single_row = np.array([]) + + if final_image.shape[2] == 1: + final_image = np.squeeze(final_image, axis=2) + Image.fromarray(final_image).save(image_path) + + +def gradient_penalty(discriminator, real_image, fake_image): + batchsz = real_image.shape[0] + # dtype caused disconvergence? + t = tf.random.uniform([batchsz, 1, 1, 1], minval=0., + maxval=1., dtype=tf.float32) + x_hat = t * real_image + (1. - t) * fake_image + with tf.GradientTape() as tape: + tape.watch(x_hat) + Dx = discriminator(x_hat, training=True) + grads = tape.gradient(Dx, x_hat) + slopes = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3])) + gp = tf.reduce_mean((slopes - 1.) ** 2) + return gp + + +def d_loss_fn(generator, discriminator, batch_z, real_image): + fake_image = generator(batch_z, training=True) + d_fake_score = discriminator(fake_image, training=True) + d_real_score = discriminator(real_image, training=True) + + loss = tf.reduce_mean(d_fake_score - d_real_score) + + gp = gradient_penalty(discriminator, real_image, fake_image) * 10. + + loss = loss + gp + return loss, gp + + +def g_loss_fn(generator, discriminator, batch_z): + fake_image = generator(batch_z, training=True) + d_fake_logits = discriminator(fake_image, training=True) + # loss = celoss_ones(d_fake_logits) + loss = -tf.reduce_mean(d_fake_logits) + return loss + + +def main(): + tf.random.set_seed(233) + np.random.seed(233) + + z_dim = 100 + epochs = 3000000 + batch_size = 512 + learning_rate = 2e-4 + # ratios = D steps:G steps + ratios = 2 + + img_path = glob.glob(os.path.join('faces', '*.jpg')) + dataset, img_shape, _ = make_anime_dataset(img_path, batch_size) + print(dataset, img_shape) + sample = next(iter(dataset)) + print(sample.shape, tf.reduce_max(sample).numpy(), + tf.reduce_min(sample).numpy()) + dataset = dataset.repeat() + db_iter = iter(dataset) + + generator = Generator() + generator.build(input_shape=(None, z_dim)) + # generator.load_weights(os.path.join('checkpoints', 'generator-5000')) + discriminator = Discriminator() + discriminator.build(input_shape=(None, 64, 64, 3)) + # discriminator.load_weights(os.path.join('checkpoints', 'discriminator-5000')) + + g_optimizer = tf.optimizers.Adam(learning_rate, beta_1=0.5) + d_optimizer = tf.optimizers.Adam(learning_rate, beta_1=0.5) + # a fixed noise for sampling + z_sample = tf.random.normal([100, z_dim]) + + g_loss_meter = keras.metrics.Mean() + d_loss_meter = keras.metrics.Mean() + gp_meter = keras.metrics.Mean() + + for epoch in range(epochs): + + # train D + for step in range(ratios): + batch_z = tf.random.normal([batch_size, z_dim]) + batch_x = next(db_iter) + with tf.GradientTape() as tape: + d_loss, gp = d_loss_fn( + generator, discriminator, batch_z, batch_x) + + d_loss_meter.update_state(d_loss) + gp_meter.update_state(gp) + + gradients = tape.gradient( + d_loss, discriminator.trainable_variables) + d_optimizer.apply_gradients( + zip(gradients, discriminator.trainable_variables)) + + # train G + batch_z = tf.random.normal([batch_size, z_dim]) + with tf.GradientTape() as tape: + g_loss = g_loss_fn(generator, discriminator, batch_z) + + g_loss_meter.update_state(g_loss) + + gradients = tape.gradient(g_loss, generator.trainable_variables) + g_optimizer.apply_gradients( + zip(gradients, generator.trainable_variables)) + + if epoch % 100 == 0: + + fake_image = generator(z_sample, training=False) + + print(epoch, 'd-loss:', d_loss_meter.result().numpy(), + 'g-loss', g_loss_meter.result().numpy(), + 'gp', gp_meter.result().numpy()) + + d_loss_meter.reset_states() + g_loss_meter.reset_states() + gp_meter.reset_states() + + # save generated image samples + img_path = os.path.join('images_wgan_gp', 'wgan_gp-%d.png' % epoch) + save_result(fake_image.numpy(), 10, img_path, color_mode='P') + + if epoch + 1 % 2000 == 0: + generator.save_weights(os.path.join( + 'checkpoints_gp', 'generator-%d' % epoch)) + discriminator.save_weights(os.path.join( + 'checkpoints_gp', 'discriminator-%d' % epoch)) + + +if __name__ == '__main__': + main()