Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Recall Tests #251

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/test/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.commons.lang.StringUtils;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.knn.index.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.indices.ModelDao;
Expand Down Expand Up @@ -782,6 +783,88 @@ public void bulkIngestRandomVectors(String indexName, String fieldName, int numV

}

//Method that adds multiple documents into the index using Bulk API
public void bulkAddKnnDocs(String index, String fieldName, float[][] indexVectors, int docCount) throws IOException {
Request request = new Request(
"POST",
"/_bulk"
);

request.addParameter("refresh", "true");
StringBuilder sb = new StringBuilder();

for (int i = 0; i < docCount; i++) {
sb.append("{ \"index\" : { \"_index\" : \"")
.append(index)
.append("\", \"_id\" : \"")
.append(i+1)
.append("\" } }\n")
.append("{ \"")
.append(fieldName)
.append("\" : ")
.append(Arrays.toString(indexVectors[i]))
.append(" }\n");
}

request.setJsonEntity(sb.toString());

Response response = client().performRequest(request);
assertEquals(response.getStatusLine().getStatusCode(), 200);
}

//Method that returns index vectors of the documents that were added before into the index
public float[][] getIndexVectorsFromIndex(String testIndex, String testField, int docCount, int dimensions) throws IOException {
float[][] vectors = new float[docCount][dimensions];

QueryBuilder qb = new MatchAllQueryBuilder();

Request request = new Request(
"POST",
"/" + testIndex + "/_search"
);

request.addParameter("size", Integer.toString(docCount));
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
builder.field("query", qb);
builder.endObject();
request.setJsonEntity(Strings.toString(builder));

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK,
RestStatus.fromCode(response.getStatusLine().getStatusCode()));

List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), testField);
int i = 0;

for (KNNResult result : results) {
float[] primitiveArray = Floats.toArray(Arrays.stream(result.getVector()).collect(Collectors.toList()));
vectors[i++] = primitiveArray;
}

return vectors;
}

// Method that performs bulk search for multiple queries and stores the resulting documents ids into list
public List<List<String>> bulkSearch(String testIndex, String testField, float[][] queryVectors, int k) throws IOException {
List<List<String>> searchResults = new ArrayList<>();
List<String> kVectors;

for (int i = 0; i < queryVectors.length; i++) {
KNNQueryBuilder knnQueryBuilderRecall = new KNNQueryBuilder(testField, queryVectors[i], k);
Response respRecall = searchKNNIndex(testIndex, knnQueryBuilderRecall,k);
List<KNNResult> resultsRecall = parseSearchResponse(EntityUtils.toString(respRecall.getEntity()), testField);

assertEquals(resultsRecall.size(), k);
kVectors = new ArrayList<>();
for (KNNResult result : resultsRecall) {
kVectors.add(result.getDocId());
}
searchResults.add(kVectors);
}

return searchResults;
}

