Skip to content

Commit

Permalink
Query: Adds hybrid search query pipeline stage (#4794)
Browse files Browse the repository at this point in the history
## Description

Adds hybrid search query pipeline stage. This requires the new Direct
package and gateway to be available in order to light up.

Given an input SQL such as:
```sql
      SELECT TOP 100 c.text, c.abstract
      FROM c
      ORDER BY RANK RRF(FullTextScore(c.text, ['swim', 'run']), FullTextScore(c.abstract, ['energy']))
```

The new query plan (encoded below as XML instead of JSON to help
readability) is as follows:

```
        <queryRanges>
          <Item>{"min":[],"max":"Infinity","isMinInclusive":true,"isMaxInclusive":false}</Item>
        </queryRanges>
        <hybridSearchQueryInfo>
          <globalStatisticsQuery><![CDATA[
SELECT 
    COUNT(1) AS documentCount,
    [
        {
            totalWordCount: SUM(_FullTextWordCount(c.text)),
            hitCounts: [
                COUNTIF(FullTextContains(c.text, "swim")),
                COUNTIF(FullTextContains(c.text, "run"))
            ]
        },
        {
            totalWordCount: SUM(_FullTextWordCount(c.abstract)),
            hitCounts: [
                COUNTIF(FullTextContains(c.abstract, "energy"))
            ]
        }
    ] AS fullTextStatistics
FROM c
]]></globalStatisticsQuery>
          <componentQueryInfos>
            <Item>
              <distinctType>None</distinctType>
              <top>200</top>
              <orderBy>
                <Item>Descending</Item>
              </orderBy>
              <orderByExpressions>
                <Item>_FullTextScore(c.text, ["swim", "run"], {documentdb-formattablehybridsearchquery-totaldocumentcount}, {documentdb-formattablehybridsearchquery-totalwordcount-0}, {documentdb-formattablehybridsearchquery-hitcountsarray-0})</Item>
              </orderByExpressions>
              <hasSelectValue>false</hasSelectValue>
              <rewrittenQuery><![CDATA[
SELECT TOP 200 
    c._rid,
    [
        {
            item: _FullTextScore(c.text, ["swim", "run"], {documentdb-formattablehybridsearchquery-totaldocumentcount}, {documentdb-formattablehybridsearchquery-totalwordcount-0}, {documentdb-formattablehybridsearchquery-hitcountsarray-0})
        }
    ] AS orderByItems,
    {
        payload: {
            text: c.text,
            abstract: c.abstract
        },
        componentScores: [
            _FullTextScore(c.text, ["swim", "run"], {documentdb-formattablehybridsearchquery-totaldocumentcount}, {documentdb-formattablehybridsearchquery-totalwordcount-0}, {documentdb-formattablehybridsearchquery-hitcountsarray-0}),
            _FullTextScore(c.abstract, ["energy"], {documentdb-formattablehybridsearchquery-totaldocumentcount}, {documentdb-formattablehybridsearchquery-totalwordcount-1}, {documentdb-formattablehybridsearchquery-hitcountsarray-1})
        ]
    } AS payload
FROM c
WHERE {documentdb-formattableorderbyquery-filter}
ORDER BY _FullTextScore(c.text, ["swim", "run"], {documentdb-formattablehybridsearchquery-totaldocumentcount}, {documentdb-formattablehybridsearchquery-totalwordcount-0}, {documentdb-formattablehybridsearchquery-hitcountsarray-0}) DESC
]]></rewrittenQuery>
              <hasNonStreamingOrderBy>true</hasNonStreamingOrderBy>
            </Item>
            <Item>
              <distinctType>None</distinctType>
              <top>200</top>
              <orderBy>
                <Item>Descending</Item>
              </orderBy>
              <orderByExpressions>
                <Item>_FullTextScore(c.abstract, ["energy"], {documentdb-formattablehybridsearchquery-totaldocumentcount}, {documentdb-formattablehybridsearchquery-totalwordcount-1}, {documentdb-formattablehybridsearchquery-hitcountsarray-1})</Item>
              </orderByExpressions>
              <hasSelectValue>false</hasSelectValue>
              <rewrittenQuery><![CDATA[
SELECT TOP 200 
    c._rid,
    [
        {
            item: _FullTextScore(c.abstract, ["energy"], {documentdb-formattablehybridsearchquery-totaldocumentcount}, {documentdb-formattablehybridsearchquery-totalwordcount-1}, {documentdb-formattablehybridsearchquery-hitcountsarray-1})
        }
    ] AS orderByItems,
    {
        payload: {
            text: c.text,
            abstract: c.abstract
        },
        componentScores: [
            _FullTextScore(c.text, ["swim", "run"], {documentdb-formattablehybridsearchquery-totaldocumentcount}, {documentdb-formattablehybridsearchquery-totalwordcount-0}, {documentdb-formattablehybridsearchquery-hitcountsarray-0}),
            _FullTextScore(c.abstract, ["energy"], {documentdb-formattablehybridsearchquery-totaldocumentcount}, {documentdb-formattablehybridsearchquery-totalwordcount-1}, {documentdb-formattablehybridsearchquery-hitcountsarray-1})
        ]
    } AS payload
FROM c
WHERE {documentdb-formattableorderbyquery-filter}
ORDER BY _FullTextScore(c.abstract, ["energy"], {documentdb-formattablehybridsearchquery-totaldocumentcount}, {documentdb-formattablehybridsearchquery-totalwordcount-1}, {documentdb-formattablehybridsearchquery-hitcountsarray-1}) DESC
]]></rewrittenQuery>
              <hasNonStreamingOrderBy>true</hasNonStreamingOrderBy>
            </Item>
          </componentQueryInfos>
          <take>100</take>
          <requiresGlobalStatistics>true</requiresGlobalStatistics>
        </hybridSearchQueryInfo>
```

We have a custom implementation for the global statistics inside the
`HybridSearchCrossPartitionQueryPipelineStage` because it uses nested
aggregates. Each of the component queries in the hybrid search query
plan is cross partition, and we run them using the existing cross
partition query pipelines.

Note the use of placeholders such as
`{documentdb-formattablehybridsearchquery-totaldocumentcount}` in the
query plan. These need to be replaced by the global statistics.

## Type of change

- [x] New feature (non-breaking change which adds functionality)
  • Loading branch information
neildsh authored Oct 18, 2024
1 parent 57c681f commit 4e1c033
Show file tree
Hide file tree
Showing 21 changed files with 1,721 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ public static TryCatch<AggregateValue> TryCreate(
tryCreateAggregator = AverageAggregator.TryCreate(continuationToken);
break;

case AggregateOperator.Count:
case AggregateOperator.Count:
case AggregateOperator.CountIf:
tryCreateAggregator = CountAggregator.TryCreate(continuationToken);
break;

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// ------------------------------------------------------------

namespace Microsoft.Azure.Cosmos.Query.Core.Pipeline.CrossPartition.HybridSearch
{
using System;
using System.Collections.Generic;
using Microsoft.Azure.Cosmos.CosmosElements;

internal sealed class FullTextStatistics
{
private readonly long[] hitCounts;

public long TotalWordCount { get; }

public ReadOnlyMemory<long> HitCounts => this.hitCounts;

public FullTextStatistics(long totalWordCount, long[] hitCounts)
{
this.TotalWordCount = totalWordCount;
this.hitCounts = hitCounts;
}

public FullTextStatistics(CosmosObject cosmosObject)
{
if (cosmosObject == null)
{
throw new System.ArgumentNullException($"{nameof(cosmosObject)} must not be null.");
}

if (!cosmosObject.TryGetValue(FieldNames.TotalWordCount, out CosmosNumber totalWordCount))
{
throw new System.ArgumentException($"{FieldNames.TotalWordCount} must exist and be a number");
}

if (!cosmosObject.TryGetValue(FieldNames.HitCounts, out CosmosArray hitCountsArray))
{
throw new System.ArgumentException($"{FieldNames.HitCounts} must exist and be an array");
}

long[] hitCounts = new long[hitCountsArray.Count];
for (int index = 0; index < hitCountsArray.Count; ++index)
{
if (!(hitCountsArray[index] is CosmosNumber cosmosNumber))
{
throw new System.ArgumentException($"{FieldNames.HitCounts} must be an array of numbers");
}

hitCounts[index] = Number64.ToLong(cosmosNumber.Value);
}

this.TotalWordCount = Number64.ToLong(totalWordCount.Value);
this.hitCounts = hitCounts;
}

private static class FieldNames
{
public const string TotalWordCount = "totalWordCount";

public const string HitCounts = "hitCounts";
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// ------------------------------------------------------------

namespace Microsoft.Azure.Cosmos.Query.Core.Pipeline.CrossPartition.HybridSearch
{
using System.Collections.Generic;
using Microsoft.Azure.Cosmos.CosmosElements;

internal sealed class GlobalFullTextSearchStatistics
{
public long DocumentCount { get; }

public IReadOnlyList<FullTextStatistics> FullTextStatistics { get; }

public GlobalFullTextSearchStatistics(long documentCount, IReadOnlyList<FullTextStatistics> fullTextStatistics)
{
this.DocumentCount = documentCount;
this.FullTextStatistics = fullTextStatistics ?? throw new System.ArgumentNullException($"{nameof(fullTextStatistics)} must not be null.");
}

public GlobalFullTextSearchStatistics(CosmosElement cosmosElement)
{
if (cosmosElement == null)
{
throw new System.ArgumentNullException($"{nameof(cosmosElement)} must not be null.");
}

if (!(cosmosElement is CosmosObject cosmosObject))
{
throw new System.ArgumentException($"{nameof(cosmosElement)} must be an object.");
}

if (!cosmosObject.TryGetValue(FieldNames.DocumentCount, out CosmosNumber cosmosNumber))
{
throw new System.ArgumentException($"{FieldNames.DocumentCount} must exist and be a number");
}

if (!cosmosObject.TryGetValue(FieldNames.Statistics, out CosmosArray statisticsArray))
{
throw new System.ArgumentException($"{FieldNames.Statistics} must exist and be an array");
}

List<FullTextStatistics> fullTextStatisticsList = new List<FullTextStatistics>(statisticsArray.Count);
foreach (CosmosElement statisticsElement in statisticsArray)
{
if (!(statisticsElement is CosmosObject))
{
throw new System.ArgumentException($"{FieldNames.Statistics} must be an array of objects");
}

FullTextStatistics fullTextStatistics = new FullTextStatistics(statisticsElement as CosmosObject);
fullTextStatisticsList.Add(fullTextStatistics);
}

this.DocumentCount = Number64.ToLong(cosmosNumber.Value);
this.FullTextStatistics = fullTextStatisticsList;
}

private static class FieldNames
{
public const string DocumentCount = "documentCount";

public const string Statistics = "fullTextStatistics";
}
}
}
Loading

0 comments on commit 4e1c033

Please sign in to comment.