From 5c4b3415ea6df6ff6969d6cff2115955004f904b Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Sun, 15 Aug 2021 15:01:07 +0200 Subject: [PATCH] Cosmos FromSql Closes #17311 --- .../Extensions/CosmosQueryableExtensions.cs | 55 +- .../CosmosDiscriminatorConvention.cs | 8 +- .../Properties/CosmosStrings.Designer.cs | 8 + .../Properties/CosmosStrings.resx | 3 + .../CosmosQueryTranslationPostprocessor.cs | 3 +- ...yableMethodTranslatingExpressionVisitor.cs | 24 +- ...osShapedQueryCompilingExpressionVisitor.cs | 2 - .../Query/Internal/FromSqlExpression.cs | 108 +++ .../Internal/FromSqlQueryRootExpression.cs | 141 ++++ .../Query/Internal/ISqlExpressionFactory.cs | 8 + .../Query/Internal/QuerySqlGenerator.cs | 121 +++- .../Internal/QuerySqlGeneratorFactory.cs | 20 +- .../Query/Internal/SelectExpression.cs | 31 +- .../Query/Internal/SqlExpressionFactory.cs | 14 + .../Query/Internal/SqlExpressionVisitor.cs | 11 + .../Properties/RelationalStrings.Designer.cs | 8 + .../Properties/RelationalStrings.resx | 3 + ...mSqlParameterExpandingExpressionVisitor.cs | 144 ++-- .../Internal/FromSqlQueryRootExpression.cs | 2 +- .../Query/QuerySqlGenerator.cs | 25 +- .../Query/SqlExpressions/SelectExpression.cs | 2 +- .../EFCore.Cosmos.FunctionalTests.csproj | 4 + .../Query/FromSqlQueryCosmosTest.cs | 615 ++++++++++++++++++ 23 files changed, 1238 insertions(+), 122 deletions(-) create mode 100644 src/EFCore.Cosmos/Query/Internal/FromSqlExpression.cs create mode 100644 src/EFCore.Cosmos/Query/Internal/FromSqlQueryRootExpression.cs create mode 100644 test/EFCore.Cosmos.FunctionalTests/Query/FromSqlQueryCosmosTest.cs diff --git a/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs b/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs index b83dc622604..eb55e255aa8 100644 --- a/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs +++ b/src/EFCore.Cosmos/Extensions/CosmosQueryableExtensions.cs @@ -5,6 +5,8 @@ using System.Linq; using System.Linq.Expressions; using System.Reflection; +using JetBrains.Annotations; +using Microsoft.EntityFrameworkCore.Cosmos.Query.Internal; using Microsoft.EntityFrameworkCore.Query; using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.Utilities; @@ -13,7 +15,7 @@ namespace Microsoft.EntityFrameworkCore { /// - /// Cosmos DB specific extension methods for LINQ queries. + /// Cosmos-specific extension methods for LINQ queries. /// public static class CosmosQueryableExtensions { @@ -46,5 +48,56 @@ source.Provider is EntityQueryProvider Expression.Constant(partitionKey))) : source; } + + /// + /// + /// Creates a LINQ query based on a raw SQL query. + /// + /// + /// You can compose on top of the raw SQL query using LINQ operators: + /// + /// context.Blogs.FromSqlRaw("SELECT * FROM root c).OrderBy(b => b.Name) + /// + /// As with any API that accepts SQL it is important to parameterize any user input to protect against a SQL injection + /// attack. You can include parameter place holders in the SQL query string and then supply parameter values as additional + /// arguments. Any parameter values you supply will automatically be converted to a Cosmos parameter: + /// + /// context.Blogs.FromSqlRaw(""SELECT * FROM root c WHERE c["Name"] = {0})", userSuppliedSearchTerm) + /// + /// The type of the elements of . + /// + /// An to use as the base of the raw SQL query (typically a ). + /// + /// The raw SQL query. + /// The values to be assigned to parameters. + /// An representing the raw SQL query. + [StringFormatMethod("sql")] + public static IQueryable FromSqlRaw( + this IQueryable source, + [NotParameterized] string sql, + params object[] parameters) + where TEntity : class + { + Check.NotNull(source, nameof(source)); + Check.NotEmpty(sql, nameof(sql)); + Check.NotNull(parameters, nameof(parameters)); + + var queryRootExpression = (QueryRootExpression)source.Expression; + + var entityType = queryRootExpression.EntityType; + + Check.DebugAssert( + (entityType.BaseType is null && !entityType.GetDirectlyDerivedTypes().Any()) + || entityType.FindDiscriminatorProperty() is not null, + "Found FromSql on a TPT entity type, but TPT isn't supported on Cosmos"); + + var fromSqlQueryRootExpression = new FromSqlQueryRootExpression( + queryRootExpression.QueryProvider!, + entityType, + sql, + Expression.Constant(parameters)); + + return source.Provider.CreateQuery(fromSqlQueryRootExpression); + } } } diff --git a/src/EFCore.Cosmos/Metadata/Conventions/CosmosDiscriminatorConvention.cs b/src/EFCore.Cosmos/Metadata/Conventions/CosmosDiscriminatorConvention.cs index 60d76dbfb03..623a51283c6 100644 --- a/src/EFCore.Cosmos/Metadata/Conventions/CosmosDiscriminatorConvention.cs +++ b/src/EFCore.Cosmos/Metadata/Conventions/CosmosDiscriminatorConvention.cs @@ -102,14 +102,14 @@ private void ProcessEntityType(IConventionEntityTypeBuilder entityTypeBuilder) return; } - if (!entityType.IsDocumentRoot()) + if (entityType.IsDocumentRoot()) { - entityTypeBuilder.HasNoDiscriminator(); + entityTypeBuilder.HasDiscriminator(typeof(string)) + ?.HasValue(entityType, entityType.ShortName()); } else { - entityTypeBuilder.HasDiscriminator(typeof(string)) - ?.HasValue(entityType, entityType.ShortName()); + entityTypeBuilder.HasNoDiscriminator(); } } diff --git a/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs b/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs index d4202d762da..bf066fe8ecf 100644 --- a/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs +++ b/src/EFCore.Cosmos/Properties/CosmosStrings.Designer.cs @@ -86,6 +86,14 @@ public static string InvalidDerivedTypeInEntityProjection(object? derivedType, o GetString("InvalidDerivedTypeInEntityProjection", nameof(derivedType), nameof(entityType)), derivedType, entityType); + /// + /// A FromSqlExpression has an invalid arguments expression type '{expressionType}' or value type '{valueType}'. + /// + public static string InvalidFromSqlArguments(object? expressionType, object? valueType) + => string.Format( + GetString("InvalidFromSqlArguments", nameof(expressionType), nameof(valueType)), + expressionType, valueType); + /// /// Unable to generate a valid 'id' value to execute a 'ReadItem' query. This usually happens when the value provided for one of the properties is 'null' or an empty string. Please supply a value that's not 'null' or an empty string. /// diff --git a/src/EFCore.Cosmos/Properties/CosmosStrings.resx b/src/EFCore.Cosmos/Properties/CosmosStrings.resx index 2f21c2993e5..3213e25a55d 100644 --- a/src/EFCore.Cosmos/Properties/CosmosStrings.resx +++ b/src/EFCore.Cosmos/Properties/CosmosStrings.resx @@ -144,6 +144,9 @@ The specified entity type '{derivedType}' is not derived from '{entityType}'. + + A FromSqlExpression has an invalid arguments expression type '{expressionType}' or value type '{valueType}'. + Unable to generate a valid 'id' value to execute a 'ReadItem' query. This usually happens when the value provided for one of the properties is 'null' or an empty string. Please supply a value that's not 'null' or an empty string. diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryTranslationPostprocessor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryTranslationPostprocessor.cs index 4f520ef4a66..6c61757941d 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryTranslationPostprocessor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryTranslationPostprocessor.cs @@ -44,8 +44,7 @@ public override Expression Process(Expression query) { query = base.Process(query); - if (query is ShapedQueryExpression shapedQueryExpression - && shapedQueryExpression.QueryExpression is SelectExpression selectExpression) + if (query is ShapedQueryExpression { QueryExpression: SelectExpression selectExpression }) { // Cosmos does not have nested select expression so this should be safe. selectExpression.ApplyProjection(); diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs index 6977c168f13..5966236baf8 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs @@ -128,7 +128,7 @@ public override Expression Visit(Expression expression) var readItemExpression = new ReadItemExpression(entityType, propertyParameterList); - return CreateShapedQueryExpression(readItemExpression, entityType) + return CreateShapedQueryExpression(entityType, readItemExpression) .UpdateResultCardinality(ResultCardinality.Single); } } @@ -187,6 +187,24 @@ static bool TryGetPartitionKeyProperty(IEntityType entityType, out IProperty par } } + /// + protected override Expression VisitExtension(Expression extensionExpression) + { + switch (extensionExpression) + { + case FromSqlQueryRootExpression fromSqlQueryRootExpression: + return CreateShapedQueryExpression( + fromSqlQueryRootExpression.EntityType, + _sqlExpressionFactory.Select( + fromSqlQueryRootExpression.EntityType, + fromSqlQueryRootExpression.Sql, + fromSqlQueryRootExpression.Argument)); + + default: + return base.VisitExtension(extensionExpression); + } + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -246,10 +264,10 @@ protected override ShapedQueryExpression CreateShapedQueryExpression(IEntityType var selectExpression = _sqlExpressionFactory.Select(entityType); - return CreateShapedQueryExpression(selectExpression, entityType); + return CreateShapedQueryExpression(entityType, selectExpression); } - private ShapedQueryExpression CreateShapedQueryExpression(Expression queryExpression, IEntityType entityType) + private ShapedQueryExpression CreateShapedQueryExpression(IEntityType entityType, Expression queryExpression) => new( queryExpression, new EntityShaperExpression( diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs index 7474b41f5b4..5045fc389f5 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosShapedQueryCompilingExpressionVisitor.cs @@ -65,7 +65,6 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery switch (shapedQueryExpression.QueryExpression) { case SelectExpression selectExpression: - shaperBody = new CosmosProjectionBindingRemovingExpressionVisitor( selectExpression, jObjectParameter, QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll) @@ -92,7 +91,6 @@ protected override Expression VisitShapedQuery(ShapedQueryExpression shapedQuery Expression.Constant(_threadSafetyChecksEnabled)); case ReadItemExpression readItemExpression: - shaperBody = new CosmosProjectionBindingRemovingReadItemExpressionVisitor( readItemExpression, jObjectParameter, QueryCompilationContext.QueryTrackingBehavior == QueryTrackingBehavior.TrackAll) diff --git a/src/EFCore.Cosmos/Query/Internal/FromSqlExpression.cs b/src/EFCore.Cosmos/Query/Internal/FromSqlExpression.cs new file mode 100644 index 00000000000..09edce8fb2f --- /dev/null +++ b/src/EFCore.Cosmos/Query/Internal/FromSqlExpression.cs @@ -0,0 +1,108 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Utilities; + +#nullable disable + +namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public class FromSqlExpression : RootReferenceExpression, IPrintableExpression + { + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public FromSqlExpression(IEntityType entityType, string alias, string sql, Expression arguments) : base(entityType, alias) + { + Check.NotEmpty(sql, nameof(sql)); + Check.NotNull(arguments, nameof(arguments)); + + Sql = sql; + Arguments = arguments; + } + + /// + public override string Alias => base.Alias!; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual string Sql { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression Arguments { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual FromSqlExpression Update(Expression arguments) + { + Check.NotNull(arguments, nameof(arguments)); + + return arguments != Arguments + ? new FromSqlExpression(EntityType, Alias, Sql, arguments) + : this; + } + + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + Check.NotNull(visitor, nameof(visitor)); + + return this; + } + + /// + public override Type Type + => typeof(object); + + /// + void IPrintableExpression.Print(ExpressionPrinter expressionPrinter) + { + Check.NotNull(expressionPrinter, nameof(expressionPrinter)); + + expressionPrinter.Append(Sql); + } + + /// + public override bool Equals(object obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is FromSqlExpression fromSqlExpression + && Equals(fromSqlExpression)); + + private bool Equals(FromSqlExpression fromSqlExpression) + => base.Equals(fromSqlExpression) + && Sql == fromSqlExpression.Sql + && ExpressionEqualityComparer.Instance.Equals(Arguments, fromSqlExpression.Arguments); + + /// + public override int GetHashCode() + => HashCode.Combine(base.GetHashCode(), Sql); + } +} diff --git a/src/EFCore.Cosmos/Query/Internal/FromSqlQueryRootExpression.cs b/src/EFCore.Cosmos/Query/Internal/FromSqlQueryRootExpression.cs new file mode 100644 index 00000000000..2ac394e4c4a --- /dev/null +++ b/src/EFCore.Cosmos/Query/Internal/FromSqlQueryRootExpression.cs @@ -0,0 +1,141 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq.Expressions; +using Microsoft.EntityFrameworkCore.Metadata; +using Microsoft.EntityFrameworkCore.Query; +using Microsoft.EntityFrameworkCore.Utilities; + +namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal +{ + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public class FromSqlQueryRootExpression : QueryRootExpression + { + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public FromSqlQueryRootExpression( + IAsyncQueryProvider queryProvider, + IEntityType entityType, + string sql, + Expression argument) + : base(queryProvider, entityType) + { + Check.NotEmpty(sql, nameof(sql)); + Check.NotNull(argument, nameof(argument)); + + Sql = sql; + Argument = argument; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public FromSqlQueryRootExpression( + IEntityType entityType, + string sql, + Expression argument) + : base(entityType) + { + Check.NotEmpty(sql, nameof(sql)); + Check.NotNull(argument, nameof(argument)); + + Sql = sql; + Argument = argument; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual string Sql { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual Expression Argument { get; } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override Expression DetachQueryProvider() + => new FromSqlQueryRootExpression(EntityType, Sql, Argument); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override Expression VisitChildren(ExpressionVisitor visitor) + { + var argument = visitor.Visit(Argument); + + return argument != Argument + ? new FromSqlQueryRootExpression(EntityType, Sql, argument) + : this; + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected override void Print(ExpressionPrinter expressionPrinter) + { + Check.NotNull(expressionPrinter, nameof(expressionPrinter)); + + base.Print(expressionPrinter); + expressionPrinter.Append($".FromSql({Sql}, "); + expressionPrinter.Visit(Argument); + expressionPrinter.AppendLine(")"); + } + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override bool Equals(object? obj) + => obj != null + && (ReferenceEquals(this, obj) + || obj is FromSqlQueryRootExpression queryRootExpression + && Equals(queryRootExpression)); + + private bool Equals(FromSqlQueryRootExpression queryRootExpression) + => base.Equals(queryRootExpression) + && Sql == queryRootExpression.Sql + && ExpressionEqualityComparer.Instance.Equals(Argument, queryRootExpression.Argument); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public override int GetHashCode() + => HashCode.Combine(base.GetHashCode(), Sql, ExpressionEqualityComparer.Instance.GetHashCode(Argument)); + } +} diff --git a/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs b/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs index 9652f0e02dd..3cc1a477d39 100644 --- a/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs +++ b/src/EFCore.Cosmos/Query/Internal/ISqlExpressionFactory.cs @@ -276,5 +276,13 @@ SqlConditionalExpression Condition( /// doing so can result in application failures when updating to a new Entity Framework Core release. /// SelectExpression Select(IEntityType entityType); + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + SelectExpression Select(IEntityType entityType, string sql, Expression argument); } } diff --git a/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs b/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs index 1a75b103766..b93ea6e2e0b 100644 --- a/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs +++ b/src/EFCore.Cosmos/Query/Internal/QuerySqlGenerator.cs @@ -5,9 +5,9 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; -using System.Text; using Microsoft.EntityFrameworkCore.Cosmos.Internal; using Microsoft.EntityFrameworkCore.Cosmos.Storage.Internal; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Storage; using Microsoft.EntityFrameworkCore.Utilities; using Newtonsoft.Json; @@ -25,10 +25,12 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal /// public class QuerySqlGenerator : SqlExpressionVisitor { - private readonly StringBuilder _sqlBuilder = new(); + private readonly ITypeMappingSource _typeMappingSource; + private readonly IndentedStringBuilder _sqlBuilder = new(); private IReadOnlyDictionary _parameterValues; private List _sqlParameters; private bool _useValueProjection; + private ParameterNameGenerator _parameterNameGenerator; private readonly IDictionary _operatorMap = new Dictionary { @@ -64,6 +66,15 @@ public class QuerySqlGenerator : SqlExpressionVisitor { ExpressionType.Not, "~" } }; + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public QuerySqlGenerator(ITypeMappingSource typeMappingSource) + => _typeMappingSource = typeMappingSource; + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -77,6 +88,7 @@ public virtual CosmosSqlQuery GetSqlQuery( _sqlBuilder.Clear(); _parameterValues = parameterValues; _sqlParameters = new List(); + _parameterNameGenerator = new ParameterNameGenerator(); Visit(selectExpression); @@ -108,7 +120,7 @@ protected override Expression VisitObjectArrayProjection(ObjectArrayProjectionEx { Check.NotNull(objectArrayProjectionExpression, nameof(objectArrayProjectionExpression)); - _sqlBuilder.Append(objectArrayProjectionExpression); + _sqlBuilder.Append(objectArrayProjectionExpression.ToString()); return objectArrayProjectionExpression; } @@ -123,7 +135,7 @@ protected override Expression VisitKeyAccess(KeyAccessExpression keyAccessExpres { Check.NotNull(keyAccessExpression, nameof(keyAccessExpression)); - _sqlBuilder.Append(keyAccessExpression); + _sqlBuilder.Append(keyAccessExpression.ToString()); return keyAccessExpression; } @@ -138,7 +150,7 @@ protected override Expression VisitObjectAccess(ObjectAccessExpression objectAcc { Check.NotNull(objectAccessExpression, nameof(objectAccessExpression)); - _sqlBuilder.Append(objectAccessExpression); + _sqlBuilder.Append(objectAccessExpression.ToString()); return objectAccessExpression; } @@ -180,7 +192,7 @@ protected override Expression VisitRootReference(RootReferenceExpression rootRef { Check.NotNull(rootReferenceExpression, nameof(rootReferenceExpression)); - _sqlBuilder.Append(rootReferenceExpression); + _sqlBuilder.Append(rootReferenceExpression.ToString()); return rootReferenceExpression; } @@ -225,7 +237,14 @@ protected override Expression VisitSelect(SelectExpression selectExpression) _sqlBuilder.AppendLine(); - _sqlBuilder.Append("FROM root "); + if (selectExpression.FromExpression is FromSqlExpression) + { + _sqlBuilder.Append("FROM "); + } + else + { + _sqlBuilder.Append("FROM root "); + } Visit(selectExpression.FromExpression); _sqlBuilder.AppendLine(); @@ -272,6 +291,73 @@ protected override Expression VisitSelect(SelectExpression selectExpression) return selectExpression; } + /// + protected override Expression VisitFromSql(FromSqlExpression fromSqlExpression) + { + Check.NotNull(fromSqlExpression, nameof(fromSqlExpression)); + + var sql = fromSqlExpression.Sql; + + string[] substitutions; + + switch (fromSqlExpression.Arguments) + { + case ParameterExpression { Name : not null } parameterExpression + when _parameterValues.TryGetValue(parameterExpression.Name, out var parameterValue) + && parameterValue is object[] parameterValues: + { + substitutions = new string[parameterValues.Length]; + for (var i = 0; i < parameterValues.Length; i++) + { + var parameterName = _parameterNameGenerator.GenerateNext(); + _sqlParameters.Add(new SqlParameter(parameterName, parameterValues[i])); + substitutions[i] = parameterName; + } + + break; + } + + case ConstantExpression { Value : object[] constantValues }: + { + substitutions = new string[constantValues.Length]; + for (var i = 0; i < constantValues.Length; i++) + { + var value = constantValues[i]; + substitutions[i] = GenerateConstant(value, _typeMappingSource.FindMapping(value.GetType())); + } + + break; + } + + default: + throw new ArgumentOutOfRangeException( + nameof(fromSqlExpression), + fromSqlExpression.Arguments, + CosmosStrings.InvalidFromSqlArguments( + fromSqlExpression.Arguments.GetType(), + fromSqlExpression.Arguments is ConstantExpression constantExpression + ? constantExpression.Value?.GetType() + : null)); + } + + // ReSharper disable once CoVariantArrayConversion + // InvariantCulture not needed since substitutions are all strings + sql = string.Format(sql, substitutions); + + _sqlBuilder.AppendLine("("); + + using (_sqlBuilder.Indent()) + { + _sqlBuilder.AppendLines(sql); + } + + _sqlBuilder + .Append(") ") + .Append(fromSqlExpression.Alias); + + return fromSqlExpression; + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -350,7 +436,7 @@ protected override Expression VisitSqlUnary(SqlUnaryExpression sqlUnaryExpressio private void GenerateList( IReadOnlyList items, Action generationAction, - Action joinAction = null) + Action joinAction = null) { joinAction ??= (isb => isb.Append(", ")); @@ -375,13 +461,18 @@ protected override Expression VisitSqlConstant(SqlConstantExpression sqlConstant { Check.NotNull(sqlConstantExpression, nameof(sqlConstantExpression)); - var jToken = GenerateJToken(sqlConstantExpression.Value, sqlConstantExpression.TypeMapping); - - _sqlBuilder.Append(jToken == null ? "null" : jToken.ToString(Formatting.None)); + _sqlBuilder.Append(GenerateConstant(sqlConstantExpression.Value, sqlConstantExpression.TypeMapping)); return sqlConstantExpression; } + private string GenerateConstant(object value, CoreTypeMapping typeMapping) + { + var jToken = GenerateJToken(value, typeMapping); + + return jToken is null ? "null" : jToken.ToString(Formatting.None); + } + private JToken GenerateJToken(object value, CoreTypeMapping typeMapping) { if (value?.GetType().IsInteger() == true) @@ -488,5 +579,13 @@ protected override Expression VisitSqlFunction(SqlFunctionExpression sqlFunction return sqlFunctionExpression; } + + private sealed class ParameterNameGenerator + { + private int _count; + + public string GenerateNext() + => "@p" + _count++; + } } } diff --git a/src/EFCore.Cosmos/Query/Internal/QuerySqlGeneratorFactory.cs b/src/EFCore.Cosmos/Query/Internal/QuerySqlGeneratorFactory.cs index 49e82ecf67f..c2c1fb107f1 100644 --- a/src/EFCore.Cosmos/Query/Internal/QuerySqlGeneratorFactory.cs +++ b/src/EFCore.Cosmos/Query/Internal/QuerySqlGeneratorFactory.cs @@ -1,6 +1,9 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.EntityFrameworkCore.Storage; +using Microsoft.EntityFrameworkCore.Utilities; + namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal { /// @@ -11,6 +14,21 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal /// public class QuerySqlGeneratorFactory : IQuerySqlGeneratorFactory { + private readonly ITypeMappingSource _typeMappingSource; + + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public QuerySqlGeneratorFactory(ITypeMappingSource typeMappingSource) + { + Check.NotNull(typeMappingSource, nameof(typeMappingSource)); + + _typeMappingSource = typeMappingSource; + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -18,6 +36,6 @@ public class QuerySqlGeneratorFactory : IQuerySqlGeneratorFactory /// doing so can result in application failures when updating to a new Entity Framework Core release. /// public virtual QuerySqlGenerator Create() - => new(); + => new(_typeMappingSource); } } diff --git a/src/EFCore.Cosmos/Query/Internal/SelectExpression.cs b/src/EFCore.Cosmos/Query/Internal/SelectExpression.cs index cae1cbbed00..7a861eaf5b2 100644 --- a/src/EFCore.Cosmos/Query/Internal/SelectExpression.cs +++ b/src/EFCore.Cosmos/Query/Internal/SelectExpression.cs @@ -45,6 +45,19 @@ public SelectExpression(IEntityType entityType) _projectionMapping[new ProjectionMember()] = new EntityProjectionExpression(entityType, FromExpression); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public SelectExpression(IEntityType entityType, string sql, Expression argument) + { + Container = entityType.GetContainer(); + FromExpression = new FromSqlExpression(entityType, RootAlias, sql, argument); + _projectionMapping[new ProjectionMember()] = new EntityProjectionExpression(entityType, new RootReferenceExpression(entityType, RootAlias)); + } + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -166,18 +179,14 @@ public virtual void SetPartitionKey(IProperty partitionKeyProperty, Expression e /// public virtual string GetPartitionKey(IReadOnlyDictionary parameterValues) { - switch (_partitionKeyValue) + return _partitionKeyValue switch { - case ConstantExpression constantExpression: - return GetString(_partitionKeyValueConverter, constantExpression.Value); - - case ParameterExpression parameterExpression - when parameterValues.TryGetValue(parameterExpression.Name, out var value): - return GetString(_partitionKeyValueConverter, value); - - default: - return null; - } + ConstantExpression constantExpression + => GetString(_partitionKeyValueConverter, constantExpression.Value), + ParameterExpression parameterExpression when parameterValues.TryGetValue(parameterExpression.Name, out var value) + => GetString(_partitionKeyValueConverter, value), + _ => null + }; static string GetString(ValueConverter converter, object value) => converter is null diff --git a/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs b/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs index b7420e92756..f9d9239103c 100644 --- a/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs +++ b/src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs @@ -515,6 +515,20 @@ public virtual SelectExpression Select(IEntityType entityType) return selectExpression; } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + public virtual SelectExpression Select(IEntityType entityType, string sql, Expression argument) + { + var selectExpression = new SelectExpression(entityType, sql, argument); + AddDiscriminator(selectExpression, entityType); + + return selectExpression; + } + private void AddDiscriminator(SelectExpression selectExpression, IEntityType entityType) { var concreteEntityTypes = entityType.GetConcreteDerivedTypesInclusive().ToList(); diff --git a/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs index 151de23dc5f..5bdc0c898fe 100644 --- a/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/SqlExpressionVisitor.cs @@ -46,6 +46,9 @@ protected override Expression VisitExtension(Expression extensionExpression) case ObjectArrayProjectionExpression arrayProjectionExpression: return VisitObjectArrayProjection(arrayProjectionExpression); + case FromSqlExpression fromSqlExpression: + return VisitFromSql(fromSqlExpression); + case RootReferenceExpression rootReferenceExpression: return VisitRootReference(rootReferenceExpression); @@ -83,6 +86,14 @@ protected override Expression VisitExtension(Expression extensionExpression) return base.VisitExtension(extensionExpression); } + /// + /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to + /// the same compatibility standards as public APIs. It may be changed or removed without notice in + /// any release. You should only use it directly in your code with extreme caution and knowing that + /// doing so can result in application failures when updating to a new Entity Framework Core release. + /// + protected abstract Expression VisitFromSql(FromSqlExpression fromSqlExpression); + /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in diff --git a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs index 61d1ee3dc1d..b442af24873 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs +++ b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs @@ -692,6 +692,14 @@ public static string InvalidDerivedTypeInEntityProjection(object? derivedType, o GetString("InvalidDerivedTypeInEntityProjection", nameof(derivedType), nameof(entityType)), derivedType, entityType); + /// + /// A FromSqlExpression has an invalid arguments expression type '{expressionType}' or value type '{valueType}'. + /// + public static string InvalidFromSqlArguments(object? expressionType, object? valueType) + => string.Format( + GetString("InvalidFromSqlArguments", nameof(expressionType), nameof(valueType)), + expressionType, valueType); + /// /// The grouping key '{keySelector}' is of type '{keyType}' which is not valid key. /// diff --git a/src/EFCore.Relational/Properties/RelationalStrings.resx b/src/EFCore.Relational/Properties/RelationalStrings.resx index 763417d25c9..1533ef0bad3 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.resx +++ b/src/EFCore.Relational/Properties/RelationalStrings.resx @@ -378,6 +378,9 @@ The specified entity type '{derivedType}' is not derived from '{entityType}'. + + A FromSqlExpression has an invalid arguments expression type '{expressionType}' or value type '{valueType}'. + The grouping key '{keySelector}' is of type '{keyType}' which is not valid key. diff --git a/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs b/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs index 08f9baf030a..08bc511144a 100644 --- a/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/Internal/FromSqlParameterExpandingExpressionVisitor.cs @@ -92,94 +92,88 @@ public virtual SelectExpression Expand( [return: NotNullIfNotNull("expression")] public override Expression? Visit(Expression? expression) { - if (expression is FromSqlExpression fromSql) + if (expression is not FromSqlExpression fromSql) { - if (!_visitedFromSqlExpressions.TryGetValue(fromSql, out var updatedFromSql)) - { - switch (fromSql.Arguments) + return base.Visit(expression); + } + + if (_visitedFromSqlExpressions.TryGetValue(fromSql, out var visitedFromSql)) + { + return visitedFromSql; + } + + switch (fromSql.Arguments) + { + case ParameterExpression parameterExpression: + // parameter value will never be null. It could be empty object[] + var parameterValues = (object[])_parametersValues[parameterExpression.Name!]!; + _canCache = false; + + var subParameters = new List(parameterValues.Length); + // ReSharper disable once ForCanBeConvertedToForeach + for (var i = 0; i < parameterValues.Length; i++) { - case ParameterExpression parameterExpression: - // parameter value will never be null. It could be empty object[] - var parameterValues = (object[])_parametersValues[parameterExpression.Name!]!; - _canCache = false; - - var subParameters = new List(parameterValues.Length); - // ReSharper disable once ForCanBeConvertedToForeach - for (var i = 0; i < parameterValues.Length; i++) + var parameterName = _parameterNameGenerator.GenerateNext(); + if (parameterValues[i] is DbParameter dbParameter) + { + if (string.IsNullOrEmpty(dbParameter.ParameterName)) + { + dbParameter.ParameterName = parameterName; + } + else { - var parameterName = _parameterNameGenerator.GenerateNext(); - if (parameterValues[i] is DbParameter dbParameter) - { - if (string.IsNullOrEmpty(dbParameter.ParameterName)) - { - dbParameter.ParameterName = parameterName; - } - else - { - parameterName = dbParameter.ParameterName; - } - - subParameters.Add(new RawRelationalParameter(parameterName, dbParameter)); - } - else - { - subParameters.Add( - new TypeMappedRelationalParameter( - parameterName, - parameterName, - _typeMappingSource.GetMappingForValue(parameterValues[i]), - parameterValues[i]?.GetType().IsNullableType())); - } + parameterName = dbParameter.ParameterName; } - updatedFromSql = fromSql.Update( - Expression.Constant(new CompositeRelationalParameter(parameterExpression.Name!, subParameters))); + subParameters.Add(new RawRelationalParameter(parameterName, dbParameter)); + } + else + { + subParameters.Add( + new TypeMappedRelationalParameter( + parameterName, + parameterName, + _typeMappingSource.GetMappingForValue(parameterValues[i]), + parameterValues[i]?.GetType().IsNullableType())); + } + } - _visitedFromSqlExpressions[fromSql] = updatedFromSql; - break; + return _visitedFromSqlExpressions[fromSql] = fromSql.Update( + Expression.Constant(new CompositeRelationalParameter(parameterExpression.Name!, subParameters))); - case ConstantExpression constantExpression: - var existingValues = constantExpression.GetConstantValue(); - var constantValues = new object?[existingValues.Length]; - for (var i = 0; i < existingValues.Length; i++) + case ConstantExpression constantExpression: + var existingValues = constantExpression.GetConstantValue(); + var constantValues = new object?[existingValues.Length]; + for (var i = 0; i < existingValues.Length; i++) + { + var value = existingValues[i]; + if (value is DbParameter dbParameter) + { + var parameterName = _parameterNameGenerator.GenerateNext(); + if (string.IsNullOrEmpty(dbParameter.ParameterName)) { - var value = existingValues[i]; - if (value is DbParameter dbParameter) - { - var parameterName = _parameterNameGenerator.GenerateNext(); - if (string.IsNullOrEmpty(dbParameter.ParameterName)) - { - dbParameter.ParameterName = parameterName; - } - else - { - parameterName = dbParameter.ParameterName; - } - - constantValues[i] = new RawRelationalParameter(parameterName, dbParameter); - } - else - { - constantValues[i] = _sqlExpressionFactory.Constant( - value, _typeMappingSource.GetMappingForValue(value)); - } + dbParameter.ParameterName = parameterName; + } + else + { + parameterName = dbParameter.ParameterName; } - updatedFromSql = fromSql.Update(Expression.Constant(constantValues, typeof(object?[]))); - - _visitedFromSqlExpressions[fromSql] = updatedFromSql; - break; - - default: - Check.DebugAssert(false, "FromSql.Arguments must be Constant/ParameterExpression"); - break; + constantValues[i] = new RawRelationalParameter(parameterName, dbParameter); + } + else + { + constantValues[i] = _sqlExpressionFactory.Constant( + value, _typeMappingSource.GetMappingForValue(value)); + } } - } - return updatedFromSql; - } + return _visitedFromSqlExpressions[fromSql] = fromSql.Update(Expression.Constant(constantValues, typeof(object?[]))); - return base.Visit(expression); + default: + Check.DebugAssert(false, "FromSql.Arguments must be Constant/ParameterExpression"); + return null; + } } } } diff --git a/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs b/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs index 041f2cbc066..ca7dc4cae04 100644 --- a/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs +++ b/src/EFCore.Relational/Query/Internal/FromSqlQueryRootExpression.cs @@ -125,7 +125,7 @@ public override bool Equals(object? obj) private bool Equals(FromSqlQueryRootExpression queryRootExpression) => base.Equals(queryRootExpression) - && string.Equals(Sql, queryRootExpression.Sql, StringComparison.OrdinalIgnoreCase) + && Sql == queryRootExpression.Sql && ExpressionEqualityComparer.Instance.Equals(Argument, queryRootExpression.Argument); /// diff --git a/src/EFCore.Relational/Query/QuerySqlGenerator.cs b/src/EFCore.Relational/Query/QuerySqlGenerator.cs index 032941ba7cd..44c54be4cee 100644 --- a/src/EFCore.Relational/Query/QuerySqlGenerator.cs +++ b/src/EFCore.Relational/Query/QuerySqlGenerator.cs @@ -363,8 +363,7 @@ private void GenerateFromSql(FromSqlExpression fromSqlExpression) switch (fromSqlExpression.Arguments) { - case ConstantExpression constantExpression - when constantExpression.Value is CompositeRelationalParameter compositeRelationalParameter: + case ConstantExpression { Value: CompositeRelationalParameter compositeRelationalParameter }: { var subParameters = compositeRelationalParameter.RelationalParameters; substitutions = new string[subParameters.Count]; @@ -378,8 +377,7 @@ private void GenerateFromSql(FromSqlExpression fromSqlExpression) break; } - case ConstantExpression constantExpression - when constantExpression.Value is object[] constantValues: + case ConstantExpression { Value: object[] constantValues }: { substitutions = new string[constantValues.Length]; for (var i = 0; i < constantValues.Length; i++) @@ -398,15 +396,22 @@ private void GenerateFromSql(FromSqlExpression fromSqlExpression) break; } - } - if (substitutions != null) - { - // ReSharper disable once CoVariantArrayConversion - // InvariantCulture not needed since substitutions are all strings - sql = string.Format(sql, substitutions); + default: + throw new ArgumentOutOfRangeException( + nameof(fromSqlExpression), + fromSqlExpression.Arguments, + RelationalStrings.InvalidFromSqlArguments( + fromSqlExpression.Arguments.GetType(), + fromSqlExpression.Arguments is ConstantExpression constantExpression + ? constantExpression.Value?.GetType() + : null)); } + // ReSharper disable once CoVariantArrayConversion + // InvariantCulture not needed since substitutions are all strings + sql = string.Format(sql, substitutions); + _relationalCommandBuilder.AppendLines(sql); } diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs index 71d03c4cc27..25abf270b86 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.cs @@ -2839,7 +2839,7 @@ EntityProjectionExpression LiftEntityProjectionFromSubquery(EntityProjectionExpr } /// - /// Checks whether this representes a which is not composed upon. + /// Checks whether this represents a which is not composed upon. /// /// A bool value indicating a non-composed . public bool IsNonComposedFromSql() diff --git a/test/EFCore.Cosmos.FunctionalTests/EFCore.Cosmos.FunctionalTests.csproj b/test/EFCore.Cosmos.FunctionalTests/EFCore.Cosmos.FunctionalTests.csproj index ca49776d0ea..c4671ee3f09 100644 --- a/test/EFCore.Cosmos.FunctionalTests/EFCore.Cosmos.FunctionalTests.csproj +++ b/test/EFCore.Cosmos.FunctionalTests/EFCore.Cosmos.FunctionalTests.csproj @@ -15,6 +15,10 @@ + + + + diff --git a/test/EFCore.Cosmos.FunctionalTests/Query/FromSqlQueryCosmosTest.cs b/test/EFCore.Cosmos.FunctionalTests/Query/FromSqlQueryCosmosTest.cs new file mode 100644 index 00000000000..420e3dd0ffa --- /dev/null +++ b/test/EFCore.Cosmos.FunctionalTests/Query/FromSqlQueryCosmosTest.cs @@ -0,0 +1,615 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.EntityFrameworkCore.TestModels.Northwind; +using Microsoft.EntityFrameworkCore.TestUtilities; +using Microsoft.EntityFrameworkCore.Utilities; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.EntityFrameworkCore.Query +{ + public class FromSqlQueryCosmosTest : QueryTestBase> + { + private static readonly string _eol = Environment.NewLine; + + public FromSqlQueryCosmosTest( + NorthwindQueryCosmosFixture fixture, + ITestOutputHelper testOutputHelper) + : base(fixture) + { + ClearLog(); + Fixture.TestSqlLoggerFactory.SetTestOutputHelper(testOutputHelper); + } + + protected NorthwindContext CreateContext() + => Fixture.CreateContext(); + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_simple(bool async) + { + using var context = CreateContext(); + var query = context.Set() + .FromSqlRaw(@"SELECT * FROM root c WHERE c[""ContactName""] LIKE '%z%'"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(14, actual.Length); + Assert.Equal(14, context.ChangeTracker.Entries().Count()); + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c WHERE c[""ContactName""] LIKE '%z%' +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_simple_columns_out_of_order(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT c[""id""], c[""Discriminator""], c[""Region""], c[""PostalCode""], c[""Phone""], c[""Fax""], c[""CustomerID""], c[""Country""], c[""ContactTitle""], c[""ContactName""], c[""CompanyName""], c[""City""], c[""Address""] FROM root c"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(91, actual.Length); + Assert.Equal(91, context.ChangeTracker.Entries().Count()); + + AssertSql( + @"SELECT c +FROM ( + SELECT c[""id""], c[""Discriminator""], c[""Region""], c[""PostalCode""], c[""Phone""], c[""Fax""], c[""CustomerID""], c[""Country""], c[""ContactTitle""], c[""ContactName""], c[""CompanyName""], c[""City""], c[""Address""] FROM root c +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_simple_columns_out_of_order_and_extra_columns(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT c[""id""], c[""Discriminator""], c[""Region""], c[""PostalCode""], c[""PostalCode""] AS Foo, c[""Phone""], c[""Fax""], c[""CustomerID""], c[""Country""], c[""ContactTitle""], c[""ContactName""], c[""CompanyName""], c[""City""], c[""Address""] FROM root c"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(91, actual.Length); + Assert.Equal(91, context.ChangeTracker.Entries().Count()); + + AssertSql( + @"SELECT c +FROM ( + SELECT c[""id""], c[""Discriminator""], c[""Region""], c[""PostalCode""], c[""PostalCode""] AS Foo, c[""Phone""], c[""Fax""], c[""CustomerID""], c[""Country""], c[""ContactTitle""], c[""ContactName""], c[""CompanyName""], c[""City""], c[""Address""] FROM root c +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_composed(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * FROM root c").Where(c => c.ContactName.Contains("z")); + + var sql = query.ToQueryString(); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(14, actual.Length); + Assert.Equal(14, context.ChangeTracker.Entries().Count()); + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND CONTAINS(c[""ContactName""], ""z""))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_composed_after_removing_whitespaces(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + _eol + " " + _eol + _eol + _eol + "SELECT" + _eol + "* FROM root c") + .Where(c => c.ContactName.Contains("z")); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(14, actual.Length); + + AssertSql( + @"SELECT c +FROM ( + + + + + SELECT + * FROM root c +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND CONTAINS(c[""ContactName""], ""z""))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_composed_compiled(bool async) + { + if (async) + { + var query = EF.CompileAsyncQuery( + (NorthwindContext context) => context.Set() + .FromSqlRaw(@"SELECT * FROM root c") + .Where(c => c.ContactName.Contains("z"))); + + using (var context = CreateContext()) + { + var actual = await query(context).ToListAsync(); + + Assert.Equal(14, actual.Count); + } + } + else + { + var query = EF.CompileQuery( + (NorthwindContext context) => context.Set() + .FromSqlRaw(@"SELECT * FROM root c") + .Where(c => c.ContactName.Contains("z"))); + + using (var context = CreateContext()) + { + var actual = query(context).ToArray(); + + Assert.Equal(14, actual.Length); + } + } + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND CONTAINS(c[""ContactName""], ""z""))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_composed_compiled_with_parameter(bool async) + { + if (async) + { + var query = EF.CompileAsyncQuery( + (NorthwindContext context) => context.Set() + .FromSqlRaw(@"SELECT * FROM root c WHERE c[""CustomerID""] = {0}", "CONSH") + .Where(c => c.ContactName.Contains("z"))); + + using (var context = CreateContext()) + { + var actual = await query(context).ToListAsync(); + + Assert.Single(actual); + } + } + else + { + var query = EF.CompileQuery( + (NorthwindContext context) => context.Set() + .FromSqlRaw(@"SELECT * FROM root c WHERE c[""CustomerID""] = {0}", "CONSH") + .Where(c => c.ContactName.Contains("z"))); + + using (var context = CreateContext()) + { + var actual = query(context).ToArray(); + + Assert.Single(actual); + } + } + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c WHERE c[""CustomerID""] = ""CONSH"" +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND CONTAINS(c[""ContactName""], ""z""))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_multiple_line_query(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * +FROM root c +WHERE c[""City""] = 'London'"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(6, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + + AssertSql( + @"SELECT c +FROM ( + SELECT * + FROM root c + WHERE c[""City""] = 'London' +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_composed_multiple_line_query(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * +FROM root c") + .Where(c => c.City == "London"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(6, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + + AssertSql( + @"SELECT c +FROM ( + SELECT * + FROM root c +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""City""] = ""London""))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_with_parameters(bool async) + { + var city = "London"; + var contactTitle = "Sales Representative"; + + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * FROM root c WHERE c[""City""] = {0} AND c[""ContactTitle""] = {1}", city, + contactTitle); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(3, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + Assert.True(actual.All(c => c.ContactTitle == "Sales Representative")); + + AssertSql( + @"@p0='London' +@p1='Sales Representative' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = @p0 AND c[""ContactTitle""] = @p1 +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_with_parameters_inline(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * FROM root c WHERE c[""City""] = {0} AND c[""ContactTitle""] = {1}", "London", + "Sales Representative"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(3, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + Assert.True(actual.All(c => c.ContactTitle == "Sales Representative")); + + AssertSql( + @"@p0='London' +@p1='Sales Representative' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = @p0 AND c[""ContactTitle""] = @p1 +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_with_null_parameter(bool async) + { + uint? reportsTo = null; + + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * FROM root c WHERE c[""ReportsTo""] = {0} OR (IS_NULL(c[""ReportsTo""]) AND IS_NULL({0}))", reportsTo); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Single(actual); + + AssertSql( + @"@p0=null + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""ReportsTo""] = @p0 OR (IS_NULL(c[""ReportsTo""]) AND IS_NULL(@p0)) +) c +WHERE (c[""Discriminator""] = ""Employee"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public async Task FromSqlRaw_queryable_with_parameters_and_closure(bool async) + { + var city = "London"; + var contactTitle = "Sales Representative"; + + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * FROM root c WHERE c[""City""] = {0}", city) + .Where(c => c.ContactTitle == contactTitle); + var queryString = query.ToQueryString(); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(3, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + Assert.True(actual.All(c => c.ContactTitle == "Sales Representative")); + + AssertSql( + @"@p0='London' +@__contactTitle_1='Sales Representative' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = @p0 +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""ContactTitle""] = @__contactTitle_1))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_simple_cache_key_includes_query_string(bool async) + { + using var context = CreateContext(); + var query = context.Set() + .FromSqlRaw(@"SELECT * FROM root c WHERE c[""City""] = 'London'"); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(6, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + + query = context.Set() + .FromSqlRaw(@"SELECT * FROM root c WHERE c[""City""] = 'Seattle'"); + + actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Single(actual); + Assert.True(actual.All(c => c.City == "Seattle")); + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = 'London' +) c +WHERE (c[""Discriminator""] = ""Customer"")", + // + @"SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = 'Seattle' +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_with_parameters_cache_key_includes_parameters(bool async) + { + var city = "London"; + var contactTitle = "Sales Representative"; + var sql = @"SELECT * FROM root c WHERE c[""City""] = {0} AND c[""ContactTitle""] = {1}"; + + using var context = CreateContext(); + var query = context.Set().FromSqlRaw(sql, city, contactTitle); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(3, actual.Length); + Assert.True(actual.All(c => c.City == "London")); + Assert.True(actual.All(c => c.ContactTitle == "Sales Representative")); + + city = "Madrid"; + contactTitle = "Accounting Manager"; + + query = context.Set().FromSqlRaw(sql, city, contactTitle); + + actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(2, actual.Length); + Assert.True(actual.All(c => c.City == "Madrid")); + Assert.True(actual.All(c => c.ContactTitle == "Accounting Manager")); + + AssertSql( + @"@p0='London' +@p1='Sales Representative' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = @p0 AND c[""ContactTitle""] = @p1 +) c +WHERE (c[""Discriminator""] = ""Customer"")", + // + @"@p0='Madrid' +@p1='Accounting Manager' + +SELECT c +FROM ( + SELECT * FROM root c WHERE c[""City""] = @p0 AND c[""ContactTitle""] = @p1 +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_simple_as_no_tracking_not_composed(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw(@"SELECT * FROM root c") + .AsNoTracking(); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(91, actual.Length); + Assert.Empty(context.ChangeTracker.Entries()); + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_simple_projection_composed(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw( + @"SELECT * +FROM root c +WHERE NOT c[""Discontinued""] AND ((c[""UnitsInStock""] + c[""UnitsOnOrder""]) < c[""ReorderLevel""])") + .Select(p => p.ProductName); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(2, actual.Length); + + AssertSql( + @"SELECT c[""ProductName""] +FROM ( + SELECT * + FROM root c + WHERE NOT c[""Discontinued""] AND ((c[""UnitsInStock""] + c[""UnitsOnOrder""]) < c[""ReorderLevel""]) +) c +WHERE (c[""Discriminator""] = ""Product"")"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_composed_with_nullable_predicate(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw(@"SELECT * FROM root c") + .Where(c => c.ContactName == c.CompanyName); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Empty(actual); + + AssertSql( + @"SELECT c +FROM ( + SELECT * FROM root c +) c +WHERE ((c[""Discriminator""] = ""Customer"") AND (c[""ContactName""] = c[""CompanyName""]))"); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_does_not_parameterize_interpolated_string(bool async) + { + using var context = CreateContext(); + var propertyName = "OrderID"; + var max = 10250; + var query = context.Orders.FromSqlRaw($@"SELECT * FROM root c WHERE c[""{propertyName}""] < {{0}}", max); + + var actual = async + ? await query.ToListAsync() + : query.ToList(); + + Assert.Equal(2, actual.Count); + } + + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task FromSqlRaw_queryable_simple_projection_not_composed(bool async) + { + using var context = CreateContext(); + var query = context.Set().FromSqlRaw(@"SELECT * FROM root c") + .Select( + c => new { c.CustomerID, c.City }) + .AsNoTracking(); + + var actual = async + ? await query.ToArrayAsync() + : query.ToArray(); + + Assert.Equal(91, actual.Length); + Assert.Empty(context.ChangeTracker.Entries()); + + AssertSql( + @"SELECT c[""CustomerID""], c[""City""] +FROM ( + SELECT * FROM root c +) c +WHERE (c[""Discriminator""] = ""Customer"")"); + } + + private void AssertSql(params string[] expected) + => Fixture.TestSqlLoggerFactory.AssertBaseline(expected); + + protected void ClearLog() + => Fixture.TestSqlLoggerFactory.Clear(); + } +}