Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/enumerable count support #33

Merged
merged 6 commits into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.Services;
using Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;
using Laraue.EfCoreTriggers.Common.SqlGeneration;
using Laraue.EfCoreTriggers.Common.TriggerBuilders;

namespace Laraue.EfCoreTriggers.Common.Converters.MethodCall.Enumerable
{
/// <summary>
/// Base visitor for <see cref="IEnumerable{T}"/> extensions.
/// </summary>
public abstract class BaseEnumerableVisitor : BaseMethodCallVisitor
{
/// <inheritdoc />
protected override Type ReflectedType => typeof(System.Linq.Enumerable);

private readonly IDbSchemaRetriever _schemaRetriever;
private readonly ISqlGenerator _sqlGenerator;
private readonly IExpressionVisitorFactory _expressionVisitorFactory;

/// <inheritdoc />
protected BaseEnumerableVisitor(
IExpressionVisitorFactory visitorFactory,
IDbSchemaRetriever schemaRetriever,
ISqlGenerator sqlGenerator)
: base(visitorFactory)
{
_schemaRetriever = schemaRetriever;
_sqlGenerator = sqlGenerator;
_expressionVisitorFactory = visitorFactory;
}

public override SqlBuilder Visit(MethodCallExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
var whereExpressions = new HashSet<Expression>();
var exp = GetFlattenExpressions(expression, whereExpressions);

var baseMember = exp as MemberExpression;
var enumerableMemberType = baseMember.Type;

var originalSetMemberExpression = (ParameterExpression) baseMember.Expression;
var originalSetType = originalSetMemberExpression?.Type;

if (!typeof(IEnumerable).IsAssignableFrom(enumerableMemberType) || !enumerableMemberType.IsGenericType)
{
throw new NotSupportedException($"Don't know how to translate expression {expression}");
}

var entityType = baseMember.Type.GetGenericArguments()[0];

var otherArguments = expression.Arguments
.Skip(1)
.ToArray();

var selectSql = Visit(otherArguments, argumentTypes, visitedMembers);

if (selectSql.Item2 is not null)
{
whereExpressions.Add(selectSql.Item2);
}

var finalSql = SqlBuilder.FromString("(");

finalSql.WithIdent(x=> x
.Append("SELECT ")
.Append(selectSql.Item1)
.AppendNewLine($"FROM {_schemaRetriever.GetTableName(entityType)}")
.AppendNewLine($"INNER JOIN {_schemaRetriever.GetTableName(originalSetType)} ON "));

var keys = _schemaRetriever.GetForeignKeyMembers(entityType, originalSetType);

var joinParts = new List<string>();
foreach (var key in keys)
{
var column1Sql = _sqlGenerator.GetColumnSql(entityType, key.ForeignKey, ArgumentType.Default);

var argument2Type = argumentTypes.Get(originalSetMemberExpression);

var column2WhereSql = _sqlGenerator.GetVariableSql(originalSetType, key.PrincipalKey, argument2Type);
visitedMembers.AddMember(argument2Type, key.PrincipalKey);

var column2JoinSql = _sqlGenerator.GetColumnSql(originalSetType, key.PrincipalKey, ArgumentType.Default);

joinParts.Add($"{column1Sql} = {column2JoinSql}");
joinParts.Add($"{column1Sql} = {column2WhereSql}");
}

foreach (var e in whereExpressions)
{
joinParts.Add(_expressionVisitorFactory.Visit(e, argumentTypes, visitedMembers));
}

finalSql.AppendJoin(" AND ", joinParts);

finalSql.Append(")");

return finalSql;
}

private Expression GetFlattenExpressions(MethodCallExpression methodCallExpression, HashSet<Expression> whereExpressions)
{
while (true)
{
if (methodCallExpression.Arguments[0] is not MethodCallExpression childCall ||
childCall.Method.Name != nameof(System.Linq.Enumerable.Where))
{
return methodCallExpression.Arguments[0];
}

whereExpressions.Add(childCall.Arguments[1]);

methodCallExpression = childCall;
}
}

protected abstract (SqlBuilder, Expression) Visit(
Expression[] arguments,
ArgumentTypes argumentTypes,
VisitedMembers visitedMembers);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.Services;
using Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;
using Laraue.EfCoreTriggers.Common.SqlGeneration;
using Laraue.EfCoreTriggers.Common.TriggerBuilders;

namespace Laraue.EfCoreTriggers.Common.Converters.MethodCall.Enumerable.Count;

public class CountVisitor : BaseEnumerableVisitor
{
protected override string MethodName => nameof(System.Linq.Enumerable.Count);

private readonly IExpressionVisitorFactory _expressionVisitorFactory;

public CountVisitor(
IExpressionVisitorFactory visitorFactory,
IDbSchemaRetriever schemaRetriever,
ISqlGenerator sqlGenerator)
: base(visitorFactory, schemaRetriever, sqlGenerator)
{
_expressionVisitorFactory = visitorFactory;
}

protected override (SqlBuilder, Expression) Visit(Expression[] arguments, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
return (SqlBuilder.FromString("count(*)"), arguments.FirstOrDefault());
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.Converters.MethodCall;
using Laraue.EfCoreTriggers.Common.Converters.MethodCall.Enumerable.Count;
using Laraue.EfCoreTriggers.Common.Services;
using Laraue.EfCoreTriggers.Common.Services.Impl;
using Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;
Expand Down Expand Up @@ -78,6 +79,7 @@ public static IServiceCollection AddDefaultServices(this IServiceCollection serv
.AddScoped<IMemberInfoVisitor<LambdaExpression>, SetLambdaExpressionVisitor>()
.AddScoped<IMemberInfoVisitor<MemberInitExpression>, SetMemberInitExpressionVisitor>()
.AddScoped<IMemberInfoVisitor<NewExpression>, SetNewExpressionVisitor>()
.AddScoped<IMemberInfoVisitor<BinaryExpression>, SetBinaryExpressionVisitor>()

.AddScoped<IDbSchemaRetriever, EfCoreDbSchemaRetriever>()

Expand All @@ -86,6 +88,9 @@ public static IServiceCollection AddDefaultServices(this IServiceCollection serv
.AddExpressionVisitor<MemberExpression, MemberExpressionVisitor>()
.AddExpressionVisitor<ConstantExpression, ConstantExpressionVisitor>()
.AddExpressionVisitor<MethodCallExpression, MethodCallExpressionVisitor>()
.AddExpressionVisitor<LambdaExpression, LambdaExpressionVisitor>()

.AddMethodCallConverter<CountVisitor>()

.AddScoped<IUpdateExpressionVisitor, UpdateExpressionVisitor>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,9 @@
</PackageReference>
<PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="6.0.2" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="JetBrains.Annotations" Version="2021.3.0" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ public interface IDbSchemaRetriever
/// <param name="type"></param>
/// <returns></returns>
PropertyInfo[] GetPrimaryKeyMembers(Type type);

KeyInfo[] GetForeignKeyMembers(Type type, Type type2);
}
6 changes: 6 additions & 0 deletions src/Laraue.EfCoreTriggers.Common/Services/ISqlGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,10 @@ public interface ISqlGenerator
/// </summary>
/// <returns></returns>
char GetDelimiter();

/// <summary>
///
/// </summary>
/// <returns></returns>
string GetVariableSql(Type type, MemberInfo member, ArgumentType argumentType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Reflection;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;

namespace Laraue.EfCoreTriggers.Common.Services.Impl;

Expand Down Expand Up @@ -116,4 +117,28 @@ public PropertyInfo[] GetPrimaryKeyMembers(Type type)
.Select(x => x.PropertyInfo)
.ToArray();
}

public KeyInfo[] GetForeignKeyMembers(Type type, Type type2)
{
var entityType = Model.FindEntityType(type);

var outerForeignKey = entityType.GetForeignKeys()
.FirstOrDefault(x => x.PrincipalEntityType.ClrType == type2);

var outerKey = outerForeignKey
.PrincipalKey
.Properties;

var innerKey = outerForeignKey.Properties;

var keys = outerKey.Zip(innerKey)
.Select(x => new KeyInfo
{
PrincipalKey = x.First.PropertyInfo,
ForeignKey = x.Second.PropertyInfo,
})
.ToArray();

return keys;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,29 @@ public SqlBuilder Visit(Expression expression, ArgumentTypes argumentTypes, Visi
MethodCallExpression methodCall => Visit(methodCall, argumentTypes, visitedMembers),
UnaryExpression unary => Visit(unary, argumentTypes, visitedMembers),
NewExpression @new => Visit(@new, argumentTypes, visitedMembers),
LambdaExpression lambda => Visit(lambda, argumentTypes, visitedMembers),
_ => throw new NotSupportedException($"Expression of type {expression.GetType()} is not supported")
};
}

private SqlBuilder Visit<TExpression>(TExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
where TExpression : Expression
{
var visitor = GetExpressionVisitor<TExpression>();

return visitor.Visit(expression, argumentTypes, visitedMembers);
}

public IExpressionVisitor<TExpression> GetExpressionVisitor<TExpression>()
where TExpression : Expression
{
var visitor = _provider.GetService<IExpressionVisitor<TExpression>>();

if (visitor is null)
{
throw new NotSupportedException("Not supported " + typeof(TExpression));
}
return visitor.Visit(expression, argumentTypes, visitedMembers);

return visitor;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,12 @@ SqlBuilder Visit(
Expression expression,
ArgumentTypes argumentTypes,
VisitedMembers visitedMembers);

/// <summary>
/// Get expression visitor for the passed expression type.
/// </summary>
/// <typeparam name="TExpression"></typeparam>
/// <returns></returns>
IExpressionVisitor<TExpression> GetExpressionVisitor<TExpression>()
where TExpression : Expression;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System;
using System.Linq.Expressions;
using Laraue.EfCoreTriggers.Common.Services.Impl.SetExpressionVisitors;
using Laraue.EfCoreTriggers.Common.SqlGeneration;
using Laraue.EfCoreTriggers.Common.TriggerBuilders;

namespace Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;

public class LambdaExpressionVisitor : BaseExpressionVisitor<LambdaExpression>
{
private readonly IExpressionVisitorFactory _factory;

public LambdaExpressionVisitor(IExpressionVisitorFactory factory)
{
_factory = factory;
}

public override SqlBuilder Visit(LambdaExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
return _factory.Visit(expression.Body, argumentTypes, visitedMembers);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ public MemberExpressionVisitor(ISqlGenerator generator)
public override SqlBuilder Visit(MemberExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
argumentTypes ??= new ArgumentTypes();
var parameterExpression = (ParameterExpression)expression.Expression;
var memberName = parameterExpression.Name;
if (!argumentTypes.TryGetValue(memberName, out var argumentType))
{
argumentType = ArgumentType.Default;
}

var argumentType = argumentTypes.Get((ParameterExpression) expression.Expression);

visitedMembers.AddMember(argumentType, expression.Member);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,32 @@ namespace Laraue.EfCoreTriggers.Common.Services.Impl.ExpressionVisitors;
/// <inheritdoc />
public class MethodCallExpressionVisitor : BaseExpressionVisitor<MethodCallExpression>
{
private readonly IMethodCallVisitor[] _converters;
private readonly IMethodCallVisitor[] _visitors;

/// <inheritdoc />
public MethodCallExpressionVisitor(IEnumerable<IMethodCallVisitor> methodCallVisitors)
{
_converters = methodCallVisitors.Reverse().ToArray();
_visitors = methodCallVisitors.Reverse().ToArray();
}

/// <inheritdoc />
/// <inheritdoc />s
public override SqlBuilder Visit(MethodCallExpression expression, ArgumentTypes argumentTypes, VisitedMembers visitedMembers)
{
foreach (var converter in _converters)
var visitor = GetVisitor(expression);

return visitor.Visit(expression, argumentTypes, visitedMembers);
}

public IMethodCallVisitor GetVisitor(MethodCallExpression expression)
{
foreach (var converter in _visitors)
{
if (converter.IsApplicable(expression))
{
return converter.Visit(expression, argumentTypes, visitedMembers);
return converter;
}
}

throw new NotSupportedException($"Expression {expression.Method.Name} is not supported");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public Dictionary<MemberInfo, SqlBuilder> Visit(Expression expression, ArgumentT
LambdaExpression lambdaExpression => Visit(lambdaExpression, argumentTypes, visitedMembers),
MemberInitExpression memberInitExpression => Visit(memberInitExpression, argumentTypes, visitedMembers),
NewExpression newExpression => Visit(newExpression, argumentTypes, visitedMembers),
BinaryExpression binaryExpression => Visit(binaryExpression, argumentTypes, visitedMembers),
_ => throw new NotSupportedException($"Expression of type {expression.GetType()} is not supported")
};
}
Expand Down
Loading