-
Notifications
You must be signed in to change notification settings - Fork 82
/
options.py
49 lines (45 loc) · 1.79 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class TrainingOptions:
"""
Configuration options for the training
"""
def __init__(self,
batch_size: int,
number_of_epochs: int,
train_folder: str, validation_folder: str, runs_folder: str,
start_epoch: int, experiment_name: str):
self.batch_size = batch_size
self.number_of_epochs = number_of_epochs
self.train_folder = train_folder
self.validation_folder = validation_folder
self.runs_folder = runs_folder
self.start_epoch = start_epoch
self.experiment_name = experiment_name
class HiDDenConfiguration():
"""
The HiDDeN network configuration.
"""
def __init__(self, H: int, W: int, message_length: int,
encoder_blocks: int, encoder_channels: int,
decoder_blocks: int, decoder_channels: int,
use_discriminator: bool,
use_vgg: bool,
discriminator_blocks: int, discriminator_channels: int,
decoder_loss: float,
encoder_loss: float,
adversarial_loss: float,
enable_fp16: bool = False):
self.H = H
self.W = W
self.message_length = message_length
self.encoder_blocks = encoder_blocks
self.encoder_channels = encoder_channels
self.use_discriminator = use_discriminator
self.use_vgg = use_vgg
self.decoder_blocks = decoder_blocks
self.decoder_channels = decoder_channels
self.discriminator_blocks = discriminator_blocks
self.discriminator_channels = discriminator_channels
self.decoder_loss = decoder_loss
self.encoder_loss = encoder_loss
self.adversarial_loss = adversarial_loss
self.enable_fp16 = enable_fp16