Skip to content

Commit

Permalink
Query: Throw when capturing unknown type constants in shaper
Browse files Browse the repository at this point in the history
3.0 work for #13048
  • Loading branch information
smitpatel committed Aug 11, 2019
1 parent ae7f03d commit edc4d9a
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ public class CosmosShapedQueryCompilingExpressionVisitor : ShapedQueryCompilingE
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public CosmosShapedQueryCompilingExpressionVisitor(
QueryCompilationContext queryCompilationContext,
ShapedQueryCompilingExpressionVisitorDependencies dependencies,
QueryCompilationContext queryCompilationContext,
ISqlExpressionFactory sqlExpressionFactory,
IQuerySqlGeneratorFactory querySqlGeneratorFactory)
: base(queryCompilationContext, dependencies)
: base(dependencies, queryCompilationContext)
{
_sqlExpressionFactory = sqlExpressionFactory;
_querySqlGeneratorFactory = querySqlGeneratorFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public CosmosShapedQueryCompilingExpressionVisitorFactory(
/// </summary>
public virtual ShapedQueryCompilingExpressionVisitor Create(QueryCompilationContext queryCompilationContext)
=> new CosmosShapedQueryCompilingExpressionVisitor(
queryCompilationContext,
_dependencies,
queryCompilationContext,
_sqlExpressionFactory,
_querySqlGeneratorFactory);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ public partial class InMemoryShapedQueryCompilingExpressionVisitor : ShapedQuery
private readonly IDiagnosticsLogger<DbLoggerCategory.Query> _logger;

public InMemoryShapedQueryCompilingExpressionVisitor(
QueryCompilationContext queryCompilationContext,
ShapedQueryCompilingExpressionVisitorDependencies dependencies)
: base(queryCompilationContext, dependencies)
ShapedQueryCompilingExpressionVisitorDependencies dependencies,
QueryCompilationContext queryCompilationContext)
: base(dependencies, queryCompilationContext)
{
_contextType = queryCompilationContext.ContextType;
_logger = queryCompilationContext.Logger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public InMemoryShapedQueryCompilingExpressionVisitorFactory(ShapedQueryCompiling
}

public virtual ShapedQueryCompilingExpressionVisitor Create(QueryCompilationContext queryCompilationContext)
=> new InMemoryShapedQueryCompilingExpressionVisitor(queryCompilationContext, _dependencies);
=> new InMemoryShapedQueryCompilingExpressionVisitor(_dependencies, queryCompilationContext);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ public RelationalShapedQueryCompilingExpressionVisitorFactory(
public virtual ShapedQueryCompilingExpressionVisitor Create(QueryCompilationContext queryCompilationContext)
{
return new RelationalShapedQueryCompilingExpressionVisitor(
queryCompilationContext,
_dependencies,
_relationalDependencies);
_relationalDependencies,
queryCompilationContext);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ public partial class RelationalShapedQueryCompilingExpressionVisitor : ShapedQue
private readonly ISet<string> _tags;

public RelationalShapedQueryCompilingExpressionVisitor(
QueryCompilationContext queryCompilationContext,
ShapedQueryCompilingExpressionVisitorDependencies dependencies,
RelationalShapedQueryCompilingExpressionVisitorDependencies relationalDependencies)
: base(queryCompilationContext, dependencies)
RelationalShapedQueryCompilingExpressionVisitorDependencies relationalDependencies,
QueryCompilationContext queryCompilationContext)
: base(dependencies, queryCompilationContext)
{
RelationalDependencies = relationalDependencies;

Expand Down
48 changes: 44 additions & 4 deletions src/EFCore/Query/ShapedQueryCompilingExpressionVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.ChangeTracking.Internal;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Storage;
Expand All @@ -32,10 +33,11 @@ private static readonly PropertyInfo _cancellationTokenMemberInfo

private readonly Expression _cancellationTokenParameter;
private readonly EntityMaterializerInjectingExpressionVisitor _entityMaterializerInjectingExpressionVisitor;
private readonly ConstantVerifyingExpressionVisitor _constantVerifyingExpressionVisitor;

protected ShapedQueryCompilingExpressionVisitor(
QueryCompilationContext queryCompilationContext,
ShapedQueryCompilingExpressionVisitorDependencies dependencies)
ShapedQueryCompilingExpressionVisitorDependencies dependencies,
QueryCompilationContext queryCompilationContext)
{
Dependencies = dependencies;

Expand All @@ -46,6 +48,8 @@ protected ShapedQueryCompilingExpressionVisitor(
dependencies.EntityMaterializerSource,
queryCompilationContext.IsTracking);

_constantVerifyingExpressionVisitor = new ConstantVerifyingExpressionVisitor(dependencies.TypeMappingSource);

IsAsync = queryCompilationContext.IsAsync;

if (queryCompilationContext.IsAsync)
Expand Down Expand Up @@ -113,7 +117,7 @@ private static async Task<TSource> SingleAsync<TSource>(
{
await using (var enumerator = asyncEnumerable.GetAsyncEnumerator(cancellationToken))
{
if (!(await enumerator.MoveNextAsync()))
if (!await enumerator.MoveNextAsync())
{
throw new InvalidOperationException();
}
Expand All @@ -124,6 +128,7 @@ private static async Task<TSource> SingleAsync<TSource>(
{
throw new InvalidOperationException();
}

return result;
}
}
Expand Down Expand Up @@ -153,7 +158,42 @@ private static async Task<TSource> SingleOrDefaultAsync<TSource>(
protected abstract Expression VisitShapedQueryExpression(ShapedQueryExpression shapedQueryExpression);

protected virtual Expression InjectEntityMaterializers(Expression expression)
=> _entityMaterializerInjectingExpressionVisitor.Inject(expression);
{
_constantVerifyingExpressionVisitor.Visit(expression);

return _entityMaterializerInjectingExpressionVisitor.Inject(expression);
}

private class ConstantVerifyingExpressionVisitor : ExpressionVisitor
{
private readonly ITypeMappingSource _typeMappingSource;

public ConstantVerifyingExpressionVisitor(ITypeMappingSource typeMappingSource)
{
_typeMappingSource = typeMappingSource;
}

protected override Expression VisitConstant(ConstantExpression constantExpression)
{
if (constantExpression.Value == null
|| _typeMappingSource.FindMapping(constantExpression.Type) != null)
{
return constantExpression;
}

throw new InvalidOperationException(
$"Client projection contains reference to constant expression of type: {constantExpression.Type.DisplayName()}. " +
"This could potentially cause memory leak.");
}

protected override Expression VisitExtension(Expression extensionExpression)
{
return extensionExpression is EntityShaperExpression
|| extensionExpression is ProjectionBindingExpression
? extensionExpression
: base.VisitExtension(extensionExpression);
}
}

private class EntityMaterializerInjectingExpressionVisitor : ExpressionVisitor
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;
using Microsoft.Extensions.DependencyInjection;

Expand Down Expand Up @@ -54,24 +55,40 @@ public sealed class ShapedQueryCompilingExpressionVisitorDependencies
/// </summary>
[EntityFrameworkInternal]
public ShapedQueryCompilingExpressionVisitorDependencies(
[NotNull] IEntityMaterializerSource entityMaterializerSource)
[NotNull] IEntityMaterializerSource entityMaterializerSource,
[NotNull] ITypeMappingSource typeMappingSource)
{
Check.NotNull(entityMaterializerSource, nameof(entityMaterializerSource));
Check.NotNull(typeMappingSource, nameof(typeMappingSource));

EntityMaterializerSource = entityMaterializerSource;
TypeMappingSource = typeMappingSource;
}

/// <summary>
/// The materializer source.
/// </summary>
public IEntityMaterializerSource EntityMaterializerSource { get; }

/// <summary>
/// The type mapping source.
/// </summary>
public ITypeMappingSource TypeMappingSource { get; }

/// <summary>
/// Clones this dependency parameter object with one service replaced.
/// </summary>
/// <param name="entityMaterializerSource"> A replacement for the current dependency of this type. </param>
/// <returns> A new parameter object with the given service replaced. </returns>
public ShapedQueryCompilingExpressionVisitorDependencies With([NotNull] IEntityMaterializerSource entityMaterializerSource)
=> new ShapedQueryCompilingExpressionVisitorDependencies(entityMaterializerSource);
=> new ShapedQueryCompilingExpressionVisitorDependencies(entityMaterializerSource, TypeMappingSource);

/// <summary>
/// Clones this dependency parameter object with one service replaced.
/// </summary>
/// <param name="typeMappingSource"> A replacement for the current dependency of this type. </param>
/// <returns> A new parameter object with the given service replaced. </returns>
public ShapedQueryCompilingExpressionVisitorDependencies With([NotNull] ITypeMappingSource typeMappingSource)
=> new ShapedQueryCompilingExpressionVisitorDependencies(EntityMaterializerSource, typeMappingSource);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -276,5 +276,23 @@ public override Task QueryType_with_included_navs_multi_level(bool isAsync)
{
return base.QueryType_with_included_navs_multi_level(isAsync);
}

[ConditionalTheory(Skip = "Issue#17050")]
public override void Client_code_using_instance_in_static_method()
{
base.Client_code_using_instance_in_static_method();
}

[ConditionalTheory(Skip = "Issue#17050")]
public override void Client_code_using_instance_method_throws()
{
base.Client_code_using_instance_method_throws();
}

[ConditionalTheory(Skip = "Issue#17050")]
public override void Client_code_using_instance_in_anonymous_type()
{
base.Client_code_using_instance_in_anonymous_type();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public virtual void Executes_stored_procedure_with_generated_parameter()
}
}

[ConditionalFact]
[ConditionalFact(Skip = "Issue#17019")]
public virtual void Throws_on_concurrent_command()
{
using (var context = CreateContext())
Expand Down Expand Up @@ -221,7 +221,7 @@ public virtual async Task Executes_stored_procedure_with_generated_parameter_asy
}
}

[ConditionalFact]
[ConditionalFact(Skip = "Issue#17019")]
public virtual async Task Throws_on_concurrent_command_async()
{
using (var context = CreateContext())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7092,7 +7092,7 @@ public virtual Task Multiple_includes_with_client_method_around_entity_and_also_
return Task.CompletedTask;
}

public TEntity Client<TEntity>(TEntity entity) => entity;
public static TEntity Client<TEntity>(TEntity entity) => entity;

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,7 @@ public virtual Task GroupBy_empty_key_Aggregate(bool isAsync)
.Select(g => g.Sum(o => o.OrderID)));
}

[ConditionalTheory]
[ConditionalTheory(Skip = "Issue#17048")]
[MemberData(nameof(IsAsyncData))]
public virtual Task GroupBy_empty_key_Aggregate_Key(bool isAsync)
{
Expand Down
40 changes: 37 additions & 3 deletions test/EFCore.Specification.Tests/Query/SimpleQueryTestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3683,7 +3683,7 @@ orderby o.OrderID
}
}

[ConditionalFact(Skip = "Deadlock")]
[ConditionalFact(Skip = "Issue#17019")]
public virtual void Throws_on_concurrent_query_list()
{
using (var context = CreateContext())
Expand Down Expand Up @@ -3719,7 +3719,7 @@ public virtual void Throws_on_concurrent_query_list()
}
}

[ConditionalFact(Skip = "Deadlock")]
[ConditionalFact(Skip = "Issue#17019")]
public virtual void Throws_on_concurrent_query_first()
{
using (var context = CreateContext())
Expand Down Expand Up @@ -4290,7 +4290,7 @@ public virtual Task Select_expression_int_to_string(bool isAsync)
e => e.ShipName);
}

[ConditionalTheory]
[ConditionalTheory(Skip = "Issue#17048")]
[MemberData(nameof(IsAsyncData))]
public virtual async Task ToString_with_formatter_is_evaluated_on_the_client(bool isAsync)
{
Expand Down Expand Up @@ -6000,5 +6000,39 @@ public virtual Task Navigation_inside_interpolated_string_is_expanded(bool isAsy
isAsync,
os => os.Select(o => $"CustomerCity:{o.Customer.City}"));
}

[ConditionalFact]
public virtual void Client_code_using_instance_method_throws()
{
using (var context = CreateContext())
{
Assert.Throws<InvalidOperationException>(
() => context.Customers.Select(c => InstanceMethod(c)).ToList());
}
}

private string InstanceMethod(Customer c) => c.City;

[ConditionalFact]
public virtual void Client_code_using_instance_in_static_method()
{
using (var context = CreateContext())
{
Assert.Throws<InvalidOperationException>(
() => context.Customers.Select(c => StaticMethod(this, c)).ToList());
}
}

private static string StaticMethod(SimpleQueryTestBase<TFixture> containingClass, Customer c) => c.City;

[ConditionalFact]
public virtual void Client_code_using_instance_in_anonymous_type()
{
using (var context = CreateContext())
{
Assert.Throws<InvalidOperationException>(
() => context.Customers.Select(c => new { A = this }).ToList());
}
}
}
}

0 comments on commit edc4d9a

Please sign in to comment.