Skip to content

Commit 7e0fb81

Browse files
committed
reduced typing inconsitencies
1 parent 8f0a473 commit 7e0fb81

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+546
-527
lines changed

.pre-commit-config.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
repos:
2+
#- repo: https://github.com/pre-commit/mirrors-mypy
3+
# rev: v0.991
4+
# hooks:
5+
# - id: mypy
6+
# args: [--ignore-missing-imports, --disallow-untyped-defs, --disallow-incomplete-defs, --disallow-untyped-calls]
27
- repo: https://github.com/timothycrosley/isort
38
rev: 5.10.1
49
hooks:

data/__init__.py

Whitespace-only changes.

data/mnist_data_handler.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import os
3-
from typing import Any, List, Tuple
3+
from typing import Any, List, Optional, Tuple
44

55
import numpy as np
66
from tensorflow import keras
@@ -31,7 +31,7 @@ def get_basic_data(categorical: bool = False) -> Tuple[Tuple[Any, Any], Tuple[An
3131
return (x_train, y_train), (x_test, y_test), input_shape, num_classes
3232

3333

34-
def get_prepared_data(class_selection: List[int] = None) -> Tuple[Tuple[Any, Any], Tuple[Any, Any], Any, Any]:
34+
def get_prepared_data(class_selection: Optional[List[int]] = None) -> Tuple[Tuple[Any, Any], Tuple[Any, Any], Any, Any]:
3535
(x_train, y_train), (x_test, y_test), input_shape, num_classes = get_basic_data()
3636

3737
if class_selection is not None:
@@ -61,7 +61,7 @@ def get_prepared_data(class_selection: List[int] = None) -> Tuple[Tuple[Any, Any
6161
return (x_train, y_train), (x_test, y_test), input_shape, num_classes
6262

6363

64-
def get_unbalance_data(main_class: int, other_class_percentage: float, class_selection: List[int] = None) \
64+
def get_unbalance_data(main_class: int, other_class_percentage: float, class_selection: Optional[List[int]] = None) \
6565
-> Tuple[Tuple[Any, Any], Tuple[Any, Any], Any, Any]:
6666
(x_train, y_train), (x_test, y_test), input_shape, num_classes = get_basic_data()
6767

@@ -108,7 +108,7 @@ def get_unbalance_data(main_class: int, other_class_percentage: float, class_sel
108108
return (x_train, y_train), (x_test, y_test), input_shape, num_classes
109109

110110

111-
def split_mnist_data(class_selection: List[int] = None):
111+
def split_mnist_data(class_selection: Optional[List[int]] = None) -> None:
112112
(x_train, y_train), (x_test, y_test), input_shape, num_classes = get_basic_data()
113113
logging.info('splitting %i train examples' % x_train.shape[0])
114114
logging.info('splitting %i test examples' % x_test.shape[0])

data/model_data.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from enum import Enum
3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Optional
44

55
from tensorflow import keras
66
from tensorflow.keras.models import Model
@@ -20,7 +20,7 @@ class ModelTrainType(Enum):
2020

2121

2222
class ModelData:
23-
def __init__(self, name: str, description: str = None, model: Model = None):
23+
def __init__(self, name: str, description: Optional[str] = None, model: Optional[Model] = None) -> None:
2424
self.name: str = name
2525
self.model: Model = model if model is not None else keras.models.load_model(
2626
self.get_model_path())
@@ -34,7 +34,7 @@ def __init__(self, name: str, description: str = None, model: Model = None):
3434
self.data_file.read_data()
3535

3636
def set_parameter(self, batch_size: int, epochs: int, layer_data: List[int], learning_rate: float,
37-
training_samples: int, test_samples: int):
37+
training_samples: int, test_samples: int) -> None:
3838
self.data['batch_size'] = batch_size
3939
self.data['epochs'] = epochs
4040
self.data['layer_data'] = layer_data
@@ -44,20 +44,20 @@ def set_parameter(self, batch_size: int, epochs: int, layer_data: List[int], lea
4444
self.data['test_samples'] = test_samples
4545

4646
def set_initial_performance(self, test_loss: float, test_accuracy: float, train_loss: float, train_accuracy: float,
47-
classification_report: Any):
47+
classification_report: Any) -> None:
4848
self.data['test_loss'] = str(test_loss)
4949
self.data['test_accuracy'] = str(test_accuracy)
5050
self.data['train_loss'] = str(train_loss)
5151
self.data['train_accuracy'] = str(train_accuracy)
5252
self.data['classification_report'] = classification_report
5353

54-
def set_class_selection(self, class_selection: List[int]):
54+
def set_class_selection(self, class_selection: List[int]) -> None:
5555
importance: dict = dict()
5656
importance['class_selection'] = class_selection
5757
self.data_file.append_main_data('processed', 'importance', importance)
5858
self.data_file.write_data()
5959

60-
def set_importance_type(self, importance_type: int):
60+
def set_importance_type(self, importance_type: int) -> None:
6161
importance: dict = dict()
6262
importance['importance_type'] = importance_type
6363
self.data_file.append_main_data('processed', 'importance', importance)
@@ -66,31 +66,31 @@ def set_importance_type(self, importance_type: int):
6666
def get_num_classes(self) -> int:
6767
return self.data_file.data_cache['overall']['basic_model']['num_classes']
6868

69-
def get_class_selection(self) -> List[int] or None:
69+
def get_class_selection(self) -> Optional[List[int]]:
7070
return self.data_file.data_cache['processed']['importance']['class_selection']
7171

7272
def get_importance_type(self) -> int:
7373
return self.data_file.data_cache['processed']['importance']['importance_type']
7474

75-
def store_model_data(self):
75+
def store_model_data(self) -> None:
7676
self.data_file.append_main_data('overall', 'basic_model', self.data)
7777
self.data_file.write_data()
7878

79-
def store_main_data(self, key: str, sub_key: str, data: Dict[Any, Any]):
79+
def store_main_data(self, key: str, sub_key: str, data: Dict[Any, Any]) -> None:
8080
self.data_file.append_main_data(key, sub_key, data)
8181
self.data_file.write_data()
8282

83-
def store_data(self, key: str, sub_key: str, sub_sub_key: str, data: Dict[Any, Any]):
83+
def store_data(self, key: str, sub_key: str, sub_sub_key: str, data: Dict[Any, Any]) -> None:
8484
self.data_file.append_data(key, sub_key, sub_sub_key, data)
8585
self.data_file.write_data()
8686

87-
def save_model(self):
87+
def save_model(self) -> None:
8888
path: str = DATA_PATH + 'model/' + self.name + '/tf_model'
8989
if not os.path.exists(path):
9090
os.makedirs(path)
9191
self.model.save(path)
9292

93-
def reload_model(self):
93+
def reload_model(self) -> None:
9494
self.model = keras.models.load_model(self.get_model_path())
9595
self.check_model_supported_layer()
9696

@@ -100,10 +100,10 @@ def get_model_path(self) -> str:
100100
def get_path(self) -> str:
101101
return DATA_PATH + 'model/' + self.name + '/'
102102

103-
def save_data(self):
103+
def save_data(self) -> None:
104104
self.data_file.write_data()
105105

106-
def check_model_supported_layer(self):
106+
def check_model_supported_layer(self) -> None:
107107
for index, layer in enumerate(self.model.layers):
108108
if layer.__class__.__name__ not in SUPPORTED_LAYER and layer.__class__.__name__ not in IGNORED_LAYER:
109109
raise Exception(

definitions.py

+36
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import os
2+
from enum import IntEnum, IntFlag, auto
3+
from typing import List
4+
5+
from pyrr import Vector3
26

37
BASE_PATH = os.path.dirname(os.path.realpath(__file__))
48
DATA_PATH = BASE_PATH + '/storage/data/'
@@ -7,6 +11,38 @@
711
ADDITIONAL_EDGE_BUFFER_DATA: int = 8
812

913

14+
class ProcessRenderMode(IntFlag):
15+
FINAL = auto()
16+
NODE_ITERATIONS = auto()
17+
EDGE_ITERATIONS = auto()
18+
SMOOTHING = auto()
19+
20+
21+
class CameraPose(IntEnum):
22+
FRONT = 0
23+
RIGHT = 1
24+
LEFT = 2
25+
LOWER_BACK_RIGHT = 3
26+
BACK_RIGHT = 4
27+
UPPER_BACK_LEFT = 5
28+
UPPER_BACK_RIGHT = 6
29+
BACK = 7
30+
DEFAULT = 8
31+
32+
33+
CAMERA_POSE_POSITION: List[Vector3] = [
34+
Vector3([3.5, 0.0, 0.0]),
35+
Vector3([0.0, 0.0, 2.5]),
36+
Vector3([0.0, 0.0, -2.5]),
37+
Vector3([-2.75, -1.0, 1.25]),
38+
Vector3([-2.5, 0.0, 2.5]),
39+
Vector3([-2.0, 2.0, -2.0]),
40+
Vector3([-2.0, 2.0, 2.0]),
41+
Vector3([-4.0, 0.0, 0.0]),
42+
Vector3([-3.5, 0.0, 0.0])
43+
]
44+
45+
1046
def pairwise(it, size: int):
1147
it = iter(it)
1248
while True:

evaluation/__init__.py

Whitespace-only changes.

evaluation/create_plot.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from utility.file import EvaluationFile
1111

1212

13-
def setup_plot():
13+
def setup_plot() -> None:
1414
plt.rc('font', size=14)
1515
plt.rc('axes', titlesize=14)
1616
plt.rc('axes', labelsize=14)
@@ -26,7 +26,7 @@ def load_data(name: str, importance_name: str, timed_name: bool = False) -> Dict
2626
return evaluation_file.data_cache[importance_name]
2727

2828

29-
def save_plot(name: str):
29+
def save_plot(name: str) -> None:
3030
directory_path: str = os.path.join(
3131
BASE_PATH, os.path.join('storage', 'evaluation'))
3232
if not os.path.exists(directory_path):
@@ -37,7 +37,7 @@ def save_plot(name: str):
3737
plt.savefig(file_path)
3838

3939

40-
def create_importance_plot(filename: str, importance_name: str, timed_name: bool = False, show: bool = False):
40+
def create_importance_plot(filename: str, importance_name: str, timed_name: bool = False, show: bool = False) -> None:
4141
data: Dict[Any, Any] = load_data(filename, importance_name, timed_name)
4242

4343
converted_data: List[List[Any]] = []
@@ -76,7 +76,7 @@ def create_importance_plot(filename: str, importance_name: str, timed_name: bool
7676

7777

7878
def create_importance_plot_compare_regularizer(filename: str, importance_names: List[str], check_importance_type: str,
79-
timed_name: bool = False, show: bool = False):
79+
timed_name: bool = False, show: bool = False) -> None:
8080
plt.rcParams['legend.loc'] = 'lower left'
8181
converted_data: List[List[Any]] = []
8282
for importance_name in importance_names:
@@ -117,7 +117,7 @@ def create_importance_plot_compare_regularizer(filename: str, importance_names:
117117

118118

119119
def create_importance_plot_compare_bn_parameter(filename: str, importance_names: List[str], check_importance_type: str,
120-
timed_name: bool = False, show: bool = False):
120+
timed_name: bool = False, show: bool = False) -> None:
121121
converted_data: List[List[Any]] = []
122122
for importance_name in importance_names:
123123
data: Dict[Any, Any] = load_data(filename, importance_name, timed_name)
@@ -154,7 +154,7 @@ def create_importance_plot_compare_bn_parameter(filename: str, importance_names:
154154
def create_importance_plot_compare_class_vs_all(filename: str, importance_name: str, class_index: int,
155155
check_importance_type: str, class_specific_data: bool = True,
156156
timed_name: bool = False,
157-
show: bool = False):
157+
show: bool = False) -> None:
158158
converted_data: List[List[Any]] = []
159159

160160
importance_data_name: str = '%s_[%s]' % (
@@ -191,7 +191,7 @@ def create_importance_plot_compare_class_vs_all(filename: str, importance_name:
191191

192192
def create_importance_plot_compare_classes_vs_all(filename: str, importance_name: str, check_importance_type: str,
193193
class_specific_data: bool = True, timed_name: bool = False,
194-
show: bool = False):
194+
show: bool = False) -> None:
195195
converted_data: List[List[Any]] = []
196196

197197
for i in range(10):

evaluation/evaluator.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import time
3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Optional
44

55
import numpy as np
66
from tensorflow import keras
@@ -13,27 +13,27 @@
1313

1414

1515
class ImportanceEvaluator:
16-
def __init__(self, model_data: ModelData):
16+
def __init__(self, model_data: ModelData) -> None:
1717
self.model_data: ModelData = model_data
1818
self.importance_type: ImportanceType = ImportanceType(
1919
model_data.get_importance_type())
2020
self.importance_calculation: ImportanceCalculation = ImportanceCalculation.BNN_EDGE
21-
self.relevant_classes: List[int] or None = None
21+
self.relevant_classes: Optional[List[int]] = None
2222

2323
self.x_train = None
2424
self.y_train = None
2525
self.x_test = None
2626
self.y_test = None
2727

28-
def set_train_and_test_data(self, x_train, y_train, x_test, y_test):
28+
def set_train_and_test_data(self, x_train, y_train, x_test, y_test) -> None:
2929
self.x_train = x_train
3030
self.y_train = y_train
3131
self.x_test = x_test
3232
self.y_test = y_test
3333

34-
def setup(self, importance_type: ImportanceType,
34+
def setup(self, importance_type: ImportanceType = ImportanceType.L1,
3535
importance_calculation: ImportanceCalculation = ImportanceCalculation.BNN_EDGE,
36-
relevant_classes: List[int] = None):
36+
relevant_classes: Optional[List[int]] = None) -> None:
3737
self.importance_type = importance_type
3838
self.importance_calculation = importance_calculation
3939
self.relevant_classes = relevant_classes
@@ -62,7 +62,7 @@ def get_importance(self, edge_alpha: float, classes_importance: List[float]) ->
6262
def prune_model(self,
6363
importance_prune_percent: str,
6464
importance_data: ImportanceDataHandler,
65-
importance_threshold: float):
65+
importance_threshold: float) -> None:
6666
data: Dict[Any, Any] = dict()
6767

6868
pruned_edges: int = 0
@@ -127,7 +127,7 @@ def accuracy_report(self, truths: np.array, predictions: np.array) -> Dict[str,
127127
true_positive_rate + true_negative_rate) / 2.0
128128
return accuracy_report
129129

130-
def test_model(self, importance_prune_percent: str):
130+
def test_model(self, importance_prune_percent: str) -> None:
131131
self.model_data.model.compile(loss=keras.losses.categorical_crossentropy,
132132
optimizer=keras.optimizers.Adam(0.001),
133133
metrics=['accuracy'])
@@ -171,7 +171,7 @@ def test_model(self, importance_prune_percent: str):
171171
self.importance_calculation.name,
172172
importance_prune_data)
173173

174-
def create_evaluation_data(self, step_size: int = 1, start_percentage: int = 0, end_percentage: int = 100):
174+
def create_evaluation_data(self, step_size: int = 1, start_percentage: int = 0, end_percentage: int = 100) -> None:
175175
importance_data: ImportanceDataHandler = ImportanceDataHandler(
176176
self.model_data.get_path() + get_importance_type_name(self.importance_type) + '.imp.npz')
177177

examples/__init__.py

Whitespace-only changes.

examples/create_images.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33

44
sys.path.append(os.path.abspath(os.path.join(
5-
os.path.dirname(sys.modules[__name__].__file__), '..')))
5+
os.path.dirname(sys.modules[__name__].__file__), '..'))) # type: ignore
66

77
if True:
88
import matplotlib.pyplot as plt

examples/evaluate_importance_data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33

44
sys.path.append(os.path.abspath(os.path.join(
5-
os.path.dirname(sys.modules[__name__].__file__), '..')))
5+
os.path.dirname(sys.modules[__name__].__file__), '..'))) # type: ignore
66

