Skip to content

Commit

Permalink
Fix to #16323 - Query: add support for keyless entity types with defi…
Browse files Browse the repository at this point in the history
…ning query

During nav expansion, when visiting EntityQueryables we now peek into entity metadata looking for defining query and query filter. If they are found, nav expansion applies them and expands navigations in the resulting query (recursively).

Also, fix to #17111 - Cannot use DbSet in query filter

DbSet accesses are converted into EntityQueryables during model finialization. We need to do this early, because otherwise we would need a dbcontext to create a queryable, and that dbcontext could be disposed at the time nav expansion runs.
Since FromSql/FromSqlRaw/FromSqlInterpolated are constrained to DbSet<T> they need to be converted to FromSqlOnQueryable at the same time.
  • Loading branch information
maumar committed Aug 20, 2019
1 parent 9fc67e2 commit 3085673
Show file tree
Hide file tree
Showing 26 changed files with 432 additions and 247 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.EntityFrameworkCore
/// </summary>
public static class RelationalQueryableExtensions
{
private static readonly MethodInfo _fromSqlOnQueryableMethodInfo
internal static readonly MethodInfo FromSqlOnQueryableMethodInfo
= typeof(RelationalQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(FromSqlOnQueryable))
.Single();
Expand Down Expand Up @@ -71,7 +71,7 @@ public static IQueryable<TEntity> FromSql<TEntity>(
return source.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
FromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
Expression.Constant(sql.Format),
Expression.Constant(parameters)));
Expand Down Expand Up @@ -115,7 +115,7 @@ public static IQueryable<TEntity> FromSql<TEntity>(
return source.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
FromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
source.Expression,
Expression.Constant(sql.Format),
Expression.Constant(sql.GetArguments())));
Expand Down Expand Up @@ -164,7 +164,7 @@ public static IQueryable<TEntity> FromSqlRaw<TEntity>(
return queryableSource.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
FromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
queryableSource.Expression,
Expression.Constant(sql),
Expression.Constant(parameters)));
Expand Down Expand Up @@ -204,7 +204,7 @@ public static IQueryable<TEntity> FromSqlInterpolated<TEntity>(
return queryableSource.Provider.CreateQuery<TEntity>(
Expression.Call(
null,
_fromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
FromSqlOnQueryableMethodInfo.MakeGenericMethod(typeof(TEntity)),
queryableSource.Expression,
Expression.Constant(sql.Format),
Expression.Constant(sql.GetArguments())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ public override ConventionSet CreateConventionSet()
new DbFunctionTypeMappingConvention(Dependencies, RelationalDependencies),
typeof(ValidatingConvention));

ReplaceConvention(
conventionSet.ModelFinalizedConventions,
(QueryFilterDefiningQueryRewritingConvention)new RelationalQueryFilterDefiningQueryRewritingConvention(Dependencies, RelationalDependencies));

return conventionSet;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
{
public class RelationalQueryFilterDefiningQueryRewritingConvention : QueryFilterDefiningQueryRewritingConvention
{
/// <summary>
/// Creates a new instance of <see cref="RelationalQueryFilterDefiningQueryRewritingConvention" />.
/// </summary>
/// <param name="dependencies"> Parameter object containing dependencies for this convention. </param>
/// <param name="relationalDependencies"> Parameter object containing relational dependencies for this convention. </param>
public RelationalQueryFilterDefiningQueryRewritingConvention(
[NotNull] ProviderConventionSetBuilderDependencies dependencies,
[NotNull] RelationalConventionSetBuilderDependencies relationalDependencies)
: base(dependencies)
{
DbSetAccessRewriter = new RelationalDbSetAccessRewritingExpressionVisitor(Dependencies.ContextType);
}

protected class RelationalDbSetAccessRewritingExpressionVisitor : DbSetAccessRewritingExpressionVisitor
{
public RelationalDbSetAccessRewritingExpressionVisitor(Type contextType)
: base(contextType)
{
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.DeclaringType == typeof(RelationalQueryableExtensions)
&& (methodCallExpression.Method.Name == nameof(RelationalQueryableExtensions.FromSql)
|| methodCallExpression.Method.Name == nameof(RelationalQueryableExtensions.FromSqlRaw)
|| methodCallExpression.Method.Name == nameof(RelationalQueryableExtensions.FromSqlInterpolated)))
{
var newSource = Visit(methodCallExpression.Arguments[0]);
var fromSqlOnQueryableMethod = RelationalQueryableExtensions.FromSqlOnQueryableMethodInfo.MakeGenericMethod(newSource.Type.GetGenericArguments()[0]);

switch (methodCallExpression.Method.Name)
{
case nameof(RelationalQueryableExtensions.FromSqlRaw):
return Expression.Call(
null,
fromSqlOnQueryableMethod,
newSource,
methodCallExpression.Arguments[1],
methodCallExpression.Arguments[2]);

case nameof(RelationalQueryableExtensions.FromSqlInterpolated):
case nameof(RelationalQueryableExtensions.FromSql) when methodCallExpression.Arguments.Count == 2:
var formattableString = Expression.Lambda<Func<FormattableString>>(Expression.Convert(methodCallExpression.Arguments[1], typeof(FormattableString))).Compile().Invoke();

return Expression.Call(
null,
fromSqlOnQueryableMethod,
newSource,
Expression.Constant(formattableString.Format),
Expression.Constant(formattableString.GetArguments()));

case nameof(RelationalQueryableExtensions.FromSql) when methodCallExpression.Arguments.Count == 3:
#pragma warning disable CS0618 // Type or member is obsolete
var rawSqlStringString = Expression.Lambda<Func<RawSqlString>>(Expression.Convert(methodCallExpression.Arguments[1], typeof(RawSqlString))).Compile().Invoke();
#pragma warning restore CS0618 // Type or member is obsolete

return Expression.Call(
null,
fromSqlOnQueryableMethod,
newSource,
Expression.Constant(rawSqlStringString.Format),
methodCallExpression.Arguments[2]);
}
}

return base.VisitMethodCall(methodCallExpression);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.InteropServices.ComTypes;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ public override Expression Visit(Expression expression)
break;

case ConstantExpression constantExpression:
var constantValues = (object[])constantExpression.Value;
for (var i = 0; i < constantValues.Length; i++)
var existingValues = (object[])constantExpression.Value;
var constantValues = new object[existingValues.Length];
for (var i = 0; i < existingValues.Length; i++)
{
var value = constantValues[i];
var value = existingValues[i];
if (value is DbParameter dbParameter)
{
var parameterName = _parameterNameGenerator.GenerateNext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ public virtual ConventionSet CreateConventionSet()
conventionSet.ModelFinalizedConventions.Add(servicePropertyDiscoveryConvention);
conventionSet.ModelFinalizedConventions.Add(nonNullableReferencePropertyConvention);
conventionSet.ModelFinalizedConventions.Add(nonNullableNavigationConvention);
conventionSet.ModelFinalizedConventions.Add(new QueryFilterDefiningQueryRewritingConvention(Dependencies));
conventionSet.ModelFinalizedConventions.Add(new ValidatingConvention(Dependencies));
// Don't add any more conventions to ModelFinalizedConventions after ValidatingConvention

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using Microsoft.EntityFrameworkCore.Metadata.Builders;
using Microsoft.EntityFrameworkCore.Metadata.Conventions.Infrastructure;
using Microsoft.EntityFrameworkCore.Query.Internal;

namespace Microsoft.EntityFrameworkCore.Metadata.Conventions
{
/// <summary>
/// Convention that converts accesses of DbSets inside query filters and defining queries into EntityQueryables.
/// This makes them consistent with how DbSet accesses in the actual queries are represented, which allows for easier processing in the query pipeline.
/// </summary>
public class QueryFilterDefiningQueryRewritingConvention : IModelFinalizedConvention
{
/// <summary>
/// Creates a new instance of <see cref="QueryFilterDefiningQueryRewritingConvention" />.
/// </summary>
/// <param name="dependencies"> Parameter object containing dependencies for this convention. </param>
public QueryFilterDefiningQueryRewritingConvention([NotNull] ProviderConventionSetBuilderDependencies dependencies)
{
Dependencies = dependencies;
DbSetAccessRewriter = new DbSetAccessRewritingExpressionVisitor(dependencies.ContextType);
}

/// <summary>
/// Parameter object containing service dependencies.
/// </summary>
protected virtual ProviderConventionSetBuilderDependencies Dependencies { get; }

/// <summary>
/// Visitor used to rewrite DbSets accesses encountered in query filters and defining queries to EntityQueryables.
/// </summary>
protected virtual DbSetAccessRewritingExpressionVisitor DbSetAccessRewriter { get; [param: NotNull] set; }

/// <summary>
/// Called after a model is finalized.
/// </summary>
/// <param name="modelBuilder"> The builder for the model. </param>
/// <param name="context"> Additional information associated with convention execution. </param>
public virtual void ProcessModelFinalized(
IConventionModelBuilder modelBuilder,
IConventionContext<IConventionModelBuilder> context)
{
foreach (var entityType in modelBuilder.Metadata.GetEntityTypes())
{
var queryFilter = entityType.GetQueryFilter();
if (queryFilter != null)
{
entityType.SetQueryFilter((LambdaExpression)DbSetAccessRewriter.Visit(queryFilter));
}

var definingQuery = entityType.GetDefiningQuery();
if (definingQuery != null)
{
entityType.SetDefiningQuery((LambdaExpression)DbSetAccessRewriter.Visit(definingQuery));
}
}
}

protected class DbSetAccessRewritingExpressionVisitor : ExpressionVisitor
{
private readonly Type _contextType;

public DbSetAccessRewritingExpressionVisitor(Type contextType)
{
_contextType = contextType;
}

protected override Expression VisitMember(MemberExpression memberExpression)
{
if (memberExpression.Expression != null
&& (memberExpression.Expression.Type.IsAssignableFrom(_contextType)
|| _contextType.IsAssignableFrom(memberExpression.Expression.Type))
&& memberExpression.Type.IsGenericType
&& (memberExpression.Type.GetGenericTypeDefinition() == typeof(DbSet<>)
#pragma warning disable CS0618 // Type or member is obsolete
|| memberExpression.Type.GetGenericTypeDefinition() == typeof(DbQuery<>)))
#pragma warning restore CS0618 // Type or member is obsolete
{
return NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(memberExpression.Type.GetGenericArguments()[0]);
}

return base.VisitMember(memberExpression);
}

protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
if (methodCallExpression.Method.Name == nameof(DbContext.Set)
&& methodCallExpression.Object != null
&& typeof(DbContext).IsAssignableFrom(methodCallExpression.Object.Type)
&& methodCallExpression.Type.IsGenericType
&& methodCallExpression.Type.GetGenericTypeDefinition() == typeof(DbSet<>))
{
return NullAsyncQueryProvider.Instance.CreateEntityQueryableExpression(methodCallExpression.Type.GetGenericArguments()[0]);
}

return base.VisitMethodCall(methodCallExpression);
}
}
}
}
9 changes: 9 additions & 0 deletions src/EFCore/Properties/CoreStrings.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/EFCore/Properties/CoreStrings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -1204,4 +1204,7 @@
<data name="SetOperationWithDifferentIncludesInOperands" xml:space="preserve">
<value>When performing a set operation, both operands must have the same Include operations.</value>
</data>
<data name="IncludeOnEntityWithDefiningQueryNotSupported" xml:space="preserve">
<value>Include is not supported for entities with defining query. Entity type: '{entityType}'</value>
</data>
</root>
Loading

0 comments on commit 3085673

Please sign in to comment.