Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/opti eval #38

Merged
merged 21 commits into from
Sep 19, 2023
Merged

Chore/opti eval #38

merged 21 commits into from
Sep 19, 2023

Conversation

Ben-Epstein
Copy link
Contributor

No description provided.

@@ -104,9 +104,11 @@ def main() -> None:
args = parse_args()
SELECTED_TORCH_DTYPE: Final[torch.dtype] = torch.float16 if args.torch_dtype == "float16" else torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(args.retriever_model_name_or_path)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like you did in the eval_rag script, can you just take the tokenizer from the retriever_model? Move this below the line after

retriever_model = AutoModelForSentenceEmbedding(
        args.retriever_model_name_or_path, tokenizer, get_peft=False, use_bnb=False
    )

then do

tokenizer = retriever_model.retriever_tokenizer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just use the End2End model’s wrapper class. But yeah can do as you mentioned as well. But the wrapper class does everything.

Copy link
Contributor

@metric-space metric-space Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand, there is no attribute retriever_tokenizer for retriever_model (unless there's some meta level extraction I can't see in the model code) and the tokenizer would be the very same I pass in, the initialization is done outside. The end2end does the initializing but AutoModelForSentenceEmbedding does not, https://github.com/arcee-ai/DALM/blob/main/dalm/models/retriever_only_base_model.py#L12

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry it should be just the tikenizer. You're right. Feel free to change it

rag_model.attach_pre_trained_peft_layers(
args.retriever_peft_model_path, args.generator_peft_model_path, args.device
)
# rag_model.attach_pre_trained_peft_layers(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this commented out?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To test in case you can't find the peft layers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left in during testing, my bad

Comment on lines 123 to 124
# TODO: ask if this is a mistake
# retriever_tokenizer = retriever_model.retriever_tokenizer
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, this is what I suggested above https://github.com/arcee-ai/DALM/pull/38/files#r1328573360

Maybe i'm missing something, why is this a mistake?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is not. Retriever class initlaize both the model and the tokenizer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, but you can still use this as the tokenizer, like I said https://github.com/arcee-ai/DALM/pull/38/files#r1328573360 right?

wouldn't this tokenizer be the exact same one as
tokenizer = AutoTokenizer.from_pretrained(args.retriever_model_name_or_path) ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes can

dalm/eval/eval_rag.py Outdated Show resolved Hide resolved
rag_model.attach_pre_trained_peft_layers(
args.retriever_peft_model_path, args.generator_peft_model_path, args.device
)
# rag_model.attach_pre_trained_peft_layers(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To test in case you can't find the peft layers

@@ -295,22 +301,29 @@ def get_passage_embeddings(

# this query comes without the answer
query = f"#query# {test_example[args.query_column_name]} #passage# {search_result_passage} #answer# "
queries_for_gen_eval.append(query)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this try to send all the prompts through the generator? If so it will easily run out of memory. We need to create some batches

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also evaluate the retriever as a batch rather than a single query per time?

Copy link
Contributor Author

@Ben-Epstein Ben-Epstein Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this try to send all the prompts through the generator? If so it will easily run out of memory. We need to create some batches

Fixed bcfb79c

Can we also evaluate the retriever as a batch rather than a single query per time?

I'm not sure how to do that currently. Searching the hnsw index, given the way get_nearest_neighbours is written, it only takes 1 at a time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I've gotten somewhere with this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah atm, we only take the top-1 passage and check the EM of the answer. But it could be more as well.

@@ -104,9 +104,11 @@ def main() -> None:
args = parse_args()
SELECTED_TORCH_DTYPE: Final[torch.dtype] = torch.float16 if args.torch_dtype == "float16" else torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(args.retriever_model_name_or_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just use the End2End model’s wrapper class. But yeah can do as you mentioned as well. But the wrapper class does everything.

Comment on lines 123 to 124
# TODO: ask if this is a mistake
# retriever_tokenizer = retriever_model.retriever_tokenizer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is not. Retriever class initlaize both the model and the tokenizer.

dalm/eval/eval_retriever_only.py Outdated Show resolved Hide resolved
@@ -219,6 +219,25 @@ def get_passage_embeddings(

print("Evaluation start")

# ruff:noqa
def my_collate_fn(batch: List[Dict[str, torch.Tensor | str]]) -> Dict[str, torch.Tensor | List[str]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful.. if we use the default collate function we will loose all the text outputs that need to compute the precision and the accuracy .. let’s add this to the e2e as well.

Copy link
Member

@shamanez shamanez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good

@Ben-Epstein Ben-Epstein merged commit 8d3602c into eval_script_optimization Sep 19, 2023
@Ben-Epstein Ben-Epstein deleted the chore/opti-eval branch September 19, 2023 20:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants