Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added extra experiments - mainly around macro chunking #16

Merged
merged 32 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
7ce00ae
initial experiment setup
dannyjameswilliams Sep 26, 2024
ffed829
changed len tokens
dannyjameswilliams Sep 26, 2024
b467575
soft boundary correctly overlaps for later batches
dannyjameswilliams Sep 26, 2024
d0ef88a
removed incorrect soft
dannyjameswilliams Sep 26, 2024
7985c66
added loop over overlap sizes
dannyjameswilliams Sep 26, 2024
56a6a41
added results
dannyjameswilliams Sep 26, 2024
f3eb17d
experiment for WikimQA
dannyjameswilliams Sep 26, 2024
50b2b00
added narrativeQA task
dannyjameswilliams Sep 26, 2024
3f226a5
added remainder of longembed datasets
dannyjameswilliams Sep 26, 2024
19ab752
typo
dannyjameswilliams Sep 26, 2024
57cc8ba
more tasks for soft/hard
dannyjameswilliams Sep 26, 2024
9043a3f
experiments
dannyjameswilliams Sep 26, 2024
e1d66e1
fix merge
dannyjameswilliams Sep 26, 2024
0a4a27e
added macro chunk experiment file
dannyjameswilliams Sep 26, 2024
043eec6
for merge
dannyjameswilliams Sep 26, 2024
f960b15
merge def values
dannyjameswilliams Sep 26, 2024
8e0c80d
chunk size results
dannyjameswilliams Sep 26, 2024
753b603
added benchmark files for macro chunks
dannyjameswilliams Sep 26, 2024
6b52c8c
removed raw results
dannyjameswilliams Sep 26, 2024
9fc7d52
added plotting files for results, requires running them first
dannyjameswilliams Sep 26, 2024
7cff570
renamed file
dannyjameswilliams Sep 26, 2024
b831989
renamed file
dannyjameswilliams Sep 26, 2024
1eccfe5
added plt.show()
dannyjameswilliams Sep 26, 2024
4219ca6
renamed to macro chunking
dannyjameswilliams Sep 26, 2024
ad9f37c
moved file
dannyjameswilliams Sep 26, 2024
ada2a91
Update chunked_pooling/chunking.py according to comment
dannyjameswilliams Oct 1, 2024
d4f99ce
updated main experiment file with long late chunking
dannyjameswilliams Oct 2, 2024
936ae51
remove redundant macro chunking file
dannyjameswilliams Oct 2, 2024
1ed7fb6
updated default to truncation (8192)
dannyjameswilliams Oct 2, 2024
abab7fa
updated error message/print statement
dannyjameswilliams Oct 2, 2024
b46d469
changed how local llm is loaded
dannyjameswilliams Oct 2, 2024
b64c2a6
removed comment on pip and update default model to phi
dannyjameswilliams Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 161 additions & 1 deletion chunked_pooling/chunked_eval_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class LEMBWikimQARetrievalChunked(AbsTaskChunkedRetrieval):
name="LEMBWikimQARetrievalChunked",
dataset={
"path": "dwzhu/LongEmbed",
"revision": "6e346642246bfb4928c560ee08640dc84d074e8c",
"revision": "10039a580487dacecf79db69166e17ace3ede392",
"name": "LEMBWikimQARetrieval",
},
reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
Expand Down Expand Up @@ -297,6 +297,166 @@ def load_data(self, **kwargs):
self.data_loaded = True


class LEMBSummScreenFDRetrievalChunked(AbsTaskChunkedRetrieval):
"""
modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py
"""

_EVAL_SPLIT = "test"

