Skip to content

Commit

Permalink
Linq: add enum Equals and object Equals support (#3242)
Browse files Browse the repository at this point in the history
  • Loading branch information
tykovec authored May 17, 2023
1 parent 79ad5b8 commit c52c7f1
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 4 deletions.
27 changes: 27 additions & 0 deletions src/NHibernate.Test/Async/Linq/FunctionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,33 @@ where item.Discount.Equals(-1)
await (ObjectDumper.WriteAsync(query));
}

[Test]
public async Task WhereEnumEqualAsync()
{
var query = from item in db.PatientRecords
where item.Gender.Equals(Gender.Female)
select item;

await (ObjectDumper.WriteAsync(query));

query = from item in db.PatientRecords
where item.Gender.Equals(item.Gender)
select item;

await (ObjectDumper.WriteAsync(query));
}


[Test]
public async Task WhereObjectEqualAsync()
{
var query = from item in db.PatientRecords
where ((object) item.Gender).Equals(Gender.Female)
select item;

await (ObjectDumper.WriteAsync(query));
}

[Test]
public async Task WhereEquatableEqualAsync()
{
Expand Down
27 changes: 27 additions & 0 deletions src/NHibernate.Test/Linq/FunctionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,33 @@ where item.Discount.Equals(-1)
ObjectDumper.Write(query);
}

[Test]
public void WhereEnumEqual()
{
var query = from item in db.PatientRecords
where item.Gender.Equals(Gender.Female)
select item;

ObjectDumper.Write(query);

query = from item in db.PatientRecords
where item.Gender.Equals(item.Gender)
select item;

ObjectDumper.Write(query);
}


[Test]
public void WhereObjectEqual()
{
var query = from item in db.PatientRecords
where ((object) item.Gender).Equals(Gender.Female)
select item;

ObjectDumper.Write(query);
}

[Test]
public void WhereEquatableEqual()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ public virtual bool TryGetGenerator(MemberInfo property, out IHqlGeneratorForPro

public virtual void RegisterGenerator(MethodInfo method, IHqlGeneratorForMethod generator)
{
registeredMethods.Add(method, generator);
registeredMethods[method] = generator;
}

public virtual void RegisterGenerator(MemberInfo property, IHqlGeneratorForProperty generator)
{
registeredProperties.Add(property, generator);
registeredProperties[property] = generator;
}

public void RegisterGenerator(IRuntimeMethodHqlGenerator generator)
Expand Down
10 changes: 8 additions & 2 deletions src/NHibernate/Linq/Functions/EqualsGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ public class EqualsGenerator : BaseHqlGeneratorForMethod
ReflectHelper.GetMethodDefinition<IEquatable<DateTime>>(x => x.Equals(default(DateTime))),
ReflectHelper.GetMethodDefinition<IEquatable<DateTimeOffset>>(x => x.Equals(default(DateTimeOffset))),
ReflectHelper.GetMethodDefinition<IEquatable<TimeSpan>>(x => x.Equals(default(TimeSpan))),
ReflectHelper.GetMethodDefinition<IEquatable<bool>>(x => x.Equals(default(bool)))
ReflectHelper.GetMethodDefinition<IEquatable<bool>>(x => x.Equals(default(bool))),
ReflectHelper.GetMethodDefinition<object>(x => x.Equals(default(object))), // this covers also Enum.Equals
ReflectHelper.GetMethodDefinition<IEquatable<object>>(x => x.Equals(default(object))),
ReflectHelper.GetMethodDefinition<IEquatable<Enum>>(x => x.Equals(default(Enum)))
};

public EqualsGenerator()
Expand All @@ -72,7 +75,10 @@ public override HqlTreeNode BuildHql(MethodInfo method, Expression targetObject,
{
Expression lhs = arguments.Count == 1 ? targetObject : arguments[0];
Expression rhs = arguments.Count == 1 ? arguments[0] : arguments[1];

if (lhs.Type.IsEnum)
{
return visitor.Visit(Expression.Equal(lhs, Expression.Convert(rhs, lhs.Type)));
}
return visitor.Visit(Expression.Equal(lhs, rhs));
}
}
Expand Down

0 comments on commit c52c7f1

Please sign in to comment.