Skip to content

Commit

Permalink
Revert "Implemented Brute Force Search class"
Browse files Browse the repository at this point in the history
This reverts commit b6ec064.
  • Loading branch information
punAhuja committed Nov 14, 2024
1 parent b6ec064 commit b57de36
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 175 deletions.
4 changes: 2 additions & 2 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/ExampleApp.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ public static void main(String[] args) throws Throwable {

// Search
SearchResult rslt = index.search(query);
System.out.println(rslt.getAllResults());
System.out.println(rslt.getResults());

// Search from de-serialized index
SearchResult rslt2 = index2.search(query);
System.out.println(rslt2.getAllResults());
System.out.println(rslt2.getResults());

}
}

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,19 @@ private void init() throws Throwable {
File wd = new File(System.getProperty("user.dir"));
bridge = SymbolLookup.libraryLookup(wd.getParent() + "/internal/libcuvs_java.so", arena);

indexMH = linker.downcallHandle(bridge.find("build_index").get(),
indexMH = linker.downcallHandle(bridge.findOrThrow("build_index"),
FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, linker.canonicalLayouts().get("long"),
linker.canonicalLayouts().get("long"), ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS));

searchMH = linker.downcallHandle(bridge.find("search_index").get(),
searchMH = linker.downcallHandle(bridge.findOrThrow("search_index"),
FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, linker.canonicalLayouts().get("int"),
linker.canonicalLayouts().get("long"), linker.canonicalLayouts().get("long"), ValueLayout.ADDRESS,
ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS));

serializeMH = linker.downcallHandle(bridge.find("serialize_index").get(),
serializeMH = linker.downcallHandle(bridge.findOrThrow("serialize_index"),
FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS));

deserializeMH = linker.downcallHandle(bridge.find("deserialize_index").get(),
deserializeMH = linker.downcallHandle(bridge.findOrThrow("deserialize_index"),
FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS));

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public CuVSResources() throws Throwable {
File wd = new File(System.getProperty("user.dir"));
bridge = SymbolLookup.libraryLookup(wd.getParent() + "/internal/libcuvs_java.so", arena);

cresMH = linker.downcallHandle(bridge.find("create_resource").get(),
cresMH = linker.downcallHandle(bridge.findOrThrow("create_resource"),
FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS));

MemoryLayout rvML = linker.canonicalLayouts().get("int");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,66 +9,39 @@

public class SearchResult {

private Map<Integer, Map<Integer, Float>> results; // Stores results for multiple queries
private Map<Integer, Integer> mapping;
SequenceLayout neighboursSL;
SequenceLayout distancesSL;
MemorySegment neighboursMS;
MemorySegment distancesMS;
int topK;

/**
* Constructor for search results loaded from memory.
*/
public SearchResult(SequenceLayout neighboursSL, SequenceLayout distancesSL, MemorySegment neighboursMS,
MemorySegment distancesMS, int topK, Map<Integer, Integer> mapping) {
super();
this.topK = topK;
this.neighboursSL = neighboursSL;
this.distancesSL = distancesSL;
this.neighboursMS = neighboursMS;
this.distancesMS = distancesMS;
this.mapping = mapping;
this.results = new HashMap<>();
this.load();
}

/**
* Constructor for brute force or precomputed results.
*/
public SearchResult(Map<Integer, Map<Integer, Float>> results) {
this.results = results;
}

/**
* Load results from memory segments for multiple queries.
*/
private void load() {
VarHandle neighboursVH = neighboursSL.varHandle(PathElement.sequenceElement());
VarHandle distancesVH = distancesSL.varHandle(PathElement.sequenceElement());

for (long queryIndex = 0; queryIndex < topK; queryIndex++) {
Map<Integer, Float> queryResults = new HashMap<>();
for (long i = 0; i < topK; i++) {
int id = (int) neighboursVH.get(neighboursMS, queryIndex, i);
float distance = (float) distancesVH.get(distancesMS, queryIndex, i);
queryResults.put(mapping != null ? mapping.get(id) : id, distance);
}
results.put((int) queryIndex, queryResults);
}
private Map<Integer, Float> results;
private Map<Integer, Integer> mapping;
SequenceLayout neighboursSL;
SequenceLayout distancesSL;
MemorySegment neighboursMS;
MemorySegment distancesMS;
int topK;

public SearchResult(SequenceLayout neighboursSL, SequenceLayout distancesSL, MemorySegment neighboursMS,
MemorySegment distancesMS, int topK, Map<Integer, Integer> mapping) {
super();
this.topK = topK;
this.neighboursSL = neighboursSL;
this.distancesSL = distancesSL;
this.neighboursMS = neighboursMS;
this.distancesMS = distancesMS;
this.mapping = mapping;
results = new HashMap<Integer, Float>();
this.load();
}

private void load() {
VarHandle neighboursVH = neighboursSL.varHandle(PathElement.sequenceElement());
VarHandle distancesVH = distancesSL.varHandle(PathElement.sequenceElement());

for (long i = 0; i < topK; i++) {
int id = (int) neighboursVH.get(neighboursMS, 0L, i);
results.put(mapping != null ? mapping.get(id) : id, (float) distancesVH.get(distancesMS, 0L, i));
}
}

/**
* Retrieve results for a specific query.
*/
public Map<Integer, Float> getResults(int queryIndex) {
return results.get(queryIndex);
}
public Map<Integer, Float> getResults() {
return results;
}

/**
* Retrieve results for all queries.
*/
public Map<Integer, Map<Integer, Float>> getAllResults() {
return results;
}
}

0 comments on commit b57de36

Please sign in to comment.