metadata = TaskMetadata(
name="LEMBSummScreenFDRetrievalChunked",
dataset={
"path": "dwzhu/LongEmbed",
"revision": "10039a580487dacecf79db69166e17ace3ede392",
"name": "LEMBSummScreenFDRetrieval",
},
reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
description=("summ_screen_fd subset of dwzhu/LongEmbed dataset."),
type="Retrieval",
category="s2p",
modalities=["text"],
eval_splits=[_EVAL_SPLIT],
eval_langs=["eng-Latn"],
main_score="ndcg_at_10",
date=("1950-01-01", "2019-12-31"),
domains=None,
socioeconomic_status=None,
n_samples=None,
avg_character_length=None,
form=None,
text_creation=None,
task_subtypes=["Article retrieval"],
license="not specified",
annotations_creators="derived",
dialect=[],
sample_creation="found",
bibtex_citation="""
@inproceedings{ho2020constructing,
title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps},
author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko},
booktitle={Proceedings of the 28th International Conference on Computational Linguistics},
pages={6609--6625},
year={2020}
}
""",
descriptive_stats={
"n_samples": {_EVAL_SPLIT: 500},
"avg_character_length": {
"test": {
"average_document_length": 30854.327,
"average_query_length": 591.49,
"num_documents": 300,
"num_queries": 300,
"average_relevant_docs_per_query": 1.0,
}
},
},
)

def load_data(self, **kwargs):
if self.data_loaded:
return

dataset_dict = {**self.metadata.dataset}
dataset_dict['name'] = 'summ_screen_fd'

query_list = datasets.load_dataset(**dataset_dict)["queries"]
queries = {row["qid"]: row["text"] for row in query_list}

corpus_list = datasets.load_dataset(**dataset_dict)["corpus"]
corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list}

qrels_list = datasets.load_dataset(**dataset_dict)["qrels"]
qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list}

self.corpus = {self._EVAL_SPLIT: corpus}
self.queries = {self._EVAL_SPLIT: queries}
self.relevant_docs = {self._EVAL_SPLIT: qrels}

self.data_loaded = True


class LEMBQMSumRetrievalChunked(AbsTaskChunkedRetrieval):
"""
modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBWikimQARetrieval.py
"""

_EVAL_SPLIT = "test"

metadata = TaskMetadata(
name="LEMBQMSumRetrievalChunked",
dataset={
"path": "dwzhu/LongEmbed",
"revision": "10039a580487dacecf79db69166e17ace3ede392",
"name": "LEMBQMSumRetrieval",
},
reference="https://huggingface.co/datasets/dwzhu/LongEmbed",
description=("qmsum subset of dwzhu/LongEmbed dataset."),
type="Retrieval",
category="s2p",
modalities=["text"],
eval_splits=[_EVAL_SPLIT],
eval_langs=["eng-Latn"],
main_score="ndcg_at_10",
date=("1950-01-01", "2019-12-31"),
domains=None,
socioeconomic_status=None,
n_samples=None,
avg_character_length=None,
form=None,
text_creation=None,
task_subtypes=["Article retrieval"],
license="not specified",
annotations_creators="derived",
dialect=[],
sample_creation="found",
bibtex_citation="""
@inproceedings{ho2020constructing,
title={Constructing A Multi-hop QA Dataset for Comprehensive Evaluation of Reasoning Steps},
author={Ho, Xanh and Nguyen, Anh-Khoa Duong and Sugawara, Saku and Aizawa, Akiko},
booktitle={Proceedings of the 28th International Conference on Computational Linguistics},
pages={6609--6625},
year={2020}
}
""",
descriptive_stats={
"n_samples": {_EVAL_SPLIT: 500},
"avg_character_length": {
"test": {
"average_document_length": 53335.817,
"average_query_length": 433.50,
"num_documents": 300,
"num_queries": 300,
"average_relevant_docs_per_query": 1.0,
}
},
},
)

def load_data(self, **kwargs):
if self.data_loaded:
return

dataset_dict = {**self.metadata.dataset}
dataset_dict['name'] = 'qmsum'

query_list = datasets.load_dataset(**dataset_dict)["queries"]
queries = {row["qid"]: row["text"] for row in query_list}

corpus_list = datasets.load_dataset(**dataset_dict)["corpus"]
corpus = {row["doc_id"]: {"text": row["text"]} for row in corpus_list}

