Skip to content

Commit

Permalink
Add support for private and protected methods in DbFunctions
Browse files Browse the repository at this point in the history
  • Loading branch information
pmiddleton authored and smitpatel committed Oct 19, 2017
1 parent 746ff59 commit 843929d
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 46 deletions.
19 changes: 13 additions & 6 deletions src/EFCore.Relational/Infrastructure/RelationalModelCustomizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,20 @@ protected virtual void FindDbFunctions([NotNull] ModelBuilder modelBuilder, [Not
Check.NotNull(modelBuilder, nameof(modelBuilder));
Check.NotNull(context, nameof(context));

var functions = context.GetType().GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static | BindingFlags.FlattenHierarchy)
.Where(mi => mi.IsPublic
&& mi.GetCustomAttributes(typeof(DbFunctionAttribute)).Any());
var contextType = context.GetType();

foreach (var function in functions)
{
modelBuilder.HasDbFunction(function);
while(contextType != typeof(DbContext))
{
var functions = contextType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance
| BindingFlags.Static | BindingFlags.DeclaredOnly)
.Where(mi => mi.GetCustomAttributes(typeof(DbFunctionAttribute)).Any());

foreach (var function in functions)
{
modelBuilder.HasDbFunction(function);
}

contextType = contextType.BaseType;
}
}

Expand Down
146 changes: 106 additions & 40 deletions test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ public static int DuplicateNameTest()

public class MyBaseContext : DbContext
{
public static readonly string[] FunctionNames =
{
nameof(MyBaseContext.StaticPublicBase),
nameof(MyBaseContext.StaticProtectedBase),
nameof(MyBaseContext.StaticPrivateBase),
nameof(MyBaseContext.StaticInteranlBase),
nameof(MyBaseContext.StaticProtectedInteralBase),
nameof(MyBaseContext.InstancePublicBase),
nameof(MyBaseContext.InstanceProtectedBase),
nameof(MyBaseContext.InstancePrivateBase),
nameof(MyBaseContext.InstanceInteranlBase),
nameof(MyBaseContext.InstanceProtectedInteralBase),
};

public static void Foo()
{
}
Expand All @@ -42,22 +56,57 @@ public static void Skip2()
private static void Skip()
{
}

[DbFunction]
public static int StaticBase()
{
throw new Exception();
}
public static int StaticPublicBase() => throw new Exception();

[DbFunction]
public int NonStaticBase()
{
throw new Exception();
}
protected static int StaticProtectedBase() => throw new Exception();

[DbFunction]
private static int StaticPrivateBase() => throw new Exception();

[DbFunction]
internal static int StaticInteranlBase() => throw new Exception();

[DbFunction]
protected internal static int StaticProtectedInteralBase() => throw new Exception();

[DbFunction]
public int InstancePublicBase() => throw new Exception();

[DbFunction]
protected int InstanceProtectedBase() => throw new Exception();

[DbFunction]
private int InstancePrivateBase() => throw new Exception();

[DbFunction]
internal int InstanceInteranlBase() => throw new Exception();

[DbFunction]
protected internal int InstanceProtectedInteralBase() => throw new Exception();

[DbFunction]
public virtual int VirtualBase() => throw new Exception();
}

public class MyDerivedContext : MyBaseContext
{
public new static readonly string[] FunctionNames =
{
nameof(MyDerivedContext.StaticPublicDerived),
nameof(MyDerivedContext.StaticProtectedDerived),
nameof(MyDerivedContext.StaticPrivateDerived),
nameof(MyDerivedContext.StaticInteranlDerived),
nameof(MyDerivedContext.StaticProtectedInteralDerived),
nameof(MyDerivedContext.InstancePublicDerived),
nameof(MyDerivedContext.InstanceProtectedDerived),
nameof(MyDerivedContext.InstancePrivateDerived),
nameof(MyDerivedContext.InstanceInteranlDerived),
nameof(MyDerivedContext.InstanceProtectedInteralDerived),
};

public static void Bar()
{
}
Expand All @@ -70,22 +119,43 @@ private static void Skip4()
{
}

[DbFunction]
public static int StaticDerived()
public static int DuplicateNameTest()
{
throw new Exception();
}

[DbFunction]
public int NonStaticDerived()
{
throw new Exception();
}
public static int StaticPublicDerived() => throw new Exception();

public static int DuplicateNameTest()
{
throw new Exception();
}
[DbFunction]
protected static int StaticProtectedDerived() => throw new Exception();

[DbFunction]
private static int StaticPrivateDerived() => throw new Exception();

[DbFunction]
internal static int StaticInteranlDerived() => throw new Exception();

[DbFunction]
protected internal static int StaticProtectedInteralDerived() => throw new Exception();

[DbFunction]
public int InstancePublicDerived() => throw new Exception();

[DbFunction]
protected int InstanceProtectedDerived() => throw new Exception();

[DbFunction]
private int InstancePrivateDerived() => throw new Exception();

[DbFunction]
internal int InstanceInteranlDerived() => throw new Exception();

[DbFunction]
protected internal int InstanceProtectedInteralDerived() => throw new Exception();

[DbFunction]
public override int VirtualBase() => throw new Exception();
}

public static MethodInfo MethodAmi = typeof(TestMethods).GetRuntimeMethod(nameof(TestMethods.MethodA), new[] { typeof(string), typeof(int) });
Expand Down Expand Up @@ -132,16 +202,16 @@ public virtual void DbFunctions_with_duplicate_names_and_parameters_on_different
{
var modelBuilder = GetModelBuilder();

var Dup1methodInfo
var dup1methodInfo
= typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.DuplicateNameTest), new Type[] { });

var Dup2methodInfo
var dup2methodInfo
= typeof(MyNonDbContext)
.GetRuntimeMethod(nameof(MyNonDbContext.DuplicateNameTest), new Type[] { });

var dbFunc1 = modelBuilder.HasDbFunction(Dup1methodInfo).HasName("Dup1").Metadata;
var dbFunc2 = modelBuilder.HasDbFunction(Dup2methodInfo).HasName("Dup2").Metadata;
var dbFunc1 = modelBuilder.HasDbFunction(dup1methodInfo).HasName("Dup1").Metadata;
var dbFunc2 = modelBuilder.HasDbFunction(dup2methodInfo).HasName("Dup2").Metadata;

Assert.Equal("Dup1", dbFunc1.FunctionName);
Assert.Equal("Dup2", dbFunc2.FunctionName);
Expand All @@ -156,35 +226,31 @@ public virtual void Finds_dbFunctions_on_dbContext()

customizer.Customize(modelBuilder, new MyDerivedContext());

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyBaseContext.NonStaticBase), new Type[] { })));

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyBaseContext)
.GetRuntimeMethod(nameof(MyBaseContext.StaticBase), new Type[] { })));

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticDerived), new Type[] { })));
foreach (var function in MyBaseContext.FunctionNames)
{
Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyBaseContext).GetMethod(function, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)));
}

Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticDerived), new Type[] { })));
foreach (var function in MyDerivedContext.FunctionNames)
{
Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction(
typeof(MyDerivedContext).GetMethod(function, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)));
}
}

[Fact]
public virtual void Non_static_function_on_dbcontext_does_not_throw()
{
var modelBuilder = GetModelBuilder();

var methodInfo
var methodInfo
= typeof(MyDerivedContext)
.GetRuntimeMethod(nameof(MyDerivedContext.NonStaticBase), new Type[] { });
.GetRuntimeMethod(nameof(MyDerivedContext.InstancePublicBase), new Type[] { });

var dbFunc = modelBuilder.HasDbFunction(methodInfo).Metadata;

Assert.Equal("NonStaticBase", dbFunc.FunctionName);
Assert.Equal("InstancePublicBase", dbFunc.FunctionName);
Assert.Equal(typeof(int), dbFunc.MethodInfo.ReturnType);
}

Expand Down

0 comments on commit 843929d

Please sign in to comment.