-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
124 lines (105 loc) · 2.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#
# coding=utf-8
import os
import numpy as np
import matplotlib.pyplot as plt
import torch as th
import argparse
from config import PICS_PATH
def norm(x):
"""
归一化
:param x: tensor
:return: ret: tensor
"""
x_max = th.max(x)
x_min = th.min(x)
ret = (x - x_min) / (x_max - x_min)
return ret
def norm_np(x):
"""
归一化
:param x: ndarray
:return: ret: ndarray
"""
x_max = np.max(x)
x_min = np.min(x)
ret = (x - x_min) / (x_max - x_min)
return ret
def std(x):
"""
:param x: tensor
:return: ret: tensor
"""
mu = th.mean(x)
std = th.std(x)
ret = (x - mu) / std
return ret
def file_w(lst, file_name, file_path='results/'):
"""
:param reward: float
:param file_name: string
:param file_path: string
:return:
"""
s = ','.join([str(ele) for ele in lst])
with open(file_path+file_name, 'w') as f:
f.write(s)
def file_r(file_name, file_path='results/'):
"""
:param file_name: string
:param file_path: string
:return: list(float)
"""
with open(file_path+file_name, 'r') as f:
s = f.read()
content_list = [float(ele) for ele in s.split(',')]
return content_list
def make_pic(lst, title_name):
"""
:param lst: list(float)
:param title_name: string
:return:
"""
x = np.arange(1, len(lst) + 1)
print("length: {}".format(len(lst)))
y = np.array(lst)
plt.title(title_name)
plt.plot(x, y)
plt.show()
def make_pics(title_name):
_, _, files = os.walk('results/').__next__()
for file in files:
file_lst = file.split('@')
if title_name == file_lst[0]:
lst = file_r(file)
plt.title(file_lst[0])
x = np.arange(1, len(lst)+1)
y = np.array(lst)
label = file_lst[1] # 添加图的标签
plt.plot(x, y, label=label)
plt.legend() # 添加图例
plt.show()
if __name__ == '__main__':
# 创建解析器
parser = argparse.ArgumentParser()
parser.add_argument('--pics', type=str, default='', help='input name need to make pics')
args = parser.parse_args()
# reward
if args.pics == 'reward':
make_pics('reward')
elif args.pics == 'moving_average_reward':
make_pics('moving_average_reward')
# loss
elif args.pics == 'loss':
make_pics('loss')
elif args.pics == 'a_loss':
make_pics('a_loss')
elif args.pics == 'c_loss':
make_pics('c_loss')
# grad
elif args.pics == 'grad':
make_pics('grad')
else:
lst = [1,2]
file_w(lst, 'hahah.txt')