From 2a588244cb70421d24280c480cc1019c79074bcf Mon Sep 17 00:00:00 2001 From: Smit Patel Date: Fri, 3 Apr 2020 14:32:30 -0700 Subject: [PATCH] Check if parameter.Name is not null before comparing it with query parameter Resolves #20485 Parameters which are used inside lambda should get replaced with appropriate shaper/selector. All other parameters should be query parameters otherwise it is an error. --- ...yExpressionTranslatingExpressionVisitor.cs | 23 ++++-- ...ntityEqualityRewritingExpressionVisitor.cs | 80 ++++++++++++++----- .../Query/SimpleQueryTestBase.cs | 24 ++++++ 3 files changed, 104 insertions(+), 23 deletions(-) diff --git a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs index 50528bb17a4..56e4f253c90 100644 --- a/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs +++ b/src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs @@ -609,12 +609,25 @@ protected override Expression VisitExtension(Expression extensionExpression) protected override Expression VisitParameter(ParameterExpression parameterExpression) { - if (parameterExpression.Name.StartsWith(CompiledQueryParameterPrefix, StringComparison.Ordinal)) + if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue20485", out var enabled) && enabled) { - return Expression.Call( - _getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), - QueryCompilationContext.QueryContextParameter, - Expression.Constant(parameterExpression.Name)); + if (parameterExpression.Name.StartsWith(CompiledQueryParameterPrefix, StringComparison.Ordinal)) + { + return Expression.Call( + _getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(parameterExpression.Name)); + } + } + else + { + if (parameterExpression.Name?.StartsWith(CompiledQueryParameterPrefix, StringComparison.Ordinal) == true) + { + return Expression.Call( + _getParameterValueMethodInfo.MakeGenericMethod(parameterExpression.Type), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(parameterExpression.Name)); + } } throw new InvalidOperationException(CoreStrings.TranslationFailed(parameterExpression.Print())); diff --git a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs index a8625507780..ba489c8557c 100644 --- a/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs +++ b/src/EFCore/Query/Internal/EntityEqualityRewritingExpressionVisitor.cs @@ -429,8 +429,26 @@ private Expression VisitContainsMethodCall(MethodCallExpression methodCallExpres rewrittenSource = Expression.Constant(keyList, keyListType); } + else if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue20485", out var enabled) && enabled + && newSource is ParameterExpression listParam2 + && listParam2.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal)) + { + // The source list is a parameter. Add a runtime parameter that will contain a list of the extracted keys for each execution. + var lambda = Expression.Lambda( + Expression.Call( + _parameterListValueExtractor.MakeGenericMethod(entityType.ClrType, keyProperty.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(listParam2.Name, typeof(string)), + Expression.Constant(keyProperty, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter + ); + + var newParameterName = + $"{RuntimeParameterPrefix}{listParam2.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{keyProperty.Name}"; + rewrittenSource = _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + } else if (newSource is ParameterExpression listParam - && listParam.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal)) + && listParam.Name?.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal) == true) { // The source list is a parameter. Add a runtime parameter that will contain a list of the extracted keys for each execution. var lambda = Expression.Lambda( @@ -935,24 +953,49 @@ private Expression CreatePropertyAccessExpression(Expression target, IProperty p return Expression.Constant(property.GetGetter().GetClrValue(value), property.ClrType.MakeNullable()); } - // If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new - // parameter to be added at runtime, with the value of the property on the base parameter. - if (target is ParameterExpression baseParameterExpression + if (AppContext.TryGetSwitch("Microsoft.EntityFrameworkCore.Issue20485", out var enabled) && enabled) + { + // If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new + // parameter to be added at runtime, with the value of the property on the base parameter. + if (target is ParameterExpression baseParameterExpression && baseParameterExpression.Name.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal)) + { + // Generate an expression to get the base parameter from the query context's parameter list, and extract the + // property from that + var lambda = Expression.Lambda( + Expression.Call( + _parameterValueExtractor.MakeGenericMethod(property.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(baseParameterExpression.Name, typeof(string)), + Expression.Constant(property, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter); + + var newParameterName = + $"{RuntimeParameterPrefix}{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}"; + return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + } + } + else { - // Generate an expression to get the base parameter from the query context's parameter list, and extract the - // property from that - var lambda = Expression.Lambda( - Expression.Call( - _parameterValueExtractor.MakeGenericMethod(property.ClrType.MakeNullable()), - QueryCompilationContext.QueryContextParameter, - Expression.Constant(baseParameterExpression.Name, typeof(string)), - Expression.Constant(property, typeof(IProperty))), - QueryCompilationContext.QueryContextParameter); - - var newParameterName = - $"{RuntimeParameterPrefix}{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}"; - return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + // If the target is a query parameter, we can't simply add a property access over it, but must instead cause a new + // parameter to be added at runtime, with the value of the property on the base parameter. + if (target is ParameterExpression baseParameterExpression + && baseParameterExpression.Name?.StartsWith(CompiledQueryCache.CompiledQueryParameterPrefix, StringComparison.Ordinal) == true) + { + // Generate an expression to get the base parameter from the query context's parameter list, and extract the + // property from that + var lambda = Expression.Lambda( + Expression.Call( + _parameterValueExtractor.MakeGenericMethod(property.ClrType.MakeNullable()), + QueryCompilationContext.QueryContextParameter, + Expression.Constant(baseParameterExpression.Name, typeof(string)), + Expression.Constant(property, typeof(IProperty))), + QueryCompilationContext.QueryContextParameter); + + var newParameterName = + $"{RuntimeParameterPrefix}{baseParameterExpression.Name.Substring(CompiledQueryCache.CompiledQueryParameterPrefix.Length)}_{property.Name}"; + return _queryCompilationContext.RegisterRuntimeParameter(newParameterName, lambda); + } } return target.CreateEFPropertyExpression(property); @@ -1036,7 +1079,8 @@ protected struct EntityOrDtoType public static EntityOrDtoType FromEntityReferenceExpression(EntityReferenceExpression ere) => new EntityOrDtoType { - EntityType = ere.IsEntityType ? ere.EntityType : null, DtoType = ere.IsDtoType ? ere.DtoType : null + EntityType = ere.IsEntityType ? ere.EntityType : null, + DtoType = ere.IsDtoType ? ere.DtoType : null }; public static EntityOrDtoType FromDtoType(Dictionary dtoType) diff --git a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs index 91a47b85193..c4c8d320f7c 100644 --- a/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs +++ b/test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs @@ -10,6 +10,7 @@ using System.Threading.Tasks; using Microsoft.EntityFrameworkCore.Diagnostics; using Microsoft.EntityFrameworkCore.Internal; +using Microsoft.EntityFrameworkCore.Query.Internal; using Microsoft.EntityFrameworkCore.TestModels.Northwind; using Microsoft.EntityFrameworkCore.TestUtilities; using Xunit; @@ -5630,6 +5631,29 @@ public virtual Task AsQueryable_in_query_server_evals(bool isAsync) elementAsserter: (e, a) => AssertCollection(e, a, ordered: true)); } + [ConditionalTheory] + [MemberData(nameof(IsAsyncData))] + public virtual async Task Null_parameter_name_works(bool isAsync) + { + using var context = CreateContext(); + var customerDbSet = context.Set().AsQueryable(); + + var parameter = Expression.Parameter(typeof(Customer)); + var body = Expression.Equal(parameter, Expression.Default(typeof(Customer))); + var queryExpression = Expression.Call( + QueryableMethods.Where.MakeGenericMethod(typeof(Customer)), + customerDbSet.Expression, + Expression.Quote(Expression.Lambda(body, parameter))); + + var query = ((IAsyncQueryProvider)customerDbSet.Provider).CreateQuery(queryExpression); + + var result = isAsync + ? (await query.ToListAsync()) + : query.ToList(); + + Assert.Empty(result); + } + protected async Task AssertTranslationFailed(Func testCode) { Assert.Contains(