Skip to content

Commit

Permalink
Feature/enumerable count support (#33)
Browse files Browse the repository at this point in the history
* feat: added support of count call on related entity

Co-authored-by: Ilya Belyanskiy <win7user20@gmail.com>
  • Loading branch information
win7user10 and Ilya Belyanskiy committed Mar 2, 2022
1 parent 57f0b42 commit 44c9a38
Show file tree
Hide file tree
Showing 51 changed files with 687 additions and 135 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.Services;
using Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;
using Laraue.EfCoreTriggers.Common.SqlGeneration;
using Laraue.EfCoreTriggers.Common.TriggerBuilders;

namespace Laraue.EfCoreTriggers.Common.Converters.MethodCall.Enumerable
{
/// <summary>
/// Base visitor for <see cref="IEnumerable{T}"/> extensions.
/// </summary>
public abstract class BaseEnumerableVisitor : BaseMethodCallVisitor
{
/// <inheritdoc />
protected override Type ReflectedType => typeof(System.Linq.Enumerable);

private readonly IDbSchemaRetriever _schemaRetriever;
private readonly ISqlGenerator _sqlGenerator;
private readonly IExpressionVisitorFactory _expressionVisitorFactory;

/// <inheritdoc />
protected BaseEnumerableVisitor(
IExpressionVisitorFactory visitorFactory,
IDbSchemaRetriever schemaRetriever,
ISqlGenerator sqlGenerator)
: base(visitorFactory)
{
_schemaRetriever = schemaRetriever;
_sqlGenerator = sqlGenerator;
_expressionVisitorFactory = visitorFactory;
}

public override SqlBuilder Visit(MethodCallExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
var whereExpressions = new HashSet<Expression>();
var exp = GetFlattenExpressions(expression, whereExpressions);

var baseMember = exp as MemberExpression;
var enumerableMemberType = baseMember.Type;

var originalSetMemberExpression = (ParameterExpression) baseMember.Expression;
var originalSetType = originalSetMemberExpression?.Type;

if (!typeof(IEnumerable).IsAssignableFrom(enumerableMemberType) || !enumerableMemberType.IsGenericType)
{
throw new NotSupportedException($"Don't know how to translate expression {expression}");
}

var entityType = baseMember.Type.GetGenericArguments()[0];

var otherArguments = expression.Arguments
.Skip(1)
.ToArray();

var selectSql = Visit(otherArguments, argumentTypes, visitedMembers);

if (selectSql.Item2 is not null)
{
whereExpressions.Add(selectSql.Item2);
}

var finalSql = SqlBuilder.FromString("(");

finalSql.WithIdent(x=> x
.Append("SELECT ")
.Append(selectSql.Item1)
.AppendNewLine($"FROM {_schemaRetriever.GetTableName(entityType)}")
.AppendNewLine($"INNER JOIN {_schemaRetriever.GetTableName(originalSetType)} ON "));

var keys = _schemaRetriever.GetForeignKeyMembers(entityType, originalSetType);

var joinParts = new List<string>();
foreach (var key in keys)
{
var column1Sql = _sqlGenerator.GetColumnSql(entityType, key.ForeignKey, ArgumentType.Default);

var argument2Type = argumentTypes.Get(originalSetMemberExpression);

var column2WhereSql = _sqlGenerator.GetVariableSql(originalSetType, key.PrincipalKey, argument2Type);
visitedMembers.AddMember(argument2Type, key.PrincipalKey);

var column2JoinSql = _sqlGenerator.GetColumnSql(originalSetType, key.PrincipalKey, ArgumentType.Default);

joinParts.Add($"{column1Sql} = {column2JoinSql}");
joinParts.Add($"{column1Sql} = {column2WhereSql}");
}

foreach (var e in whereExpressions)
{
joinParts.Add(_expressionVisitorFactory.Visit(e, argumentTypes, visitedMembers));
}

finalSql.AppendJoin(" AND ", joinParts);

finalSql.Append(")");

return finalSql;
}

private Expression GetFlattenExpressions(MethodCallExpression methodCallExpression, HashSet<Expression> whereExpressions)
{
while (true)
{
if (methodCallExpression.Arguments[0] is not MethodCallExpression childCall ||
childCall.Method.Name != nameof(System.Linq.Enumerable.Where))
{
return methodCallExpression.Arguments[0];
}

whereExpressions.Add(childCall.Arguments[1]);

methodCallExpression = childCall;
}
}

protected abstract (SqlBuilder, Expression) Visit(
Expression[] arguments,
ArgumentTypes argumentTypes,
VisitedMembers visitedMembers);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.Services;
using Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;
using Laraue.EfCoreTriggers.Common.SqlGeneration;
using Laraue.EfCoreTriggers.Common.TriggerBuilders;

namespace Laraue.EfCoreTriggers.Common.Converters.MethodCall.Enumerable.Count;

public class CountVisitor : BaseEnumerableVisitor
{
protected override string MethodName => nameof(System.Linq.Enumerable.Count);

private readonly IExpressionVisitorFactory _expressionVisitorFactory;

public CountVisitor(
IExpressionVisitorFactory visitorFactory,
IDbSchemaRetriever schemaRetriever,
ISqlGenerator sqlGenerator)
: base(visitorFactory, schemaRetriever, sqlGenerator)
{
_expressionVisitorFactory = visitorFactory;
}

protected override (SqlBuilder, Expression) Visit(Expression[] arguments, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
return (SqlBuilder.FromString("count(*)"), arguments.FirstOrDefault());
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.Converters.MethodCall;
using Laraue.EfCoreTriggers.Common.Converters.MethodCall.Enumerable.Count;
using Laraue.EfCoreTriggers.Common.Services;
using Laraue.EfCoreTriggers.Common.Services.Impl;
using Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;
Expand Down Expand Up @@ -78,6 +79,7 @@ public static IServiceCollection AddDefaultServices(this IServiceCollection serv
.AddScoped<IMemberInfoVisitor<LambdaExpression>, SetLambdaExpressionVisitor>()
.AddScoped<IMemberInfoVisitor<MemberInitExpression>, SetMemberInitExpressionVisitor>()
.AddScoped<IMemberInfoVisitor<NewExpression>, SetNewExpressionVisitor>()
.AddScoped<IMemberInfoVisitor<BinaryExpression>, SetBinaryExpressionVisitor>()

.AddScoped<IDbSchemaRetriever, EfCoreDbSchemaRetriever>()

Expand All @@ -86,6 +88,9 @@ public static IServiceCollection AddDefaultServices(this IServiceCollection serv
.AddExpressionVisitor<MemberExpression, MemberExpressionVisitor>()
.AddExpressionVisitor<ConstantExpression, ConstantExpressionVisitor>()
.AddExpressionVisitor<MethodCallExpression, MethodCallExpressionVisitor>()
.AddExpressionVisitor<LambdaExpression, LambdaExpressionVisitor>()

.AddMethodCallConverter<CountVisitor>()

.AddScoped<IUpdateExpressionVisitor, UpdateExpressionVisitor>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,9 @@
</PackageReference>
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="6.0.2" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="JetBrains.Annotations" Version="2021.3.0" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ public interface IDbSchemaRetriever
/// <param name="type"></param>
/// <returns></returns>
PropertyInfo[] GetPrimaryKeyMembers(Type type);

KeyInfo[] GetForeignKeyMembers(Type type, Type type2);
}
6 changes: 6 additions & 0 deletions src/Laraue.EfCoreTriggers.Common/Services/ISqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,10 @@ public interface ISqlGenerator
/// </summary>
/// <returns></returns>
char GetDelimiter();

/// <summary>
///
/// </summary>
/// <returns></returns>
string GetVariableSql(Type type, MemberInfo member, ArgumentType argumentType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Reflection;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;

namespace Laraue.EfCoreTriggers.Common.Services.Impl;

Expand Down Expand Up @@ -116,4 +117,28 @@ public PropertyInfo[] GetPrimaryKeyMembers(Type type)
.Select(x => x.PropertyInfo)
.ToArray();
}

public KeyInfo[] GetForeignKeyMembers(Type type, Type type2)
{
var entityType = Model.FindEntityType(type);

var outerForeignKey = entityType.GetForeignKeys()
.FirstOrDefault(x => x.PrincipalEntityType.ClrType == type2);

var outerKey = outerForeignKey
.PrincipalKey
.Properties;

var innerKey = outerForeignKey.Properties;

var keys = outerKey.Zip(innerKey)
.Select(x => new KeyInfo
{
PrincipalKey = x.First.PropertyInfo,
ForeignKey = x.Second.PropertyInfo,
})
.ToArray();

return keys;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,29 @@ public SqlBuilder Visit(Expression expression, ArgumentTypes argumentTypes, Visi
MethodCallExpression methodCall => Visit(methodCall, argumentTypes, visitedMembers),
UnaryExpression unary => Visit(unary, argumentTypes, visitedMembers),
NewExpression @new => Visit(@new, argumentTypes, visitedMembers),
LambdaExpression lambda => Visit(lambda, argumentTypes, visitedMembers),
_ => throw new NotSupportedException($"Expression of type {expression.GetType()} is not supported")
};
}

private SqlBuilder Visit<TExpression>(TExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
where TExpression : Expression
{
var visitor = GetExpressionVisitor<TExpression>();

return visitor.Visit(expression, argumentTypes, visitedMembers);
}

public IExpressionVisitor<TExpression> GetExpressionVisitor<TExpression>()
where TExpression : Expression
{
var visitor = _provider.GetService<IExpressionVisitor<TExpression>>();

if (visitor is null)
{
throw new NotSupportedException("Not supported " + typeof(TExpression));
}
return visitor.Visit(expression, argumentTypes, visitedMembers);

return visitor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,12 @@ SqlBuilder Visit(
Expression expression,
ArgumentTypes argumentTypes,
VisitedMembers visitedMembers);

/// <summary>
/// Get expression visitor for the passed expression type.
/// </summary>
/// <typeparam name="TExpression"></typeparam>
/// <returns></returns>
IExpressionVisitor<TExpression> GetExpressionVisitor<TExpression>()
where TExpression : Expression;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System;
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.Services.Impl.SetExpressionVisitors;
using Laraue.EfCoreTriggers.Common.SqlGeneration;
using Laraue.EfCoreTriggers.Common.TriggerBuilders;

namespace Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;

public class LambdaExpressionVisitor : BaseExpressionVisitor<LambdaExpression>
{
private readonly IExpressionVisitorFactory _factory;

public LambdaExpressionVisitor(IExpressionVisitorFactory factory)
{
_factory = factory;
}

public override SqlBuilder Visit(LambdaExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
return _factory.Visit(expression.Body, argumentTypes, visitedMembers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ public MemberExpressionVisitor(ISqlGenerator generator)
public override SqlBuilder Visit(MemberExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
argumentTypes ??= new ArgumentTypes();
var parameterExpression = (ParameterExpression)expression.Expression;
var memberName = parameterExpression.Name;
if (!argumentTypes.TryGetValue(memberName, out var argumentType))
{
argumentType = ArgumentType.Default;
}

var argumentType = argumentTypes.Get((ParameterExpression) expression.Expression);

visitedMembers.AddMember(argumentType, expression.Member);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,32 @@ namespace Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;
/// <inheritdoc />
public class MethodCallExpressionVisitor : BaseExpressionVisitor<MethodCallExpression>
{
private readonly IMethodCallVisitor[] _converters;
private readonly IMethodCallVisitor[] _visitors;

/// <inheritdoc />
public MethodCallExpressionVisitor(IEnumerable<IMethodCallVisitor> methodCallVisitors)
{
_converters = methodCallVisitors.Reverse().ToArray();
_visitors = methodCallVisitors.Reverse().ToArray();
}

/// <inheritdoc />
/// <inheritdoc />s
public override SqlBuilder Visit(MethodCallExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
foreach (var converter in _converters)
var visitor = GetVisitor(expression);

return visitor.Visit(expression, argumentTypes, visitedMembers);
}

public IMethodCallVisitor GetVisitor(MethodCallExpression expression)
{
foreach (var converter in _visitors)
{
if (converter.IsApplicable(expression))
{
return converter.Visit(expression, argumentTypes, visitedMembers);
return converter;
}
}

throw new NotSupportedException($"Expression {expression.Method.Name} is not supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public Dictionary<MemberInfo, SqlBuilder> Visit(Expression expression, ArgumentT
LambdaExpression lambdaExpression => Visit(lambdaExpression, argumentTypes, visitedMembers),
MemberInitExpression memberInitExpression => Visit(memberInitExpression, argumentTypes, visitedMembers),
NewExpression newExpression => Visit(newExpression, argumentTypes, visitedMembers),
BinaryExpression binaryExpression => Visit(binaryExpression, argumentTypes, visitedMembers),
_ => throw new NotSupportedException($"Expression of type {expression.GetType()} is not supported")
};
}
Expand Down
Loading

0 comments on commit 44c9a38

Please sign in to comment.