-
Notifications
You must be signed in to change notification settings - Fork 116
/
configs.py
60 lines (47 loc) · 2.06 KB
/
configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import json
import os
from api_trt.logger import logger
class Configs(object):
def __init__(self, models_dir: str = '/models'):
self.models_dir = self.__get_param('MODELS_DIR', models_dir)
self.onnx_models_dir = os.path.join(self.models_dir, 'onnx')
self.trt_engines_dir = os.path.join(self.models_dir, 'trt-engines')
self.models = self.__read_models_file()
self.type2path = dict(
onnx=self.onnx_models_dir,
engine=self.trt_engines_dir,
plan=self.trt_engines_dir
)
def __read_models_file(self):
models_default_path = os.path.join(self.models_dir, 'models.json')
models_override_path = os.path.join(self.models_dir, 'models.override.json')
models_conf = models_default_path
if os.path.exists(models_override_path):
models_conf = models_override_path
logger.warning(f"Found '{models_override_path}', using this instead of default.")
try:
models = json.load(open(models_conf, mode='r'))
return models
except FileNotFoundError as e:
e.strerror = f"The file `{models_conf}` doesn't exist"
raise e
except Exception as e:
raise e
def __get_param(self, ENV, default=None):
return os.environ.get(ENV, default)
def build_model_paths(self, model_name: str, ext: str):
base = self.type2path[ext]
parent = os.path.join(base, model_name)
file = os.path.join(parent, f"{model_name}.{ext}")
return parent, file
def get_outputs_order(self, model_name):
return self.models.get(model_name, {}).get('outputs')
def get_shape(self, model_name):
return self.models.get(model_name, {}).get('shape')
def get_dl_link(self, model_name):
return self.models.get(model_name, {}).get('link')
def get_dl_type(self, model_name):
return self.models.get(model_name, {}).get('dl_type')
def get_function(self, model_name):
return self.models.get(model_name, {}).get('function')
config = Configs()