diff --git a/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java b/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java index 5f00631ad6028..a9cc288933ce4 100644 --- a/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java +++ b/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java @@ -243,36 +243,82 @@ public String toString(String field) { return builder.toString(); } - private volatile Term[] equalTerms = null; + private class TermAndBoost implements Comparable { + protected final Term term; + protected float boost; - private Term[] equalsTerms() { - if (terms.length == 1) { - return terms; + protected TermAndBoost(Term term, float boost) { + this.term = term; + this.boost = boost; + } + + @Override + public int compareTo(TermAndBoost other) { + int compareTo = term.compareTo(other.term); + if (compareTo == 0) { + compareTo = Float.compare(boost, other.boost); + } + return compareTo; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o instanceof TermAndBoost == false) { + return false; + } + + TermAndBoost that = (TermAndBoost) o; + return term.equals(that.term) && boost == that.boost; + } + + @Override + public int hashCode() { + return 31 * term.hashCode() + Float.hashCode(boost); } - if (equalTerms == null) { + } + + private volatile TermAndBoost[] equalTermsAndBoosts = null; + + private TermAndBoost[] equalsTermsAndBoosts() { + if (equalTermsAndBoosts != null) { + return equalTermsAndBoosts; + } + if (terms.length == 1) { + float boost = (boosts != null ? boosts[0] : 1f); + equalTermsAndBoosts = new TermAndBoost[] {new TermAndBoost(terms[0], boost)}; + } else { // sort the terms to make sure equals and hashCode are consistent // this should be a very small cost and equivalent to a HashSet but less object creation - final Term[] t = new Term[terms.length]; - System.arraycopy(terms, 0, t, 0, terms.length); - ArrayUtil.timSort(t); - equalTerms = t; + equalTermsAndBoosts = new TermAndBoost[terms.length]; + for (int i = 0; i < terms.length; i++) { + float boost = (boosts != null ? boosts[i] : 1f); + equalTermsAndBoosts[i] = new TermAndBoost(terms[i], boost); + } + ArrayUtil.timSort(equalTermsAndBoosts); } - return equalTerms; - + return equalTermsAndBoosts; } @Override public boolean equals(Object o) { - if (this == o) return true; - if (sameClassAs(o) == false) return false; + if (this == o) { + return true; + } + if (sameClassAs(o) == false) { + return false; + } BlendedTermQuery that = (BlendedTermQuery) o; - return Arrays.equals(equalsTerms(), that.equalsTerms()); + return Arrays.equals(equalsTermsAndBoosts(), that.equalsTermsAndBoosts()); + } @Override public int hashCode() { - return Objects.hash(classHash(), Arrays.hashCode(equalsTerms())); + return Objects.hash(classHash(), Arrays.hashCode(equalsTermsAndBoosts())); } public static BlendedTermQuery dismaxBlendedQuery(Term[] terms, final float tieBreakerMultiplier) { diff --git a/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java b/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java index b7b2107320e39..1513817e37a5f 100644 --- a/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java +++ b/server/src/test/java/org/apache/lucene/queries/BlendedTermQueryTests.java @@ -42,6 +42,9 @@ import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.store.Directory; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.EqualsHashCodeTestUtils; +import org.elasticsearch.test.EqualsHashCodeTestUtils.CopyFunction; +import org.elasticsearch.test.EqualsHashCodeTestUtils.MutateFunction; import java.io.IOException; import java.util.Arrays; @@ -254,4 +257,72 @@ public void testMinTTF() throws IOException { w.close(); dir.close(); } + + public void testEqualsAndHash() { + String[] fields = new String[1 + random().nextInt(10)]; + for (int i = 0; i < fields.length; i++) { + fields[i] = randomRealisticUnicodeOfLengthBetween(1, 10); + } + String term = randomRealisticUnicodeOfLengthBetween(1, 10); + Term[] terms = toTerms(fields, term); + float tieBreaker = randomFloat(); + final float[] boosts; + if (randomBoolean()) { + boosts = new float[terms.length]; + for (int i = 0; i < terms.length; i++) { + boosts[i] = randomFloat(); + } + } else { + boosts = null; + } + + BlendedTermQuery original = BlendedTermQuery.dismaxBlendedQuery(terms, boosts, tieBreaker); + CopyFunction copyFunction = org -> { + Term[] termsCopy = new Term[terms.length]; + System.arraycopy(terms, 0, termsCopy, 0, terms.length); + + float[] boostsCopy = null; + if (boosts != null) { + boostsCopy = new float[boosts.length]; + System.arraycopy(boosts, 0, boostsCopy, 0, terms.length); + } + if (randomBoolean() && terms.length > 1) { + // if we swap two elements, the resulting query should still be regarded as equal + int swapPos = randomIntBetween(1, terms.length - 1); + + Term swpTerm = termsCopy[0]; + termsCopy[0] = termsCopy[swapPos]; + termsCopy[swapPos] = swpTerm; + + if (boosts != null) { + float swpBoost = boostsCopy[0]; + boostsCopy[0] = boostsCopy[swapPos]; + boostsCopy[swapPos] = swpBoost; + } + } + return BlendedTermQuery.dismaxBlendedQuery(termsCopy, boostsCopy, tieBreaker); + }; + MutateFunction mutateFunction = org -> { + if (randomBoolean()) { + Term[] termsCopy = new Term[terms.length]; + System.arraycopy(terms, 0, termsCopy, 0, terms.length); + termsCopy[randomIntBetween(0, terms.length - 1)] = new Term(randomAlphaOfLength(10), randomAlphaOfLength(10)); + return BlendedTermQuery.dismaxBlendedQuery(termsCopy, boosts, tieBreaker); + } else { + float[] boostsCopy = null; + if (boosts != null) { + boostsCopy = new float[boosts.length]; + System.arraycopy(boosts, 0, boostsCopy, 0, terms.length); + boostsCopy[randomIntBetween(0, terms.length - 1)] = randomFloat(); + } else { + boostsCopy = new float[terms.length]; + for (int i = 0; i < terms.length; i++) { + boostsCopy[i] = randomFloat(); + } + } + return BlendedTermQuery.dismaxBlendedQuery(terms, boostsCopy, tieBreaker); + } + }; + EqualsHashCodeTestUtils.checkEqualsAndHashCode(original, copyFunction, mutateFunction ); + } }