Skip to content

Commit

Permalink
Query: Adds LINQ Support for FirstOrDefault (#4286)
Browse files Browse the repository at this point in the history
* Initial commit

* Fixed failing aggregate tests.

* Added validation for unsupported overloads of FirstOrDefault.

* Addressed a couple of TODOs.

* Addressed comments.

* Addressed comments.

* Addressed remaining comment.
  • Loading branch information
adityasa authored Jan 30, 2024
1 parent 1bbe101 commit 2c7c7ad
Show file tree
Hide file tree
Showing 20 changed files with 1,924 additions and 1,001 deletions.
3 changes: 1 addition & 2 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ namespace Microsoft.Azure.Cosmos.Linq
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Diagnostics;
Expand Down Expand Up @@ -774,7 +773,7 @@ public static Task<Response<int>> SumAsync(
return ResponseHelperAsync(source.Sum());
}

return ((CosmosLinqQueryProvider)source.Provider).ExecuteAggregateAsync<int?>(
return cosmosLinqQueryProvider.ExecuteAggregateAsync<int?>(
Expression.Call(
GetMethodInfoOf<IQueryable<int?>, int?>(Queryable.Sum),
source.Expression),
Expand Down
95 changes: 77 additions & 18 deletions Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace Microsoft.Azure.Cosmos.Linq
using Microsoft.Azure.Cosmos.Serializer;
using Microsoft.Azure.Cosmos.Tracing;
using Newtonsoft.Json;
using Debug = System.Diagnostics.Debug;

/// <summary>
/// This is the entry point for LINQ query creation/execution, it generate query provider, implements IOrderedQueryable.
Expand Down Expand Up @@ -108,7 +109,12 @@ public IEnumerator<T> GetEnumerator()
" use GetItemQueryIterator to execute asynchronously");
}

FeedIterator<T> localFeedIterator = this.CreateFeedIterator(false);
FeedIterator<T> localFeedIterator = this.CreateFeedIterator(false, out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

while (localFeedIterator.HasMoreResults)
{
#pragma warning disable VSTHRD002 // Avoid problematic synchronous waits
Expand All @@ -133,7 +139,7 @@ IEnumerator IEnumerable.GetEnumerator()

public override string ToString()
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions).SqlQuerySpec;
if (querySpec != null)
{
return JsonConvert.SerializeObject(querySpec);
Expand All @@ -144,20 +150,36 @@ public override string ToString()

public QueryDefinition ToQueryDefinition(IDictionary<object, string> parameters = null)
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions, parameters);
return QueryDefinition.CreateFromQuerySpec(querySpec);
LinqQueryOperation linqQueryOperation = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions, parameters);
ScalarOperationKind scalarOperationKind = linqQueryOperation.ScalarOperationKind;
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

return QueryDefinition.CreateFromQuerySpec(linqQueryOperation.SqlQuerySpec);
}

