-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_squad.py
34 lines (25 loc) · 1.02 KB
/
eval_squad.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import paddle
from utils.args import parse_args
from utils.data import get_dev_dataloader
from train_squad import MODEL_CLASSES, evaluate
import os
def main(args):
paddle.set_device(args.device)
model_class, tokenizer_class, args.need_token_type_ids = MODEL_CLASSES[
args.model_type
]
if args.use_huggingface_tokenizer and args.model_type == "mobilebert":
from transformers import MobileBertTokenizerFast
tokenizer = MobileBertTokenizerFast.from_pretrained("google/mobilebert-uncased")
else:
print("paddle tokenizer.")
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
dev_data_loader = get_dev_dataloader(tokenizer, args)
if not args.version_2_with_negative:
model = model_class.from_pretrained("./task/squadv1/step-6720")
else:
model = model_class.from_pretrained("./task/squadv2/step-10320")
evaluate(model, dev_data_loader, args, output_dir="./")
if __name__ == "__main__":
args_ = parse_args()
main(args_)