From b4b299c9db3412f78418f19575c9f3b417132be0 Mon Sep 17 00:00:00 2001 From: Paul Middleton Date: Mon, 4 Sep 2017 20:44:00 -0500 Subject: [PATCH] DBFunctions - Add support for instance methods. --- .../RelationalModelCustomizer.cs | 6 +- .../Metadata/Internal/DbFunction.cs | 8 +- .../Properties/RelationalStrings.Designer.cs | 16 +- .../Properties/RelationalStrings.resx | 6 +- .../SqlTranslatingExpressionVisitor.cs | 4 +- .../Metadata/DbFunctionMetadataTests.cs | 111 +- .../Query/UdfDbFunctionSqlServerTests.cs | 1008 +++++++++++++++-- 7 files changed, 1030 insertions(+), 129 deletions(-) diff --git a/src/EFCore.Relational/Infrastructure/RelationalModelCustomizer.cs b/src/EFCore.Relational/Infrastructure/RelationalModelCustomizer.cs index 796d58e308b..3e321b89702 100644 --- a/src/EFCore.Relational/Infrastructure/RelationalModelCustomizer.cs +++ b/src/EFCore.Relational/Infrastructure/RelationalModelCustomizer.cs @@ -72,10 +72,8 @@ protected virtual void FindDbFunctions([NotNull] ModelBuilder modelBuilder, [Not Check.NotNull(modelBuilder, nameof(modelBuilder)); Check.NotNull(context, nameof(context)); - var functions = context.GetType().GetMethods(BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy) - .Where( - mi => mi.IsStatic - && mi.IsPublic + var functions = context.GetType().GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static | BindingFlags.FlattenHierarchy) + .Where(mi => mi.IsPublic && mi.GetCustomAttributes(typeof(DbFunctionAttribute)).Any()); foreach (var function in functions) diff --git a/src/EFCore.Relational/Metadata/Internal/DbFunction.cs b/src/EFCore.Relational/Metadata/Internal/DbFunction.cs index a1130f993d9..cfde55f1b84 100644 --- a/src/EFCore.Relational/Metadata/Internal/DbFunction.cs +++ b/src/EFCore.Relational/Metadata/Internal/DbFunction.cs @@ -68,9 +68,11 @@ private DbFunction( throw new ArgumentException(RelationalStrings.DbFunctionGenericMethodNotSupported(methodInfo.DisplayName())); } - if (!methodInfo.IsStatic) + if (!methodInfo.IsStatic + && !typeof(DbContext).IsAssignableFrom(methodInfo.DeclaringType)) { - throw new ArgumentException(RelationalStrings.DbFunctionMethodMustBeStatic(methodInfo.DisplayName())); + throw new ArgumentException( + RelationalStrings.DbFunctionInvalidInstanceType(methodInfo.DisplayName(), methodInfo.DeclaringType.ShortDisplayName())); } if (methodInfo.ReturnType == null @@ -102,7 +104,7 @@ public static IEnumerable GetDbFunctions([NotNull] IModel model, [N } private static string BuildAnnotationName(string annotationPrefix, MethodBase methodBase) - => $@"{annotationPrefix}{methodBase.Name}({string.Join(",", methodBase.GetParameters().Select(p => p.ParameterType.Name))})"; + => $@"{annotationPrefix}{methodBase.DeclaringType.ShortDisplayName()}{methodBase.Name}({string.Join(",", methodBase.GetParameters().Select(p => p.ParameterType.Name))})"; /// /// This API supports the Entity Framework Core infrastructure and is not intended to be used diff --git a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs index 19d036ce91d..9a6ffa82604 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs +++ b/src/EFCore.Relational/Properties/RelationalStrings.Designer.cs @@ -848,14 +848,6 @@ public static string DbFunctionInvalidParameterType([CanBeNull] object parameter GetString("DbFunctionInvalidParameterType", nameof(parameter), nameof(function), nameof(type)), parameter, function, type); - /// - /// The DbFunction '{function}' must be a static method. Non-static methods are not supported. - /// - public static string DbFunctionMethodMustBeStatic([CanBeNull] object function) - => string.Format( - GetString("DbFunctionMethodMustBeStatic", nameof(function)), - function); - /// /// The DbFunction '{function}' is generic. Generic methods are not supported. /// @@ -872,6 +864,14 @@ public static string DbFunctionExpressionIsNotMethodCall([CanBeNull] object expr GetString("DbFunctionExpressionIsNotMethodCall", nameof(expression)), expression); + /// + /// The DbFunction '{function}' defined on type '{type}' must be either a static method or an instance method defined on a DbContext subclass. Instance methods on other types are not supported. + /// + public static string DbFunctionInvalidInstanceType([CanBeNull] object function, [CanBeNull] object type) + => string.Format( + GetString("DbFunctionInvalidInstanceType", nameof(function), nameof(type)), + function, type); + /// /// An ambient transaction has been detected. The ambient transaction needs to be completed before beginning a transaction on this connection. /// diff --git a/src/EFCore.Relational/Properties/RelationalStrings.resx b/src/EFCore.Relational/Properties/RelationalStrings.resx index e37cc82e9c4..be7aac66d6f 100644 --- a/src/EFCore.Relational/Properties/RelationalStrings.resx +++ b/src/EFCore.Relational/Properties/RelationalStrings.resx @@ -425,15 +425,15 @@ The parameter '{parameter}' for the DbFunction '{function}' has an invalid type '{type}'. Ensure the parameter type can be mapped by the current provider. - - The DbFunction '{function}' must be a static method. Non-static methods are not supported. - The DbFunction '{function}' is generic. Generic methods are not supported. The provided DbFunction expression '{expression}' is invalid. The expression should be a lambda expression containing a single method call to the target static method. Default values can be provided as arguments if required. E.g. () => SomeClass.SomeMethod(null, 0) + + The DbFunction '{function}' defined on type '{type}' must be either a static method or an instance method defined on a DbContext subclass. Instance methods on other types are not supported. + An ambient transaction has been detected. The ambient transaction needs to be completed before beginning a transaction on this connection. diff --git a/src/EFCore.Relational/Query/ExpressionVisitors/SqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/ExpressionVisitors/SqlTranslatingExpressionVisitor.cs index e8d6aefa026..09a40cd027a 100644 --- a/src/EFCore.Relational/Query/ExpressionVisitors/SqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/ExpressionVisitors/SqlTranslatingExpressionVisitor.cs @@ -613,7 +613,9 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp { Check.NotNull(methodCallExpression, nameof(methodCallExpression)); - var operand = Visit(methodCallExpression.Object); + var operand = _queryModelVisitor.QueryCompilationContext.Model.Relational().FindDbFunction(methodCallExpression.Method) != null + ? methodCallExpression.Object + : Visit(methodCallExpression.Object); if (operand != null || methodCallExpression.Object == null) diff --git a/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs b/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs index 0791a06bdf0..0b23c8103c5 100644 --- a/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs +++ b/test/EFCore.Relational.Tests/Metadata/DbFunctionMetadataTests.cs @@ -4,6 +4,7 @@ using System; using System.Linq.Expressions; using System.Reflection; +using Microsoft.EntityFrameworkCore.Infrastructure; using Microsoft.EntityFrameworkCore.Internal; using Microsoft.EntityFrameworkCore.Metadata.Conventions; using Microsoft.EntityFrameworkCore.Metadata.Conventions.Internal; @@ -15,9 +16,21 @@ namespace Microsoft.EntityFrameworkCore.Metadata { public class DbFunctionMetadataTests { + public class MyNonDbContext + { + public int NonStatic() + { + throw new Exception(); + } + + public static int DuplicateNameTest() + { + throw new Exception(); + } + } + public class MyBaseContext : DbContext { - [DbFunction] public static void Foo() { } @@ -29,11 +42,22 @@ public static void Skip2() private static void Skip() { } + + [DbFunction] + public static int StaticBase() + { + throw new Exception(); + } + + [DbFunction] + public int NonStaticBase() + { + throw new Exception(); + } } public class MyDerivedContext : MyBaseContext { - [DbFunction] public static void Bar() { } @@ -47,8 +71,20 @@ private static void Skip4() } [DbFunction] - public void NonStatic() + public static int StaticDerived() + { + throw new Exception(); + } + + [DbFunction] + public int NonStaticDerived() + { + throw new Exception(); + } + + public static int DuplicateNameTest() { + throw new Exception(); } } @@ -92,16 +128,77 @@ public static int MethodH(T a, string b) } [Fact] - public virtual void Detects_non_static_function_on_dbcontext() + public virtual void DbFunctions_with_duplicate_names_and_parameters_on_different_types_dont_collide() { var modelBuilder = GetModelBuilder(); - var methodInfo + var Dup1methodInfo = typeof(MyDerivedContext) - .GetRuntimeMethod(nameof(MyDerivedContext.NonStatic), new Type[] { }); + .GetRuntimeMethod(nameof(MyDerivedContext.DuplicateNameTest), new Type[] { }); + + var Dup2methodInfo + = typeof(MyNonDbContext) + .GetRuntimeMethod(nameof(MyNonDbContext.DuplicateNameTest), new Type[] { }); + + var dbFunc1 = modelBuilder.HasDbFunction(Dup1methodInfo).HasName("Dup1").Metadata; + var dbFunc2 = modelBuilder.HasDbFunction(Dup2methodInfo).HasName("Dup2").Metadata; + + Assert.Equal("Dup1", dbFunc1.FunctionName); + Assert.Equal("Dup2", dbFunc2.FunctionName); + } + + [Fact] + public virtual void Finds_dbFunctions_on_dbContext() + { + var modelBuilder = GetModelBuilder(); + + var customizer = new RelationalModelCustomizer(new ModelCustomizerDependencies(new DbSetFinder())); + + customizer.Customize(modelBuilder, new MyDerivedContext()); + + Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction( + typeof(MyDerivedContext) + .GetRuntimeMethod(nameof(MyBaseContext.NonStaticBase), new Type[] { }))); + + Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction( + typeof(MyBaseContext) + .GetRuntimeMethod(nameof(MyBaseContext.StaticBase), new Type[] { }))); + + Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction( + typeof(MyDerivedContext) + .GetRuntimeMethod(nameof(MyDerivedContext.NonStaticDerived), new Type[] { }))); + + Assert.NotNull(modelBuilder.Model.Relational().FindDbFunction( + typeof(MyDerivedContext) + .GetRuntimeMethod(nameof(MyDerivedContext.NonStaticDerived), new Type[] { }))); + } + + [Fact] + public virtual void Non_static_function_on_dbcontext_does_not_throw() + { + var modelBuilder = GetModelBuilder(); + + var methodInfo + = typeof(MyDerivedContext) + .GetRuntimeMethod(nameof(MyDerivedContext.NonStaticBase), new Type[] { }); + + var dbFunc = modelBuilder.HasDbFunction(methodInfo).Metadata; + + Assert.Equal("NonStaticBase", dbFunc.FunctionName); + Assert.Equal(typeof(int), dbFunc.MethodInfo.ReturnType); + } + + [Fact] + public virtual void Non_static_function_on_non_dbcontext_throws() + { + var modelBuilder = GetModelBuilder(); + + var methodInfo + = typeof(MyNonDbContext) + .GetRuntimeMethod(nameof(MyNonDbContext.NonStatic), new Type[] { }); Assert.Equal( - RelationalStrings.DbFunctionMethodMustBeStatic("MyDerivedContext.NonStatic"), + RelationalStrings.DbFunctionInvalidInstanceType(methodInfo.DisplayName(), typeof(MyNonDbContext).ShortDisplayName()), Assert.Throws(() => modelBuilder.HasDbFunction(methodInfo)).Message); } diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs index 61987f53d81..a102fb35c5f 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/UdfDbFunctionSqlServerTests.cs @@ -59,32 +59,32 @@ public enum ReportingPeriod Fall } - public static long MyCustomLength(string s) + public static long MyCustomLengthStatic(string s) { throw new Exception(); } - public static bool IsDate(string date) + public static bool IsDateStatic(string date) { throw new Exception(); } - public static int AddOne(int num) + public static int AddOneStatic(int num) { return num + 1; } - public static int AddFive(int number) + public static int AddFiveStatic(int number) { return number + 5; } - public static int CustomerOrderCount(int customerId) + public static int CustomerOrderCountStatic(int customerId) { throw new NotImplementedException(); } - public static int CustomerOrderCountWithClient(int customerId) + public static int CustomerOrderCountWithClientStatic(int customerId) { switch (customerId) { @@ -101,22 +101,89 @@ public static int CustomerOrderCountWithClient(int customerId) } } - public static string StarValue(int starCount, int value) + public static string StarValueStatic(int starCount, int value) { throw new NotImplementedException(); } - public static bool IsTopCustomer(int customerId) + public static bool IsTopCustomerStatic(int customerId) { throw new NotImplementedException(); } - public static int GetCustomerWithMostOrdersAfterDate(DateTime? startDate) + public static int GetCustomerWithMostOrdersAfterDateStatic(DateTime? startDate) { throw new NotImplementedException(); } - public static DateTime? GetReportingPeriodStartDate(ReportingPeriod periodId) + public static DateTime? GetReportingPeriodStartDateStatic(ReportingPeriod periodId) + { + throw new NotImplementedException(); + } + + public long MyCustomLengthInstance(string s) + { + throw new Exception(); + } + + public bool IsDateInstance(string date) + { + throw new Exception(); + } + + public int AddOneInstance(int num) + { + return num + 1; + } + + public int AddFiveInstance(int number) + { + return number + 5; + } + + public int CustomerOrderCountInstance(int customerId) + { + throw new NotImplementedException(); + } + + public int CustomerOrderCountWithClientInstance(int customerId) + { + switch (customerId) + { + case 1: + return 3; + case 2: + return 2; + case 3: + return 1; + case 4: + return 0; + default: + throw new Exception(); + } + } + + public string StarValueInstance(int starCount, int value) + { + throw new NotImplementedException(); + } + + public bool IsTopCustomerInstance(int customerId) + { + throw new NotImplementedException(); + } + + public int GetCustomerWithMostOrdersAfterDateInstance(DateTime? startDate) + { + throw new NotImplementedException(); + } + + public DateTime? GetReportingPeriodStartDateInstance(ReportingPeriod periodId) + { + throw new NotImplementedException(); + } + + public string DollarValueInstance(int starCount, string value) { throw new NotImplementedException(); } @@ -130,29 +197,48 @@ public UDFSqlContext(DbContextOptions options) protected override void OnModelCreating(ModelBuilder modelBuilder) { - modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(CustomerOrderCount))); - modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(CustomerOrderCountWithClient))).HasName("CustomerOrderCount"); - modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(StarValue))); - modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(IsTopCustomer))); - modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetCustomerWithMostOrdersAfterDate))); - modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetReportingPeriodStartDate))); - modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(IsDate))).HasSchema(""); + //Static + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(CustomerOrderCountStatic))).HasName("CustomerOrderCount"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(CustomerOrderCountWithClientStatic))).HasName("CustomerOrderCount"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(StarValueStatic))).HasName("StarValue"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(IsTopCustomerStatic))).HasName("IsTopCustomer"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetCustomerWithMostOrdersAfterDateStatic))).HasName("GetCustomerWithMostOrdersAfterDate"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetReportingPeriodStartDateStatic))).HasName("GetReportingPeriodStartDate"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(IsDateStatic))).HasSchema("").HasName("IsDate"); - var methodInfo = typeof(UDFSqlContext).GetMethod(nameof(MyCustomLength)); + var methodInfo = typeof(UDFSqlContext).GetMethod(nameof(MyCustomLengthStatic)); modelBuilder.HasDbFunction(methodInfo) .HasTranslation(args => new SqlFunctionExpression("len", methodInfo.ReturnType, args)); + + //Instance + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(CustomerOrderCountInstance))).HasName("CustomerOrderCount"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(CustomerOrderCountWithClientInstance))).HasName("CustomerOrderCount"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(StarValueInstance))).HasName("StarValue"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(IsTopCustomerInstance))).HasName("IsTopCustomer"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetCustomerWithMostOrdersAfterDateInstance))).HasName("GetCustomerWithMostOrdersAfterDate"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(GetReportingPeriodStartDateInstance))).HasName("GetReportingPeriodStartDate"); + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(IsDateInstance))).HasSchema("").HasName("IsDate"); + + modelBuilder.HasDbFunction(typeof(UDFSqlContext).GetMethod(nameof(DollarValueInstance))).HasName("DollarValue"); + + var methodInfo2 = typeof(UDFSqlContext).GetMethod(nameof(MyCustomLengthInstance)); + + modelBuilder.HasDbFunction(methodInfo2) + .HasTranslation(args => new SqlFunctionExpression("len", methodInfo2.ReturnType, args)); } } #region Scalar Tests + #region Static + [Fact] - private void Scalar_Function_Extension_Method() + private void Scalar_Function_Extension_Method_Static() { using (var context = CreateContext()) { - var len = context.Customers.Count(c => UDFSqlContext.IsDate(c.FirstName) == false); + var len = context.Customers.Count(c => UDFSqlContext.IsDateStatic(c.FirstName) == false); Assert.Equal(3, len); @@ -169,14 +255,14 @@ THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT) } [Fact] - private void Scalar_Function_With_Translator_Translates() + private void Scalar_Function_With_Translator_Translates_Static() { using (var context = CreateContext()) { var customerId = 3; var len = context.Customers.Where(c => c.Id == customerId) - .Select(c => UDFSqlContext.MyCustomLength(c.LastName)).Single(); + .Select(c => UDFSqlContext.MyCustomLengthStatic(c.LastName)).Single(); Assert.Equal(5, len); @@ -192,7 +278,7 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Function_ClientEval_Method_As_Translateable_Method_Parameter() + public void Scalar_Function_ClientEval_Method_As_Translateable_Method_Parameter_Static() { using (var context = CreateContext()) { @@ -202,19 +288,19 @@ public void Scalar_Function_ClientEval_Method_As_Translateable_Method_Parameter( select new { c.FirstName, - OrderCount = UDFSqlContext.CustomerOrderCount(UDFSqlContext.AddFive(c.Id - 5)) + OrderCount = UDFSqlContext.CustomerOrderCountStatic(UDFSqlContext.AddFiveStatic(c.Id - 5)) }).Single()); } } [Fact] - public void Scalar_Function_Constant_Parameter() + public void Scalar_Function_Constant_Parameter_Static() { using (var context = CreateContext()) { var customerId = 1; - var custs = context.Customers.Select(c => UDFSqlContext.CustomerOrderCount(customerId)).ToList(); + var custs = context.Customers.Select(c => UDFSqlContext.CustomerOrderCountStatic(customerId)).ToList(); Assert.Equal(3, custs.Count); @@ -229,7 +315,7 @@ SELECT [dbo].CustomerOrderCount(@__customerId_0) } [Fact] - public void Scalar_Function_Anonymous_Type_Select_Correlated() + public void Scalar_Function_Anonymous_Type_Select_Correlated_Static() { using (var context = CreateContext()) { @@ -238,7 +324,7 @@ public void Scalar_Function_Anonymous_Type_Select_Correlated() select new { c.LastName, - OrderCount = UDFSqlContext.CustomerOrderCount(c.Id) + OrderCount = UDFSqlContext.CustomerOrderCountStatic(c.Id) }).Single(); Assert.Equal("One", cust.LastName); @@ -254,7 +340,7 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Function_Anonymous_Type_Select_Not_Correlated() + public void Scalar_Function_Anonymous_Type_Select_Not_Correlated_Static() { using (var context = CreateContext()) { @@ -263,7 +349,7 @@ public void Scalar_Function_Anonymous_Type_Select_Not_Correlated() select new { c.LastName, - OrderCount = UDFSqlContext.CustomerOrderCount(1) + OrderCount = UDFSqlContext.CustomerOrderCountStatic(1) }).Single(); Assert.Equal("One", cust.LastName); @@ -279,7 +365,7 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Function_Anonymous_Type_Select_Parameter() + public void Scalar_Function_Anonymous_Type_Select_Parameter_Static() { using (var context = CreateContext()) { @@ -290,7 +376,7 @@ public void Scalar_Function_Anonymous_Type_Select_Parameter() select new { c.LastName, - OrderCount = UDFSqlContext.CustomerOrderCount(customerId) + OrderCount = UDFSqlContext.CustomerOrderCountStatic(customerId) }).Single(); Assert.Equal("One", cust.LastName); @@ -309,7 +395,7 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Function_Anonymous_Type_Select_Nested() + public void Scalar_Function_Anonymous_Type_Select_Nested_Static() { using (var context = CreateContext()) { @@ -321,7 +407,7 @@ public void Scalar_Function_Anonymous_Type_Select_Nested() select new { c.LastName, - OrderCount = UDFSqlContext.StarValue(starCount, UDFSqlContext.CustomerOrderCount(customerId)) + OrderCount = UDFSqlContext.StarValueStatic(starCount, UDFSqlContext.CustomerOrderCountStatic(customerId)) }).Single(); Assert.Equal("Three", cust.LastName); @@ -341,12 +427,12 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Function_Where_Correlated() + public void Scalar_Function_Where_Correlated_Static() { using (var context = CreateContext()) { var cust = (from c in context.Customers - where UDFSqlContext.IsTopCustomer(c.Id) + where UDFSqlContext.IsTopCustomerStatic(c.Id) select c.Id.ToString().ToLower()).ToList(); Assert.Equal(1, cust.Count); @@ -361,14 +447,14 @@ WHERE [dbo].IsTopCustomer([c].[Id]) = 1", } [Fact] - public void Scalar_Function_Where_Not_Correlated() + public void Scalar_Function_Where_Not_Correlated_Static() { using (var context = CreateContext()) { var startDate = DateTime.Parse("4/1/2000"); var custId = (from c in context.Customers - where UDFSqlContext.GetCustomerWithMostOrdersAfterDate(startDate) == c.Id + where UDFSqlContext.GetCustomerWithMostOrdersAfterDateStatic(startDate) == c.Id select c.Id).SingleOrDefault(); Assert.Equal(custId, 2); @@ -385,15 +471,15 @@ WHERE [dbo].GetCustomerWithMostOrdersAfterDate(@__startDate_0) = [c].[Id]", } [Fact] - public void Scalar_Function_Where_Parameter() + public void Scalar_Function_Where_Parameter_Static() { using (var context = CreateContext()) { var period = UDFSqlContext.ReportingPeriod.Winter; var custId = (from c in context.Customers - where c.Id == UDFSqlContext.GetCustomerWithMostOrdersAfterDate( - UDFSqlContext.GetReportingPeriodStartDate(period)) + where c.Id == UDFSqlContext.GetCustomerWithMostOrdersAfterDateStatic( + UDFSqlContext.GetReportingPeriodStartDateStatic(period)) select c.Id).SingleOrDefault(); Assert.Equal(custId, 1); @@ -410,13 +496,13 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Function_Where_Nested() + public void Scalar_Function_Where_Nested_Static() { using (var context = CreateContext()) { var custId = (from c in context.Customers - where c.Id == UDFSqlContext.GetCustomerWithMostOrdersAfterDate( - UDFSqlContext.GetReportingPeriodStartDate( + where c.Id == UDFSqlContext.GetCustomerWithMostOrdersAfterDateStatic( + UDFSqlContext.GetReportingPeriodStartDateStatic( UDFSqlContext.ReportingPeriod.Winter)) select c.Id).SingleOrDefault(); @@ -432,12 +518,12 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Function_Let_Correlated() + public void Scalar_Function_Let_Correlated_Static() { using (var context = CreateContext()) { var cust = (from c in context.Customers - let orderCount = UDFSqlContext.CustomerOrderCount(c.Id) + let orderCount = UDFSqlContext.CustomerOrderCountStatic(c.Id) where c.Id == 2 select new { @@ -458,12 +544,12 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Function_Let_Not_Correlated() + public void Scalar_Function_Let_Not_Correlated_Static() { using (var context = CreateContext()) { var cust = (from c in context.Customers - let orderCount = UDFSqlContext.CustomerOrderCount(2) + let orderCount = UDFSqlContext.CustomerOrderCountStatic(2) where c.Id == 2 select new { @@ -484,14 +570,14 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Function_Let_Not_Parameter() + public void Scalar_Function_Let_Not_Parameter_Static() { var customerId = 2; using (var context = CreateContext()) { var cust = (from c in context.Customers - let orderCount = UDFSqlContext.CustomerOrderCount(customerId) + let orderCount = UDFSqlContext.CustomerOrderCountStatic(customerId) where c.Id == customerId select new { @@ -515,7 +601,7 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Function_Let_Nested() + public void Scalar_Function_Let_Nested_Static() { using (var context = CreateContext()) { @@ -523,7 +609,7 @@ public void Scalar_Function_Let_Nested() var starCount = 3; var cust = (from c in context.Customers - let orderCount = UDFSqlContext.StarValue(starCount, UDFSqlContext.CustomerOrderCount(customerId)) + let orderCount = UDFSqlContext.StarValueStatic(starCount, UDFSqlContext.CustomerOrderCountStatic(customerId)) where c.Id == customerId select new { @@ -548,12 +634,12 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Nested_Function_Unwind_Client_Eval_Where() + public void Scalar_Nested_Function_Unwind_Client_Eval_Where_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 2 == UDFSqlContext.AddOne(c.Id) + where 2 == UDFSqlContext.AddOneStatic(c.Id) select c.Id).Single(); Assert.Equal(1, results); @@ -566,12 +652,12 @@ public void Scalar_Nested_Function_Unwind_Client_Eval_Where() } [Fact] - public void Scalar_Nested__Function_Unwind_Client_Eval_OrderBy() + public void Scalar_Nested__Function_Unwind_Client_Eval_OrderBy_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - orderby UDFSqlContext.AddOne(c.Id) + orderby UDFSqlContext.AddOneStatic(c.Id) select c.Id).ToList(); Assert.Equal(3, results.Count); @@ -586,13 +672,13 @@ orderby UDFSqlContext.AddOne(c.Id) } [Fact] - public void Scalar_Nested_Function_Unwind_Client_Eval_Select() + public void Scalar_Nested_Function_Unwind_Client_Eval_Select_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers orderby c.Id - select UDFSqlContext.AddOne(c.Id)).ToList(); + select UDFSqlContext.AddOneStatic(c.Id)).ToList(); Assert.Equal(3, results.Count); Assert.True(results.SequenceEqual(Enumerable.Range(2, 3))); @@ -607,12 +693,12 @@ FROM [Customers] AS [c] } [Fact] - public void Scalar_Nested_Function_Client_BCL_UDF() + public void Scalar_Nested_Function_Client_BCL_UDF_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 2 == UDFSqlContext.AddOne(Math.Abs(UDFSqlContext.CustomerOrderCountWithClient(c.Id))) + where 2 == UDFSqlContext.AddOneStatic(Math.Abs(UDFSqlContext.CustomerOrderCountWithClientStatic(c.Id))) select c.Id).Single(); Assert.Equal(3, results); @@ -625,12 +711,12 @@ public void Scalar_Nested_Function_Client_BCL_UDF() } [Fact] - public void Scalar_Nested_Function_Client_UDF_BCL() + public void Scalar_Nested_Function_Client_UDF_BCL_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 2 == UDFSqlContext.AddOne(UDFSqlContext.CustomerOrderCountWithClient(Math.Abs(c.Id))) + where 2 == UDFSqlContext.AddOneStatic(UDFSqlContext.CustomerOrderCountWithClientStatic(Math.Abs(c.Id))) select c.Id).Single(); Assert.Equal(3, results); @@ -643,12 +729,12 @@ public void Scalar_Nested_Function_Client_UDF_BCL() } [Fact] - public void Scalar_Nested_Function_BCL_Client_UDF() + public void Scalar_Nested_Function_BCL_Client_UDF_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 2 == Math.Abs(UDFSqlContext.AddOne(UDFSqlContext.CustomerOrderCountWithClient(c.Id))) + where 2 == Math.Abs(UDFSqlContext.AddOneStatic(UDFSqlContext.CustomerOrderCountWithClientStatic(c.Id))) select c.Id).Single(); Assert.Equal(3, results); @@ -661,12 +747,12 @@ public void Scalar_Nested_Function_BCL_Client_UDF() } [Fact] - public void Scalar_Nested_Function_BCL_UDF_Client() + public void Scalar_Nested_Function_BCL_UDF_Client_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 1 == Math.Abs(UDFSqlContext.CustomerOrderCountWithClient(UDFSqlContext.AddOne(c.Id))) + where 1 == Math.Abs(UDFSqlContext.CustomerOrderCountWithClientStatic(UDFSqlContext.AddOneStatic(c.Id))) select c.Id).Single(); Assert.Equal(2, results); @@ -679,12 +765,12 @@ public void Scalar_Nested_Function_BCL_UDF_Client() } [Fact] - public void Scalar_Nested_Function_UDF_BCL_Client() + public void Scalar_Nested_Function_UDF_BCL_Client_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 1 == UDFSqlContext.CustomerOrderCountWithClient(Math.Abs(UDFSqlContext.AddOne(c.Id))) + where 1 == UDFSqlContext.CustomerOrderCountWithClientStatic(Math.Abs(UDFSqlContext.AddOneStatic(c.Id))) select c.Id).Single(); Assert.Equal(2, results); @@ -697,12 +783,12 @@ public void Scalar_Nested_Function_UDF_BCL_Client() } [Fact] - public void Scalar_Nested_Function_UDF_Client_BCL() + public void Scalar_Nested_Function_UDF_Client_BCL_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 1 == UDFSqlContext.CustomerOrderCountWithClient(UDFSqlContext.AddOne(Math.Abs(c.Id))) + where 1 == UDFSqlContext.CustomerOrderCountWithClientStatic(UDFSqlContext.AddOneStatic(Math.Abs(c.Id))) select c.Id).Single(); Assert.Equal(2, results); @@ -715,12 +801,12 @@ public void Scalar_Nested_Function_UDF_Client_BCL() } [Fact] - public void Scalar_Nested_Function_Client_BCL() + public void Scalar_Nested_Function_Client_BCL_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 3 == UDFSqlContext.AddOne(Math.Abs(c.Id)) + where 3 == UDFSqlContext.AddOneStatic(Math.Abs(c.Id)) select c.Id).Single(); Assert.Equal(2, results); @@ -733,12 +819,12 @@ public void Scalar_Nested_Function_Client_BCL() } [Fact] - public void Scalar_Nested_Function_Client_UDF() + public void Scalar_Nested_Function_Client_UDF_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 2 == UDFSqlContext.AddOne(UDFSqlContext.CustomerOrderCountWithClient(c.Id)) + where 2 == UDFSqlContext.AddOneStatic(UDFSqlContext.CustomerOrderCountWithClientStatic(c.Id)) select c.Id).Single(); Assert.Equal(3, results); @@ -751,12 +837,12 @@ public void Scalar_Nested_Function_Client_UDF() } [Fact] - public void Scalar_Nested_Function_BCL_Client() + public void Scalar_Nested_Function_BCL_Client_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 3 == Math.Abs(UDFSqlContext.AddOne(c.Id)) + where 3 == Math.Abs(UDFSqlContext.AddOneStatic(c.Id)) select c.Id).Single(); Assert.Equal(2, results); @@ -769,12 +855,12 @@ public void Scalar_Nested_Function_BCL_Client() } [Fact] - public void Scalar_Nested_Function_BCL_UDF() + public void Scalar_Nested_Function_BCL_UDF_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 3 == Math.Abs(UDFSqlContext.CustomerOrderCount(c.Id)) + where 3 == Math.Abs(UDFSqlContext.CustomerOrderCountStatic(c.Id)) select c.Id).Single(); Assert.Equal(1, results); @@ -789,12 +875,12 @@ FROM [Customers] AS [c] [Fact] - public void Scalar_Nested_Function_UDF_Client() + public void Scalar_Nested_Function_UDF_Client_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 2 == UDFSqlContext.CustomerOrderCountWithClient(UDFSqlContext.AddOne(c.Id)) + where 2 == UDFSqlContext.CustomerOrderCountWithClientStatic(UDFSqlContext.AddOneStatic(c.Id)) select c.Id).Single(); Assert.Equal(1, results); @@ -807,12 +893,12 @@ public void Scalar_Nested_Function_UDF_Client() } [Fact] - public void Scalar_Nested_Function_UDF_BCL() + public void Scalar_Nested_Function_UDF_BCL_Static() { using (var context = CreateContext()) { var results = (from c in context.Customers - where 3 == UDFSqlContext.CustomerOrderCount(Math.Abs(c.Id)) + where 3 == UDFSqlContext.CustomerOrderCountStatic(Math.Abs(c.Id)) select c.Id).Single(); Assert.Equal(1, results); @@ -827,36 +913,752 @@ FROM [Customers] AS [c] #endregion - public class SqlServerUDFFixture : SharedStoreFixtureBase + #region Instance + + [Fact] + public void Scalar_Function_Non_Static() { - protected override string StoreName { get; } = "UDFDbFunctionSqlServerTests"; - protected override Type ContextType { get; } = typeof(UDFSqlContext); - protected override ITestStoreFactory TestStoreFactory => SqlServerTestStoreFactory.Instance; + using (var context = CreateContext()) + { + var custName = (from c in context.Customers + where c.Id == 1 + select new + { + Id = context.StarValueInstance(4, c.Id), + LastName = context.DollarValueInstance(2, c.LastName) + }).Single(); - public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ServiceProvider.GetRequiredService(); + Assert.Equal(custName.LastName, "$$One"); - public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) + Assert.Equal( + @"SELECT TOP(2) [dbo].StarValue(4, [c].[Id]) AS [Id], [dbo].DollarValue(2, [c].[LastName]) AS [LastName] +FROM [Customers] AS [c] +WHERE [c].[Id] = 1", + Sql, + ignoreLineEndingDifferences: true); + } + } + + + [Fact] + private void Scalar_Function_Extension_Method_Instance() + { + using (var context = CreateContext()) { - base.AddOptions(builder); - return builder.ConfigureWarnings(w => w.Ignore(RelationalEventId.QueryClientEvaluationWarning)); + var len = context.Customers.Count(c => context.IsDateInstance(c.FirstName) == false); + + Assert.Equal(3, len); + + Assert.Equal( + @"SELECT COUNT(*) +FROM [Customers] AS [c] +WHERE CASE + WHEN IsDate([c].[FirstName]) = 1 + THEN CAST(1 AS BIT) ELSE CAST(0 AS BIT) +END = 0", + Sql, + ignoreLineEndingDifferences: true); } + } - protected override void Seed(DbContext context) + [Fact] + private void Scalar_Function_With_Translator_Translates_Instance() + { + using (var context = CreateContext()) { - context.Database.EnsureCreated(); + var customerId = 3; - context.Database.ExecuteSqlCommand(@"create function [dbo].[CustomerOrderCount] (@customerId int) - returns int - as - begin - return (select count(id) from orders where customerId = @customerId); - end"); + var len = context.Customers.Where(c => c.Id == customerId) + .Select(c => context.MyCustomLengthInstance(c.LastName)).Single(); - context.Database.ExecuteSqlCommand(@"create function[dbo].[StarValue] (@starCount int, @value nvarchar(max)) - returns nvarchar(max) - as - begin - return replicate('*', @starCount) + @value + Assert.Equal(5, len); + + Assert.Equal( + @"@__customerId_0='3' + +SELECT TOP(2) len([c].[LastName]) +FROM [Customers] AS [c] +WHERE [c].[Id] = @__customerId_0", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_ClientEval_Method_As_Translateable_Method_Parameter_Instance() + { + using (var context = CreateContext()) + { + Assert.Throws( + () => (from c in context.Customers + where c.Id == 1 + select new + { + c.FirstName, + OrderCount = context.CustomerOrderCountInstance(context.AddFiveInstance(c.Id - 5)) + }).Single()); + } + } + + [Fact] + public void Scalar_Function_Constant_Parameter_Instance() + { + using (var context = CreateContext()) + { + var customerId = 1; + + var custs = context.Customers.Select(c => context.CustomerOrderCountInstance(customerId)).ToList(); + + Assert.Equal(3, custs.Count); + + Assert.Equal( + @"@__customerId_1='1' + +SELECT [dbo].CustomerOrderCount(@__customerId_1) +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Anonymous_Type_Select_Correlated_Instance() + { + using (var context = CreateContext()) + { + var cust = (from c in context.Customers + where c.Id == 1 + select new + { + c.LastName, + OrderCount = context.CustomerOrderCountInstance(c.Id) + }).Single(); + + Assert.Equal("One", cust.LastName); + Assert.Equal(3, cust.OrderCount); + + Assert.Equal( + @"SELECT TOP(2) [c].[LastName], [dbo].CustomerOrderCount([c].[Id]) AS [OrderCount] +FROM [Customers] AS [c] +WHERE [c].[Id] = 1", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Anonymous_Type_Select_Not_Correlated_Instance() + { + using (var context = CreateContext()) + { + var cust = (from c in context.Customers + where c.Id == 1 + select new + { + c.LastName, + OrderCount = context.CustomerOrderCountInstance(1) + }).Single(); + + Assert.Equal("One", cust.LastName); + Assert.Equal(3, cust.OrderCount); + + Assert.Equal( + @"SELECT TOP(2) [c].[LastName], [dbo].CustomerOrderCount(1) AS [OrderCount] +FROM [Customers] AS [c] +WHERE [c].[Id] = 1", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Anonymous_Type_Select_Parameter_Instance() + { + using (var context = CreateContext()) + { + var customerId = 1; + + var cust = (from c in context.Customers + where c.Id == customerId + select new + { + c.LastName, + OrderCount = context.CustomerOrderCountInstance(customerId) + }).Single(); + + Assert.Equal("One", cust.LastName); + Assert.Equal(3, cust.OrderCount); + + Assert.Equal( + @"@__customerId_2='1' +@__customerId_0='1' + +SELECT TOP(2) [c].[LastName], [dbo].CustomerOrderCount(@__customerId_2) AS [OrderCount] +FROM [Customers] AS [c] +WHERE [c].[Id] = @__customerId_0", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Anonymous_Type_Select_Nested_Instance() + { + using (var context = CreateContext()) + { + var customerId = 3; + var starCount = 3; + + var cust = (from c in context.Customers + where c.Id == customerId + select new + { + c.LastName, + OrderCount = context.StarValueInstance(starCount, context.CustomerOrderCountInstance(customerId)) + }).Single(); + + Assert.Equal("Three", cust.LastName); + Assert.Equal("***1", cust.OrderCount); + + Assert.Equal( + @"@__starCount_2='3' +@__customerId_4='3' +@__customerId_0='3' + +SELECT TOP(2) [c].[LastName], [dbo].StarValue(@__starCount_2, [dbo].CustomerOrderCount(@__customerId_4)) AS [OrderCount] +FROM [Customers] AS [c] +WHERE [c].[Id] = @__customerId_0", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Where_Correlated_Instance() + { + using (var context = CreateContext()) + { + var cust = (from c in context.Customers + where context.IsTopCustomerInstance(c.Id) + select c.Id.ToString().ToLower()).ToList(); + + Assert.Equal(1, cust.Count); + + Assert.Equal( + @"SELECT LOWER(CONVERT(VARCHAR(11), [c].[Id])) +FROM [Customers] AS [c] +WHERE [dbo].IsTopCustomer([c].[Id]) = 1", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Where_Not_Correlated_Instance() + { + using (var context = CreateContext()) + { + var startDate = DateTime.Parse("4/1/2000"); + + var custId = (from c in context.Customers + where context.GetCustomerWithMostOrdersAfterDateInstance(startDate) == c.Id + select c.Id).SingleOrDefault(); + + Assert.Equal(custId, 2); + + Assert.Equal( + @"@__startDate_1='2000-04-01T00:00:00' + +SELECT TOP(2) [c].[Id] +FROM [Customers] AS [c] +WHERE [dbo].GetCustomerWithMostOrdersAfterDate(@__startDate_1) = [c].[Id]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Where_Parameter_Instance() + { + using (var context = CreateContext()) + { + var period = UDFSqlContext.ReportingPeriod.Winter; + + var custId = (from c in context.Customers + where c.Id == context.GetCustomerWithMostOrdersAfterDateInstance( + context.GetReportingPeriodStartDateInstance(period)) + select c.Id).SingleOrDefault(); + + Assert.Equal(custId, 1); + + Assert.Equal( + @"@__period_2='Winter' + +SELECT TOP(2) [c].[Id] +FROM [Customers] AS [c] +WHERE [c].[Id] = [dbo].GetCustomerWithMostOrdersAfterDate([dbo].GetReportingPeriodStartDate(@__period_2))", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Where_Nested_Instance() + { + using (var context = CreateContext()) + { + var custId = (from c in context.Customers + where c.Id == context.GetCustomerWithMostOrdersAfterDateInstance( + context.GetReportingPeriodStartDateInstance( + UDFSqlContext.ReportingPeriod.Winter)) + select c.Id).SingleOrDefault(); + + Assert.Equal(custId, 1); + + Assert.Equal( + @"SELECT TOP(2) [c].[Id] +FROM [Customers] AS [c] +WHERE [c].[Id] = [dbo].GetCustomerWithMostOrdersAfterDate([dbo].GetReportingPeriodStartDate(0))", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Let_Correlated_Instance() + { + using (var context = CreateContext()) + { + var cust = (from c in context.Customers + let orderCount = context.CustomerOrderCountInstance(c.Id) + where c.Id == 2 + select new + { + c.LastName, + OrderCount = orderCount + }).Single(); + + Assert.Equal("Two", cust.LastName); + Assert.Equal(2, cust.OrderCount); + + Assert.Equal( + @"SELECT TOP(2) [c].[LastName], [dbo].CustomerOrderCount([c].[Id]) AS [OrderCount] +FROM [Customers] AS [c] +WHERE [c].[Id] = 2", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Let_Not_Correlated_Instance() + { + using (var context = CreateContext()) + { + var cust = (from c in context.Customers + let orderCount = context.CustomerOrderCountInstance(2) + where c.Id == 2 + select new + { + c.LastName, + OrderCount = orderCount + }).Single(); + + Assert.Equal("Two", cust.LastName); + Assert.Equal(2, cust.OrderCount); + + Assert.Equal( + @"SELECT TOP(2) [c].[LastName], [dbo].CustomerOrderCount(2) AS [OrderCount] +FROM [Customers] AS [c] +WHERE [c].[Id] = 2", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Let_Not_Parameter_Instance() + { + var customerId = 2; + + using (var context = CreateContext()) + { + var cust = (from c in context.Customers + let orderCount = context.CustomerOrderCountInstance(customerId) + where c.Id == customerId + select new + { + c.LastName, + OrderCount = orderCount + }).Single(); + + Assert.Equal("Two", cust.LastName); + Assert.Equal(2, cust.OrderCount); + + Assert.Equal( + @"@__8__locals1_customerId_1='2' +@__8__locals1_customerId_2='2' + +SELECT TOP(2) [c].[LastName], [dbo].CustomerOrderCount(@__8__locals1_customerId_1) AS [OrderCount] +FROM [Customers] AS [c] +WHERE [c].[Id] = @__8__locals1_customerId_2", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Function_Let_Nested_Instance() + { + using (var context = CreateContext()) + { + var customerId = 1; + var starCount = 3; + + var cust = (from c in context.Customers + let orderCount = context.StarValueInstance(starCount, context.CustomerOrderCountInstance(customerId)) + where c.Id == customerId + select new + { + c.LastName, + OrderCount = orderCount + }).Single(); + + Assert.Equal("One", cust.LastName); + Assert.Equal("***3", cust.OrderCount); + + Assert.Equal( + @"@__starCount_1='3' +@__customerId_3='1' +@__customerId_4='1' + +SELECT TOP(2) [c].[LastName], [dbo].StarValue(@__starCount_1, [dbo].CustomerOrderCount(@__customerId_3)) AS [OrderCount] +FROM [Customers] AS [c] +WHERE [c].[Id] = @__customerId_4", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_Unwind_Client_Eval_Where_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 2 == context.AddOneInstance(c.Id) + select c.Id).Single(); + + Assert.Equal(1, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested__Function_Unwind_Client_Eval_OrderBy_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + orderby context.AddOneInstance(c.Id) + select c.Id).ToList(); + + Assert.Equal(3, results.Count); + Assert.True(results.SequenceEqual(Enumerable.Range(1, 3))); + + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_Unwind_Client_Eval_Select_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + orderby c.Id + select context.AddOneInstance(c.Id)).ToList(); + + Assert.Equal(3, results.Count); + Assert.True(results.SequenceEqual(Enumerable.Range(2, 3))); + + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c] +ORDER BY [c].[Id]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_Client_BCL_UDF_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 2 == context.AddOneInstance(Math.Abs(context.CustomerOrderCountWithClientInstance(c.Id))) + select c.Id).Single(); + + Assert.Equal(3, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_Client_UDF_BCL_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 2 == context.AddOneInstance(context.CustomerOrderCountWithClientInstance(Math.Abs(c.Id))) + select c.Id).Single(); + + Assert.Equal(3, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_BCL_Client_UDF_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 2 == Math.Abs(context.AddOneInstance(context.CustomerOrderCountWithClientInstance(c.Id))) + select c.Id).Single(); + + Assert.Equal(3, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_BCL_UDF_Client_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 1 == Math.Abs(context.CustomerOrderCountWithClientInstance(context.AddOneInstance(c.Id))) + select c.Id).Single(); + + Assert.Equal(2, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_UDF_BCL_Client_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 1 == context.CustomerOrderCountWithClientInstance(Math.Abs(context.AddOneInstance(c.Id))) + select c.Id).Single(); + + Assert.Equal(2, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_UDF_Client_BCL_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 1 == context.CustomerOrderCountWithClientInstance(context.AddOneInstance(Math.Abs(c.Id))) + select c.Id).Single(); + + Assert.Equal(2, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_Client_BCL_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 3 == context.AddOneInstance(Math.Abs(c.Id)) + select c.Id).Single(); + + Assert.Equal(2, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_Client_UDF_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 2 == context.AddOneInstance(context.CustomerOrderCountWithClientInstance(c.Id)) + select c.Id).Single(); + + Assert.Equal(3, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_BCL_Client_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 3 == Math.Abs(context.AddOneInstance(c.Id)) + select c.Id).Single(); + + Assert.Equal(2, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_BCL_UDF_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 3 == Math.Abs(context.CustomerOrderCountInstance(c.Id)) + select c.Id).Single(); + + Assert.Equal(1, results); + Assert.Equal( + @"SELECT TOP(2) [c].[Id] +FROM [Customers] AS [c] +WHERE 3 = ABS([dbo].CustomerOrderCount([c].[Id]))", + Sql, + ignoreLineEndingDifferences: true); + } + } + + + [Fact] + public void Scalar_Nested_Function_UDF_Client_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 2 == context.CustomerOrderCountWithClientInstance(context.AddOneInstance(c.Id)) + select c.Id).Single(); + + Assert.Equal(1, results); + Assert.Equal( + @"SELECT [c].[Id] +FROM [Customers] AS [c]", + Sql, + ignoreLineEndingDifferences: true); + } + } + + [Fact] + public void Scalar_Nested_Function_UDF_BCL_Instance() + { + using (var context = CreateContext()) + { + var results = (from c in context.Customers + where 3 == context.CustomerOrderCountInstance(Math.Abs(c.Id)) + select c.Id).Single(); + + Assert.Equal(1, results); + Assert.Equal( + @"SELECT TOP(2) [c].[Id] +FROM [Customers] AS [c] +WHERE 3 = [dbo].CustomerOrderCount(ABS([c].[Id]))", + Sql, + ignoreLineEndingDifferences: true); + } + } + + #endregion + + #endregion + + public class SqlServerUDFFixture : SharedStoreFixtureBase + { + protected override string StoreName { get; } = "UDFDbFunctionSqlServerTests"; + protected override Type ContextType { get; } = typeof(UDFSqlContext); + protected override ITestStoreFactory TestStoreFactory => SqlServerTestStoreFactory.Instance; + + public TestSqlLoggerFactory TestSqlLoggerFactory => (TestSqlLoggerFactory)ServiceProvider.GetRequiredService(); + + public override DbContextOptionsBuilder AddOptions(DbContextOptionsBuilder builder) + { + base.AddOptions(builder); + return builder.ConfigureWarnings(w => w.Ignore(RelationalEventId.QueryClientEvaluationWarning)); + } + + protected override void Seed(DbContext context) + { + context.Database.EnsureCreated(); + + context.Database.ExecuteSqlCommand(@"create function [dbo].[CustomerOrderCount] (@customerId int) + returns int + as + begin + return (select count(id) from orders where customerId = @customerId); + end"); + + context.Database.ExecuteSqlCommand(@"create function[dbo].[StarValue] (@starCount int, @value nvarchar(max)) + returns nvarchar(max) + as + begin + return replicate('*', @starCount) + @value + end"); + + context.Database.ExecuteSqlCommand(@"create function[dbo].[DollarValue] (@starCount int, @value nvarchar(max)) + returns nvarchar(max) + as + begin + return replicate('$', @starCount) + @value end"); context.Database.ExecuteSqlCommand(@"create function [dbo].[GetReportingPeriodStartDate] (@period int)