public FeedIterator<T> ToFeedIterator()
{
return new FeedIteratorInlineCore<T>(this.CreateFeedIterator(true),
this.container.ClientContext);
FeedIterator<T> iterator = this.CreateFeedIterator(true, out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

return new FeedIteratorInlineCore<T>(iterator, this.container.ClientContext);
}

public FeedIterator ToStreamIterator()
{
return new FeedIteratorInlineCore(this.CreateStreamIterator(true),
this.container.ClientContext);
FeedIterator iterator = this.CreateStreamIterator(true, out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

return new FeedIteratorInlineCore(iterator, this.container.ClientContext);
}

public void Dispose()
Expand All @@ -180,15 +202,18 @@ internal async Task<Response<T>> AggregateResultAsync(CancellationToken cancella
List<T> result = new List<T>();
Headers headers = new Headers();

FeedIterator<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false);
FeedIteratorInternal<T> localFeedIteratorInternal = (FeedIteratorInternal<T>)localFeedIterator;
FeedIteratorInlineCore<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, scalarOperationKind: out ScalarOperationKind scalarOperationKind);
Debug.Assert(
scalarOperationKind == ScalarOperationKind.None,
"CosmosLinqQuery Assert!",
$"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'");

ITrace rootTrace;
using (rootTrace = Trace.GetRootTrace("Aggregate LINQ Operation"))
{
while (localFeedIterator.HasMoreResults)
{
FeedResponse<T> response = await localFeedIteratorInternal.ReadNextAsync(rootTrace, cancellationToken);
FeedResponse<T> response = await localFeedIterator.ReadNextAsync(rootTrace, cancellationToken);
headers.RequestCharge += response.RequestCharge;
result.AddRange(response);
}
Expand All @@ -202,23 +227,57 @@ internal async Task<Response<T>> AggregateResultAsync(CancellationToken cancella
null);
}

private FeedIteratorInternal CreateStreamIterator(bool isContinuationExcpected)
internal T ExecuteScalar()
{
FeedIteratorInlineCore<T> localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, out ScalarOperationKind scalarOperationKind);
Headers headers = new Headers();

List<T> result = new List<T>();
ITrace rootTrace;
using (rootTrace = Trace.GetRootTrace("Scalar LINQ Operation"))
{
while (localFeedIterator.HasMoreResults)
{
FeedResponse<T> response = localFeedIterator.ReadNextAsync(rootTrace, cancellationToken: default).GetAwaiter().GetResult();
headers.RequestCharge += response.RequestCharge;
result.AddRange(response);
}
}

switch (scalarOperationKind)
{
case ScalarOperationKind.FirstOrDefault:
return result.FirstOrDefault();

// ExecuteScalar gets called when (sync) aggregates such as Max, Min, Sum are invoked on the IQueryable.
// Since query fully supprots these operations, there is no client operation involved.
// In these cases we return FirstOrDefault which handles empty/undefined/null result set from the backend.
case ScalarOperationKind.None:
return result.SingleOrDefault();

default:
throw new InvalidOperationException($"Unsupported scalar operation {scalarOperationKind}");
}
}

private FeedIteratorInternal CreateStreamIterator(bool isContinuationExcpected, out ScalarOperationKind scalarOperationKind)
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);
LinqQueryOperation linqQueryOperation = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);
scalarOperationKind = linqQueryOperation.ScalarOperationKind;

return this.container.GetItemQueryStreamIteratorInternal(
sqlQuerySpec: querySpec,
sqlQuerySpec: linqQueryOperation.SqlQuerySpec,
isContinuationExcpected: isContinuationExcpected,
continuationToken: this.continuationToken,
feedRange: null,
requestOptions: this.cosmosQueryRequestOptions);
}

private FeedIterator<T> CreateFeedIterator(bool isContinuationExpected)
private FeedIteratorInlineCore<T> CreateFeedIterator(bool isContinuationExpected, out ScalarOperationKind scalarOperationKind)
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions);

