forked from burakcan-izmirli/DeepResponse
-
Notifications
You must be signed in to change notification settings - Fork 2
/
deep_response.py
43 lines (30 loc) · 1.36 KB
/
deep_response.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
""" Main file of deep response """
from comet_ml import Experiment
import logging
import tensorflow as tf
from helper.argument_parser import argument_parser
from helper.seed_setter import set_seed
from src.strategy_creator import StrategyCreator
tf.config.run_functions_eagerly(True)
class DeepResponse(StrategyCreator):
""" DeepResponse"""
def main(self):
""" Main function """
logging.info("DeepResponse was started.")
comet = self.get_comet_strategy().integrate_comet()
set_seed(self.random_state)
split_strategy = self.get_split_strategy()
dataset_strategy = split_strategy['dataset']
model_training_strategy = split_strategy['training']
learning_task_strategy = self.get_learning_task_strategy()
raw_dataset = dataset_strategy.read_and_shuffle_dataset(self.random_state)
dataset_iterator = dataset_strategy.prepare_dataset(
raw_dataset, self.split_type, self.batch_size, self.random_state, learning_task_strategy
)
model_creation_strategy = self.get_model_creation_strategy()
model_training_strategy.train_and_evaluate_model(
model_creation_strategy, dataset_iterator, self.batch_size,
self.learning_rate, self.epoch, comet, learning_task_strategy
)
if __name__ == '__main__':
DeepResponse(*argument_parser()).main()