-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_pred.py
24 lines (18 loc) · 1002 Bytes
/
evaluate_pred.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import argparse
import json
from ref_l4 import RefL4Dataset, RefL4Evaluator
def main(args):
custom_transforms = None
ref_l4_dataset = RefL4Dataset(args.dataset_path, split=args.split, custom_transforms=custom_transforms)
print("Dataset loaded. Length:", len(ref_l4_dataset))
evaluator = RefL4Evaluator(dataset=ref_l4_dataset)
with open(args.pred_json_path, 'r') as f:
predictions = json.load(f)
evaluator.evaluate(predictions)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluate the predictions on Ref-L4 dataset.')
parser.add_argument('--dataset_path', type=str, default="JierunChen/Ref-L4", help='Path to the Ref-L4 dataset.')
parser.add_argument('--split', type=str, default='all', choices=['val', 'test', 'all'], help='Dataset split to use (val, test, all).')
parser.add_argument('--pred_json_path', type=str, required=True, help='Path to the predictions JSON file.')
args = parser.parse_args()
main(args)