FeedIteratorInternal streamIterator = this.CreateStreamIterator(isContinuationExpected);
FeedIteratorInternal streamIterator = this.CreateStreamIterator(
isContinuationExpected,
out scalarOperationKind);
return new FeedIteratorInlineCore<T>(new FeedIteratorCore<T>(
streamIterator,
this.responseFactory.CreateQueryFeedUserTypeResponse<T>),
Expand Down
5 changes: 4 additions & 1 deletion Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQueryProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Microsoft.Azure.Cosmos.Linq
{
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Threading;
Expand Down Expand Up @@ -60,6 +61,7 @@ public IQueryable<TElement> CreateQuery<TElement>(Expression expression)

public IQueryable CreateQuery(Expression expression)
{
// ISSUE-TODO-adityasa-2024/1/26 - Investigate if reflection usage can be removed.
Type expressionType = TypeSystem.GetElementType(expression.Type);
Type documentQueryType = typeof(CosmosLinqQuery<bool>).GetGenericTypeDefinition().MakeGenericType(expressionType);
return (IQueryable)Activator.CreateInstance(
Expand All @@ -76,6 +78,7 @@ public IQueryable CreateQuery(Expression expression)

public TResult Execute<TResult>(Expression expression)
{
// ISSUE-TODO-adityasa-2024/1/26 - We should be able to delegate the implementation to ExecuteAggregateAsync method below by providing an Async implementation of ExecuteScalar.
Type cosmosQueryType = typeof(CosmosLinqQuery<bool>).GetGenericTypeDefinition().MakeGenericType(typeof(TResult));
CosmosLinqQuery<TResult> cosmosLINQQuery = (CosmosLinqQuery<TResult>)Activator.CreateInstance(
cosmosQueryType,
Expand All @@ -88,7 +91,7 @@ public TResult Execute<TResult>(Expression expression)
this.allowSynchronousQueryExecution,
this.linqSerializerOptions);
this.onExecuteScalarQueryCallback?.Invoke(cosmosLINQQuery);
return cosmosLINQQuery.ToList().FirstOrDefault();
return cosmosLINQQuery.ExecuteScalar();
}

//Sync execution of query via direct invoke on IQueryProvider.
Expand Down
6 changes: 3 additions & 3 deletions Microsoft.Azure.Cosmos/src/Linq/DocumentQuery.cs
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,10 @@ IEnumerator IEnumerable.GetEnumerator()

public override string ToString()
{
SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression);
if (querySpec != null)
LinqQueryOperation linqQueryOperation = DocumentQueryEvaluator.Evaluate(this.Expression);
if (linqQueryOperation.SqlQuerySpec != null)
{
return JsonConvert.SerializeObject(querySpec);
return JsonConvert.SerializeObject(linqQueryOperation.SqlQuerySpec);
}

return new Uri(this.client.ServiceEndpoint, this.documentsFeedOrDatabaseLink).ToString();
Expand Down
17 changes: 9 additions & 8 deletions Microsoft.Azure.Cosmos/src/Linq/DocumentQueryEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal static class DocumentQueryEvaluator
{
private const string SQLMethod = "AsSQL";

public static SqlQuerySpec Evaluate(
public static LinqQueryOperation Evaluate(
Expression expression,
CosmosLinqSerializerOptionsInternal linqSerializerOptions = null,
IDictionary<object, string> parameters = null)
Expand Down Expand Up @@ -51,7 +51,7 @@ public static bool IsTransformExpression(Expression expression)
/// foreach(Database db in client.CreateDatabaseQuery()) {}
/// </summary>
/// <param name="expression"></param>
private static SqlQuerySpec HandleEmptyQuery(ConstantExpression expression)
private static LinqQueryOperation HandleEmptyQuery(ConstantExpression expression)
{
if (expression.Value == null)
{
Expand All @@ -69,11 +69,12 @@ private static SqlQuerySpec HandleEmptyQuery(ConstantExpression expression)
ClientResources.BadQuery_InvalidExpression,
expression.ToString()));
}

//No query specified.
return null;
return new LinqQueryOperation(sqlQuerySpec: null, scalarOperationKind: ScalarOperationKind.None);
}

private static SqlQuerySpec HandleMethodCallExpression(
private static LinqQueryOperation HandleMethodCallExpression(
MethodCallExpression expression,
IDictionary<object, string> parameters,
CosmosLinqSerializerOptionsInternal linqSerializerOptions = null)
Expand All @@ -100,7 +101,7 @@ private static SqlQuerySpec HandleMethodCallExpression(
/// foreach(string record in client.CreateDocumentQuery().Navigate("Raw JQuery"))
/// </summary>
/// <param name="expression"></param>
private static SqlQuerySpec HandleAsSqlTransformExpression(MethodCallExpression expression)
private static LinqQueryOperation HandleAsSqlTransformExpression(MethodCallExpression expression)
{
Expression paramExpression = expression.Arguments[1];

Expand All @@ -122,7 +123,7 @@ private static SqlQuerySpec HandleAsSqlTransformExpression(MethodCallExpression
}
}

private static SqlQuerySpec GetSqlQuerySpec(object value)
private static LinqQueryOperation GetSqlQuerySpec(object value)
{
if (value == null)
{
Expand All @@ -133,11 +134,11 @@ private static SqlQuerySpec GetSqlQuerySpec(object value)
}
else if (value.GetType() == typeof(SqlQuerySpec))
{
return (SqlQuerySpec)value;
return new LinqQueryOperation((SqlQuerySpec)value, ScalarOperationKind.None);
}
else if (value.GetType() == typeof(string))
{
return new SqlQuerySpec((string)value);
return new LinqQueryOperation(new SqlQuerySpec((string)value), ScalarOperationKind.None);
}
else
{
Expand Down
Loading

0 comments on commit 2c7c7ad

Please sign in to comment.