Skip to content

Commit

Permalink
Add Recall Tests (opensearch-project#251)
Browse files Browse the repository at this point in the history
* Add Recall Tests

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>

* Calculate Recall using document ids and other minor changes

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda authored Jan 15, 2022
1 parent cec015c commit 814932d
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 14 deletions.
83 changes: 83 additions & 0 deletions src/test/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.commons.lang.StringUtils;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.knn.index.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.indices.ModelDao;
Expand Down Expand Up @@ -782,6 +783,88 @@ public void bulkIngestRandomVectors(String indexName, String fieldName, int numV

}

//Method that adds multiple documents into the index using Bulk API
public void bulkAddKnnDocs(String index, String fieldName, float[][] indexVectors, int docCount) throws IOException {
Request request = new Request(
"POST",
"/_bulk"
);

request.addParameter("refresh", "true");
StringBuilder sb = new StringBuilder();

for (int i = 0; i < docCount; i++) {
sb.append("{ \"index\" : { \"_index\" : \"")
.append(index)
.append("\", \"_id\" : \"")
.append(i+1)
.append("\" } }\n")
.append("{ \"")
.append(fieldName)
.append("\" : ")
.append(Arrays.toString(indexVectors[i]))
.append(" }\n");
}

request.setJsonEntity(sb.toString());

Response response = client().performRequest(request);
assertEquals(response.getStatusLine().getStatusCode(), 200);
}

//Method that returns index vectors of the documents that were added before into the index
public float[][] getIndexVectorsFromIndex(String testIndex, String testField, int docCount, int dimensions) throws IOException {
float[][] vectors = new float[docCount][dimensions];

QueryBuilder qb = new MatchAllQueryBuilder();

Request request = new Request(
"POST",
"/" + testIndex + "/_search"
);

request.addParameter("size", Integer.toString(docCount));
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
builder.field("query", qb);
builder.endObject();
request.setJsonEntity(Strings.toString(builder));

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK,
RestStatus.fromCode(response.getStatusLine().getStatusCode()));

List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), testField);
int i = 0;

for (KNNResult result : results) {
float[] primitiveArray = Floats.toArray(Arrays.stream(result.getVector()).collect(Collectors.toList()));
vectors[i++] = primitiveArray;
}

return vectors;
}

// Method that performs bulk search for multiple queries and stores the resulting documents ids into list
public List<List<String>> bulkSearch(String testIndex, String testField, float[][] queryVectors, int k) throws IOException {
List<List<String>> searchResults = new ArrayList<>();
List<String> kVectors;

for (int i = 0; i < queryVectors.length; i++) {
KNNQueryBuilder knnQueryBuilderRecall = new KNNQueryBuilder(testField, queryVectors[i], k);
Response respRecall = searchKNNIndex(testIndex, knnQueryBuilderRecall,k);
List<KNNResult> resultsRecall = parseSearchResponse(EntityUtils.toString(respRecall.getEntity()), testField);

assertEquals(resultsRecall.size(), k);
kVectors = new ArrayList<>();
for (KNNResult result : resultsRecall) {
kVectors.add(result.getDocId());
}
searchResults.add(kVectors);
}

return searchResults;
}

