1
1
import os
2
2
from enum import Enum
3
- from typing import Any , Dict , List
3
+ from typing import Any , Dict , List , Optional
4
4
5
5
from tensorflow import keras
6
6
from tensorflow .keras .models import Model
@@ -20,7 +20,7 @@ class ModelTrainType(Enum):
20
20
21
21
22
22
class ModelData :
23
- def __init__ (self , name : str , description : str = None , model : Model = None ):
23
+ def __init__ (self , name : str , description : Optional [ str ] = None , model : Optional [ Model ] = None ) -> None :
24
24
self .name : str = name
25
25
self .model : Model = model if model is not None else keras .models .load_model (
26
26
self .get_model_path ())
@@ -34,7 +34,7 @@ def __init__(self, name: str, description: str = None, model: Model = None):
34
34
self .data_file .read_data ()
35
35
36
36
def set_parameter (self , batch_size : int , epochs : int , layer_data : List [int ], learning_rate : float ,
37
- training_samples : int , test_samples : int ):
37
+ training_samples : int , test_samples : int ) -> None :
38
38
self .data ['batch_size' ] = batch_size
39
39
self .data ['epochs' ] = epochs
40
40
self .data ['layer_data' ] = layer_data
@@ -44,20 +44,20 @@ def set_parameter(self, batch_size: int, epochs: int, layer_data: List[int], lea
44
44
self .data ['test_samples' ] = test_samples
45
45
46
46
def set_initial_performance (self , test_loss : float , test_accuracy : float , train_loss : float , train_accuracy : float ,
47
- classification_report : Any ):
47
+ classification_report : Any ) -> None :
48
48
self .data ['test_loss' ] = str (test_loss )
49
49
self .data ['test_accuracy' ] = str (test_accuracy )
50
50
self .data ['train_loss' ] = str (train_loss )
51
51
self .data ['train_accuracy' ] = str (train_accuracy )
52
52
self .data ['classification_report' ] = classification_report
53
53
54
- def set_class_selection (self , class_selection : List [int ]):
54
+ def set_class_selection (self , class_selection : List [int ]) -> None :
55
55
importance : dict = dict ()
56
56
importance ['class_selection' ] = class_selection
57
57
self .data_file .append_main_data ('processed' , 'importance' , importance )
58
58
self .data_file .write_data ()
59
59
60
- def set_importance_type (self , importance_type : int ):
60
+ def set_importance_type (self , importance_type : int ) -> None :
61
61
importance : dict = dict ()
62
62
importance ['importance_type' ] = importance_type
63
63
self .data_file .append_main_data ('processed' , 'importance' , importance )
@@ -66,31 +66,31 @@ def set_importance_type(self, importance_type: int):
66
66
def get_num_classes (self ) -> int :
67
67
return self .data_file .data_cache ['overall' ]['basic_model' ]['num_classes' ]
68
68
69
- def get_class_selection (self ) -> List [int ] or None :
69
+ def get_class_selection (self ) -> Optional [ List [int ]] :
70
70
return self .data_file .data_cache ['processed' ]['importance' ]['class_selection' ]
71
71
72
72
def get_importance_type (self ) -> int :
73
73
return self .data_file .data_cache ['processed' ]['importance' ]['importance_type' ]
74
74
75
- def store_model_data (self ):
75
+ def store_model_data (self ) -> None :
76
76
self .data_file .append_main_data ('overall' , 'basic_model' , self .data )
77
77
self .data_file .write_data ()
78
78
79
- def store_main_data (self , key : str , sub_key : str , data : Dict [Any , Any ]):
79
+ def store_main_data (self , key : str , sub_key : str , data : Dict [Any , Any ]) -> None :
80
80
self .data_file .append_main_data (key , sub_key , data )
81
81
self .data_file .write_data ()
82
82
83
- def store_data (self , key : str , sub_key : str , sub_sub_key : str , data : Dict [Any , Any ]):
83
+ def store_data (self , key : str , sub_key : str , sub_sub_key : str , data : Dict [Any , Any ]) -> None :
84
84
self .data_file .append_data (key , sub_key , sub_sub_key , data )
85
85
self .data_file .write_data ()
86
86
87
- def save_model (self ):
87
+ def save_model (self ) -> None :
88
88
path : str = DATA_PATH + 'model/' + self .name + '/tf_model'
89
89
if not os .path .exists (path ):
90
90
os .makedirs (path )
91
91
self .model .save (path )
92
92
93
- def reload_model (self ):
93
+ def reload_model (self ) -> None :
94
94
self .model = keras .models .load_model (self .get_model_path ())
95
95
self .check_model_supported_layer ()
96
96
@@ -100,10 +100,10 @@ def get_model_path(self) -> str:
100
100
def get_path (self ) -> str :
101
101
return DATA_PATH + 'model/' + self .name + '/'
102
102
103
- def save_data (self ):
103
+ def save_data (self ) -> None :
104
104
self .data_file .write_data ()
105
105
106
- def check_model_supported_layer (self ):
106
+ def check_model_supported_layer (self ) -> None :
107
107
for index , layer in enumerate (self .model .layers ):
108
108
if layer .__class__ .__name__ not in SUPPORTED_LAYER and layer .__class__ .__name__ not in IGNORED_LAYER :
109
109
raise Exception (
0 commit comments