Skip to content

Commit

Permalink
Refactor/rename QueryVector back to WeightedToken. No logic changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Apr 2, 2024
1 parent 670172d commit a325de4
Show file tree
Hide file tree
Showing 17 changed files with 119 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public static List<ChunkedInferenceServiceResults> of(List<String> inputs, Spars
public static ChunkedSparseEmbeddingResults of(String input, SparseEmbeddingResults.Embedding embedding) {
var weightedTokens = embedding.tokens()
.stream()
.map(weightedToken -> new TextExpansionResults.QueryVector(weightedToken.token(), weightedToken.weight()))
.map(weightedToken -> new TextExpansionResults.WeightedToken(weightedToken.token(), weightedToken.weight()))
.toList();

return new ChunkedSparseEmbeddingResults(List.of(new ChunkedTextExpansionResults.ChunkedResult(input, weightedTokens)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public List<? extends InferenceResults> transformToLegacyFormat() {
DEFAULT_RESULTS_FIELD,
embedding.tokens()
.stream()
.map(weightedToken -> new TextExpansionResults.QueryVector(weightedToken.token, weightedToken.weight))
.map(weightedToken -> new TextExpansionResults.WeightedToken(weightedToken.token, weightedToken.weight))
.toList(),
embedding.isTruncated
)
Expand All @@ -111,9 +111,9 @@ public Embedding(StreamInput in) throws IOException {
this(in.readCollectionAsList(SparseEmbeddingResults.WeightedToken::new), in.readBoolean());
}

public static Embedding create(List<TextExpansionResults.QueryVector> queryVectors, boolean isTruncated) {
public static Embedding create(List<TextExpansionResults.WeightedToken> weightedTokens, boolean isTruncated) {
return new Embedding(
queryVectors.stream().map(token -> new WeightedToken(token.token(), token.weight())).toList(),
weightedTokens.stream().map(token -> new WeightedToken(token.token(), token.weight())).toList(),
isTruncated
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,27 @@
public class ChunkedTextExpansionResults extends ChunkedNlpInferenceResults {
public static final String NAME = "chunked_text_expansion_result";

public record ChunkedResult(String matchedText, List<TextExpansionResults.QueryVector> queryVectors)
public record ChunkedResult(String matchedText, List<TextExpansionResults.WeightedToken> weightedTokens)
implements
Writeable,
ToXContentObject {

public ChunkedResult(StreamInput in) throws IOException {
this(in.readString(), in.readCollectionAsList(TextExpansionResults.QueryVector::new));
this(in.readString(), in.readCollectionAsList(TextExpansionResults.WeightedToken::new));
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(matchedText);
out.writeCollection(queryVectors);
out.writeCollection(weightedTokens);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(TEXT, matchedText);
builder.startObject(INFERENCE);
for (var weightedToken : queryVectors) {
for (var weightedToken : weightedTokens) {
weightedToken.toXContent(builder, params);
}
builder.endObject();
Expand All @@ -56,8 +56,8 @@ public Map<String, Object> asMap() {
map.put(TEXT, matchedText);
map.put(
INFERENCE,
queryVectors.stream()
.collect(Collectors.toMap(TextExpansionResults.QueryVector::token, TextExpansionResults.QueryVector::weight))
weightedTokens.stream()
.collect(Collectors.toMap(TextExpansionResults.WeightedToken::token, TextExpansionResults.WeightedToken::weight))
);
return map;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ public class TextExpansionResults extends NlpInferenceResults {

public static final String NAME = "text_expansion_result";

public record QueryVector(String token, float weight) implements Writeable, ToXContentFragment {
public record WeightedToken(String token, float weight) implements Writeable, ToXContentFragment {

public QueryVector(StreamInput in) throws IOException {
public WeightedToken(StreamInput in) throws IOException {
this(in.readString(), in.readFloat());
}

Expand All @@ -53,22 +53,22 @@ public String toString() {
}

private final String resultsField;
private final List<QueryVector> queryVectors;
private final List<WeightedToken> weightedTokens;

public TextExpansionResults(String resultField, List<QueryVector> queryVectors, boolean isTruncated) {
public TextExpansionResults(String resultField, List<WeightedToken> weightedTokens, boolean isTruncated) {
super(isTruncated);
this.resultsField = resultField;
this.queryVectors = queryVectors;
this.weightedTokens = weightedTokens;
}

public TextExpansionResults(StreamInput in) throws IOException {
super(in);
this.resultsField = in.readString();
this.queryVectors = in.readCollectionAsList(QueryVector::new);
this.weightedTokens = in.readCollectionAsList(WeightedToken::new);
}

public List<QueryVector> getVectorDimensions() {
return queryVectors;
public List<WeightedToken> getVectorDimensions() {
return weightedTokens;
}

@Override
Expand All @@ -89,7 +89,7 @@ public Object predictedValue() {
@Override
void doXContentBody(XContentBuilder builder, Params params) throws IOException {
builder.startObject(resultsField);
for (var vectorDimension : queryVectors) {
for (var vectorDimension : weightedTokens) {
vectorDimension.toXContent(builder, params);
}
builder.endObject();
Expand All @@ -101,29 +101,29 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
if (super.equals(o) == false) return false;
TextExpansionResults that = (TextExpansionResults) o;
return Objects.equals(resultsField, that.resultsField) && Objects.equals(queryVectors, that.queryVectors);
return Objects.equals(resultsField, that.resultsField) && Objects.equals(weightedTokens, that.weightedTokens);
}

@Override
public int hashCode() {
return Objects.hash(super.hashCode(), resultsField, queryVectors);
return Objects.hash(super.hashCode(), resultsField, weightedTokens);
}

@Override
void doWriteTo(StreamOutput out) throws IOException {
out.writeString(resultsField);
out.writeCollection(queryVectors);
out.writeCollection(weightedTokens);
}

@Override
void addMapFields(Map<String, Object> map) {
map.put(resultsField, queryVectors.stream().collect(Collectors.toMap(QueryVector::token, QueryVector::weight)));
map.put(resultsField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
}

@Override
public Map<String, Object> asMap(String outputField) {
var map = super.asMap(outputField);
map.put(outputField, queryVectors.stream().collect(Collectors.toMap(QueryVector::token, QueryVector::weight)));
map.put(outputField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight)));
return map;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ public static ChunkedTextExpansionResults createRandomResults() {
int numChunks = randomIntBetween(1, 5);

for (int i = 0; i < numChunks; i++) {
var tokenWeights = new ArrayList<TextExpansionResults.QueryVector>();
var tokenWeights = new ArrayList<TextExpansionResults.WeightedToken>();
int numTokens = randomIntBetween(1, 8);
for (int j = 0; j < numTokens; j++) {
tokenWeights.add(new TextExpansionResults.QueryVector(Integer.toString(j), (float) randomDoubleBetween(0.0, 5.0, false)));
tokenWeights.add(new TextExpansionResults.WeightedToken(Integer.toString(j), (float) randomDoubleBetween(0.0, 5.0, false)));
}
chunks.add(new ChunkedTextExpansionResults.ChunkedResult(randomAlphaOfLength(6), tokenWeights));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ public static TextExpansionResults createRandomResults() {

public static TextExpansionResults createRandomResults(int min, int max) {
int numTokens = randomIntBetween(min, max);
List<TextExpansionResults.QueryVector> tokenList = new ArrayList<>();
List<TextExpansionResults.WeightedToken> tokenList = new ArrayList<>();
for (int i = 0; i < numTokens; i++) {
tokenList.add(new TextExpansionResults.QueryVector(Integer.toString(i), (float) randomDoubleBetween(0.0, 5.0, false)));
tokenList.add(new TextExpansionResults.WeightedToken(Integer.toString(i), (float) randomDoubleBetween(0.0, 5.0, false)));
}
return new TextExpansionResults(randomAlphaOfLength(4), tokenList, randomBoolean());
}
Expand Down Expand Up @@ -55,7 +55,7 @@ void assertFieldValues(TextExpansionResults createdInstance, IngestDocument docu
var ingestedTokens = (Map<String, Object>) document.getFieldValue(parentField + resultsField, Map.class);
var tokenMap = createdInstance.getVectorDimensions()
.stream()
.collect(Collectors.toMap(TextExpansionResults.QueryVector::token, TextExpansionResults.QueryVector::weight));
.collect(Collectors.toMap(TextExpansionResults.WeightedToken::token, TextExpansionResults.WeightedToken::weight));
assertEquals(tokenMap.size(), ingestedTokens.size());

assertEquals(tokenMap, ingestedTokens);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ private SparseEmbeddingResults makeResults(List<String> input) {
private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> input) {
var chunks = new ArrayList<ChunkedTextExpansionResults.ChunkedResult>();
for (int i = 0; i < input.size(); i++) {
var tokens = new ArrayList<TextExpansionResults.QueryVector>();
var tokens = new ArrayList<TextExpansionResults.WeightedToken>();
for (int j = 0; j < 5; j++) {
tokens.add(new TextExpansionResults.QueryVector(Integer.toString(j), (float) j));
tokens.add(new TextExpansionResults.WeightedToken(Integer.toString(j), (float) j));
}
chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ public static ChunkedSparseEmbeddingResults createRandomResults() {
int numChunks = randomIntBetween(1, 5);

for (int i = 0; i < numChunks; i++) {
var tokenWeights = new ArrayList<TextExpansionResults.QueryVector>();
var tokenWeights = new ArrayList<TextExpansionResults.WeightedToken>();
int numTokens = randomIntBetween(1, 8);
for (int j = 0; j < numTokens; j++) {
tokenWeights.add(new TextExpansionResults.QueryVector(Integer.toString(j), (float) randomDoubleBetween(0.0, 5.0, false)));
tokenWeights.add(new TextExpansionResults.WeightedToken(Integer.toString(j), (float) randomDoubleBetween(0.0, 5.0, false)));
}
chunks.add(new ChunkedTextExpansionResults.ChunkedResult(randomAlphaOfLength(6), tokenWeights));
}
Expand All @@ -43,7 +43,7 @@ public static ChunkedSparseEmbeddingResults createRandomResults() {

public void testToXContent_CreatesTheRightJsonForASingleChunk() {
var entity = new ChunkedSparseEmbeddingResults(
List.of(new ChunkedTextExpansionResults.ChunkedResult("text", List.of(new TextExpansionResults.QueryVector("token", 0.1f))))
List.of(new ChunkedTextExpansionResults.ChunkedResult("text", List.of(new TextExpansionResults.WeightedToken("token", 0.1f))))
);

assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ public void testTransformToCoordinationFormat() {
results,
is(
List.of(
new TextExpansionResults(DEFAULT_RESULTS_FIELD, List.of(new TextExpansionResults.QueryVector("token", 0.1F)), false),
new TextExpansionResults(DEFAULT_RESULTS_FIELD, List.of(new TextExpansionResults.QueryVector("token2", 0.2F)), true)
new TextExpansionResults(DEFAULT_RESULTS_FIELD, List.of(new TextExpansionResults.WeightedToken("token", 0.1F)), false),
new TextExpansionResults(DEFAULT_RESULTS_FIELD, List.of(new TextExpansionResults.WeightedToken("token2", 0.2F)), true)
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,21 @@ static InferenceResults processResult(
}
}

static List<TextExpansionResults.QueryVector> sparseVectorToTokenWeights(
static List<TextExpansionResults.WeightedToken> sparseVectorToTokenWeights(
double[] vector,
TokenizationResult tokenization,
Map<Integer, String> replacementVocab
) {
// Anything with a score > 0.0 is retained.
List<TextExpansionResults.QueryVector> queryVectors = new ArrayList<>();
List<TextExpansionResults.WeightedToken> weightedTokens = new ArrayList<>();
for (int i = 0; i < vector.length; i++) {
if (vector[i] > 0.0) {
queryVectors.add(new TextExpansionResults.QueryVector(tokenForId(i, tokenization, replacementVocab), (float) vector[i]));
weightedTokens.add(
new TextExpansionResults.WeightedToken(tokenForId(i, tokenization, replacementVocab), (float) vector[i])
);
}
}
return queryVectors;
return weightedTokens;
}

static String tokenForId(int id, TokenizationResult tokenization, Map<Integer, String> replacementVocab) {
Expand Down
Loading

0 comments on commit a325de4

Please sign in to comment.