forked from ok1zjf/VASNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsys_utils.py
152 lines (121 loc) · 4.53 KB
/
sys_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
__author__ = 'Jiri Fajtl'
__email__ = 'ok1zjf@gmail.com'
__version__= '3.6'
__status__ = "Research"
__date__ = "1/12/2018"
__license__= "MIT License"
import os
import numpy as np
import subprocess
import platform
import sys
import pkg_resources
import torch
import h5py
import json
import ortools
from torch.nn.modules.module import _addindent
# import PIL as Image
# import cv2
def list_files(path, extensions=[], sort=True, max_len=-1):
if os.path.isdir(path):
filenames = [os.path.join(path, fn) for fn in os.listdir(path) if
any([fn.lower().endswith(ext) for ext in extensions])]
else:
print("ERROR. ", path,' is not a directory!')
return []
if sort:
filenames.sort()
if max_len>-1:
filenames = filenames[:max_len]
return filenames
def del_file(filename):
try:
os.remove(filename)
except:
pass
return
def get_video_list(video_path, max_len=-1, extensions=['avi', 'flv', 'mpg', 'mp4']):
return list_files(video_path, extensions=extensions , sort=True, max_len=max_len)
def get_image_list(video_path, max_len=-1):
return list_files(video_path, extensions=['jpg', 'jpeg', 'png'], sort=True, max_len=max_len)
def run_command(command):
p = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
return '\n'.join([ '\t'+line.decode("utf-8").strip() for line in p.stdout.readlines()])
def ge_pkg_versions():
dep_versions = {}
dep_versions['display'] = run_command('cat /proc/driver/nvidia/version')
dep_versions['cuda'] = 'NA'
cuda_home = '/usr/local/cuda/'
if 'CUDA_HOME' in os.environ:
cuda_home = os.environ['CUDA_HOME']
cmd = cuda_home+'/version.txt'
if os.path.isfile(cmd):
dep_versions['cuda'] = run_command('cat '+cmd)
dep_versions['cudnn'] = torch.backends.cudnn.version()
dep_versions['platform'] = platform.platform()
dep_versions['python'] = sys.version_info[:3]
dep_versions['torch'] = torch.__version__
dep_versions['numpy'] = np.__version__
dep_versions['h5py'] = h5py.__version__
dep_versions['json'] = json.__version__
dep_versions['ortools'] = ortools.__version__
dep_versions['torchvision'] = pkg_resources.get_distribution("torchvision").version
# dep_versions['PIL'] = Image.VERSION
# dep_versions['OpenCV'] = 'NA'
# if 'cv2' in sys.modules:
# dep_versions['OpenCV'] = cv2.__version__
return dep_versions
def print_pkg_versions():
print("Packages & system versions:")
print("----------------------------------------------------------------------")
versions = ge_pkg_versions()
for key, val in versions.items():
print(key,": ",val)
print("")
return
def torch_summarize(model, show_weights=True, show_parameters=True):
"""Summarizes torch model by showing trainable parameters and weights."""
tmpstr = model.__class__.__name__ + ' (\n'
parameters = 0
convs = 0
for key, module in model._modules.items():
# if it contains layers let call it recursively to get params and weights
if type(module) in [torch.nn.modules.container.Container, torch.nn.modules.container.Sequential]:
modstr, p, cnvs = torch_summarize(module)
parameters += p
convs += cnvs
else:
modstr = module.__repr__()
convs += len(modstr.split('Conv2d')) - 1
modstr = _addindent(modstr, 2)
# if 'conv' in key:
# convs += 1
params = sum([np.prod(p.size()) for p in module.parameters()])
parameters += params
weights = tuple([tuple(p.size()) for p in module.parameters()])
tmpstr += ' (' + key + '): ' + modstr
if show_weights:
tmpstr += ', weights={}'.format(weights)
if show_parameters:
tmpstr += ', parameters={} / {}'.format(params, parameters)
tmpstr += ', convs={}'.format(convs)
tmpstr += '\n'
tmpstr = tmpstr + ')'
return tmpstr, parameters, convs
def print_table(table, cell_width=[3,35,8]):
slen=sum(cell_width)+len(cell_width)*2+2
print('-'*slen)
header = table.pop(0)
for i, head in enumerate(header):
print(' {name: <{alignment}}'.format(name=head, alignment=cell_width[i]), end='')
print('')
print('='*slen)
for row in table:
for i, val in enumerate(row):
print(' {val: <{alignment}}'.format(val=val, alignment=cell_width[i]), end='')
print('')
print('-'*slen)
if __name__ == "__main__":
# Tests
print_pkg_versions()