Skip to content

Commit

Permalink
Improve fusion regression (#2627)
Browse files Browse the repository at this point in the history
This is okay to merge for now - I will circle back and add my patch on top.
  • Loading branch information
Stefan824 authored Jan 16, 2025
1 parent 97c71bd commit 4609d0b
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/main/python/run_fusion_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def construct_fusion_commands(yaml_data: dict) -> list:
return [
[
FUSE_COMMAND,
'-runs', ' '.join([run for run in yaml_data['runs']]),
'-runs', ' '.join(run['file'] for run in yaml_data['runs']),
'-output', method.get('output'),
'-method', method.get('name', 'average'),
'-k', str(method.get('k', 1000)),
Expand Down Expand Up @@ -166,6 +166,12 @@ def evaluate_and_verify(yaml_data: dict, dry_run: bool):
logger.error(f"Failed to load configuration file: {e}")
exit(1)

# Check existence of run files
for run in yaml_data['runs']:
if not os.path.exists(run['file']):
logger.error(f"Run file {run['file']} does not exist. Please run the dependent regressions first, recorded in the fusion yaml file.")
exit(1)

# Construct the fusion command
fusion_commands = construct_fusion_commands(yaml_data)

Expand All @@ -178,4 +184,4 @@ def evaluate_and_verify(yaml_data: dict, dry_run: bool):
# Evaluate and verify results
evaluate_and_verify(yaml_data, args.dry_run)

logger.info(f"Total execution time: {time.time() - start_time:.2f} seconds")
logger.info(f"Total execution time: {time.time() - start_time:.2f} seconds")
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
---
corpus: beir-v1.0.0-robust04
corpus_path: collections/beir-v1.0.0/corpus/robust04/

metrics:
- metric: nDCG@10
command: bin/trec_eval
params: -c -m ndcg_cut.10
separator: "\t"
parse_index: 2
metric_precision: 4
can_combine: false
- metric: R@100
command: bin/trec_eval
params: -c -m recall.100
separator: "\t"
parse_index: 2
metric_precision: 4
can_combine: false
- metric: R@1000
command: bin/trec_eval
params: -c -m recall.1000
separator: "\t"
parse_index: 2
metric_precision: 4
can_combine: false

topic_reader: TsvString
topics:
- name: "BEIR (v1.0.0): Robust04"
id: test
path: topics.beir-v1.0.0-robust04.test.tsv.gz
qrel: qrels.beir-v1.0.0-robust04.test.txt

# Run dependencies for fusion
runs:
- name: flat-bm25
dependency: beir-v1.0.0-robust04.flat.yaml
file: runs/run.beir-v1.0.0-robust04.flat.bm25.topics.beir-v1.0.0-robust04.test.txt
- name: bge-flat-onnx
dependency: beir-v1.0.0-robust04.bge-base-en-v1.5.flat.onnx.yaml
file: runs/run.beir-v1.0.0-robust04.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt

methods:
- name: rrf
k: 1000
depth: 1000
rrf_k: 60
output: runs/runs.fuse.rrf.beir-v1.0.0-robust04.flat.bm25.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt
results:
nDCG@10:
- 0.5070
R@100:
- 0.4465
R@1000:
- 0.7219
- name: average
output: runs/runs.fuse.avg.beir-v1.0.0-robust04.flat.bm25.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt
results:
nDCG@10:
- 0.4324
R@100:
- 0.3963
R@1000:
- 0.6345
- name: interpolation
alpha: 0.5
output: runs/runs.fuse.interp.beir-v1.0.0-robust04.flat.bm25.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt
results:
nDCG@10:
- 0.4324
R@100:
- 0.3963
R@1000:
- 0.6345

0 comments on commit 4609d0b

Please sign in to comment.