Skip to content

Commit

Permalink
add RawTFSimilarity class (#13749)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpoerschke authored Sep 17, 2024
1 parent a4c79c8 commit a817426
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.similarities;

import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.TermStatistics;

/** Similarity that returns the raw TF as score. */
public class RawTFSimilarity extends Similarity {

/** Default constructor: parameter-free */
public RawTFSimilarity() {
super();
}

/** Primary constructor. */
public RawTFSimilarity(boolean discountOverlaps) {
super(discountOverlaps);
}

@Override
public SimScorer scorer(
float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return boost * freq;
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.similarities.ClassicSimilarity;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.RawTFSimilarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter;
Expand Down Expand Up @@ -75,7 +74,7 @@ public void setUp() throws Exception {
searcher = newSearcher(reader, true, false);
searcher.setSimilarity(new ClassicSimilarity());
scorerSearcher = new ScorerIndexSearcher(reader);
scorerSearcher.setSimilarity(new CountingSimilarity());
scorerSearcher.setSimilarity(new RawTFSimilarity());
}

@Override
Expand Down Expand Up @@ -345,24 +344,4 @@ private static StringBuilder indent(final StringBuilder builder, final int inden
return builder;
}
}

// Similarity that just returns the frequency as the score
private static class CountingSimilarity extends Similarity {

@Override
public long computeNorm(FieldInvertState state) {
return 1;
}

@Override
public SimScorer scorer(
float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return freq;
}
};
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.RawTFSimilarity;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter;
Expand Down Expand Up @@ -67,7 +66,7 @@ public void setUp() throws Exception {
reader = writer.getReader();
writer.close();
searcher = newSearcher(reader);
searcher.setSimilarity(new TFSimilarity());
searcher.setSimilarity(new RawTFSimilarity());
}

static Document doc(String v1, String v2) {
Expand All @@ -93,26 +92,6 @@ public void tearDown() throws Exception {
super.tearDown();
}

// Similarity that returns the TF as score
private static class TFSimilarity extends Similarity {

@Override
public long computeNorm(FieldInvertState state) {
return 1; // we dont care
}

@Override
public SimScorer scorer(
float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
return new SimScorer() {
@Override
public float score(float freq, long norm) {
return freq;
}
};
}
}

public void testScorerGetChildren() throws Exception {
Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search.similarities;

import java.io.IOException;
import java.util.Random;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.search.similarities.BaseSimilarityTestCase;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.IOUtils;

public class TestRawTFSimilarity extends BaseSimilarityTestCase {

private Directory directory;
private IndexReader indexReader;
private IndexSearcher indexSearcher;

@Override
protected Similarity getSimilarity(Random random) {
return new RawTFSimilarity();
}

@Override
public void setUp() throws Exception {
super.setUp();
directory = newDirectory();
try (IndexWriter indexWriter = new IndexWriter(directory, newIndexWriterConfig())) {
final Document document1 = new Document();
final Document document2 = new Document();
final Document document3 = new Document();
document1.add(LuceneTestCase.newTextField("test", "one", Field.Store.YES));
document2.add(LuceneTestCase.newTextField("test", "two two", Field.Store.YES));
document3.add(LuceneTestCase.newTextField("test", "three three three", Field.Store.YES));
indexWriter.addDocument(document1);
indexWriter.addDocument(document2);
indexWriter.addDocument(document3);
indexWriter.commit();
}
indexReader = DirectoryReader.open(directory);
indexSearcher = newSearcher(indexReader);
indexSearcher.setSimilarity(new RawTFSimilarity());
}

@Override
public void tearDown() throws Exception {
IOUtils.close(indexReader, directory);
super.tearDown();
}

public void testOne() throws IOException {
implTest("one", 1f);
}

public void testTwo() throws IOException {
implTest("two", 2f);
}

public void testThree() throws IOException {
implTest("three", 3f);
}

private void implTest(String text, float expectedScore) throws IOException {
Query query = new TermQuery(new Term("test", text));
TopDocs topDocs = indexSearcher.search(query, 1);
assertEquals(1, topDocs.totalHits.value());
assertEquals(1, topDocs.scoreDocs.length);
assertEquals(expectedScore, topDocs.scoreDocs[0].score, 0.0);
}

public void testBoostQuery() throws IOException {
Query query = new TermQuery(new Term("test", "three"));
float boost = 14f;
TopDocs topDocs = indexSearcher.search(new BoostQuery(query, boost), 1);
assertEquals(1, topDocs.totalHits.value());
assertEquals(1, topDocs.scoreDocs.length);
assertEquals(42f, topDocs.scoreDocs[0].score, 0.0);
}
}

0 comments on commit a817426

Please sign in to comment.