Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SentenceTransformer to NLQ #26

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NLQ/VSLNet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def main(configs, parser):
e_labels.to(device),
h_labels.to(device),
)
if configs.predictor == "bert":
if configs.predictor in ("bert", "st"):
word_ids = {key: val.to(device) for key, val in word_ids.items()}
# generate mask
query_mask = (
Expand Down
15 changes: 14 additions & 1 deletion NLQ/VSLNet/model/VSLNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ConditionedPredictor,
HighLightLayer,
BertEmbedding,
STEmbedding,
)


Expand Down Expand Up @@ -87,6 +88,10 @@ def __init__(self, configs, word_vectors):
# init parameters
self.init_parameters()
self.embedding_net = BertEmbedding(configs.text_agnostic)
elif configs.predictor == "st":
self.query_affine = nn.Linear(768, configs.dim)
self.init_parameters()
self.embedding_net = STEmbedding(configs.text_agnostic)
else:
self.embedding_net = Embedding(
num_words=configs.word_size,
Expand Down Expand Up @@ -120,10 +125,18 @@ def forward(self, word_ids, char_ids, video_features, v_mask, q_mask):
if self.configs.predictor == "bert":
query_features = self.embedding_net(word_ids)
query_features = self.query_affine(query_features)
elif self.configs.predictor == "st":
query_features = self.embedding_net(word_ids, q_mask)
query_features = self.query_affine(query_features)
else:
query_features = self.embedding_net(word_ids, char_ids)

query_features = self.feature_encoder(query_features, mask=q_mask)
if len(query_features.shape) == 2:
query_features.unsqueeze_(1)
q_mask = torch.ones((query_features.shape[0], 1), device=q_mask.device)
else:
query_features = self.feature_encoder(query_features, mask=q_mask)

video_features = self.feature_encoder(video_features, mask=v_mask)
features = self.cq_attention(video_features, query_features, v_mask, q_mask)
features = self.cq_concat(features, query_features, q_mask)
Expand Down
19 changes: 18 additions & 1 deletion NLQ/VSLNet/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,23 @@ def forward(self, word_ids, char_ids):
return emb


class STEmbedding(nn.Module):
def __init__(self, text_agnostic=False):
super().__init__()
from sentence_transformers import SentenceTransformer

assert not text_agnostic
self.embedder = SentenceTransformer("all-mpnet-base-v2")
# Freeze the model.
for param in self.embedder.parameters():
param.requires_grad = False

def forward(self, word_ids, q_mask):
word_ids["attention_mask"] = q_mask
outputs = self.embedder(word_ids)
return outputs["sentence_embedding"].detach()


class BertEmbedding(nn.Module):
def __init__(self, text_agnostic=False):
super(BertEmbedding, self).__init__()
Expand Down Expand Up @@ -334,7 +351,7 @@ def forward(self, x, mask=None):
features = x + self.pos_embedding(x) # (batch_size, seq_len, dim)
features = self.conv_block(features) # (batch_size, seq_len, dim)
features = self.attention_block(
features, mask=mask
features.squeeze(1), mask=mask
) # (batch_size, seq_len, dim)
return features

Expand Down
13 changes: 10 additions & 3 deletions NLQ/VSLNet/utils/data_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
from nltk.tokenize import word_tokenize
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

from utils.data_util import (
load_json,
Expand Down Expand Up @@ -50,7 +51,7 @@ def process_data_tan(self, data, scope):
for timestamp, exact_time, sentence, ann_uid, query_idx in zipper:
start_time = max(0.0, float(timestamp[0]) / fps)
end_time = min(float(timestamp[1]) / fps, duration)
if self._predictor != "bert":
if self._predictor not in ("bert", "st"):
words = word_tokenize(sentence.strip().lower(), language="english")
else:
words = sentence
Expand Down Expand Up @@ -308,6 +309,7 @@ def gen_or_load_dataset(configs):
)
+ ".pkl",
)
# NOTE dataset cache
if os.path.exists(save_path):
dataset = load_pickle(save_path)
return dataset
Expand All @@ -329,10 +331,14 @@ def gen_or_load_dataset(configs):
if val_data is None
else [train_data, val_data, test_data]
)
if configs.predictor == "bert":
if configs.predictor in ("bert", "st"):
from transformers import BertTokenizer, BertForPreTraining

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
if configs.predictor == "bert":
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
else:
# tokenizer = lambda x: x # just keep the text as is
tokenizer = SentenceTransformer("all-mpnet-base-v2").tokenizer
train_set = dataset_gen_bert(
train_data,
vfeat_lens,
Expand Down Expand Up @@ -416,5 +422,6 @@ def gen_or_load_dataset(configs):
"n_words": len(word_dict),
"n_chars": len(char_dict),
}
print("Saving")
save_pickle(dataset, save_path)
return dataset
16 changes: 10 additions & 6 deletions NLQ/VSLNet/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ def train_collate_fn(data):
if not isinstance(word_ids[0], list):
pad_input_ids, _ = pad_seq([ii["input_ids"] for ii in word_ids])
pad_attention_mask, _ = pad_seq([ii["attention_mask"] for ii in word_ids])
pad_token_type_ids, _ = pad_seq([ii["token_type_ids"] for ii in word_ids])
word_ids = {
new_word_ids = {
"input_ids": torch.LongTensor(pad_input_ids),
"attention_mask": torch.LongTensor(pad_attention_mask),
"token_type_ids": torch.LongTensor(pad_token_type_ids),
}
if "token_type_ids" in word_ids[0]:
pad_token_type_ids, _ = pad_seq([ii["token_type_ids"] for ii in word_ids])
new_word_ids["token_type_ids"] = torch.LongTensor(pad_token_type_ids)
word_ids = new_word_ids
char_ids = None
else:
# process word ids
Expand Down Expand Up @@ -83,12 +85,14 @@ def test_collate_fn(data):
if not isinstance(word_ids[0], list):
pad_input_ids, _ = pad_seq([ii["input_ids"] for ii in word_ids])
pad_attention_mask, _ = pad_seq([ii["attention_mask"] for ii in word_ids])
pad_token_type_ids, _ = pad_seq([ii["token_type_ids"] for ii in word_ids])
word_ids = {
new_word_ids = {
"input_ids": torch.LongTensor(pad_input_ids),
"attention_mask": torch.LongTensor(pad_attention_mask),
"token_type_ids": torch.LongTensor(pad_token_type_ids),
}
if "token_type_ids" in word_ids[0]:
pad_token_type_ids, _ = pad_seq([ii["token_type_ids"] for ii in word_ids])
new_word_ids["token_type_ids"] = torch.LongTensor(pad_token_type_ids)
word_ids = new_word_ids
char_ids = None
else:
# process word ids
Expand Down