qrels_list = datasets.load_dataset(**dataset_dict)["qrels"]
qrels = {row["qid"]: {row["doc_id"]: 1} for row in qrels_list}

self.corpus = {self._EVAL_SPLIT: corpus}
self.queries = {self._EVAL_SPLIT: queries}
self.relevant_docs = {self._EVAL_SPLIT: qrels}

self.data_loaded = True


class LEMBNeedleRetrievalChunked(AbsTaskChunkedRetrieval):
"""
modified from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/tasks/Retrieval/eng/LEMBNeedleRetrieval.py
Expand Down
4 changes: 2 additions & 2 deletions chunked_pooling/chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def chunk(
tokenizer=tokenizer,
)
elif chunking_strategy == "fixed":
if chunk_size < 10:
raise ValueError("Chunk size must be greater than 10.")
if chunk_size < 4:
raise ValueError("Chunk size must be >= 4.")
return self.chunk_by_tokens(text, chunk_size, tokenizer)
elif chunking_strategy == "sentences":
return self.chunk_by_sentences(text, n_sentences, tokenizer)
Expand Down
52 changes: 46 additions & 6 deletions chunked_pooling/mteb_chunked_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def __init__(
model_has_instructions: bool = False,
embedding_model_name: Optional[str] = None, # for semantic chunking
truncate_max_length: Optional[int] = 8192,
long_late_chunking_embed_size: Optional[int] = 0,
long_late_chunking_overlap_size: Optional[int] = 512,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -51,6 +53,9 @@ def __init__(
}
self.truncate_max_length = truncate_max_length

self.long_late_chunking_embed_size = long_late_chunking_embed_size
self.long_late_chunking_overlap_size = long_late_chunking_overlap_size

def load_data(self, **kwargs):
self.retrieval_task.load_data(**kwargs)
self.corpus = self.retrieval_task.corpus
Expand Down Expand Up @@ -114,6 +119,34 @@ def _truncate_documents(self, corpus):
v['text'] = v['text'][: last_token_span[1]]
return corpus

def _embed_with_overlap(self, model, model_inputs):

len_tokens = len(model_inputs["input_ids"][0])

if len_tokens > self.long_late_chunking_embed_size:
indices = []
for i in range(0, len_tokens, self.long_late_chunking_embed_size - self.long_late_chunking_overlap_size):
start = i
end = min(i + self.long_late_chunking_embed_size, len_tokens)
indices.append((start, end))
else:
indices = [(0, len_tokens)]

outputs = []
for start, end in indices:

batch_inputs = {k: v[:, start:end] for k, v in model_inputs.items()}

with torch.no_grad():
model_output = model(**batch_inputs)

if start > 0:
outputs.append(model_output[0][:, self.long_late_chunking_overlap_size:])
else:
outputs.append(model_output[0])

return torch.cat(outputs, dim=1).to(model.device)

def _evaluate_monolingual(
self,
model,
Expand Down Expand Up @@ -181,17 +214,24 @@ def _evaluate_monolingual(
text_inputs,
return_tensors='pt',
padding=True,
truncation=True,
max_length=8192,
truncation=self.truncate_max_length is not None,
max_length=self.truncate_max_length,
)
if model.device.type == 'cuda':
model_inputs = {
k: v.to(model.device) for k, v in model_inputs.items()
}
model_outputs = model(**model_inputs)
output_embs = chunked_pooling(
model_outputs, annotations, max_length=8192
)

if self.long_late_chunking_embed_size > 0:
model_outputs = self._embed_with_overlap(model, model_inputs)
output_embs = chunked_pooling(
[model_outputs], annotations, max_length=None
)
else: # truncation
model_outputs = model(**model_inputs)
output_embs = chunked_pooling(
model_outputs, annotations, max_length=self.truncate_max_length
)
corpus_embs.extend(output_embs)

max_chunks = max([len(x) for x in corpus_embs])
Expand Down
Loading