-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
util.py
504 lines (393 loc) · 19.2 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
import requests
from torch import Tensor, device
from typing import List, Callable
from tqdm.autonotebook import tqdm
import sys
import importlib
import os
import torch
import numpy as np
import queue
import logging
from typing import Dict, Optional, Union
from pathlib import Path
import huggingface_hub
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub import HfApi, hf_hub_url, cached_download, HfFolder
import fnmatch
from packaging import version
import heapq
logger = logging.getLogger(__name__)
def pytorch_cos_sim(a: Tensor, b: Tensor):
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
"""
return cos_sim(a, b)
def cos_sim(a: Tensor, b: Tensor):
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
if len(a.shape) == 1:
a = a.unsqueeze(0)
if len(b.shape) == 1:
b = b.unsqueeze(0)
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.mm(a_norm, b_norm.transpose(0, 1))
def dot_score(a: Tensor, b: Tensor):
"""
Computes the dot-product dot_prod(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = dot_prod(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
if len(a.shape) == 1:
a = a.unsqueeze(0)
if len(b.shape) == 1:
b = b.unsqueeze(0)
return torch.mm(a, b.transpose(0, 1))
def pairwise_dot_score(a: Tensor, b: Tensor):
"""
Computes the pairwise dot-product dot_prod(a[i], b[i])
:return: Vector with res[i] = dot_prod(a[i], b[i])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
return (a * b).sum(dim=-1)
def pairwise_cos_sim(a: Tensor, b: Tensor):
"""
Computes the pairwise cossim cos_sim(a[i], b[i])
:return: Vector with res[i] = cos_sim(a[i], b[i])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
return pairwise_dot_score(normalize_embeddings(a), normalize_embeddings(b))
def normalize_embeddings(embeddings: Tensor):
"""
Normalizes the embeddings matrix, so that each sentence embedding has unit length
"""
return torch.nn.functional.normalize(embeddings, p=2, dim=1)
def paraphrase_mining(model,
sentences: List[str],
show_progress_bar: bool = False,
batch_size:int = 32,
*args,
**kwargs):
"""
Given a list of sentences / texts, this function performs paraphrase mining. It compares all sentences against all
other sentences and returns a list with the pairs that have the highest cosine similarity score.
:param model: SentenceTransformer model for embedding computation
:param sentences: A list of strings (texts or sentences)
:param show_progress_bar: Plotting of a progress bar
:param batch_size: Number of texts that are encoded simultaneously by the model
:param query_chunk_size: Search for most similar pairs for #query_chunk_size at the same time. Decrease, to lower memory footprint (increases run-time).
:param corpus_chunk_size: Compare a sentence simultaneously against #corpus_chunk_size other sentences. Decrease, to lower memory footprint (increases run-time).
:param max_pairs: Maximal number of text pairs returned.
:param top_k: For each sentence, we retrieve up to top_k other sentences
:param score_function: Function for computing scores. By default, cosine similarity.
:return: Returns a list of triplets with the format [score, id1, id2]
"""
# Compute embedding for the sentences
embeddings = model.encode(sentences, show_progress_bar=show_progress_bar, batch_size=batch_size, convert_to_tensor=True)
return paraphrase_mining_embeddings(embeddings, *args, **kwargs)
def paraphrase_mining_embeddings(embeddings: Tensor,
query_chunk_size: int = 5000,
corpus_chunk_size: int = 100000,
max_pairs: int = 500000,
top_k: int = 100,
score_function: Callable[[Tensor, Tensor], Tensor] = cos_sim):
"""
Given a list of sentences / texts, this function performs paraphrase mining. It compares all sentences against all
other sentences and returns a list with the pairs that have the highest cosine similarity score.
:param embeddings: A tensor with the embeddings
:param query_chunk_size: Search for most similar pairs for #query_chunk_size at the same time. Decrease, to lower memory footprint (increases run-time).
:param corpus_chunk_size: Compare a sentence simultaneously against #corpus_chunk_size other sentences. Decrease, to lower memory footprint (increases run-time).
:param max_pairs: Maximal number of text pairs returned.
:param top_k: For each sentence, we retrieve up to top_k other sentences
:param score_function: Function for computing scores. By default, cosine similarity.
:return: Returns a list of triplets with the format [score, id1, id2]
"""
top_k += 1 # A sentence has the highest similarity to itself. Increase +1 as we are interest in distinct pairs
# Mine for duplicates
pairs = queue.PriorityQueue()
min_score = -1
num_added = 0
for corpus_start_idx in range(0, len(embeddings), corpus_chunk_size):
for query_start_idx in range(0, len(embeddings), query_chunk_size):
scores = score_function(embeddings[query_start_idx:query_start_idx+query_chunk_size], embeddings[corpus_start_idx:corpus_start_idx+corpus_chunk_size])
scores_top_k_values, scores_top_k_idx = torch.topk(scores, min(top_k, len(scores[0])), dim=1, largest=True, sorted=False)
scores_top_k_values = scores_top_k_values.cpu().tolist()
scores_top_k_idx = scores_top_k_idx.cpu().tolist()
for query_itr in range(len(scores)):
for top_k_idx, corpus_itr in enumerate(scores_top_k_idx[query_itr]):
i = query_start_idx + query_itr
j = corpus_start_idx + corpus_itr
if i != j and scores_top_k_values[query_itr][top_k_idx] > min_score:
pairs.put((scores_top_k_values[query_itr][top_k_idx], i, j))
num_added += 1
if num_added >= max_pairs:
entry = pairs.get()
min_score = entry[0]
# Get the pairs
added_pairs = set() # Used for duplicate detection
pairs_list = []
while not pairs.empty():
score, i, j = pairs.get()
sorted_i, sorted_j = sorted([i, j])
if sorted_i != sorted_j and (sorted_i, sorted_j) not in added_pairs:
added_pairs.add((sorted_i, sorted_j))
pairs_list.append([score, i, j])
# Highest scores first
pairs_list = sorted(pairs_list, key=lambda x: x[0], reverse=True)
return pairs_list
def information_retrieval(*args, **kwargs):
"""This function is deprecated. Use semantic_search instead"""
return semantic_search(*args, **kwargs)
def semantic_search(query_embeddings: Tensor,
corpus_embeddings: Tensor,
query_chunk_size: int = 100,
corpus_chunk_size: int = 500000,
top_k: int = 10,
score_function: Callable[[Tensor, Tensor], Tensor] = cos_sim):
"""
This function performs a cosine similarity search between a list of query embeddings and a list of corpus embeddings.
It can be used for Information Retrieval / Semantic Search for corpora up to about 1 Million entries.
:param query_embeddings: A 2 dimensional tensor with the query embeddings.
:param corpus_embeddings: A 2 dimensional tensor with the corpus embeddings.
:param query_chunk_size: Process 100 queries simultaneously. Increasing that value increases the speed, but requires more memory.
:param corpus_chunk_size: Scans the corpus 100k entries at a time. Increasing that value increases the speed, but requires more memory.
:param top_k: Retrieve top k matching entries.
:param score_function: Function for computing scores. By default, cosine similarity.
:return: Returns a list with one entry for each query. Each entry is a list of dictionaries with the keys 'corpus_id' and 'score', sorted by decreasing cosine similarity scores.
"""
if isinstance(query_embeddings, (np.ndarray, np.generic)):
query_embeddings = torch.from_numpy(query_embeddings)
elif isinstance(query_embeddings, list):
query_embeddings = torch.stack(query_embeddings)
if len(query_embeddings.shape) == 1:
query_embeddings = query_embeddings.unsqueeze(0)
if isinstance(corpus_embeddings, (np.ndarray, np.generic)):
corpus_embeddings = torch.from_numpy(corpus_embeddings)
elif isinstance(corpus_embeddings, list):
corpus_embeddings = torch.stack(corpus_embeddings)
#Check that corpus and queries are on the same device
if corpus_embeddings.device != query_embeddings.device:
query_embeddings = query_embeddings.to(corpus_embeddings.device)
queries_result_list = [[] for _ in range(len(query_embeddings))]
for query_start_idx in range(0, len(query_embeddings), query_chunk_size):
# Iterate over chunks of the corpus
for corpus_start_idx in range(0, len(corpus_embeddings), corpus_chunk_size):
# Compute cosine similarities
cos_scores = score_function(query_embeddings[query_start_idx:query_start_idx+query_chunk_size], corpus_embeddings[corpus_start_idx:corpus_start_idx+corpus_chunk_size])
# Get top-k scores
cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(cos_scores, min(top_k, len(cos_scores[0])), dim=1, largest=True, sorted=False)
cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
for query_itr in range(len(cos_scores)):
for sub_corpus_id, score in zip(cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]):
corpus_id = corpus_start_idx + sub_corpus_id
query_id = query_start_idx + query_itr
if len(queries_result_list[query_id]) < top_k:
heapq.heappush(queries_result_list[query_id], (score, corpus_id)) # heaqp tracks the quantity of the first element in the tuple
else:
heapq.heappushpop(queries_result_list[query_id], (score, corpus_id))
#change the data format and sort
for query_id in range(len(queries_result_list)):
for doc_itr in range(len(queries_result_list[query_id])):
score, corpus_id = queries_result_list[query_id][doc_itr]
queries_result_list[query_id][doc_itr] = {'corpus_id': corpus_id, 'score': score}
queries_result_list[query_id] = sorted(queries_result_list[query_id], key=lambda x: x['score'], reverse=True)
return queries_result_list
def http_get(url, path):
"""
Downloads a URL to a given path on disc
"""
if os.path.dirname(path) != '':
os.makedirs(os.path.dirname(path), exist_ok=True)
req = requests.get(url, stream=True)
if req.status_code != 200:
print("Exception when trying to download {}. Response {}".format(url, req.status_code), file=sys.stderr)
req.raise_for_status()
return
download_filepath = path+"_part"
with open(download_filepath, "wb") as file_binary:
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, unit_scale=True)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
file_binary.write(chunk)
os.rename(download_filepath, path)
progress.close()
def batch_to_device(batch, target_device: device):
"""
send a pytorch batch to a device (CPU/GPU)
"""
for key in batch:
if isinstance(batch[key], Tensor):
batch[key] = batch[key].to(target_device)
return batch
def fullname(o):
"""
Gives a full name (package_name.class_name) for a class / object in Python. Will
be used to load the correct classes from JSON files
"""
module = o.__class__.__module__
if module is None or module == str.__class__.__module__:
return o.__class__.__name__ # Avoid reporting __builtin__
else:
return module + '.' + o.__class__.__name__
def import_from_string(dotted_path):
"""
Import a dotted module path and return the attribute/class designated by the
last name in the path. Raise ImportError if the import failed.
"""
try:
module_path, class_name = dotted_path.rsplit('.', 1)
except ValueError:
msg = "%s doesn't look like a module path" % dotted_path
raise ImportError(msg)
try:
module = importlib.import_module(dotted_path)
except:
module = importlib.import_module(module_path)
try:
return getattr(module, class_name)
except AttributeError:
msg = 'Module "%s" does not define a "%s" attribute/class' % (module_path, class_name)
raise ImportError(msg)
def community_detection(embeddings, threshold=0.75, min_community_size=10, batch_size=1024):
"""
Function for Fast Community Detection
Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold).
Returns only communities that are larger than min_community_size. The communities are returned
in decreasing order. The first element in each list is the central point in the community.
"""
if not isinstance(embeddings, torch.Tensor):
embeddings = torch.tensor(embeddings)
threshold = torch.tensor(threshold, device=embeddings.device)
extracted_communities = []
# Maximum size for community
min_community_size = min(min_community_size, len(embeddings))
sort_max_size = min(max(2 * min_community_size, 50), len(embeddings))
for start_idx in range(0, len(embeddings), batch_size):
# Compute cosine similarity scores
cos_scores = cos_sim(embeddings[start_idx:start_idx + batch_size], embeddings)
# Minimum size for a community
top_k_values, _ = cos_scores.topk(k=min_community_size, largest=True)
# Filter for rows >= min_threshold
for i in range(len(top_k_values)):
if top_k_values[i][-1] >= threshold:
new_cluster = []
# Only check top k most similar entries
top_val_large, top_idx_large = cos_scores[i].topk(k=sort_max_size, largest=True)
# Check if we need to increase sort_max_size
while top_val_large[-1] > threshold and sort_max_size < len(embeddings):
sort_max_size = min(2 * sort_max_size, len(embeddings))
top_val_large, top_idx_large = cos_scores[i].topk(k=sort_max_size, largest=True)
for idx, val in zip(top_idx_large.tolist(), top_val_large):
if val < threshold:
break
new_cluster.append(idx)
extracted_communities.append(new_cluster)
del cos_scores
# Largest cluster first
extracted_communities = sorted(extracted_communities, key=lambda x: len(x), reverse=True)
# Step 2) Remove overlapping communities
unique_communities = []
extracted_ids = set()
for cluster_id, community in enumerate(extracted_communities):
community = sorted(community)
non_overlapped_community = []
for idx in community:
if idx not in extracted_ids:
non_overlapped_community.append(idx)
if len(non_overlapped_community) >= min_community_size:
unique_communities.append(non_overlapped_community)
extracted_ids.update(non_overlapped_community)
unique_communities = sorted(unique_communities, key=lambda x: len(x), reverse=True)
return unique_communities
##################
#
######################
def snapshot_download(
repo_id: str,
revision: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
ignore_files: Optional[List[str]] = None,
use_auth_token: Union[bool, str, None] = None
) -> str:
"""
Method derived from huggingface_hub.
Adds a new parameters 'ignore_files', which allows to ignore certain files / file-patterns
"""
if cache_dir is None:
cache_dir = HUGGINGFACE_HUB_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
_api = HfApi()
token = None
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token:
token = HfFolder.get_token()
model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token)
storage_folder = os.path.join(
cache_dir, repo_id.replace("/", "_")
)
all_files = model_info.siblings
#Download modules.json as the last file
for idx, repofile in enumerate(all_files):
if repofile.rfilename == "modules.json":
del all_files[idx]
all_files.append(repofile)
break
for model_file in all_files:
if ignore_files is not None:
skip_download = False
for pattern in ignore_files:
if fnmatch.fnmatch(model_file.rfilename, pattern):
skip_download = True
break
if skip_download:
continue
url = hf_hub_url(
repo_id, filename=model_file.rfilename, revision=model_info.sha
)
relative_filepath = os.path.join(*model_file.rfilename.split("/"))
# Create potential nested dir
nested_dirname = os.path.dirname(
os.path.join(storage_folder, relative_filepath)
)
os.makedirs(nested_dirname, exist_ok=True)
cached_download_args = {'url': url,
'cache_dir': storage_folder,
'force_filename': relative_filepath,
'library_name': library_name,
'library_version': library_version,
'user_agent': user_agent,
'use_auth_token': use_auth_token}
if version.parse(huggingface_hub.__version__) >= version.parse("0.8.1"):
# huggingface_hub v0.8.1 introduces a new cache layout. We sill use a manual layout
# And need to pass legacy_cache_layout=True to avoid that a warning will be printed
cached_download_args['legacy_cache_layout'] = True
path = cached_download(**cached_download_args)
if os.path.exists(path + ".lock"):
os.remove(path + ".lock")
return storage_folder