Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query: Adds LINQ Support for FirstOrDefault #4286

Merged
merged 10 commits into from
Jan 30, 2024
3 changes: 1 addition & 2 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ namespace Microsoft.Azure.Cosmos.Linq
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Diagnostics;
Expand Down Expand Up @@ -774,7 +773,7 @@ public static Task<Response<int>> SumAsync(
return ResponseHelperAsync(source.Sum());
}

return ((CosmosLinqQueryProvider)source.Provider).ExecuteAggregateAsync<int?>(
return cosmosLinqQueryProvider.ExecuteAggregateAsync<int?>(
Expression.Call(
GetMethodInfoOf<IQueryable<int?>, int?>(Queryable.Sum),
source.Expression),
Expand Down
95 changes: 77 additions & 18 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.Azure.Cosmos.Linq
using Microsoft.Azure.Cosmos.Serializer;
using Microsoft.Azure.Cosmos.Tracing;
using Newtonsoft.Json;
using Debug = System.Diagnostics.Debug;

/// <summary>
/// This is the entry point for LINQ query creation/execution, it generate query provider, implements IOrderedQueryable.
Expand Down Expand Up @@ -108,7 +109,12 @@ public IEnumerator<T> GetEnumerator()
" use GetItemQueryIterator to execute asynchronously");
}

FeedIterator<T> localFeedIterator = this.CreateFeedIterator(false);
FeedIterator<T> localFeedIterator = this.CreateFeedIterator(false, out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

while (localFeedIterator.HasMoreResults)
{
#pragma warning disable VSTHRD002 // Avoid problematic synchronous waits
Expand All @@ -133,7 +139,7 @@ IEnumerator IEnumerable.GetEnumerator()

public override string ToString()
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions).SqlQuerySpec;
if (querySpec != null)
{
return JsonConvert.SerializeObject(querySpec);
Expand All @@ -144,20 +150,36 @@ public override string ToString()

public QueryDefinition ToQueryDefinition(IDictionary<object, string> parameters = null)
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions, parameters);
return QueryDefinition.CreateFromQuerySpec(querySpec);
LinqQueryOperation linqQueryOperation = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions, parameters);
ScalarOperationKind scalarOperationKind = linqQueryOperation.ScalarOperationKind;
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

return QueryDefinition.CreateFromQuerySpec(linqQueryOperation.SqlQuerySpec);
}

