Skip to content

Commit

Permalink
1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
donpromax committed Mar 4, 2020
0 parents commit 6e71bbd
Show file tree
Hide file tree
Showing 3 changed files with 423 additions and 0 deletions.
149 changes: 149 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import multiprocessing

import tensorflow as tf


def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
@tf.function
def _map_fn(img):
img = tf.image.resize(img, [resize, resize])
img = tf.clip_by_value(img, 0, 255)
img = img / 127.5 - 1
return img

dataset = disk_image_batch_dataset(img_paths,
batch_size,
drop_remainder=drop_remainder,
map_fn=_map_fn,
shuffle=shuffle,
repeat=repeat)
img_shape = (resize, resize, 3)
len_dataset = len(img_paths) // batch_size

return dataset, img_shape, len_dataset


def batch_dataset(dataset,
batch_size,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
# set defaults
if n_map_threads is None:
n_map_threads = multiprocessing.cpu_count()
if shuffle and shuffle_buffer_size is None:
shuffle_buffer_size = max(batch_size * 128, 2048) # set the minimum buffer size as 2048

# [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
if shuffle:
dataset = dataset.shuffle(shuffle_buffer_size)

if not filter_after_map:
if filter_fn:
dataset = dataset.filter(filter_fn)

if map_fn:
dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

else: # [*] this is slower
if map_fn:
dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

if filter_fn:
dataset = dataset.filter(filter_fn)

dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)

dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)

return dataset


def memory_data_batch_dataset(memory_data,
batch_size,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
"""Batch dataset of memory data.
Parameters
----------
memory_data : nested structure of tensors/ndarrays/lists
"""
dataset = tf.data.Dataset.from_tensor_slices(memory_data)
dataset = batch_dataset(dataset,
batch_size,
drop_remainder=drop_remainder,
n_prefetch_batch=n_prefetch_batch,
filter_fn=filter_fn,
map_fn=map_fn,
n_map_threads=n_map_threads,
filter_after_map=filter_after_map,
shuffle=shuffle,
shuffle_buffer_size=shuffle_buffer_size,
repeat=repeat)
return dataset


def disk_image_batch_dataset(img_paths,
batch_size,
labels=None,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
"""Batch dataset of disk image for PNG and JPEG.
Parameters
----------
img_paths : 1d-tensor/ndarray/list of str
labels : nested structure of tensors/ndarrays/lists
"""
if labels is None:
memory_data = img_paths
else:
memory_data = (img_paths, labels)

def parse_fn(path, *label):
img = tf.io.read_file(path)
img = tf.image.decode_png(img, 3) # fix channels to 3
return (img,) + label

if map_fn: # fuse `map_fn` and `parse_fn`
def map_fn_(*args):
return map_fn(*parse_fn(*args))
else:
map_fn_ = parse_fn

dataset = memory_data_batch_dataset(memory_data,
batch_size,
drop_remainder=drop_remainder,
n_prefetch_batch=n_prefetch_batch,
filter_fn=filter_fn,
map_fn=map_fn_,
n_map_threads=n_map_threads,
filter_after_map=filter_after_map,
shuffle=shuffle,
shuffle_buffer_size=shuffle_buffer_size,
repeat=repeat)

return dataset
103 changes: 103 additions & 0 deletions wgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)


class Generator(keras.Model):
def __init__(self, is_training=True):
super(Generator, self).__init__()
# z: [b, 100] => [b, 64, 64, 3]
# channel decrease, image size increase
self.fc = layers.Dense(3 * 3 * 512)

self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
self.bn1 = layers.BatchNormalization()

self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
self.bn2 = layers.BatchNormalization()

self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')

def call(self, inputs, training=None, mask=None):
x = self.fc(inputs)
x = tf.reshape(x, [-1, 3, 3, 512])
x = tf.nn.leaky_relu(x)
#
x = self.bn1(self.conv1(x), training=training)
x = tf.nn.leaky_relu(x)

x = self.bn2(self.conv2(x), training=training)
x = tf.nn.leaky_relu(x)

x = self.conv3(x)
x = tf.tanh(x)

return x


class Discriminator(keras.Model):
def __init__(self):
super(Discriminator, self).__init__()
# [b, 64, 64, 3] => [b, 1]
self.conv1 = layers.Conv2D(64, 5, 4, 'valid')

self.conv2 = layers.Conv2D(128, 5, 4, 'valid')
self.bn2 = layers.BatchNormalization()

self.conv3 = layers.Conv2D(256, 5, 4, 'valid')
self.bn3 = layers.BatchNormalization()

self.conv4 = layers.Conv2D(512, 5, 4, 'valid')
self.bn4 = layers.BatchNormalization()

self.flatten = layers.Flatten()

self.fc = layers.Dense(1)

def call(self, inputs, training=None, mask=None):
x = tf.nn.leaky_relu(self.conv1(inputs))
x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))

# [b, h, w, c] => [b, -1]
x = self.flatten(x)
print(x.shape)
logits = self.fc(x)
return logits


def main():
img_dim = 96
z_dim = 100
num_layers = int(np.log2(img_dim)) - 3
max_num_channels = img_dim * 8
f_size = img_dim // 2 ** (num_layers + 1)
batch_size = 256

print(num_layers)
print(max_num_channels)
print(f_size)

d = Discriminator()
g = Generator()

x = tf.random.normal([2, 96, 96, 3])
z = tf.random.normal([2, 100])

prob = d(x)
print(prob.shape)

# pic = g(z)
# print(pic.shape)


if __name__ == '__main__':
main()
Loading

0 comments on commit 6e71bbd

Please sign in to comment.