Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Fixes plumbing VectorEmbeddingPolicy to ServiceInterop to choose correct default distance function #4538

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ private static async Task<PartitionedQueryExecutionInfo> GetPartitionedQueryExec
inputParameters.SqlQuerySpec,
cosmosQueryContext.ResourceTypeEnum,
partitionKeyDefinition,
containerQueryProperties.VectorEmbeddingPolicy,
inputParameters.PartitionKey != null,
containerQueryProperties.GeospatialType,
cosmosQueryContext.UseSystemPrefix,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ internal readonly struct ContainerQueryProperties
public ContainerQueryProperties(
string resourceId,
IReadOnlyList<Range<string>> effectivePartitionKeyRanges,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
Cosmos.VectorEmbeddingPolicy vectorEmbeddingPolicy,
Cosmos.GeospatialType geospatialType)
{
this.ResourceId = resourceId;
this.EffectiveRangesForPartitionKey = effectivePartitionKeyRanges;
this.PartitionKeyDefinition = partitionKeyDefinition;
this.PartitionKeyDefinition = partitionKeyDefinition;
this.VectorEmbeddingPolicy = vectorEmbeddingPolicy;
this.GeospatialType = geospatialType;
}

Expand All @@ -27,7 +29,11 @@ public ContainerQueryProperties(
//A PartitionKey has one range when it is a full PartitionKey value.
//It can span many it is a prefix PartitionKey for a sub-partitioned container.
public IReadOnlyList<Range<string>> EffectiveRangesForPartitionKey { get; }
public PartitionKeyDefinition PartitionKeyDefinition { get; }

public PartitionKeyDefinition PartitionKeyDefinition { get; }

public Cosmos.VectorEmbeddingPolicy VectorEmbeddingPolicy { get; }

public Cosmos.GeospatialType GeospatialType { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public abstract Task<ContainerQueryProperties> GetCachedContainerQueryProperties
public abstract Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetPartitionedQueryExecutionInfoAsync(
SqlQuerySpec sqlQuerySpec,
Documents.ResourceType resourceType,
Documents.PartitionKeyDefinition partitionKeyDefinition,
Documents.PartitionKeyDefinition partitionKeyDefinition,
Cosmos.VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool requireFormattableOrderByQuery,
bool isContinuationExpected,
bool allowNonValueAggregateQuery,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ public void Update(IDictionary<string, object> queryengineConfiguration)
public TryCatch<PartitionedQueryExecutionInfo> TryGetPartitionedQueryExecutionInfo(
string querySpecJsonString,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool requireFormattableOrderByQuery,
bool isContinuationExpected,
bool allowNonValueAggregateQuery,
Expand All @@ -130,7 +131,8 @@ public TryCatch<PartitionedQueryExecutionInfo> TryGetPartitionedQueryExecutionIn
{
TryCatch<PartitionedQueryExecutionInfoInternal> tryGetInternalQueryInfo = this.TryGetPartitionedQueryExecutionInfoInternal(
querySpecJsonString: querySpecJsonString,
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: vectorEmbeddingPolicy,
requireFormattableOrderByQuery: requireFormattableOrderByQuery,
isContinuationExpected: isContinuationExpected,
allowNonValueAggregateQuery: allowNonValueAggregateQuery,
Expand Down Expand Up @@ -179,7 +181,8 @@ internal PartitionedQueryExecutionInfo ConvertPartitionedQueryExecutionInfo(

internal TryCatch<PartitionedQueryExecutionInfoInternal> TryGetPartitionedQueryExecutionInfoInternal(
string querySpecJsonString,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool requireFormattableOrderByQuery,
bool isContinuationExpected,
bool allowNonValueAggregateQuery,
Expand Down Expand Up @@ -222,8 +225,12 @@ internal TryCatch<PartitionedQueryExecutionInfoInternal> TryGetPartitionedQueryE

Span<byte> buffer = stackalloc byte[QueryPartitionProvider.InitialBufferSize];
uint errorCode;
uint serializedQueryExecutionInfoResultLength;

uint serializedQueryExecutionInfoResultLength;

string vectorEmbeddingPolicyString = vectorEmbeddingPolicy != null ?
JsonConvert.SerializeObject(vectorEmbeddingPolicy) :
null;

unsafe
{
ServiceInteropWrapper.PartitionKeyRangesApiOptions partitionKeyRangesApiOptions =
Expand All @@ -241,13 +248,15 @@ internal TryCatch<PartitionedQueryExecutionInfoInternal> TryGetPartitionedQueryE

fixed (byte* bytePtr = buffer)
{
errorCode = ServiceInteropWrapper.GetPartitionKeyRangesFromQuery3(
errorCode = ServiceInteropWrapper.GetPartitionKeyRangesFromQuery4(
this.serviceProvider,
querySpecJsonString,
partitionKeyRangesApiOptions,
allParts,
partsLengths,
(uint)partitionKeyDefinition.Paths.Count,
(uint)partitionKeyDefinition.Paths.Count,
vectorEmbeddingPolicyString,
vectorEmbeddingPolicyString?.Length ?? 0,
new IntPtr(bytePtr),
(uint)buffer.Length,
out serializedQueryExecutionInfoResultLength);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ public QueryPlanHandler(CosmosQueryClient queryClient)
public async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryPlanAsync(
SqlQuerySpec sqlQuerySpec,
Documents.ResourceType resourceType,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
QueryFeatures supportedQueryFeatures,
bool hasLogicalPartitionKey,
bool useSystemPrefix,
Expand All @@ -47,7 +48,8 @@ public async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryPlanAsync(
TryCatch<PartitionedQueryExecutionInfo> tryGetQueryInfo = await this.TryGetQueryInfoAsync(
sqlQuerySpec,
resourceType,
partitionKeyDefinition,
partitionKeyDefinition,
vectorEmbeddingPolicy,
hasLogicalPartitionKey,
useSystemPrefix,
geospatialType,
Expand Down Expand Up @@ -75,7 +77,8 @@ public async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryPlanAsync(
QueryFeatures supportedQueryFeatures,
SqlQuerySpec sqlQuerySpec,
Documents.ResourceType resourceType,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool hasLogicalPartitionKey,
bool useSystemPrefix,
GeospatialType geospatialType,
Expand All @@ -96,7 +99,8 @@ public async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryPlanAsync(
TryCatch<PartitionedQueryExecutionInfo> tryGetQueryInfo = await this.TryGetQueryInfoAsync(
sqlQuerySpec,
resourceType,
partitionKeyDefinition,
partitionKeyDefinition,
vectorEmbeddingPolicy,
hasLogicalPartitionKey,
useSystemPrefix,
geospatialType,
Expand All @@ -115,7 +119,8 @@ public async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryPlanAsync(
private Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryInfoAsync(
SqlQuerySpec sqlQuerySpec,
Documents.ResourceType resourceType,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool hasLogicalPartitionKey,
bool useSystemPrefix,
Cosmos.GeospatialType geospatialType,
Expand All @@ -126,7 +131,8 @@ private Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetQueryInfoAsync(
return this.queryClient.TryGetPartitionedQueryExecutionInfoAsync(
sqlQuerySpec: sqlQuerySpec,
resourceType: resourceType,
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: vectorEmbeddingPolicy,
requireFormattableOrderByQuery: true,
isContinuationExpected: false,
allowNonValueAggregateQuery: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ public static async Task<PartitionedQueryExecutionInfo> GetQueryPlanWithServiceI
CosmosQueryClient queryClient,
SqlQuerySpec sqlQuerySpec,
Documents.ResourceType resourceType,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool hasLogicalPartitionKey,
GeospatialType geospatialType,
bool useSystemPrefix,
Expand Down Expand Up @@ -81,7 +82,8 @@ public static async Task<PartitionedQueryExecutionInfo> GetQueryPlanWithServiceI
TryCatch<PartitionedQueryExecutionInfo> tryGetQueryPlan = await queryPlanHandler.TryGetQueryPlanAsync(
sqlQuerySpec,
resourceType,
partitionKeyDefinition,
partitionKeyDefinition,
vectorEmbeddingPolicy,
QueryPlanRetriever.SupportedQueryFeatures,
hasLogicalPartitionKey,
useSystemPrefix,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ private static bool ServiceInteropAvailable()
allowDCount: false,
allowNonValueAggregates: false,
useSystemPrefix: false,
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: collection.VectorEmbeddingPolicy,
queryPartitionProvider: queryPartitionProvider,
clientApiVersion: version,
geospatialType: collection.GeospatialConfig.GeospatialType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ protected SqlQuerySpec QuerySpec
public Guid CorrelatedActivityId { get; }

public async Task<PartitionedQueryExecutionInfo> GetPartitionedQueryExecutionInfoAsync(
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
Cosmos.VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool requireFormattableOrderByQuery,
bool isContinuationExpected,
bool allowNonValueAggregateQuery,
Expand All @@ -179,7 +180,8 @@ public async Task<PartitionedQueryExecutionInfo> GetPartitionedQueryExecutionInf
QueryPartitionProvider queryPartitionProvider = await this.Client.GetQueryPartitionProviderAsync();
TryCatch<PartitionedQueryExecutionInfo> tryGetPartitionedQueryExecutionInfo = queryPartitionProvider.TryGetPartitionedQueryExecutionInfo(
querySpecJsonString: JsonConvert.SerializeObject(this.QuerySpec),
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: vectorEmbeddingPolicy,
requireFormattableOrderByQuery: requireFormattableOrderByQuery,
isContinuationExpected: isContinuationExpected,
allowNonValueAggregateQuery: allowNonValueAggregateQuery,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ public static async Task<IDocumentQueryExecutionContext> CreateDocumentQueryExec
//if collection is deleted/created with same name.
//need to make it not rely on information from collection cache.
PartitionedQueryExecutionInfo partitionedQueryExecutionInfo = await queryExecutionContext.GetPartitionedQueryExecutionInfoAsync(
partitionKeyDefinition: collection.PartitionKey,
partitionKeyDefinition: collection.PartitionKey,
vectorEmbeddingPolicy: collection.VectorEmbeddingPolicy,
requireFormattableOrderByQuery: true,
isContinuationExpected: isContinuationExpected,
allowNonValueAggregateQuery: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,16 @@ public override async Task<ContainerQueryProperties> GetCachedContainerQueryProp
return new ContainerQueryProperties(
containerProperties.ResourceId,
effectivePartitionKeyRange,
containerProperties.PartitionKey,
containerProperties.PartitionKey,
containerProperties.VectorEmbeddingPolicy,
containerProperties.GeospatialConfig.GeospatialType);
}

public override async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetPartitionedQueryExecutionInfoAsync(
SqlQuerySpec sqlQuerySpec,
ResourceType resourceType,
PartitionKeyDefinition partitionKeyDefinition,
PartitionKeyDefinition partitionKeyDefinition,
VectorEmbeddingPolicy vectorEmbeddingPolicy,
bool requireFormattableOrderByQuery,
bool isContinuationExpected,
bool allowNonValueAggregateQuery,
Expand All @@ -116,7 +118,8 @@ public override async Task<TryCatch<PartitionedQueryExecutionInfo>> TryGetPartit

return (await this.documentClient.QueryPartitionProvider).TryGetPartitionedQueryExecutionInfo(
querySpecJsonString: queryString,
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: vectorEmbeddingPolicy,
requireFormattableOrderByQuery: requireFormattableOrderByQuery,
isContinuationExpected: isContinuationExpected,
allowNonValueAggregateQuery: allowNonValueAggregateQuery,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public static IReadOnlyList<Range<string>> GetProvidedPartitionKeyRanges(
bool allowNonValueAggregates,
bool useSystemPrefix,
PartitionKeyDefinition partitionKeyDefinition,
Cosmos.VectorEmbeddingPolicy vectorEmbeddingPolicy,
QueryPartitionProvider queryPartitionProvider,
string clientApiVersion,
Cosmos.GeospatialType geospatialType,
Expand All @@ -57,7 +58,8 @@ public static IReadOnlyList<Range<string>> GetProvidedPartitionKeyRanges(

TryCatch<PartitionedQueryExecutionInfo> tryGetPartitionQueryExecutionInfo = queryPartitionProvider.TryGetPartitionedQueryExecutionInfo(
querySpecJsonString: querySpecJsonString,
partitionKeyDefinition: partitionKeyDefinition,
partitionKeyDefinition: partitionKeyDefinition,
vectorEmbeddingPolicy: vectorEmbeddingPolicy,
requireFormattableOrderByQuery: VersionUtility.IsLaterThan(clientApiVersion, HttpConstants.VersionDates.v2016_11_14),
isContinuationExpected: isContinuationExpected,
allowNonValueAggregateQuery: allowNonValueAggregates,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1723,9 +1723,10 @@ public async Task ItemEpkQuerySingleKeyRangeValidation()

ContainerQueryProperties containerQueryProperties = new ContainerQueryProperties(
containerResponse.Resource.ResourceId,
null,
effectivePartitionKeyRanges: null,
//new List<Documents.Routing.Range<string>> { new Documents.Routing.Range<string>("AA", "AA", true, true) },
containerResponse.Resource.PartitionKey,
containerResponse.Resource.PartitionKey,
vectorEmbeddingPolicy: null,
containerResponse.Resource.GeospatialConfig.GeospatialType);

// There should only be one range since the EPK option is set.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ public async Task GetTargetPartitionKeyRangesAsyncWithFeedRange()

ContainerQueryProperties containerQueryProperties = new ContainerQueryProperties(
containerResponse.Resource.ResourceId,
null,
containerResponse.Resource.PartitionKey,
effectivePartitionKeyRanges: null,
containerResponse.Resource.PartitionKey,
vectorEmbeddingPolicy: null,
containerResponse.Resource.GeospatialConfig.GeospatialType);

IReadOnlyList<FeedRange> feedTokens = await container.GetFeedRangesAsync();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ private static void ParseUsingNativeParser(SqlQuerySpec sqlQuerySpec)
{
TryCatch<PartitionedQueryExecutionInfo> tryGetQueryPlan = QueryPartitionProvider.TryGetPartitionedQueryExecutionInfo(
querySpecJsonString: JsonConvert.SerializeObject(sqlQuerySpec),
partitionKeyDefinition: PartitionKeyDefinition,
partitionKeyDefinition: PartitionKeyDefinition,
vectorEmbeddingPolicy: null,
requireFormattableOrderByQuery: true,
isContinuationExpected: false,
allowNonValueAggregateQuery: true,
Expand Down
Loading