diff --git a/src/NHibernate.Test/Async/Linq/FunctionTests.cs b/src/NHibernate.Test/Async/Linq/FunctionTests.cs index e47e3d22c37..c39ff901df4 100644 --- a/src/NHibernate.Test/Async/Linq/FunctionTests.cs +++ b/src/NHibernate.Test/Async/Linq/FunctionTests.cs @@ -459,6 +459,33 @@ where item.Discount.Equals(-1) await (ObjectDumper.WriteAsync(query)); } + [Test] + public async Task WhereEnumEqualAsync() + { + var query = from item in db.PatientRecords + where item.Gender.Equals(Gender.Female) + select item; + + await (ObjectDumper.WriteAsync(query)); + + query = from item in db.PatientRecords + where item.Gender.Equals(item.Gender) + select item; + + await (ObjectDumper.WriteAsync(query)); + } + + + [Test] + public async Task WhereObjectEqualAsync() + { + var query = from item in db.PatientRecords + where ((object) item.Gender).Equals(Gender.Female) + select item; + + await (ObjectDumper.WriteAsync(query)); + } + [Test] public async Task WhereEquatableEqualAsync() { diff --git a/src/NHibernate.Test/Linq/FunctionTests.cs b/src/NHibernate.Test/Linq/FunctionTests.cs index b92e7b3ece9..877040e35ba 100644 --- a/src/NHibernate.Test/Linq/FunctionTests.cs +++ b/src/NHibernate.Test/Linq/FunctionTests.cs @@ -491,6 +491,33 @@ where item.Discount.Equals(-1) ObjectDumper.Write(query); } + [Test] + public void WhereEnumEqual() + { + var query = from item in db.PatientRecords + where item.Gender.Equals(Gender.Female) + select item; + + ObjectDumper.Write(query); + + query = from item in db.PatientRecords + where item.Gender.Equals(item.Gender) + select item; + + ObjectDumper.Write(query); + } + + + [Test] + public void WhereObjectEqual() + { + var query = from item in db.PatientRecords + where ((object) item.Gender).Equals(Gender.Female) + select item; + + ObjectDumper.Write(query); + } + [Test] public void WhereEquatableEqual() { diff --git a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs index 6800271f9f0..7c35d09fd13 100644 --- a/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs +++ b/src/NHibernate/Linq/Functions/DefaultLinqToHqlGeneratorsRegistry.cs @@ -111,12 +111,12 @@ public virtual bool TryGetGenerator(MemberInfo property, out IHqlGeneratorForPro public virtual void RegisterGenerator(MethodInfo method, IHqlGeneratorForMethod generator) { - registeredMethods.Add(method, generator); + registeredMethods[method] = generator; } public virtual void RegisterGenerator(MemberInfo property, IHqlGeneratorForProperty generator) { - registeredProperties.Add(property, generator); + registeredProperties[property] = generator; } public void RegisterGenerator(IRuntimeMethodHqlGenerator generator) diff --git a/src/NHibernate/Linq/Functions/EqualsGenerator.cs b/src/NHibernate/Linq/Functions/EqualsGenerator.cs index 82c15c87190..301d405b1be 100644 --- a/src/NHibernate/Linq/Functions/EqualsGenerator.cs +++ b/src/NHibernate/Linq/Functions/EqualsGenerator.cs @@ -58,7 +58,10 @@ public class EqualsGenerator : BaseHqlGeneratorForMethod ReflectHelper.GetMethodDefinition>(x => x.Equals(default(DateTime))), ReflectHelper.GetMethodDefinition>(x => x.Equals(default(DateTimeOffset))), ReflectHelper.GetMethodDefinition>(x => x.Equals(default(TimeSpan))), - ReflectHelper.GetMethodDefinition>(x => x.Equals(default(bool))) + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(bool))), + ReflectHelper.GetMethodDefinition(x => x.Equals(default(object))), // this covers also Enum.Equals + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(object))), + ReflectHelper.GetMethodDefinition>(x => x.Equals(default(Enum))) }; public EqualsGenerator() @@ -72,7 +75,10 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject, { Expression lhs = arguments.Count == 1 ? targetObject : arguments[0]; Expression rhs = arguments.Count == 1 ? arguments[0] : arguments[1]; - + if (lhs.Type.IsEnum) + { + return visitor.Visit(Expression.Equal(lhs, Expression.Convert(rhs, lhs.Type))); + } return visitor.Visit(Expression.Equal(lhs, rhs)); } }