forked from cs230-stanford/cs230-code-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTriplet_Miner.py
47 lines (40 loc) · 1.55 KB
/
Triplet_Miner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
import random
def triplet_loss(A, P, N, margin=0.0):
pos_dist = np.linalg.norm(A-P)
neg_dist = np.linalg.norm(A-N)
return pos_dist - neg_dist + margin
def mine_triplets_all(embedding_tuples):
triplets_caption = []
triplets_clips = []
for tuple in embedding_tuples:
anchor = tuple[0]
positive = tuple[1]
for tuple in embedding_tuples:
negative = tuple[0]
if triplet_loss(anchor, positive, negative) > 0:
triplets_caption.append((anchor, positive, negative))
temp = anchor
anchor = positive
positive = temp
negative = tuple[1]
if triplet_loss(anchor, positive, negative) > 0:
triplets_clips.append((anchor, positive, negative))
return triplets_caption, triplets_clips
def mine_triplets_random(embedding_tuples):
triplets_caption = []
triplets_clips = []
m = len(embedding_tuples)
for tuple in embedding_tuples:
anchor = tuple[0]
positive = tuple[1]
index = random.randint(0, m - 1)
while triplet_loss(anchor, positive, embedding_tuples[index][0]) <= 0:
index = random.randint(0, m - 1)
triplets_caption.append((anchor, positive, embedding_tuples[index][0]))
temp = anchor
anchor = positive
positive = temp
while triplet_loss(anchor, positive, embedding_tuples[index][1]) <= 0:
index = random.randint(0, m - 1)
triplets_clips.append((anchor, positive, embedding_tuples[index][1]))