-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 6e71bbd
Showing
3 changed files
with
423 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import multiprocessing | ||
|
||
import tensorflow as tf | ||
|
||
|
||
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1): | ||
@tf.function | ||
def _map_fn(img): | ||
img = tf.image.resize(img, [resize, resize]) | ||
img = tf.clip_by_value(img, 0, 255) | ||
img = img / 127.5 - 1 | ||
return img | ||
|
||
dataset = disk_image_batch_dataset(img_paths, | ||
batch_size, | ||
drop_remainder=drop_remainder, | ||
map_fn=_map_fn, | ||
shuffle=shuffle, | ||
repeat=repeat) | ||
img_shape = (resize, resize, 3) | ||
len_dataset = len(img_paths) // batch_size | ||
|
||
return dataset, img_shape, len_dataset | ||
|
||
|
||
def batch_dataset(dataset, | ||
batch_size, | ||
drop_remainder=True, | ||
n_prefetch_batch=1, | ||
filter_fn=None, | ||
map_fn=None, | ||
n_map_threads=None, | ||
filter_after_map=False, | ||
shuffle=True, | ||
shuffle_buffer_size=None, | ||
repeat=None): | ||
# set defaults | ||
if n_map_threads is None: | ||
n_map_threads = multiprocessing.cpu_count() | ||
if shuffle and shuffle_buffer_size is None: | ||
shuffle_buffer_size = max(batch_size * 128, 2048) # set the minimum buffer size as 2048 | ||
|
||
# [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly | ||
if shuffle: | ||
dataset = dataset.shuffle(shuffle_buffer_size) | ||
|
||
if not filter_after_map: | ||
if filter_fn: | ||
dataset = dataset.filter(filter_fn) | ||
|
||
if map_fn: | ||
dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads) | ||
|
||
else: # [*] this is slower | ||
if map_fn: | ||
dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads) | ||
|
||
if filter_fn: | ||
dataset = dataset.filter(filter_fn) | ||
|
||
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) | ||
|
||
dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch) | ||
|
||
return dataset | ||
|
||
|
||
def memory_data_batch_dataset(memory_data, | ||
batch_size, | ||
drop_remainder=True, | ||
n_prefetch_batch=1, | ||
filter_fn=None, | ||
map_fn=None, | ||
n_map_threads=None, | ||
filter_after_map=False, | ||
shuffle=True, | ||
shuffle_buffer_size=None, | ||
repeat=None): | ||
"""Batch dataset of memory data. | ||
Parameters | ||
---------- | ||
memory_data : nested structure of tensors/ndarrays/lists | ||
""" | ||
dataset = tf.data.Dataset.from_tensor_slices(memory_data) | ||
dataset = batch_dataset(dataset, | ||
batch_size, | ||
drop_remainder=drop_remainder, | ||
n_prefetch_batch=n_prefetch_batch, | ||
filter_fn=filter_fn, | ||
map_fn=map_fn, | ||
n_map_threads=n_map_threads, | ||
filter_after_map=filter_after_map, | ||
shuffle=shuffle, | ||
shuffle_buffer_size=shuffle_buffer_size, | ||
repeat=repeat) | ||
return dataset | ||
|
||
|
||
def disk_image_batch_dataset(img_paths, | ||
batch_size, | ||
labels=None, | ||
drop_remainder=True, | ||
n_prefetch_batch=1, | ||
filter_fn=None, | ||
map_fn=None, | ||
n_map_threads=None, | ||
filter_after_map=False, | ||
shuffle=True, | ||
shuffle_buffer_size=None, | ||
repeat=None): | ||
"""Batch dataset of disk image for PNG and JPEG. | ||
Parameters | ||
---------- | ||
img_paths : 1d-tensor/ndarray/list of str | ||
labels : nested structure of tensors/ndarrays/lists | ||
""" | ||
if labels is None: | ||
memory_data = img_paths | ||
else: | ||
memory_data = (img_paths, labels) | ||
|
||
def parse_fn(path, *label): | ||
img = tf.io.read_file(path) | ||
img = tf.image.decode_png(img, 3) # fix channels to 3 | ||
return (img,) + label | ||
|
||
if map_fn: # fuse `map_fn` and `parse_fn` | ||
def map_fn_(*args): | ||
return map_fn(*parse_fn(*args)) | ||
else: | ||
map_fn_ = parse_fn | ||
|
||
dataset = memory_data_batch_dataset(memory_data, | ||
batch_size, | ||
drop_remainder=drop_remainder, | ||
n_prefetch_batch=n_prefetch_batch, | ||
filter_fn=filter_fn, | ||
map_fn=map_fn_, | ||
n_map_threads=n_map_threads, | ||
filter_after_map=filter_after_map, | ||
shuffle=shuffle, | ||
shuffle_buffer_size=shuffle_buffer_size, | ||
repeat=repeat) | ||
|
||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import os | ||
|
||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
from tensorflow.keras import layers | ||
|
||
gpus = tf.config.experimental.list_physical_devices(device_type='GPU') | ||
for gpu in gpus: | ||
tf.config.experimental.set_memory_growth(gpu, True) | ||
|
||
|
||
class Generator(keras.Model): | ||
def __init__(self, is_training=True): | ||
super(Generator, self).__init__() | ||
# z: [b, 100] => [b, 64, 64, 3] | ||
# channel decrease, image size increase | ||
self.fc = layers.Dense(3 * 3 * 512) | ||
|
||
self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid') | ||
self.bn1 = layers.BatchNormalization() | ||
|
||
self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid') | ||
self.bn2 = layers.BatchNormalization() | ||
|
||
self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid') | ||
|
||
def call(self, inputs, training=None, mask=None): | ||
x = self.fc(inputs) | ||
x = tf.reshape(x, [-1, 3, 3, 512]) | ||
x = tf.nn.leaky_relu(x) | ||
# | ||
x = self.bn1(self.conv1(x), training=training) | ||
x = tf.nn.leaky_relu(x) | ||
|
||
x = self.bn2(self.conv2(x), training=training) | ||
x = tf.nn.leaky_relu(x) | ||
|
||
x = self.conv3(x) | ||
x = tf.tanh(x) | ||
|
||
return x | ||
|
||
|
||
class Discriminator(keras.Model): | ||
def __init__(self): | ||
super(Discriminator, self).__init__() | ||
# [b, 64, 64, 3] => [b, 1] | ||
self.conv1 = layers.Conv2D(64, 5, 4, 'valid') | ||
|
||
self.conv2 = layers.Conv2D(128, 5, 4, 'valid') | ||
self.bn2 = layers.BatchNormalization() | ||
|
||
self.conv3 = layers.Conv2D(256, 5, 4, 'valid') | ||
self.bn3 = layers.BatchNormalization() | ||
|
||
self.conv4 = layers.Conv2D(512, 5, 4, 'valid') | ||
self.bn4 = layers.BatchNormalization() | ||
|
||
self.flatten = layers.Flatten() | ||
|
||
self.fc = layers.Dense(1) | ||
|
||
def call(self, inputs, training=None, mask=None): | ||
x = tf.nn.leaky_relu(self.conv1(inputs)) | ||
x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training)) | ||
x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training)) | ||
|
||
# [b, h, w, c] => [b, -1] | ||
x = self.flatten(x) | ||
print(x.shape) | ||
logits = self.fc(x) | ||
return logits | ||
|
||
|
||
def main(): | ||
img_dim = 96 | ||
z_dim = 100 | ||
num_layers = int(np.log2(img_dim)) - 3 | ||
max_num_channels = img_dim * 8 | ||
f_size = img_dim // 2 ** (num_layers + 1) | ||
batch_size = 256 | ||
|
||
print(num_layers) | ||
print(max_num_channels) | ||
print(f_size) | ||
|
||
d = Discriminator() | ||
g = Generator() | ||
|
||
x = tf.random.normal([2, 96, 96, 3]) | ||
z = tf.random.normal([2, 100]) | ||
|
||
prob = d(x) | ||
print(prob.shape) | ||
|
||
# pic = g(z) | ||
# print(pic.shape) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.