77
if True:
88
from data.mnist_data_handler import get_prepared_data

examples/evaluation_plots.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import sys
55

66
sys.path.append(os.path.abspath(os.path.join(
7-
os.path.dirname(sys.modules[__name__].__file__), '..')))
7+
os.path.dirname(sys.modules[__name__].__file__), '..'))) # type: ignore
88

99
if True:
1010
from typing import List

examples/process_mnist_model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os.path
22
import sys
3-
from typing import List
3+
from typing import List, Optional
44

55
sys.path.append(os.path.abspath(os.path.join(
6-
os.path.dirname(sys.modules[__name__].__file__), '..')))
6+
os.path.dirname(sys.modules[__name__].__file__), '..'))) # type: ignore
77

88
if True:
99
from data.mnist_data_handler import split_mnist_data
@@ -22,7 +22,7 @@
2222

2323
# -------------------------------------------------change these settings-----------------------------------------------#
2424
name: str = 'default'
25-
class_selection: List[int] or None = None # [0, 1, 2, 3, 4]
25+
class_selection: Optional[List[int]] = None # [0, 1, 2, 3, 4]
2626
importance_type: ImportanceType = ImportanceType(
2727
ImportanceType.GAMMA | ImportanceType.L1)
2828

gui/__init__.py

Whitespace-only changes.

gui/frame_building.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from rendering.shader_uniforms import EDGE_SHADER_UNIFORM, NODE_SHADER_UNIFORM
1111

1212

13-
def set_stat_frame(gui_root: Tk, settings: Dict[Any, Any]):
13+
def set_stat_frame(gui_root: Tk, settings: Dict[Any, Any]) -> None:
1414
stats_frame: LabelFrame = LabelFrame(
1515
gui_root, text='Statistics', width=60, padx=5, pady=5)
1616
stats_frame.grid(row=0, column=0, padx=5, pady=5)

0 commit comments

Comments
 (0)