Skip to content

Commit

Permalink
Fix header
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Oct 15, 2024
1 parent f1a5762 commit c1ea67a
Show file tree
Hide file tree
Showing 2 changed files with 240 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.application.rules.retriever;
Expand All @@ -15,6 +13,8 @@
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
Expand All @@ -23,6 +23,7 @@
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -124,6 +125,7 @@ public String getName() {

@Override
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
Logger logger = LogManager.getLogger(QueryRuleRetrieverBuilder.class);
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
Expand All @@ -149,6 +151,9 @@ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit,
sourceBuilder.query(newQuery);
}

sourceBuilder.sort(new ScoreSortBuilder());

logger.info("sourceBuilder: " + sourceBuilder);
return sourceBuilder;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.xpack.application.rules.retriever;

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/


import org.apache.lucene.search.ScoreDoc;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.fetch.StoredFieldsContext;
import org.elasticsearch.search.rank.RankDoc;
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverBuilder;
import org.elasticsearch.search.retriever.RetrieverParserContext;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.sort.ScoreSortBuilder;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.application.EnterpriseSearch;
import org.elasticsearch.xpack.application.rules.RuleQueryBuilder;
import org.elasticsearch.xpack.core.XPackPlugin;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

/**
* A query rule retriever applies query rules defined in one or more rulesets to the underlying retriever.
*/
public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<QueryRuleRetrieverBuilder> {

public static final String NAME = "rule";
public static final NodeFeature QUERY_RULE_RETRIEVERS_SUPPORTED = new NodeFeature("query_rule_retriever_supported");

public static final ParseField RULESET_IDS_FIELD = new ParseField("ruleset_ids");
public static final ParseField MATCH_CRITERIA_FIELD = new ParseField("match_criteria");
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<QueryRuleRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
"rule",
args -> {
List<String> rulesetIds = (List<String>) args[0];
Map<String, Object> matchCriteria = (Map<String, Object>) args[1];
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[2];
int rankWindowSize = args[3] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, retrieverBuilder, rankWindowSize);
}
);

static {
PARSER.declareStringArray(constructorArg(), RULESET_IDS_FIELD);
PARSER.declareObject(constructorArg(), (p, c) -> p.map(), MATCH_CRITERIA_FIELD);
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> p.namedObject(RetrieverBuilder.class, n, c), RETRIEVER_FIELD);
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
}

public static QueryRuleRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
if (context.clusterSupportsFeature(QUERY_RULE_RETRIEVERS_SUPPORTED) == false) {
throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + NAME + "]");
}
if (EnterpriseSearch.QUERY_RULES_RETRIEVER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) {
throw LicenseUtils.newComplianceException("Query Rules");
}
try {
return PARSER.apply(parser, context);
} catch (Exception e) {
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
}
}

private final List<String> rulesetIds;
private final Map<String, Object> matchCriteria;

public QueryRuleRetrieverBuilder(
List<String> rulesetIds,
Map<String, Object> matchCriteria,
RetrieverBuilder retrieverBuilder,
int rankWindowSize
) {
super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize);
this.rulesetIds = rulesetIds;
this.matchCriteria = matchCriteria;
}

public QueryRuleRetrieverBuilder(
List<String> rulesetIds,
Map<String, Object> matchCriteria,
List<RetrieverSource> retrieverSource,
int rankWindowSize,
String retrieverName,
List<QueryBuilder> preFilterQueryBuilders
) {
super(retrieverSource, rankWindowSize);
this.rulesetIds = rulesetIds;
this.matchCriteria = matchCriteria;
this.retrieverName = retrieverName;
this.preFilterQueryBuilders = new ArrayList<>(preFilterQueryBuilders);
}

@Override
public String getName() {
return NAME;
}

@Override
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
Logger logger = LogManager.getLogger(QueryRuleRetrieverBuilder.class);
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
.trackTotalHits(false)
.storedFields(new StoredFieldsContext(false))
.size(rankWindowSize);
if (preFilterQueryBuilders.isEmpty() == false) {
retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
}
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);

QueryBuilder query = sourceBuilder.query();
if (query != null && query instanceof RuleQueryBuilder == false) {
QueryBuilder organicQuery = query;
query = new RuleQueryBuilder(organicQuery, matchCriteria, rulesetIds);
}

// apply the pre-filters
if (preFilterQueryBuilders.size() > 0) {
BoolQueryBuilder newQuery = new BoolQueryBuilder();
if (query != null) {
newQuery.must(query);
}
preFilterQueryBuilders.forEach(newQuery::filter);
sourceBuilder.query(newQuery);
}

sourceBuilder.sort(new ScoreSortBuilder());

logger.info("sourceBuilder: " + sourceBuilder);
return sourceBuilder;
}

@Override
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
builder.array(RULESET_IDS_FIELD.getPreferredName(), rulesetIds.toArray());
builder.startObject(MATCH_CRITERIA_FIELD.getPreferredName());
builder.mapContents(matchCriteria);
builder.endObject();
}

@Override
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
return new QueryRuleRetrieverBuilder(
rulesetIds,
matchCriteria,
newChildRetrievers,
rankWindowSize,
retrieverName,
preFilterQueryBuilders
);
}

@Override
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
assert rankResults.size() == 1;
ScoreDoc[] scoreDocs = rankResults.getFirst();
RankDoc[] rankDocs = new RankDoc[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
ScoreDoc scoreDoc = scoreDocs[i];
rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
}
return rankDocs;
}

@Override
public QueryBuilder explainQuery() {
// the original matching set of the QueryRuleRetriever retriever is specified by its nested retriever
return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.getFirst().retriever().explainQuery() }, true);
}

@Override
public boolean doEquals(Object o) {
QueryRuleRetrieverBuilder that = (QueryRuleRetrieverBuilder) o;
return super.doEquals(o) && Objects.equals(rulesetIds, that.rulesetIds) && Objects.equals(matchCriteria, that.matchCriteria);
}

@Override
public int doHashCode() {
return Objects.hash(super.doHashCode(), rulesetIds, matchCriteria);
}
}

0 comments on commit c1ea67a

Please sign in to comment.