Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
[feature] support PubMedQA (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee authored Jul 25, 2024
1 parent 4ade71d commit 440d33b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
6 changes: 1 addition & 5 deletions evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json

import fire
import torch

Expand Down Expand Up @@ -45,9 +43,7 @@ def main(
router_profile=router_profile,
)

output = mlora.evaluate(model, tokenizer, [evaluate_paramas], save_file=save_file)

print(json.dumps(output, indent=4))
mlora.evaluate(model, tokenizer, [evaluate_paramas], save_file=save_file)


if __name__ == "__main__":
Expand Down
40 changes: 40 additions & 0 deletions mlora/tasks/qa_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,45 @@ def loading_data(
return ret


class PubMedQA(QuestionAnswerTask):
def __init__(self) -> None:
super().__init__(["yes", "no", "maybe"])

def loading_data(
self, tokenizer: Tokenizer, is_train: bool = True
) -> List[DataClass]:
data = hf_datasets.load_dataset(
"qiaojin/PubMedQA", "pqa_artificial" if is_train else "pqa_labeled"
)["train"]
logging.info("Preparing data for PubMedQA")
ret: List[DataClass] = []
for idx, data_point in enumerate(data):
prompt = (
"Instruction:\nPlease answer the following question with yes or no "
+ "based on your medical knowledge and the following context.\n"
+ f"Question:\n{data_point['question']}\nContext:\n"
)
context = data_point["context"]
for label, text in zip(context["labels"], context["contexts"]):
prompt += f"({label}) {text}\n"
answer = data_point["final_decision"]
assert answer in self.labels2id_
if is_train:
prompt += f"Long Answer:\n{data_point['long_answer']}\n"
prompt += "Answer:"
prompt += f" {answer}"
labels = None
else:
prompt += "Answer:"
labels = [self.labels2id_[answer]]
tokens = tokenizer.encode(data=prompt)
ret.append(DataClass(tokens_=tokens, labels_=labels))
if idx % 10000 == 0:
logging.info(f"Encode text data: {idx}/{len(data)}")

return ret


def update_task_dict(task_dict):
task_dict.update(
{
Expand All @@ -298,5 +337,6 @@ def update_task_dict(task_dict):
"hellaswag": HellaSwag(),
"winogrande": WinoGrande(),
"csqa": CommonSenseQA(),
"pubmedqa": PubMedQA(),
}
)

0 comments on commit 440d33b

Please sign in to comment.