-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_GAN.py
68 lines (47 loc) · 2 KB
/
test_GAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""
Builds and tests a forward pass through a model.
"""
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import Model
def test_generator(generator:Model, example_count=1):
"""
Generates and displays a simulated spectrogram from a generator model.
generator: The generator of a GAN, taking as input some latent vector.
example_count: The number of simulated spectrograms to generate.
"""
# Test the forward pass through the generator:
_batch_size, latent_dimensions = generator.input_shape
noise = tf.random.normal([example_count, latent_dimensions])
simulated_spectrogram = generator(noise, training=False)
print('Simulated spectrogram shape:', simulated_spectrogram.shape)
plt.imshow(simulated_spectrogram[0], cmap='gray')
# reconstructed_signal = tf.signal.inverse_stft(
# stfts=tf.cast(test_simulation, tf.complex64),
# frame_length=FT_frame_length,
# frame_step=FT_frame_step,
# window_fn=tf.signal.inverse_stft_window_fn(FT_frame_step),
# )
# plt.plot(reconstructed_signal[0])
plt.show()
def test_discriminator(discriminator):
raise NotImplementedError
# Test the forward pass through the discriminator:
# evaluation = discriminator(test_simulation)
# print('Evaluation score:', evaluation.numpy().item())
def test_gan(gan):
raise NotImplementedError
test_generator(gan.generator)
test_discriminator(gan.discriminator)
if __name__ == '__main__':
# This is a workaround while I migrate the data pipeline section
# # Ideally you could `from MockingBot import SAVED_MODELS_PATH`
import pathlib
SAVED_MODELS_PATH = pathlib.Path('./saved_models')
# Choose a saved model folder here:
model_name = 'Dense4CentNet-ReLU/v1'
generator = tf.keras.models.load_model(
SAVED_MODELS_PATH / model_name / 'generator')
discriminator = tf.keras.models.load_model(
SAVED_MODELS_PATH / model_name / 'discriminator')
test_generator(generator)