/**
* Method that call train api and produces a trained model
*
Expand Down
143 changes: 142 additions & 1 deletion src/test/java/org/opensearch/knn/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,51 @@
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.plugin.script.KNNScoringUtil;
import org.opensearch.knn.plugin.stats.suppliers.ModelIndexStatusSupplier;
import java.util.Comparator;
import java.util.Random;
import java.util.Set;
import java.util.PriorityQueue;
import java.util.ArrayList;
import java.util.List;
import java.util.HashSet;
import java.util.Map;
import static org.apache.lucene.util.LuceneTestCase.random;

class DistVector {
public float dist;
public String docID;

public DistVector (float dist, String docID) {
this.dist = dist;
this.docID = docID;
}

public String getDocID() {
return docID;
}

public float getDist() {
return dist;
}
}

class DistComparator implements Comparator<DistVector> {

public int compare(DistVector d1, DistVector d2) {
if (d1.dist < d2.dist) {
return 1;
} else if (d1.dist > d2.dist) {
return -1;
}
return 0;
}
}

public class TestUtils {
public static final String KNN_BWC_PREFIX = "knn-bwc-";
Expand All @@ -32,6 +70,109 @@ public class TestUtils {
public static final String BWCSUITE_ROUND = "tests.rest.bwcsuite_round";
public static final String TEST_CLUSTER_NAME = "tests.clustername";

// Generating vectors using random function with a seed which makes these vectors standard and generate same vectors for each run.
public static float[][] randomlyGenerateStandardVectors(int numVectors, int dimensions, int seed) {
float[][] standardVectors = new float[numVectors][dimensions];
Random rand = new Random(seed);

for (int i = 0; i < numVectors; i++) {
float[] vector = new float[dimensions];
for (int j = 0; j < dimensions; j++) {
vector[j] = rand.nextFloat();
}
standardVectors[i] = vector;
}
return standardVectors;
}

public static float[][] generateRandomVectors(int numVectors, int dimensions) {
float[][] randomVectors = new float[numVectors][dimensions];

for (int i = 0; i < numVectors; i++) {
float[] vector = new float[dimensions];
for (int j = 0; j < dimensions; j++) {
vector[j] = random().nextFloat();
}
randomVectors[i] = vector;
}
return randomVectors;
}

/*
* Here, for a given space type we will compute the 'k' shortest distances among all the index vectors for each and every query vector using a priority queue and
* their document ids are stored. These document ids are later used while calculating Recall value to compare with the document ids of 'k' results obtained for
* each and every search query performed.
*/
public static List<Set<String>> computeGroundTruthValues(float[][] indexVectors, float[][] queryVectors, SpaceType spaceType, int k) {
ArrayList<Set<String>> groundTruthValues = new ArrayList<>();
PriorityQueue<DistVector> pq;
HashSet<String> docIds;
float dist = 0.0f;

for (int i = 0; i < queryVectors.length; i++) {
pq = new PriorityQueue<>(k, new DistComparator());
for (int j = 0; j < indexVectors.length; j++) {
if (spaceType != null && "l2".equals(spaceType.getValue())) {
dist = KNNScoringUtil.l2Squared(queryVectors[i], indexVectors[j]);
}

if (pq.size() < k) {
pq.add(new DistVector(dist, String.valueOf(j+1)));
} else if (pq.peek().getDist() > dist) {
pq.poll();
pq.add(new DistVector(dist, String.valueOf(j+1)));
}
}

docIds = new HashSet<>();
while (!pq.isEmpty()) {
docIds.add(pq.poll().getDocID());
}

groundTruthValues.add(docIds);
}

return groundTruthValues;
}

public static float[][] getQueryVectors(int queryCount, int dimensions, int docCount, boolean isStandard) {
if (isStandard) {
return randomlyGenerateStandardVectors(queryCount, dimensions, docCount+1);
} else {
return generateRandomVectors(queryCount, dimensions);
}
}

public static float[][] getIndexVectors(int docCount, int dimensions, boolean isStandard) {
if (isStandard) {
return randomlyGenerateStandardVectors(docCount, dimensions, 1);
} else {
return generateRandomVectors(docCount, dimensions);
}
}

/*
* Recall is the number of relevant documents retrieved by a search divided by the total number of existing relevant documents.
* We are similarly calculating recall by verifying number of relevant documents obtained in the search results by comparing with
* groundTruthValues and then dividing by 'k'
*/
public static double calculateRecallValue(List<List<String>> searchResults, List<Set<String>> groundTruthValues, int k) {
ArrayList<Float> recalls = new ArrayList<>();

for (int i = 0; i < searchResults.size(); i++) {
float recallVal = 0.0F;
for (int j = 0; j < searchResults.get(i).size(); j++) {
if (groundTruthValues.get(i).contains(searchResults.get(i).get(j))) {
recallVal += 1.0;
}
}
recalls.add(recallVal / k);
}

double sum = recalls.stream().reduce((a,b)->a+b).get();
return sum/recalls.size();
}

/**
* Class to read in some test data from text files
*/
Expand Down
Loading

0 comments on commit 814932d

Please sign in to comment.