forked from jaychempan/SWAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mytools.py
185 lines (141 loc) · 4.79 KB
/
mytools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# coding:utf-8
# Author:Zhiqiang Yuan
"""导入一些包"""
import os
import time, random
import json
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
""" 打印一些东西 """
"""----------------------------------------------------------------------"""
# 打印列表按照竖行的形式
def print_list(list):
print("++++++++++++++++++++++++++++++++++++++++++++")
for l in list:
print(l)
print("++++++++++++++++++++++++++++++++++++++++++++")
# 打印字典按照竖行的形式
def print_dict(dict):
print("++++++++++++++++++++++++++++++++++++++++++++")
for k, v in dict.items():
print("key:", k, " value:", v)
print("++++++++++++++++++++++++++++++++++++++++++++")
# 打印一些东西,加入标识符
def print_with_log(info):
print("++++++++++++++++++++++++++++++++++++++++++++")
print(info)
print("++++++++++++++++++++++++++++++++++++++++++++")
# 打印标识符
def print_log():
print("++++++++++++++++++++++++++++++++++++++++++++")
""" 文件存储 """
"""----------------------------------------------------------------------"""
# 保存结果到json文件
def save_to_json(info, filename, encoding='UTF-8'):
with open(filename, "w", encoding=encoding) as f:
json.dump(info, f, indent=2, separators=(',', ':'))
# 从json文件中读取
def load_from_json(filename):
with open(filename, encoding='utf-8') as f:
info = json.load(f)
return info
# 储存为npy文件
def save_to_npy(info, filename):
np.save(filename, info, allow_pickle=True)
# 从npy中读取
def load_from_npy(filename):
info = np.load(filename, allow_pickle=True)
return info
# 保存结果到txt文件
def log_to_txt(contexts=None, filename="save.txt", mark=False, encoding='UTF-8', add_n=False):
f = open(filename, "a", encoding=encoding)
if mark:
sig = "------------------------------------------------\n"
f.write(sig)
elif isinstance(contexts, dict):
tmp = ""
for c in contexts.keys():
tmp += str(c) + " | " + str(contexts[c]) + "\n"
contexts = tmp
f.write(contexts)
else:
if isinstance(contexts, list):
tmp = ""
for c in contexts:
if add_n:
tmp += str(c) + "\n"
else:
tmp += str(c)
contexts = tmp
else:
contexts = contexts + "\n"
f.write(contexts)
f.close()
# 从txt中读取行
def load_from_txt(filename, encoding="utf-8"):
f = open(filename, 'r', encoding=encoding)
contexts = f.readlines()
return contexts
""" 字典变换 """
"""----------------------------------------------------------------------"""
# 键值互换
def dict_k_v_exchange(dict):
tmp = {}
for key, value in dict.items():
tmp[value] = key
return tmp
# 2维数组转字典
def d2array_to_dict(d2array):
# Input: N x 2 list
# Output: dict
dict = {}
for item in d2array:
if item[0] not in dict.keys():
dict[item[0]] = [item[1]]
else:
dict[item[0]].append(item[1])
return dict
""" 绘图 """
"""----------------------------------------------------------------------"""
# 绘制3D图像
def visual_3d_points(list, color=True):
"""
:param list: N x (dim +1)
N 为点的数量
dim 为 输入数据的维度
1 为类别, 即可视化的颜色 当且仅当color为True时
"""
list = np.array(list)
if color:
data = list[:, :4]
label = list[:, -1]
else:
data = list
label = None
# PCA降维
pca = PCA(n_components=3, whiten=True).fit(data)
data = pca.transform(data)
# 定义坐标轴
fig = plt.figure()
ax1 = plt.axes(projection='3d')
if label is not None:
color = label
else:
color = "blue"
ax1.scatter3D(np.transpose(data)[0], np.transpose(data)[1], np.transpose(data)[2], c=color) # 绘制散点图
plt.show()
""" 实用工具 """
"""----------------------------------------------------------------------"""
# 计算数组中元素出现的个数
def count_list(lens):
dict = {}
for key in lens:
dict[key] = dict.get(key, 0) + 1
dict = sorted(dict.items(), key=lambda x: x[1], reverse=True)
print_list(dict)
return dict
# list 加法 w1、w2为权重
def list_add(list1, list2, w1=1, w2=1):
return [l1 * w1 + l2 * w2 for (l1, l2) in zip(list1, list2)]