public FeedIterator<T> ToFeedIterator()
{
return new FeedIteratorInlineCore<T>(this.CreateFeedIterator(true),
this.container.ClientContext);
FeedIterator<T> iterator = this.CreateFeedIterator(true, out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

return new FeedIteratorInlineCore<T>(iterator, this.container.ClientContext);
}

public FeedIterator ToStreamIterator()
{
return new FeedIteratorInlineCore(this.CreateStreamIterator(true),
this.container.ClientContext);
FeedIterator iterator = this.CreateStreamIterator(true, out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

return new FeedIteratorInlineCore(iterator, this.container.ClientContext);
}

public void Dispose()
Expand All @@ -180,15 +202,18 @@ internal async Task<Response<T>> AggregateResultAsync(CancellationToken cancella
List<T> result = new List<T>();
Headers headers = new Headers();

FeedIterator<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false);
FeedIteratorInternal<T> localFeedIteratorInternal = (FeedIteratorInternal<T>)localFeedIterator;
FeedIteratorInlineCore<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, scalarOperationKind: out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

ITrace rootTrace;
using (rootTrace = Trace.GetRootTrace("Aggregate LINQ Operation"))
{
while (localFeedIterator.HasMoreResults)
{
FeedResponse<T> response = await localFeedIteratorInternal.ReadNextAsync(rootTrace, cancellationToken);
FeedResponse<T> response = await localFeedIterator.ReadNextAsync(rootTrace, cancellationToken);
headers.RequestCharge += response.RequestCharge;
result.AddRange(response);
}
Expand All @@ -202,23 +227,57 @@ internal async Task<Response<T>> AggregateResultAsync(CancellationToken cancella
null);
}

private FeedIteratorInternal CreateStreamIterator(bool isContinuationExcpected)
internal T ExecuteScalar()
{
FeedIteratorInlineCore<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, out ScalarOperationKind scalarOperationKind);
Headers headers = new Headers();

List<T> result = new List<T>();
ITrace rootTrace;
using (rootTrace = Trace.GetRootTrace("Scalar LINQ Operation"))
{
while (localFeedIterator.HasMoreResults)
{
FeedResponse<T> response = localFeedIterator.ReadNextAsync(rootTrace, cancellationToken: default).GetAwaiter().GetResult();
headers.RequestCharge += response.RequestCharge;
result.AddRange(response);
}
}

switch (scalarOperationKind)
{
case ScalarOperationKind.FirstOrDefault:
return result.FirstOrDefault();

// ExecuteScalar gets called when (sync) aggregates such as Max, Min, Sum are invoked on the IQueryable.
// Since query fully supprots these operations, there is no client operation involved.
// In these cases we return FirstOrDefault which handles empty/undefined/null result set from the backend.
case ScalarOperationKind.None:
return result.SingleOrDefault();

default:
throw new InvalidOperationException($"Unsupported scalar operation {scalarOperationKind}");
}
}

private FeedIteratorInternal CreateStreamIterator(bool isContinuationExcpected, out ScalarOperationKind scalarOperationKind)
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);
LinqQueryOperation linqQueryOperation = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);
scalarOperationKind = linqQueryOperation.ScalarOperationKind;

return this.container.GetItemQueryStreamIteratorInternal(
sqlQuerySpec: querySpec,
sqlQuerySpec: linqQueryOperation.SqlQuerySpec,
isContinuationExcpected: isContinuationExcpected,
continuationToken: this.continuationToken,
feedRange: null,
requestOptions: this.cosmosQueryRequestOptions);
}

private FeedIterator<T> CreateFeedIterator(bool isContinuationExpected)
private FeedIteratorInlineCore<T> CreateFeedIterator(bool isContinuationExpected, out ScalarOperationKind scalarOperationKind)
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);

