diff --git a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs index 619c1c1ecf..7f7a22b466 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs @@ -10,7 +10,6 @@ namespace Microsoft.Azure.Cosmos.Linq using System.Linq; using System.Linq.Expressions; using System.Reflection; - using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Cosmos.Diagnostics; @@ -774,7 +773,7 @@ public static Task> SumAsync( return ResponseHelperAsync(source.Sum()); } - return ((CosmosLinqQueryProvider)source.Provider).ExecuteAggregateAsync( + return cosmosLinqQueryProvider.ExecuteAggregateAsync( Expression.Call( GetMethodInfoOf, int?>(Queryable.Sum), source.Expression), diff --git a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs index a4c4c5e938..6676d096c9 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQuery.cs @@ -16,6 +16,7 @@ namespace Microsoft.Azure.Cosmos.Linq using Microsoft.Azure.Cosmos.Serializer; using Microsoft.Azure.Cosmos.Tracing; using Newtonsoft.Json; + using Debug = System.Diagnostics.Debug; /// /// This is the entry point for LINQ query creation/execution, it generate query provider, implements IOrderedQueryable. @@ -108,7 +109,12 @@ public IEnumerator GetEnumerator() " use GetItemQueryIterator to execute asynchronously"); } - FeedIterator localFeedIterator = this.CreateFeedIterator(false); + FeedIterator localFeedIterator = this.CreateFeedIterator(false, out ScalarOperationKind scalarOperationKind); + Debug.Assert( + scalarOperationKind == ScalarOperationKind.None, + "CosmosLinqQuery Assert!", + $"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'"); + while (localFeedIterator.HasMoreResults) { #pragma warning disable VSTHRD002 // Avoid problematic synchronous waits @@ -133,7 +139,7 @@ IEnumerator IEnumerable.GetEnumerator() public override string ToString() { - SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions); + SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions).SqlQuerySpec; if (querySpec != null) { return JsonConvert.SerializeObject(querySpec); @@ -144,20 +150,36 @@ public override string ToString() public QueryDefinition ToQueryDefinition(IDictionary parameters = null) { - SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions, parameters); - return QueryDefinition.CreateFromQuerySpec(querySpec); + LinqQueryOperation linqQueryOperation = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions, parameters); + ScalarOperationKind scalarOperationKind = linqQueryOperation.ScalarOperationKind; + Debug.Assert( + scalarOperationKind == ScalarOperationKind.None, + "CosmosLinqQuery Assert!", + $"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'"); + + return QueryDefinition.CreateFromQuerySpec(linqQueryOperation.SqlQuerySpec); } public FeedIterator ToFeedIterator() { - return new FeedIteratorInlineCore(this.CreateFeedIterator(true), - this.container.ClientContext); + FeedIterator iterator = this.CreateFeedIterator(true, out ScalarOperationKind scalarOperationKind); + Debug.Assert( + scalarOperationKind == ScalarOperationKind.None, + "CosmosLinqQuery Assert!", + $"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'"); + + return new FeedIteratorInlineCore(iterator, this.container.ClientContext); } public FeedIterator ToStreamIterator() { - return new FeedIteratorInlineCore(this.CreateStreamIterator(true), - this.container.ClientContext); + FeedIterator iterator = this.CreateStreamIterator(true, out ScalarOperationKind scalarOperationKind); + Debug.Assert( + scalarOperationKind == ScalarOperationKind.None, + "CosmosLinqQuery Assert!", + $"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'"); + + return new FeedIteratorInlineCore(iterator, this.container.ClientContext); } public void Dispose() @@ -180,15 +202,18 @@ internal async Task> AggregateResultAsync(CancellationToken cancella List result = new List(); Headers headers = new Headers(); - FeedIterator localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false); - FeedIteratorInternal localFeedIteratorInternal = (FeedIteratorInternal)localFeedIterator; + FeedIteratorInlineCore localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, scalarOperationKind: out ScalarOperationKind scalarOperationKind); + Debug.Assert( + scalarOperationKind == ScalarOperationKind.None, + "CosmosLinqQuery Assert!", + $"Unexpected client operation. Expected 'None', Received '{scalarOperationKind}'"); ITrace rootTrace; using (rootTrace = Trace.GetRootTrace("Aggregate LINQ Operation")) { while (localFeedIterator.HasMoreResults) { - FeedResponse response = await localFeedIteratorInternal.ReadNextAsync(rootTrace, cancellationToken); + FeedResponse response = await localFeedIterator.ReadNextAsync(rootTrace, cancellationToken); headers.RequestCharge += response.RequestCharge; result.AddRange(response); } @@ -202,23 +227,57 @@ internal async Task> AggregateResultAsync(CancellationToken cancella null); } - private FeedIteratorInternal CreateStreamIterator(bool isContinuationExcpected) + internal T ExecuteScalar() + { + FeedIteratorInlineCore localFeedIterator = this.CreateFeedIterator(isContinuationExpected: false, out ScalarOperationKind scalarOperationKind); + Headers headers = new Headers(); + + List result = new List(); + ITrace rootTrace; + using (rootTrace = Trace.GetRootTrace("Scalar LINQ Operation")) + { + while (localFeedIterator.HasMoreResults) + { + FeedResponse response = localFeedIterator.ReadNextAsync(rootTrace, cancellationToken: default).GetAwaiter().GetResult(); + headers.RequestCharge += response.RequestCharge; + result.AddRange(response); + } + } + + switch (scalarOperationKind) + { + case ScalarOperationKind.FirstOrDefault: + return result.FirstOrDefault(); + + // ExecuteScalar gets called when (sync) aggregates such as Max, Min, Sum are invoked on the IQueryable. + // Since query fully supprots these operations, there is no client operation involved. + // In these cases we return FirstOrDefault which handles empty/undefined/null result set from the backend. + case ScalarOperationKind.None: + return result.SingleOrDefault(); + + default: + throw new InvalidOperationException($"Unsupported scalar operation {scalarOperationKind}"); + } + } + + private FeedIteratorInternal CreateStreamIterator(bool isContinuationExcpected, out ScalarOperationKind scalarOperationKind) { - SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions); + LinqQueryOperation linqQueryOperation = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions); + scalarOperationKind = linqQueryOperation.ScalarOperationKind; return this.container.GetItemQueryStreamIteratorInternal( - sqlQuerySpec: querySpec, + sqlQuerySpec: linqQueryOperation.SqlQuerySpec, isContinuationExcpected: isContinuationExcpected, continuationToken: this.continuationToken, feedRange: null, requestOptions: this.cosmosQueryRequestOptions); } - private FeedIterator CreateFeedIterator(bool isContinuationExpected) + private FeedIteratorInlineCore CreateFeedIterator(bool isContinuationExpected, out ScalarOperationKind scalarOperationKind) { - SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression, this.linqSerializationOptions); - - FeedIteratorInternal streamIterator = this.CreateStreamIterator(isContinuationExpected); + FeedIteratorInternal streamIterator = this.CreateStreamIterator( + isContinuationExpected, + out scalarOperationKind); return new FeedIteratorInlineCore(new FeedIteratorCore( streamIterator, this.responseFactory.CreateQueryFeedUserTypeResponse), diff --git a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQueryProvider.cs b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQueryProvider.cs index 2c695c6477..d681caaef5 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQueryProvider.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/CosmosLinqQueryProvider.cs @@ -5,6 +5,7 @@ namespace Microsoft.Azure.Cosmos.Linq { using System; + using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Threading; @@ -60,6 +61,7 @@ public IQueryable CreateQuery(Expression expression) public IQueryable CreateQuery(Expression expression) { + // ISSUE-TODO-adityasa-2024/1/26 - Investigate if reflection usage can be removed. Type expressionType = TypeSystem.GetElementType(expression.Type); Type documentQueryType = typeof(CosmosLinqQuery).GetGenericTypeDefinition().MakeGenericType(expressionType); return (IQueryable)Activator.CreateInstance( @@ -76,6 +78,7 @@ public IQueryable CreateQuery(Expression expression) public TResult Execute(Expression expression) { + // ISSUE-TODO-adityasa-2024/1/26 - We should be able to delegate the implementation to ExecuteAggregateAsync method below by providing an Async implementation of ExecuteScalar. Type cosmosQueryType = typeof(CosmosLinqQuery).GetGenericTypeDefinition().MakeGenericType(typeof(TResult)); CosmosLinqQuery cosmosLINQQuery = (CosmosLinqQuery)Activator.CreateInstance( cosmosQueryType, @@ -88,7 +91,7 @@ public TResult Execute(Expression expression) this.allowSynchronousQueryExecution, this.linqSerializerOptions); this.onExecuteScalarQueryCallback?.Invoke(cosmosLINQQuery); - return cosmosLINQQuery.ToList().FirstOrDefault(); + return cosmosLINQQuery.ExecuteScalar(); } //Sync execution of query via direct invoke on IQueryProvider. diff --git a/Microsoft.Azure.Cosmos/src/Linq/DocumentQuery.cs b/Microsoft.Azure.Cosmos/src/Linq/DocumentQuery.cs index 6d0b7f82b0..be86284dab 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/DocumentQuery.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/DocumentQuery.cs @@ -273,10 +273,10 @@ IEnumerator IEnumerable.GetEnumerator() public override string ToString() { - SqlQuerySpec querySpec = DocumentQueryEvaluator.Evaluate(this.Expression); - if (querySpec != null) + LinqQueryOperation linqQueryOperation = DocumentQueryEvaluator.Evaluate(this.Expression); + if (linqQueryOperation.SqlQuerySpec != null) { - return JsonConvert.SerializeObject(querySpec); + return JsonConvert.SerializeObject(linqQueryOperation.SqlQuerySpec); } return new Uri(this.client.ServiceEndpoint, this.documentsFeedOrDatabaseLink).ToString(); diff --git a/Microsoft.Azure.Cosmos/src/Linq/DocumentQueryEvaluator.cs b/Microsoft.Azure.Cosmos/src/Linq/DocumentQueryEvaluator.cs index da25188bc2..dec354bd19 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/DocumentQueryEvaluator.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/DocumentQueryEvaluator.cs @@ -15,7 +15,7 @@ internal static class DocumentQueryEvaluator { private const string SQLMethod = "AsSQL"; - public static SqlQuerySpec Evaluate( + public static LinqQueryOperation Evaluate( Expression expression, CosmosLinqSerializerOptionsInternal linqSerializerOptions = null, IDictionary parameters = null) @@ -51,7 +51,7 @@ public static bool IsTransformExpression(Expression expression) /// foreach(Database db in client.CreateDatabaseQuery()) {} /// /// - private static SqlQuerySpec HandleEmptyQuery(ConstantExpression expression) + private static LinqQueryOperation HandleEmptyQuery(ConstantExpression expression) { if (expression.Value == null) { @@ -69,11 +69,12 @@ private static SqlQuerySpec HandleEmptyQuery(ConstantExpression expression) ClientResources.BadQuery_InvalidExpression, expression.ToString())); } + //No query specified. - return null; + return new LinqQueryOperation(sqlQuerySpec: null, scalarOperationKind: ScalarOperationKind.None); } - private static SqlQuerySpec HandleMethodCallExpression( + private static LinqQueryOperation HandleMethodCallExpression( MethodCallExpression expression, IDictionary parameters, CosmosLinqSerializerOptionsInternal linqSerializerOptions = null) @@ -100,7 +101,7 @@ private static SqlQuerySpec HandleMethodCallExpression( /// foreach(string record in client.CreateDocumentQuery().Navigate("Raw JQuery")) /// /// - private static SqlQuerySpec HandleAsSqlTransformExpression(MethodCallExpression expression) + private static LinqQueryOperation HandleAsSqlTransformExpression(MethodCallExpression expression) { Expression paramExpression = expression.Arguments[1]; @@ -122,7 +123,7 @@ private static SqlQuerySpec HandleAsSqlTransformExpression(MethodCallExpression } } - private static SqlQuerySpec GetSqlQuerySpec(object value) + private static LinqQueryOperation GetSqlQuerySpec(object value) { if (value == null) { @@ -133,11 +134,11 @@ private static SqlQuerySpec GetSqlQuerySpec(object value) } else if (value.GetType() == typeof(SqlQuerySpec)) { - return (SqlQuerySpec)value; + return new LinqQueryOperation((SqlQuerySpec)value, ScalarOperationKind.None); } else if (value.GetType() == typeof(string)) { - return new SqlQuerySpec((string)value); + return new LinqQueryOperation(new SqlQuerySpec((string)value), ScalarOperationKind.None); } else { diff --git a/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs b/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs index ad9d02cab4..62b9132a8e 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/ExpressionToSQL.cs @@ -59,18 +59,22 @@ public static class LinqMethods public const string Any = "Any"; public const string Average = "Average"; public const string Count = "Count"; + public const string Distinct = "Distinct"; + public const string First = "First"; + public const string FirstOrDefault = "FirstOrDefault"; public const string Max = "Max"; public const string Min = "Min"; public const string OrderBy = "OrderBy"; - public const string ThenBy = "ThenBy"; public const string OrderByDescending = "OrderByDescending"; - public const string ThenByDescending = "ThenByDescending"; public const string Select = "Select"; public const string SelectMany = "SelectMany"; - public const string Sum = "Sum"; + public const string Single = "Single"; + public const string SingleOrDefault = "SingleOrDefault"; public const string Skip = "Skip"; + public const string Sum = "Sum"; + public const string ThenBy = "ThenBy"; + public const string ThenByDescending = "ThenByDescending"; public const string Take = "Take"; - public const string Distinct = "Distinct"; public const string Where = "Where"; } @@ -84,11 +88,13 @@ public static class LinqMethods /// An Expression representing a Query on a IDocumentQuery object. /// Optional dictionary for parameter name and value /// Optional serializer options. + /// Indicates the client operation that needs to be performed on the results of SqlQuery. /// The corresponding SQL query. public static SqlQuery TranslateQuery( Expression inputExpression, IDictionary parameters, - CosmosLinqSerializerOptionsInternal linqSerializerOptions) + CosmosLinqSerializerOptionsInternal linqSerializerOptions, + out ScalarOperationKind clientOperation) { TranslationContext context = new TranslationContext(linqSerializerOptions, parameters); ExpressionToSql.Translate(inputExpression, context); // ignore result here @@ -96,6 +102,7 @@ public static SqlQuery TranslateQuery( QueryUnderConstruction query = context.CurrentQuery; query = query.FlattenAsPossible(); SqlQuery result = query.GetSqlQuery(); + clientOperation = context.ClientOperation; return result; } @@ -1149,22 +1156,67 @@ private static Collection VisitMethodCall(MethodCallExpression inputExpression, context.PushSubqueryBinding(shouldBeOnNewQuery); switch (inputExpression.Method.Name) { - case LinqMethods.Select: + case LinqMethods.Any: { - SqlSelectClause select = ExpressionToSql.VisitSelect(inputExpression.Arguments, context); + result = new Collection(string.Empty); + + if (inputExpression.Arguments.Count == 2) + { + // Any is translated to an SELECT VALUE EXISTS() where Any operation itself is treated as a Where. + SqlWhereClause where = ExpressionToSql.VisitWhere(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddWhereClause(where, context); + } + break; + } + case LinqMethods.Average: + { + SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Avg); context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } - case LinqMethods.Where: + case LinqMethods.Count: { - SqlWhereClause where = ExpressionToSql.VisitWhere(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddWhereClause(where, context); + SqlSelectClause select = ExpressionToSql.VisitCount(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } - case LinqMethods.SelectMany: + case LinqMethods.Distinct: { - context.CurrentQuery = context.PackageCurrentQueryIfNeccessary(); - result = ExpressionToSql.VisitSelectMany(inputExpression.Arguments, context); + SqlSelectClause select = ExpressionToSql.VisitDistinct(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case LinqMethods.FirstOrDefault: + { + if (inputExpression.Arguments.Count == 1) + { + // TOP is not allowed when OFFSET ... LIMIT is present. + if (!context.CurrentQuery.HasOffsetSpec()) + { + SqlNumberLiteral sqlNumberLiteral = SqlNumberLiteral.Create(1); + SqlTopSpec topSpec = SqlTopSpec.Create(sqlNumberLiteral); + context.CurrentQuery = context.CurrentQuery.AddTopSpec(topSpec); + } + + context.SetClientOperation(ScalarOperationKind.FirstOrDefault); + } + else + { + throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, ClientResources.InvalidArgumentsCount, inputExpression.Method.Name, 0, inputExpression.Arguments.Count - 1)); + } + + break; + } + case LinqMethods.Max: + { + SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Max); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } + case LinqMethods.Min: + { + SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Min); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } case LinqMethods.OrderBy: @@ -1179,16 +1231,16 @@ private static Collection VisitMethodCall(MethodCallExpression inputExpression, context.CurrentQuery = context.CurrentQuery.AddOrderByClause(orderBy, context); break; } - case LinqMethods.ThenBy: + case LinqMethods.Select: { - SqlOrderByClause thenBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, false, context); - context.CurrentQuery = context.CurrentQuery.UpdateOrderByClause(thenBy, context); + SqlSelectClause select = ExpressionToSql.VisitSelect(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); break; } - case LinqMethods.ThenByDescending: + case LinqMethods.SelectMany: { - SqlOrderByClause thenBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, true, context); - context.CurrentQuery = context.CurrentQuery.UpdateOrderByClause(thenBy, context); + context.CurrentQuery = context.PackageCurrentQueryIfNeccessary(); + result = ExpressionToSql.VisitSelectMany(inputExpression.Arguments, context); break; } case LinqMethods.Skip: @@ -1197,6 +1249,12 @@ private static Collection VisitMethodCall(MethodCallExpression inputExpression, context.CurrentQuery = context.CurrentQuery.AddOffsetSpec(offsetSpec, context); break; } + case LinqMethods.Sum: + { + SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Sum); + context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + break; + } case LinqMethods.Take: { if (context.CurrentQuery.HasOffsetSpec()) @@ -1211,51 +1269,22 @@ private static Collection VisitMethodCall(MethodCallExpression inputExpression, } break; } - case LinqMethods.Distinct: - { - SqlSelectClause select = ExpressionToSql.VisitDistinct(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.Max: - { - SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Max); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.Min: - { - SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Min); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.Average: - { - SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Avg); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); - break; - } - case LinqMethods.Count: + case LinqMethods.ThenBy: { - SqlSelectClause select = ExpressionToSql.VisitCount(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + SqlOrderByClause thenBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, false, context); + context.CurrentQuery = context.CurrentQuery.UpdateOrderByClause(thenBy, context); break; } - case LinqMethods.Sum: + case LinqMethods.ThenByDescending: { - SqlSelectClause select = ExpressionToSql.VisitAggregateFunction(inputExpression.Arguments, context, SqlFunctionCallScalarExpression.Names.Sum); - context.CurrentQuery = context.CurrentQuery.AddSelectClause(select, context); + SqlOrderByClause thenBy = ExpressionToSql.VisitOrderBy(inputExpression.Arguments, true, context); + context.CurrentQuery = context.CurrentQuery.UpdateOrderByClause(thenBy, context); break; } - case LinqMethods.Any: + case LinqMethods.Where: { - result = new Collection(string.Empty); - if (inputExpression.Arguments.Count == 2) - { - // Any is translated to an SELECT VALUE EXISTS() where Any operation itself is treated as a Where. - SqlWhereClause where = ExpressionToSql.VisitWhere(inputExpression.Arguments, context); - context.CurrentQuery = context.CurrentQuery.AddWhereClause(where, context); - } + SqlWhereClause where = ExpressionToSql.VisitWhere(inputExpression.Arguments, context); + context.CurrentQuery = context.CurrentQuery.AddWhereClause(where, context); break; } default: diff --git a/Microsoft.Azure.Cosmos/src/Linq/LinqQueryOperation.cs b/Microsoft.Azure.Cosmos/src/Linq/LinqQueryOperation.cs new file mode 100644 index 0000000000..b1f07bc561 --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Linq/LinqQueryOperation.cs @@ -0,0 +1,24 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Linq +{ + using Microsoft.Azure.Cosmos.Query.Core; + + /// + /// Represents a linq expression as a combination of sql query and client operation. + /// + internal class LinqQueryOperation + { + public LinqQueryOperation(SqlQuerySpec sqlQuerySpec, ScalarOperationKind scalarOperationKind) + { + this.SqlQuerySpec = sqlQuerySpec; + this.ScalarOperationKind = scalarOperationKind; + } + + public SqlQuerySpec SqlQuerySpec { get; } + + public ScalarOperationKind ScalarOperationKind { get; } + } +} diff --git a/Microsoft.Azure.Cosmos/src/Linq/SQLTranslator.cs b/Microsoft.Azure.Cosmos/src/Linq/SQLTranslator.cs index 8491d9cf0c..9cf44e1af0 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/SQLTranslator.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/SQLTranslator.cs @@ -42,14 +42,13 @@ internal static string TranslateExpressionOld( return scalarExpression.ToString(); } - internal static SqlQuerySpec TranslateQuery( + internal static LinqQueryOperation TranslateQuery( Expression inputExpression, CosmosLinqSerializerOptionsInternal linqSerializerOptions, IDictionary parameters) { inputExpression = ConstantEvaluator.PartialEval(inputExpression); - SqlQuery query = ExpressionToSql.TranslateQuery(inputExpression, parameters, linqSerializerOptions); - string queryText = null; + SqlQuery query = ExpressionToSql.TranslateQuery(inputExpression, parameters, linqSerializerOptions, out ScalarOperationKind clientOperation); SqlParameterCollection sqlParameters = new SqlParameterCollection(); if (parameters != null && parameters.Count > 0) { @@ -58,10 +57,11 @@ internal static SqlQuerySpec TranslateQuery( sqlParameters.Add(new Microsoft.Azure.Cosmos.Query.Core.SqlParameter(keyValuePair.Value, keyValuePair.Key)); } } - queryText = query.ToString(); + + string queryText = query.ToString(); SqlQuerySpec sqlQuerySpec = new SqlQuerySpec(queryText, sqlParameters); - return sqlQuerySpec; + return new LinqQueryOperation(sqlQuerySpec, clientOperation); } } } diff --git a/Microsoft.Azure.Cosmos/src/Linq/ScalarOperationKind.cs b/Microsoft.Azure.Cosmos/src/Linq/ScalarOperationKind.cs new file mode 100644 index 0000000000..a145af3e6d --- /dev/null +++ b/Microsoft.Azure.Cosmos/src/Linq/ScalarOperationKind.cs @@ -0,0 +1,27 @@ +//------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +//------------------------------------------------------------ + +namespace Microsoft.Azure.Cosmos.Linq +{ + /// + /// Represents the operation that needs to be performed on the client side. + /// + /// + /// This enum represents scalar operations such as FirstOrDefault. Scalar operations are disallowed in sub-expressions/sub-queries. + /// With these restrictations, enum is sufficient, but in future for a larger surface area we may need + /// to use an object model like ClientQL to represent these operations better. + /// + internal enum ScalarOperationKind + { + /// + /// Indicates that client does not need to perform any operation on query results. + /// + None, + + /// + /// Indicates that the client needs to perform FirstOrDefault on the query results returned by the backend. + /// + FirstOrDefault + } +} diff --git a/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs b/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs index 3b1d9cd20b..82150b6d4a 100644 --- a/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs +++ b/Microsoft.Azure.Cosmos/src/Linq/TranslationContext.cs @@ -6,6 +6,7 @@ namespace Microsoft.Azure.Cosmos.Linq { using System; using System.Collections.Generic; + using System.Diagnostics; using System.Linq.Expressions; using Microsoft.Azure.Cosmos.Serializer; using Microsoft.Azure.Cosmos.SqlObjects; @@ -72,6 +73,8 @@ internal sealed class TranslationContext private static readonly MemberNames DefaultMemberNames = new MemberNames(new CosmosLinqSerializerOptions()); + private ScalarOperationKind? clientOperation; + public TranslationContext(CosmosLinqSerializerOptionsInternal linqSerializerOptionsInternal, IDictionary parameters = null) { this.InScope = new HashSet(); @@ -82,6 +85,7 @@ public TranslationContext(CosmosLinqSerializerOptionsInternal linqSerializerOpti this.CurrentQuery = new QueryUnderConstruction(this.GetGenFreshParameterFunc()); this.subqueryBindingStack = new Stack(); this.Parameters = parameters; + this.clientOperation = null; if (linqSerializerOptionsInternal?.CustomCosmosLinqSerializer != null) { @@ -102,6 +106,18 @@ public TranslationContext(CosmosLinqSerializerOptionsInternal linqSerializerOpti } } + public ScalarOperationKind ClientOperation => this.clientOperation ?? ScalarOperationKind.None; + + public void SetClientOperation(ScalarOperationKind clientOperation) + { + // CosmosLinqQuery which is the only indirect sole consumer of this class can only see at most one scalar operation at the top level, since the return type of scalar operation is no longer IQueryable. + // Furthermore, any nested scalar operations (on nested properties of type IEnumerable) are not handled in the same way as the top level operations. + // As a result clientOperation can only be set at most once. + Debug.Assert(this.clientOperation == null, "TranslationContext Assert!", "ClientOperation can be set at most once!"); + + this.clientOperation = clientOperation; + } + public Expression LookupSubstitution(ParameterExpression parameter) { return this.substitutions.Lookup(parameter); diff --git a/Microsoft.Azure.Cosmos/src/Query/v2Query/DocumentQueryExecutionContextBase.cs b/Microsoft.Azure.Cosmos/src/Query/v2Query/DocumentQueryExecutionContextBase.cs index 053c951f6c..5bd47c8a21 100644 --- a/Microsoft.Azure.Cosmos/src/Query/v2Query/DocumentQueryExecutionContextBase.cs +++ b/Microsoft.Azure.Cosmos/src/Query/v2Query/DocumentQueryExecutionContextBase.cs @@ -134,7 +134,14 @@ protected SqlQuerySpec QuerySpec { if (!this.isExpressionEvaluated) { - this.querySpec = DocumentQueryEvaluator.Evaluate(this.expression); + LinqQueryOperation linqQuery = DocumentQueryEvaluator.Evaluate(this.expression); + + if (linqQuery.ScalarOperationKind != ScalarOperationKind.None) + { + throw new NotSupportedException($"This operation does not support the supplied LINQ expression since it involves client side operation : {linqQuery.ScalarOperationKind}"); + } + + this.querySpec = linqQuery.SqlQuerySpec; this.isExpressionEvaluated = true; } diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqScalarFunctionBaselineTests.TestFirstOrDefault.xml b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqScalarFunctionBaselineTests.TestFirstOrDefault.xml new file mode 100644 index 0000000000..58f80c538b --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/BaselineTest/TestBaseline/LinqScalarFunctionBaselineTests.TestFirstOrDefault.xml @@ -0,0 +1,231 @@ + + + + + + + + + + + + + + FirstOrDefault 1]]> + data.Flag).FirstOrDefault(), Object)]]> + + + + + + + + + FirstOrDefault 2]]> + data.Multiples).FirstOrDefault()]]> + + + + + + + + + FirstOrDefault 1]]> + (data.Id == "1")).FirstOrDefault()]]> + + + + + + + + + FirstOrDefault 2]]> + data.Flag).FirstOrDefault()]]> + + + + + + + + + Where -> FirstOrDefault]]> + data.Flag).Where(flag => flag).FirstOrDefault(), Object)]]> + + + + + + + + + Select -> FirstOrDefault]]> + data.Id).Select(data => data.Flag).FirstOrDefault(), Object)]]> + + + + + + + + + FirstOrDefault]]> + data.Multiples).FirstOrDefault(), Object)]]> + + + + + + + + + FirstOrDefault]]> + + + + + + + + + + Skip -> Take -> FirstOrDefault]]> + data).Skip(5).Take(5).FirstOrDefault()]]> + + + + + + + + + FirstOrDefault]]> + + + + + + + + + + + (data.Flag AndAlso Not(data.Flag))).FirstOrDefault()]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + True)]]> + + + + + + + + + + + True, new Data())]]> + + + + + + + + + + + True, value(Microsoft.Azure.Cosmos.Services.Management.Tests.LinqProviderTests.LinqScalarFunctionBaselineTests).GetDefaultData())]]> + + + + + + + + + + + data.Multiples.FirstOrDefault()).Min(), Object)]]> + + + + + + + + + + + new List`1() {Void Add(Int32)(1), Void Add(Int32)(2), Void Add(Int32)(3)}.FirstOrDefault()).Min(), Object)]]> + + + + + + + + \ No newline at end of file diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/ILinqTestDataGenerator.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/ILinqTestDataGenerator.cs new file mode 100644 index 0000000000..5d39585531 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/ILinqTestDataGenerator.cs @@ -0,0 +1,14 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------- +namespace Microsoft.Azure.Cosmos.Services.Management.Tests +{ + using System.Collections.Generic; + + internal interface ILinqTestDataGenerator + { + IEnumerable GenerateData(); + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqCleanupTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqCleanupTests.cs new file mode 100644 index 0000000000..6e16380f90 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqCleanupTests.cs @@ -0,0 +1,71 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------- +namespace Microsoft.Azure.Cosmos.Services.Management.Tests.LinqProviderTests +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Text.RegularExpressions; + using System.Threading.Tasks; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using TestCommon = Microsoft.Azure.Cosmos.SDK.EmulatorTests.TestCommon; + + /// + /// Contains test that cleans up databases left over during debugging of LINQ tests. + /// This test does not run by default, but automates the process of deleting the databases left over during debugging session. + /// + [TestClass] + public class LinqCleanupTests + { + [Ignore] + [TestMethod] + public async Task CleanupLinqTestDatabases() + { + CosmosClient client = TestCommon.CreateCosmosClient(true); + Uri uri = client.ClientContext.Client.Endpoint; + if (uri.ToString().StartsWith(@"https://localhost:") || + uri.ToString().StartsWith(@"https://127.0.0.1:")) + { + Debug.WriteLine($"Executing against local endpoint '{uri}', continuing."); + FeedIterator feedIterator = client + .GetDatabaseQueryIterator( + queryDefinition: null, + continuationToken: null, + requestOptions: new QueryRequestOptions() { MaxItemCount = 2 }); + + Regex linqTestDatabaseRegex = new Regex("^Linq.*Baseline(Tests)?-[0-9A-Fa-f]{32}$"); + List databasesToDelete = new List(); + while (feedIterator.HasMoreResults) + { + FeedResponse databasePropertiesResponse = await feedIterator.ReadNextAsync(); + foreach (DatabaseProperties database in databasePropertiesResponse) + { + if (linqTestDatabaseRegex.IsMatch(database.Id)) + { + Debug.WriteLine($"Recognized database for deletion : '{database.Id}'"); + databasesToDelete.Add(database.Id); + } + else + { + Debug.WriteLine($"Database not recognized for deletion : '{database.Id}'"); + } + } + } + + foreach (string databaseToDelete in databasesToDelete) + { + Debug.WriteLine($"Deleting database '{databaseToDelete}'"); + Database database = client.GetDatabase(databaseToDelete); + DatabaseResponse response = await database.DeleteAsync(); + } + } + else + { + Debug.WriteLine($"Executing against non-local endpoint '{uri}', aborting."); + } + } + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqScalarFunctionBaselineTests.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqScalarFunctionBaselineTests.cs new file mode 100644 index 0000000000..74e52ab3b6 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqScalarFunctionBaselineTests.cs @@ -0,0 +1,332 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------- +namespace Microsoft.Azure.Cosmos.Services.Management.Tests.LinqProviderTests +{ + using System; + using System.Collections.Generic; + using System.Linq; + using System.Linq.Expressions; + using System.Threading.Tasks; + using System.Xml; + using TestCommon = Microsoft.Azure.Cosmos.SDK.EmulatorTests.TestCommon; + using Microsoft.Azure.Cosmos.Services.Management.Tests.BaselineTest; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Newtonsoft.Json; + using Newtonsoft.Json.Linq; + using Microsoft.Azure.Cosmos.Linq; + + /// + /// LINQ tests for Non aggregate scalar functions such as FirstOrDefault + /// + [TestClass] + public class LinqScalarFunctionBaselineTests : BaselineTests + { + private static CosmosClient client; + private static Cosmos.Database testDb; + private static Func> getQuery; + private static Func> getQueryFamily; + private static IQueryable lastExecutedScalarQuery; + + [ClassInitialize] + public async static Task Initialize(TestContext textContext) + { + client = TestCommon.CreateCosmosClient(true); + + // Set a callback to get the handle of the last executed query to do the verification + // This is neede because aggregate queries return type is a scalar so it can't be used + // to verify the translated LINQ directly as other queries type. + client.DocumentClient.OnExecuteScalarQueryCallback = q => lastExecutedScalarQuery = q; + + string dbName = $"{nameof(LinqAggregateFunctionBaselineTests)}-{Guid.NewGuid().ToString("N")}"; + testDb = await client.CreateDatabaseAsync(dbName); + + getQuery = LinqTestsCommon.GenerateSimpleCosmosData(testDb, useRandomData: false); + getQueryFamily = LinqTestsCommon.GenerateFamilyCosmosData(testDb, out _); + } + + [TestMethod] + [Owner("adityasa")] + public void TestFirstOrDefault() + { + List inputs = new List(); + + /////////////////////////////////////////////////// + // Positive cases - With at least one result + /////////////////////////////////////////////////// + + inputs.Add(new LinqScalarFunctionInput( + "FirstOrDefault", + b => getQuery(b) + .FirstOrDefault())); + + inputs.Add(new LinqScalarFunctionInput( + "Select -> FirstOrDefault 1", + b => getQuery(b) + .Select(data => data.Flag) + .FirstOrDefault())); + + inputs.Add(new LinqScalarFunctionInput( + "Select -> FirstOrDefault 2", + b => getQuery(b) + .Select(data => data.Multiples) + .FirstOrDefault())); + + inputs.Add(new LinqScalarFunctionInput( + "Where -> FirstOrDefault 1", + b => getQuery(b) + .Where(data => data.Id == "1") + .FirstOrDefault())); + + inputs.Add(new LinqScalarFunctionInput( + "Where -> FirstOrDefault 2", + b => getQuery(b) + .Where(data => data.Flag) + .FirstOrDefault())); + + inputs.Add(new LinqScalarFunctionInput( + "Select -> Where -> FirstOrDefault", + b => getQuery(b) + .Select(data => data.Flag) + .Where(flag => flag) + .FirstOrDefault())); + + inputs.Add(new LinqScalarFunctionInput( + "OrderBy -> Select -> FirstOrDefault", + b => getQuery(b) + .OrderBy(data => data.Id) + .Select(data => data.Flag) + .FirstOrDefault())); + + inputs.Add(new LinqScalarFunctionInput( + "SelectMany -> FirstOrDefault", + b => getQuery(b) + .SelectMany(data => data.Multiples) + .FirstOrDefault())); + + inputs.Add(new LinqScalarFunctionInput( + "Take -> FirstOrDefault", + b => getQuery(b) + .Take(10) + .FirstOrDefault())); + + inputs.Add(new LinqScalarFunctionInput( + "Select -> Skip -> Take -> FirstOrDefault", + b => getQuery(b) + .Select(data => data) + .Skip(5) + .Take(5) + .FirstOrDefault())); + + inputs.Add(new LinqScalarFunctionInput( + "Skip -> FirstOrDefault", + b => getQuery(b) + .Skip(3) + .FirstOrDefault())); + + /////////////////////////////////////////////////// + // Positive cases - With no results + /////////////////////////////////////////////////// + + inputs.Add(new LinqScalarFunctionInput( + "FirstOrDefault (default)", + b => getQuery(b) + .Where(data => data.Flag && !data.Flag) + .FirstOrDefault())); + + ///////////////// + // Negative cases + ///////////////// + + // ISSUE-TODO-adityasa-2024/1/26 - Support FirstOrDefault overloads. + // Please note, this requires potential support for user code invocation in context of rest of the client code (except maybe some simple cases). + // We do not currently do this for any other scenarios. + + // Unsupported + inputs.Add(new LinqScalarFunctionInput( + "FirstOrDefault with explicit (inline) default", + b => getQuery(b) + .FirstOrDefault(new Data()))); + + // Unsupported + inputs.Add(new LinqScalarFunctionInput( + "FirstOrDefault with explicit default from function invocation", + b => getQuery(b) + .FirstOrDefault(this.GetDefaultData()))); + + // Unsupported + inputs.Add(new LinqScalarFunctionInput( + "FirstOrDefault with predicate", + b => getQuery(b) + .FirstOrDefault(_ => true))); + + // Unsupported + inputs.Add(new LinqScalarFunctionInput( + "FirstOrDefault with explicit (inline) default and predicate", + b => getQuery(b) + .FirstOrDefault(_ => true, new Data()))); + + // Unsupported + inputs.Add(new LinqScalarFunctionInput( + "FirstOrDefault with explicit default from function invocation and predicate", + b => getQuery(b) + .FirstOrDefault(_ => true, this.GetDefaultData()))); + + // Unsupported + inputs.Add(new LinqScalarFunctionInput( + "Nested FirstOrDefault 1", + b => getQuery(b) + .Select(data => data.Multiples.FirstOrDefault()) + .Min())); + + // Unsupported + inputs.Add(new LinqScalarFunctionInput( + "Nested FirstOrDefault 2", + b => getQuery(b) + .Select(data => new List { 1, 2, 3 }.FirstOrDefault()) + .Min())); + + this.ExecuteTestSuite(inputs); + } + + private Data GetDefaultData() + { + return new Data(); + } + + public override LinqScalarFunctionOutput ExecuteTest(LinqScalarFunctionInput input) + { + lastExecutedScalarQuery = null; + Func compiledQuery = input.Expression.Compile(); + + string errorMessage = null; + string query = string.Empty; + object queryResult = null; + try + { + try + { + queryResult = compiledQuery(true); + } + finally + { + Assert.IsNotNull(lastExecutedScalarQuery, "lastExecutedScalarQuery is not set"); + + query = JObject + .Parse(lastExecutedScalarQuery.ToString()) + .GetValue("query", StringComparison.Ordinal) + .ToString(); + } + + try + { + object dataResult = compiledQuery(false); + Assert.IsTrue(AreEqual(dataResult, queryResult)); + } + catch (ArgumentException) + { + // Min and Max operations cannot be done on Document type + // In this case, the queryResult should be null + Assert.AreEqual(null, queryResult); + } + } + catch (Exception e) + { + errorMessage = LinqTestsCommon.BuildExceptionMessageForTest(e); + } + + string serializedResults = JsonConvert.SerializeObject( + queryResult, + new JsonSerializerSettings { Formatting = Newtonsoft.Json.Formatting.Indented }); + + return new LinqScalarFunctionOutput(query, errorMessage, serializedResults); + } + + private static bool AreEqual(object obj1, object obj2) + { + bool equals = obj1 == obj2; + if (equals) + { + return true; + } + + if (obj1 is int[] intArray1 && obj2 is int[] intArray2) + { + return intArray1.SequenceEqual(intArray2); + } + + return obj1.Equals(obj2); + } + } + + public sealed class LinqScalarFunctionOutput : BaselineTestOutput + { + public string SqlQuery { get; } + + public string ErrorMessage { get; } + + public string SerializedResults { get; } + + internal LinqScalarFunctionOutput(string sqlQuery, string errorMessage, string serializedResults) + { + this.SqlQuery = sqlQuery; + this.ErrorMessage = errorMessage; + this.SerializedResults = serializedResults; + } + + public override void SerializeAsXml(XmlWriter xmlWriter) + { + xmlWriter.WriteStartElement(nameof(this.SqlQuery)); + xmlWriter.WriteCData(LinqTestOutput.FormatSql(this.SqlQuery)); + xmlWriter.WriteEndElement(); + if (this.ErrorMessage != null) + { + xmlWriter.WriteStartElement(nameof(this.ErrorMessage)); + xmlWriter.WriteCData(LinqTestOutput.FormatErrorMessage(this.ErrorMessage)); + xmlWriter.WriteEndElement(); + } + + if (this.SerializedResults != null) + { + xmlWriter.WriteStartElement(nameof(this.SerializedResults)); + xmlWriter.WriteCData(LinqTestOutput.FormatErrorMessage(this.SerializedResults)); + xmlWriter.WriteEndElement(); + } + } + } + + public sealed class LinqScalarFunctionInput : BaselineTestInput + { + internal LinqScalarFunctionInput(string description, Expression> expression) + : base(description) + { + if (expression == null) + { + throw new ArgumentNullException($"{nameof(expression)} must not be null."); + } + + this.Expression = expression; + } + + public Expression> Expression { get; } + + public override void SerializeAsXml(XmlWriter xmlWriter) + { + if (xmlWriter == null) + { + throw new ArgumentNullException($"{nameof(xmlWriter)} cannot be null."); + } + + string expressionString = LinqTestInput.FilterInputExpression(this.Expression.Body.ToString()); + + xmlWriter.WriteStartElement("Description"); + xmlWriter.WriteCData(this.Description); + xmlWriter.WriteEndElement(); + xmlWriter.WriteStartElement("Expression"); + xmlWriter.WriteCData(expressionString); + xmlWriter.WriteEndElement(); + } + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestData.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestData.cs index a543a00418..4044b623fc 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestData.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestData.cs @@ -5,9 +5,10 @@ //----------------------------------------------------------------------- namespace Microsoft.Azure.Cosmos.Services.Management.Tests { - using Newtonsoft.Json; using System; using System.Collections.Generic; + using System.Linq; + using Newtonsoft.Json; #region Family classes @@ -95,6 +96,48 @@ public class Data public bool Flag { get; set; } public int[] Multiples { get; set; } + + public override bool Equals(object obj) + { + Data other = obj as Data; + + if(other == null) + { + return false; + } + + bool equals = this.Id == other.Id && + this.Number == other.Number && + this.Pk == other.Pk && + this.Flag == other.Flag && + (this.Multiples?.Length == other.Multiples?.Length); + + if (equals && + this.Multiples != null) + { + equals &= this.Multiples.SequenceEqual(other.Multiples); + } + + return equals; + } + + public override int GetHashCode() + { + int hashCode = this.Id.GetHashCode() ^ + this.Number.GetHashCode() ^ + this.Pk.GetHashCode() ^ + this.Flag.GetHashCode(); + + if (this.Multiples != null) + { + foreach (int value in this.Multiples) + { + hashCode ^= value.GetHashCode(); + } + } + + return hashCode; + } } #endregion diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestDataGenerator.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestDataGenerator.cs new file mode 100644 index 0000000000..ef1ba2430c --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestDataGenerator.cs @@ -0,0 +1,34 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------- +namespace Microsoft.Azure.Cosmos.Services.Management.Tests +{ + using System.Collections.Generic; + + internal class LinqTestDataGenerator : ILinqTestDataGenerator + { + private readonly int count; + + public LinqTestDataGenerator(int count) + { + this.count = count; + } + + public IEnumerable GenerateData() + { + for (int index = 0; index < this.count; index++) + { + yield return new Data() + { + Id = index.ToString(), + Number = index * 1000, + Flag = index % 2 == 0, + Multiples = new int[] { index, index * 2, index * 3, index * 4 }, + Pk = "Test" + }; + } + } + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestRandomDataGenerator.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestRandomDataGenerator.cs new file mode 100644 index 0000000000..7e45310746 --- /dev/null +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestRandomDataGenerator.cs @@ -0,0 +1,41 @@ +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------- +namespace Microsoft.Azure.Cosmos.Services.Management.Tests +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + + internal class LinqTestRandomDataGenerator : ILinqTestDataGenerator + { + private readonly int count; + private readonly Random random; + + public LinqTestRandomDataGenerator(int count) + { + this.count = count; + int seed = DateTime.Now.Millisecond; + this.random = new Random(seed); + + Debug.WriteLine("Random seed: {0}", seed); + } + + public IEnumerable GenerateData() + { + for (int index = 0; index < this.count; index++) + { + yield return new Data() + { + Id = Guid.NewGuid().ToString(), + Number = this.random.Next(-10000, 10000), + Flag = index % 2 == 0, + Multiples = new int[] { index, index * 2, index * 3, index * 4 }, + Pk = "Test" + }; + } + } + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs index 815de90537..7d703061a4 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTestsCommon.cs @@ -1,586 +1,574 @@ -//----------------------------------------------------------------------- -// -// Copyright (c) Microsoft Corporation. All rights reserved. -// -//----------------------------------------------------------------------- -namespace Microsoft.Azure.Cosmos.Services.Management.Tests -{ - using System; - using System.Collections; - using System.Collections.Generic; - using System.Collections.ObjectModel; - using System.Diagnostics; - using System.IO; - using System.Linq; - using System.Linq.Expressions; - using System.Reflection; - using System.Runtime.CompilerServices; - using System.Text; - using System.Text.Json.Serialization; - using System.Text.Json; - using System.Text.RegularExpressions; - using System.Xml; - using global::Azure.Core.Serialization; - using Microsoft.Azure.Cosmos.Services.Management.Tests.BaselineTest; - using Microsoft.Azure.Documents; - using Microsoft.VisualStudio.TestTools.UnitTesting; - using Newtonsoft.Json; +//----------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------- +namespace Microsoft.Azure.Cosmos.Services.Management.Tests +{ + using System; + using System.Collections; + using System.Collections.Generic; + using System.Collections.ObjectModel; + using System.Diagnostics; + using System.IO; + using System.Linq; + using System.Linq.Expressions; + using System.Reflection; + using System.Runtime.CompilerServices; + using System.Text; + using System.Text.Json.Serialization; + using System.Text.Json; + using System.Text.RegularExpressions; + using System.Xml; + using global::Azure.Core.Serialization; + using Microsoft.Azure.Cosmos.Services.Management.Tests.BaselineTest; + using Microsoft.Azure.Documents; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Newtonsoft.Json; using Newtonsoft.Json.Linq; - internal class LinqTestsCommon - { - /// - /// Compare two list of anonymous objects - /// - /// - /// - /// - private static bool CompareListOfAnonymousType(List queryResults, List dataResults) - { - return queryResults.SequenceEqual(dataResults); - } - - /// - /// Compare 2 IEnumerable which may contain IEnumerable themselves. - /// - /// The query results from Cosmos DB - /// The query results from actual data - /// True if the two IEbumerable equal - private static bool NestedListsSequenceEqual(IEnumerable queryResults, IEnumerable dataResults) - { - IEnumerator queryIter, dataIter; - for (queryIter = queryResults.GetEnumerator(), dataIter = dataResults.GetEnumerator(); - queryIter.MoveNext() && dataIter.MoveNext();) - { - IEnumerable queryEnumerable = queryIter.Current as IEnumerable; - IEnumerable dataEnumerable = dataIter.Current as IEnumerable; - if (queryEnumerable == null && dataEnumerable == null) - { - if (!queryIter.Current.Equals(dataIter.Current)) return false; - - } - - else if (queryEnumerable == null || dataEnumerable == null) - { - return false; - } - - else - { - if (!LinqTestsCommon.NestedListsSequenceEqual(queryEnumerable, dataEnumerable)) return false; - } - } - - return !(queryIter.MoveNext() || dataIter.MoveNext()); - } - - /// - /// Compare the list of results from CosmosDB query and the list of results from LinQ query on the original data - /// Similar to Collections.SequenceEqual with the assumption that these lists are non-empty - /// - /// A list representing the query restuls from CosmosDB - /// A list representing the linQ query results from the original data - /// true if the two - private static bool CompareListOfArrays(List queryResults, List dataResults) - { - if (NestedListsSequenceEqual(queryResults, dataResults)) return true; - - bool resultMatched = true; - - // dataResults contains type ConcatIterator whereas queryResults may contain IEnumerable - // therefore it's simpler to just cast them into List> manually for simplify the verification - List> l1 = new List>(); - foreach (IEnumerable list in dataResults) - { - List l = new List(); - IEnumerator iterator = list.GetEnumerator(); - while (iterator.MoveNext()) - { - l.Add(iterator.Current); - } - - l1.Add(l); - } - - List> l2 = new List>(); - foreach (IEnumerable list in queryResults) - { - List l = new List(); - IEnumerator iterator = list.GetEnumerator(); - while (iterator.MoveNext()) - { - l.Add(iterator.Current); - } - - l2.Add(l); - } - - foreach (IEnumerable list in l1) - { - if (!l2.Any(a => a.SequenceEqual(list))) - { - resultMatched = false; - return false; - } - } - - foreach (IEnumerable list in l2) - { - if (!l1.Any(a => a.SequenceEqual(list))) - { - resultMatched = false; - break; - } - } - - return resultMatched; - } - - private static bool IsNumber(dynamic value) - { - return value is sbyte - || value is byte - || value is short - || value is ushort - || value is int - || value is uint - || value is long - || value is ulong - || value is float - || value is double - || value is decimal; - } - - public static Boolean IsAnonymousType(Type type) - { - Boolean hasCompilerGeneratedAttribute = type.GetCustomAttributes(typeof(CompilerGeneratedAttribute), false).Count() > 0; - Boolean nameContainsAnonymousType = type.FullName.Contains("AnonymousType"); - Boolean isAnonymousType = hasCompilerGeneratedAttribute && nameContainsAnonymousType; - - return isAnonymousType; - } - - /// - /// Gets the results of CosmosDB query and the results of LINQ query on the original data - /// - /// - /// - public static (List queryResults, List dataResults) GetResults(IQueryable queryResults, IQueryable dataResults) - { - // execution validation - IEnumerator queryEnumerator = queryResults.GetEnumerator(); - List queryResultsList = new List(); - while (queryEnumerator.MoveNext()) - { - queryResultsList.Add(queryEnumerator.Current); - } - - List dataResultsList = dataResults?.Cast()?.ToList(); - - return (queryResultsList, dataResultsList); - } - - /// - /// Validates the results of CosmosDB query and the results of LINQ query on the original data - /// Using Assert, will fail the unit test if the two results list are not SequenceEqual - /// - /// - /// - private static void ValidateResults(List queryResultsList, List dataResultsList) - { - bool resultMatched = true; - string actualStr = null; - string expectedStr = null; - if (dataResultsList.Count == 0 || queryResultsList.Count == 0) - { - resultMatched &= dataResultsList.Count == queryResultsList.Count; - } - else - { - dynamic firstElem = dataResultsList.FirstOrDefault(); - if (firstElem is IEnumerable) - { - resultMatched &= CompareListOfArrays(queryResultsList, dataResultsList); - } - else if (LinqTestsCommon.IsAnonymousType(firstElem.GetType())) - { - resultMatched &= CompareListOfAnonymousType(queryResultsList, dataResultsList); - } - else if (LinqTestsCommon.IsNumber(firstElem)) - { - const double Epsilon = 1E-6; - Type dataType = firstElem.GetType(); - List dataSortedList = dataResultsList.OrderBy(x => x).ToList(); - List querySortedList = queryResultsList.OrderBy(x => x).ToList(); - if (dataSortedList.Count != querySortedList.Count) - { - resultMatched = false; - } - else - { - for (int i = 0; i < dataSortedList.Count; ++i) - { - if (Math.Abs(dataSortedList[i] - (dynamic)querySortedList[i]) > (dynamic)Convert.ChangeType(Epsilon, dataType)) - { - resultMatched = false; - break; - } - } - } - - if (!resultMatched) - { - actualStr = JsonConvert.SerializeObject(querySortedList); - expectedStr = JsonConvert.SerializeObject(dataSortedList); - } - } - else - { - List dataNotQuery = dataResultsList.Except(queryResultsList).ToList(); - List queryNotData = queryResultsList.Except(dataResultsList).ToList(); - resultMatched &= !dataNotQuery.Any() && !queryNotData.Any(); - } - } - - string assertMsg = string.Empty; - if (!resultMatched) - { - actualStr ??= JsonConvert.SerializeObject(queryResultsList); - expectedStr ??= JsonConvert.SerializeObject(dataResultsList); - - resultMatched |= actualStr.Equals(expectedStr); - if (!resultMatched) - { - assertMsg = $"Expected: {expectedStr}, Actual: {actualStr}, RandomSeed: {LinqTestInput.RandomSeed}"; - } - } - - Assert.IsTrue(resultMatched, assertMsg); - } - - /// - /// Generate a random string containing alphabetical characters - /// - /// - /// - /// a random string - public static string RandomString(Random random, int length) - { - const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789abcdefghijklmnopqrstuvwxyz "; - return new string(Enumerable.Repeat(chars, length).Select(s => s[random.Next(s.Length)]).ToArray()); - } - - /// - /// Generate a random DateTime object from a DateTime, - /// with the variance of the time span between the provided DateTime to the current time - /// - /// - /// - /// - public static DateTime RandomDateTime(Random random, DateTime midDateTime) - { - TimeSpan timeSpan = DateTime.Now - midDateTime; - TimeSpan newSpan = new TimeSpan(0, random.Next(0, (int)timeSpan.TotalMinutes * 2) - (int)timeSpan.TotalMinutes, 0); - DateTime newDate = midDateTime + newSpan; - return newDate; - } - - /// - /// Generate test data for most LINQ tests - /// - /// the object type - /// the lamda to create an instance of test data - /// number of test data to be created - /// the target container - /// a lambda that takes a boolean which indicate where the query should run against CosmosDB or against original data, and return a query results as IQueryable - public static Func> GenerateTestCosmosData(Func func, int count, Container container) - { - List data = new List(); - int seed = DateTime.Now.Millisecond; - Random random = new Random(seed); - Debug.WriteLine("Random seed: {0}", seed); - LinqTestInput.RandomSeed = seed; - for (int i = 0; i < count; ++i) - { - data.Add(func(random)); - } - - foreach (T obj in data) - { - ItemResponse response = container.CreateItemAsync(obj, new Cosmos.PartitionKey("Test")).Result; - } - - FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; - QueryRequestOptions requestOptions = new QueryRequestOptions(); - - IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions); - - // To cover both query against backend and queries on the original data using LINQ nicely, - // the LINQ expression should be written once and they should be compiled and executed against the two sources. - // That is done by using Func that take a boolean Func. The parameter of the Func indicate whether the Cosmos DB query - // or the data list should be used. When a test is executed, the compiled LINQ expression would pass different values - // to this getQuery method. - IQueryable getQuery(bool useQuery) => useQuery ? query : data.AsQueryable(); - - return getQuery; - } - - /// - /// Generate a non-random payload for serializer LINQ tests. - /// - /// the object type - /// the lamda to create an instance of test data - /// number of test data to be created - /// the target container - /// if theCosmosLinqSerializerOption of camelCaseSerialization should be applied - /// a lambda that takes a boolean which indicate where the query should run against CosmosDB or against original data, and return a query results as IQueryable. - public static Func> GenerateSerializationTestCosmosData(Func func, int count, Container container, CosmosLinqSerializerOptions linqSerializerOptions) - { - List data = new List(); - for (int i = 0; i < count; i++) - { - data.Add(func(i, linqSerializerOptions.PropertyNamingPolicy == CosmosPropertyNamingPolicy.CamelCase)); - } - - foreach (T obj in data) - { - ItemResponse response = container.CreateItemAsync(obj, new Cosmos.PartitionKey("Test")).Result; - } - - FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; - QueryRequestOptions requestOptions = new QueryRequestOptions(); - - IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions, linqSerializerOptions: linqSerializerOptions); - - IQueryable getQuery(bool useQuery) => useQuery ? query : data.AsQueryable(); - - return getQuery; - } - - public static Func> GenerateFamilyCosmosData( - Cosmos.Database cosmosDatabase, out Container container) - { - // The test collection should have range index on string properties - // for the orderby tests - PartitionKeyDefinition partitionKeyDefinition = new PartitionKeyDefinition { Paths = new System.Collections.ObjectModel.Collection(new[] { "/Pk" }), Kind = PartitionKind.Hash }; - ContainerProperties newCol = new ContainerProperties() - { - Id = Guid.NewGuid().ToString(), - PartitionKey = partitionKeyDefinition, - IndexingPolicy = new Microsoft.Azure.Cosmos.IndexingPolicy() - { - IncludedPaths = new Collection() - { - new Cosmos.IncludedPath() - { - Path = "/*", - Indexes = new System.Collections.ObjectModel.Collection() - { - Microsoft.Azure.Cosmos.Index.Range(Microsoft.Azure.Cosmos.DataType.Number, -1), - Microsoft.Azure.Cosmos.Index.Range(Microsoft.Azure.Cosmos.DataType.String, -1) - } - } - }, - CompositeIndexes = new Collection>() - { - new Collection() - { - new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending } - }, - new Collection() - { - new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Descending } - }, - new Collection() - { - new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Descending } - }, - new Collection() - { - new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Descending } - }, - new Collection() - { - new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Ascending }, - new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Descending } - } - } - } - }; - container = cosmosDatabase.CreateContainerAsync(newCol).Result; - const int Records = 100; - const int MaxNameLength = 100; - const int MaxThingStringLength = 50; - const int MaxChild = 5; - const int MaxPets = MaxChild; - const int MaxThings = MaxChild; - const int MaxGrade = 101; - const int MaxTransaction = 20; - const int MaxTransactionMinuteRange = 200; - int MaxTransactionType = Enum.GetValues(typeof(TransactionType)).Length; - Family createDataObj(Random random) - { - Family obj = new Family - { - FamilyId = random.NextDouble() < 0.05 ? "some id" : Guid.NewGuid().ToString(), - IsRegistered = random.NextDouble() < 0.5, - NullableInt = random.NextDouble() < 0.5 ? (int?)random.Next() : null, - Int = random.NextDouble() < 0.5 ? 5 : random.Next(), - Id = Guid.NewGuid().ToString(), - Pk = "Test", - Parents = new Parent[random.Next(2) + 1] - }; - for (int i = 0; i < obj.Parents.Length; ++i) - { - obj.Parents[i] = new Parent() - { - FamilyName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), - GivenName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)) - }; - } - - obj.Tags = new string[random.Next(MaxChild)]; - for (int i = 0; i < obj.Tags.Length; ++i) - { - obj.Tags[i] = (i + random.Next(30, 36)).ToString(); - } - - obj.Children = new Child[random.Next(MaxChild)]; - for (int i = 0; i < obj.Children.Length; ++i) - { - obj.Children[i] = new Child() - { - Gender = random.NextDouble() < 0.5 ? "male" : "female", - FamilyName = obj.Parents[random.Next(obj.Parents.Length)].FamilyName, - GivenName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), - Grade = random.Next(MaxGrade) - }; - - obj.Children[i].Pets = new List(); - for (int j = 0; j < random.Next(MaxPets); ++j) - { - obj.Children[i].Pets.Add(new Pet() - { - GivenName = random.NextDouble() < 0.5 ? - LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)) : - "Fluffy" - }); - } - - obj.Children[i].Things = new Dictionary(); - for (int j = 0; j < random.Next(MaxThings) + 1; ++j) - { - obj.Children[i].Things.Add( - j == 0 ? "A" : $"{j}-{random.Next()}", - LinqTestsCommon.RandomString(random, random.Next(MaxThingStringLength))); - } - } - - obj.Records = new Logs - { - LogId = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), - Transactions = new Transaction[random.Next(MaxTransaction)] - }; - for (int i = 0; i < obj.Records.Transactions.Length; ++i) - { - Transaction transaction = new Transaction() - { - Amount = random.Next(), - Date = DateTime.Now.AddMinutes(random.Next(MaxTransactionMinuteRange)), - Type = (TransactionType)random.Next(MaxTransactionType) - }; - obj.Records.Transactions[i] = transaction; - } - - return obj; - } - - Func> getQuery = LinqTestsCommon.GenerateTestCosmosData(createDataObj, Records, container); - return getQuery; - } - - public static Func> GenerateSimpleCosmosData(Cosmos.Database cosmosDatabase) - { - const int DocumentCount = 10; - PartitionKeyDefinition partitionKeyDefinition = new PartitionKeyDefinition { Paths = new System.Collections.ObjectModel.Collection(new[] { "/Pk" }), Kind = PartitionKind.Hash }; - Container container = cosmosDatabase.CreateContainerAsync(new ContainerProperties { Id = Guid.NewGuid().ToString(), PartitionKey = partitionKeyDefinition }).Result; - - int seed = DateTime.Now.Millisecond; - Random random = new Random(seed); - Debug.WriteLine("Random seed: {0}", seed); - List testData = new List(); - for (int index = 0; index < DocumentCount; index++) - { - Data dataEntry = new Data() - { - Id = Guid.NewGuid().ToString(), - Number = random.Next(-10000, 10000), - Flag = index % 2 == 0, - Multiples = new int[] { index, index * 2, index * 3, index * 4 }, - Pk = "Test" - }; - - Data response = container.CreateItemAsync(dataEntry, new Cosmos.PartitionKey(dataEntry.Pk)).Result; - testData.Add(dataEntry); - } - - FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; - QueryRequestOptions requestOptions = new QueryRequestOptions(); - - IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions); - - // To cover both query against backend and queries on the original data using LINQ nicely, - // the LINQ expression should be written once and they should be compiled and executed against the two sources. - // That is done by using Func that take a boolean Func. The parameter of the Func indicate whether the Cosmos DB query - // or the data list should be used. When a test is executed, the compiled LINQ expression would pass different values - // to this getQuery method. - IQueryable getQuery(bool useQuery) => useQuery ? query : testData.AsQueryable(); - return getQuery; - } - - public static LinqTestOutput ExecuteTest(LinqTestInput input, bool serializeResultsInBaseline = false) - { - string querySqlStr = string.Empty; - try - { - Func compiledQuery = input.Expression.Compile(); - - IQueryable query = compiledQuery(true); - querySqlStr = JObject.Parse(query.ToString()).GetValue("query", StringComparison.Ordinal).ToString(); - - IQueryable dataQuery = input.skipVerification ? null : compiledQuery(false); - - (List queryResults, List dataResults) = GetResults(query, dataQuery); - - // we skip unordered query because the LINQ results vs actual query results are non-deterministic - if (!input.skipVerification) - { - LinqTestsCommon.ValidateResults(queryResults, dataResults); - } - - string serializedResults = serializeResultsInBaseline ? - JsonConvert.SerializeObject(queryResults.Select(item => item is LinqTestObject ? item.ToString() : item), new JsonSerializerSettings { Formatting = Newtonsoft.Json.Formatting.Indented}) : - null; - - return new LinqTestOutput(querySqlStr, serializedResults, errorMsg: null, input.inputData); - } - catch (Exception e) - { - return new LinqTestOutput(querySqlStr, serializedResults: null, errorMsg: LinqTestsCommon.BuildExceptionMessageForTest(e), inputData: input.inputData); - } - } - - public static string BuildExceptionMessageForTest(Exception ex) + internal class LinqTestsCommon + { + /// + /// Compare two list of anonymous objects + /// + /// + /// + /// + private static bool CompareListOfAnonymousType(List queryResults, List dataResults) + { + return queryResults.SequenceEqual(dataResults); + } + + /// + /// Compare 2 IEnumerable which may contain IEnumerable themselves. + /// + /// The query results from Cosmos DB + /// The query results from actual data + /// True if the two IEbumerable equal + private static bool NestedListsSequenceEqual(IEnumerable queryResults, IEnumerable dataResults) + { + IEnumerator queryIter, dataIter; + for (queryIter = queryResults.GetEnumerator(), dataIter = dataResults.GetEnumerator(); + queryIter.MoveNext() && dataIter.MoveNext();) + { + IEnumerable queryEnumerable = queryIter.Current as IEnumerable; + IEnumerable dataEnumerable = dataIter.Current as IEnumerable; + if (queryEnumerable == null && dataEnumerable == null) + { + if (!queryIter.Current.Equals(dataIter.Current)) return false; + + } + + else if (queryEnumerable == null || dataEnumerable == null) + { + return false; + } + + else + { + if (!LinqTestsCommon.NestedListsSequenceEqual(queryEnumerable, dataEnumerable)) return false; + } + } + + return !(queryIter.MoveNext() || dataIter.MoveNext()); + } + + /// + /// Compare the list of results from CosmosDB query and the list of results from LinQ query on the original data + /// Similar to Collections.SequenceEqual with the assumption that these lists are non-empty + /// + /// A list representing the query restuls from CosmosDB + /// A list representing the linQ query results from the original data + /// true if the two + private static bool CompareListOfArrays(List queryResults, List dataResults) + { + if (NestedListsSequenceEqual(queryResults, dataResults)) return true; + + bool resultMatched = true; + + // dataResults contains type ConcatIterator whereas queryResults may contain IEnumerable + // therefore it's simpler to just cast them into List> manually for simplify the verification + List> l1 = new List>(); + foreach (IEnumerable list in dataResults) + { + List l = new List(); + IEnumerator iterator = list.GetEnumerator(); + while (iterator.MoveNext()) + { + l.Add(iterator.Current); + } + + l1.Add(l); + } + + List> l2 = new List>(); + foreach (IEnumerable list in queryResults) + { + List l = new List(); + IEnumerator iterator = list.GetEnumerator(); + while (iterator.MoveNext()) + { + l.Add(iterator.Current); + } + + l2.Add(l); + } + + foreach (IEnumerable list in l1) + { + if (!l2.Any(a => a.SequenceEqual(list))) + { + resultMatched = false; + return false; + } + } + + foreach (IEnumerable list in l2) + { + if (!l1.Any(a => a.SequenceEqual(list))) + { + resultMatched = false; + break; + } + } + + return resultMatched; + } + + private static bool IsNumber(dynamic value) + { + return value is sbyte + || value is byte + || value is short + || value is ushort + || value is int + || value is uint + || value is long + || value is ulong + || value is float + || value is double + || value is decimal; + } + + public static Boolean IsAnonymousType(Type type) + { + Boolean hasCompilerGeneratedAttribute = type.GetCustomAttributes(typeof(CompilerGeneratedAttribute), false).Count() > 0; + Boolean nameContainsAnonymousType = type.FullName.Contains("AnonymousType"); + Boolean isAnonymousType = hasCompilerGeneratedAttribute && nameContainsAnonymousType; + + return isAnonymousType; + } + + /// + /// Gets the results of CosmosDB query and the results of LINQ query on the original data + /// + /// + /// + public static (List queryResults, List dataResults) GetResults(IQueryable queryResults, IQueryable dataResults) + { + // execution validation + IEnumerator queryEnumerator = queryResults.GetEnumerator(); + List queryResultsList = new List(); + while (queryEnumerator.MoveNext()) + { + queryResultsList.Add(queryEnumerator.Current); + } + + List dataResultsList = dataResults?.Cast()?.ToList(); + + return (queryResultsList, dataResultsList); + } + + /// + /// Validates the results of CosmosDB query and the results of LINQ query on the original data + /// Using Assert, will fail the unit test if the two results list are not SequenceEqual + /// + /// + /// + private static void ValidateResults(List queryResultsList, List dataResultsList) + { + bool resultMatched = true; + string actualStr = null; + string expectedStr = null; + if (dataResultsList.Count == 0 || queryResultsList.Count == 0) + { + resultMatched &= dataResultsList.Count == queryResultsList.Count; + } + else + { + dynamic firstElem = dataResultsList.FirstOrDefault(); + if (firstElem is IEnumerable) + { + resultMatched &= CompareListOfArrays(queryResultsList, dataResultsList); + } + else if (LinqTestsCommon.IsAnonymousType(firstElem.GetType())) + { + resultMatched &= CompareListOfAnonymousType(queryResultsList, dataResultsList); + } + else if (LinqTestsCommon.IsNumber(firstElem)) + { + const double Epsilon = 1E-6; + Type dataType = firstElem.GetType(); + List dataSortedList = dataResultsList.OrderBy(x => x).ToList(); + List querySortedList = queryResultsList.OrderBy(x => x).ToList(); + if (dataSortedList.Count != querySortedList.Count) + { + resultMatched = false; + } + else + { + for (int i = 0; i < dataSortedList.Count; ++i) + { + if (Math.Abs(dataSortedList[i] - (dynamic)querySortedList[i]) > (dynamic)Convert.ChangeType(Epsilon, dataType)) + { + resultMatched = false; + break; + } + } + } + + if (!resultMatched) + { + actualStr = JsonConvert.SerializeObject(querySortedList); + expectedStr = JsonConvert.SerializeObject(dataSortedList); + } + } + else + { + List dataNotQuery = dataResultsList.Except(queryResultsList).ToList(); + List queryNotData = queryResultsList.Except(dataResultsList).ToList(); + resultMatched &= !dataNotQuery.Any() && !queryNotData.Any(); + } + } + + string assertMsg = string.Empty; + if (!resultMatched) + { + actualStr ??= JsonConvert.SerializeObject(queryResultsList); + expectedStr ??= JsonConvert.SerializeObject(dataResultsList); + + resultMatched |= actualStr.Equals(expectedStr); + if (!resultMatched) + { + assertMsg = $"Expected: {expectedStr}, Actual: {actualStr}, RandomSeed: {LinqTestInput.RandomSeed}"; + } + } + + Assert.IsTrue(resultMatched, assertMsg); + } + + /// + /// Generate a random string containing alphabetical characters + /// + /// + /// + /// a random string + public static string RandomString(Random random, int length) + { + const string chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789abcdefghijklmnopqrstuvwxyz "; + return new string(Enumerable.Repeat(chars, length).Select(s => s[random.Next(s.Length)]).ToArray()); + } + + /// + /// Generate a random DateTime object from a DateTime, + /// with the variance of the time span between the provided DateTime to the current time + /// + /// + /// + /// + public static DateTime RandomDateTime(Random random, DateTime midDateTime) + { + TimeSpan timeSpan = DateTime.Now - midDateTime; + TimeSpan newSpan = new TimeSpan(0, random.Next(0, (int)timeSpan.TotalMinutes * 2) - (int)timeSpan.TotalMinutes, 0); + DateTime newDate = midDateTime + newSpan; + return newDate; + } + + /// + /// Generate test data for most LINQ tests + /// + /// the object type + /// the lamda to create an instance of test data + /// number of test data to be created + /// the target container + /// a lambda that takes a boolean which indicate where the query should run against CosmosDB or against original data, and return a query results as IQueryable + public static Func> GenerateTestCosmosData(Func func, int count, Container container) + { + List data = new List(); + int seed = DateTime.Now.Millisecond; + Random random = new Random(seed); + Debug.WriteLine("Random seed: {0}", seed); + LinqTestInput.RandomSeed = seed; + for (int i = 0; i < count; ++i) + { + data.Add(func(random)); + } + + foreach (T obj in data) + { + ItemResponse response = container.CreateItemAsync(obj, new Cosmos.PartitionKey("Test")).Result; + } + + FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; + QueryRequestOptions requestOptions = new QueryRequestOptions(); + + IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions); + + // To cover both query against backend and queries on the original data using LINQ nicely, + // the LINQ expression should be written once and they should be compiled and executed against the two sources. + // That is done by using Func that take a boolean Func. The parameter of the Func indicate whether the Cosmos DB query + // or the data list should be used. When a test is executed, the compiled LINQ expression would pass different values + // to this getQuery method. + IQueryable getQuery(bool useQuery) => useQuery ? query : data.AsQueryable(); + + return getQuery; + } + + /// + /// Generate a non-random payload for serializer LINQ tests. + /// + /// the object type + /// the lamda to create an instance of test data + /// number of test data to be created + /// the target container + /// if theCosmosLinqSerializerOption of camelCaseSerialization should be applied + /// a lambda that takes a boolean which indicate where the query should run against CosmosDB or against original data, and return a query results as IQueryable. + public static Func> GenerateSerializationTestCosmosData(Func func, int count, Container container, CosmosLinqSerializerOptions linqSerializerOptions) + { + List data = new List(); + for (int i = 0; i < count; i++) + { + data.Add(func(i, linqSerializerOptions.PropertyNamingPolicy == CosmosPropertyNamingPolicy.CamelCase)); + } + + foreach (T obj in data) + { + ItemResponse response = container.CreateItemAsync(obj, new Cosmos.PartitionKey("Test")).Result; + } + + FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; + QueryRequestOptions requestOptions = new QueryRequestOptions(); + + IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions, linqSerializerOptions: linqSerializerOptions); + + IQueryable getQuery(bool useQuery) => useQuery ? query : data.AsQueryable(); + + return getQuery; + } + + public static Func> GenerateFamilyCosmosData( + Cosmos.Database cosmosDatabase, out Container container) + { + // The test collection should have range index on string properties + // for the orderby tests + PartitionKeyDefinition partitionKeyDefinition = new PartitionKeyDefinition { Paths = new System.Collections.ObjectModel.Collection(new[] { "/Pk" }), Kind = PartitionKind.Hash }; + ContainerProperties newCol = new ContainerProperties() + { + Id = Guid.NewGuid().ToString(), + PartitionKey = partitionKeyDefinition, + IndexingPolicy = new Microsoft.Azure.Cosmos.IndexingPolicy() + { + IncludedPaths = new Collection() + { + new Cosmos.IncludedPath() + { + Path = "/*", + Indexes = new System.Collections.ObjectModel.Collection() + { + Microsoft.Azure.Cosmos.Index.Range(Microsoft.Azure.Cosmos.DataType.Number, -1), + Microsoft.Azure.Cosmos.Index.Range(Microsoft.Azure.Cosmos.DataType.String, -1) + } + } + }, + CompositeIndexes = new Collection>() + { + new Collection() + { + new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending } + }, + new Collection() + { + new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Descending } + }, + new Collection() + { + new Cosmos.CompositePath() { Path = "/FamilyId", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Descending } + }, + new Collection() + { + new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Descending } + }, + new Collection() + { + new Cosmos.CompositePath() { Path = "/IsRegistered", Order = Cosmos.CompositePathSortOrder.Ascending }, + new Cosmos.CompositePath() { Path = "/Int", Order = Cosmos.CompositePathSortOrder.Descending } + } + } + } + }; + container = cosmosDatabase.CreateContainerAsync(newCol).Result; + const int Records = 100; + const int MaxNameLength = 100; + const int MaxThingStringLength = 50; + const int MaxChild = 5; + const int MaxPets = MaxChild; + const int MaxThings = MaxChild; + const int MaxGrade = 101; + const int MaxTransaction = 20; + const int MaxTransactionMinuteRange = 200; + int MaxTransactionType = Enum.GetValues(typeof(TransactionType)).Length; + Family createDataObj(Random random) + { + Family obj = new Family + { + FamilyId = random.NextDouble() < 0.05 ? "some id" : Guid.NewGuid().ToString(), + IsRegistered = random.NextDouble() < 0.5, + NullableInt = random.NextDouble() < 0.5 ? (int?)random.Next() : null, + Int = random.NextDouble() < 0.5 ? 5 : random.Next(), + Id = Guid.NewGuid().ToString(), + Pk = "Test", + Parents = new Parent[random.Next(2) + 1] + }; + for (int i = 0; i < obj.Parents.Length; ++i) + { + obj.Parents[i] = new Parent() + { + FamilyName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), + GivenName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)) + }; + } + + obj.Tags = new string[random.Next(MaxChild)]; + for (int i = 0; i < obj.Tags.Length; ++i) + { + obj.Tags[i] = (i + random.Next(30, 36)).ToString(); + } + + obj.Children = new Child[random.Next(MaxChild)]; + for (int i = 0; i < obj.Children.Length; ++i) + { + obj.Children[i] = new Child() + { + Gender = random.NextDouble() < 0.5 ? "male" : "female", + FamilyName = obj.Parents[random.Next(obj.Parents.Length)].FamilyName, + GivenName = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), + Grade = random.Next(MaxGrade) + }; + + obj.Children[i].Pets = new List(); + for (int j = 0; j < random.Next(MaxPets); ++j) + { + obj.Children[i].Pets.Add(new Pet() + { + GivenName = random.NextDouble() < 0.5 ? + LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)) : + "Fluffy" + }); + } + + obj.Children[i].Things = new Dictionary(); + for (int j = 0; j < random.Next(MaxThings) + 1; ++j) + { + obj.Children[i].Things.Add( + j == 0 ? "A" : $"{j}-{random.Next()}", + LinqTestsCommon.RandomString(random, random.Next(MaxThingStringLength))); + } + } + + obj.Records = new Logs + { + LogId = LinqTestsCommon.RandomString(random, random.Next(MaxNameLength)), + Transactions = new Transaction[random.Next(MaxTransaction)] + }; + for (int i = 0; i < obj.Records.Transactions.Length; ++i) + { + Transaction transaction = new Transaction() + { + Amount = random.Next(), + Date = DateTime.Now.AddMinutes(random.Next(MaxTransactionMinuteRange)), + Type = (TransactionType)random.Next(MaxTransactionType) + }; + obj.Records.Transactions[i] = transaction; + } + + return obj; + } + + Func> getQuery = LinqTestsCommon.GenerateTestCosmosData(createDataObj, Records, container); + return getQuery; + } + + public static Func> GenerateSimpleCosmosData(Cosmos.Database cosmosDatabase, bool useRandomData = true) + { + const int DocumentCount = 10; + PartitionKeyDefinition partitionKeyDefinition = new PartitionKeyDefinition { Paths = new System.Collections.ObjectModel.Collection(new[] { "/Pk" }), Kind = PartitionKind.Hash }; + Container container = cosmosDatabase.CreateContainerAsync(new ContainerProperties { Id = Guid.NewGuid().ToString(), PartitionKey = partitionKeyDefinition }).Result; + + ILinqTestDataGenerator dataGenerator = useRandomData ? new LinqTestRandomDataGenerator(DocumentCount) : new LinqTestDataGenerator(DocumentCount); + List testData = new List(dataGenerator.GenerateData()); + foreach (Data dataEntry in testData) + { + Data response = container.CreateItemAsync(dataEntry, new Cosmos.PartitionKey(dataEntry.Pk)).Result; + } + + FeedOptions feedOptions = new FeedOptions() { EnableScanInQuery = true, EnableCrossPartitionQuery = true }; + QueryRequestOptions requestOptions = new QueryRequestOptions(); + + IOrderedQueryable query = container.GetItemLinqQueryable(allowSynchronousQueryExecution: true, requestOptions: requestOptions); + + // To cover both query against backend and queries on the original data using LINQ nicely, + // the LINQ expression should be written once and they should be compiled and executed against the two sources. + // That is done by using Func that take a boolean Func. The parameter of the Func indicate whether the Cosmos DB query + // or the data list should be used. When a test is executed, the compiled LINQ expression would pass different values + // to this getQuery method. + IQueryable getQuery(bool useQuery) => useQuery ? query : testData.AsQueryable(); + return getQuery; + } + + public static LinqTestOutput ExecuteTest(LinqTestInput input, bool serializeResultsInBaseline = false) + { + string querySqlStr = string.Empty; + try + { + Func compiledQuery = input.Expression.Compile(); + + IQueryable query = compiledQuery(true); + querySqlStr = JObject.Parse(query.ToString()).GetValue("query", StringComparison.Ordinal).ToString(); + + IQueryable dataQuery = input.skipVerification ? null : compiledQuery(false); + + (List queryResults, List dataResults) = GetResults(query, dataQuery); + + // we skip unordered query because the LINQ results vs actual query results are non-deterministic + if (!input.skipVerification) + { + LinqTestsCommon.ValidateResults(queryResults, dataResults); + } + + string serializedResults = serializeResultsInBaseline ? + JsonConvert.SerializeObject(queryResults.Select(item => item is LinqTestObject ? item.ToString() : item), new JsonSerializerSettings { Formatting = Newtonsoft.Json.Formatting.Indented}) : + null; + + return new LinqTestOutput(querySqlStr, serializedResults, errorMsg: null, input.inputData); + } + catch (Exception e) + { + return new LinqTestOutput(querySqlStr, serializedResults: null, errorMsg: LinqTestsCommon.BuildExceptionMessageForTest(e), inputData: input.inputData); + } + } + + public static string BuildExceptionMessageForTest(Exception ex) { StringBuilder message = new StringBuilder(); - do - { - if (ex is CosmosException cosmosException) + do + { + if (ex is CosmosException cosmosException) { // ODE scenario: The backend generates an error response message with significant variations when compared to the Service Interop which gets called in the Non ODE scenario. // The objective is to standardize and normalize the backend response for consistency. @@ -602,336 +590,336 @@ public static string BuildExceptionMessageForTest(Exception ex) else { message.Append($"Status Code: {cosmosException.StatusCode}"); - } - } - else if (ex is DocumentClientException documentClientException) - { - message.Append(documentClientException.RawErrorMessage); - } - else + } + } + else if (ex is DocumentClientException documentClientException) + { + message.Append(documentClientException.RawErrorMessage); + } + else { message.Append(ex.Message); - } - - ex = ex.InnerException; - if (ex != null) - { - message.Append(","); - } - } + } + + ex = ex.InnerException; + if (ex != null) + { + message.Append(","); + } + } while (ex != null); return message.ToString(); - } - } - - /// - /// A base class that determines equality based on its json representation - /// - public class LinqTestObject - { - private string json; - - protected virtual string SerializeForTestBaseline() - { - return JsonConvert.SerializeObject(this); - } - - public override string ToString() - { - // simple cached serialization - this.json ??= this.SerializeForTestBaseline(); - return this.json; - } - - public override bool Equals(object obj) - { - if (!(obj is LinqTestObject && - obj.GetType().IsAssignableFrom(this.GetType()) && - this.GetType().IsAssignableFrom(obj.GetType()))) return false; - if (obj == null) return false; - - return this.ToString().Equals(obj.ToString()); - } - - public override int GetHashCode() - { - return this.ToString().GetHashCode(); - } - } - - public class LinqTestInput : BaselineTestInput - { - internal static Regex classNameRegex = new Regex("(value\\(.+?\\+)?\\<\\>.+?__([A-Za-z]+)((\\d+_\\d+(`\\d+\\[.+?\\])?\\)(\\.value)?)|\\d+`\\d+)"); - internal static Regex invokeCompileRegex = new Regex("(Convert\\()?Invoke\\([^.]+\\.[^.,]+(\\.Compile\\(\\))?, b\\)(\\.Cast\\(\\))?(\\))?"); - - // As the tests are executed sequentially - // We can store the random seed in a static variable for diagnostics - internal static int RandomSeed = -1; - - internal int randomSeed = -1; - internal Expression> Expression { get; } - internal string expressionStr; - internal string inputData; - - // We skip the verification between Cosmos DB and actual query restuls in the following cases - // - unordered query since the results are not deterministics for LinQ results and actual query results - // - scenarios not supported in LINQ, e.g. sequence doesn't contain element. - internal bool skipVerification; - - internal LinqTestInput( - string description, - Expression> expr, - bool skipVerification = false, - string expressionStr = null, - string inputData = null) - : base(description) - { - this.Expression = expr ?? throw new ArgumentNullException($"{nameof(expr)} must not be null."); - this.skipVerification = skipVerification; - this.expressionStr = expressionStr; - this.inputData = inputData; - } - - public static string FilterInputExpression(string input) - { - StringBuilder expressionSb = new StringBuilder(input); - // simplify full qualified class name - // e.g. before: value(Microsoft.Azure.Documents.Services.Management.Tests.LinqSQLTranslationTest+<>c__DisplayClass7_0), after: DisplayClass - // before: <>f__AnonymousType14`2(, after: AnonymousType( - // value(Microsoft.Azure.Documents.Services.Management.Tests.LinqProviderTests.LinqTranslationBaselineTests +<> c__DisplayClass24_0`1[System.String]).value - Match match = classNameRegex.Match(expressionSb.ToString()); - while (match.Success) - { - expressionSb = expressionSb.Replace(match.Groups[0].Value, match.Groups[2].Value); - match = match.NextMatch(); - } - - // remove the Invoke().Compile() string from the Linq scanning tests - match = invokeCompileRegex.Match(expressionSb.ToString()); - while (match.Success) - { - expressionSb = expressionSb.Replace(match.Groups[0].Value, string.Empty); - match = match.NextMatch(); - } - - expressionSb.Insert(0, "query"); - - return expressionSb.ToString(); - } - - public override void SerializeAsXml(XmlWriter xmlWriter) - { - if (xmlWriter == null) - { - throw new ArgumentNullException($"{nameof(xmlWriter)} cannot be null."); - } - - this.expressionStr ??= LinqTestInput.FilterInputExpression(this.Expression.Body.ToString()); - - xmlWriter.WriteStartElement("Description"); - xmlWriter.WriteCData(this.Description); - xmlWriter.WriteEndElement(); - xmlWriter.WriteStartElement("Expression"); - xmlWriter.WriteCData(this.expressionStr); - xmlWriter.WriteEndElement(); - } - } - - public class LinqTestOutput : BaselineTestOutput - { - internal static Regex sdkVersion = new Regex("(,\\W*)?documentdb-dotnet-sdk[^]]+"); - internal static Regex activityId = new Regex("(,\\W*)?ActivityId:.+", RegexOptions.Multiline); - internal static Regex newLine = new Regex("(\r\n|\r|\n)"); - - internal string SqlQuery { get; } - internal string ErrorMessage { get; } - internal string Results { get; } - internal string InputData { get; } - - private static readonly Dictionary newlineKeywords = new Dictionary() { - { "SELECT", "\nSELECT" }, - { "FROM", "\nFROM" }, - { "WHERE", "\nWHERE" }, - { "JOIN", "\nJOIN" }, - { "ORDER BY", "\nORDER BY" }, - { "OFFSET", "\nOFFSET" }, - { " )", "\n)" } - }; - - public static string FormatErrorMessage(string msg) - { - msg = newLine.Replace(msg, string.Empty); - - // remove sdk version in the error message which can change in the future. - // e.g. - msg = sdkVersion.Replace(msg, string.Empty); - - // remove activity Id - msg = activityId.Replace(msg, string.Empty); - - return msg; - } - - internal LinqTestOutput(string sqlQuery, string serializedResults, string errorMsg, string inputData) - { - this.SqlQuery = FormatSql(sqlQuery); - this.Results = serializedResults; - this.ErrorMessage = errorMsg; - this.InputData = inputData; - } - - public static String FormatSql(string sqlQuery) - { - const string subqueryCue = "(SELECT"; - bool hasSubquery = sqlQuery.IndexOf(subqueryCue, StringComparison.OrdinalIgnoreCase) > 0; - - StringBuilder sb = new StringBuilder(sqlQuery); - foreach (KeyValuePair kv in newlineKeywords) - { - sb.Replace(kv.Key, kv.Value); - } - - if (!hasSubquery) return sb.ToString(); - - const string oneTab = " "; - const string startCue = "SELECT"; - const string endCue = ")"; - - string[] tokens = sb.ToString().Split('\n'); - bool firstSelect = true; - sb.Length = 0; - StringBuilder indentSb = new StringBuilder(); - for (int i = 0; i < tokens.Length; ++i) - { - if (tokens[i].StartsWith(startCue, StringComparison.OrdinalIgnoreCase)) - { - if (!firstSelect) indentSb.Append(oneTab); else firstSelect = false; - - } - else if (tokens[i].StartsWith(endCue, StringComparison.OrdinalIgnoreCase)) - { - indentSb.Length -= oneTab.Length; - } - - sb.Append(indentSb).Append(tokens[i]).Append("\n"); - } - - return sb.ToString(); - } - - public override void SerializeAsXml(XmlWriter xmlWriter) - { - xmlWriter.WriteStartElement(nameof(this.SqlQuery)); - xmlWriter.WriteCData(this.SqlQuery); - xmlWriter.WriteEndElement(); - if (this.InputData != null) - { - xmlWriter.WriteStartElement("InputData"); - xmlWriter.WriteCData(this.InputData); - xmlWriter.WriteEndElement(); - } - if (this.Results != null) - { - xmlWriter.WriteStartElement("Results"); - xmlWriter.WriteCData(this.Results); - xmlWriter.WriteEndElement(); - } - if (this.ErrorMessage != null) - { - xmlWriter.WriteStartElement("ErrorMessage"); - xmlWriter.WriteCData(LinqTestOutput.FormatErrorMessage(this.ErrorMessage)); - xmlWriter.WriteEndElement(); - } - } - } - - class SystemTextJsonLinqSerializer : CosmosLinqSerializer - { - private readonly JsonObjectSerializer systemTextJsonSerializer; - - public SystemTextJsonLinqSerializer(JsonSerializerOptions jsonSerializerOptions) - { - this.systemTextJsonSerializer = new JsonObjectSerializer(jsonSerializerOptions); - } - - public override T FromStream(Stream stream) - { - if (stream == null) - throw new ArgumentNullException(nameof(stream)); - - using (stream) - { - if (stream.CanSeek && stream.Length == 0) - { - return default; - } - - if (typeof(Stream).IsAssignableFrom(typeof(T))) - { - return (T)(object)stream; - } - - return (T)this.systemTextJsonSerializer.Deserialize(stream, typeof(T), default); - } - } - - public override Stream ToStream(T input) - { - MemoryStream streamPayload = new MemoryStream(); - this.systemTextJsonSerializer.Serialize(streamPayload, input, input.GetType(), default); - streamPayload.Position = 0; - return streamPayload; - } - - public override string SerializeMemberName(MemberInfo memberInfo) - { - JsonPropertyNameAttribute jsonPropertyNameAttribute = memberInfo.GetCustomAttribute(true); - - string memberName = !string.IsNullOrEmpty(jsonPropertyNameAttribute?.Name) - ? jsonPropertyNameAttribute.Name - : memberInfo.Name; - - return memberName; - } + } } - class SystemTextJsonSerializer : CosmosSerializer - { - private readonly JsonObjectSerializer systemTextJsonSerializer; - - public SystemTextJsonSerializer(JsonSerializerOptions jsonSerializerOptions) - { - this.systemTextJsonSerializer = new JsonObjectSerializer(jsonSerializerOptions); - } - - public override T FromStream(Stream stream) - { - if (stream == null) - throw new ArgumentNullException(nameof(stream)); - - using (stream) - { - if (stream.CanSeek && stream.Length == 0) - { - return default; - } - - if (typeof(Stream).IsAssignableFrom(typeof(T))) - { - return (T)(object)stream; - } - - return (T)this.systemTextJsonSerializer.Deserialize(stream, typeof(T), default); - } - } - - public override Stream ToStream(T input) - { - MemoryStream streamPayload = new MemoryStream(); - this.systemTextJsonSerializer.Serialize(streamPayload, input, input.GetType(), default); - streamPayload.Position = 0; - return streamPayload; - } - } -} + /// + /// A base class that determines equality based on its json representation + /// + public class LinqTestObject + { + private string json; + + protected virtual string SerializeForTestBaseline() + { + return JsonConvert.SerializeObject(this); + } + + public override string ToString() + { + // simple cached serialization + this.json ??= this.SerializeForTestBaseline(); + return this.json; + } + + public override bool Equals(object obj) + { + if (!(obj is LinqTestObject && + obj.GetType().IsAssignableFrom(this.GetType()) && + this.GetType().IsAssignableFrom(obj.GetType()))) return false; + if (obj == null) return false; + + return this.ToString().Equals(obj.ToString()); + } + + public override int GetHashCode() + { + return this.ToString().GetHashCode(); + } + } + + public class LinqTestInput : BaselineTestInput + { + internal static Regex classNameRegex = new Regex("(value\\(.+?\\+)?\\<\\>.+?__([A-Za-z]+)((\\d+_\\d+(`\\d+\\[.+?\\])?\\)(\\.value)?)|\\d+`\\d+)"); + internal static Regex invokeCompileRegex = new Regex("(Convert\\()?Invoke\\([^.]+\\.[^.,]+(\\.Compile\\(\\))?, b\\)(\\.Cast\\(\\))?(\\))?"); + + // As the tests are executed sequentially + // We can store the random seed in a static variable for diagnostics + internal static int RandomSeed = -1; + + internal int randomSeed = -1; + internal Expression> Expression { get; } + internal string expressionStr; + internal string inputData; + + // We skip the verification between Cosmos DB and actual query restuls in the following cases + // - unordered query since the results are not deterministics for LinQ results and actual query results + // - scenarios not supported in LINQ, e.g. sequence doesn't contain element. + internal bool skipVerification; + + internal LinqTestInput( + string description, + Expression> expr, + bool skipVerification = false, + string expressionStr = null, + string inputData = null) + : base(description) + { + this.Expression = expr ?? throw new ArgumentNullException($"{nameof(expr)} must not be null."); + this.skipVerification = skipVerification; + this.expressionStr = expressionStr; + this.inputData = inputData; + } + + public static string FilterInputExpression(string input) + { + StringBuilder expressionSb = new StringBuilder(input); + // simplify full qualified class name + // e.g. before: value(Microsoft.Azure.Documents.Services.Management.Tests.LinqSQLTranslationTest+<>c__DisplayClass7_0), after: DisplayClass + // before: <>f__AnonymousType14`2(, after: AnonymousType( + // value(Microsoft.Azure.Documents.Services.Management.Tests.LinqProviderTests.LinqTranslationBaselineTests +<> c__DisplayClass24_0`1[System.String]).value + Match match = classNameRegex.Match(expressionSb.ToString()); + while (match.Success) + { + expressionSb = expressionSb.Replace(match.Groups[0].Value, match.Groups[2].Value); + match = match.NextMatch(); + } + + // remove the Invoke().Compile() string from the Linq scanning tests + match = invokeCompileRegex.Match(expressionSb.ToString()); + while (match.Success) + { + expressionSb = expressionSb.Replace(match.Groups[0].Value, string.Empty); + match = match.NextMatch(); + } + + expressionSb.Insert(0, "query"); + + return expressionSb.ToString(); + } + + public override void SerializeAsXml(XmlWriter xmlWriter) + { + if (xmlWriter == null) + { + throw new ArgumentNullException($"{nameof(xmlWriter)} cannot be null."); + } + + this.expressionStr ??= LinqTestInput.FilterInputExpression(this.Expression.Body.ToString()); + + xmlWriter.WriteStartElement("Description"); + xmlWriter.WriteCData(this.Description); + xmlWriter.WriteEndElement(); + xmlWriter.WriteStartElement("Expression"); + xmlWriter.WriteCData(this.expressionStr); + xmlWriter.WriteEndElement(); + } + } + + public class LinqTestOutput : BaselineTestOutput + { + internal static Regex sdkVersion = new Regex("(,\\W*)?documentdb-dotnet-sdk[^]]+"); + internal static Regex activityId = new Regex("(,\\W*)?ActivityId:.+", RegexOptions.Multiline); + internal static Regex newLine = new Regex("(\r\n|\r|\n)"); + + internal string SqlQuery { get; } + internal string ErrorMessage { get; } + internal string Results { get; } + internal string InputData { get; } + + private static readonly Dictionary newlineKeywords = new Dictionary() { + { "SELECT", "\nSELECT" }, + { "FROM", "\nFROM" }, + { "WHERE", "\nWHERE" }, + { "JOIN", "\nJOIN" }, + { "ORDER BY", "\nORDER BY" }, + { "OFFSET", "\nOFFSET" }, + { " )", "\n)" } + }; + + public static string FormatErrorMessage(string msg) + { + msg = newLine.Replace(msg, string.Empty); + + // remove sdk version in the error message which can change in the future. + // e.g. + msg = sdkVersion.Replace(msg, string.Empty); + + // remove activity Id + msg = activityId.Replace(msg, string.Empty); + + return msg; + } + + internal LinqTestOutput(string sqlQuery, string serializedResults, string errorMsg, string inputData) + { + this.SqlQuery = FormatSql(sqlQuery); + this.Results = serializedResults; + this.ErrorMessage = errorMsg; + this.InputData = inputData; + } + + public static String FormatSql(string sqlQuery) + { + const string subqueryCue = "(SELECT"; + bool hasSubquery = sqlQuery.IndexOf(subqueryCue, StringComparison.OrdinalIgnoreCase) > 0; + + StringBuilder sb = new StringBuilder(sqlQuery); + foreach (KeyValuePair kv in newlineKeywords) + { + sb.Replace(kv.Key, kv.Value); + } + + if (!hasSubquery) return sb.ToString(); + + const string oneTab = " "; + const string startCue = "SELECT"; + const string endCue = ")"; + + string[] tokens = sb.ToString().Split('\n'); + bool firstSelect = true; + sb.Length = 0; + StringBuilder indentSb = new StringBuilder(); + for (int i = 0; i < tokens.Length; ++i) + { + if (tokens[i].StartsWith(startCue, StringComparison.OrdinalIgnoreCase)) + { + if (!firstSelect) indentSb.Append(oneTab); else firstSelect = false; + + } + else if (tokens[i].StartsWith(endCue, StringComparison.OrdinalIgnoreCase)) + { + indentSb.Length -= oneTab.Length; + } + + sb.Append(indentSb).Append(tokens[i]).Append("\n"); + } + + return sb.ToString(); + } + + public override void SerializeAsXml(XmlWriter xmlWriter) + { + xmlWriter.WriteStartElement(nameof(this.SqlQuery)); + xmlWriter.WriteCData(this.SqlQuery); + xmlWriter.WriteEndElement(); + if (this.InputData != null) + { + xmlWriter.WriteStartElement("InputData"); + xmlWriter.WriteCData(this.InputData); + xmlWriter.WriteEndElement(); + } + if (this.Results != null) + { + xmlWriter.WriteStartElement("Results"); + xmlWriter.WriteCData(this.Results); + xmlWriter.WriteEndElement(); + } + if (this.ErrorMessage != null) + { + xmlWriter.WriteStartElement("ErrorMessage"); + xmlWriter.WriteCData(LinqTestOutput.FormatErrorMessage(this.ErrorMessage)); + xmlWriter.WriteEndElement(); + } + } + } + + class SystemTextJsonLinqSerializer : CosmosLinqSerializer + { + private readonly JsonObjectSerializer systemTextJsonSerializer; + + public SystemTextJsonLinqSerializer(JsonSerializerOptions jsonSerializerOptions) + { + this.systemTextJsonSerializer = new JsonObjectSerializer(jsonSerializerOptions); + } + + public override T FromStream(Stream stream) + { + if (stream == null) + throw new ArgumentNullException(nameof(stream)); + + using (stream) + { + if (stream.CanSeek && stream.Length == 0) + { + return default; + } + + if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + return (T)(object)stream; + } + + return (T)this.systemTextJsonSerializer.Deserialize(stream, typeof(T), default); + } + } + + public override Stream ToStream(T input) + { + MemoryStream streamPayload = new MemoryStream(); + this.systemTextJsonSerializer.Serialize(streamPayload, input, input.GetType(), default); + streamPayload.Position = 0; + return streamPayload; + } + + public override string SerializeMemberName(MemberInfo memberInfo) + { + JsonPropertyNameAttribute jsonPropertyNameAttribute = memberInfo.GetCustomAttribute(true); + + string memberName = !string.IsNullOrEmpty(jsonPropertyNameAttribute?.Name) + ? jsonPropertyNameAttribute.Name + : memberInfo.Name; + + return memberName; + } + } + + class SystemTextJsonSerializer : CosmosSerializer + { + private readonly JsonObjectSerializer systemTextJsonSerializer; + + public SystemTextJsonSerializer(JsonSerializerOptions jsonSerializerOptions) + { + this.systemTextJsonSerializer = new JsonObjectSerializer(jsonSerializerOptions); + } + + public override T FromStream(Stream stream) + { + if (stream == null) + throw new ArgumentNullException(nameof(stream)); + + using (stream) + { + if (stream.CanSeek && stream.Length == 0) + { + return default; + } + + if (typeof(Stream).IsAssignableFrom(typeof(T))) + { + return (T)(object)stream; + } + + return (T)this.systemTextJsonSerializer.Deserialize(stream, typeof(T), default); + } + } + + public override Stream ToStream(T input) + { + MemoryStream streamPayload = new MemoryStream(); + this.systemTextJsonSerializer.Serialize(streamPayload, input, input.GetType(), default); + streamPayload.Position = 0; + return streamPayload; + } + } +} diff --git a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj index 296072a411..61e7745427 100644 --- a/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj +++ b/Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj @@ -37,6 +37,7 @@ + @@ -164,6 +165,9 @@ PreserveNewest + + PreserveNewest + PreserveNewest