Skip to content

Commit

Permalink
Get some more renames that weren't handled by IDE
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Mar 29, 2024
1 parent 9c46626 commit d3f3d52
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public static SparseEmbeddingResults of(List<? extends InferenceResults> results

for (InferenceResults result : results) {
if (result instanceof TextExpansionResults expansionResults) {
embeddings.add(Embedding.create(expansionResults.getWeightedTokens(), expansionResults.isTruncated()));
embeddings.add(Embedding.create(expansionResults.getVectorDimensions(), expansionResults.isTruncated()));
} else {
throw new IllegalArgumentException("Received invalid legacy inference result");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public TextExpansionResults(StreamInput in) throws IOException {
this.vectorDimensions = in.readCollectionAsList(VectorDimension::new);
}

public List<VectorDimension> getWeightedTokens() {
public List<VectorDimension> getVectorDimensions() {
return vectorDimensions;
}

Expand All @@ -89,8 +89,8 @@ public Object predictedValue() {
@Override
void doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.startObject(resultsField);
for (var weightedToken : vectorDimensions) {
weightedToken.toXContent(builder, params);
for (var vectorDimension : vectorDimensions) {
vectorDimension.toXContent(builder, params);
}
builder.endObject();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,18 @@ protected TextExpansionResults createTestInstance() {

@Override
protected TextExpansionResults mutateInstance(TextExpansionResults instance) {
return new TextExpansionResults(instance.getResultsField() + "-FOO", instance.getWeightedTokens(), instance.isTruncated() == false);
return new TextExpansionResults(
instance.getResultsField() + "-FOO",
instance.getVectorDimensions(),
instance.isTruncated() == false
);
}

@Override
@SuppressWarnings("unchecked")
void assertFieldValues(TextExpansionResults createdInstance, IngestDocument document, String parentField, String resultsField) {
var ingestedTokens = (Map<String, Object>) document.getFieldValue(parentField + resultsField, Map.class);
var tokenMap = createdInstance.getWeightedTokens()
var tokenMap = createdInstance.getVectorDimensions()
.stream()
.collect(Collectors.toMap(TextExpansionResults.VectorDimension::token, TextExpansionResults.VectorDimension::weight));
assertEquals(tokenMap.size(), ingestedTokens.size());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This test tests cases covered by ML's text_expansion.yml
# This test tests cases covered by ML's sparse_vector.yml
---
setup:
- skip:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# This test tests cases covered by ML's text_expansion.yml
# This test tests cases covered by ML's sparse_vector.yml
---
setup:
- skip:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class SparseVectorQueryBuilder extends AbstractQueryBuilder<SparseVectorQ
private final String modelText;
private final String modelId;
private final List<VectorDimension> vectorDimensions;
private SetOnce<TextExpansionResults> weightedTokensSupplier;
private SetOnce<TextExpansionResults> vectorDimensionsSupplier;
private final TokenPruningConfig tokenPruningConfig;

public SparseVectorQueryBuilder(
Expand Down Expand Up @@ -101,15 +101,15 @@ public SparseVectorQueryBuilder(StreamInput in) throws IOException {
this.vectorDimensions = in.readOptionalCollectionAsList(VectorDimension::new);
}

private SparseVectorQueryBuilder(SparseVectorQueryBuilder other, SetOnce<TextExpansionResults> weightedTokensSupplier) {
private SparseVectorQueryBuilder(SparseVectorQueryBuilder other, SetOnce<TextExpansionResults> vectorDimensionsSupplier) {
this.fieldName = other.fieldName;
this.modelText = other.modelText;
this.modelId = other.modelId;
this.vectorDimensions = other.vectorDimensions;
this.tokenPruningConfig = other.tokenPruningConfig;
this.boost = other.boost;
this.queryName = other.queryName;
this.weightedTokensSupplier = weightedTokensSupplier;
this.vectorDimensionsSupplier = vectorDimensionsSupplier;
}

@Override
Expand All @@ -124,7 +124,7 @@ public TransportVersion getMinimalSupportedVersion() {

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
if (weightedTokensSupplier != null) {
if (vectorDimensionsSupplier != null) {
throw new IllegalStateException("token supplier must be null, can't serialize suppliers, missing a rewriteAndFetch?");
}
out.writeString(fieldName);
Expand Down Expand Up @@ -159,11 +159,11 @@ protected void doXContent(XContentBuilder builder, Params params) throws IOExcep

@Override
protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
if (weightedTokensSupplier != null) {
if (weightedTokensSupplier.get() == null) {
if (vectorDimensionsSupplier != null) {
if (vectorDimensionsSupplier.get() == null) {
return this;
}
return textExpansionResultsToQuery(fieldName, weightedTokensSupplier.get());
return textExpansionResultsToQuery(fieldName, vectorDimensionsSupplier.get());
}

if (modelId != null) {
Expand Down Expand Up @@ -246,7 +246,7 @@ private QueryBuilder vectorDimensionsToQuery(String fieldName, List<VectorDimens
}

private QueryBuilder textExpansionResultsToQuery(String fieldName, TextExpansionResults textExpansionResults) {
return vectorDimensionsToQuery(fieldName, textExpansionResults.getWeightedTokens());
return vectorDimensionsToQuery(fieldName, textExpansionResults.getVectorDimensions());
}

@Override
Expand All @@ -261,12 +261,12 @@ protected boolean doEquals(SparseVectorQueryBuilder other) {
&& Objects.equals(modelId, other.modelId)
&& Objects.equals(tokenPruningConfig, other.tokenPruningConfig)
&& Objects.equals(vectorDimensions, other.vectorDimensions)
&& Objects.equals(weightedTokensSupplier, other.weightedTokensSupplier);
&& Objects.equals(vectorDimensionsSupplier, other.vectorDimensionsSupplier);
}

@Override
protected int doHashCode() {
return Objects.hash(fieldName, modelText, modelId, tokenPruningConfig, vectorDimensions, weightedTokensSupplier);
return Objects.hash(fieldName, modelText, modelId, tokenPruningConfig, vectorDimensions, vectorDimensionsSupplier);
}

public static SparseVectorQueryBuilder fromXContent(XContentParser parser) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ private QueryBuilder weightedTokensToQuery(String fieldName, TextExpansionResult
if (tokenPruningConfig != null) {
VectorDimensionsQueryBuilder vectorDimensionsQueryBuilder = new VectorDimensionsQueryBuilder(
fieldName,
textExpansionResults.getWeightedTokens(),
textExpansionResults.getVectorDimensions(),
tokenPruningConfig
);
vectorDimensionsQueryBuilder.queryName(queryName);
Expand All @@ -236,7 +236,7 @@ private QueryBuilder weightedTokensToQuery(String fieldName, TextExpansionResult
// if no token pruning configuration is specified we fall back to a boolean query.
// TODO this should be updated to always use a WeightedTokensQueryBuilder once it's in all supported versions.
var boolQuery = QueryBuilders.boolQuery();
for (var weightedToken : textExpansionResults.getWeightedTokens()) {
for (var weightedToken : textExpansionResults.getVectorDimensions()) {
boolQuery.should(QueryBuilders.termQuery(fieldName, weightedToken.token()).boost(weightedToken.weight()));
}
boolQuery.minimumShouldMatch(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import static org.elasticsearch.xpack.ml.queries.TextExpansionQueryBuilder.PRUNING_CONFIG;

public class VectorDimensionsQueryBuilder extends AbstractQueryBuilder<VectorDimensionsQueryBuilder> {
// TODO Decide if we want to allow this to remain a query, but deprecated, or remove entirely since it's preview.
public static final String NAME = "weighted_tokens";

public static final ParseField TOKENS_FIELD = new ParseField("tokens");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -566,20 +566,20 @@ public void testMutateDocumentWithInputFieldsNested() {
assertEquals(modelId, document.getFieldValue("ml.results.model_id", String.class));

var bodyTokens = document.getFieldValue("ml.results.body_tokens", HashMap.class);
assertEquals(teResult1.getWeightedTokens().size(), bodyTokens.entrySet().size());
if (teResult1.getWeightedTokens().isEmpty() == false) {
assertEquals(teResult1.getVectorDimensions().size(), bodyTokens.entrySet().size());
if (teResult1.getVectorDimensions().isEmpty() == false) {
assertEquals(
(float) bodyTokens.get(teResult1.getWeightedTokens().get(0).token()),
teResult1.getWeightedTokens().get(0).weight(),
(float) bodyTokens.get(teResult1.getVectorDimensions().get(0).token()),
teResult1.getVectorDimensions().get(0).weight(),
0.001
);
}
var contentTokens = document.getFieldValue("ml.results.content_tokens", HashMap.class);
assertEquals(teResult2.getWeightedTokens().size(), contentTokens.entrySet().size());
if (teResult2.getWeightedTokens().isEmpty() == false) {
assertEquals(teResult2.getVectorDimensions().size(), contentTokens.entrySet().size());
if (teResult2.getVectorDimensions().isEmpty() == false) {
assertEquals(
(float) contentTokens.get(teResult2.getWeightedTokens().get(0).token()),
teResult2.getWeightedTokens().get(0).weight(),
(float) contentTokens.get(teResult2.getVectorDimensions().get(0).token()),
teResult2.getVectorDimensions().get(0).weight(),
0.001
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public void testProcessResult() {
var results = (TextExpansionResults) inferenceResult;
assertEquals(results.getResultsField(), "foo");

var weightedTokens = results.getWeightedTokens();
var weightedTokens = results.getVectorDimensions();
assertThat(weightedTokens, hasSize(3));
assertEquals(new TextExpansionResults.VectorDimension("e", 4.0f), weightedTokens.get(0));
assertEquals(new TextExpansionResults.VectorDimension("d", 3.0f), weightedTokens.get(1));
Expand All @@ -84,7 +84,7 @@ public void testSanitiseVocab() {
var results = (TextExpansionResults) inferenceResult;
assertEquals(results.getResultsField(), "foo");

var weightedTokens = results.getWeightedTokens();
var weightedTokens = results.getVectorDimensions();
assertThat(weightedTokens, hasSize(6));
assertEquals(new TextExpansionResults.VectorDimension("fff", 6.0f), weightedTokens.get(0));
assertEquals(new TextExpansionResults.VectorDimension("XXX", 5.0f), weightedTokens.get(1));
Expand Down Expand Up @@ -112,7 +112,7 @@ public void testSanitizeOutputTokens() {
TokenizationResult tokenizationResult = new BertTokenizationResult(vocab, List.of(), 0);

TextExpansionResults results = (TextExpansionResults) resultProcessor.processResult(tokenizationResult, pytorchResult, false);
var weightedTokens = results.getWeightedTokens();
var weightedTokens = results.getVectorDimensions();
assertThat(weightedTokens, hasSize(5));
assertEquals(new TextExpansionResults.VectorDimension("##__", 5.0f), weightedTokens.get(0));
assertEquals(new TextExpansionResults.VectorDimension("__", 4.0f), weightedTokens.get(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@

import static org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder.MODEL_ID;
import static org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder.MODEL_TEXT;
import static org.elasticsearch.xpack.ml.queries.SparseVectorQueryBuilder.VECTOR_DIMENSIONS;
import static org.elasticsearch.xpack.ml.queries.VectorDimensionsQueryBuilder.TOKENS_FIELD;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.Matchers.either;
Expand All @@ -52,10 +51,10 @@
public class SparseVectorQueryBuilderTests extends AbstractQueryTestCase<SparseVectorQueryBuilder> {

private static final String RANK_FEATURES_FIELD = "rank";
private static final List<TextExpansionResults.VectorDimension> WEIGHTED_TOKENS = List.of(
private static final List<TextExpansionResults.VectorDimension> VECTOR_DIMENSIONS = List.of(
new TextExpansionResults.VectorDimension("foo", .42f)
);
private static final int NUM_TOKENS = WEIGHTED_TOKENS.size();
private static final int NUM_TOKENS = VECTOR_DIMENSIONS.size();

@Override
protected SparseVectorQueryBuilder doCreateTestQueryBuilder() {
Expand All @@ -64,7 +63,7 @@ protected SparseVectorQueryBuilder doCreateTestQueryBuilder() {
: null;
String modelText = randomAlphaOfLength(4);
String modelId = randomBoolean() ? randomAlphaOfLength(4) : null;
List<TextExpansionResults.VectorDimension> vectorDimensions = modelId == null ? WEIGHTED_TOKENS : null;
List<TextExpansionResults.VectorDimension> vectorDimensions = modelId == null ? VECTOR_DIMENSIONS : null;

var builder = new SparseVectorQueryBuilder(RANK_FEATURES_FIELD, modelText, modelId, vectorDimensions, tokenPruningConfig);
if (randomBoolean()) {
Expand Down Expand Up @@ -204,7 +203,11 @@ public void testIllegalValues() {
() -> new SparseVectorQueryBuilder("field name", "model text", null, null)
);
assertEquals(
"[sparse_vector] requires one of [" + MODEL_ID.getPreferredName() + "], or [" + VECTOR_DIMENSIONS.getPreferredName() + "]",
"[sparse_vector] requires one of ["
+ MODEL_ID.getPreferredName()
+ "], or ["
+ SparseVectorQueryBuilder.VECTOR_DIMENSIONS.getPreferredName()
+ "]",
e.getMessage()
);
}
Expand All @@ -224,7 +227,7 @@ public void testToXContentWithModelId() throws IOException {
}

public void testToXContentWithVectorDimensions() throws IOException {
QueryBuilder query = new SparseVectorQueryBuilder("foo", "bar", null, WEIGHTED_TOKENS);
QueryBuilder query = new SparseVectorQueryBuilder("foo", "bar", null, VECTOR_DIMENSIONS);
checkGeneratedJson("""
{
"sparse_vector": {
Expand Down Expand Up @@ -256,7 +259,7 @@ public void testToXContentWithThresholds() throws IOException {
}

public void testToXContentWithThresholdsAndVectorDimensions() throws IOException {
QueryBuilder query = new SparseVectorQueryBuilder("foo", "bar", null, WEIGHTED_TOKENS, new TokenPruningConfig(4, 0.3f, false));
QueryBuilder query = new SparseVectorQueryBuilder("foo", "bar", null, VECTOR_DIMENSIONS, new TokenPruningConfig(4, 0.3f, false));
checkGeneratedJson("""
{
"sparse_vector": {
Expand Down Expand Up @@ -293,7 +296,7 @@ public void testToXContentWithThresholdsAndOnlyScorePrunedTokens() throws IOExce
}

public void testToXContentWithThresholdsAndOnlyScorePrunedTokensAndVectorDimensions() throws IOException {
QueryBuilder query = new SparseVectorQueryBuilder("foo", "bar", null, WEIGHTED_TOKENS, new TokenPruningConfig(4, 0.3f, true));
QueryBuilder query = new SparseVectorQueryBuilder("foo", "bar", null, VECTOR_DIMENSIONS, new TokenPruningConfig(4, 0.3f, true));
checkGeneratedJson("""
{
"sparse_vector": {
Expand All @@ -314,7 +317,7 @@ public void testToXContentWithThresholdsAndOnlyScorePrunedTokensAndVectorDimensi

@Override
public void testValidOutput() {
QueryBuilder query = new SparseVectorQueryBuilder("foo", "bar", null, WEIGHTED_TOKENS, new TokenPruningConfig(4, 0.3f, true));
QueryBuilder query = new SparseVectorQueryBuilder("foo", "bar", null, VECTOR_DIMENSIONS, new TokenPruningConfig(4, 0.3f, true));
assertEquals("""
{
"sparse_vector" : {
Expand All @@ -335,7 +338,10 @@ public void testValidOutput() {

@Override
protected String[] shuffleProtectedFields() {
return new String[] { TOKENS_FIELD.getPreferredName(), MODEL_ID.getPreferredName(), VECTOR_DIMENSIONS.getPreferredName() };
return new String[] {
TOKENS_FIELD.getPreferredName(),
MODEL_ID.getPreferredName(),
SparseVectorQueryBuilder.VECTOR_DIMENSIONS.getPreferredName() };
}

public void testThatTokensAreCorrectlyPruned() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
public class VectorDimensionsQueryBuilderTests extends AbstractQueryTestCase<VectorDimensionsQueryBuilder> {

private static final String RANK_FEATURES_FIELD = "rank";
private static final List<VectorDimension> WEIGHTED_TOKENS = List.of(new VectorDimension("foo", .42f));
private static final int NUM_TOKENS = WEIGHTED_TOKENS.size();
private static final List<VectorDimension> VECTOR_DIMENSIONS = List.of(new VectorDimension("foo", .42f));
private static final int NUM_TOKENS = VECTOR_DIMENSIONS.size();

@Override
protected VectorDimensionsQueryBuilder doCreateTestQueryBuilder() {
Expand All @@ -64,7 +64,7 @@ private VectorDimensionsQueryBuilder createTestQueryBuilder(boolean onlyScorePru
? new TokenPruningConfig(randomIntBetween(1, 100), randomFloat(), onlyScorePrunedTokens)
: null;

var builder = new VectorDimensionsQueryBuilder(RANK_FEATURES_FIELD, WEIGHTED_TOKENS, tokenPruningConfig);
var builder = new VectorDimensionsQueryBuilder(RANK_FEATURES_FIELD, VECTOR_DIMENSIONS, tokenPruningConfig);
if (randomBoolean()) {
builder.boost((float) randomDoubleBetween(0.1, 10.0, true));
}
Expand Down Expand Up @@ -95,7 +95,7 @@ protected Object simulateMethod(Method method, Object[] args) {
// asserts that 2 rewritten queries are the same
var response = InferModelAction.Response.builder()
.setId(request.getId())
.addInferenceResults(List.of(new TextExpansionResults("foo", WEIGHTED_TOKENS.stream().toList(), randomBoolean())))
.addInferenceResults(List.of(new TextExpansionResults("foo", VECTOR_DIMENSIONS.stream().toList(), randomBoolean())))
.build();
@SuppressWarnings("unchecked") // We matched the method above.
ActionListener<InferModelAction.Response> listener = (ActionListener<InferModelAction.Response>) args[2];
Expand Down Expand Up @@ -382,7 +382,7 @@ public void testIllegalValues() {
}

public void testToXContent() throws Exception {
QueryBuilder query = new VectorDimensionsQueryBuilder("foo", WEIGHTED_TOKENS, null);
QueryBuilder query = new VectorDimensionsQueryBuilder("foo", VECTOR_DIMENSIONS, null);
checkGeneratedJson("""
{
"weighted_tokens": {
Expand All @@ -396,7 +396,7 @@ public void testToXContent() throws Exception {
}

public void testToXContentWithThresholds() throws Exception {
QueryBuilder query = new VectorDimensionsQueryBuilder("foo", WEIGHTED_TOKENS, new TokenPruningConfig(4, 0.4f, false));
QueryBuilder query = new VectorDimensionsQueryBuilder("foo", VECTOR_DIMENSIONS, new TokenPruningConfig(4, 0.4f, false));
checkGeneratedJson("""
{
"weighted_tokens": {
Expand All @@ -414,7 +414,7 @@ public void testToXContentWithThresholds() throws Exception {
}

public void testToXContentWithThresholdsAndOnlyScorePrunedTokens() throws Exception {
QueryBuilder query = new VectorDimensionsQueryBuilder("foo", WEIGHTED_TOKENS, new TokenPruningConfig(4, 0.4f, true));
QueryBuilder query = new VectorDimensionsQueryBuilder("foo", VECTOR_DIMENSIONS, new TokenPruningConfig(4, 0.4f, true));
checkGeneratedJson("""
{
"weighted_tokens": {
Expand Down

0 comments on commit d3f3d52

Please sign in to comment.