Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/134822.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 134822
summary: Support querying multiple indices with the simplified RRF retriever
area: Relevance
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public Set<NodeFeature> getTestFeatures() {
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
RRFRetrieverBuilder.WEIGHTED_SUPPORT,
LINEAR_RETRIEVER_TOP_LEVEL_NORMALIZER,
LinearRetrieverBuilder.MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT
LinearRetrieverBuilder.MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT,
RRFRetrieverBuilder.MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.weighted_support");

public static final NodeFeature MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT = new NodeFeature(
"rrf_retriever.multi_index_simplified_format_support"
);
public static final String NAME = "rrf";

public static final ParseField RETRIEVERS_FIELD = new ParseField("retrievers");
Expand Down Expand Up @@ -253,11 +255,7 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
// TODO: Refactor duplicate code
// Using the multi-fields query format
var localIndicesMetadata = resolvedIndices.getConcreteLocalIndicesMetadata();
if (localIndicesMetadata.size() > 1) {
throw new IllegalArgumentException(
"[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying multiple indices"
);
} else if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
if (resolvedIndices.getRemoteClusterIndices().isEmpty() == false) {
throw new IllegalArgumentException(
"[" + NAME + "] cannot specify [" + QUERY_FIELD.getPreferredName() + "] when querying remote indices"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.SearchModule;
import org.elasticsearch.search.builder.PointInTimeBuilder;
Expand Down Expand Up @@ -235,6 +236,270 @@ public void testMultiFieldsParamsRewrite() {
);
}

public void testMultiIndexMultiFieldsParamsRewrite() {
String indexName = "test-index";
String anotherIndexName = "test-another-index";
final ResolvedIndices resolvedIndices = createMockResolvedIndices(
Map.of(
indexName,
List.of("semantic_field_1", "semantic_field_2"),
anotherIndexName,
List.of("semantic_field_2", "semantic_field_3")
),
null,
Map.of() // use random and different inference IDs for semantic_text fields
);

final QueryRewriteContext queryRewriteContext = new QueryRewriteContext(
parserConfig(),
null,
null,
TransportVersion.current(),
RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY,
resolvedIndices,
new PointInTimeBuilder(new BytesArray("pitid")),
null,
null
);

// No wildcards, no per-field boosting
RRFRetrieverBuilder retriever = new RRFRetrieverBuilder(
null,
List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
"foo",
DEFAULT_RANK_WINDOW_SIZE,
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
new float[0]
);
assertMultiIndexMultiFieldsParamsRewrite(
retriever,
queryRewriteContext,
Map.of(
Map.of("field_1", 1.0f, "field_2", 1.0f),
List.of(indexName),
Map.of("field_1", 1.0f, "field_2", 1.0f, "semantic_field_1", 1.0f),
List.of(anotherIndexName)
),
Map.of(
new Tuple<>("semantic_field_1", List.of(indexName)),
1.0f,
new Tuple<>("semantic_field_2", List.of(indexName)), // field with different inference IDs, we filter on index name
1.0f,
new Tuple<>("semantic_field_2", List.of(anotherIndexName)),
1.0f
),
"foo",
null
);

// Glob matching on inference and non-inference fields
retriever = new RRFRetrieverBuilder(
null,
List.of("field_*", "field_1", "*_field_1", "semantic_*"),
"baz2",
DEFAULT_RANK_WINDOW_SIZE,
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
new float[0]
);
assertMultiIndexMultiFieldsParamsRewrite(
retriever,
queryRewriteContext,
Map.of(Map.of("field_*", 1.0f, "field_1", 1.0f, "*_field_1", 1.0f, "semantic_*", 1.0f), List.of()),
Map.of(
new Tuple<>("semantic_field_1", List.of(indexName)),
1.0f,
new Tuple<>("semantic_field_2", List.of(indexName)),
1.0f,
new Tuple<>("semantic_field_2", List.of(anotherIndexName)),
1.0f,
new Tuple<>("semantic_field_3", List.of(anotherIndexName)),
1.0f
),
"baz2",
null
);

// Non-default rank window size
retriever = new RRFRetrieverBuilder(
null,
List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
"foo2",
DEFAULT_RANK_WINDOW_SIZE * 2,
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
new float[0]
);
assertMultiIndexMultiFieldsParamsRewrite(
retriever,
queryRewriteContext,
Map.of(
Map.of("field_1", 1.0f, "field_2", 1.0f),
List.of(indexName),
Map.of("field_1", 1.0f, "field_2", 1.0f, "semantic_field_1", 1.0f),
List.of(anotherIndexName)
),
Map.of(
new Tuple<>("semantic_field_1", List.of(indexName)),
1.0f,
new Tuple<>("semantic_field_2", List.of(indexName)),
1.0f,
new Tuple<>("semantic_field_2", List.of(anotherIndexName)),
1.0f
),
"foo2",
null
);
Comment on lines +322 to +350
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we use this to test non-default rank constant as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

promise to do it in a follow up - I don't want to keep @mridula-s109 waiting too much on this PR to continue the work on #132680 - this is the second PR that will introduce conflicts for her PR.


// All-fields wildcard
retriever = new RRFRetrieverBuilder(
null,
List.of("*"),
"qux",
DEFAULT_RANK_WINDOW_SIZE,
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
new float[0]
);
assertMultiIndexMultiFieldsParamsRewrite(
retriever,
queryRewriteContext,
Map.of(Map.of("*", 1.0f), List.of()), // no index filter for the lexical retriever
Map.of(
new Tuple<>("semantic_field_1", List.of(indexName)),
1.0f,
new Tuple<>("semantic_field_2", List.of(indexName)),
1.0f,
new Tuple<>("semantic_field_2", List.of(anotherIndexName)),
1.0f,
new Tuple<>("semantic_field_3", List.of(anotherIndexName)),
1.0f
),
"qux",
null
);
}

public void testMultiIndexMultiFieldsParamsRewriteWithSameInferenceIds() {
String indexName = "test-index";
String anotherIndexName = "test-another-index";
final ResolvedIndices resolvedIndices = createMockResolvedIndices(
Map.of(
indexName,
List.of("semantic_field_1", "semantic_field_2"),
anotherIndexName,
List.of("semantic_field_2", "semantic_field_3")
),
null,
Map.of("semantic_field_2", "common_inference_id") // use the same inference ID for semantic_field_2
);

final QueryRewriteContext queryRewriteContext = new QueryRewriteContext(
parserConfig(),
null,
null,
TransportVersion.current(),
RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY,
resolvedIndices,
new PointInTimeBuilder(new BytesArray("pitid")),
null,
null
);

// No wildcards, no per-field boosting
RRFRetrieverBuilder retriever = new RRFRetrieverBuilder(
null,
List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
"foo",
DEFAULT_RANK_WINDOW_SIZE,
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
new float[0]
);
assertMultiIndexMultiFieldsParamsRewrite(
retriever,
queryRewriteContext,
Map.of(
Map.of("field_1", 1.0f, "field_2", 1.0f),
List.of(indexName),
Map.of("field_1", 1.0f, "field_2", 1.0f, "semantic_field_1", 1.0f),
List.of(anotherIndexName)
),
Map.of(new Tuple<>("semantic_field_1", List.of(indexName)), 1.0f, new Tuple<>("semantic_field_2", List.of()), 1.0f),
"foo",
null
);

// Non-default rank window size
retriever = new RRFRetrieverBuilder(
null,
List.of("field_1", "field_2", "semantic_field_1", "semantic_field_2"),
"foo2",
DEFAULT_RANK_WINDOW_SIZE * 2,
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
new float[0]
);
assertMultiIndexMultiFieldsParamsRewrite(
retriever,
queryRewriteContext,
Map.of(
Map.of("field_1", 1.0f, "field_2", 1.0f),
List.of(indexName),
Map.of("field_1", 1.0f, "field_2", 1.0f, "semantic_field_1", 1.0f),
List.of(anotherIndexName)
),
Map.of(new Tuple<>("semantic_field_1", List.of(indexName)), 1.0f, new Tuple<>("semantic_field_2", List.of()), 1.0f),
"foo2",
null
);

// Glob matching on inference and non-inference fields
retriever = new RRFRetrieverBuilder(
null,
List.of("field_*", "field_1", "*_field_1", "semantic_*"),
"baz2",
DEFAULT_RANK_WINDOW_SIZE,
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
new float[0]
);
assertMultiIndexMultiFieldsParamsRewrite(
retriever,
queryRewriteContext,
Map.of(Map.of("field_*", 1.0f, "field_1", 1.0f, "*_field_1", 1.0f, "semantic_*", 1.0f), List.of()),
Map.of(
new Tuple<>("semantic_field_1", List.of(indexName)),
1.0f,
new Tuple<>("semantic_field_2", List.of()),
1.0f,
new Tuple<>("semantic_field_3", List.of(anotherIndexName)),
1.0f
),
"baz2",
null
);

// All-fields wildcard
retriever = new RRFRetrieverBuilder(
null,
List.of("*"),
"qux",
DEFAULT_RANK_WINDOW_SIZE,
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
new float[0]
);
assertMultiIndexMultiFieldsParamsRewrite(
retriever,
queryRewriteContext,
Map.of(Map.of("*", 1.0f), List.of()), // on index filter on the lexical query
Map.of(
new Tuple<>("semantic_field_1", List.of(indexName)),
1.0f,
new Tuple<>("semantic_field_2", List.of()), // no index filter since both indices have this field
1.0f,
new Tuple<>("semantic_field_3", List.of(anotherIndexName)),
1.0f
),
"qux",
null
);
}

public void testSearchRemoteIndex() {
final ResolvedIndices resolvedIndices = createMockResolvedIndices(
Map.of("local-index", List.of()),
Expand Down
Loading