Skip to content

Commit

Permalink
Add access to dense_vector values (#71847)
Browse files Browse the repository at this point in the history
Allow direct access to a dense_vector' values in script
through the following functions:

- getVectorValue – returns a vector's value as an array of floats
- getMagnitude – returns a vector's magnitude

Closes #51964
Backport for #71313
  • Loading branch information
mayya-sharipova authored Apr 19, 2021
1 parent d190d58 commit d53e83c
Show file tree
Hide file tree
Showing 23 changed files with 436 additions and 139 deletions.
7 changes: 3 additions & 4 deletions docs/reference/mapping/types/dense-vector.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ A `dense_vector` field stores dense vectors of float values.
The maximum number of dimensions that can be in a vector should
not exceed 2048. A `dense_vector` field is a single-valued field.

These vectors can be used for <<vector-functions,document scoring>>.
For example, a document score can represent a distance between
a given query vector and the indexed document vector.
`dense_vector` fields do not support querying, sorting or aggregating. They can
only be accessed in scripts through the dedicated <<vector-functions,vector functions>>.

You index a dense vector as an array of floats.

Expand Down Expand Up @@ -47,4 +46,4 @@ PUT my-index-000001/_doc/2
--------------------------------------------------

<1> dimsthe number of dimensions in the vector, required parameter.
<1> dimsthe number of dimensions in the vector, required parameter.
57 changes: 57 additions & 0 deletions docs/reference/vectors/vector-functions.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ to limit the number of matched documents with a `query` parameter.

====== `dense_vector` functions

This is the list of available vector functions and vector access methods:

1. `cosineSimilarity` – calculates cosine similarity
2. `dotProduct` – calculates dot product
3. `l1norm` – calculates L^1^ distance
4. `l2norm` - calculates L^2^ distance
5. `doc[<field>].vectorValue` – returns a vector's value as an array of floats
6. `doc[<field>].magnitude` – returns a vector's magnitude

Let's create an index with a `dense_vector` mapping and index a couple
of documents into it.

Expand Down Expand Up @@ -198,6 +207,54 @@ You can check if a document has a value for the field `my_vector` by
--------------------------------------------------
// NOTCONSOLE

The recommended way to access dense vectors is through `cosineSimilarity`,
`dotProduct`, `l1norm` or `l2norm` functions. But for custom use cases,
you can access dense vectors's values directly through the following functions:

- `doc[<field>].vectorValue` – returns a vector's value as an array of floats

- `doc[<field>].magnitude` – returns a vector's magnitude as a float
(for vectors created prior to version 7.5 the magnitude is not stored.
So this function calculates it anew every time it is called).

For example, the script below implements a cosine similarity using these
two functions:

[source,console]
--------------------------------------------------
GET my-index-000001/_search
{
"query": {
"script_score": {
"query" : {
"bool" : {
"filter" : {
"term" : {
"status" : "published"
}
}
}
},
"script": {
"source": """
float[] v = doc['my_dense_vector'].vectorValue;
float vm = doc['my_dense_vector'].magnitude;
float dotProduct = 0;
for (int i = 0; i < v.length; i++) {
dotProduct += v[i] * params.queryVector[i];
}
return dotProduct / (vm * (float) params.queryVectorMag);
""",
"params": {
"queryVector": [4, 3.4, -0.2],
"queryVectorMag": 5.25357
}
}
}
}
}
--------------------------------------------------

====== `sparse_vector` functions

deprecated[7.6, The `sparse_vector` type is deprecated and will be removed in 8.0.]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.script.ExplainableScoreScript;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.Version;

import java.io.IOException;
import java.util.Objects;
Expand Down Expand Up @@ -42,15 +41,13 @@ public float score() {

private final int shardId;
private final String indexName;
private final Version indexVersion;

public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) {
public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) {
super(CombineFunction.REPLACE);
this.sScript = sScript;
this.script = script;
this.indexName = indexName;
this.shardId = shardId;
this.indexVersion = indexVersion;
}

@Override
Expand All @@ -60,7 +57,6 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx
leafScript.setScorer(scorer);
leafScript._setIndexName(indexName);
leafScript._setShard(shardId);
leafScript._setIndexVersion(indexVersion);
return new LeafScoreFunction() {
@Override
public double score(int docId, float subQueryScore) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ private ScoreScript makeScoreScript(LeafReaderContext context) throws IOExceptio
final ScoreScript scoreScript = scriptBuilder.newInstance(context);
scoreScript._setIndexName(indexName);
scoreScript._setShard(shardId);
scoreScript._setIndexVersion(indexVersion);
return scoreScript;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ protected ScoreFunction doToFunction(SearchExecutionContext context) {
try {
ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
return new ScriptScoreFunction(script, searchScript,
context.index().getName(), context.getShardId(), context.indexVersionCreated());
return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId());
} catch (Exception e) {
throw new QueryShardException(context, "script_score: the script could not be loaded", e);
}
Expand Down
22 changes: 0 additions & 22 deletions server/src/main/java/org/elasticsearch/script/ScoreScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorable;
import org.elasticsearch.Version;
import org.elasticsearch.common.logging.DeprecationCategory;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.index.fielddata.ScriptDocValues;
Expand Down Expand Up @@ -85,7 +84,6 @@ public Explanation get(double score, Explanation subQueryExplanation) {
private int docId;
private int shardId = -1;
private String indexName = null;
private Version indexVersion = null;

public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
// null check needed b/c of expression engine subclass
Expand Down Expand Up @@ -185,19 +183,6 @@ public String _getIndex() {
}
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
* It is only used within predefined painless functions.
* @return index version or throws an exception if the index version is not set up for this script instance
*/
public Version _getIndexVersion() {
if (indexVersion != null) {
return indexVersion;
} else {
throw new IllegalArgumentException("index version can not be looked up!");
}
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
*/
Expand All @@ -212,13 +197,6 @@ public void _setIndexName(String indexName) {
this.indexName = indexName;
}

/**
* Starting a name with underscore, so that the user cannot access this function directly through a script
*/
public void _setIndexVersion(Version indexVersion) {
this.indexVersion = indexVersion;
}


/** A factory to construct {@link ScoreScript} instances. */
public interface LeafFactory {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
---
"Access to values of dense_vector in script":
- skip:
version: " - 7.12.99"
reason: "Access to values of dense_vector in script was added in 7.13"
- do:
indices.create:
index: test-index
body:
mappings:
properties:
v:
type: dense_vector
dims: 3

- do:
bulk:
index: test-index
refresh: true
body:
- '{"index": {"_id": "1"}}'
- '{"v": [1, 1, 1]}'
- '{"index": {"_id": "2"}}'
- '{"v": [1, 1, 2]}'
- '{"index": {"_id": "3"}}'
- '{"v": [1, 1, 3]}'
- '{"index": {"_id": "missing_vector"}}'
- '{}'

# vector functions in loop – return the index of the closest parameter vector based on cosine similarity
- do:
search:
body:
query:
script_score:
query: { "exists": { "field": "v" } }
script:
source: |
float[] v = doc['v'].vectorValue;
float vm = doc['v'].magnitude;
int closestPv = 0;
float maxCosSim = -1;
for (int i = 0; i < params.pvs.length; i++) {
float dotProduct = 0;
for (int j = 0; j < v.length; j++) {
dotProduct += v[j] * params.pvs[i][j];
}
float cosSim = dotProduct / (vm * (float) params.pvs_magnts[i]);
if (maxCosSim < cosSim) {
maxCosSim = cosSim;
closestPv = i;
}
}
closestPv;
params:
pvs: [ [ 1, 1, 1 ], [ 1, 1, 2 ], [ 1, 1, 3 ] ]
pvs_magnts: [1.7320, 2.4495, 3.3166]

- match: { hits.hits.0._id: "3" }
- match: { hits.hits.0._score: 2 }
- match: { hits.hits.1._id: "2" }
- match: { hits.hits.1._score: 1 }
- match: { hits.hits.2._id: "1" }
- match: { hits.hits.2._score: 0 }
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ protected List<Parameter<?>> getParameters() {
public DenseVectorFieldMapper build(ContentPath contentPath) {
return new DenseVectorFieldMapper(
name,
new DenseVectorFieldType(buildFullName(contentPath), dims.getValue(), meta.getValue()),
new DenseVectorFieldType(buildFullName(contentPath), indexVersionCreated, dims.getValue(), meta.getValue()),
dims.getValue(),
indexVersionCreated,
multiFieldsBuilder.build(this, contentPath),
Expand All @@ -95,10 +95,12 @@ public DenseVectorFieldMapper build(ContentPath contentPath) {

public static final class DenseVectorFieldType extends MappedFieldType {
private final int dims;
private final Version indexVersionCreated;

public DenseVectorFieldType(String name, int dims, Map<String, String> meta) {
public DenseVectorFieldType(String name, Version indexVersionCreated, int dims, Map<String, String> meta) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.dims = dims;
this.indexVersionCreated = indexVersionCreated;
}

int dims() {
Expand All @@ -125,7 +127,7 @@ protected Object parseSourceValue(Object value) {

@Override
public DocValueFormat docValueFormat(String format, ZoneId timeZone) {
throw new UnsupportedOperationException(
throw new IllegalArgumentException(
"Field [" + name() + "] of type [" + typeName() + "] doesn't support docvalue_fields or aggregations");
}

Expand All @@ -136,7 +138,7 @@ public boolean isAggregatable() {

@Override
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
return new VectorIndexFieldData.Builder(name(), true, CoreValuesSourceType.KEYWORD);
return new VectorIndexFieldData.Builder(name(), true, CoreValuesSourceType.KEYWORD, indexVersionCreated, dims);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ protected List<Parameter<?>> getParameters() {
@Override
public SparseVectorFieldMapper build(ContentPath contentPath) {
return new SparseVectorFieldMapper(
name, new SparseVectorFieldType(buildFullName(contentPath), meta.getValue()),
name, new SparseVectorFieldType(buildFullName(contentPath), indexCreatedVersion, meta.getValue()),
multiFieldsBuilder.build(this, contentPath), copyTo.build(), indexCreatedVersion);
}
}
Expand All @@ -83,8 +83,10 @@ name, new SparseVectorFieldType(buildFullName(contentPath), meta.getValue()),

public static final class SparseVectorFieldType extends MappedFieldType {

public SparseVectorFieldType(String name, Map<String, String> meta) {
private final Version indexVersionCreated;
public SparseVectorFieldType(String name, Version indexVersionCreated, Map<String, String> meta) {
super(name, false, false, true, TextSearchInfo.NONE, meta);
this.indexVersionCreated = indexVersionCreated;
}

@Override
Expand All @@ -94,7 +96,7 @@ public String typeName() {

@Override
public DocValueFormat docValueFormat(String format, ZoneId timeZone) {
throw new UnsupportedOperationException(
throw new IllegalArgumentException(
"Field [" + name() + "] of type [" + typeName() + "] doesn't support docvalue_fields or aggregations");
}

Expand All @@ -118,7 +120,7 @@ public Query existsQuery(SearchExecutionContext context) {

@Override
public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
return new VectorIndexFieldData.Builder(name(), false, CoreValuesSourceType.KEYWORD);
return new VectorIndexFieldData.Builder(name(), false, CoreValuesSourceType.KEYWORD, indexVersionCreated, -1);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import java.nio.ByteBuffer;

// static utility functions for encoding and decoding dense_vector and sparse_vector fields
public final class VectorEncoderDecoder {
static final byte INT_BYTES = 4;
static final byte SHORT_BYTES = 2;
Expand Down Expand Up @@ -168,9 +167,51 @@ public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) {
* NOTE: this function can only be called on vectors from an index version greater than or
* equal to 7.5.0, since vectors created prior to that do not store the magnitude.
*/
public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) {
public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) {
assert indexVersion.onOrAfter(Version.V_7_5_0);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);
return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - INT_BYTES);
}

/**
* Calculates vector magnitude
*/
private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) {
final int length = denseVectorLength(indexVersion, vectorBR);
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
double magnitude = 0.0f;
for (int i = 0; i < length; i++) {
float value = byteBuffer.getFloat();
magnitude += value * value;
}
magnitude = Math.sqrt(magnitude);
return (float) magnitude;
}

public static float getMagnitude(Version indexVersion, BytesRef vectorBR) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
if (indexVersion.onOrAfter(Version.V_7_5_0)) {
return decodeMagnitude(indexVersion, vectorBR);
} else {
return calculateMagnitude(indexVersion, vectorBR);
}
}

/**
* Decodes a BytesRef into the provided array of floats
* @param vectorBR - dense vector encoded in BytesRef
* @param vector - array of floats where the decoded vector should be stored
*/
public static void decodeDenseVector(BytesRef vectorBR, float[] vector) {
if (vectorBR == null) {
throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
}
ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
for (int dim = 0; dim < vector.length; dim++) {
vector[dim] = byteBuffer.getFloat();
}
}

}
Loading

0 comments on commit d53e83c

Please sign in to comment.