diff --git a/src/main/java/org/opensearch/search/relevance/SearchRelevancePlugin.java b/src/main/java/org/opensearch/search/relevance/SearchRelevancePlugin.java index 6be5a17..8e96eed 100644 --- a/src/main/java/org/opensearch/search/relevance/SearchRelevancePlugin.java +++ b/src/main/java/org/opensearch/search/relevance/SearchRelevancePlugin.java @@ -14,6 +14,7 @@ import java.util.List; import java.util.Map; import java.util.function.Supplier; +import java.util.stream.Collectors; import org.opensearch.action.support.ActionFilter; import org.opensearch.client.Client; @@ -30,14 +31,15 @@ import org.opensearch.repositories.RepositoriesService; import org.opensearch.script.ScriptService; import org.opensearch.search.relevance.actionfilter.SearchActionFilter; -import org.opensearch.search.relevance.transformer.ResultTransformerType; +import org.opensearch.search.relevance.client.OpenSearchClient; +import org.opensearch.search.relevance.configuration.ResultTransformerConfigurationFactory; import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraClientSettings; import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraHttpClient; -import org.opensearch.search.relevance.client.OpenSearchClient; import org.opensearch.search.relevance.configuration.SearchConfigurationExtBuilder; import org.opensearch.search.relevance.transformer.kendraintelligentranking.KendraIntelligentRanker; import org.opensearch.search.relevance.transformer.ResultTransformer; import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings; +import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfigurationFactory; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; @@ -47,9 +49,13 @@ public class SearchRelevancePlugin extends Plugin implements ActionPlugin, Searc private KendraHttpClient kendraClient; private KendraIntelligentRanker kendraIntelligentRanker; - private Map getAllResultTransformers() { + private Collection getAllResultTransformers() { // Initialize and add other transformers here - return Map.of(ResultTransformerType.KENDRA_INTELLIGENT_RANKING, this.kendraIntelligentRanker); + return List.of(this.kendraIntelligentRanker); + } + + private Collection getResultTransformerConfigurationFactories() { + return List.of(KendraIntelligentRankingConfigurationFactory.INSTANCE); } @Override @@ -93,8 +99,12 @@ public Collection createComponents( @Override public List> getSearchExts() { + Map resultTransformerMap = getResultTransformerConfigurationFactories().stream() + .collect(Collectors.toMap(ResultTransformerConfigurationFactory::getName, i -> i)); return Collections.singletonList( - new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME, SearchConfigurationExtBuilder::new, SearchConfigurationExtBuilder::parse)); + new SearchExtSpec<>(SearchConfigurationExtBuilder.NAME, + input -> new SearchConfigurationExtBuilder(input, resultTransformerMap), + parser -> SearchConfigurationExtBuilder.parse(parser, resultTransformerMap))); } } diff --git a/src/main/java/org/opensearch/search/relevance/actionfilter/SearchActionFilter.java b/src/main/java/org/opensearch/search/relevance/actionfilter/SearchActionFilter.java index 1134b35..85bfb06 100644 --- a/src/main/java/org/opensearch/search/relevance/actionfilter/SearchActionFilter.java +++ b/src/main/java/org/opensearch/search/relevance/actionfilter/SearchActionFilter.java @@ -7,17 +7,6 @@ */ package org.opensearch.search.relevance.actionfilter; -import static org.opensearch.action.search.ShardSearchFailure.readShardSearchFailure; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; @@ -28,7 +17,6 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; -import org.opensearch.action.search.ShardSearchFailure; import org.opensearch.action.support.ActionFilter; import org.opensearch.action.support.ActionFilterChain; import org.opensearch.common.io.stream.BytesStreamOutput; @@ -43,27 +31,37 @@ import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.profile.SearchProfileShardResults; -import org.opensearch.search.relevance.configuration.ConfigurationUtils; import org.opensearch.search.relevance.client.OpenSearchClient; +import org.opensearch.search.relevance.configuration.ConfigurationUtils; import org.opensearch.search.relevance.configuration.ResultTransformerConfiguration; import org.opensearch.search.relevance.transformer.ResultTransformer; -import org.opensearch.search.relevance.transformer.ResultTransformerType; -import org.opensearch.search.suggest.Suggest; import org.opensearch.tasks.Task; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + public class SearchActionFilter implements ActionFilter { private static final Logger logger = LogManager.getLogger(SearchActionFilter.class); private final int order; private final NamedWriteableRegistry namedWriteableRegistry; - private final Map supportedResultTransformers; + private final Map resultTransformerMap; private final OpenSearchClient openSearchClient; - public SearchActionFilter(Map supportedResultTransformers, OpenSearchClient openSearchClient) { + public SearchActionFilter(Collection supportedResultTransformers, + OpenSearchClient openSearchClient) { order = 10; // TODO: Finalize this value namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); - this.supportedResultTransformers = supportedResultTransformers; + resultTransformerMap = supportedResultTransformers.stream() + .collect(Collectors.toMap(t -> t.getConfigurationFactory().getName(), t -> t)); this.openSearchClient = openSearchClient; } @@ -109,12 +107,12 @@ public void app return; } - List resultTransformerConfigurations = getResultTransformerConfigurations(indices[0], - searchRequest); + List resultTransformerConfigurations = + getResultTransformerConfigurations(indices[0], searchRequest); LinkedHashMap orderedTransformersAndConfigs = new LinkedHashMap<>(); for (ResultTransformerConfiguration config : resultTransformerConfigurations) { - ResultTransformer resultTransformer = supportedResultTransformers.get(config.getType()); + ResultTransformer resultTransformer = resultTransformerMap.get(config.getTransformerName()); // TODO: Should transformers make a decision based on the original request or the request they receive in the chain if (resultTransformer.shouldTransform(searchRequest, config)) { searchRequest = resultTransformer.preprocessRequest(searchRequest, config); @@ -154,17 +152,15 @@ private List getResultTransformerConfigurations( } // Fetch all index settings for this plugin - String[] settingNames = supportedResultTransformers.values() + String[] settingNames = resultTransformerMap.values() .stream() - .map(t -> t.getTransformerSettings() + .flatMap(t -> t.getTransformerSettings() .stream() - .map(Setting::getKey) - .collect(Collectors.toList())) - .flatMap(Collection::stream) + .map(Setting::getKey)) .toArray(String[]::new); configs = ConfigurationUtils.getResultTransformersFromIndexConfiguration( - openSearchClient.getIndexSettings(indexName, settingNames)); + openSearchClient.getIndexSettings(indexName, settingNames), resultTransformerMap); return configs; } @@ -194,21 +190,26 @@ public void onResponse(final Response response) { final SearchResponse searchResponse = (SearchResponse) response; final long totalHits = searchResponse.getHits().getTotalHits().value; if (totalHits == 0) { - logger.info("TotalHits = 0. Returning search response without re-ranking."); + logger.info("TotalHits = 0. Returning search response without transforming."); listener.onResponse(response); return; } logger.debug("Starting re-ranking for search response: {}", searchResponse); try { - final BytesStreamOutput out = new BytesStreamOutput(); - searchResponse.writeTo(out); - - final StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), namedWriteableRegistry); + // Clone search hits (by serializing + deserializing) before transforming + final BytesStreamOutput out = new BytesStreamOutput(); + searchResponse.getHits().writeTo(out); + final StreamInput in = new NamedWriteableAwareStreamInput(out.bytes().streamInput(), + namedWriteableRegistry); SearchHits hits = new SearchHits(in); + for (Map.Entry entry : orderedTransformersAndConfigs.entrySet()) { + long startTime = System.nanoTime(); hits = entry.getKey().transform(hits, searchRequest, entry.getValue()); + long timeTookMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime); + logger.info(entry.getValue().getTransformerName() + ": took " + timeTookMillis + " ms"); } List searchHitsList = Arrays.asList(hits.getHits()); @@ -230,52 +231,22 @@ public void onResponse(final Response response) { } } - // TODO: How to handle SearchHits.TotalHits when transformer modifies the hit count hits = new SearchHits( searchHitsList.toArray(new SearchHit[0]), hits.getTotalHits(), hits.getMaxScore()); - final InternalAggregations aggregations = - in.readBoolean() ? InternalAggregations.readFrom(in) : null; - final Suggest suggest = in.readBoolean() ? new Suggest(in) : null; - final boolean timedOut = in.readBoolean(); - final Boolean terminatedEarly = in.readOptionalBoolean(); - final SearchProfileShardResults profileResults = in.readOptionalWriteable( - SearchProfileShardResults::new); - final int numReducePhases = in.readVInt(); - final SearchResponseSections internalResponse = new InternalSearchResponse(hits, - aggregations, suggest, - profileResults, timedOut, terminatedEarly, numReducePhases); - - final int totalShards = in.readVInt(); - final int successfulShards = in.readVInt(); - final int shardSearchFailureSize = in.readVInt(); - final ShardSearchFailure[] shardFailures; - if (shardSearchFailureSize == 0) { - shardFailures = ShardSearchFailure.EMPTY_ARRAY; - } else { - shardFailures = new ShardSearchFailure[shardSearchFailureSize]; - for (int i = 0; i < shardFailures.length; i++) { - shardFailures[i] = readShardSearchFailure(in); - } - } - - final SearchResponse.Clusters clusters = new SearchResponse.Clusters(in.readVInt(), - in.readVInt(), in.readVInt()); - final String scrollId = in.readOptionalString(); - final int skippedShards = in.readVInt(); - - final long tookInMillis = (System.nanoTime() - startTime) / 1000000; - final SearchResponse newResponse = new SearchResponse(internalResponse, scrollId, - totalShards, successfulShards, - skippedShards, tookInMillis, shardFailures, clusters); + (InternalAggregations) searchResponse.getAggregations(), searchResponse.getSuggest(), + new SearchProfileShardResults(searchResponse.getProfileResults()), searchResponse.isTimedOut(), + searchResponse.isTerminatedEarly(), searchResponse.getNumReducePhases()); + + final long tookInMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime); + final SearchResponse newResponse = new SearchResponse(internalResponse, searchResponse.getScrollId(), + searchResponse.getTotalShards(), searchResponse.getSuccessfulShards(), + searchResponse.getSkippedShards(), tookInMillis, searchResponse.getShardFailures(), + searchResponse.getClusters()); listener.onResponse((Response) newResponse); - - // TODO: Change this to a metric - logger.info("Result transformer operations overhead time: {}ms", - tookInMillis - searchResponse.getTook().getMillis()); } catch (final Exception e) { logger.error("Result transformer operations failed.", e); throw new OpenSearchException("Result transformer operations failed.", e); diff --git a/src/main/java/org/opensearch/search/relevance/configuration/ConfigurationUtils.java b/src/main/java/org/opensearch/search/relevance/configuration/ConfigurationUtils.java index 328ab39..ed7b60d 100644 --- a/src/main/java/org/opensearch/search/relevance/configuration/ConfigurationUtils.java +++ b/src/main/java/org/opensearch/search/relevance/configuration/ConfigurationUtils.java @@ -7,18 +7,18 @@ */ package org.opensearch.search.relevance.configuration; -import static org.opensearch.search.relevance.configuration.Constants.RESULT_TRANSFORMER_SETTING_PREFIX; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.relevance.transformer.ResultTransformer; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.settings.Settings; -import org.opensearch.search.SearchExtBuilder; -import org.opensearch.search.relevance.transformer.ResultTransformerType; -import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfiguration; + +import static org.opensearch.search.relevance.configuration.Constants.RESULT_TRANSFORMER_SETTING_PREFIX; public class ConfigurationUtils { @@ -27,16 +27,16 @@ public class ConfigurationUtils { * @param settings all index settings configured for this plugin * @return ordered and validated list of result transformers, empty list if not specified */ - public static List getResultTransformersFromIndexConfiguration( - Settings settings) { + public static List getResultTransformersFromIndexConfiguration(Settings settings, + Map resultTransformerMap) { List indexLevelConfigs = new ArrayList<>(); if (settings != null) { if (settings.getGroups(RESULT_TRANSFORMER_SETTING_PREFIX) != null) { - for (Map.Entry resultTransformer : settings.getGroups(RESULT_TRANSFORMER_SETTING_PREFIX).entrySet()) { - ResultTransformerType resultTransformerType = ResultTransformerType.fromString(resultTransformer.getKey()); - if (ResultTransformerType.KENDRA_INTELLIGENT_RANKING.equals(resultTransformerType)) { - indexLevelConfigs.add(new KendraIntelligentRankingConfiguration(resultTransformer.getValue())); + for (Map.Entry transformerSettings : settings.getGroups(RESULT_TRANSFORMER_SETTING_PREFIX).entrySet()) { + if (resultTransformerMap.containsKey(transformerSettings.getKey())) { + ResultTransformer transformer = resultTransformerMap.get(transformerSettings.getKey()); + indexLevelConfigs.add(transformer.getConfigurationFactory().configure(transformerSettings.getValue())); } } } @@ -86,7 +86,7 @@ public static List reorderAndValidateConfigs( for (int i = 0; i < configs.size(); ++i) { if (configs.get(i).getOrder() != (i + 1)) { throw new IllegalArgumentException("Expected order [" + (i + 1) + "] for transformer [" + - configs.get(i).getType() + "], but found [" + configs.get(i).getOrder() + "]"); + configs.get(i).getTransformerName() + "], but found [" + configs.get(i).getOrder() + "]"); } } diff --git a/src/main/java/org/opensearch/search/relevance/configuration/ResultTransformerConfiguration.java b/src/main/java/org/opensearch/search/relevance/configuration/ResultTransformerConfiguration.java index f63ec46..b52c2f1 100644 --- a/src/main/java/org/opensearch/search/relevance/configuration/ResultTransformerConfiguration.java +++ b/src/main/java/org/opensearch/search/relevance/configuration/ResultTransformerConfiguration.java @@ -7,10 +7,7 @@ */ package org.opensearch.search.relevance.configuration; -import org.opensearch.search.relevance.transformer.ResultTransformerType; - public abstract class ResultTransformerConfiguration extends TransformerConfiguration { - public abstract ResultTransformerType getType(); - + public abstract String getTransformerName(); } diff --git a/src/main/java/org/opensearch/search/relevance/configuration/ResultTransformerConfigurationFactory.java b/src/main/java/org/opensearch/search/relevance/configuration/ResultTransformerConfigurationFactory.java new file mode 100644 index 0000000..431efcc --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/configuration/ResultTransformerConfigurationFactory.java @@ -0,0 +1,42 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.configuration; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentParser; + +import java.io.IOException; + +public interface ResultTransformerConfigurationFactory { + String getName(); + + /** + * Build configuration based on index settings + * @param indexSettings a set of index settings under a group scoped based on this result transformer's name. + * @return a transformer configuration based on the passed settings. + */ + ResultTransformerConfiguration configure(Settings indexSettings); + + /** + * Build configuration from serialized XContent, e.g. as part of a serialized {@link SearchConfigurationExtBuilder}. + * @param parser an XContentParser pointing to a node serialized from a {@link ResultTransformerConfiguration} of + * this type. + * @return a transformer configuration based on the parameters specified in the XContent. + */ + ResultTransformerConfiguration configure(XContentParser parser) throws IOException; + + /** + * Build configuration from a serialized stream. + * @param streamInput a {@link org.opensearch.common.io.stream.Writeable} serialized representation of transformer + * configuration. + * @return configuration the deserialized transformer configuration. + */ + ResultTransformerConfiguration configure(StreamInput streamInput) throws IOException; + +} diff --git a/src/main/java/org/opensearch/search/relevance/configuration/SearchConfigurationExtBuilder.java b/src/main/java/org/opensearch/search/relevance/configuration/SearchConfigurationExtBuilder.java index 170e934..a84f544 100644 --- a/src/main/java/org/opensearch/search/relevance/configuration/SearchConfigurationExtBuilder.java +++ b/src/main/java/org/opensearch/search/relevance/configuration/SearchConfigurationExtBuilder.java @@ -7,14 +7,6 @@ */ package org.opensearch.search.relevance.configuration; -import static org.opensearch.search.relevance.configuration.Constants.SEARCH_CONFIGURATION; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.common.ParseField; import org.opensearch.common.ParsingException; import org.opensearch.common.io.stream.StreamInput; @@ -22,118 +14,139 @@ import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.search.SearchExtBuilder; -import org.opensearch.search.relevance.transformer.ResultTransformerType; import org.opensearch.search.relevance.transformer.TransformerType; -import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfiguration; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.opensearch.search.relevance.configuration.Constants.SEARCH_CONFIGURATION; public class SearchConfigurationExtBuilder extends SearchExtBuilder { - public static final String NAME = SEARCH_CONFIGURATION; + public static final String NAME = SEARCH_CONFIGURATION; - private static final ParseField RESULT_TRANSFORMER = new ParseField(TransformerType.RESULT_TRANSFORMER.toString()); - private static final ParseField KENDRA_INTELLIGENT_RANKING = new ParseField(ResultTransformerType.KENDRA_INTELLIGENT_RANKING.toString()); + private static final ParseField RESULT_TRANSFORMER = new ParseField(TransformerType.RESULT_TRANSFORMER.toString()); - private List resultTransformerConfigurations = new ArrayList<>(); + private List resultTransformerConfigurations = new ArrayList<>(); - public SearchConfigurationExtBuilder() {} + public SearchConfigurationExtBuilder() { + } - public SearchConfigurationExtBuilder(StreamInput input) throws IOException { - ResultTransformerConfiguration cfg1 = input.readOptionalWriteable(KendraIntelligentRankingConfiguration::new); - resultTransformerConfigurations.add(cfg1); - } + public SearchConfigurationExtBuilder(StreamInput input, Map resultTransformerMap) throws IOException { + int numTransformers = input.readInt(); + for (int i = 0; i < numTransformers; i++) { + String transformerName = input.readString(); + ResultTransformerConfigurationFactory transformer = resultTransformerMap.get(transformerName); + if (transformer == null) { + throw new IllegalStateException("Unknown result transformer " + transformerName); + } + resultTransformerConfigurations.add(transformer.configure(input)); + } + } - @Override - public void writeTo(StreamOutput out) throws IOException { - for (ResultTransformerConfiguration config : resultTransformerConfigurations) { - out.writeOptionalWriteable(config); + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(resultTransformerConfigurations.size()); + for (ResultTransformerConfiguration config : resultTransformerConfigurations) { + out.writeString(config.getTransformerName()); + config.writeTo(out); + } } - } - - @Override - public String getWriteableName() { - return NAME; - } - - public static SearchConfigurationExtBuilder parse(XContentParser parser) throws IOException { - SearchConfigurationExtBuilder extBuilder = new SearchConfigurationExtBuilder(); - XContentParser.Token token = parser.currentToken(); - String currentFieldName = null; - if (token != XContentParser.Token.START_OBJECT && (token = parser.nextToken()) != XContentParser.Token.START_OBJECT) { - throw new ParsingException( - parser.getTokenLocation(), - "Expected [" + XContentParser.Token.START_OBJECT + "] but found [" + token + "]", - parser.getTokenLocation() - ); + + @Override + public String getWriteableName() { + return NAME; } - while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { - if (token == XContentParser.Token.FIELD_NAME) { - currentFieldName = parser.currentName(); - } else if (token == XContentParser.Token.START_OBJECT) { - if (RESULT_TRANSFORMER.match(currentFieldName, parser.getDeprecationHandler())) { - while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + + public static SearchConfigurationExtBuilder parse(XContentParser parser, + Map resultTransformerMap) throws IOException { + SearchConfigurationExtBuilder extBuilder = new SearchConfigurationExtBuilder(); + XContentParser.Token token = parser.currentToken(); + String currentFieldName = null; + if (token != XContentParser.Token.START_OBJECT && (token = parser.nextToken()) != XContentParser.Token.START_OBJECT) { + throw new ParsingException( + parser.getTokenLocation(), + "Expected [" + XContentParser.Token.START_OBJECT + "] but found [" + token + "]", + parser.getTokenLocation() + ); + } + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { - currentFieldName = parser.currentName(); - } - if (KENDRA_INTELLIGENT_RANKING.match(currentFieldName, - parser.getDeprecationHandler())) { - extBuilder.addResultTransformer( - KendraIntelligentRankingConfiguration.parse(parser, null)); + currentFieldName = parser.currentName(); + } else if (token == XContentParser.Token.START_OBJECT) { + if (RESULT_TRANSFORMER.match(currentFieldName, parser.getDeprecationHandler())) { + currentFieldName = null; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (currentFieldName != null) { + if (resultTransformerMap.containsKey(currentFieldName)) { + ResultTransformerConfiguration configuration = + resultTransformerMap.get(currentFieldName).configure(parser); + extBuilder.addResultTransformer(configuration); + } else { + throw new IllegalArgumentException( + "Unrecognized Result Transformer type [" + currentFieldName + "]"); + } + } + } + } else { + throw new IllegalArgumentException("Unrecognized Transformer type [" + currentFieldName + "]"); + } } else { - throw new IllegalArgumentException( - "Unrecognized Result Transformer type [" + currentFieldName + "]"); + throw new ParsingException( + parser.getTokenLocation(), + "Unknown key for a " + token + " in [" + currentFieldName + "].", + parser.getTokenLocation() + ); } - } - } else { - throw new IllegalArgumentException("Unrecognized Transformer type [" + currentFieldName + "]"); } - } else { - throw new ParsingException( - parser.getTokenLocation(), - "Unknown key for a " + token + " in [" + currentFieldName + "].", - parser.getTokenLocation() - ); - } + return extBuilder; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(RESULT_TRANSFORMER.getPreferredName()); + for (ResultTransformerConfiguration config : resultTransformerConfigurations) { + builder.field(config.getTransformerName(), config); + } + return builder.endObject(); } - return extBuilder; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(NAME); - for (ResultTransformerConfiguration config : resultTransformerConfigurations) { - builder.field(config.getType().toString(), config); + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (!(obj instanceof SearchConfigurationExtBuilder)) { + return false; + } + SearchConfigurationExtBuilder o = (SearchConfigurationExtBuilder) obj; + HashSet myConfigurations = new HashSet<>(this.resultTransformerConfigurations); + HashSet otherConfigurations = new HashSet<>(o.resultTransformerConfigurations); + return (this.resultTransformerConfigurations.size() == o.resultTransformerConfigurations.size() && + myConfigurations.equals(otherConfigurations)); + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.resultTransformerConfigurations); + } + + public SearchConfigurationExtBuilder setResultTransformers(final List resultTransformerConfigurations) { + this.resultTransformerConfigurations = resultTransformerConfigurations; + return this; } - return builder.endObject(); - } - @Override - public boolean equals(Object obj) { - if (obj == null) { - return false; + public List getResultTransformers() { + return this.resultTransformerConfigurations; } - if (!(obj instanceof SearchConfigurationExtBuilder)) { - return false; + + public SearchConfigurationExtBuilder addResultTransformer(final ResultTransformerConfiguration resultTransformerConfiguration) { + this.resultTransformerConfigurations.add(resultTransformerConfiguration); + return this; } - SearchConfigurationExtBuilder o = (SearchConfigurationExtBuilder) obj; - return (this.resultTransformerConfigurations.size() == o.resultTransformerConfigurations.size() && - this.resultTransformerConfigurations.containsAll(o.resultTransformerConfigurations) && - o.resultTransformerConfigurations.containsAll(this.resultTransformerConfigurations)) ; - } - - @Override - public int hashCode() { - return Objects.hash(this.getClass(), this.resultTransformerConfigurations); - } - - public SearchConfigurationExtBuilder setResultTransformers(final List resultTransformerConfigurations) { - this.resultTransformerConfigurations = resultTransformerConfigurations; - return this; - } - - public List getResultTransformers() { - return this.resultTransformerConfigurations; - } - - public void addResultTransformer(final ResultTransformerConfiguration resultTransformerConfiguration) { - this.resultTransformerConfigurations.add(resultTransformerConfiguration); - } } diff --git a/src/main/java/org/opensearch/search/relevance/transformer/ResultTransformer.java b/src/main/java/org/opensearch/search/relevance/transformer/ResultTransformer.java index 6568828..72ac0c6 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/ResultTransformer.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/ResultTransformer.java @@ -7,20 +7,26 @@ */ package org.opensearch.search.relevance.transformer; -import java.util.List; import org.opensearch.action.search.SearchRequest; import org.opensearch.common.settings.Setting; import org.opensearch.search.SearchHits; import org.opensearch.search.relevance.configuration.ResultTransformerConfiguration; +import org.opensearch.search.relevance.configuration.ResultTransformerConfigurationFactory; -public interface ResultTransformer { +import java.util.List; +public interface ResultTransformer { /** - * Get the list of settings required / supported by the transformer + * Get the list of settings supported by the transformer * @return list of transformer settings */ List> getTransformerSettings(); + /** + * @return a factory able to construct configurations for this transformer. + */ + ResultTransformerConfigurationFactory getConfigurationFactory(); + /** * Decide whether to apply the transformer on the input request * @param request input Search Request diff --git a/src/main/java/org/opensearch/search/relevance/transformer/ResultTransformerType.java b/src/main/java/org/opensearch/search/relevance/transformer/ResultTransformerType.java deleted file mode 100644 index 81880fc..0000000 --- a/src/main/java/org/opensearch/search/relevance/transformer/ResultTransformerType.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.search.relevance.transformer; - -public enum ResultTransformerType { - KENDRA_INTELLIGENT_RANKING("kendra_intelligent_ranking"); - - private final String type; - - ResultTransformerType(String type) { - this.type = type; - } - - @Override - public String toString() { - return type; - } - - public static ResultTransformerType fromString(String type) { - for (ResultTransformerType resultTransformerType : values()) { - if (resultTransformerType.type.equalsIgnoreCase(type)) { - return resultTransformerType; - } - } - throw new IllegalArgumentException("Unrecognized Result Transformer type [" + type + "]"); - } -} diff --git a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRanker.java b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRanker.java index 458dc55..ef06d03 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRanker.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRanker.java @@ -7,30 +7,20 @@ */ package org.opensearch.search.relevance.transformer.kendraintelligentranking; -import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.Constants.BODY_FIELD; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.PriorityQueue; - import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; - import org.opensearch.action.search.SearchRequest; import org.opensearch.common.settings.Setting; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.SearchService; import org.opensearch.search.relevance.configuration.ResultTransformerConfiguration; +import org.opensearch.search.relevance.configuration.ResultTransformerConfigurationFactory; import org.opensearch.search.relevance.transformer.ResultTransformer; import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraHttpClient; +import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings; import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfiguration; +import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfigurationFactory; import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.KendraIntelligentRankingException; import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.PassageScore; import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.Document; @@ -42,17 +32,29 @@ import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.QueryParser; import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.QueryParser.QueryParserResult; import org.opensearch.search.relevance.transformer.kendraintelligentranking.preprocess.TextTokenizer; -import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.PriorityQueue; + +import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.Constants.BODY_FIELD; public class KendraIntelligentRanker implements ResultTransformer { - private static final int MAX_SENTENCE_LENGTH_IN_TOKENS = 35; - private static final int MIN_PASSAGE_LENGTH_IN_TOKENS = 100; - private static final int MAX_PASSAGE_COUNT = 10; - private static final int TITLE_TOKENS_TRIMMED = 15; - private static final int BODY_PASSAGE_TRIMMED = 200; - private static final double BM25_B_VALUE = 0.75; - private static final double BM25_K1_VALUE = 1.6; - private static final int TOP_K_PASSAGES = 3; + public static final String NAME = "kendra_intelligent_ranking"; + private static final int MAX_SENTENCE_LENGTH_IN_TOKENS = 35; + private static final int MIN_PASSAGE_LENGTH_IN_TOKENS = 100; + private static final int MAX_PASSAGE_COUNT = 10; + private static final int TITLE_TOKENS_TRIMMED = 15; + private static final int BODY_PASSAGE_TRIMMED = 200; + private static final double BM25_B_VALUE = 0.75; + private static final double BM25_K1_VALUE = 1.6; + private static final int TOP_K_PASSAGES = 3; private static final Logger logger = LogManager.getLogger(KendraIntelligentRanker.class); @@ -71,6 +73,13 @@ public List> getTransformerSettings() { return KendraIntelligentRankerSettings.getAllSettings(); } + + @Override + public ResultTransformerConfigurationFactory getConfigurationFactory() { + return KendraIntelligentRankingConfigurationFactory.INSTANCE; + } + + /** * Check if search request is eligible for rescore * @@ -90,6 +99,10 @@ public boolean shouldTransform(final SearchRequest request, final ResultTransfor request.source().from() >= kendraConfiguration.getProperties().getDocLimit()) { return false; } + if (!kendraClient.isValid()) { + logger.warn("Kendra ranking endpoint was not configured. Skipping reranking."); + return false; + } return true; } @@ -153,7 +166,7 @@ public SearchHits transform(final SearchHits hits, throw new KendraIntelligentRankingException(errorMessage); } List> passages = passageGenerator.generatePassages(docSourceMap.get(bodyFieldName).toString(), - MAX_SENTENCE_LENGTH_IN_TOKENS, MIN_PASSAGE_LENGTH_IN_TOKENS, MAX_PASSAGE_COUNT); + MAX_SENTENCE_LENGTH_IN_TOKENS, MIN_PASSAGE_LENGTH_IN_TOKENS, MAX_PASSAGE_COUNT); List> topPassages = getTopPassages(queryParserResult.getQueryText(), passages); List tokenizedTitle = null; if (titleFieldName != null && docSourceMap.get(titleFieldName) != null) { @@ -172,10 +185,10 @@ public SearchHits transform(final SearchHits hits, } originalHitsAsDocuments.add(new Document( originalHits.get(j).getId() + "@" + (i + 1), - originalHits.get(j).getId(), - tokenizedTitle, - passageTokens, - originalHits.get(j).getScore())); + originalHits.get(j).getId(), + tokenizedTitle, + passageTokens, + originalHits.get(j).getScore())); } // Map search hits by their ID in order to map Kendra response documents back to hits later idToSearchHitMap.put(originalHits.get(j).getId(), originalHits.get(j)); @@ -194,16 +207,16 @@ public SearchHits transform(final SearchHits hits, rescoreResultItem.getDocumentId()); logger.error(errorMessage); throw new KendraIntelligentRankingException(errorMessage); - } - searchHit.score(rescoreResultItem.getScore()); - maxScore = Math.max(maxScore, rescoreResultItem.getScore()); - newSearchHits.add(searchHit); + } + searchHit.score(rescoreResultItem.getScore()); + maxScore = Math.max(maxScore, rescoreResultItem.getScore()); + newSearchHits.add(searchHit); } // Add remaining hits to response, which are already sorted by OpenSearch score for (int i = numberOfHitsToRerank; i < originalHits.size(); ++i) { - newSearchHits.add(originalHits.get(i)); + newSearchHits.add(originalHits.get(i)); } - return new SearchHits(newSearchHits.toArray(new SearchHit[newSearchHits.size()]), hits.getTotalHits(), maxScore); + return new SearchHits(newSearchHits.toArray(new SearchHit[0]), hits.getTotalHits(), maxScore); } catch (Exception ex) { logger.error("Failed to rescore. Returning original search results without rescore.", ex); return hits; @@ -213,7 +226,7 @@ public SearchHits transform(final SearchHits hits, private List> getTopPassages(final String queryText, final List> passages) { List query = textTokenizer.tokenize(queryText); BM25Scorer bm25Scorer = new BM25Scorer(BM25_B_VALUE, BM25_K1_VALUE, passages); - PriorityQueue pq = new PriorityQueue<>(Comparator.comparingDouble(x -> x.getScore())); + PriorityQueue pq = new PriorityQueue<>(Comparator.comparingDouble(PassageScore::getScore)); for (int i = 0; i < passages.size(); i++) { double score = bm25Scorer.score(query, passages.get(i)); diff --git a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/client/KendraHttpClient.java b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/client/KendraHttpClient.java index d30fb27..bb318bf 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/client/KendraHttpClient.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/client/KendraHttpClient.java @@ -34,6 +34,7 @@ import java.security.AccessController; import java.security.PrivilegedAction; +import org.apache.commons.lang3.StringUtils; import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.RescoreRequest; import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.RescoreResult; @@ -50,45 +51,53 @@ public class KendraHttpClient { private final AWS4Signer aws4Signer; private final String serviceEndpoint; private final String executionPlanId; - private final ObjectMapper objectMapper; + private final ObjectMapper objectMapper = new ObjectMapper() + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); public KendraHttpClient(KendraClientSettings clientSettings) { - amazonHttpClient = AccessController.doPrivileged((PrivilegedAction) () -> new AmazonHttpClient(new ClientConfiguration())); - errorHandler = new SimpleAwsErrorHandler(); - responseHandler = new SimpleResponseHandler(); - aws4Signer = new AWS4Signer(); - aws4Signer.setServiceName(KENDRA_RANKING_SERVICE_NAME); - aws4Signer.setRegionName(clientSettings.getServiceRegion()); - objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); serviceEndpoint = clientSettings.getServiceEndpoint(); executionPlanId = clientSettings.getExecutionPlanId(); + if (isValid()) { + amazonHttpClient = AccessController.doPrivileged((PrivilegedAction) () -> new AmazonHttpClient(new ClientConfiguration())); + errorHandler = new SimpleAwsErrorHandler(); + responseHandler = new SimpleResponseHandler(); + aws4Signer = new AWS4Signer(); + aws4Signer.setServiceName(KENDRA_RANKING_SERVICE_NAME); + aws4Signer.setRegionName(clientSettings.getServiceRegion()); - final AWSCredentialsProvider credentialsProvider; - final AWSCredentials credentials = clientSettings.getCredentials(); - if (credentials == null) { - // Use environment variables, system properties or instance profile credentials. - credentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); - } else { - // Use keystore credentials. - credentialsProvider = new AWSStaticCredentialsProvider(credentials); - } + final AWSCredentialsProvider credentialsProvider; + final AWSCredentials credentials = clientSettings.getCredentials(); + if (credentials == null) { + // Use environment variables, system properties or instance profile credentials. + credentialsProvider = DefaultAWSCredentialsProviderChain.getInstance(); + } else { + // Use keystore credentials. + credentialsProvider = new AWSStaticCredentialsProvider(credentials); + } - final String assumeRoleArn = clientSettings.getAssumeRoleArn(); - if (assumeRoleArn != null && !assumeRoleArn.isBlank()) { - // If AssumeRoleArn was provided in config, use auto-refreshed role credentials. - awsCredentialsProvider = AccessController.doPrivileged( - (PrivilegedAction) () -> { - AWSSecurityTokenService awsSecurityTokenService = AWSSecurityTokenServiceClientBuilder.standard() - .withCredentials(credentialsProvider) - .withRegion(clientSettings.getServiceRegion()) - .build(); + final String assumeRoleArn = clientSettings.getAssumeRoleArn(); + if (assumeRoleArn != null && !assumeRoleArn.isBlank()) { + // If AssumeRoleArn was provided in config, use auto-refreshed role credentials. + awsCredentialsProvider = AccessController.doPrivileged( + (PrivilegedAction) () -> { + AWSSecurityTokenService awsSecurityTokenService = AWSSecurityTokenServiceClientBuilder.standard() + .withCredentials(credentialsProvider) + .withRegion(clientSettings.getServiceRegion()) + .build(); - return new STSAssumeRoleSessionCredentialsProvider.Builder(clientSettings.getAssumeRoleArn(), ASSUME_ROLE_SESSION_NAME) - .withStsClient(awsSecurityTokenService) - .build(); - }); + return new STSAssumeRoleSessionCredentialsProvider.Builder(clientSettings.getAssumeRoleArn(), ASSUME_ROLE_SESSION_NAME) + .withStsClient(awsSecurityTokenService) + .build(); + }); + } else { + awsCredentialsProvider = credentialsProvider; + } } else { - awsCredentialsProvider = credentialsProvider; + amazonHttpClient = null; + aws4Signer = null; + awsCredentialsProvider = null; + errorHandler = null; + responseHandler = null; } } @@ -119,4 +128,8 @@ public URI buildRescoreURI() { return URI.create(String.join("/", serviceEndpoint, KENDRA_RESCORE_EXECUTION_PLANS, executionPlanId, KENDRA_RESCORE_URI)); } + + public boolean isValid() { + return StringUtils.isNotEmpty(serviceEndpoint) && StringUtils.isNotEmpty(executionPlanId); + } } diff --git a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankerSettings.java b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankerSettings.java index 1648ad9..e5f4f23 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankerSettings.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankerSettings.java @@ -29,7 +29,7 @@ public class KendraIntelligentRankerSettings { */ static final class FieldValidator implements Setting.Validator> { - private String settingName; + private final String settingName; public FieldValidator(final String name) { this.settingName = name; @@ -48,7 +48,7 @@ public void validate(List value) { */ static final class DocLimitValidator implements Setting.Validator { - private String settingName; + private final String settingName; public DocLimitValidator(final String name) { this.settingName = name; @@ -115,7 +115,7 @@ public void validate(Integer value) { public static final Setting ASSUME_ROLE_ARN_SETTING = Setting.simpleString("kendra_intelligent_ranking.service.assume_role_arn", Setting.Property.NodeScope); - public static final List> getAllSettings() { + public static List> getAllSettings() { return Arrays.asList( KENDRA_ORDER_SETTING, KENDRA_BODY_FIELD_SETTING, diff --git a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfiguration.java b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfiguration.java index 6bafb3d..25a0c11 100644 --- a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfiguration.java +++ b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfiguration.java @@ -7,17 +7,6 @@ */ package org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration; -import static org.opensearch.search.relevance.configuration.Constants.ORDER; -import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.Constants.DOC_LIMIT_SETTING_NAME; -import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.Constants.KENDRA_DEFAULT_DOC_LIMIT; -import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings.BODY_FIELD_VALIDATOR; -import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings.DOC_LIMIT_VALIDATOR; -import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings.TITLE_FIELD_VALIDATOR; - -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.Objects; import org.opensearch.common.ParseField; import org.opensearch.common.ParsingException; import org.opensearch.common.io.stream.StreamInput; @@ -30,16 +19,27 @@ import org.opensearch.common.xcontent.XContentParser; import org.opensearch.search.relevance.configuration.ResultTransformerConfiguration; import org.opensearch.search.relevance.configuration.TransformerConfiguration; -import org.opensearch.search.relevance.transformer.ResultTransformerType; +import org.opensearch.search.relevance.transformer.kendraintelligentranking.KendraIntelligentRanker; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import static org.opensearch.search.relevance.configuration.Constants.ORDER; +import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.Constants.KENDRA_DEFAULT_DOC_LIMIT; +import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings.BODY_FIELD_VALIDATOR; +import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings.DOC_LIMIT_VALIDATOR; +import static org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings.TITLE_FIELD_VALIDATOR; public class KendraIntelligentRankingConfiguration extends ResultTransformerConfiguration { private static final ObjectParser PARSER; static { - PARSER = new ObjectParser("kendra_intelligent_ranking_configuration", KendraIntelligentRankingConfiguration::new); + PARSER = new ObjectParser<>("kendra_intelligent_ranking_configuration", KendraIntelligentRankingConfiguration::new); PARSER.declareInt(TransformerConfiguration::setOrder, TRANSFORMER_ORDER); PARSER.declareObject(KendraIntelligentRankingConfiguration::setProperties, - (p, c) -> KendraIntelligentRankingProperties.parse(p, c), + KendraIntelligentRankingProperties::parse, TRANSFORMER_PROPERTIES); } @@ -66,8 +66,8 @@ public KendraIntelligentRankingConfiguration(Settings settings) { } @Override - public ResultTransformerType getType() { - return ResultTransformerType.KENDRA_INTELLIGENT_RANKING; + public String getTransformerName() { + return KendraIntelligentRanker.NAME; } @Override @@ -76,7 +76,7 @@ public void writeTo(StreamOutput out) throws IOException { this.properties.writeTo(out); } - public static ResultTransformerConfiguration parse(XContentParser parser, Void context) throws IOException { + public static ResultTransformerConfiguration parse(XContentParser parser) throws IOException { try { KendraIntelligentRankingConfiguration configuration = PARSER.parse(parser, null); if (configuration != null && configuration.getOrder() <= 0) { diff --git a/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfigurationFactory.java b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfigurationFactory.java new file mode 100644 index 0000000..bb9d9b3 --- /dev/null +++ b/src/main/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfigurationFactory.java @@ -0,0 +1,45 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration; + +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.search.relevance.configuration.ResultTransformerConfiguration; +import org.opensearch.search.relevance.configuration.ResultTransformerConfigurationFactory; +import org.opensearch.search.relevance.transformer.kendraintelligentranking.KendraIntelligentRanker; + +import java.io.IOException; + +public class KendraIntelligentRankingConfigurationFactory implements ResultTransformerConfigurationFactory { + private KendraIntelligentRankingConfigurationFactory() { + } + + public static final KendraIntelligentRankingConfigurationFactory INSTANCE = + new KendraIntelligentRankingConfigurationFactory(); + + @Override + public String getName() { + return KendraIntelligentRanker.NAME; + } + + @Override + public ResultTransformerConfiguration configure(Settings indexSettings) { + return new KendraIntelligentRankingConfiguration(indexSettings); + } + + @Override + public ResultTransformerConfiguration configure(XContentParser parser) throws IOException { + return KendraIntelligentRankingConfiguration.parse(parser); + } + + @Override + public ResultTransformerConfiguration configure(StreamInput streamInput) throws IOException { + return new KendraIntelligentRankingConfiguration(streamInput); + } +} diff --git a/src/test/java/org/opensearch/search/relevance/actionfilter/SearchActionFilterTests.java b/src/test/java/org/opensearch/search/relevance/actionfilter/SearchActionFilterTests.java index 7d748f5..e73d9ce 100644 --- a/src/test/java/org/opensearch/search/relevance/actionfilter/SearchActionFilterTests.java +++ b/src/test/java/org/opensearch/search/relevance/actionfilter/SearchActionFilterTests.java @@ -28,10 +28,12 @@ import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.collect.ImmutableOpenMap; import org.opensearch.common.document.DocumentField; +import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -39,9 +41,9 @@ import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.relevance.client.OpenSearchClient; import org.opensearch.search.relevance.configuration.ResultTransformerConfiguration; +import org.opensearch.search.relevance.configuration.ResultTransformerConfigurationFactory; import org.opensearch.search.relevance.configuration.SearchConfigurationExtBuilder; import org.opensearch.search.relevance.transformer.ResultTransformer; -import org.opensearch.search.relevance.transformer.ResultTransformerType; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; @@ -65,7 +67,7 @@ public class SearchActionFilterTests extends OpenSearchTestCase { public void testIgnoresDelete() { Client client = Mockito.mock(Client.class); OpenSearchClient openSearchClient = new OpenSearchClient(client); - SearchActionFilter searchActionFilter = new SearchActionFilter(Collections.emptyMap(), openSearchClient); + SearchActionFilter searchActionFilter = new SearchActionFilter(Collections.emptyList(), openSearchClient); Task task = Mockito.mock(Task.class); DeleteRequest deleteRequest = new DeleteRequestBuilder(null, DeleteAction.INSTANCE).request(); @@ -82,7 +84,7 @@ public void testIgnoresDelete() { public void testIgnoresSearchRequestOnZeroIndices() { Client client = Mockito.mock(Client.class); OpenSearchClient openSearchClient = new OpenSearchClient(client); - SearchActionFilter searchActionFilter = new SearchActionFilter(Collections.emptyMap(), openSearchClient); + SearchActionFilter searchActionFilter = new SearchActionFilter(Collections.emptyList(), openSearchClient); Task task = Mockito.mock(Task.class); SearchRequest searchRequest = new SearchRequestBuilder(null, SearchAction.INSTANCE).request(); @@ -99,7 +101,7 @@ public void testIgnoresSearchRequestOnZeroIndices() { public void testIgnoresSearchRequestOnMultipleIndices() { Client client = Mockito.mock(Client.class); OpenSearchClient openSearchClient = new OpenSearchClient(client); - SearchActionFilter searchActionFilter = new SearchActionFilter(Collections.emptyMap(), openSearchClient); + SearchActionFilter searchActionFilter = new SearchActionFilter(Collections.emptyList(), openSearchClient); Task task = Mockito.mock(Task.class); SearchRequest searchRequest = new SearchRequestBuilder(null, SearchAction.INSTANCE) @@ -138,7 +140,7 @@ private static Client buildMockClient(String indexName, Settings... settings) { public void testOperatesOnSingleIndexWithNoTransformers() { Client client = buildMockClient("index"); OpenSearchClient openSearchClient = new OpenSearchClient(client); - SearchActionFilter searchActionFilter = new SearchActionFilter(Collections.emptyMap(), openSearchClient); + SearchActionFilter searchActionFilter = new SearchActionFilter(Collections.emptyList(), openSearchClient); Task task = Mockito.mock(Task.class); SearchRequest searchRequest = new SearchRequestBuilder(null, SearchAction.INSTANCE) @@ -153,6 +155,9 @@ public void testOperatesOnSingleIndexWithNoTransformers() { private static class MockTransformer implements ResultTransformer { + + public static final String NAME = "mock_transformer"; + public MockTransformer() { requestTransformer = i -> {}; } @@ -167,13 +172,18 @@ public MockTransformer(Consumer requestTransformer) { private boolean transformWasCalled = false; private boolean preproccessRequestWasCalled = false; - @Override public List> getTransformerSettings() { getTransformerSettingsWasCalled = true; return Collections.emptyList(); } + @Override + public ResultTransformerConfigurationFactory getConfigurationFactory() { + return MOCK_CONFIGURATION_FACTORY; + } + + @Override public boolean shouldTransform(SearchRequest request, ResultTransformerConfiguration configuration) { shouldTransformWasCalled = true; @@ -202,9 +212,8 @@ public int getOrder() { } @Override - public ResultTransformerType getType() { - // For now, the only supported type - return ResultTransformerType.KENDRA_INTELLIGENT_RANKING; + public String getTransformerName() { + return MockTransformer.NAME; } @Override @@ -217,6 +226,28 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } }; + private static ResultTransformerConfigurationFactory MOCK_CONFIGURATION_FACTORY = new ResultTransformerConfigurationFactory() { + @Override + public String getName() { + return MockTransformer.NAME; + } + + @Override + public ResultTransformerConfiguration configure(Settings indexSettings) { + return MOCK_TRANSFORMER_CONFIGURATION; + } + + @Override + public ResultTransformerConfiguration configure(XContentParser parser) { + return MOCK_TRANSFORMER_CONFIGURATION; + } + + @Override + public ResultTransformerConfiguration configure(StreamInput streamInput) { + return MOCK_TRANSFORMER_CONFIGURATION; + } + }; + /** * Even if a transformer is wired into the SearchActionFilter, if it's not enabled by search request or * index setting, the transformer will not be called. @@ -227,10 +258,7 @@ public void testTransformerDoesNotRunWhenNotEnabled() { MockTransformer mockTransformer = new MockTransformer(); - Map transformerMap = - Map.of(ResultTransformerType.KENDRA_INTELLIGENT_RANKING, mockTransformer); - - SearchActionFilter searchActionFilter = new SearchActionFilter(transformerMap, openSearchClient); + SearchActionFilter searchActionFilter = new SearchActionFilter(List.of(mockTransformer), openSearchClient); Task task = Mockito.mock(Task.class); SearchRequest searchRequest = new SearchRequestBuilder(null, SearchAction.INSTANCE) @@ -257,10 +285,7 @@ public void testTransformEnabledInRequest() throws IOException { MockTransformer mockTransformer = new MockTransformer(); - Map transformerMap = - Map.of(ResultTransformerType.KENDRA_INTELLIGENT_RANKING, mockTransformer); - - SearchActionFilter searchActionFilter = new SearchActionFilter(transformerMap, openSearchClient); + SearchActionFilter searchActionFilter = new SearchActionFilter(List.of(mockTransformer), openSearchClient); Task task = Mockito.mock(Task.class); SearchRequest searchRequest = new SearchRequestBuilder(null, SearchAction.INSTANCE) @@ -332,7 +357,7 @@ private static SearchResponse buildMockSearchResponse(int numHits) throws IOExce */ public void testTransformEnabledByIndexSetting() throws IOException { String prefix = "index.plugin.searchrelevance.result_transformer." + - ResultTransformerType.KENDRA_INTELLIGENT_RANKING; + MockTransformer.NAME; Settings enablePluginSettings = Settings.builder() .put(prefix + ".order", 1) .build(); @@ -341,10 +366,7 @@ public void testTransformEnabledByIndexSetting() throws IOException { MockTransformer mockTransformer = new MockTransformer(); - Map transformerMap = - Map.of(ResultTransformerType.KENDRA_INTELLIGENT_RANKING, mockTransformer); - - SearchActionFilter searchActionFilter = new SearchActionFilter(transformerMap, openSearchClient); + SearchActionFilter searchActionFilter = new SearchActionFilter(List.of(mockTransformer), openSearchClient); Task task = Mockito.mock(Task.class); SearchRequest searchRequest = new SearchRequestBuilder(null, SearchAction.INSTANCE) @@ -398,10 +420,7 @@ public void testOutputUsesOriginalSourceParameters() throws IOException { .fetchSource(true); }); - Map transformerMap = - Map.of(ResultTransformerType.KENDRA_INTELLIGENT_RANKING, mockTransformer); - - SearchActionFilter searchActionFilter = new SearchActionFilter(transformerMap, openSearchClient); + SearchActionFilter searchActionFilter = new SearchActionFilter(List.of(mockTransformer), openSearchClient); Task task = Mockito.mock(Task.class); SearchRequest searchRequest = new SearchRequestBuilder(null, SearchAction.INSTANCE) @@ -476,10 +495,7 @@ public void testReturnEmptyWhenOriginalFromExceedsHitCount() throws IOException .fetchSource(true); }); - Map transformerMap = - Map.of(ResultTransformerType.KENDRA_INTELLIGENT_RANKING, mockTransformer); - - SearchActionFilter searchActionFilter = new SearchActionFilter(transformerMap, openSearchClient); + SearchActionFilter searchActionFilter = new SearchActionFilter(List.of(mockTransformer), openSearchClient); Task task = Mockito.mock(Task.class); SearchRequest searchRequest = new SearchRequestBuilder(null, SearchAction.INSTANCE) diff --git a/src/test/java/org/opensearch/search/relevance/configuration/SearchConfigurationExtBuilderTests.java b/src/test/java/org/opensearch/search/relevance/configuration/SearchConfigurationExtBuilderTests.java new file mode 100644 index 0000000..d16cb11 --- /dev/null +++ b/src/test/java/org/opensearch/search/relevance/configuration/SearchConfigurationExtBuilderTests.java @@ -0,0 +1,128 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.search.relevance.configuration; + +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class SearchConfigurationExtBuilderTests extends OpenSearchTestCase { + private static final String TRANSFORMER_NAME = "mock_transformer"; + + private static class MockResultTransformerConfiguration extends ResultTransformerConfiguration { + private final String configuredValue; + + public MockResultTransformerConfiguration(String configuredValue) { + this.configuredValue = configuredValue; + } + + @Override + public String getTransformerName() { + return TRANSFORMER_NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(configuredValue); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("configuredValue", configuredValue); + return builder.endObject(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MockResultTransformerConfiguration that = (MockResultTransformerConfiguration) o; + return configuredValue.equals(that.configuredValue); + } + + @Override + public int hashCode() { + return Objects.hash(configuredValue); + } + + @Override + public String toString() { + return "MockResultTransformerConfiguration{" + + "configuredValue='" + configuredValue + '\'' + + '}'; + } + } + + private static ResultTransformerConfigurationFactory MOCK_RESULT_TRANSFORMER_CONFIGURATION_FACTORY = new ResultTransformerConfigurationFactory() { + @Override + public String getName() { + return TRANSFORMER_NAME; + } + + @Override + public ResultTransformerConfiguration configure(Settings indexSettings) { + return null; + } + + @Override + public ResultTransformerConfiguration configure(XContentParser parser) throws IOException { + XContentParser.Token token = parser.nextToken(); + assertSame(XContentParser.Token.FIELD_NAME, token); + assertEquals("configuredValue", parser.currentName()); + token = parser.nextToken(); + assertSame(XContentParser.Token.VALUE_STRING, token); + String configuredValue = parser.text(); + return new MockResultTransformerConfiguration(configuredValue); + } + + @Override + public ResultTransformerConfiguration configure(StreamInput streamInput) throws IOException { + return new MockResultTransformerConfiguration(streamInput.readString()); + } + }; + public static final Map RESULT_TRANSFORMER_CONFIGURATION_FACTORY_MAP = Map.of(TRANSFORMER_NAME, MOCK_RESULT_TRANSFORMER_CONFIGURATION_FACTORY); + + + public void testXContentRoundTrip() throws IOException { + MockResultTransformerConfiguration configuration = new MockResultTransformerConfiguration(randomUnicodeOfLength(10)); + SearchConfigurationExtBuilder searchConfigurationExtBuilder = new SearchConfigurationExtBuilder() + .addResultTransformer(configuration); + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference serialized = XContentHelper.toXContent(searchConfigurationExtBuilder, xContentType, true); + + XContentParser parser = createParser(xContentType.xContent(), serialized); + + SearchConfigurationExtBuilder deserialized = + SearchConfigurationExtBuilder.parse(parser, RESULT_TRANSFORMER_CONFIGURATION_FACTORY_MAP); + assertEquals(searchConfigurationExtBuilder, deserialized); + } + + public void testStreamRoundTrip() throws IOException { + MockResultTransformerConfiguration configuration = new MockResultTransformerConfiguration(randomUnicodeOfLength(10)); + SearchConfigurationExtBuilder searchConfigurationExtBuilder = new SearchConfigurationExtBuilder() + .addResultTransformer(configuration); + BytesStreamOutput bytesStreamOutput = new BytesStreamOutput(); + searchConfigurationExtBuilder.writeTo(bytesStreamOutput); + + SearchConfigurationExtBuilder deserialized = new SearchConfigurationExtBuilder(bytesStreamOutput.bytes().streamInput(), + RESULT_TRANSFORMER_CONFIGURATION_FACTORY_MAP); + assertEquals(searchConfigurationExtBuilder, deserialized); + } +} \ No newline at end of file diff --git a/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRankerTests.java b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRankerTests.java index d975f07..bafe6b7 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRankerTests.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/KendraIntelligentRankerTests.java @@ -12,6 +12,7 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.settings.Setting; +import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.index.query.MatchAllQueryBuilder; @@ -20,7 +21,9 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.relevance.configuration.ResultTransformerConfiguration; +import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraClientSettings; import org.opensearch.search.relevance.transformer.kendraintelligentranking.client.KendraHttpClient; +import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankerSettings; import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfiguration; import org.opensearch.search.relevance.transformer.kendraintelligentranking.configuration.KendraIntelligentRankingConfiguration.KendraIntelligentRankingProperties; import org.opensearch.search.relevance.transformer.kendraintelligentranking.model.dto.RescoreRequest; @@ -39,6 +42,7 @@ public class KendraIntelligentRankerTests extends OpenSearchTestCase { private static KendraHttpClient buildMockHttpClient(Function mockRescoreImpl) { KendraHttpClient kendraHttpClient = Mockito.mock(KendraHttpClient.class); + Mockito.when(kendraHttpClient.isValid()).thenReturn(true); Mockito.doAnswer(invocation -> { RescoreRequest rescoreRequest = invocation.getArgument(0); return mockRescoreImpl.apply(rescoreRequest); @@ -105,6 +109,38 @@ public void testShouldNotTransformWithSort() { assertFalse(shouldTransform); } + public void testShouldNotTransformWithInvalidClient() { + Settings emptySettings = Settings.builder().build(); + KendraHttpClient invalidClient = new KendraHttpClient(KendraClientSettings.getClientSettings(emptySettings)); + testWithInvalidClient(new KendraHttpClient(KendraClientSettings.getClientSettings(emptySettings))); + + Settings settingsWithExecutionPlan = Settings.builder() + .put(KendraIntelligentRankerSettings.EXECUTION_PLAN_ID_SETTING.getKey(), "foo-plan") + .build(); + testWithInvalidClient(new KendraHttpClient(KendraClientSettings.getClientSettings(settingsWithExecutionPlan))); + + Settings settingsWithEndpoint = Settings.builder() + .put(KendraIntelligentRankerSettings.SERVICE_ENDPOINT_SETTING.getKey(), + "https://kendra-ranking.us-west-2.api.aws") + .build(); + testWithInvalidClient(new KendraHttpClient(KendraClientSettings.getClientSettings(settingsWithEndpoint))); + } + + private void testWithInvalidClient(KendraHttpClient invalidClient) { + KendraIntelligentRanker ranker = new KendraIntelligentRanker(invalidClient); + + // Otherwise valid search request: + SearchRequest originalRequest = new SearchRequest() + .source(new SearchSourceBuilder() + .query(new MatchAllQueryBuilder())); + KendraIntelligentRankingProperties properties = + new KendraIntelligentRankingProperties(List.of("body"), List.of("title"), 10); + + ResultTransformerConfiguration configuration = new KendraIntelligentRankingConfiguration(1, properties); + boolean shouldTransform = ranker.shouldTransform(originalRequest, configuration); + assertFalse(shouldTransform); + } + public void testShouldNotTransformIfFromExceedsDocLimit() { KendraIntelligentRanker ranker = new KendraIntelligentRanker(buildMockHttpClient()); SearchRequest originalRequest = new SearchRequest() @@ -185,12 +221,12 @@ public void testTransformHits() throws IOException { rescoreRequestRef.set(req); // Return the top N results in reverse order. List resultItems = req.getDocuments().stream() - .map(d -> { - RescoreResultItem item = new RescoreResultItem(); - item.setDocumentId(d.getGroupId()); - item.setScore(randomFloat()); - return item; - }).collect(Collectors.toList()); + .map(d -> { + RescoreResultItem item = new RescoreResultItem(); + item.setDocumentId(d.getGroupId()); + item.setScore(randomFloat()); + return item; + }).collect(Collectors.toList()); Collections.reverse(resultItems); RescoreResult result = new RescoreResult(); result.setResultItems(resultItems); diff --git a/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfigurationTests.java b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfigurationTests.java index 3093cc5..0a35176 100644 --- a/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfigurationTests.java +++ b/src/test/java/org/opensearch/search/relevance/transformer/kendraintelligentranking/configuration/KendraIntelligentRankingConfigurationTests.java @@ -15,7 +15,7 @@ public class KendraIntelligentRankingConfigurationTests extends OpenSearchTestCase { public void testParseWithNullParserAndContext() { try { - KendraIntelligentRankingConfiguration.parse(null, null); + KendraIntelligentRankingConfiguration.parse(null); fail(); } catch (NullPointerException | IOException e) { }