/**
* Method that call train api and produces a trained model
*
Expand Down
143 changes: 142 additions & 1 deletion src/test/java/org/opensearch/knn/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,51 @@
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.plugin.script.KNNScoringUtil;
import org.opensearch.knn.plugin.stats.suppliers.ModelIndexStatusSupplier;
import java.util.Comparator;
import java.util.Random;
import java.util.Set;
import java.util.PriorityQueue;
import java.util.ArrayList;
import java.util.List;
import java.util.HashSet;
import java.util.Map;
import static org.apache.lucene.util.LuceneTestCase.random;

class DistVector {
public float dist;
public String docID;

public DistVector (float dist, String docID) {
this.dist = dist;
this.docID = docID;
}

public String getDocID() {
return docID;
}

public float getDist() {
return dist;
}
}

class DistComparator implements Comparator<DistVector> {

public int compare(DistVector d1, DistVector d2) {
if (d1.dist < d2.dist) {
return 1;
} else if (d1.dist > d2.dist) {
return -1;
}
return 0;
}
}

public class TestUtils {
public static final String KNN_BWC_PREFIX = "knn-bwc-";
Expand All @@ -32,6 +70,109 @@ public class TestUtils {
public static final String BWCSUITE_ROUND = "tests.rest.bwcsuite_round";
public static final String TEST_CLUSTER_NAME = "tests.clustername";

// Generating vectors using random function with a seed which makes these vectors standard and generate same vectors for each run.
public static float[][] randomlyGenerateStandardVectors(int numVectors, int dimensions, int seed) {
float[][] standardVectors = new float[numVectors][dimensions];
Random rand = new Random(seed);

for (int i = 0; i < numVectors; i++) {
float[] vector = new float[dimensions];
for (int j = 0; j < dimensions; j++) {
vector[j] = rand.nextFloat();
}
standardVectors[i] = vector;
}
return standardVectors;
}

public static float[][] generateRandomVectors(int numVectors, int dimensions) {
float[][] randomVectors = new float[numVectors][dimensions];

for (int i = 0; i < numVectors; i++) {
float[] vector = new float[dimensions];
for (int j = 0; j < dimensions; j++) {
vector[j] = random().nextFloat();
}
randomVectors[i] = vector;
}
return randomVectors;
}

/*
* Here, for a given space type we will compute the 'k' shortest distances among all the index vectors for each and every query vector using a priority queue and
* their document ids are stored. These document ids are later used while calculating Recall value to compare with the document ids of 'k' results obtained for
* each and every search query performed.
*/
public static List<Set<String>> computeGroundTruthValues(float[][] indexVectors, float[][] queryVectors, SpaceType spaceType, int k) {
ArrayList<Set<String>> groundTruthValues = new ArrayList<>();
PriorityQueue<DistVector> pq;
HashSet<String> docIds;
float dist = 0.0f;

for (int i = 0; i < queryVectors.length; i++) {
pq = new PriorityQueue<>(k, new DistComparator());
for (int j = 0; j < indexVectors.length; j++) {
if (spaceType != null && "l2".equals(spaceType.getValue())) {
dist = KNNScoringUtil.l2Squared(queryVectors[i], indexVectors[j]);
}

if (pq.size() < k) {
pq.add(new DistVector(dist, String.valueOf(j+1)));
} else if (pq.peek().getDist() > dist) {
pq.poll();
pq.add(new DistVector(dist, String.valueOf(j+1)));
}
}

docIds = new HashSet<>();
while (!pq.isEmpty()) {
docIds.add(pq.poll().getDocID());
}

groundTruthValues.add(docIds);
}

return groundTruthValues;
}

public static float[][] getQueryVectors(int queryCount, int dimensions, int docCount, boolean isStandard) {
if (isStandard) {
return randomlyGenerateStandardVectors(queryCount, dimensions, docCount+1);
} else {
return generateRandomVectors(queryCount, dimensions);
}
}

public static float[][] getIndexVectors(int docCount, int dimensions, boolean isStandard) {
if (isStandard) {
return randomlyGenerateStandardVectors(docCount, dimensions, 1);
} else {
return generateRandomVectors(docCount, dimensions);
}
}

/*
* Recall is the number of relevant documents retrieved by a search divided by the total number of existing relevant documents.
* We are similarly calculating recall by verifying number of relevant documents obtained in the search results by comparing with
* groundTruthValues and then dividing by 'k'
*/
public static double calculateRecallValue(List<List<String>> searchResults, List<Set<String>> groundTruthValues, int k) {
ArrayList<Float> recalls = new ArrayList<>();

for (int i = 0; i < searchResults.size(); i++) {
float recallVal = 0.0F;
for (int j = 0; j < searchResults.get(i).size(); j++) {
if (groundTruthValues.get(i).contains(searchResults.get(i).get(j))) {
recallVal += 1.0;
}
}
recalls.add(recallVal / k);
}

double sum = recalls.stream().reduce((a,b)->a+b).get();
return sum/recalls.size();
}

/**
* Class to read in some test data from text files
*/
Expand Down
Loading