Skip to content

Commit

Permalink
Query: Fixes plumbing VectorEmbeddingPolicy to ServiceInterop to choo…
Browse files Browse the repository at this point in the history
…se correct default distance function (#4538)

* Plumb the collection VectorEmbeddingPolicy to ServiceInterop

* Add query plan baseline tests for vector search

* Correct typo in the query for baseline test

* Fix build errors

* fix runtime issue in mock setup due to the extra argument for vector embedding policy
  • Loading branch information
neildsh committed Jun 11, 2024
1 parent 5994b16 commit 8c8d3e9
Show file tree
Hide file tree
Showing 24 changed files with 414 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ private static async Task<PartitionedQueryExecutionInfo> GetPartitionedQueryExec
inputParameters.SqlQuerySpec,
cosmosQueryContext.ResourceTypeEnum,
partitionKeyDefinition,
containerQueryProperties.VectorEmbeddingPolicy,
inputParameters.PartitionKey != null,
containerQueryProperties.GeospatialType,
cosmosQueryContext.UseSystemPrefix,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ internal readonly struct ContainerQueryProperties
public ContainerQueryProperties(
string resourceId,
IReadOnlyList<Range<string>> effectivePartitionKeyRanges,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
Cosmos.VectorEmbeddingPolicy vectorEmbeddingPolicy,
Cosmos.GeospatialType geospatialType)
{
this.ResourceId = resourceId;
this.EffectiveRangesForPartitionKey = effectivePartitionKeyRanges;
this.PartitionKeyDefinition = partitionKeyDefinition;
this.PartitionKeyDefinition = partitionKeyDefinition;
this.VectorEmbeddingPolicy = vectorEmbeddingPolicy;
this.GeospatialType = geospatialType;
}

Expand All @@ -27,7 +29,11 @@ public ContainerQueryProperties(
//A PartitionKey has one range when it is a full PartitionKey value.
//It can span many it is a prefix PartitionKey for a sub-partitioned container.
public IReadOnlyList<Range<string>> EffectiveRangesForPartitionKey { get; }
public PartitionKeyDefinition PartitionKeyDefinition { get; }

public PartitionKeyDefinition PartitionKeyDefinition { get; }

public Cosmos.VectorEmbeddingPolicy VectorEmbeddingPolicy { get; }

public Cosmos.GeospatialType GeospatialType { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public abstract Task<ContainerQueryProperties> GetCachedContainerQueryProperties
public abstract Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetPartitionedQueryExecutionInfoAsync(
SqlQuerySpec sqlQuerySpec,
Documents.ResourceType resourceType,
Documents.PartitionKeyDefinition partitionKeyDefinition,
Documents.PartitionKeyDefinition partitionKeyDefinition,
Cosmos.VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool requireFormattableOrderByQuery,
bool isContinuationExpected,
bool allowNonValueAggregateQuery,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ public void Update(IDictionary<string, object> queryengineConfiguration)
public TryCatch<PartitionedQueryExecutionInfo> TryGetPartitionedQueryExecutionInfo(
string querySpecJsonString,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool requireFormattableOrderByQuery,
bool isContinuationExpected,
bool allowNonValueAggregateQuery,
Expand All @@ -130,7 +131,8 @@ public TryCatch<PartitionedQueryExecutionInfo> TryGetPartitionedQueryExecutionIn
{
TryCatch<PartitionedQueryExecutionInfoInternal> tryGetInternalQueryInfo = this.TryGetPartitionedQueryExecutionInfoInternal(
querySpecJsonString: querySpecJsonString,
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: vectorEmbeddingPolicy,
requireFormattableOrderByQuery: requireFormattableOrderByQuery,
isContinuationExpected: isContinuationExpected,
allowNonValueAggregateQuery: allowNonValueAggregateQuery,
Expand Down Expand Up @@ -179,7 +181,8 @@ internal PartitionedQueryExecutionInfo ConvertPartitionedQueryExecutionInfo(

internal TryCatch<PartitionedQueryExecutionInfoInternal> TryGetPartitionedQueryExecutionInfoInternal(
string querySpecJsonString,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool requireFormattableOrderByQuery,
bool isContinuationExpected,
bool allowNonValueAggregateQuery,
Expand Down Expand Up @@ -222,8 +225,12 @@ internal TryCatch<PartitionedQueryExecutionInfoInternal> TryGetPartitionedQueryE

Span<byte> buffer = stackalloc byte[QueryPartitionProvider.InitialBufferSize];
uint errorCode;
uint serializedQueryExecutionInfoResultLength;

uint serializedQueryExecutionInfoResultLength;

string vectorEmbeddingPolicyString = vectorEmbeddingPolicy != null ?
JsonConvert.SerializeObject(vectorEmbeddingPolicy) :
null;

unsafe
{
ServiceInteropWrapper.PartitionKeyRangesApiOptions partitionKeyRangesApiOptions =
Expand All @@ -241,13 +248,15 @@ internal TryCatch<PartitionedQueryExecutionInfoInternal> TryGetPartitionedQueryE

fixed (byte* bytePtr = buffer)
{
errorCode = ServiceInteropWrapper.GetPartitionKeyRangesFromQuery3(
errorCode = ServiceInteropWrapper.GetPartitionKeyRangesFromQuery4(
this.serviceProvider,
querySpecJsonString,
partitionKeyRangesApiOptions,
allParts,
partsLengths,
(uint)partitionKeyDefinition.Paths.Count,
(uint)partitionKeyDefinition.Paths.Count,
vectorEmbeddingPolicyString,
vectorEmbeddingPolicyString?.Length ?? 0,
new IntPtr(bytePtr),
(uint)buffer.Length,
out serializedQueryExecutionInfoResultLength);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ public QueryPlanHandler(CosmosQueryClient queryClient)
public async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryPlanAsync(
SqlQuerySpec sqlQuerySpec,
Documents.ResourceType resourceType,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
QueryFeatures supportedQueryFeatures,
bool hasLogicalPartitionKey,
bool useSystemPrefix,
Expand All @@ -47,7 +48,8 @@ public async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryPlanAsync(
TryCatch<PartitionedQueryExecutionInfo> tryGetQueryInfo = await this.TryGetQueryInfoAsync(
sqlQuerySpec,
resourceType,
partitionKeyDefinition,
partitionKeyDefinition,
vectorEmbeddingPolicy,
hasLogicalPartitionKey,
useSystemPrefix,
geospatialType,
Expand Down Expand Up @@ -75,7 +77,8 @@ public async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryPlanAsync(
QueryFeatures supportedQueryFeatures,
SqlQuerySpec sqlQuerySpec,
Documents.ResourceType resourceType,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool hasLogicalPartitionKey,
bool useSystemPrefix,
GeospatialType geospatialType,
Expand All @@ -96,7 +99,8 @@ public async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryPlanAsync(
TryCatch<PartitionedQueryExecutionInfo> tryGetQueryInfo = await this.TryGetQueryInfoAsync(
sqlQuerySpec,
resourceType,
partitionKeyDefinition,
partitionKeyDefinition,
vectorEmbeddingPolicy,
hasLogicalPartitionKey,
useSystemPrefix,
geospatialType,
Expand All @@ -115,7 +119,8 @@ public async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryPlanAsync(
private Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryInfoAsync(
SqlQuerySpec sqlQuerySpec,
Documents.ResourceType resourceType,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool hasLogicalPartitionKey,
bool useSystemPrefix,
Cosmos.GeospatialType geospatialType,
Expand All @@ -126,7 +131,8 @@ private Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryInfoAsync(
return this.queryClient.TryGetPartitionedQueryExecutionInfoAsync(
sqlQuerySpec: sqlQuerySpec,
resourceType: resourceType,
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: vectorEmbeddingPolicy,
requireFormattableOrderByQuery: true,
isContinuationExpected: false,
allowNonValueAggregateQuery: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ public static async Task<PartitionedQueryExecutionInfo> GetQueryPlanWithServiceI
CosmosQueryClient queryClient,
SqlQuerySpec sqlQuerySpec,
Documents.ResourceType resourceType,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool hasLogicalPartitionKey,
GeospatialType geospatialType,
bool useSystemPrefix,
Expand Down Expand Up @@ -81,7 +82,8 @@ public static async Task<PartitionedQueryExecutionInfo> GetQueryPlanWithServiceI
TryCatch<PartitionedQueryExecutionInfo> tryGetQueryPlan = await queryPlanHandler.TryGetQueryPlanAsync(
sqlQuerySpec,
resourceType,
partitionKeyDefinition,
partitionKeyDefinition,
vectorEmbeddingPolicy,
QueryPlanRetriever.SupportedQueryFeatures,
hasLogicalPartitionKey,
useSystemPrefix,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ private static bool ServiceInteropAvailable()
allowDCount: false,
allowNonValueAggregates: false,
useSystemPrefix: false,
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: collection.VectorEmbeddingPolicy,
queryPartitionProvider: queryPartitionProvider,
clientApiVersion: version,
geospatialType: collection.GeospatialConfig.GeospatialType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ protected SqlQuerySpec QuerySpec
public Guid CorrelatedActivityId { get; }

public async Task<PartitionedQueryExecutionInfo> GetPartitionedQueryExecutionInfoAsync(
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
Cosmos.VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool requireFormattableOrderByQuery,
bool isContinuationExpected,
bool allowNonValueAggregateQuery,
Expand All @@ -179,7 +180,8 @@ public async Task<PartitionedQueryExecutionInfo> GetPartitionedQueryExecutionInf
QueryPartitionProvider queryPartitionProvider = await this.Client.GetQueryPartitionProviderAsync();
TryCatch<PartitionedQueryExecutionInfo> tryGetPartitionedQueryExecutionInfo = queryPartitionProvider.TryGetPartitionedQueryExecutionInfo(
querySpecJsonString: JsonConvert.SerializeObject(this.QuerySpec),
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: vectorEmbeddingPolicy,
requireFormattableOrderByQuery: requireFormattableOrderByQuery,
isContinuationExpected: isContinuationExpected,
allowNonValueAggregateQuery: allowNonValueAggregateQuery,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ public static async Task<IDocumentQueryExecutionContext> CreateDocumentQueryExec
//if collection is deleted/created with same name.
//need to make it not rely on information from collection cache.
PartitionedQueryExecutionInfo partitionedQueryExecutionInfo = await queryExecutionContext.GetPartitionedQueryExecutionInfoAsync(
partitionKeyDefinition: collection.PartitionKey,
partitionKeyDefinition: collection.PartitionKey,
vectorEmbeddingPolicy: collection.VectorEmbeddingPolicy,
requireFormattableOrderByQuery: true,
isContinuationExpected: isContinuationExpected,
allowNonValueAggregateQuery: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,16 @@ public override async Task<ContainerQueryProperties> GetCachedContainerQueryProp
return new ContainerQueryProperties(
containerProperties.ResourceId,
effectivePartitionKeyRange,
containerProperties.PartitionKey,
containerProperties.PartitionKey,
containerProperties.VectorEmbeddingPolicy,
containerProperties.GeospatialConfig.GeospatialType);
}

public override async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetPartitionedQueryExecutionInfoAsync(
SqlQuerySpec sqlQuerySpec,
ResourceType resourceType,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool requireFormattableOrderByQuery,
bool isContinuationExpected,
bool allowNonValueAggregateQuery,
Expand All @@ -116,7 +118,8 @@ public override async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetPartit

return (await this.documentClient.QueryPartitionProvider).TryGetPartitionedQueryExecutionInfo(
querySpecJsonString: queryString,
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: vectorEmbeddingPolicy,
requireFormattableOrderByQuery: requireFormattableOrderByQuery,
isContinuationExpected: isContinuationExpected,
allowNonValueAggregateQuery: allowNonValueAggregateQuery,
Expand Down
4 changes: 3 additions & 1 deletion Microsoft.Azure.Cosmos/src/Routing/PartitionRoutingHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public static IReadOnlyList<Range<string>> GetProvidedPartitionKeyRanges(
bool allowNonValueAggregates,
bool useSystemPrefix,
PartitionKeyDefinition partitionKeyDefinition,
Cosmos.VectorEmbeddingPolicy vectorEmbeddingPolicy,
QueryPartitionProvider queryPartitionProvider,
string clientApiVersion,
Cosmos.GeospatialType geospatialType,
Expand All @@ -57,7 +58,8 @@ public static IReadOnlyList<Range<string>> GetProvidedPartitionKeyRanges(

TryCatch<PartitionedQueryExecutionInfo> tryGetPartitionQueryExecutionInfo = queryPartitionProvider.TryGetPartitionedQueryExecutionInfo(
querySpecJsonString: querySpecJsonString,
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: vectorEmbeddingPolicy,
requireFormattableOrderByQuery: VersionUtility.IsLaterThan(clientApiVersion, HttpConstants.VersionDates.v2016_11_14),
isContinuationExpected: isContinuationExpected,
allowNonValueAggregateQuery: allowNonValueAggregates,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1723,9 +1723,10 @@ public async Task ItemEpkQuerySingleKeyRangeValidation()

ContainerQueryProperties containerQueryProperties = new ContainerQueryProperties(
containerResponse.Resource.ResourceId,
null,
effectivePartitionKeyRanges: null,
//new List<Documents.Routing.Range<string>> { new Documents.Routing.Range<string>("AA", "AA", true, true) },
containerResponse.Resource.PartitionKey,
containerResponse.Resource.PartitionKey,
vectorEmbeddingPolicy: null,
containerResponse.Resource.GeospatialConfig.GeospatialType);

// There should only be one range since the EPK option is set.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ public async Task GetTargetPartitionKeyRangesAsyncWithFeedRange()

ContainerQueryProperties containerQueryProperties = new ContainerQueryProperties(
containerResponse.Resource.ResourceId,
null,
containerResponse.Resource.PartitionKey,
effectivePartitionKeyRanges: null,
containerResponse.Resource.PartitionKey,
vectorEmbeddingPolicy: null,
containerResponse.Resource.GeospatialConfig.GeospatialType);

IReadOnlyList<FeedRange> feedTokens = await container.GetFeedRangesAsync();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ private static void ParseUsingNativeParser(SqlQuerySpec sqlQuerySpec)
{
TryCatch<PartitionedQueryExecutionInfo> tryGetQueryPlan = QueryPartitionProvider.TryGetPartitionedQueryExecutionInfo(
querySpecJsonString: JsonConvert.SerializeObject(sqlQuerySpec),
partitionKeyDefinition: PartitionKeyDefinition,
partitionKeyDefinition: PartitionKeyDefinition,
vectorEmbeddingPolicy: null,
requireFormattableOrderByQuery: true,
isContinuationExpected: false,
allowNonValueAggregateQuery: true,
Expand Down
Loading

0 comments on commit 8c8d3e9

Please sign in to comment.