-
Notifications
You must be signed in to change notification settings - Fork 46
/
all_utils.py
163 lines (130 loc) · 4.95 KB
/
all_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# coding: utf-8
import os
import sys
import torch
import numpy as np
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
import scipy.stats
import pickle
import requests
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def build_model(name):
tokenizer = BertTokenizer.from_pretrained(name)
model = BertModel.from_pretrained(name)
model = model.to(DEVICE)
return tokenizer, model
def sent_to_vec(sent, tokenizer, model, pooling, max_length):
with torch.no_grad():
inputs = tokenizer(sent, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
inputs['input_ids'] = inputs['input_ids'].to(DEVICE)
inputs['token_type_ids'] = inputs['token_type_ids'].to(DEVICE)
inputs['attention_mask'] = inputs['attention_mask'].to(DEVICE)
hidden_states = model(**inputs, return_dict=True, output_hidden_states=True).hidden_states
if pooling == 'first_last_avg':
output_hidden_state = (hidden_states[-1] + hidden_states[1]).mean(dim=1)
elif pooling == 'last_avg':
output_hidden_state = (hidden_states[-1]).mean(dim=1)
elif pooling == 'last2avg':
output_hidden_state = (hidden_states[-1] + hidden_states[-2]).mean(dim=1)
elif pooling == 'cls':
output_hidden_state = (hidden_states[-1])[:, 0, :]
else:
raise Exception("unknown pooling {}".format(POOLING))
vec = output_hidden_state.cpu().numpy()[0]
return vec
def sents_to_vecs(sents, tokenizer, model, pooling, max_length, verbose=True):
vecs = []
if verbose:
sents = tqdm(sents)
for sent in sents:
vec = sent_to_vec(sent, tokenizer, model, pooling, max_length)
vecs.append(vec)
assert len(sents) == len(vecs)
vecs = np.array(vecs)
return vecs
def calc_spearmanr_corr(x, y):
return scipy.stats.spearmanr(x, y).correlation
def compute_kernel_bias(vecs):
"""计算kernel和bias
最后的变换:y = (x + bias).dot(kernel)
"""
vecs = np.concatenate(vecs, axis=0)
mu = vecs.mean(axis=0, keepdims=True)
cov = np.cov(vecs.T)
u, s, vh = np.linalg.svd(cov)
W = np.dot(u, np.diag(1/np.sqrt(s)))
return W, -mu
def save_whiten(path, kernel, bias):
whiten = {
'kernel': kernel,
'bias': bias
}
with open(path, 'wb') as f:
pickle.dump(whiten, f)
return path
def load_whiten(path):
with open(path, 'rb') as f:
whiten = pickle.load(f)
kernel = whiten['kernel']
bias = whiten['bias']
return kernel, bias
def transform_and_normalize(vecs, kernel, bias):
"""应用变换,然后标准化
"""
if not (kernel is None or bias is None):
vecs = (vecs + bias).dot(kernel)
return normalize(vecs)
def normalize(vecs):
"""标准化
"""
return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5
def http_get(url, path):
"""
Downloads a URL to a given path on disc
"""
if os.path.dirname(path) != '':
os.makedirs(os.path.dirname(path), exist_ok=True)
req = requests.get(url, stream=True)
if req.status_code != 200:
print("Exception when trying to download {}. Response {}".format(url, req.status_code), file=sys.stderr)
req.raise_for_status()
return
download_filepath = path+"_part"
with open(download_filepath, "wb") as file_binary:
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, unit_scale=True)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
file_binary.write(chunk)
os.rename(download_filepath, path)
progress.close()
import inspect
def get_size(obj, seen=None):
"""Recursively finds size of objects in bytes"""
size = sys.getsizeof(obj)
if seen is None:
seen = set()
obj_id = id(obj)
if obj_id in seen:
return 0
# Important mark as seen *before* entering recursion to gracefully handle
# self-referential objects
seen.add(obj_id)
if hasattr(obj, '__dict__'):
for cls in obj.__class__.__mro__:
if '__dict__' in cls.__dict__:
d = cls.__dict__['__dict__']
if inspect.isgetsetdescriptor(d) or inspect.ismemberdescriptor(d):
size += get_size(obj.__dict__, seen)
break
if isinstance(obj, dict):
size += sum((get_size(v, seen) for v in obj.values()))
size += sum((get_size(k, seen) for k in obj.keys()))
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
size += sum((get_size(i, seen) for i in obj))
if hasattr(obj, '__slots__'): # can have __slots__ with __dict__
size += sum(get_size(getattr(obj, s), seen) for s in obj.__slots__ if hasattr(obj, s))
return size