-
Notifications
You must be signed in to change notification settings - Fork 199
/
Copy pathtrain_msmarco_v3_bpr.py
174 lines (136 loc) · 7.93 KB
/
train_msmarco_v3_bpr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
'''
This example shows how to train a Binary-Code (Binary Passage Retriever) based Bi-Encoder for the MS Marco dataset (https://github.com/microsoft/MSMARCO-Passage-Ranking).
The model is trained using hard negatives which were specially mined with different dense and lexical search methods for MSMARCO.
The idea for Binary Passage Retriever originated by Yamada et. al, 2021 in Efficient Passage Retrieval with Hashing for Open-domain Question Answering.
For more details, please refer here: https://arxiv.org/abs/2106.00882
This example has been taken from here with few modifications to train SBERT (MSMARCO-v3) models:
(https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/ms_marco/train_bi-encoder-v3.py)
The queries and passages are passed independently to the transformer network to produce fixed sized binary codes or hashes!!
These embeddings can then be compared using hamming distances to find matching passages for a given query.
For training, we use BPRLoss (MarginRankingLoss + MultipleNegativesRankingLoss). There, we pass triplets in the format:
(query, positive_passage, negative_passage)
Negative passage are hard negative examples, that were mined using different dense embedding methods and lexical search methods.
Each positive and negative passage comes with a score from a Cross-Encoder. This allows denoising, i.e. removing false negative
passages that are actually relevant for the query.
Running this script:
python train_msmarco_v3_bpr.py
'''
from sentence_transformers import SentenceTransformer, models, InputExample
from beir import util, LoggingHandler
from beir.losses import BPRLoss
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.train import TrainRetriever
from torch.utils.data import Dataset
from tqdm.autonotebook import tqdm
import pathlib, os, gzip, json
import logging
import random
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout
#### Download msmarco.zip dataset and unzip the dataset
dataset = "msmarco"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
data_path = util.download_and_unzip(url, out_dir)
### Load BEIR MSMARCO training dataset, this will be used for query and corpus for reference.
corpus, queries, _ = GenericDataLoader(data_path).load(split="train")
#################################
#### Parameters for Training ####
#################################
train_batch_size = 75 # Increasing the train batch size improves the model performance, but requires more GPU memory (O(n))
max_seq_length = 350 # Max length for passages. Increasing it, requires more GPU memory (O(n^2))
ce_score_margin = 3 # Margin for the CrossEncoder score between negative and positive passages
num_negs_per_system = 5 # We used different systems to mine hard negatives. Number of hard negatives to add from each system
##################################################
#### Download MSMARCO Hard Negs Triplets File ####
##################################################
triplets_url = "https://sbert.net/datasets/msmarco-hard-negatives.jsonl.gz"
msmarco_triplets_filepath = os.path.join(data_path, "msmarco-hard-negatives.jsonl.gz")
if not os.path.isfile(msmarco_triplets_filepath):
util.download_url(triplets_url, msmarco_triplets_filepath)
#### Load the hard negative MSMARCO jsonl triplets from SBERT
#### These contain a ce-score which denotes the cross-encoder score for the query and passage.
#### We chose a margin between positive and negative passage scores => above which consider negative as hard negative.
#### Finally to limit the number of negatives per passage, we define num_negs_per_system across all different systems.
logging.info("Loading MSMARCO hard-negatives...")
train_queries = {}
with gzip.open(msmarco_triplets_filepath, 'rt', encoding='utf8') as fIn:
for line in tqdm(fIn, total=502939):
data = json.loads(line)
#Get the positive passage ids
pos_pids = [item['pid'] for item in data['pos']]
pos_min_ce_score = min([item['ce-score'] for item in data['pos']])
ce_score_threshold = pos_min_ce_score - ce_score_margin
#Get the hard negatives
neg_pids = set()
for system_negs in data['neg'].values():
negs_added = 0
for item in system_negs:
if item['ce-score'] > ce_score_threshold:
continue
pid = item['pid']
if pid not in neg_pids:
neg_pids.add(pid)
negs_added += 1
if negs_added >= num_negs_per_system:
break
if len(pos_pids) > 0 and len(neg_pids) > 0:
train_queries[data['qid']] = {'query': queries[data['qid']], 'pos': pos_pids, 'hard_neg': list(neg_pids)}
logging.info("Train queries: {}".format(len(train_queries)))
# We create a custom MSMARCO dataset that returns triplets (query, positive, negative)
# on-the-fly based on the information from the mined-hard-negatives jsonl file.
class MSMARCODataset(Dataset):
def __init__(self, queries, corpus):
self.queries = queries
self.queries_ids = list(queries.keys())
self.corpus = corpus
for qid in self.queries:
self.queries[qid]['pos'] = list(self.queries[qid]['pos'])
self.queries[qid]['hard_neg'] = list(self.queries[qid]['hard_neg'])
random.shuffle(self.queries[qid]['hard_neg'])
def __getitem__(self, item):
query = self.queries[self.queries_ids[item]]
query_text = query['query']
pos_id = query['pos'].pop(0) #Pop positive and add at end
pos_text = self.corpus[pos_id]["text"]
query['pos'].append(pos_id)
neg_id = query['hard_neg'].pop(0) #Pop negative and add at end
neg_text = self.corpus[neg_id]["text"]
query['hard_neg'].append(neg_id)
return InputExample(texts=[query_text, pos_text, neg_text])
def __len__(self):
return len(self.queries)
# We construct the SentenceTransformer bi-encoder from scratch with CLS token Pooling
model_name = "distilbert-base-uncased"
word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
pooling_mode_cls_token=True,
pooling_mode_mean_tokens=False)
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
#### Provide a high batch-size to train better with triplets!
retriever = TrainRetriever(model=model, batch_size=train_batch_size)
# For training the SentenceTransformer model, we need a dataset, a dataloader, and a loss used for training.
train_dataset = MSMARCODataset(train_queries, corpus=corpus)
train_dataloader = retriever.prepare_train(train_dataset, shuffle=True, dataset_present=True)
#### Training SBERT with dot-product (default)
train_loss = BPRLoss(model=retriever.model)
#### If no dev set is present from above use dummy evaluator
ir_evaluator = retriever.load_dummy_evaluator()
#### Provide model save path
model_save_path = os.path.join(pathlib.Path(__file__).parent.absolute(), "output", "{}-v3-{}".format(model_name, dataset))
os.makedirs(model_save_path, exist_ok=True)
#### Configure Train params
num_epochs = 10
evaluation_steps = 10000
warmup_steps = 1000
retriever.fit(train_objectives=[(train_dataloader, train_loss)],
evaluator=ir_evaluator,
epochs=num_epochs,
output_path=model_save_path,
warmup_steps=warmup_steps,
evaluation_steps=evaluation_steps,
use_amp=True)