From 1cb5ea13344196a6b7e14553df2db36993c0db2d Mon Sep 17 00:00:00 2001 From: Carlos Date: Fri, 26 Jan 2024 10:57:16 +0000 Subject: [PATCH 01/33] add fusion --- pom.xml | 4 + .../java/io/anserini/search/FuseRuns.java | 249 ++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 src/main/java/io/anserini/search/FuseRuns.java diff --git a/pom.xml b/pom.xml index 34ca37707c..72fa25a2cb 100644 --- a/pom.xml +++ b/pom.xml @@ -127,6 +127,10 @@ io.anserini.search.SearchCollection SearchCollection + + io.anserini.search.FuseRuns + FuseRuns + io.anserini.search.SearchHnswDenseVectors SearchHnswDenseVectors diff --git a/src/main/java/io/anserini/search/FuseRuns.java b/src/main/java/io/anserini/search/FuseRuns.java new file mode 100644 index 0000000000..10a2ef85d4 --- /dev/null +++ b/src/main/java/io/anserini/search/FuseRuns.java @@ -0,0 +1,249 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.anserini.search; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.HashMap; +import java.util.TreeMap; +import java.util.Map; +import java.util.Collections; +import java.util.Set; +import java.util.HashSet; +import java.io.Closeable; +import java.io.IOException; +import java.io.FileNotFoundException; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Locale; + +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.ParserProperties; +import org.kohsuke.args4j.Option; + +class FusedRunOutputWriter implements Closeable { + private final PrintWriter out; + private final String format; + private final String runtag; + + public FusedRunOutputWriter(String output, String format, String runtag) throws IOException { + this.out = new PrintWriter(Files.newBufferedWriter(Paths.get(output), StandardCharsets.UTF_8)); + this.format = format; + this.runtag = runtag; + } + + private class Document implements Comparable{ + String docid; + double score; + + public Document(String docid, double score) + { + this.docid = docid; + this.score = score; + } + + @Override public int compareTo(Document a) + { + return Double.compare(a.score,this.score); + } + + } + + public void writeTopic(String qid, HashMap results) { + int rank = 1; + ArrayList documents = new ArrayList<>(); + + for (Map.Entry entry : results.entrySet()) { + documents.add(new Document(entry.getKey(), entry.getValue())); + } + Collections.sort(documents); + for (Document r : documents) { + if ("msmarco".equals(format)) { + // MS MARCO output format: + out.append(String.format(Locale.US, "%s\t%s\t%d\n", qid, r.docid, rank)); + } else { + // Standard TREC format: + // + the first column is the topic number. + // + the second column is currently unused and should always be "Q0". + // + the third column is the official document identifier of the retrieved document. + // + the fourth column is the rank the document is retrieved. + // + the fifth column shows the score (integer or floating point) that generated the ranking. + // + the sixth column is called the "run tag" and should be a unique identifier for your + out.append(String.format(Locale.US, "%s Q0 %s %d %f %s\n", qid, r.docid, rank, r.score, runtag)); + } + rank++; + } + } + + @Override + public void close() { + out.flush(); + out.close(); + } +} + + +public class FuseRuns { + + public static class Args { + @Option(name = "-options", usage = "Print information about options.") + public Boolean options = false; + + @Option(name = "-filename_a", metaVar = "[filename_a]", required = true, usage = "Path to the first run to fuse") + public String filename_a; + + @Option(name = "-filename_b", metaVar = "[filename_b]", required = true, usage = "Path to the second run to fuse") + public String filename_b; + + @Option(name = "-filename_output", metaVar = "[filename_output]", required = true, usage = "Path to save the output") + public String filename_output; + + @Option(name = "-runtag", metaVar = "[runtag]", usage = "Run tag for the fusion") + public String runtag="fused"; + + } + + public static TreeMap> createRunMap(String filename){ + + TreeMap> twoTierHashMap = new TreeMap<>(); + try (BufferedReader br = new BufferedReader(new FileReader(filename))) { + String line; + while ((line = br.readLine()) != null) { + String[] data = line.split(" "); + HashMap innerHashMap = new HashMap(); + HashMap innerHashMap_ = twoTierHashMap.putIfAbsent(data[0], innerHashMap); + if (innerHashMap_ != null){ + innerHashMap = innerHashMap_; + } + innerHashMap.put(data[2], Double.valueOf(data[4])); + } + } catch (FileNotFoundException ex){ + System.out.println(ex); + } catch (IOException ex){ + System.out.println(ex); + } + return twoTierHashMap; + } + + public static void normalize_min_max(TreeMap> hashMap) { + for (String outerKey : hashMap.keySet()) { + Map innerHashMap = hashMap.get(outerKey); + Double min = Double.MAX_VALUE; + Double max = -1.0; + for (String innerKey : innerHashMap.keySet()) { + Double innerValue = innerHashMap.get(innerKey); + if (innerValue < min) { + min = innerValue; + } + if (innerValue > max) { + max = innerValue; + } + } + for (String innerKey : innerHashMap.keySet()) { + Double innerValue = innerHashMap.get(innerKey); + Double newValue = (innerValue - min) / (max-min); + innerHashMap.replace(innerKey,innerValue,newValue); + } + } + } + + public static HashMap aggregateQuery(HashMap hashMap1, HashMap hashMap2) { + HashMap mergedHashMap = new HashMap(); + for (String key : hashMap1.keySet()) { + mergedHashMap.put(key, hashMap1.get(key)); + } + for (String key : hashMap2.keySet()) { + Double existingValue = mergedHashMap.getOrDefault(key,0.0); + mergedHashMap.put(key, hashMap2.get(key) + existingValue); + } + return mergedHashMap; + } + + public static TreeMap> aggregateHashMap(TreeMap> hashMap1, TreeMap> hashMap2) { + Set queries = new HashSet(); + TreeMap> finalHashMap = new TreeMap>(); + for (String key : hashMap1.keySet()) { + queries.add(key); + } + for (String key : hashMap2.keySet()) { + queries.add(key); + } + Iterator queryIterator = queries.iterator(); + while(queryIterator.hasNext()) { + String query = queryIterator.next(); + HashMap aggregated = aggregateQuery(hashMap1.getOrDefault(query,new HashMap()), hashMap2.getOrDefault(query,new HashMap())); + finalHashMap.put(query,aggregated); + } + return finalHashMap; + } + + + public static void main(String[] args) { + + Args fuseArgs = new Args(); + CmdLineParser parser = new CmdLineParser(fuseArgs, ParserProperties.defaults().withUsageWidth(120)); + + try { + parser.parseArgument(args); + } catch (CmdLineException e) { + if (fuseArgs.options) { + System.err.printf("Options for %s:\n\n", FuseRuns.class.getSimpleName()); + parser.printUsage(System.err); + + ArrayList required = new ArrayList(); + parser.getOptions().forEach((option) -> { + if (option.option.required()) { + required.add(option.option.toString()); + } + }); + + System.err.printf("\nRequired options are %s\n", required); + } else { + System.err.printf("Error: %s. For help, use \"-options\" to print out information about options.\n", + e.getMessage()); + } + + return; + } + + try { + TreeMap> runA = createRunMap(fuseArgs.filename_a); + TreeMap> runB = createRunMap(fuseArgs.filename_b); + normalize_min_max(runA); + normalize_min_max(runB); + TreeMap> finalRun = aggregateHashMap(runA,runB); + + // Merge and output + FusedRunOutputWriter out = new FusedRunOutputWriter(fuseArgs.filename_output, "trec", fuseArgs.runtag); + for (String key : finalRun.keySet()) { + out.writeTopic(key, finalRun.get(key)); + } + out.close(); + System.out.println("File " + fuseArgs.filename_output + " successfully created!"); + + } catch (IOException e) { + System.out.println("Error occurred: " + e.getMessage()); + } + } +} + From c626a7caaa033821853715d17b3c5ced5a9df94d Mon Sep 17 00:00:00 2001 From: DanielKohn1208 Date: Wed, 1 May 2024 10:38:19 -0400 Subject: [PATCH 02/33] added cadurosar's code via a copy + paste and made changes to match pyserini convention --- .../java/io/anserini/fusion/FuseRuns.java | 268 ++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 src/main/java/io/anserini/fusion/FuseRuns.java diff --git a/src/main/java/io/anserini/fusion/FuseRuns.java b/src/main/java/io/anserini/fusion/FuseRuns.java new file mode 100644 index 0000000000..3a09e62b9c --- /dev/null +++ b/src/main/java/io/anserini/fusion/FuseRuns.java @@ -0,0 +1,268 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.fusion; + +import java.io.BufferedReader; +import java.io.Closeable; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +import org.apache.commons.lang3.LocaleUtils; +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.Option; +import org.kohsuke.args4j.ParserProperties; +import org.kohsuke.args4j.spi.StringArrayOptionHandler; + +class FusedRunOutputWriter implements Closeable { + private final PrintWriter out; + private final String format; + private final String runtag; + + public FusedRunOutputWriter(String output, String format, String runtag) throws IOException { + this.out = new PrintWriter(Files.newBufferedWriter(Paths.get(output), StandardCharsets.UTF_8)); + this.format = format; + this.runtag = runtag; + } + +private class Document implements Comparable{ +String docid; +double score; + +public Document(String docid, double score) +{ +this.docid = docid; +this.score = score; +} + +@Override public int compareTo(Document a) +{ +return Double.compare(a.score,this.score); +} + +} + + public void writeTopic(String qid, HashMap results) { + int rank = 1; + ArrayList documents = new ArrayList<>(); + + for (Map.Entry entry : results.entrySet()) { + documents.add(new Document(entry.getKey(), entry.getValue())); + } + Collections.sort(documents); + for (Document r : documents) { + if ("msmarco".equals(format)) { + // MS MARCO output format: + out.append(String.format(Locale.US, "%s\t%s\t%d\n", qid, r.docid, rank)); + } else { + // Standard TREC format: + // + the first column is the topic number. + // + the second column is currently unused and should always be "Q0". + // + the third column is the official document identifier of the retrieved document. + // + the fourth column is the rank the document is retrieved. + // + the fifth column shows the score (integer or floating point) that generated the ranking. + // + the sixth column is called the "run tag" and should be a unique identifier for your + out.append(String.format(Locale.US, "%s Q0 %s %d %f %s\n", qid, r.docid, rank, r.score, runtag)); + } + rank++; + } + } + + @Override + public void close() { + out.flush(); + out.close(); + } +} + + + +public class FuseRuns { + + public static class Args { + @Option(name = "-options", usage = "Print information about options.") + public Boolean options = false; + + @Option(name = "-runs", handler = StringArrayOptionHandler.class, metaVar = "", required = true, + usage = "Path to both run files to fuse") + public String[] runs = new String[]{}; + @Option(name = "-output", metaVar = "[output]", required = true, usage = "Path to save the output") + public String output; + + @Option(name = "-runtag", metaVar = "[runtag]", usage = "Run tag for the fusion") + public String runtag = "fused"; + + // Currently useless, will eventually be needed to match pyserini's fusion implementation + // @Option(name = "-method", metaVar = "[method]", required = false, usage = "Specify fusionm method") + // public String method = "default"; + // + // @Option(name = "-rrf_k", metaVar = "[rrf_k]", required = false, usage = "Parameter k needed for reciprocal rank fusion.") + // public int rrf_k = 60; + // + // @Option(name = "-alpha", required = false, usage = "Alpha value used for interpolation.") + // public double alpha = 0.5; + // + // @Option(name = "-k", required = false, usage = "Alpha value used for interpolation") + // public int k = 1000; + // + // @Option (name = "-resort", usage="We Resort the Trec run files or not") + // public boolean resort = false; + // + + } + + public static TreeMap> createRunMap(String filename){ + TreeMap> twoTierHashMap = new TreeMap<>(); + try (BufferedReader br = new BufferedReader(new FileReader(filename))) { + String line; + while ((line = br.readLine()) != null) { + String[] data = line.split(" "); + HashMap innerHashMap = new HashMap(); + HashMap innerHashMap_ = twoTierHashMap.putIfAbsent(data[0], innerHashMap); + if (innerHashMap_ != null){ + innerHashMap = innerHashMap_; + } + innerHashMap.put(data[2], Double.valueOf(data[4])); + } + } catch (FileNotFoundException ex){ + System.out.println(ex); + } catch (IOException ex){ + System.out.println(ex); + } + return twoTierHashMap; + } + + public static void normalize_min_max(TreeMap> hashMap) { + for (String outerKey : hashMap.keySet()) { + Map innerHashMap = hashMap.get(outerKey); + Double min = Double.MAX_VALUE; + Double max = -1.0; + for (String innerKey : innerHashMap.keySet()) { + Double innerValue = innerHashMap.get(innerKey); + if (innerValue < min) { + min = innerValue; + } + if (innerValue > max) { + max = innerValue; + } + } + for (String innerKey : innerHashMap.keySet()) { + Double innerValue = innerHashMap.get(innerKey); + Double newValue = (innerValue - min) / (max-min); + innerHashMap.replace(innerKey,innerValue,newValue); + } + } + } + + public static HashMap aggregateQuery(HashMap hashMap1, HashMap hashMap2) { + HashMap mergedHashMap = new HashMap(); + for (String key : hashMap1.keySet()) { + mergedHashMap.put(key, hashMap1.get(key)); + } + for (String key : hashMap2.keySet()) { + Double existingValue = mergedHashMap.getOrDefault(key,0.0); + mergedHashMap.put(key, hashMap2.get(key) + existingValue); + } + return mergedHashMap; + } + + public static TreeMap> aggregateHashMap(TreeMap> hashMap1, TreeMap> hashMap2) { + Set queries = new HashSet(); + TreeMap> finalHashMap = new TreeMap>(); + for (String key : hashMap1.keySet()) { + queries.add(key); + } + for (String key : hashMap2.keySet()) { + queries.add(key); + } + Iterator queryIterator = queries.iterator(); + while(queryIterator.hasNext()) { + String query = queryIterator.next(); + HashMap aggregated = aggregateQuery(hashMap1.getOrDefault(query,new HashMap()), hashMap2.getOrDefault(query,new HashMap())); + finalHashMap.put(query,aggregated); + } + return finalHashMap; + } + + public static void main(String[] args) { + Args fuseArgs = new Args(); + CmdLineParser parser = new CmdLineParser(fuseArgs, ParserProperties.defaults().withUsageWidth(120)); + + // parse argumens + try { + parser.parseArgument(args); + if(fuseArgs.runs.length != 2) { + // TODO: THIS CONSTRUCTOR IS DEPRECATED + throw new CmdLineException(parser, "Expects exactly 2 run files"); + } + } catch (CmdLineException e) { + if (fuseArgs.options) { + System.err.printf("Options for %s:\n\n", FuseRuns.class.getSimpleName()); + parser.printUsage(System.err); + + ArrayList required = new ArrayList(); + parser.getOptions().forEach((option) -> { + if (option.option.required()) { + required.add(option.option.toString()); + } + }); + + System.err.printf("\nRequired options are %s\n", required); + } else { + System.err.printf("Error: %s. For help, use \"-options\" to print out information about options.\n", + e.getMessage()); + } + return; + } + + + try { + TreeMap> runA = createRunMap(fuseArgs.runs[0]); + TreeMap> runB = createRunMap(fuseArgs.runs[1]); + normalize_min_max(runA); + normalize_min_max(runB); + + TreeMap> finalRun = aggregateHashMap(runA, runB); + + FusedRunOutputWriter out = new FusedRunOutputWriter(fuseArgs.output, "trec", fuseArgs.runtag); + for (String key : finalRun.keySet()) { + out.writeTopic(key, finalRun.get(key)); + } + out.close(); + System.out.println("File " + fuseArgs.output + " was succesfully created"); + + } catch (IOException e) { + System.out.println("Error occured: " + e.getMessage()); + } + + } +} + From 7db24a9648bc0f64882e0e23ebd4b11d996b708d Mon Sep 17 00:00:00 2001 From: DanielKohn1208 Date: Wed, 1 May 2024 11:00:29 -0400 Subject: [PATCH 03/33] moved FuseRuns --- src/main/java/io/anserini/{search => fusion}/FuseRuns.java | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/main/java/io/anserini/{search => fusion}/FuseRuns.java (100%) diff --git a/src/main/java/io/anserini/search/FuseRuns.java b/src/main/java/io/anserini/fusion/FuseRuns.java similarity index 100% rename from src/main/java/io/anserini/search/FuseRuns.java rename to src/main/java/io/anserini/fusion/FuseRuns.java From ebd8ed4b55a395ea588839fac8ac9ea7b763dd2e Mon Sep 17 00:00:00 2001 From: DanielKohn1208 Date: Tue, 7 May 2024 21:17:09 -0400 Subject: [PATCH 04/33] added run fusion to match pyserini implementation --- .../java/io/anserini/fusion/FuseRuns.java | 148 ++++++------- .../java/io/anserini/fusion/FuseUtils.java | 198 ++++++++++++++++++ .../java/io/anserini/fusion/FuseRunsTest.java | 184 ++++++++++++++++ .../resources/simple_trec_run_fusion_1.txt | 3 + .../resources/simple_trec_run_fusion_2.txt | 3 + .../resources/simple_trec_run_fusion_3.txt | 3 + .../resources/simple_trec_run_fusion_4.txt | 3 + 7 files changed, 454 insertions(+), 88 deletions(-) create mode 100644 src/main/java/io/anserini/fusion/FuseUtils.java create mode 100644 src/test/java/io/anserini/fusion/FuseRunsTest.java create mode 100644 src/test/resources/simple_trec_run_fusion_1.txt create mode 100644 src/test/resources/simple_trec_run_fusion_2.txt create mode 100644 src/test/resources/simple_trec_run_fusion_3.txt create mode 100644 src/test/resources/simple_trec_run_fusion_4.txt diff --git a/src/main/java/io/anserini/fusion/FuseRuns.java b/src/main/java/io/anserini/fusion/FuseRuns.java index e6c09c3f48..a89000b770 100644 --- a/src/main/java/io/anserini/fusion/FuseRuns.java +++ b/src/main/java/io/anserini/fusion/FuseRuns.java @@ -27,14 +27,11 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; import java.util.Locale; import java.util.Map; -import java.util.Set; import java.util.TreeMap; -import org.apache.commons.lang3.LocaleUtils; +import org.apache.commons.lang3.NotImplementedException; import org.kohsuke.args4j.CmdLineException; import org.kohsuke.args4j.CmdLineParser; import org.kohsuke.args4j.Option; @@ -102,6 +99,17 @@ public void close() { } } +// Used to hold the score and the rank of a document +class DocScore{ + public Double score; + public int initialRank; + + public DocScore(Double score, int initialRank) { + this.score = score; + this.initialRank = initialRank; + } +} + public class FuseRuns { public static class Args { @@ -117,98 +125,48 @@ public static class Args { @Option(name = "-runtag", metaVar = "[runtag]", usage = "Run tag for the fusion") public String runtag = "fused"; - // Currently useless, will eventually be needed to match pyserini's fusion implementation - // @Option(name = "-method", metaVar = "[method]", required = false, usage = "Specify fusionm method") - // public String method = "default"; - // - // @Option(name = "-rrf_k", metaVar = "[rrf_k]", required = false, usage = "Parameter k needed for reciprocal rank fusion.") - // public int rrf_k = 60; - // - // @Option(name = "-alpha", required = false, usage = "Alpha value used for interpolation.") - // public double alpha = 0.5; - // - // @Option(name = "-k", required = false, usage = "Alpha value used for interpolation") - // public int k = 1000; - // - // @Option (name = "-resort", usage="We Resort the Trec run files or not") - // public boolean resort = false; - // + @Option(name = "-method", metaVar = "[method]", required = false, usage = "Specify fusion method") + public String method = "default"; + + @Option(name = "-rrf_k", metaVar = "[rrf_k]", required = false, usage = "Parameter k needed for reciprocal rank fusion.") + public double rrf_k = 60; + + @Option(name = "-alpha", required = false, usage = "Alpha value used for interpolation.") + public double alpha = 0.5; + + @Option(name = "-k", required = false, usage = "number of documents to output for topic") + public int k = 1000; + + @Option(name = "-depth", required = false, usage = "Pool depth per topic.") + public int depth = 1000; + + @Option (name = "-resort", usage="We Resort the Trec run files or not") + public boolean resort = false; + } - public static TreeMap> createRunMap(String filename){ - TreeMap> twoTierHashMap = new TreeMap<>(); + public static TreeMap> createRunMap(String filename) throws FileNotFoundException, IOException { + TreeMap> twoTierHashMap = new TreeMap<>(); try (BufferedReader br = new BufferedReader(new FileReader(filename))) { String line; while ((line = br.readLine()) != null) { String[] data = line.split(" "); - HashMap innerHashMap = new HashMap(); - HashMap innerHashMap_ = twoTierHashMap.putIfAbsent(data[0], innerHashMap); + HashMap innerHashMap = new HashMap(); + HashMap innerHashMap_ = twoTierHashMap.putIfAbsent(data[0], innerHashMap); if (innerHashMap_ != null){ innerHashMap = innerHashMap_; } - innerHashMap.put(data[2], Double.valueOf(data[4])); + innerHashMap.put(data[2], new DocScore(Double.valueOf(data[4]), Integer.parseInt(data[3]))); } } catch (FileNotFoundException ex){ - System.out.println(ex); + throw ex; } catch (IOException ex){ - System.out.println(ex); + throw ex; } return twoTierHashMap; } - public static void normalize_min_max(TreeMap> hashMap) { - for (String outerKey : hashMap.keySet()) { - Map innerHashMap = hashMap.get(outerKey); - Double min = Double.MAX_VALUE; - Double max = -1.0; - for (String innerKey : innerHashMap.keySet()) { - Double innerValue = innerHashMap.get(innerKey); - if (innerValue < min) { - min = innerValue; - } - if (innerValue > max) { - max = innerValue; - } - } - for (String innerKey : innerHashMap.keySet()) { - Double innerValue = innerHashMap.get(innerKey); - Double newValue = (innerValue - min) / (max-min); - innerHashMap.replace(innerKey,innerValue,newValue); - } - } - } - - public static HashMap aggregateQuery(HashMap hashMap1, HashMap hashMap2) { - HashMap mergedHashMap = new HashMap(); - for (String key : hashMap1.keySet()) { - mergedHashMap.put(key, hashMap1.get(key)); - } - for (String key : hashMap2.keySet()) { - Double existingValue = mergedHashMap.getOrDefault(key,0.0); - mergedHashMap.put(key, hashMap2.get(key) + existingValue); - } - return mergedHashMap; - } - - public static TreeMap> aggregateHashMap(TreeMap> hashMap1, TreeMap> hashMap2) { - Set queries = new HashSet(); - TreeMap> finalHashMap = new TreeMap>(); - for (String key : hashMap1.keySet()) { - queries.add(key); - } - for (String key : hashMap2.keySet()) { - queries.add(key); - } - Iterator queryIterator = queries.iterator(); - while(queryIterator.hasNext()) { - String query = queryIterator.next(); - HashMap aggregated = aggregateQuery(hashMap1.getOrDefault(query,new HashMap()), hashMap2.getOrDefault(query,new HashMap())); - finalHashMap.put(query,aggregated); - } - return finalHashMap; - } - public static void main(String[] args) { Args fuseArgs = new Args(); CmdLineParser parser = new CmdLineParser(fuseArgs, ParserProperties.defaults().withUsageWidth(120)); @@ -216,9 +174,13 @@ public static void main(String[] args) { // parse argumens try { parser.parseArgument(args); + // TODO: THESE CONSTRUCTORS ARE DEPRECATED if(fuseArgs.runs.length != 2) { - // TODO: THIS CONSTRUCTOR IS DEPRECATED - throw new CmdLineException(parser, "Expects exactly 2 run files"); + throw new CmdLineException(parser, "Option run expects exactly 2 files"); + } else if (fuseArgs.depth <= 0) { + throw new CmdLineException(parser, "Option depth must be greater than 0"); + } else if (fuseArgs.k <= 0) { + throw new CmdLineException(parser, "Option k must be greater than 0"); } } catch (CmdLineException e) { if (fuseArgs.options) { @@ -242,12 +204,21 @@ public static void main(String[] args) { try { - TreeMap> runA = createRunMap(fuseArgs.runs[0]); - TreeMap> runB = createRunMap(fuseArgs.runs[1]); - normalize_min_max(runA); - normalize_min_max(runB); + // TreeMap> + TreeMap> runA = createRunMap(fuseArgs.runs[0]); + TreeMap> runB = createRunMap(fuseArgs.runs[1]); - TreeMap> finalRun = aggregateHashMap(runA, runB); + TreeMap> finalRun; + + if (fuseArgs.method.equals(FusionMethods.AVERAGE)) { + finalRun = FusionMethods.average(runA, runB, fuseArgs.depth, fuseArgs.k); + } else if (fuseArgs.method.equals(FusionMethods.RRF)) { + finalRun = FusionMethods.reciprocal_rank_fusion(runA, runB, fuseArgs.rrf_k, fuseArgs.depth, fuseArgs.k); + } else if (fuseArgs.method.equals(FusionMethods.INTERPOLATION)) { + finalRun = FusionMethods.interpolation(runA, runB, fuseArgs.alpha, fuseArgs.depth, fuseArgs.k); + } else { + throw new NotImplementedException("This method has not yet been implemented: " + fuseArgs.method); + } FusedRunOutputWriter out = new FusedRunOutputWriter(fuseArgs.output, "trec", fuseArgs.runtag); for (String key : finalRun.keySet()) { @@ -257,9 +228,10 @@ public static void main(String[] args) { System.out.println("File " + fuseArgs.output + " was succesfully created"); } catch (IOException e) { - System.out.println("Error occured: " + e.getMessage()); + System.err.println("Error occured: " + e.getMessage()); + } catch (NotImplementedException e) { + System.err.println("Error occured: " + e.getMessage()); } - } } diff --git a/src/main/java/io/anserini/fusion/FuseUtils.java b/src/main/java/io/anserini/fusion/FuseUtils.java new file mode 100644 index 0000000000..cc6a0340d8 --- /dev/null +++ b/src/main/java/io/anserini/fusion/FuseUtils.java @@ -0,0 +1,198 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.anserini.fusion; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.Map.Entry; + +class FusionMethods { + static final String AVERAGE = "average"; + static final String RRF = "rrf"; + static final String INTERPOLATION = "interpolation"; + + public static TreeMap> average(TreeMap> runA, + TreeMap> runB, int depth, int k) { + + RescoreMethods.scale(runA, 1 / (double) (runA.size())); + RescoreMethods.scale(runB, 1 / (double) (runB.size())); + + return AggregationMethods.sum(runA, runB, depth, k); + } + + public static TreeMap> reciprocal_rank_fusion(TreeMap> runA, + TreeMap> runB, double rrf_k, int depth, int k) { + + RescoreMethods.rrf(runA, rrf_k); + RescoreMethods.rrf(runB, rrf_k); + + return AggregationMethods.sum(runA, runB, depth, k); + } + + public static TreeMap> interpolation(TreeMap> runA, + TreeMap> runB, double alpha, int depth, int k) { + + RescoreMethods.scale(runA, alpha); + RescoreMethods.scale(runB, 1 - alpha); + + return AggregationMethods.sum(runA, runB, depth, k); + } + +} + +class AggregationMethods { + public static TreeMap> sum(TreeMap> runA, + TreeMap> runB, int depth, int k) { + Set queries = new HashSet(); + TreeMap> finalHashMap = new TreeMap>(); + + // add all keys into set of queries + for (String key : runA.keySet()) { + queries.add(key); + } + for (String key : runB.keySet()) { + queries.add(key); + } + Iterator queryIterator = queries.iterator(); + while (queryIterator.hasNext()) { + String query = queryIterator.next(); + HashMap aggregated = sumIndividualTopic( + runA.getOrDefault(query, new HashMap()), + runB.getOrDefault(query, new HashMap()), depth, k); + finalHashMap.put(query, aggregated); + } + + return finalHashMap; + } + + private static HashMap sumIndividualTopic(HashMap docDataA, + HashMap docDataB, int depth, int k) { + HashMap mergedHashMap = new HashMap(); + + // shrink entries + shrinkToNEntriesDepth(docDataA, depth); + shrinkToNEntriesDepth(docDataB, depth); + for (String key : docDataA.keySet()) { + mergedHashMap.put(key, docDataA.get(key).score); + } + for (String key : docDataB.keySet()) { + Double existingValue = mergedHashMap.getOrDefault(key, 0.0); + mergedHashMap.put(key, docDataB.get(key).score + existingValue); + } + shrinkToNEntriesOutput(mergedHashMap, k); + return mergedHashMap; + } + + private static void shrinkToNEntriesDepth(Map map, int n) { + // we keep the entires with highest scores + int amountToRemove = map.size() - n; + if (amountToRemove <= 0) { + return; + } + + ArrayList> asList = new ArrayList>(); + for (Entry entry : map.entrySet()) { + asList.add(entry); + } + + Collections.sort(asList, new Comparator>() { + @Override + public int compare(Entry o1, Entry o2) { + return o1.getValue().score.compareTo(o2.getValue().score); + } + }); + + for (int i = 0; i < n; i++) { + map.remove(asList.get(i).getKey()); + } + } + + private static void shrinkToNEntriesOutput(HashMap hashMap, int n) { + int amountToRemove = hashMap.size() - n; + if (amountToRemove <= 0) { + return; + } + + ArrayList> asList = new ArrayList>(); + for (Entry entry : hashMap.entrySet()) { + asList.add(entry); + } + Collections.sort(asList, new Comparator>() { + @Override + public int compare(Entry o1, Entry o2) { + return o1.getValue().compareTo(o2.getValue()); + } + }); + + for (int i = 0; i < n; i++) { + hashMap.remove(asList.get(i).getKey()); + } + } + +} + +class RescoreMethods { + public static void normalize(Map> hashMap) { + for (String outerKey : hashMap.keySet()) { + Map innerHashMap = hashMap.get(outerKey); + Double min = Double.MAX_VALUE; + Double max = -1.0; + for (String innerKey : innerHashMap.keySet()) { + Double innerValue = innerHashMap.get(innerKey).score; + if (innerValue < min) { + min = innerValue; + } + if (innerValue > max) { + max = innerValue; + } + } + for (String innerKey : innerHashMap.keySet()) { + // Double innerValue = innerHashMap.get(innerKey).score; + // Double newValue = (innerValue - min) / (max - min); + DocScore innerValue = innerHashMap.get(innerKey); + innerValue.score = (innerValue.score - min) / (max - min); + } + } + } + + public static void scale(Map> hashMap, double scale) { + for (String outerKey : hashMap.keySet()) { + Map innerHashMap = hashMap.get(outerKey); + for (String innerKey : innerHashMap.keySet()) { + DocScore innerValue = innerHashMap.get(innerKey); + innerValue.score *= scale; + } + } + } + + public static void rrf(Map> hashMap, double rrf_k) { + for (String outerKey : hashMap.keySet()) { + Map innerHashMap = hashMap.get(outerKey); + for (String innerKey : innerHashMap.keySet()) { + DocScore innerValue = innerHashMap.get(innerKey); + innerValue.score = 1 /((double)innerValue.initialRank + rrf_k); + } + } + } + +} diff --git a/src/test/java/io/anserini/fusion/FuseRunsTest.java b/src/test/java/io/anserini/fusion/FuseRunsTest.java new file mode 100644 index 0000000000..5ee5495e2e --- /dev/null +++ b/src/test/java/io/anserini/fusion/FuseRunsTest.java @@ -0,0 +1,184 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.anserini.fusion; + +import io.anserini.TestUtils; +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.core.config.Configurator; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.PrintStream; + +import static org.junit.Assert.assertTrue; + + +public class FuseRunsTest { + private final ByteArrayOutputStream err = new ByteArrayOutputStream(); + private PrintStream save; + + @BeforeClass + public static void setupClass() { + Configurator.setLevel(FuseRuns.class.getName(), Level.ERROR); + } + + private void redirectStderr() { + save = System.err; + err.reset(); + System.setErr(new PrintStream(err)); + } + + private void restoreStderr() { + System.setErr(save); + } + + @Test + public void testReciprocalRankFusionSimple() throws Exception { + redirectStderr(); + FuseRuns.main(new String[] { + "-method", "rrf" , + "-runs", "src/test/resources/simple_trec_run_fusion_1.txt", "src/test/resources/simple_trec_run_fusion_2.txt", + "-output", "fuse.test", + "-runtag", "test" + }); + + TestUtils.checkFile("fuse.test", new String[]{ + "1 Q0 pyeb86on 1 0.032787 test", + "1 Q0 2054tkb7 2 0.032002 test", + "2 Q0 hanxiao2 1 0.016393 test", + "3 Q0 hanxiao2 1 0.016393 test"}); + assertTrue(new File("fuse.test").delete()); + restoreStderr(); + } + + + @Test + public void testAverageFusionSimple() throws Exception { + redirectStderr(); + FuseRuns.main(new String[] { + "-method", "average" , + "-runs", "src/test/resources/simple_trec_run_fusion_1.txt", "src/test/resources/simple_trec_run_fusion_2.txt", + "-output", "fuse.test", + "-runtag", "test" + }); + + TestUtils.checkFile("fuse.test", new String[]{ + "1 Q0 pyeb86on 1 13.200000 test", + "1 Q0 2054tkb7 2 7.150000 test", + "2 Q0 hanxiao2 1 49.500000 test", + "3 Q0 hanxiao2 1 1.650000 test"}); + assertTrue(new File("fuse.test").delete()); + restoreStderr(); + } + + @Test + public void testInterpolationFusionSimple() throws Exception { + redirectStderr(); + FuseRuns.main(new String[] { + "-method", "interpolation" , + "-runs", "src/test/resources/simple_trec_run_fusion_1.txt", "src/test/resources/simple_trec_run_fusion_2.txt", + "-output", "fuse.test", + "-alpha", "0.4", + "-runtag", "test" + }); + + TestUtils.checkFile("fuse.test", new String[]{ + "1 Q0 pyeb86on 1 11.040000 test", + "1 Q0 2054tkb7 2 5.980000 test", + "2 Q0 hanxiao2 1 39.600000 test", + "3 Q0 hanxiao2 1 1.980000 test"}); + assertTrue(new File("fuse.test").delete()); + restoreStderr(); + } + + @Test + public void testDepthAndKVariance() throws Exception { + redirectStderr(); + FuseRuns.main(new String[] { + "-method", "rrf", + "-runs", "src/test/resources/simple_trec_run_fusion_3.txt", "src/test/resources/simple_trec_run_fusion_4.txt", + "-output", "fuse.test", + "-runtag", "test", + "-k", "1", + "-depth", "2" + }); + + TestUtils.checkFile("fuse.test", new String[] { + "1 Q0 hanxiao2 1 0.032787 test" + }); + + assertTrue(new File("fuse.test").delete()); + restoreStderr(); + } + + @Test + public void testInvalidArguments() throws Exception { + redirectStderr(); + + FuseRuns.main(new String[] { + "-method", "nonexistentmethod", + "-runs", "src/test/resources/simple_trec_run_fusion_3.txt", "src/test/resources/simple_trec_run_fusion_4.txt", + "-output", "fuse.test", + "-runtag", "test", + }); + assertTrue(err.toString().contains("This method has not yet been implemented")); + err.reset(); + + FuseRuns.main(new String[] { + "-method", "rrf", + "-runs", "src/test/resources/nonexistentfilethatwillneverexist.txt", "src/test/resources/simple_trec_run_fusion_4.txt", + "-output", "fuse.test", + "-runtag", "test", + }); + assertTrue(err.toString().contains("Error occured: src/test/resources/nonexistentfilethatwillneverexist.txt (No such file or directory)")); + err.reset(); + + FuseRuns.main(new String[] { + "-method", "rrf", + "-runs", "src/test/resources/simple_trec_run_fusion_3.txt", "src/test/resources/simple_trec_run_fusion_4.txt", + "-output", "fuse.test", + "-runtag", "test", + "-k", "0", + }); + assertTrue(err.toString().contains("Option k must be greater than 0")); + err.reset(); + + FuseRuns.main(new String[] { + "-method", "rrf", + "-runs", "src/test/resources/simple_trec_run_fusion_3.txt", "src/test/resources/simple_trec_run_fusion_4.txt", + "-output", "fuse.test", + "-runtag", "test", + "-depth", "0", + }); + assertTrue(err.toString().contains("Option depth must be greater than 0")); + err.reset(); + + FuseRuns.main(new String[] { + "-method", "rrf", + "-runs", "src/test/resources/simple_trec_run_fusion_3.txt", + "-output", "fuse.test", + "-runtag", "test", + }); + assertTrue(err.toString().contains("Option run expects exactly 2 files")); + err.reset(); + + restoreStderr(); + } + +} + diff --git a/src/test/resources/simple_trec_run_fusion_1.txt b/src/test/resources/simple_trec_run_fusion_1.txt new file mode 100644 index 0000000000..b6037a91df --- /dev/null +++ b/src/test/resources/simple_trec_run_fusion_1.txt @@ -0,0 +1,3 @@ +1 Q0 pyeb86on 1 24 reciprocal_rank_fusion_k=60 +1 Q0 2054tkb7 2 13 reciprocal_rank_fusion_k=60 +2 Q0 hanxiao2 1 99 reciprocal_rank_fusion_k=60 diff --git a/src/test/resources/simple_trec_run_fusion_2.txt b/src/test/resources/simple_trec_run_fusion_2.txt new file mode 100644 index 0000000000..bcc65c7f1a --- /dev/null +++ b/src/test/resources/simple_trec_run_fusion_2.txt @@ -0,0 +1,3 @@ +1 Q0 pyeb86on 1 2.4 reciprocal_rank_fusion_k=60 +1 Q0 2054tkb7 3 1.3 reciprocal_rank_fusion_k=60 +3 Q0 hanxiao2 1 3.3 reciprocal_rank_fusion_k=60 diff --git a/src/test/resources/simple_trec_run_fusion_3.txt b/src/test/resources/simple_trec_run_fusion_3.txt new file mode 100644 index 0000000000..ab00d8104a --- /dev/null +++ b/src/test/resources/simple_trec_run_fusion_3.txt @@ -0,0 +1,3 @@ +1 Q0 hanxiao2 1 99 reciprocal_rank_fusion_k=60 +1 Q0 pyeb86on 2 24 reciprocal_rank_fusion_k=60 +1 Q0 2054tkb7 3 13 reciprocal_rank_fusion_k=60 diff --git a/src/test/resources/simple_trec_run_fusion_4.txt b/src/test/resources/simple_trec_run_fusion_4.txt new file mode 100644 index 0000000000..c6767d2271 --- /dev/null +++ b/src/test/resources/simple_trec_run_fusion_4.txt @@ -0,0 +1,3 @@ +1 Q0 hanxiao2 1 3.3 reciprocal_rank_fusion_k=60 +1 Q0 pyeb86on 2 2.4 reciprocal_rank_fusion_k=60 +1 Q0 2054tkb7 3 1.3 reciprocal_rank_fusion_k=60 From 1b398af774d50decef16b15ddce71e0a6efed265 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Fri, 6 Sep 2024 13:52:40 -0400 Subject: [PATCH 05/33] added fusion feature --- .../java/io/anserini/fusion/FuseRuns.java | 237 ---------------- .../java/io/anserini/fusion/FuseTrecRuns.java | 140 +++++++++ .../java/io/anserini/fusion/FuseUtils.java | 198 ------------- .../java/io/anserini/fusion/TrecRunFuser.java | 134 +++++++++ .../io/anserini/trectools/RescoreMethod.java | 7 + .../java/io/anserini/trectools/TrecRun.java | 265 ++++++++++++++++++ 6 files changed, 546 insertions(+), 435 deletions(-) delete mode 100644 src/main/java/io/anserini/fusion/FuseRuns.java create mode 100644 src/main/java/io/anserini/fusion/FuseTrecRuns.java delete mode 100644 src/main/java/io/anserini/fusion/FuseUtils.java create mode 100644 src/main/java/io/anserini/fusion/TrecRunFuser.java create mode 100644 src/main/java/io/anserini/trectools/RescoreMethod.java create mode 100644 src/main/java/io/anserini/trectools/TrecRun.java diff --git a/src/main/java/io/anserini/fusion/FuseRuns.java b/src/main/java/io/anserini/fusion/FuseRuns.java deleted file mode 100644 index a89000b770..0000000000 --- a/src/main/java/io/anserini/fusion/FuseRuns.java +++ /dev/null @@ -1,237 +0,0 @@ -/* - * Anserini: A Lucene toolkit for reproducible information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.anserini.fusion; - -import java.io.BufferedReader; -import java.io.Closeable; -import java.io.FileNotFoundException; -import java.io.FileReader; -import java.io.IOException; -import java.io.PrintWriter; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.Locale; -import java.util.Map; -import java.util.TreeMap; - -import org.apache.commons.lang3.NotImplementedException; -import org.kohsuke.args4j.CmdLineException; -import org.kohsuke.args4j.CmdLineParser; -import org.kohsuke.args4j.Option; -import org.kohsuke.args4j.ParserProperties; -import org.kohsuke.args4j.spi.StringArrayOptionHandler; - -class FusedRunOutputWriter implements Closeable { - private final PrintWriter out; - private final String format; - private final String runtag; - - public FusedRunOutputWriter(String output, String format, String runtag) throws IOException { - this.out = new PrintWriter(Files.newBufferedWriter(Paths.get(output), StandardCharsets.UTF_8)); - this.format = format; - this.runtag = runtag; - } - - private class Document implements Comparable{ - String docid; - double score; - - public Document(String docid, double score) - { - this.docid = docid; - this.score = score; - } - - @Override public int compareTo(Document a) - { - return Double.compare(a.score,this.score); - } - - } - - public void writeTopic(String qid, HashMap results) { - int rank = 1; - ArrayList documents = new ArrayList<>(); - - for (Map.Entry entry : results.entrySet()) { - documents.add(new Document(entry.getKey(), entry.getValue())); - } - Collections.sort(documents); - for (Document r : documents) { - if ("msmarco".equals(format)) { - // MS MARCO output format: - out.append(String.format(Locale.US, "%s\t%s\t%d\n", qid, r.docid, rank)); - } else { - // Standard TREC format: - // + the first column is the topic number. - // + the second column is currently unused and should always be "Q0". - // + the third column is the official document identifier of the retrieved document. - // + the fourth column is the rank the document is retrieved. - // + the fifth column shows the score (integer or floating point) that generated the ranking. - // + the sixth column is called the "run tag" and should be a unique identifier for your - out.append(String.format(Locale.US, "%s Q0 %s %d %f %s\n", qid, r.docid, rank, r.score, runtag)); - } - rank++; - } - } - - @Override - public void close() { - out.flush(); - out.close(); - } -} - -// Used to hold the score and the rank of a document -class DocScore{ - public Double score; - public int initialRank; - - public DocScore(Double score, int initialRank) { - this.score = score; - this.initialRank = initialRank; - } -} - -public class FuseRuns { - - public static class Args { - @Option(name = "-options", usage = "Print information about options.") - public Boolean options = false; - - @Option(name = "-runs", handler = StringArrayOptionHandler.class, metaVar = "", required = true, - usage = "Path to both run files to fuse") - public String[] runs = new String[]{}; - @Option(name = "-output", metaVar = "[output]", required = true, usage = "Path to save the output") - public String output; - - @Option(name = "-runtag", metaVar = "[runtag]", usage = "Run tag for the fusion") - public String runtag = "fused"; - - @Option(name = "-method", metaVar = "[method]", required = false, usage = "Specify fusion method") - public String method = "default"; - - @Option(name = "-rrf_k", metaVar = "[rrf_k]", required = false, usage = "Parameter k needed for reciprocal rank fusion.") - public double rrf_k = 60; - - @Option(name = "-alpha", required = false, usage = "Alpha value used for interpolation.") - public double alpha = 0.5; - - @Option(name = "-k", required = false, usage = "number of documents to output for topic") - public int k = 1000; - - @Option(name = "-depth", required = false, usage = "Pool depth per topic.") - public int depth = 1000; - - @Option (name = "-resort", usage="We Resort the Trec run files or not") - public boolean resort = false; - - - } - - public static TreeMap> createRunMap(String filename) throws FileNotFoundException, IOException { - TreeMap> twoTierHashMap = new TreeMap<>(); - try (BufferedReader br = new BufferedReader(new FileReader(filename))) { - String line; - while ((line = br.readLine()) != null) { - String[] data = line.split(" "); - HashMap innerHashMap = new HashMap(); - HashMap innerHashMap_ = twoTierHashMap.putIfAbsent(data[0], innerHashMap); - if (innerHashMap_ != null){ - innerHashMap = innerHashMap_; - } - innerHashMap.put(data[2], new DocScore(Double.valueOf(data[4]), Integer.parseInt(data[3]))); - } - } catch (FileNotFoundException ex){ - throw ex; - } catch (IOException ex){ - throw ex; - } - return twoTierHashMap; - } - - public static void main(String[] args) { - Args fuseArgs = new Args(); - CmdLineParser parser = new CmdLineParser(fuseArgs, ParserProperties.defaults().withUsageWidth(120)); - - // parse argumens - try { - parser.parseArgument(args); - // TODO: THESE CONSTRUCTORS ARE DEPRECATED - if(fuseArgs.runs.length != 2) { - throw new CmdLineException(parser, "Option run expects exactly 2 files"); - } else if (fuseArgs.depth <= 0) { - throw new CmdLineException(parser, "Option depth must be greater than 0"); - } else if (fuseArgs.k <= 0) { - throw new CmdLineException(parser, "Option k must be greater than 0"); - } - } catch (CmdLineException e) { - if (fuseArgs.options) { - System.err.printf("Options for %s:\n\n", FuseRuns.class.getSimpleName()); - parser.printUsage(System.err); - - ArrayList required = new ArrayList(); - parser.getOptions().forEach((option) -> { - if (option.option.required()) { - required.add(option.option.toString()); - } - }); - - System.err.printf("\nRequired options are %s\n", required); - } else { - System.err.printf("Error: %s. For help, use \"-options\" to print out information about options.\n", - e.getMessage()); - } - return; - } - - - try { - // TreeMap> - TreeMap> runA = createRunMap(fuseArgs.runs[0]); - TreeMap> runB = createRunMap(fuseArgs.runs[1]); - - TreeMap> finalRun; - - if (fuseArgs.method.equals(FusionMethods.AVERAGE)) { - finalRun = FusionMethods.average(runA, runB, fuseArgs.depth, fuseArgs.k); - } else if (fuseArgs.method.equals(FusionMethods.RRF)) { - finalRun = FusionMethods.reciprocal_rank_fusion(runA, runB, fuseArgs.rrf_k, fuseArgs.depth, fuseArgs.k); - } else if (fuseArgs.method.equals(FusionMethods.INTERPOLATION)) { - finalRun = FusionMethods.interpolation(runA, runB, fuseArgs.alpha, fuseArgs.depth, fuseArgs.k); - } else { - throw new NotImplementedException("This method has not yet been implemented: " + fuseArgs.method); - } - - FusedRunOutputWriter out = new FusedRunOutputWriter(fuseArgs.output, "trec", fuseArgs.runtag); - for (String key : finalRun.keySet()) { - out.writeTopic(key, finalRun.get(key)); - } - out.close(); - System.out.println("File " + fuseArgs.output + " was succesfully created"); - - } catch (IOException e) { - System.err.println("Error occured: " + e.getMessage()); - } catch (NotImplementedException e) { - System.err.println("Error occured: " + e.getMessage()); - } - } -} - diff --git a/src/main/java/io/anserini/fusion/FuseTrecRuns.java b/src/main/java/io/anserini/fusion/FuseTrecRuns.java new file mode 100644 index 0000000000..59edf2f479 --- /dev/null +++ b/src/main/java/io/anserini/fusion/FuseTrecRuns.java @@ -0,0 +1,140 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.fusion; + +import org.apache.commons.lang3.NotImplementedException; +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.Option; +import org.kohsuke.args4j.ParserProperties; +import org.kohsuke.args4j.spi.StringArrayOptionHandler; + +import java.io.IOException; +import java.util.Arrays; +import java.util.ArrayList; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +import java.util.List; +import java.util.Map; +import java.util.SortedMap; +import java.util.TreeMap; + +import io.anserini.trectools.TrecRun; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +/** + * Main entry point for Fusion. + */ +public class FuseTrecRuns { + private static final Logger LOG = LogManager.getLogger(FuseTrecRuns.class); + + public static class Args extends TrecRunFuser.Args { + @Option(name = "-options", usage = "Print information about options.") + public Boolean options = false; + + @Option(name = "-runs", handler = StringArrayOptionHandler.class, metaVar = "[file]", required = true, + usage = "Path to both run files to fuse") + public String[] runs; + + @Option (name = "-resort", usage="We Resort the Trec run files or not") + public boolean resort = false; + } + + private final Args args; + private final TrecRunFuser fuser; + private final List runs = new ArrayList(); + + public FuseTrecRuns(Args args) throws IOException { + this.args = args; + this.fuser = new TrecRunFuser(args); + + LOG.info(String.format("============ Initializing %s ============", FuseTrecRuns.class.getSimpleName())); + LOG.info("Runs: " + Arrays.toString(args.runs)); + LOG.info("Run tag: " + args.runtag); + LOG.info("Fusion method: " + args.method); + LOG.info("Reciprocal Rank Fusion K value (rrf_k): " + args.rrf_k); + LOG.info("Alpha value for interpolation: " + args.alpha); + LOG.info("Max documents to output (k): " + args.k); + LOG.info("Pool depth: " + args.depth); + LOG.info("Resort TREC run files: " + args.resort); + + try { + // Ensure positive depth and k values + if (args.depth <= 0) { + throw new IllegalArgumentException("Option depth must be greater than 0"); + } + if (args.k <= 0) { + throw new IllegalArgumentException("Option k must be greater than 0"); + } + } catch (Exception e) { + throw new IllegalArgumentException(String.format("Error: %s. Please check the provided arguments. Use the \"-options\" flag to print out detailed information about available options and their usage.\n", + e.getMessage())); + } + + for (String runFile : args.runs) { + try { + Path path = Paths.get(runFile); + TrecRun run = new TrecRun(path, args.resort); + runs.add(run); + } catch (Exception e) { + throw new IllegalArgumentException(String.format("Error: %s. Please check the provided arguments. Use the \"-options\" flag to print out detailed information about available options and their usage.\n", + e.getMessage())); + } + } + } + + public void run() throws IOException { + LOG.info("============ Launching Fusion ============"); + fuser.fuse(runs); + } + + public static void main(String[] args) throws Exception { + Args fuseArgs = new Args(); + CmdLineParser parser = new CmdLineParser(fuseArgs, ParserProperties.defaults().withUsageWidth(120)); + + try { + parser.parseArgument(args); + } catch (CmdLineException e) { + if (fuseArgs.options) { + System.err.printf("Options for %s:\n\n", FuseTrecRuns.class.getSimpleName()); + parser.printUsage(System.err); + ArrayList required = new ArrayList<>(); + parser.getOptions().forEach(option -> { + if (option.option.required()) { + required.add(option.option.toString()); + } + }); + System.err.printf("\nRequired options are %s\n", required); + } else { + System.err.printf("Error: %s. For help, use \"-options\" to print out information about options.\n", + e.getMessage()); + } + return; + } + + try { + FuseTrecRuns fuser = new FuseTrecRuns(fuseArgs); + fuser.run(); + } catch (Exception e) { + System.err.println(e.getMessage()); + } + } +} diff --git a/src/main/java/io/anserini/fusion/FuseUtils.java b/src/main/java/io/anserini/fusion/FuseUtils.java deleted file mode 100644 index cc6a0340d8..0000000000 --- a/src/main/java/io/anserini/fusion/FuseUtils.java +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Anserini: A Lucene toolkit for reproducible information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.anserini.fusion; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.Map; -import java.util.Set; -import java.util.TreeMap; -import java.util.Map.Entry; - -class FusionMethods { - static final String AVERAGE = "average"; - static final String RRF = "rrf"; - static final String INTERPOLATION = "interpolation"; - - public static TreeMap> average(TreeMap> runA, - TreeMap> runB, int depth, int k) { - - RescoreMethods.scale(runA, 1 / (double) (runA.size())); - RescoreMethods.scale(runB, 1 / (double) (runB.size())); - - return AggregationMethods.sum(runA, runB, depth, k); - } - - public static TreeMap> reciprocal_rank_fusion(TreeMap> runA, - TreeMap> runB, double rrf_k, int depth, int k) { - - RescoreMethods.rrf(runA, rrf_k); - RescoreMethods.rrf(runB, rrf_k); - - return AggregationMethods.sum(runA, runB, depth, k); - } - - public static TreeMap> interpolation(TreeMap> runA, - TreeMap> runB, double alpha, int depth, int k) { - - RescoreMethods.scale(runA, alpha); - RescoreMethods.scale(runB, 1 - alpha); - - return AggregationMethods.sum(runA, runB, depth, k); - } - -} - -class AggregationMethods { - public static TreeMap> sum(TreeMap> runA, - TreeMap> runB, int depth, int k) { - Set queries = new HashSet(); - TreeMap> finalHashMap = new TreeMap>(); - - // add all keys into set of queries - for (String key : runA.keySet()) { - queries.add(key); - } - for (String key : runB.keySet()) { - queries.add(key); - } - Iterator queryIterator = queries.iterator(); - while (queryIterator.hasNext()) { - String query = queryIterator.next(); - HashMap aggregated = sumIndividualTopic( - runA.getOrDefault(query, new HashMap()), - runB.getOrDefault(query, new HashMap()), depth, k); - finalHashMap.put(query, aggregated); - } - - return finalHashMap; - } - - private static HashMap sumIndividualTopic(HashMap docDataA, - HashMap docDataB, int depth, int k) { - HashMap mergedHashMap = new HashMap(); - - // shrink entries - shrinkToNEntriesDepth(docDataA, depth); - shrinkToNEntriesDepth(docDataB, depth); - for (String key : docDataA.keySet()) { - mergedHashMap.put(key, docDataA.get(key).score); - } - for (String key : docDataB.keySet()) { - Double existingValue = mergedHashMap.getOrDefault(key, 0.0); - mergedHashMap.put(key, docDataB.get(key).score + existingValue); - } - shrinkToNEntriesOutput(mergedHashMap, k); - return mergedHashMap; - } - - private static void shrinkToNEntriesDepth(Map map, int n) { - // we keep the entires with highest scores - int amountToRemove = map.size() - n; - if (amountToRemove <= 0) { - return; - } - - ArrayList> asList = new ArrayList>(); - for (Entry entry : map.entrySet()) { - asList.add(entry); - } - - Collections.sort(asList, new Comparator>() { - @Override - public int compare(Entry o1, Entry o2) { - return o1.getValue().score.compareTo(o2.getValue().score); - } - }); - - for (int i = 0; i < n; i++) { - map.remove(asList.get(i).getKey()); - } - } - - private static void shrinkToNEntriesOutput(HashMap hashMap, int n) { - int amountToRemove = hashMap.size() - n; - if (amountToRemove <= 0) { - return; - } - - ArrayList> asList = new ArrayList>(); - for (Entry entry : hashMap.entrySet()) { - asList.add(entry); - } - Collections.sort(asList, new Comparator>() { - @Override - public int compare(Entry o1, Entry o2) { - return o1.getValue().compareTo(o2.getValue()); - } - }); - - for (int i = 0; i < n; i++) { - hashMap.remove(asList.get(i).getKey()); - } - } - -} - -class RescoreMethods { - public static void normalize(Map> hashMap) { - for (String outerKey : hashMap.keySet()) { - Map innerHashMap = hashMap.get(outerKey); - Double min = Double.MAX_VALUE; - Double max = -1.0; - for (String innerKey : innerHashMap.keySet()) { - Double innerValue = innerHashMap.get(innerKey).score; - if (innerValue < min) { - min = innerValue; - } - if (innerValue > max) { - max = innerValue; - } - } - for (String innerKey : innerHashMap.keySet()) { - // Double innerValue = innerHashMap.get(innerKey).score; - // Double newValue = (innerValue - min) / (max - min); - DocScore innerValue = innerHashMap.get(innerKey); - innerValue.score = (innerValue.score - min) / (max - min); - } - } - } - - public static void scale(Map> hashMap, double scale) { - for (String outerKey : hashMap.keySet()) { - Map innerHashMap = hashMap.get(outerKey); - for (String innerKey : innerHashMap.keySet()) { - DocScore innerValue = innerHashMap.get(innerKey); - innerValue.score *= scale; - } - } - } - - public static void rrf(Map> hashMap, double rrf_k) { - for (String outerKey : hashMap.keySet()) { - Map innerHashMap = hashMap.get(outerKey); - for (String innerKey : innerHashMap.keySet()) { - DocScore innerValue = innerHashMap.get(innerKey); - innerValue.score = 1 /((double)innerValue.initialRank + rrf_k); - } - } - } - -} diff --git a/src/main/java/io/anserini/fusion/TrecRunFuser.java b/src/main/java/io/anserini/fusion/TrecRunFuser.java new file mode 100644 index 0000000000..163ad28663 --- /dev/null +++ b/src/main/java/io/anserini/fusion/TrecRunFuser.java @@ -0,0 +1,134 @@ +package io.anserini.fusion; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.List; + +import org.kohsuke.args4j.Option; + +import io.anserini.trectools.RescoreMethod; +import io.anserini.trectools.TrecRun; + + +public class TrecRunFuser { + private final Args args; + + public static class Args { + @Option(name = "-output", metaVar = "[output]", required = true, usage = "Path to save the output") + public String output; + + @Option(name = "-runtag", metaVar = "[runtag]", required = false, usage = "Run tag for the fusion") + public String runtag = "anserini.fusion"; + + @Option(name = "-method", metaVar = "[method]", required = false, usage = "Specify fusion method") + public String method = "rrf"; + + @Option(name = "-rrf_k", metaVar = "[rrf_k]", required = false, usage = "Parameter k needed for reciprocal rank fusion.") + public int rrf_k = 60; + + @Option(name = "-alpha", required = false, usage = "Alpha value used for interpolation.") + public double alpha = 0.5; + + @Option(name = "-k", required = false, usage = "number of documents to output for topic") + public int k = 1000; + + @Option(name = "-depth", required = false, usage = "Pool depth per topic.") + public int depth = 1000; + } + + public TrecRunFuser(Args args) { + this.args = args; + } + + /** + * Perform fusion by averaging on a list of TrecRun objects. + * + * @param runs List of TrecRun objects. + * @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered. + * @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked. + * @return Output TrecRun that combines input runs via averaging. + */ + public static TrecRun average(List runs, int depth, int k) { + + for (TrecRun run : runs) { + run.rescore(RescoreMethod.SCALE, 0, (1/(double)runs.size())); + } + + return TrecRun.merge(runs, depth, k); + } + + /** + * Perform reciprocal rank fusion on a list of TrecRun objects. Implementation follows Cormack et al. + * (SIGIR 2009) paper titled "Reciprocal Rank Fusion Outperforms Condorcet and Individual Rank Learning Methods." + * + * @param runs List of TrecRun objects. + * @param rrf_k Parameter to avoid vanishing importance of lower-ranked documents. Note that this is different from the *k* in top *k* retrieval; set to 60 by default, per Cormack et al. + * @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered. + * @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked. + * @return Output TrecRun that combines input runs via reciprocal rank fusion. + */ + public static TrecRun reciprocalRankFusion(List runs, int rrf_k, int depth, int k) { + + for (TrecRun run : runs) { + run.rescore(RescoreMethod.RRF, rrf_k, 0); + } + + return TrecRun.merge(runs, depth, k); + } + +/** + * Perform fusion by interpolation on a list of exactly two TrecRun objects. + * new_score = first_run_score * alpha + (1 - alpha) * second_run_score. + * + * @param runs List of TrecRun objects. Exactly two runs. + * @param alpha Parameter alpha will be applied on the first run and (1 - alpha) will be applied on the second run. + * @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered. + * @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked. + * @return Output TrecRun that combines input runs via interpolation. + */ + public static TrecRun interpolation(List runs, double alpha, int depth, int k) { + // Ensure exactly 2 runs are provided, as interpolation requires 2 runs + if (runs.size() != 2) { + throw new IllegalArgumentException("Interpolation requires exactly 2 runs"); + } + + runs.get(0).rescore(RescoreMethod.SCALE, 0, alpha); + runs.get(1).rescore(RescoreMethod.SCALE, 0, 1 - alpha); + + return TrecRun.merge(runs, depth, k); + } + + private void saveToTxt(TrecRun fusedRun) throws IOException { + Path outputPath = Paths.get(args.output); + fusedRun.saveToTxt(outputPath, args.runtag); + } + + /** + * Process the fusion of TrecRun objects based on the specified method. + * + * @param runs List of TrecRun objects to be fused. + * @throws IOException If an I/O error occurs while saving the output. + */ + public void fuse(List runs) throws IOException { + TrecRun fusedRun; + + // Select fusion method + switch (args.method.toLowerCase()) { + case "rrf": + fusedRun = reciprocalRankFusion(runs, args.rrf_k, args.depth, args.k); + break; + case "interpolation": + fusedRun = interpolation(runs, args.alpha, args.depth, args.k); + break; + case "average": + fusedRun = average(runs, args.depth, args.k); + break; + default: + throw new IllegalArgumentException("Unknown fusion method: " + args.method + + ". Supported methods are: average, rrf, interpolation."); + } + + saveToTxt(fusedRun); + } +} diff --git a/src/main/java/io/anserini/trectools/RescoreMethod.java b/src/main/java/io/anserini/trectools/RescoreMethod.java new file mode 100644 index 0000000000..7a8981ff37 --- /dev/null +++ b/src/main/java/io/anserini/trectools/RescoreMethod.java @@ -0,0 +1,7 @@ +package io.anserini.trectools; + +public enum RescoreMethod { + RRF, + SCALE, + NORMALIZE; +} diff --git a/src/main/java/io/anserini/trectools/TrecRun.java b/src/main/java/io/anserini/trectools/TrecRun.java new file mode 100644 index 0000000000..5894e997ae --- /dev/null +++ b/src/main/java/io/anserini/trectools/TrecRun.java @@ -0,0 +1,265 @@ +/* +* Anserini: A Lucene toolkit for reproducible information retrieval research +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package io.anserini.trectools; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.commons.io.FileUtils; + +/** + * Wrapper class for a TREC run. +*/ +public class TrecRun { + // Enum representing the columns in the TREC run file + public enum Column { + TOPIC, Q0, DOCID, RANK, SCORE, TAG + } + + private List> runData; + private Path filepath = null; + private Boolean reSort = false; + + // Constructor without reSort parameter + public TrecRun(Path filepath) throws IOException { + this(filepath, false); + } + + // Constructor with reSort parameter + public TrecRun(Path filepath, Boolean reSort) throws IOException { + this.resetData(); + this.filepath = filepath; + this.reSort = reSort; + this.readRun(filepath); + } + + // Constructor without parameters + public TrecRun() { + this.resetData(); + } + + private void resetData() { + runData = new ArrayList<>(); + } + + /** + * Reads a TREC run file and loads its data into the runData list. + * + * @param filepath Path to the TREC run file. + * @throws IOException If the file cannot be read. + */ + public void readRun(Path filepath) throws IOException { + try (BufferedReader br = new BufferedReader(new FileReader(filepath.toFile()))) { + String line; + while ((line = br.readLine()) != null) { + String[] data = line.split("\\s+"); + Map record = new EnumMap<>(Column.class); + + // Populate the record map with the parsed data + record.put(Column.TOPIC, data[0]); + record.put(Column.Q0, data[1]); + record.put(Column.DOCID, data[2]); + record.put(Column.TAG, data[5]); + + // Parse RANK as integer + int rankInt = Integer.parseInt(data[3]); + record.put(Column.RANK, rankInt); + + // Parse SCORE as double + double scoreFloat = Double.parseDouble(data[4]); + record.put(Column.SCORE, scoreFloat); + + // Add the record to runData + runData.add(record); + } + } + + if (reSort) { + runData.sort((record1, record2) -> { + int topicComparison = ((String)record1.get(Column.TOPIC)).compareTo((String)(record2.get(Column.TOPIC))); + if (topicComparison != 0) { + return topicComparison; + } + return Double.compare((Double)(record2.get(Column.SCORE)), (Double)record1.get(Column.SCORE)); + }); + String currentTopic = ""; + int rank = 1; + for (Map record : runData) { + String topic = (String) record.get(Column.TOPIC); + if (!topic.equals(currentTopic)) { + currentTopic = topic; + rank = 1; + } + record.put(Column.RANK, rank); + rank++; + } + } + } + + public Set getTopics() { + return runData.stream().map(record -> (String) record.get(Column.TOPIC)).collect(Collectors.toSet()); + } + + public TrecRun cloneRun() throws IOException { + TrecRun clone = new TrecRun(); + clone.runData = new ArrayList<>(this.runData); + clone.filepath = this.filepath; + clone.reSort = this.reSort; + return clone; + } + + /** + * Saves the TREC run data to a text file in the TREC run format. + * + * @param outputPath Path to the output file. + * @param tag Tag to be added to each record in the TREC run file. If null, the existing tags are retained. + * @throws IOException If an I/O error occurs while writing to the file. + * @throws IllegalStateException If the runData list is empty. + */ + public void saveToTxt(Path outputPath, String tag) throws IOException { + if (runData.isEmpty()) { + throw new IllegalStateException("Nothing to save. TrecRun is empty"); + } + if (tag != null) { + runData.forEach(record -> record.put(Column.TAG, tag)); + } + runData.sort(Comparator.comparing((Map r) -> Integer.parseInt((String) r.get(Column.TOPIC))) + .thenComparing(r -> (Double) r.get(Column.SCORE), Comparator.reverseOrder())); + FileUtils.writeLines(outputPath.toFile(), runData.stream() + .map(record -> record.entrySet().stream() + .map(entry -> { + if (entry.getKey() == Column.SCORE) { + return String.format("%.6f", entry.getValue()); + } else { + return entry.getValue().toString(); + } + }) + .collect(Collectors.joining(" "))) + .collect(Collectors.toList())); + } + + public List> getDocsByTopic(String topic, int maxDocs) { + return runData.stream() + .filter(record -> record.get(Column.TOPIC).equals(topic)) // Filter by topic + .limit(maxDocs > 0 ? maxDocs : Integer.MAX_VALUE) // Limit the number of docs if maxDocs > 0 + .collect(Collectors.toList()); // Collect as List> + } + + public TrecRun rescore(RescoreMethod method, int rrfK, double scale) { + switch (method) { + case RRF -> rescoreRRF(rrfK); + case SCALE -> rescoreScale(scale); + case NORMALIZE -> normalizeScores(); + default -> throw new UnsupportedOperationException("Unknown rescore method: " + method); + } + return this; + } + + private void rescoreRRF(int rrfK) { + runData.forEach(record -> { + double score = 1.0 / (rrfK + (Integer)(record.get(Column.RANK))); + record.put(Column.SCORE, score); + }); + } + + private void rescoreScale(double scale) { + runData.forEach(record -> { + double score = (Double) record.get(Column.SCORE) * scale; + record.put(Column.SCORE, score); + }); + } + + private void normalizeScores() { + for (String topic : getTopics()) { + List> topicRecords = runData.stream() + .filter(record -> record.get(Column.TOPIC).equals(topic)) + .collect(Collectors.toList()); + + double minScore = topicRecords.stream() + .mapToDouble(record -> (Double) record.get(Column.SCORE)) + .min().orElse(0.0); + double maxScore = topicRecords.stream() + .mapToDouble(record -> (Double) record.get(Column.SCORE)) + .max().orElse(1.0); + + for (Map record : topicRecords) { + double normalizedScore = ((Double) record.get(Column.SCORE) - minScore) / (maxScore - minScore); + record.put(Column.SCORE, normalizedScore); + } + } + } + + /** + * Merges multiple TrecRun instances into a single TrecRun instance. + * The merged run will contain the top documents for each topic, with scores summed across the input runs. + * + * @param runs List of TrecRun instances to merge. + * @param depth Maximum number of documents to consider from each run for each topic (null for no limit). + * @param k Maximum number of top documents to include in the merged run for each topic (null for no limit). + * @return A new TrecRun instance containing the merged results. + * @throws IllegalArgumentException if less than 2 runs are provided. + */ + public static TrecRun merge(List runs, Integer depth, Integer k) { + if (runs.size() < 2) { + throw new IllegalArgumentException("Merge requires at least 2 runs."); + } + + TrecRun mergedRun = new TrecRun(); + + Set topics = runs.stream().flatMap(run -> run.getTopics().stream()).collect(Collectors.toSet()); + + topics.forEach(topic -> { + Map docScores = new HashMap<>(); + for (TrecRun run : runs) { + run.getDocsByTopic(topic, depth != null ? depth : Integer.MAX_VALUE).forEach(record -> { + String docId = (String) record.get(Column.DOCID); + double score = (Double) record.get(Column.SCORE); + docScores.put(docId, docScores.getOrDefault(docId, 0.0) + score); + }); + } + List> sortedDocScores = docScores.entrySet().stream() + .sorted(Map.Entry.comparingByValue().reversed()) + .limit(k != null ? k : Integer.MAX_VALUE) + .collect(Collectors.toList()); + + for (int rank = 0; rank < sortedDocScores.size(); rank++) { + Map.Entry entry = sortedDocScores.get(rank); + Map record = new EnumMap<>(Column.class); + record.put(Column.TOPIC, topic); + record.put(Column.Q0, "Q0"); + record.put(Column.DOCID, entry.getKey()); + record.put(Column.RANK, rank + 1); + record.put(Column.SCORE, entry.getValue()); + record.put(Column.TAG, "merge_sum"); + mergedRun.runData.add(record); + } + }); + + return mergedRun; + } + +} \ No newline at end of file From 27b44dfa988642ab7b3827f797fd4b2bb0db1907 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Fri, 6 Sep 2024 17:13:55 -0400 Subject: [PATCH 06/33] modified arguments; added test cases --- .../java/io/anserini/fusion/FuseTrecRuns.java | 17 ++--- .../java/io/anserini/fusion/TrecRunFuser.java | 72 ++++++++++++------ .../io/anserini/fusion/FuseTrecRunsTest.java | 42 ++++++++++ .../io/anserini/trectools/TrecRunTest.java | 76 +++++++++++++++++++ 4 files changed, 170 insertions(+), 37 deletions(-) create mode 100644 src/test/java/io/anserini/fusion/FuseTrecRunsTest.java create mode 100644 src/test/java/io/anserini/trectools/TrecRunTest.java diff --git a/src/main/java/io/anserini/fusion/FuseTrecRuns.java b/src/main/java/io/anserini/fusion/FuseTrecRuns.java index 59edf2f479..3178ba67d7 100644 --- a/src/main/java/io/anserini/fusion/FuseTrecRuns.java +++ b/src/main/java/io/anserini/fusion/FuseTrecRuns.java @@ -16,30 +16,23 @@ package io.anserini.fusion; -import org.apache.commons.lang3.NotImplementedException; import org.kohsuke.args4j.CmdLineException; import org.kohsuke.args4j.CmdLineParser; import org.kohsuke.args4j.Option; import org.kohsuke.args4j.ParserProperties; import org.kohsuke.args4j.spi.StringArrayOptionHandler; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import java.io.IOException; import java.util.Arrays; import java.util.ArrayList; -import java.nio.file.Files; +import java.util.List; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.List; -import java.util.Map; -import java.util.SortedMap; -import java.util.TreeMap; - import io.anserini.trectools.TrecRun; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - /** * Main entry point for Fusion. */ @@ -47,14 +40,14 @@ public class FuseTrecRuns { private static final Logger LOG = LogManager.getLogger(FuseTrecRuns.class); public static class Args extends TrecRunFuser.Args { - @Option(name = "-options", usage = "Print information about options.") + @Option(name = "-options", required = false, usage = "Print information about options.") public Boolean options = false; @Option(name = "-runs", handler = StringArrayOptionHandler.class, metaVar = "[file]", required = true, usage = "Path to both run files to fuse") public String[] runs; - @Option (name = "-resort", usage="We Resort the Trec run files or not") + @Option (name = "-resort", required = false, metaVar = "[flag]", usage="We Resort the Trec run files or not") public boolean resort = false; } diff --git a/src/main/java/io/anserini/fusion/TrecRunFuser.java b/src/main/java/io/anserini/fusion/TrecRunFuser.java index 163ad28663..8e082ad0b3 100644 --- a/src/main/java/io/anserini/fusion/TrecRunFuser.java +++ b/src/main/java/io/anserini/fusion/TrecRunFuser.java @@ -1,3 +1,19 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.anserini.fusion; import java.io.IOException; @@ -10,10 +26,16 @@ import io.anserini.trectools.RescoreMethod; import io.anserini.trectools.TrecRun; - +/** + * Main logic class for Fusion + */ public class TrecRunFuser { private final Args args; + private static final String METHOD_RRF = "rrf"; + private static final String METHOD_INTERPOLATION = "interpolation"; + private static final String METHOD_AVERAGE = "average"; + public static class Args { @Option(name = "-output", metaVar = "[output]", required = true, usage = "Path to save the output") public String output; @@ -24,16 +46,16 @@ public static class Args { @Option(name = "-method", metaVar = "[method]", required = false, usage = "Specify fusion method") public String method = "rrf"; - @Option(name = "-rrf_k", metaVar = "[rrf_k]", required = false, usage = "Parameter k needed for reciprocal rank fusion.") + @Option(name = "-rrf_k", metaVar = "[number]", required = false, usage = "Parameter k needed for reciprocal rank fusion.") public int rrf_k = 60; - @Option(name = "-alpha", required = false, usage = "Alpha value used for interpolation.") + @Option(name = "-alpha", metaVar = "[value]", required = false, usage = "Alpha value used for interpolation.") public double alpha = 0.5; - @Option(name = "-k", required = false, usage = "number of documents to output for topic") + @Option(name = "-k", metaVar = "[number]", required = false, usage = "number of documents to output for topic") public int k = 1000; - @Option(name = "-depth", required = false, usage = "Pool depth per topic.") + @Option(name = "-depth", metaVar = "[number]", required = false, usage = "Pool depth per topic.") public int depth = 1000; } @@ -42,13 +64,13 @@ public TrecRunFuser(Args args) { } /** - * Perform fusion by averaging on a list of TrecRun objects. - * - * @param runs List of TrecRun objects. - * @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered. - * @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked. - * @return Output TrecRun that combines input runs via averaging. - */ + * Perform fusion by averaging on a list of TrecRun objects. + * + * @param runs List of TrecRun objects. + * @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered. + * @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked. + * @return Output TrecRun that combines input runs via averaging. + */ public static TrecRun average(List runs, int depth, int k) { for (TrecRun run : runs) { @@ -77,16 +99,16 @@ public static TrecRun reciprocalRankFusion(List runs, int rrf_k, int de return TrecRun.merge(runs, depth, k); } -/** - * Perform fusion by interpolation on a list of exactly two TrecRun objects. - * new_score = first_run_score * alpha + (1 - alpha) * second_run_score. - * - * @param runs List of TrecRun objects. Exactly two runs. - * @param alpha Parameter alpha will be applied on the first run and (1 - alpha) will be applied on the second run. - * @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered. - * @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked. - * @return Output TrecRun that combines input runs via interpolation. - */ + /** + * Perform fusion by interpolation on a list of exactly two TrecRun objects. + * new_score = first_run_score * alpha + (1 - alpha) * second_run_score. + * + * @param runs List of TrecRun objects. Exactly two runs. + * @param alpha Parameter alpha will be applied on the first run and (1 - alpha) will be applied on the second run. + * @param depth Maximum number of results from each input run to consider. Set to Integer.MAX_VALUE by default, which indicates that the complete list of results is considered. + * @param k Length of final results list. Set to Integer.MAX_VALUE by default, which indicates that the union of all input documents are ranked. + * @return Output TrecRun that combines input runs via interpolation. + */ public static TrecRun interpolation(List runs, double alpha, int depth, int k) { // Ensure exactly 2 runs are provided, as interpolation requires 2 runs if (runs.size() != 2) { @@ -115,13 +137,13 @@ public void fuse(List runs) throws IOException { // Select fusion method switch (args.method.toLowerCase()) { - case "rrf": + case METHOD_RRF: fusedRun = reciprocalRankFusion(runs, args.rrf_k, args.depth, args.k); break; - case "interpolation": + case METHOD_INTERPOLATION: fusedRun = interpolation(runs, args.alpha, args.depth, args.k); break; - case "average": + case METHOD_AVERAGE: fusedRun = average(runs, args.depth, args.k); break; default: diff --git a/src/test/java/io/anserini/fusion/FuseTrecRunsTest.java b/src/test/java/io/anserini/fusion/FuseTrecRunsTest.java new file mode 100644 index 0000000000..b8751a1167 --- /dev/null +++ b/src/test/java/io/anserini/fusion/FuseTrecRunsTest.java @@ -0,0 +1,42 @@ +package io.anserini.fusion; + +import java.io.IOException; + +import static org.junit.Assert.fail; +import org.junit.Test; +import org.kohsuke.args4j.CmdLineException; +import org.kohsuke.args4j.CmdLineParser; +import org.kohsuke.args4j.ParserProperties; + +public class FuseTrecRunsTest { + + @Test + public void testFuseTrecRunsRRF() throws IOException { + String[] args = { + "-runs", "runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-desc_title.txt", "runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-desc.txt", + "runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-title.txt", + "-output", "runs/testsrc/test/resources/fused_output.txt", + "-rrf_k", "60", + "-k", "1000", + "-depth", "1000", + "-resort" + }; + + FuseTrecRuns.Args fuseArgs = new FuseTrecRuns.Args(); + CmdLineParser parser = new CmdLineParser(fuseArgs, ParserProperties.defaults().withUsageWidth(120)); + + try { + parser.parseArgument(args); + } catch (CmdLineException e) { + fail("Argument parsing failed: " + e.getMessage()); + } + + FuseTrecRuns fuseTrecRuns = new FuseTrecRuns(fuseArgs); + fuseTrecRuns.run(); + + // Assert the existence of the output file + // assertTrue("Output file should exist", Paths.get("runs/testsrc/test/resources/fused_output.txt").toFile().exists()); + + // Further assertions on the output can be made by reading and validating the contents. + } +} diff --git a/src/test/java/io/anserini/trectools/TrecRunTest.java b/src/test/java/io/anserini/trectools/TrecRunTest.java new file mode 100644 index 0000000000..7c08b53951 --- /dev/null +++ b/src/test/java/io/anserini/trectools/TrecRunTest.java @@ -0,0 +1,76 @@ +package io.anserini.trectools; + +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import org.junit.Before; +import org.junit.Test; + +public class TrecRunTest { + private TrecRun trecRun; + private Path sampleFilePath; + + @Before + public void setUp() throws IOException { + sampleFilePath = Paths.get("runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-desc_title.txt"); + trecRun = new TrecRun(sampleFilePath, false); + } + + @Test + public void testReadRun() throws IOException { + assertEquals(114, trecRun.getTopics().size()); // Assuming sample file has 3 topics + } + + @Test + public void testGetDocsByTopic() { + List> docs = trecRun.getDocsByTopic("101", 0); + // System.out.println(docs); + assertNotNull(docs); + assertEquals(1000, docs.size()); // Assuming there are at least 10 documents for topic 101 + } + + @Test + public void testRescoreRRF() { + trecRun.rescore(RescoreMethod.RRF, 60, 1.0); + List> docs = trecRun.getDocsByTopic("101", 1); + System.out.println(docs.get(0).get(TrecRun.Column.SCORE)); + assertEquals(1.0 / 61, docs.get(0).get(TrecRun.Column.SCORE)); + } + + @Test + public void testNormalizeScores() { + trecRun.rescore(RescoreMethod.NORMALIZE, 0, 0); + List> docs = trecRun.getDocsByTopic("101", 0); + double maxScore = (Double) docs.get(0).get(TrecRun.Column.SCORE); + double minScore = (Double) docs.get(docs.size() - 1).get(TrecRun.Column.SCORE); + assertEquals(1.0, maxScore, 0.01); + assertEquals(0.0, minScore, 0.01); + } + + @Test + public void testMergeRuns() throws IOException { + TrecRun trecRun1 = new TrecRun(sampleFilePath); + TrecRun trecRun2 = new TrecRun(sampleFilePath); + TrecRun mergedRun = TrecRun.merge(Arrays.asList(trecRun1, trecRun2), null, 10); + Path outputPath = Paths.get("runs/testsrc/test/resources/output-merge.trec"); + mergedRun.saveToTxt(outputPath, "test_tag"); + + // assertEquals(mergedRun.getDocsByTopic("101", 1).get(0).get(TrecRun.Column.SCORE), 2 * (double) trecRun1.getDocsByTopic("101", 1).get(0).get(TrecRun.Column.SCORE)); + } + + @Test + public void testSaveToTxt() throws IOException { + Path outputPath = Paths.get("runs/testsrc/test/resources/output.trec"); + // trecRun.rescore(RescoreMethod.SCALE, 0, 2.0); + trecRun.saveToTxt(outputPath, "Anserini"); + // Re-load the saved run + TrecRun savedRun = new TrecRun(outputPath); + assertEquals(trecRun.getTopics().size(), savedRun.getTopics().size()); + } +} From 72e6e0637ccb8c656696311f4d723bfbd90e0072 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Fri, 6 Sep 2024 17:24:43 -0400 Subject: [PATCH 07/33] modified TrecRun class code style --- src/main/java/io/anserini/trectools/TrecRun.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/io/anserini/trectools/TrecRun.java b/src/main/java/io/anserini/trectools/TrecRun.java index 5894e997ae..9d441cf512 100644 --- a/src/main/java/io/anserini/trectools/TrecRun.java +++ b/src/main/java/io/anserini/trectools/TrecRun.java @@ -261,5 +261,4 @@ public static TrecRun merge(List runs, Integer depth, Integer k) { return mergedRun; } - } \ No newline at end of file From 5f7ec357aa1ddb06abfc6be913b1f09efbd60070 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Fri, 6 Sep 2024 17:25:58 -0400 Subject: [PATCH 08/33] added comment --- .../io/anserini/trectools/RescoreMethod.java | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/main/java/io/anserini/trectools/RescoreMethod.java b/src/main/java/io/anserini/trectools/RescoreMethod.java index 7a8981ff37..ca9a835f61 100644 --- a/src/main/java/io/anserini/trectools/RescoreMethod.java +++ b/src/main/java/io/anserini/trectools/RescoreMethod.java @@ -1,3 +1,19 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.anserini.trectools; public enum RescoreMethod { From 509049cd4a48a9cf240c500faf25d79091018827 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Sat, 7 Sep 2024 03:45:52 +0000 Subject: [PATCH 09/33] deleted test file from previous version --- .../java/io/anserini/fusion/FuseRunsTest.java | 184 ------------------ 1 file changed, 184 deletions(-) delete mode 100644 src/test/java/io/anserini/fusion/FuseRunsTest.java diff --git a/src/test/java/io/anserini/fusion/FuseRunsTest.java b/src/test/java/io/anserini/fusion/FuseRunsTest.java deleted file mode 100644 index 5ee5495e2e..0000000000 --- a/src/test/java/io/anserini/fusion/FuseRunsTest.java +++ /dev/null @@ -1,184 +0,0 @@ -/* - * Anserini: A Lucene toolkit for reproducible information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.anserini.fusion; - -import io.anserini.TestUtils; -import org.apache.logging.log4j.Level; -import org.apache.logging.log4j.core.config.Configurator; -import org.junit.BeforeClass; -import org.junit.Test; - -import java.io.ByteArrayOutputStream; -import java.io.File; -import java.io.PrintStream; - -import static org.junit.Assert.assertTrue; - - -public class FuseRunsTest { - private final ByteArrayOutputStream err = new ByteArrayOutputStream(); - private PrintStream save; - - @BeforeClass - public static void setupClass() { - Configurator.setLevel(FuseRuns.class.getName(), Level.ERROR); - } - - private void redirectStderr() { - save = System.err; - err.reset(); - System.setErr(new PrintStream(err)); - } - - private void restoreStderr() { - System.setErr(save); - } - - @Test - public void testReciprocalRankFusionSimple() throws Exception { - redirectStderr(); - FuseRuns.main(new String[] { - "-method", "rrf" , - "-runs", "src/test/resources/simple_trec_run_fusion_1.txt", "src/test/resources/simple_trec_run_fusion_2.txt", - "-output", "fuse.test", - "-runtag", "test" - }); - - TestUtils.checkFile("fuse.test", new String[]{ - "1 Q0 pyeb86on 1 0.032787 test", - "1 Q0 2054tkb7 2 0.032002 test", - "2 Q0 hanxiao2 1 0.016393 test", - "3 Q0 hanxiao2 1 0.016393 test"}); - assertTrue(new File("fuse.test").delete()); - restoreStderr(); - } - - - @Test - public void testAverageFusionSimple() throws Exception { - redirectStderr(); - FuseRuns.main(new String[] { - "-method", "average" , - "-runs", "src/test/resources/simple_trec_run_fusion_1.txt", "src/test/resources/simple_trec_run_fusion_2.txt", - "-output", "fuse.test", - "-runtag", "test" - }); - - TestUtils.checkFile("fuse.test", new String[]{ - "1 Q0 pyeb86on 1 13.200000 test", - "1 Q0 2054tkb7 2 7.150000 test", - "2 Q0 hanxiao2 1 49.500000 test", - "3 Q0 hanxiao2 1 1.650000 test"}); - assertTrue(new File("fuse.test").delete()); - restoreStderr(); - } - - @Test - public void testInterpolationFusionSimple() throws Exception { - redirectStderr(); - FuseRuns.main(new String[] { - "-method", "interpolation" , - "-runs", "src/test/resources/simple_trec_run_fusion_1.txt", "src/test/resources/simple_trec_run_fusion_2.txt", - "-output", "fuse.test", - "-alpha", "0.4", - "-runtag", "test" - }); - - TestUtils.checkFile("fuse.test", new String[]{ - "1 Q0 pyeb86on 1 11.040000 test", - "1 Q0 2054tkb7 2 5.980000 test", - "2 Q0 hanxiao2 1 39.600000 test", - "3 Q0 hanxiao2 1 1.980000 test"}); - assertTrue(new File("fuse.test").delete()); - restoreStderr(); - } - - @Test - public void testDepthAndKVariance() throws Exception { - redirectStderr(); - FuseRuns.main(new String[] { - "-method", "rrf", - "-runs", "src/test/resources/simple_trec_run_fusion_3.txt", "src/test/resources/simple_trec_run_fusion_4.txt", - "-output", "fuse.test", - "-runtag", "test", - "-k", "1", - "-depth", "2" - }); - - TestUtils.checkFile("fuse.test", new String[] { - "1 Q0 hanxiao2 1 0.032787 test" - }); - - assertTrue(new File("fuse.test").delete()); - restoreStderr(); - } - - @Test - public void testInvalidArguments() throws Exception { - redirectStderr(); - - FuseRuns.main(new String[] { - "-method", "nonexistentmethod", - "-runs", "src/test/resources/simple_trec_run_fusion_3.txt", "src/test/resources/simple_trec_run_fusion_4.txt", - "-output", "fuse.test", - "-runtag", "test", - }); - assertTrue(err.toString().contains("This method has not yet been implemented")); - err.reset(); - - FuseRuns.main(new String[] { - "-method", "rrf", - "-runs", "src/test/resources/nonexistentfilethatwillneverexist.txt", "src/test/resources/simple_trec_run_fusion_4.txt", - "-output", "fuse.test", - "-runtag", "test", - }); - assertTrue(err.toString().contains("Error occured: src/test/resources/nonexistentfilethatwillneverexist.txt (No such file or directory)")); - err.reset(); - - FuseRuns.main(new String[] { - "-method", "rrf", - "-runs", "src/test/resources/simple_trec_run_fusion_3.txt", "src/test/resources/simple_trec_run_fusion_4.txt", - "-output", "fuse.test", - "-runtag", "test", - "-k", "0", - }); - assertTrue(err.toString().contains("Option k must be greater than 0")); - err.reset(); - - FuseRuns.main(new String[] { - "-method", "rrf", - "-runs", "src/test/resources/simple_trec_run_fusion_3.txt", "src/test/resources/simple_trec_run_fusion_4.txt", - "-output", "fuse.test", - "-runtag", "test", - "-depth", "0", - }); - assertTrue(err.toString().contains("Option depth must be greater than 0")); - err.reset(); - - FuseRuns.main(new String[] { - "-method", "rrf", - "-runs", "src/test/resources/simple_trec_run_fusion_3.txt", - "-output", "fuse.test", - "-runtag", "test", - }); - assertTrue(err.toString().contains("Option run expects exactly 2 files")); - err.reset(); - - restoreStderr(); - } - -} - From 39f62a943f802dee1261cf712ad170f232da4c1c Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Sat, 7 Sep 2024 03:59:21 +0000 Subject: [PATCH 10/33] Added dependency for junit test --- pom.xml | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pom.xml b/pom.xml index 844afc6b1f..d706bd5f3b 100644 --- a/pom.xml +++ b/pom.xml @@ -516,5 +516,17 @@ + + junit + junit + 4.13.2 + test + + + org.junit.jupiter + junit-jupiter-engine + 5.8.2 + test + From 37e89fa200a83bed783cdb2cc735bceba58e3ce0 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Sat, 7 Sep 2024 21:40:17 +0000 Subject: [PATCH 11/33] resolved formatting; merged trectools module to fusion --- pom.xml | 8 ++++---- .../java/io/anserini/fusion/FuseTrecRuns.java | 2 -- .../{trectools => fusion}/RescoreMethod.java | 2 +- .../{trectools => fusion}/TrecRun.java | 16 ++++++++-------- .../java/io/anserini/fusion/TrecRunFuser.java | 3 --- .../io/anserini/fusion/FuseTrecRunsTest.java | 19 ++++++++++++++++--- .../{trectools => fusion}/TrecRunTest.java | 18 +++++++++++++++++- 7 files changed, 46 insertions(+), 22 deletions(-) rename src/main/java/io/anserini/{trectools => fusion}/RescoreMethod.java (95%) rename src/main/java/io/anserini/{trectools => fusion}/TrecRun.java (97%) rename src/test/java/io/anserini/{trectools => fusion}/TrecRunTest.java (80%) diff --git a/pom.xml b/pom.xml index d706bd5f3b..884e0d7a4c 100644 --- a/pom.xml +++ b/pom.xml @@ -523,10 +523,10 @@ test - org.junit.jupiter - junit-jupiter-engine - 5.8.2 - test + org.junit.jupiter + junit-jupiter-engine + 5.8.2 + test diff --git a/src/main/java/io/anserini/fusion/FuseTrecRuns.java b/src/main/java/io/anserini/fusion/FuseTrecRuns.java index 3178ba67d7..5f2d45a11f 100644 --- a/src/main/java/io/anserini/fusion/FuseTrecRuns.java +++ b/src/main/java/io/anserini/fusion/FuseTrecRuns.java @@ -31,8 +31,6 @@ import java.nio.file.Path; import java.nio.file.Paths; -import io.anserini.trectools.TrecRun; - /** * Main entry point for Fusion. */ diff --git a/src/main/java/io/anserini/trectools/RescoreMethod.java b/src/main/java/io/anserini/fusion/RescoreMethod.java similarity index 95% rename from src/main/java/io/anserini/trectools/RescoreMethod.java rename to src/main/java/io/anserini/fusion/RescoreMethod.java index ca9a835f61..e07e9f0827 100644 --- a/src/main/java/io/anserini/trectools/RescoreMethod.java +++ b/src/main/java/io/anserini/fusion/RescoreMethod.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.anserini.trectools; +package io.anserini.fusion; public enum RescoreMethod { RRF, diff --git a/src/main/java/io/anserini/trectools/TrecRun.java b/src/main/java/io/anserini/fusion/TrecRun.java similarity index 97% rename from src/main/java/io/anserini/trectools/TrecRun.java rename to src/main/java/io/anserini/fusion/TrecRun.java index 9d441cf512..14c63c7c21 100644 --- a/src/main/java/io/anserini/trectools/TrecRun.java +++ b/src/main/java/io/anserini/fusion/TrecRun.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.anserini.trectools; +package io.anserini.fusion; import java.io.BufferedReader; import java.io.FileReader; @@ -37,7 +37,7 @@ public class TrecRun { // Enum representing the columns in the TREC run file public enum Column { - TOPIC, Q0, DOCID, RANK, SCORE, TAG + TOPIC, Q0, DOCID, RANK, SCORE, TAG } private List> runData; @@ -51,19 +51,19 @@ public TrecRun(Path filepath) throws IOException { // Constructor with reSort parameter public TrecRun(Path filepath, Boolean reSort) throws IOException { - this.resetData(); - this.filepath = filepath; - this.reSort = reSort; - this.readRun(filepath); + this.resetData(); + this.filepath = filepath; + this.reSort = reSort; + this.readRun(filepath); } // Constructor without parameters public TrecRun() { - this.resetData(); + this.resetData(); } private void resetData() { - runData = new ArrayList<>(); + runData = new ArrayList<>(); } /** diff --git a/src/main/java/io/anserini/fusion/TrecRunFuser.java b/src/main/java/io/anserini/fusion/TrecRunFuser.java index 8e082ad0b3..4323e8b10a 100644 --- a/src/main/java/io/anserini/fusion/TrecRunFuser.java +++ b/src/main/java/io/anserini/fusion/TrecRunFuser.java @@ -23,9 +23,6 @@ import org.kohsuke.args4j.Option; -import io.anserini.trectools.RescoreMethod; -import io.anserini.trectools.TrecRun; - /** * Main logic class for Fusion */ diff --git a/src/test/java/io/anserini/fusion/FuseTrecRunsTest.java b/src/test/java/io/anserini/fusion/FuseTrecRunsTest.java index b8751a1167..12ed0465a9 100644 --- a/src/test/java/io/anserini/fusion/FuseTrecRunsTest.java +++ b/src/test/java/io/anserini/fusion/FuseTrecRunsTest.java @@ -1,3 +1,19 @@ +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package io.anserini.fusion; import java.io.IOException; @@ -34,9 +50,6 @@ public void testFuseTrecRunsRRF() throws IOException { FuseTrecRuns fuseTrecRuns = new FuseTrecRuns(fuseArgs); fuseTrecRuns.run(); - // Assert the existence of the output file - // assertTrue("Output file should exist", Paths.get("runs/testsrc/test/resources/fused_output.txt").toFile().exists()); - // Further assertions on the output can be made by reading and validating the contents. } } diff --git a/src/test/java/io/anserini/trectools/TrecRunTest.java b/src/test/java/io/anserini/fusion/TrecRunTest.java similarity index 80% rename from src/test/java/io/anserini/trectools/TrecRunTest.java rename to src/test/java/io/anserini/fusion/TrecRunTest.java index 7c08b53951..b548a4cd88 100644 --- a/src/test/java/io/anserini/trectools/TrecRunTest.java +++ b/src/test/java/io/anserini/fusion/TrecRunTest.java @@ -1,4 +1,20 @@ -package io.anserini.trectools; +/* + * Anserini: A Lucene toolkit for reproducible information retrieval research + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.anserini.fusion; import java.io.IOException; import java.nio.file.Path; From 54c74b4007a40f9cd5e70633a05e753367d6531b Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Sun, 8 Sep 2024 16:12:09 +0000 Subject: [PATCH 12/33] remove unused test cases --- .../io/anserini/fusion/FuseTrecRunsTest.java | 55 ----------- .../java/io/anserini/fusion/TrecRunTest.java | 92 ------------------- 2 files changed, 147 deletions(-) delete mode 100644 src/test/java/io/anserini/fusion/FuseTrecRunsTest.java delete mode 100644 src/test/java/io/anserini/fusion/TrecRunTest.java diff --git a/src/test/java/io/anserini/fusion/FuseTrecRunsTest.java b/src/test/java/io/anserini/fusion/FuseTrecRunsTest.java deleted file mode 100644 index 12ed0465a9..0000000000 --- a/src/test/java/io/anserini/fusion/FuseTrecRunsTest.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Anserini: A Lucene toolkit for reproducible information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.fusion; - -import java.io.IOException; - -import static org.junit.Assert.fail; -import org.junit.Test; -import org.kohsuke.args4j.CmdLineException; -import org.kohsuke.args4j.CmdLineParser; -import org.kohsuke.args4j.ParserProperties; - -public class FuseTrecRunsTest { - - @Test - public void testFuseTrecRunsRRF() throws IOException { - String[] args = { - "-runs", "runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-desc_title.txt", "runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-desc.txt", - "runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-title.txt", - "-output", "runs/testsrc/test/resources/fused_output.txt", - "-rrf_k", "60", - "-k", "1000", - "-depth", "1000", - "-resort" - }; - - FuseTrecRuns.Args fuseArgs = new FuseTrecRuns.Args(); - CmdLineParser parser = new CmdLineParser(fuseArgs, ParserProperties.defaults().withUsageWidth(120)); - - try { - parser.parseArgument(args); - } catch (CmdLineException e) { - fail("Argument parsing failed: " + e.getMessage()); - } - - FuseTrecRuns fuseTrecRuns = new FuseTrecRuns(fuseArgs); - fuseTrecRuns.run(); - - // Further assertions on the output can be made by reading and validating the contents. - } -} diff --git a/src/test/java/io/anserini/fusion/TrecRunTest.java b/src/test/java/io/anserini/fusion/TrecRunTest.java deleted file mode 100644 index b548a4cd88..0000000000 --- a/src/test/java/io/anserini/fusion/TrecRunTest.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Anserini: A Lucene toolkit for reproducible information retrieval research - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.anserini.fusion; - -import java.io.IOException; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.Arrays; -import java.util.List; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import org.junit.Before; -import org.junit.Test; - -public class TrecRunTest { - private TrecRun trecRun; - private Path sampleFilePath; - - @Before - public void setUp() throws IOException { - sampleFilePath = Paths.get("runs/testlong/run.neuclir22-zh-en-splade.splade.topics.neuclir22-en.splade.original-desc_title.txt"); - trecRun = new TrecRun(sampleFilePath, false); - } - - @Test - public void testReadRun() throws IOException { - assertEquals(114, trecRun.getTopics().size()); // Assuming sample file has 3 topics - } - - @Test - public void testGetDocsByTopic() { - List> docs = trecRun.getDocsByTopic("101", 0); - // System.out.println(docs); - assertNotNull(docs); - assertEquals(1000, docs.size()); // Assuming there are at least 10 documents for topic 101 - } - - @Test - public void testRescoreRRF() { - trecRun.rescore(RescoreMethod.RRF, 60, 1.0); - List> docs = trecRun.getDocsByTopic("101", 1); - System.out.println(docs.get(0).get(TrecRun.Column.SCORE)); - assertEquals(1.0 / 61, docs.get(0).get(TrecRun.Column.SCORE)); - } - - @Test - public void testNormalizeScores() { - trecRun.rescore(RescoreMethod.NORMALIZE, 0, 0); - List> docs = trecRun.getDocsByTopic("101", 0); - double maxScore = (Double) docs.get(0).get(TrecRun.Column.SCORE); - double minScore = (Double) docs.get(docs.size() - 1).get(TrecRun.Column.SCORE); - assertEquals(1.0, maxScore, 0.01); - assertEquals(0.0, minScore, 0.01); - } - - @Test - public void testMergeRuns() throws IOException { - TrecRun trecRun1 = new TrecRun(sampleFilePath); - TrecRun trecRun2 = new TrecRun(sampleFilePath); - TrecRun mergedRun = TrecRun.merge(Arrays.asList(trecRun1, trecRun2), null, 10); - Path outputPath = Paths.get("runs/testsrc/test/resources/output-merge.trec"); - mergedRun.saveToTxt(outputPath, "test_tag"); - - // assertEquals(mergedRun.getDocsByTopic("101", 1).get(0).get(TrecRun.Column.SCORE), 2 * (double) trecRun1.getDocsByTopic("101", 1).get(0).get(TrecRun.Column.SCORE)); - } - - @Test - public void testSaveToTxt() throws IOException { - Path outputPath = Paths.get("runs/testsrc/test/resources/output.trec"); - // trecRun.rescore(RescoreMethod.SCALE, 0, 2.0); - trecRun.saveToTxt(outputPath, "Anserini"); - // Re-load the saved run - TrecRun savedRun = new TrecRun(outputPath); - assertEquals(trecRun.getTopics().size(), savedRun.getTopics().size()); - } -} From 32e13c2e2b6d765f915a623ac4f3fa0af0c60366 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Sun, 8 Sep 2024 16:26:07 +0000 Subject: [PATCH 13/33] removed unused test files --- src/test/resources/simple_trec_run_fusion_1.txt | 3 --- src/test/resources/simple_trec_run_fusion_2.txt | 3 --- src/test/resources/simple_trec_run_fusion_3.txt | 3 --- src/test/resources/simple_trec_run_fusion_4.txt | 3 --- 4 files changed, 12 deletions(-) delete mode 100644 src/test/resources/simple_trec_run_fusion_1.txt delete mode 100644 src/test/resources/simple_trec_run_fusion_2.txt delete mode 100644 src/test/resources/simple_trec_run_fusion_3.txt delete mode 100644 src/test/resources/simple_trec_run_fusion_4.txt diff --git a/src/test/resources/simple_trec_run_fusion_1.txt b/src/test/resources/simple_trec_run_fusion_1.txt deleted file mode 100644 index b6037a91df..0000000000 --- a/src/test/resources/simple_trec_run_fusion_1.txt +++ /dev/null @@ -1,3 +0,0 @@ -1 Q0 pyeb86on 1 24 reciprocal_rank_fusion_k=60 -1 Q0 2054tkb7 2 13 reciprocal_rank_fusion_k=60 -2 Q0 hanxiao2 1 99 reciprocal_rank_fusion_k=60 diff --git a/src/test/resources/simple_trec_run_fusion_2.txt b/src/test/resources/simple_trec_run_fusion_2.txt deleted file mode 100644 index bcc65c7f1a..0000000000 --- a/src/test/resources/simple_trec_run_fusion_2.txt +++ /dev/null @@ -1,3 +0,0 @@ -1 Q0 pyeb86on 1 2.4 reciprocal_rank_fusion_k=60 -1 Q0 2054tkb7 3 1.3 reciprocal_rank_fusion_k=60 -3 Q0 hanxiao2 1 3.3 reciprocal_rank_fusion_k=60 diff --git a/src/test/resources/simple_trec_run_fusion_3.txt b/src/test/resources/simple_trec_run_fusion_3.txt deleted file mode 100644 index ab00d8104a..0000000000 --- a/src/test/resources/simple_trec_run_fusion_3.txt +++ /dev/null @@ -1,3 +0,0 @@ -1 Q0 hanxiao2 1 99 reciprocal_rank_fusion_k=60 -1 Q0 pyeb86on 2 24 reciprocal_rank_fusion_k=60 -1 Q0 2054tkb7 3 13 reciprocal_rank_fusion_k=60 diff --git a/src/test/resources/simple_trec_run_fusion_4.txt b/src/test/resources/simple_trec_run_fusion_4.txt deleted file mode 100644 index c6767d2271..0000000000 --- a/src/test/resources/simple_trec_run_fusion_4.txt +++ /dev/null @@ -1,3 +0,0 @@ -1 Q0 hanxiao2 1 3.3 reciprocal_rank_fusion_k=60 -1 Q0 pyeb86on 2 2.4 reciprocal_rank_fusion_k=60 -1 Q0 2054tkb7 3 1.3 reciprocal_rank_fusion_k=60 From a9d78041e9e5254746c49d4ce1132945413e3220 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 23 Sep 2024 05:06:49 +0000 Subject: [PATCH 14/33] added fusion regression script paired with two yaml test files --- src/main/python/run_fusion_regression.py | 180 ++++++++++++++++++ .../resources/fuse_regression/test_0.yaml | 65 +++++++ .../resources/fuse_regression/test_1.yaml | 74 +++++++ .../regression/beir-v1.0.0-robust04.flat.yaml | 1 + 4 files changed, 320 insertions(+) create mode 100644 src/main/python/run_fusion_regression.py create mode 100644 src/main/resources/fuse_regression/test_0.yaml create mode 100644 src/main/resources/fuse_regression/test_1.yaml diff --git a/src/main/python/run_fusion_regression.py b/src/main/python/run_fusion_regression.py new file mode 100644 index 0000000000..12ce16a2e8 --- /dev/null +++ b/src/main/python/run_fusion_regression.py @@ -0,0 +1,180 @@ +# +# Anserini: A Lucene toolkit for reproducible information retrieval research +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import argparse +import logging +import time +import yaml +from subprocess import call, Popen, PIPE + +# Constants +FUSE_COMMAND = 'bin/run.sh io.anserini.fusion.FuseTrecRuns' + +# Set up logging +logger = logging.getLogger('fusion_regression_test') +logger.setLevel(logging.INFO) +ch = logging.StreamHandler() +ch.setLevel(logging.INFO) +formatter = logging.Formatter('%(asctime)s %(levelname)s [python] %(message)s') +ch.setFormatter(formatter) +logger.addHandler(ch) + +def is_close(a: float, b: float, rel_tol: float = 1e-9, abs_tol: float = 0.0) -> bool: + """Check if two numbers are close within a given tolerance.""" + return abs(a - b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) + +def check_output(command: str) -> str: + """Run a shell command and return its output. Raise an error if the command fails.""" + process = Popen(command, shell=True, stdout=PIPE) + output, err = process.communicate() + if process.returncode == 0: + return output + else: + raise RuntimeError(f"Command {command} failed with error: {err}") + +def construct_fusion_commands(yaml_data: dict) -> list: + """ + Constructs the fusion commands from the YAML configuration. + + Args: + yaml_data (dict): The loaded YAML configuration. + + Returns: + list: A list of commands to be executed. + """ + return [ + [ + FUSE_COMMAND, + '-runs', ' '.join([run for run in yaml_data['runs']]), + '-output', method.get('output'), + '-method', method.get('name', 'average'), + '-k', str(method.get('k', 1000)), + '-depth', str(method.get('depth', 1000)), + '-rrf_k', str(method.get('rrf_k', 60)), + '-alpha', str(method.get('alpha', 0.5)) + ] + for method in yaml_data['methods'] + ] + +def run_fusion_commands(cmds: list): + """ + Run the fusion commands and log the results. + + Args: + cmds (list): List of fusion commands to run. + """ + for cmd_list in cmds: + cmd = ' '.join(cmd_list) + logger.info(f'Running command: {cmd}') + try: + return_code = call(cmd, shell=True) + if return_code != 0: + logger.error(f"Command failed with return code {return_code}: {cmd}") + except Exception as e: + logger.error(f"Error executing command {cmd}: {str(e)}") + +def evaluate_and_verify(yaml_data: dict, dry_run: bool): + """ + Runs the evaluation and verification of the fusion results. + + Args: + yaml_data (dict): The loaded YAML configuration. + dry_run (bool): If True, output commands without executing them. + """ + fail_str = '\033[91m[FAIL]\033[0m ' + ok_str = ' [OK] ' + failures = False + + logger.info('=' * 10 + ' Verifying Fusion Results ' + '=' * 10) + + for method in yaml_data['methods']: + for i, topic_set in enumerate(yaml_data['topics']): + for metric in yaml_data['metrics']: + output_runfile = str(method.get('output')) + + # Build evaluation command + eval_cmd = [ + os.path.join(metric['command']), + metric['params'] if 'params' in metric and metric['params'] else '', + os.path.join('tools/topics-and-qrels', topic_set['qrel']) if 'qrel' in topic_set and topic_set['qrel'] else '', + output_runfile + ] + + if dry_run: + logger.info(' '.join(eval_cmd)) + continue + + try: + out = check_output(' '.join(eval_cmd)).decode('utf-8').split('\n')[-1] + if not out.strip(): + continue + except Exception as e: + logger.error(f"Failed to execute evaluation command: {str(e)}") + continue + + eval_out = out.strip().split(metric['separator'])[metric['parse_index']] + expected = round(method['results'][metric['metric']][i], metric['metric_precision']) + actual = round(float(eval_out), metric['metric_precision']) + result_str = ( + f'expected: {expected:.4f} actual: {actual:.4f} (delta={abs(expected-actual):.4f}) - ' + f'metric: {metric["metric"]:<8} method: {method["name"]} topics: {topic_set["id"]}' + ) + + if is_close(expected, actual) or actual > expected: + logger.info(ok_str + result_str) + else: + logger.error(fail_str + result_str) + failures = True + + end_time = time.time() + logger.info(f"Total execution time: {end_time - start_time:.2f} seconds") + if failures: + logger.error(f'{fail_str}Some tests failed.') + else: + logger.info(f'All tests passed successfully!') + +if __name__ == '__main__': + start_time = time.time() + + # Command-line argument parsing + parser = argparse.ArgumentParser(description='Run Fusion regression tests.') + parser.add_argument('--regression', required=True, help='Name of the regression test configuration.') + parser.add_argument('--dry-run', dest='dry_run', action='store_true', + help='Output commands without actual execution.') + args = parser.parse_args() + + # Load YAML configuration + try: + with open(f'src/main/resources/fuse_regression/{args.regression}.yaml') as f: + yaml_data = yaml.safe_load(f) + except FileNotFoundError as e: + logger.error(f"Failed to load configuration file: {e}") + exit(1) + + # Construct the fusion command + fusion_commands = construct_fusion_commands(yaml_data) + + # Run the fusion process + if args.dry_run: + logger.info(' '.join([cmd for cmd_list in fusion_commands for cmd in cmd_list])) + else: + run_fusion_commands(fusion_commands) + + # Evaluate and verify results + evaluate_and_verify(yaml_data, args.dry_run) + + logger.info(f"Total execution time: {time.time() - start_time:.2f} seconds") diff --git a/src/main/resources/fuse_regression/test_0.yaml b/src/main/resources/fuse_regression/test_0.yaml new file mode 100644 index 0000000000..39cf774e9f --- /dev/null +++ b/src/main/resources/fuse_regression/test_0.yaml @@ -0,0 +1,65 @@ +--- +corpus: beir-v1.0.0-robust04.flat +corpus_path: collections/beir-v1.0.0/corpus/robust04/ + +metrics: + - metric: nDCG@10 + command: bin/trec_eval + params: -c -m ndcg_cut.10 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + - metric: R@100 + command: bin/trec_eval + params: -c -m recall.100 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + - metric: R@1000 + command: bin/trec_eval + params: -c -m recall.1000 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + +topic_reader: TsvString +topics: + - name: "BEIR (v1.0.0): Robust04" + id: test + path: topics.beir-v1.0.0-robust04.test.tsv.gz + qrel: qrels.beir-v1.0.0-robust04.test.txt + +# Fusion Regression Test Configuration +runs: + - runs/runs.beir/run.inverted.beir-v1.0.0-robust04.flat.test.bm25 + - runs/runs.beir/run.inverted.beir-v1.0.0-robust04.multifield.test.bm25 + - runs/runs.beir/run.inverted.beir-v1.0.0-robust04.splade-pp-ed.test.splade-pp-ed-cached + +methods: + - name: rrf + k: 1000 + depth: 1000 + rrf_k: 60 + output: runs/fuse/run.inverted.beir-v1.0.0-robust04.flat.test.bm25.multifield.test.bm25.splade-pp-ed.test.splade-pp-ed-cached.fusion.rrf + results: + nDCG@10: + - 0.4636 + R@100: + - 0.4243 + R@1000: + - 0.7349 + - name: average + output: runs/fuse/run.inverted.beir-v1.0.0-robust04.flat.test.bm25.multifield.test.bm25.splade-pp-ed.test.splade-pp-ed-cached.fusion.average + results: + nDCG@10: + - 0.4 + R@100: + - 0.38 + R@1000: + - 0.62 + + + diff --git a/src/main/resources/fuse_regression/test_1.yaml b/src/main/resources/fuse_regression/test_1.yaml new file mode 100644 index 0000000000..c513775381 --- /dev/null +++ b/src/main/resources/fuse_regression/test_1.yaml @@ -0,0 +1,74 @@ +--- +corpus: beir-v1.0.0-robust04.flat +corpus_path: collections/beir-v1.0.0/corpus/robust04/ + +metrics: + - metric: nDCG@10 + command: bin/trec_eval + params: -c -m ndcg_cut.10 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + - metric: R@100 + command: bin/trec_eval + params: -c -m recall.100 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + - metric: R@1000 + command: bin/trec_eval + params: -c -m recall.1000 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + +topic_reader: TsvString +topics: + - name: "BEIR (v1.0.0): Robust04" + id: test + path: topics.beir-v1.0.0-robust04.test.tsv.gz + qrel: qrels.beir-v1.0.0-robust04.test.txt + +# Fusion Regression Test Configuration +runs: + - runs/runs.beir/run.inverted.beir-v1.0.0-robust04.flat.test.bm25 + - runs/runs.beir/run.inverted.beir-v1.0.0-robust04.splade-pp-ed.test.splade-pp-ed-cached + +methods: + - name: rrf + k: 1000 + depth: 1000 + rrf_k: 60 + output: runs/fuse/run.inverted.beir-v1.0.0-robust04.flat.test.bm25.splade-pp-ed.test.splade-pp-ed-cached.fusion.rrf + results: + nDCG@10: + - 0.4636 + R@100: + - 0.4243 + R@1000: + - 0.7349 + - name: average + output: runs/fuse/run.inverted.beir-v1.0.0-robust04.flat.test.bm25.splade-pp-ed.test.splade-pp-ed-cached.fusion.average + results: + nDCG@10: + - 0.4 + R@100: + - 0.4 + R@1000: + - 0.7 + - name: interpolation + alpha: 0.5 + output: runs/fuse/run.inverted.beir-v1.0.0-robust04.flat.test.bm25.splade-pp-ed.test.splade-pp-ed-cached.fusion.interpolation + results: + nDCG@10: + - 0.4 + R@100: + - 0.4 + R@1000: + - 0.7 + + + diff --git a/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml b/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml index 5de794aaff..210782236f 100644 --- a/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml +++ b/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml @@ -53,3 +53,4 @@ models: - 0.3746 R@1000: - 0.6345 + \ No newline at end of file From e049e4864149567550ea65231d34c5b3d0448070 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 23 Sep 2024 05:24:51 +0000 Subject: [PATCH 15/33] added md for test --- docs/fuse-regressions/test-setup.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 docs/fuse-regressions/test-setup.md diff --git a/docs/fuse-regressions/test-setup.md b/docs/fuse-regressions/test-setup.md new file mode 100644 index 0000000000..cc21e08afe --- /dev/null +++ b/docs/fuse-regressions/test-setup.md @@ -0,0 +1,23 @@ +# Fusion Regression Test Setup + +This document provides instructions for setting up and downloading the necessary run files to perform fusion regression tests. + +## Prerequisites +You will need the following: +- A working installation of `wget`. +- Enough disk space to store the downloaded files. + +## Automatic Download Using Script +To automatically download the required files, you can use the following shell script. The script will download and extract the files in the `runs/runs.beir` folder with the correct filenames. + +```bash +#!/bin/bash + +# Create the target directory if it doesn't exist +mkdir -p runs/runs.beir + +# Download the run files from Google Drive using their file IDs +wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1XVlVCDYQe3YjRzxplaeGbmW_0EFQCgm8' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.multifield.test.bm25 +wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1Z4rWlNgmXebMf1ardfiDg_4KIZImjqxt' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.splade-pp-ed.test.splade-pp-ed-cached +wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1fExxJHkPPNCdtptKqWTbcsH0Ql0PnPqS' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.flat.test.bm25 + From bd0ce76605763d23435777383063d119644455da Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 23 Sep 2024 05:38:40 +0000 Subject: [PATCH 16/33] add cmd on test instruction --- docs/fuse-regressions/test-setup.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/fuse-regressions/test-setup.md b/docs/fuse-regressions/test-setup.md index cc21e08afe..d9180e5083 100644 --- a/docs/fuse-regressions/test-setup.md +++ b/docs/fuse-regressions/test-setup.md @@ -20,4 +20,11 @@ mkdir -p runs/runs.beir wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1XVlVCDYQe3YjRzxplaeGbmW_0EFQCgm8' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.multifield.test.bm25 wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1Z4rWlNgmXebMf1ardfiDg_4KIZImjqxt' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.splade-pp-ed.test.splade-pp-ed-cached wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1fExxJHkPPNCdtptKqWTbcsH0Ql0PnPqS' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.flat.test.bm25 +``` +## Run fuse-regression script with two yaml tests +```bash +python src/main/python/run_fusion_regression.py --regression test_0 + +python src/main/python/run_fusion_regression.py --regression test_1 +``` \ No newline at end of file From 17ceb49df487f7fef615122fc50791380c6522ca Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 23 Sep 2024 18:01:07 +0000 Subject: [PATCH 17/33] removed abundant dependency --- pom.xml | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pom.xml b/pom.xml index 94d19a6a4a..f036bb76b1 100644 --- a/pom.xml +++ b/pom.xml @@ -704,17 +704,5 @@ - - junit - junit - 4.13.2 - test - - - org.junit.jupiter - junit-jupiter-engine - 5.8.2 - test - From 6f550b1afb8b2bbd30df6e3e8663f16f4044102a Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 23 Sep 2024 18:46:35 +0000 Subject: [PATCH 18/33] revert unecessary change --- src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml b/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml index 210782236f..1a5a118f1a 100644 --- a/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml +++ b/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml @@ -52,5 +52,4 @@ models: R@100: - 0.3746 R@1000: - - 0.6345 - \ No newline at end of file + - 0.6345 \ No newline at end of file From f4644e1e18889ae60b9ea27c117a2ee33195842b Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 23 Sep 2024 20:40:31 +0000 Subject: [PATCH 19/33] resolved a minor decoding issue --- src/main/python/run_fusion_regression.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/python/run_fusion_regression.py b/src/main/python/run_fusion_regression.py index 12ce16a2e8..bdca48a6c9 100644 --- a/src/main/python/run_fusion_regression.py +++ b/src/main/python/run_fusion_regression.py @@ -119,7 +119,8 @@ def evaluate_and_verify(yaml_data: dict, dry_run: bool): continue try: - out = check_output(' '.join(eval_cmd)).decode('utf-8').split('\n')[-1] + out = [line for line in + check_output(' '.join(eval_cmd)).decode('utf-8').split('\n') if line.strip()][-1] if not out.strip(): continue except Exception as e: From 0ea83691c44e4dfba3514b79ff5701ce94bf08d7 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 23 Sep 2024 20:41:01 +0000 Subject: [PATCH 20/33] added a yaml that is based on regression test run results --- .../resources/fuse_regression/test_3.yaml | 73 +++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 src/main/resources/fuse_regression/test_3.yaml diff --git a/src/main/resources/fuse_regression/test_3.yaml b/src/main/resources/fuse_regression/test_3.yaml new file mode 100644 index 0000000000..33ca2b74bc --- /dev/null +++ b/src/main/resources/fuse_regression/test_3.yaml @@ -0,0 +1,73 @@ +--- +corpus: beir-v1.0.0-robust04.bge-base-en-v1.5 +corpus_path: collections/beir-v1.0.0/bge-base-en-v1.5/robust04 + +metrics: + - metric: nDCG@10 + command: bin/trec_eval + params: -c -m ndcg_cut.10 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + - metric: R@100 + command: bin/trec_eval + params: -c -m recall.100 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + - metric: R@1000 + command: bin/trec_eval + params: -c -m recall.1000 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + +topic_reader: JsonStringVector +topics: + - name: "BEIR (v1.0.0): Robust04" + id: test + path: topics.beir-v1.0.0-robust04.test.bge-base-en-v1.5.jsonl.gz + qrel: qrels.beir-v1.0.0-robust04.test.txt + +# Fusion Regression Test Configuration +runs: + - runs/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached + - runs/run.flat.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-cached +methods: + - name: rrf + k: 1000 + depth: 1000 + rrf_k: 60 + output: runs/fuse/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached.bge-flat-cached.fusion.rrf + results: + nDCG@10: + - 0.3 + R@100: + - 0.3 + R@1000: + - 0.5 + - name: average + output: runs/fuse/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached.bge-flat-cached.fusion.average + results: + nDCG@10: + - 0.3 + R@100: + - 0.3 + R@1000: + - 0.5 + - name: interpolation + alpha: 0.5 + output: runs/fuse/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached.bge-flat-cached.fusion.interpolation + results: + nDCG@10: + - 0.3 + R@100: + - 0.3 + R@1000: + - 0.5 + + + From ec57e96ee59ba9e05ade7c27e774d07a775ddb25 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 23 Sep 2024 20:59:09 +0000 Subject: [PATCH 21/33] added doc for test2 --- docs/fuse-regressions/test-setup.md | 8 +++++++- .../fuse_regression/{test_3.yaml => test_2.yaml} | 8 ++------ 2 files changed, 9 insertions(+), 7 deletions(-) rename src/main/resources/fuse_regression/{test_3.yaml => test_2.yaml} (88%) diff --git a/docs/fuse-regressions/test-setup.md b/docs/fuse-regressions/test-setup.md index d9180e5083..b4ee2862f5 100644 --- a/docs/fuse-regressions/test-setup.md +++ b/docs/fuse-regressions/test-setup.md @@ -7,7 +7,7 @@ You will need the following: - A working installation of `wget`. - Enough disk space to store the downloaded files. -## Automatic Download Using Script +## Automatic Download Using Script for test0/1 To automatically download the required files, you can use the following shell script. The script will download and extract the files in the `runs/runs.beir` folder with the correct filenames. ```bash @@ -21,10 +21,16 @@ wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1XVl wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1Z4rWlNgmXebMf1ardfiDg_4KIZImjqxt' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.splade-pp-ed.test.splade-pp-ed-cached wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1fExxJHkPPNCdtptKqWTbcsH0Ql0PnPqS' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.flat.test.bm25 ``` +## Perform two regression runs for test2 +One could generate the runs necessary for test 3 following +- https://github.com/castorini/anserini/blob/master/docs/regressions/regressions-beir-v1.0.0-robust04.bge-base-en-v1.5.flat-int8.cached.md +- https://github.com/castorini/anserini/blob/master/docs/regressions/regressions-beir-v1.0.0-robust04.bge-base-en-v1.5.flat.cached.md ## Run fuse-regression script with two yaml tests ```bash python src/main/python/run_fusion_regression.py --regression test_0 python src/main/python/run_fusion_regression.py --regression test_1 + +python src/main/python/run_fusion_regression.py --regression test_2 ``` \ No newline at end of file diff --git a/src/main/resources/fuse_regression/test_3.yaml b/src/main/resources/fuse_regression/test_2.yaml similarity index 88% rename from src/main/resources/fuse_regression/test_3.yaml rename to src/main/resources/fuse_regression/test_2.yaml index 33ca2b74bc..f19acd4969 100644 --- a/src/main/resources/fuse_regression/test_3.yaml +++ b/src/main/resources/fuse_regression/test_2.yaml @@ -25,12 +25,8 @@ metrics: metric_precision: 4 can_combine: false -topic_reader: JsonStringVector -topics: - - name: "BEIR (v1.0.0): Robust04" - id: test - path: topics.beir-v1.0.0-robust04.test.bge-base-en-v1.5.jsonl.gz - qrel: qrels.beir-v1.0.0-robust04.test.txt + + # Fusion Regression Test Configuration runs: From 042b678f6926baefeb3e56eae6cb74d38695f020 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 23 Sep 2024 21:02:11 +0000 Subject: [PATCH 22/33] typo --- docs/fuse-regressions/test-setup.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/fuse-regressions/test-setup.md b/docs/fuse-regressions/test-setup.md index b4ee2862f5..e07488a892 100644 --- a/docs/fuse-regressions/test-setup.md +++ b/docs/fuse-regressions/test-setup.md @@ -22,7 +22,7 @@ wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1Z4r wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1fExxJHkPPNCdtptKqWTbcsH0Ql0PnPqS' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.flat.test.bm25 ``` ## Perform two regression runs for test2 -One could generate the runs necessary for test 3 following +One could generate the runs necessary for test 2 following - https://github.com/castorini/anserini/blob/master/docs/regressions/regressions-beir-v1.0.0-robust04.bge-base-en-v1.5.flat-int8.cached.md - https://github.com/castorini/anserini/blob/master/docs/regressions/regressions-beir-v1.0.0-robust04.bge-base-en-v1.5.flat.cached.md From f2b6f4c18da8f6155f541633e0e3e2314d494a5b Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 23 Sep 2024 22:03:47 +0000 Subject: [PATCH 23/33] changed name for test yamls --- docs/fuse-regressions/test-setup.md | 14 ++++++++------ ...usion-regression-bge-flat-int8-robust04-2.yaml} | 8 ++++++-- ... => fusion-regression-bge-flat-robust04-3.yaml} | 0 ...usion-regression-bge-flat-robust04.yaml-2.yaml} | 0 4 files changed, 14 insertions(+), 8 deletions(-) rename src/main/resources/fuse_regression/{test_2.yaml => fusion-regression-bge-flat-int8-robust04-2.yaml} (88%) rename src/main/resources/fuse_regression/{test_0.yaml => fusion-regression-bge-flat-robust04-3.yaml} (100%) rename src/main/resources/fuse_regression/{test_1.yaml => fusion-regression-bge-flat-robust04.yaml-2.yaml} (100%) diff --git a/docs/fuse-regressions/test-setup.md b/docs/fuse-regressions/test-setup.md index e07488a892..1b3bb5acf9 100644 --- a/docs/fuse-regressions/test-setup.md +++ b/docs/fuse-regressions/test-setup.md @@ -7,7 +7,8 @@ You will need the following: - A working installation of `wget`. - Enough disk space to store the downloaded files. -## Automatic Download Using Script for test0/1 +## Automatic Download Using Script for first two tests + To automatically download the required files, you can use the following shell script. The script will download and extract the files in the `runs/runs.beir` folder with the correct filenames. ```bash @@ -21,16 +22,17 @@ wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1XVl wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1Z4rWlNgmXebMf1ardfiDg_4KIZImjqxt' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.splade-pp-ed.test.splade-pp-ed-cached wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1fExxJHkPPNCdtptKqWTbcsH0Ql0PnPqS' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.flat.test.bm25 ``` -## Perform two regression runs for test2 -One could generate the runs necessary for test 2 following +## Perform two regression runs for test fusion-regression-bge-flat-int8-robust04-2 + +One could generate the runs necessary for test fusion-regression-bge-flat-int8-robust04-2 following - https://github.com/castorini/anserini/blob/master/docs/regressions/regressions-beir-v1.0.0-robust04.bge-base-en-v1.5.flat-int8.cached.md - https://github.com/castorini/anserini/blob/master/docs/regressions/regressions-beir-v1.0.0-robust04.bge-base-en-v1.5.flat.cached.md ## Run fuse-regression script with two yaml tests ```bash -python src/main/python/run_fusion_regression.py --regression test_0 +python src/main/python/run_fusion_regression.py --regression fusion-regression-bge-flat-robust04-3 -python src/main/python/run_fusion_regression.py --regression test_1 +python src/main/python/run_fusion_regression.py --regression fusion-regression-bge-flat-robust04.yaml-2 -python src/main/python/run_fusion_regression.py --regression test_2 +python src/main/python/run_fusion_regression.py --regression fusion-regression-bge-flat-int8-robust04-2 ``` \ No newline at end of file diff --git a/src/main/resources/fuse_regression/test_2.yaml b/src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml similarity index 88% rename from src/main/resources/fuse_regression/test_2.yaml rename to src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml index f19acd4969..33ca2b74bc 100644 --- a/src/main/resources/fuse_regression/test_2.yaml +++ b/src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml @@ -25,8 +25,12 @@ metrics: metric_precision: 4 can_combine: false - - +topic_reader: JsonStringVector +topics: + - name: "BEIR (v1.0.0): Robust04" + id: test + path: topics.beir-v1.0.0-robust04.test.bge-base-en-v1.5.jsonl.gz + qrel: qrels.beir-v1.0.0-robust04.test.txt # Fusion Regression Test Configuration runs: diff --git a/src/main/resources/fuse_regression/test_0.yaml b/src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04-3.yaml similarity index 100% rename from src/main/resources/fuse_regression/test_0.yaml rename to src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04-3.yaml diff --git a/src/main/resources/fuse_regression/test_1.yaml b/src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04.yaml-2.yaml similarity index 100% rename from src/main/resources/fuse_regression/test_1.yaml rename to src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04.yaml-2.yaml From d94c0f90258b43814d4f24fbad0b2d543db1c465 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Tue, 24 Sep 2024 21:03:16 +0000 Subject: [PATCH 24/33] second attempt to revert src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml back --- src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml b/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml index 1a5a118f1a..5de794aaff 100644 --- a/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml +++ b/src/main/resources/regression/beir-v1.0.0-robust04.flat.yaml @@ -52,4 +52,4 @@ models: R@100: - 0.3746 R@1000: - - 0.6345 \ No newline at end of file + - 0.6345 From f5871b9b3f03c245f6aa77664ac841794d081f2b Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Wed, 25 Sep 2024 07:02:11 +0000 Subject: [PATCH 25/33] fixed precision and added run_origins for fusion yaml --- ...n-regression-bge-flat-int8-robust04-2.yaml | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml b/src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml index 33ca2b74bc..9c9513ee2b 100644 --- a/src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml +++ b/src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml @@ -36,6 +36,12 @@ topics: runs: - runs/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached - runs/run.flat.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-cached + +# yaml files that generated the runs +run_origins: + - src/main/resources/regression/beir-v1.0.0-robust04.bge-base-en-v1.5.flat-int8.cached.yaml + - src/main/resources/regression/beir-v1.0.0-robust04.bge-base-en-v1.5.flat.cached.yaml + methods: - name: rrf k: 1000 @@ -44,30 +50,30 @@ methods: output: runs/fuse/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached.bge-flat-cached.fusion.rrf results: nDCG@10: - - 0.3 + - 0.4467 R@100: - - 0.3 + - 0.3497 R@1000: - - 0.5 + - 0.5982 - name: average output: runs/fuse/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached.bge-flat-cached.fusion.average results: nDCG@10: - - 0.3 + - 0.4463 R@100: - - 0.3 + - 0.3498 R@1000: - - 0.5 + - 0.5984 - name: interpolation alpha: 0.5 output: runs/fuse/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached.bge-flat-cached.fusion.interpolation results: nDCG@10: - - 0.3 + - 0.4463 R@100: - - 0.3 + - 0.3498 R@1000: - - 0.5 + - 0.5984 From b7961f3580aa84483e98a5c41ae897634d715b9d Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Sun, 29 Sep 2024 02:33:37 +0000 Subject: [PATCH 26/33] removed two yamls that use runs not from current regression experiments --- ...fusion-regression-bge-flat-robust04-3.yaml | 65 ---------------- ...n-regression-bge-flat-robust04.yaml-2.yaml | 74 ------------------- 2 files changed, 139 deletions(-) delete mode 100644 src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04-3.yaml delete mode 100644 src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04.yaml-2.yaml diff --git a/src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04-3.yaml b/src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04-3.yaml deleted file mode 100644 index 39cf774e9f..0000000000 --- a/src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04-3.yaml +++ /dev/null @@ -1,65 +0,0 @@ ---- -corpus: beir-v1.0.0-robust04.flat -corpus_path: collections/beir-v1.0.0/corpus/robust04/ - -metrics: - - metric: nDCG@10 - command: bin/trec_eval - params: -c -m ndcg_cut.10 - separator: "\t" - parse_index: 2 - metric_precision: 4 - can_combine: false - - metric: R@100 - command: bin/trec_eval - params: -c -m recall.100 - separator: "\t" - parse_index: 2 - metric_precision: 4 - can_combine: false - - metric: R@1000 - command: bin/trec_eval - params: -c -m recall.1000 - separator: "\t" - parse_index: 2 - metric_precision: 4 - can_combine: false - -topic_reader: TsvString -topics: - - name: "BEIR (v1.0.0): Robust04" - id: test - path: topics.beir-v1.0.0-robust04.test.tsv.gz - qrel: qrels.beir-v1.0.0-robust04.test.txt - -# Fusion Regression Test Configuration -runs: - - runs/runs.beir/run.inverted.beir-v1.0.0-robust04.flat.test.bm25 - - runs/runs.beir/run.inverted.beir-v1.0.0-robust04.multifield.test.bm25 - - runs/runs.beir/run.inverted.beir-v1.0.0-robust04.splade-pp-ed.test.splade-pp-ed-cached - -methods: - - name: rrf - k: 1000 - depth: 1000 - rrf_k: 60 - output: runs/fuse/run.inverted.beir-v1.0.0-robust04.flat.test.bm25.multifield.test.bm25.splade-pp-ed.test.splade-pp-ed-cached.fusion.rrf - results: - nDCG@10: - - 0.4636 - R@100: - - 0.4243 - R@1000: - - 0.7349 - - name: average - output: runs/fuse/run.inverted.beir-v1.0.0-robust04.flat.test.bm25.multifield.test.bm25.splade-pp-ed.test.splade-pp-ed-cached.fusion.average - results: - nDCG@10: - - 0.4 - R@100: - - 0.38 - R@1000: - - 0.62 - - - diff --git a/src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04.yaml-2.yaml b/src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04.yaml-2.yaml deleted file mode 100644 index c513775381..0000000000 --- a/src/main/resources/fuse_regression/fusion-regression-bge-flat-robust04.yaml-2.yaml +++ /dev/null @@ -1,74 +0,0 @@ ---- -corpus: beir-v1.0.0-robust04.flat -corpus_path: collections/beir-v1.0.0/corpus/robust04/ - -metrics: - - metric: nDCG@10 - command: bin/trec_eval - params: -c -m ndcg_cut.10 - separator: "\t" - parse_index: 2 - metric_precision: 4 - can_combine: false - - metric: R@100 - command: bin/trec_eval - params: -c -m recall.100 - separator: "\t" - parse_index: 2 - metric_precision: 4 - can_combine: false - - metric: R@1000 - command: bin/trec_eval - params: -c -m recall.1000 - separator: "\t" - parse_index: 2 - metric_precision: 4 - can_combine: false - -topic_reader: TsvString -topics: - - name: "BEIR (v1.0.0): Robust04" - id: test - path: topics.beir-v1.0.0-robust04.test.tsv.gz - qrel: qrels.beir-v1.0.0-robust04.test.txt - -# Fusion Regression Test Configuration -runs: - - runs/runs.beir/run.inverted.beir-v1.0.0-robust04.flat.test.bm25 - - runs/runs.beir/run.inverted.beir-v1.0.0-robust04.splade-pp-ed.test.splade-pp-ed-cached - -methods: - - name: rrf - k: 1000 - depth: 1000 - rrf_k: 60 - output: runs/fuse/run.inverted.beir-v1.0.0-robust04.flat.test.bm25.splade-pp-ed.test.splade-pp-ed-cached.fusion.rrf - results: - nDCG@10: - - 0.4636 - R@100: - - 0.4243 - R@1000: - - 0.7349 - - name: average - output: runs/fuse/run.inverted.beir-v1.0.0-robust04.flat.test.bm25.splade-pp-ed.test.splade-pp-ed-cached.fusion.average - results: - nDCG@10: - - 0.4 - R@100: - - 0.4 - R@1000: - - 0.7 - - name: interpolation - alpha: 0.5 - output: runs/fuse/run.inverted.beir-v1.0.0-robust04.flat.test.bm25.splade-pp-ed.test.splade-pp-ed-cached.fusion.interpolation - results: - nDCG@10: - - 0.4 - R@100: - - 0.4 - R@1000: - - 0.7 - - - From ab33853a3b4ca43e4b6b0f7c1ebe686232c56d7a Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Sun, 29 Sep 2024 02:38:58 +0000 Subject: [PATCH 27/33] modified test instructions according to last commit --- docs/fuse-regressions/test-setup.md | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/docs/fuse-regressions/test-setup.md b/docs/fuse-regressions/test-setup.md index 1b3bb5acf9..9b42b89fe4 100644 --- a/docs/fuse-regressions/test-setup.md +++ b/docs/fuse-regressions/test-setup.md @@ -2,26 +2,7 @@ This document provides instructions for setting up and downloading the necessary run files to perform fusion regression tests. -## Prerequisites -You will need the following: -- A working installation of `wget`. -- Enough disk space to store the downloaded files. -## Automatic Download Using Script for first two tests - -To automatically download the required files, you can use the following shell script. The script will download and extract the files in the `runs/runs.beir` folder with the correct filenames. - -```bash -#!/bin/bash - -# Create the target directory if it doesn't exist -mkdir -p runs/runs.beir - -# Download the run files from Google Drive using their file IDs -wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1XVlVCDYQe3YjRzxplaeGbmW_0EFQCgm8' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.multifield.test.bm25 -wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1Z4rWlNgmXebMf1ardfiDg_4KIZImjqxt' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.splade-pp-ed.test.splade-pp-ed-cached -wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1fExxJHkPPNCdtptKqWTbcsH0Ql0PnPqS' -O runs/runs.beir/run.inverted.beir-v1.0.0-robust04.flat.test.bm25 -``` ## Perform two regression runs for test fusion-regression-bge-flat-int8-robust04-2 One could generate the runs necessary for test fusion-regression-bge-flat-int8-robust04-2 following @@ -30,9 +11,5 @@ One could generate the runs necessary for test fusion-regression-bge-flat-int8-r ## Run fuse-regression script with two yaml tests ```bash -python src/main/python/run_fusion_regression.py --regression fusion-regression-bge-flat-robust04-3 - -python src/main/python/run_fusion_regression.py --regression fusion-regression-bge-flat-robust04.yaml-2 - python src/main/python/run_fusion_regression.py --regression fusion-regression-bge-flat-int8-robust04-2 ``` \ No newline at end of file From db12c7912778afdb40181fd8a663d1762258c941 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 30 Sep 2024 23:33:31 +0000 Subject: [PATCH 28/33] add yaml file --- ...5.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml diff --git a/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml b/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml new file mode 100644 index 0000000000..4dc6ac5e78 --- /dev/null +++ b/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml @@ -0,0 +1,71 @@ +--- +corpus: beir-v1.0.0-robust04 +corpus_path: collections/beir-v1.0.0/corpus/robust04/ + +metrics: + - metric: nDCG@10 + command: bin/trec_eval + params: -c -m ndcg_cut.10 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + - metric: R@100 + command: bin/trec_eval + params: -c -m recall.100 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + - metric: R@1000 + command: bin/trec_eval + params: -c -m recall.1000 + separator: "\t" + parse_index: 2 + metric_precision: 4 + can_combine: false + +topic_reader: TsvString +topics: + - name: "BEIR (v1.0.0): Robust04" + id: test + path: topics.beir-v1.0.0-robust04.test.tsv.gz + qrel: qrels.beir-v1.0.0-robust04.test.txt + +# Fusion Regression Test Configuration +runs: + - runs/run.beir-v1.0.0-robust04.flat.bm25.topics.beir-v1.0.0-robust04.test.txt + - runs/run.beir-v1.0.0-robust04.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt + +methods: + - name: rrf + k: 1000 + depth: 1000 + rrf_k: 60 + output: runs/fuse/run.beir-v1.0.0-robust04.flat.bm25.fuse-rrf.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt + results: + nDCG@10: + - 0.5070 + R@100: + - 0.4465 + R@1000: + - 0.7219 + - name: average + output: runs/fuse/run.beir-v1.0.0-robust04.flat.bm25.fuse-average.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt + results: + nDCG@10: + - 0.4324 + R@100: + - 0.3963 + R@1000: + - 0.6345 + - name: interpolation + alpha: 0.5 + output: runs/fuse/run.beir-v1.0.0-robust04.flat.bm25.fuse-interpolation.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt + results: + nDCG@10: + - 0.4324 + R@100: + - 0.3963 + R@1000: + - 0.6345 \ No newline at end of file From 9a419aa57c82e955646227cd9bcbc4b1e12e5b48 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Mon, 30 Sep 2024 23:34:48 +0000 Subject: [PATCH 29/33] removed old yaml --- ...n-regression-bge-flat-int8-robust04-2.yaml | 79 ------------------- 1 file changed, 79 deletions(-) delete mode 100644 src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml diff --git a/src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml b/src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml deleted file mode 100644 index 9c9513ee2b..0000000000 --- a/src/main/resources/fuse_regression/fusion-regression-bge-flat-int8-robust04-2.yaml +++ /dev/null @@ -1,79 +0,0 @@ ---- -corpus: beir-v1.0.0-robust04.bge-base-en-v1.5 -corpus_path: collections/beir-v1.0.0/bge-base-en-v1.5/robust04 - -metrics: - - metric: nDCG@10 - command: bin/trec_eval - params: -c -m ndcg_cut.10 - separator: "\t" - parse_index: 2 - metric_precision: 4 - can_combine: false - - metric: R@100 - command: bin/trec_eval - params: -c -m recall.100 - separator: "\t" - parse_index: 2 - metric_precision: 4 - can_combine: false - - metric: R@1000 - command: bin/trec_eval - params: -c -m recall.1000 - separator: "\t" - parse_index: 2 - metric_precision: 4 - can_combine: false - -topic_reader: JsonStringVector -topics: - - name: "BEIR (v1.0.0): Robust04" - id: test - path: topics.beir-v1.0.0-robust04.test.bge-base-en-v1.5.jsonl.gz - qrel: qrels.beir-v1.0.0-robust04.test.txt - -# Fusion Regression Test Configuration -runs: - - runs/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached - - runs/run.flat.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-cached - -# yaml files that generated the runs -run_origins: - - src/main/resources/regression/beir-v1.0.0-robust04.bge-base-en-v1.5.flat-int8.cached.yaml - - src/main/resources/regression/beir-v1.0.0-robust04.bge-base-en-v1.5.flat.cached.yaml - -methods: - - name: rrf - k: 1000 - depth: 1000 - rrf_k: 60 - output: runs/fuse/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached.bge-flat-cached.fusion.rrf - results: - nDCG@10: - - 0.4467 - R@100: - - 0.3497 - R@1000: - - 0.5982 - - name: average - output: runs/fuse/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached.bge-flat-cached.fusion.average - results: - nDCG@10: - - 0.4463 - R@100: - - 0.3498 - R@1000: - - 0.5984 - - name: interpolation - alpha: 0.5 - output: runs/fuse/run.flat-int8.beir-v1.0.0-robust04.bge-base-en-v1.5.test.bge-flat-int8-cached.bge-flat-cached.fusion.interpolation - results: - nDCG@10: - - 0.4463 - R@100: - - 0.3498 - R@1000: - - 0.5984 - - - From d9cff545698cec53016b2b85c219e5859f0ec92d Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Tue, 1 Oct 2024 00:17:17 +0000 Subject: [PATCH 30/33] changed output naming --- ...ust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml b/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml index 4dc6ac5e78..e1ee5ac8ff 100644 --- a/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml +++ b/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml @@ -42,7 +42,7 @@ methods: k: 1000 depth: 1000 rrf_k: 60 - output: runs/fuse/run.beir-v1.0.0-robust04.flat.bm25.fuse-rrf.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt + output: runs/runs.fuse.rrf.beir-v1.0.0-robust04.flat.bm25.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt results: nDCG@10: - 0.5070 @@ -51,7 +51,7 @@ methods: R@1000: - 0.7219 - name: average - output: runs/fuse/run.beir-v1.0.0-robust04.flat.bm25.fuse-average.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt + output: runs/runs.fuse.avg.beir-v1.0.0-robust04.flat.bm25.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt results: nDCG@10: - 0.4324 @@ -61,7 +61,7 @@ methods: - 0.6345 - name: interpolation alpha: 0.5 - output: runs/fuse/run.beir-v1.0.0-robust04.flat.bm25.fuse-interpolation.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt + output: runs/runs.fuse.interp.beir-v1.0.0-robust04.flat.bm25.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt results: nDCG@10: - 0.4324 From 6491adb9cce6bb632dadfde20092134f61296308 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Fri, 8 Nov 2024 01:28:43 +0000 Subject: [PATCH 31/33] improved regression script and yaml information --- src/main/python/run_fusion_regression.py | 10 +++++++++- ....flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml | 9 +++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/main/python/run_fusion_regression.py b/src/main/python/run_fusion_regression.py index bdca48a6c9..2e9e8fd124 100644 --- a/src/main/python/run_fusion_regression.py +++ b/src/main/python/run_fusion_regression.py @@ -59,7 +59,7 @@ def construct_fusion_commands(yaml_data: dict) -> list: return [ [ FUSE_COMMAND, - '-runs', ' '.join([run for run in yaml_data['runs']]), + '-runs', ' '.join(run['file'] for run in yaml_data['runs']), '-output', method.get('output'), '-method', method.get('name', 'average'), '-k', str(method.get('k', 1000)), @@ -166,6 +166,14 @@ def evaluate_and_verify(yaml_data: dict, dry_run: bool): logger.error(f"Failed to load configuration file: {e}") exit(1) + # Check existence of run files + for run in yaml_data['runs']: + if not os.path.exists(run['file']): + logger.error(f"Run file {run['file']} does not exist. Please run the dependent regressions first, recorded in the fusion yaml file.") + exit(1) + + + # Construct the fusion command fusion_commands = construct_fusion_commands(yaml_data) diff --git a/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml b/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml index e1ee5ac8ff..683a0a9465 100644 --- a/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml +++ b/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml @@ -34,8 +34,13 @@ topics: # Fusion Regression Test Configuration runs: - - runs/run.beir-v1.0.0-robust04.flat.bm25.topics.beir-v1.0.0-robust04.test.txt - - runs/run.beir-v1.0.0-robust04.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt + - name: flat-bm25 + dependency: beir-v1.0.0-robust04.flat.yaml + file: runs/run.beir-v1.0.0-robust04.flat.bm25.topics.beir-v1.0.0-robust04.test.txt + - name: bge-flat-onnx + dependency: beir-v1.0.0-robust04.bge-base-en-v1.5.flat.onnx.yaml + file: runs/run.beir-v1.0.0-robust04.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt + methods: - name: rrf From af1b5c820466baf225635bec8299c2fca5b809da Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Fri, 8 Nov 2024 01:32:37 +0000 Subject: [PATCH 32/33] removed abundant file --- docs/fuse-regressions/test-setup.md | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 docs/fuse-regressions/test-setup.md diff --git a/docs/fuse-regressions/test-setup.md b/docs/fuse-regressions/test-setup.md deleted file mode 100644 index 9b42b89fe4..0000000000 --- a/docs/fuse-regressions/test-setup.md +++ /dev/null @@ -1,15 +0,0 @@ -# Fusion Regression Test Setup - -This document provides instructions for setting up and downloading the necessary run files to perform fusion regression tests. - - -## Perform two regression runs for test fusion-regression-bge-flat-int8-robust04-2 - -One could generate the runs necessary for test fusion-regression-bge-flat-int8-robust04-2 following -- https://github.com/castorini/anserini/blob/master/docs/regressions/regressions-beir-v1.0.0-robust04.bge-base-en-v1.5.flat-int8.cached.md -- https://github.com/castorini/anserini/blob/master/docs/regressions/regressions-beir-v1.0.0-robust04.bge-base-en-v1.5.flat.cached.md - -## Run fuse-regression script with two yaml tests -```bash -python src/main/python/run_fusion_regression.py --regression fusion-regression-bge-flat-int8-robust04-2 -``` \ No newline at end of file From 149b2983008792ab2718d44a5eaf1071eb9812a7 Mon Sep 17 00:00:00 2001 From: Stefan Min Date: Fri, 8 Nov 2024 01:34:37 +0000 Subject: [PATCH 33/33] resolved formatting issues --- src/main/python/run_fusion_regression.py | 4 +--- ...obust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/main/python/run_fusion_regression.py b/src/main/python/run_fusion_regression.py index 2e9e8fd124..9c5a3c7ecd 100644 --- a/src/main/python/run_fusion_regression.py +++ b/src/main/python/run_fusion_regression.py @@ -172,8 +172,6 @@ def evaluate_and_verify(yaml_data: dict, dry_run: bool): logger.error(f"Run file {run['file']} does not exist. Please run the dependent regressions first, recorded in the fusion yaml file.") exit(1) - - # Construct the fusion command fusion_commands = construct_fusion_commands(yaml_data) @@ -186,4 +184,4 @@ def evaluate_and_verify(yaml_data: dict, dry_run: bool): # Evaluate and verify results evaluate_and_verify(yaml_data, args.dry_run) - logger.info(f"Total execution time: {time.time() - start_time:.2f} seconds") + logger.info(f"Total execution time: {time.time() - start_time:.2f} seconds") \ No newline at end of file diff --git a/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml b/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml index 683a0a9465..9d7c68398e 100644 --- a/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml +++ b/src/main/resources/fuse_regression/beir-v1.0.0-robust04.flat.bm25.fuse.bge-base-en-v1.5.bge-flat-onnx.yaml @@ -32,7 +32,7 @@ topics: path: topics.beir-v1.0.0-robust04.test.tsv.gz qrel: qrels.beir-v1.0.0-robust04.test.txt -# Fusion Regression Test Configuration +# Run dependencies for fusion runs: - name: flat-bm25 dependency: beir-v1.0.0-robust04.flat.yaml @@ -41,7 +41,6 @@ runs: dependency: beir-v1.0.0-robust04.bge-base-en-v1.5.flat.onnx.yaml file: runs/run.beir-v1.0.0-robust04.bge-base-en-v1.5.bge-flat-onnx.topics.beir-v1.0.0-robust04.test.txt - methods: - name: rrf k: 1000