FeedIteratorInternal streamIterator = this.CreateStreamIterator(isContinuationExpected);
FeedIteratorInternal streamIterator = this.CreateStreamIterator(
isContinuationExpected,
out scalarOperationKind);
return new FeedIteratorInlineCore<T>(new FeedIteratorCore<T>(
streamIterator,
this.responseFactory.CreateQueryFeedUserTypeResponse<T>),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Microsoft.Azure.Cosmos.Linq
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
Expand Down Expand Up @@ -60,6 +61,7 @@ public IQueryable<TElement> CreateQuery<TElement>(Expression expression)

public IQueryable CreateQuery(Expression expression)
{
// ISSUE-TODO-adityasa-2024/1/26 - Investigate if reflection usage can be removed.
Type expressionType = TypeSystem.GetElementType(expression.Type);
Type documentQueryType = typeof(CosmosLinqQuery<bool>).GetGenericTypeDefinition().MakeGenericType(expressionType);
return (IQueryable)Activator.CreateInstance(
Expand All @@ -76,6 +78,7 @@ public IQueryable CreateQuery(Expression expression)

public TResult Execute<TResult>(Expression expression)
{
// ISSUE-TODO-adityasa-2024/1/26 - We should be able to delegate the implementation to ExecuteAggregateAsync method below by providing an Async implementation of ExecuteScalar.
Type cosmosQueryType = typeof(CosmosLinqQuery<bool>).GetGenericTypeDefinition().MakeGenericType(typeof(TResult));
CosmosLinqQuery<TResult> cosmosLINQQuery = (CosmosLinqQuery<TResult>)Activator.CreateInstance(
cosmosQueryType,
Expand All @@ -88,7 +91,7 @@ public TResult Execute<TResult>(Expression expression)
this.allowSynchronousQueryExecution,
this.linqSerializerOptions);
this.onExecuteScalarQueryCallback?.Invoke(cosmosLINQQuery);
return cosmosLINQQuery.ToList().FirstOrDefault();
return cosmosLINQQuery.ExecuteScalar();
}

//Sync execution of query via direct invoke on IQueryProvider.
Expand Down
6 changes: 3 additions & 3 deletions Microsoft.Azure.Cosmos/src/Linq/DocumentQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,10 @@ IEnumerator IEnumerable.GetEnumerator()

public override string ToString()
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression);
if (querySpec != null)
LinqQueryOperation linqQueryOperation = DocumentQueryEvaluator.Evaluate(this.Expression);
if (linqQueryOperation.SqlQuerySpec != null)
{
return JsonConvert.SerializeObject(querySpec);
return JsonConvert.SerializeObject(linqQueryOperation.SqlQuerySpec);
}

return new Uri(this.client.ServiceEndpoint, this.documentsFeedOrDatabaseLink).ToString();
Expand Down
17 changes: 9 additions & 8 deletions Microsoft.Azure.Cosmos/src/Linq/DocumentQueryEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal static class DocumentQueryEvaluator
{
private const string SQLMethod = "AsSQL";

public static SqlQuerySpec Evaluate(
public static LinqQueryOperation Evaluate(
Expression expression,
CosmosLinqSerializerOptionsInternal linqSerializerOptions = null,
IDictionary<object, string> parameters = null)
Expand Down Expand Up @@ -51,7 +51,7 @@ public static bool IsTransformExpression(Expression expression)
/// foreach(Database db in client.CreateDatabaseQuery()) {}
/// </summary>
/// <param name="expression"></param>
private static SqlQuerySpec HandleEmptyQuery(ConstantExpression expression)
private static LinqQueryOperation HandleEmptyQuery(ConstantExpression expression)
{
if (expression.Value == null)
{
Expand All @@ -69,11 +69,12 @@ private static SqlQuerySpec HandleEmptyQuery(ConstantExpression expression)
ClientResources.BadQuery_InvalidExpression,
expression.ToString()));
}

//No query specified.
return null;
return new LinqQueryOperation(sqlQuerySpec: null, scalarOperationKind: ScalarOperationKind.None);
}

private static SqlQuerySpec HandleMethodCallExpression(
private static LinqQueryOperation HandleMethodCallExpression(
MethodCallExpression expression,
IDictionary<object, string> parameters,
CosmosLinqSerializerOptionsInternal linqSerializerOptions = null)
Expand All @@ -100,7 +101,7 @@ private static SqlQuerySpec HandleMethodCallExpression(
/// foreach(string record in client.CreateDocumentQuery().Navigate("Raw JQuery"))
/// </summary>
/// <param name="expression"></param>
private static SqlQuerySpec HandleAsSqlTransformExpression(MethodCallExpression expression)
private static LinqQueryOperation HandleAsSqlTransformExpression(MethodCallExpression expression)
{
Expression paramExpression = expression.Arguments[1];

Expand All @@ -122,7 +123,7 @@ private static SqlQuerySpec HandleAsSqlTransformExpression(MethodCallExpression
}
}

private static SqlQuerySpec GetSqlQuerySpec(object value)
private static LinqQueryOperation GetSqlQuerySpec(object value)
{
if (value == null)
{
Expand All @@ -133,11 +134,11 @@ private static SqlQuerySpec GetSqlQuerySpec(object value)
}
else if (value.GetType() == typeof(SqlQuerySpec))
{
return (SqlQuerySpec)value;
return new LinqQueryOperation((SqlQuerySpec)value, ScalarOperationKind.None);
}
else if (value.GetType() == typeof(string))
{
return new SqlQuerySpec((string)value);
return new LinqQueryOperation(new SqlQuerySpec((string)value), ScalarOperationKind.None);
}
else
{
Expand Down
Loading
Loading