This repository has been archived by the owner on Mar 19, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathclose_related_methods.py
69 lines (60 loc) · 3.11 KB
/
close_related_methods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
def get_closest_token_cosine_similarities(single_weight, all_weights, return_scores=False):
cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
scores = cos(all_weights, single_weight.unsqueeze(0).to(all_weights.device))
sorted_scores, sorted_ids = torch.sort(scores, descending=True)
best_id_list = sorted_ids.tolist()
if not return_scores:
return best_id_list
scores_list = sorted_scores.tolist()
return best_id_list, scores_list
def get_closest_token_euclidean(single_weight, all_weights):
single_weight = single_weight.to(all_weights.device)
distances = torch.norm(all_weights - single_weight.unsqueeze(0), dim=1, p=2)
sorted_distances, sorted_ids = torch.sort(distances)
best_id_list = sorted_ids.tolist()
return best_id_list
def get_closest_token_manhattan(single_weight, all_weights):
single_weight = single_weight.to(all_weights.device)
distances = torch.norm(all_weights - single_weight.unsqueeze(0), dim=1, p=1)
sorted_distances, sorted_ids = torch.sort(distances)
best_id_list = sorted_ids.tolist()
return best_id_list
def get_closest_token_jaccard_similarity(single_weight, all_weights, return_scores=False):
single_weight = single_weight.to(all_weights.device)
binary_single_weight = (single_weight > 0).float()
binary_all_weights = (all_weights > 0).float()
# Calculate Jaccard similarity as intersection over union.
intersection = torch.min(binary_single_weight, binary_all_weights).sum(dim=1)
union = torch.max(binary_single_weight, binary_all_weights).sum(dim=1)
jaccard_similarity = intersection / (union + 1e-6) # Add small epsilon to avoid division by zero
sorted_similarities, sorted_ids = torch.sort(jaccard_similarity, descending=True)
best_id_list = sorted_ids.tolist()
if not return_scores:
return best_id_list
scores_list = sorted_similarities.tolist()
return best_id_list, scores_list
def get_closest_token_mahalanobis(single_weight, all_weights):
single_weight = single_weight.to(all_weights.device)
distances = torch.norm(all_weights - single_weight.unsqueeze(0), dim=1, p=2)
sorted_distances, sorted_ids = torch.sort(distances)
best_id_list = sorted_ids.tolist()
return best_id_list
def get_closest_token_hamming(single_weight, all_weights):
single_weight = single_weight.to(all_weights.device)
# Convert weights to binary
binary_single_weight = (single_weight > 0).int()
binary_all_weights = (all_weights > 0).int()
# Calculate Hamming distance
differences = torch.abs(binary_single_weight - binary_all_weights).sum(dim=1)
sorted_differences, sorted_ids = torch.sort(differences)
best_id_list = sorted_ids.tolist()
return best_id_list
correlation_functions = {
"cosine_similarities":get_closest_token_cosine_similarities,
"euclidean": get_closest_token_euclidean,
"manhattan": get_closest_token_manhattan,
"jaccard": get_closest_token_jaccard_similarity,
"mahalanobis": get_closest_token_mahalanobis,
"hamming": get_closest_token_hamming
}