From 181ede6e33fc923c7df91c93095013a4de74ad60 Mon Sep 17 00:00:00 2001 From: Damien Guard Date: Thu, 9 Oct 2025 16:28:56 +0100 Subject: [PATCH 1/5] Replace MemoryExtension.Contains/SequenceEqual with Enumerable counterparts. --- ...ingWithOutputExpressionStageDefinitions.cs | 8 +- .../Misc/ClrCompatExpressionRewriter.cs | 92 ++++++ .../ExpressionToExecutableQueryTranslator.cs | 4 +- .../LookupMethodToPipelineTranslator.cs | 4 +- .../ExpressionToSetStageTranslator.cs | 4 +- .../Linq/LinqProviderAdapter.cs | 11 +- .../Jira/CSharp5749Tests.cs | 268 ++++++++++++++++++ 7 files changed, 376 insertions(+), 15 deletions(-) create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs create mode 100644 tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs index 4d62eaea95c..66ed5ceca0a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs @@ -63,7 +63,7 @@ private AstStage RenderProjectStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer outputSerializer) { - var partiallyEvaluatedOutput = (Expression>)PartialEvaluator.EvaluatePartially(_output); + var partiallyEvaluatedOutput = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(_output)); var context = TranslationContext.Create(translationOptions); var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation); @@ -105,7 +105,7 @@ protected override AstStage RenderGroupingStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer> groupingOutputSerializer) { - var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); + var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(_groupBy)); var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); @@ -149,7 +149,7 @@ protected override AstStage RenderGroupingStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer, TInput>> groupingOutputSerializer) { - var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); + var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(_groupBy)); var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); @@ -187,7 +187,7 @@ protected override AstStage RenderGroupingStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer> groupingOutputSerializer) { - var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); + var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(_groupBy)); var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs new file mode 100644 index 00000000000..71c0ea79175 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs @@ -0,0 +1,92 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Linq.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using static System.Linq.Expressions.Expression; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +/// +/// This visitor rewrites expressions where new features of .NET CLR or +/// C# compiler interfere with LINQ expression tree translation. +/// +internal class ClrCompatExpressionRewriter : ExpressionVisitor +{ + private static readonly ClrCompatExpressionRewriter s_clrCompatExpressionRewriter = new(); + + public static Expression Rewrite(Expression expression) + => s_clrCompatExpressionRewriter.Visit(expression); + + /// + protected override Expression VisitMethodCall(MethodCallExpression methodCall) + { + var method = methodCall.Method; + + // C# 14 introduced implicit casts to ReadOnlySpan and Span + // This then binds Contains and SequenceEqual to the MemoryExtensions instead of Enumerable + // and we can't just add support for MemoryExtensions elsewhere as it's actually invalid + // in an expression tree due to the ref semantics of ReadOnlySpan. + if (method.DeclaringType == typeof(MemoryExtensions) && method.IsGenericMethod) + { + switch (method.Name) + { + // Replace MemoryExtensions.Contains(ReadOnlySpan, T) with Enumerable.Contains(IEnumerable, T) + case nameof(MemoryExtensions.Contains) + when (methodCall.Arguments.Count == 2 + || (methodCall.Arguments.Count == 3 && methodCall.Arguments[2] is ConstantExpression { Value: null })) + && TryUnwrapSpanImplicitCast(methodCall.Arguments[0], out var unwrappedSpanArg): + return Visit( + Call( + EnumerableMethod.Contains.MakeGenericMethod(method.GetGenericArguments()[0]), + unwrappedSpanArg, methodCall.Arguments[1]))!; + + // Replace MemoryExtensions.SequenceEqual(ReadOnlySpan, ReadOnlySpan) with Enumerable.SequenceEqual(IEnumerable, IEnumerable) + case nameof(MemoryExtensions.SequenceEqual) + when methodCall.Arguments.Count == 2 + && TryUnwrapSpanImplicitCast(methodCall.Arguments[0], out var unwrappedSpanArg) + && TryUnwrapSpanImplicitCast(methodCall.Arguments[1], out var unwrappedOtherArg): + return Visit( + Call( + EnumerableMethod.SequenceEqual.MakeGenericMethod(method.GetGenericArguments()[0]), + unwrappedSpanArg, unwrappedOtherArg))!; + } + + // Erase implicit casts to ReadOnlySpan and Span + static bool TryUnwrapSpanImplicitCast(Expression expression, out Expression result) + { + if (expression is MethodCallExpression + { + Method: + { + Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType + } + } methodCallExpression + && implicitCastDeclaringType.GetGenericTypeDefinition() is var genericTypeDefinition + && (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>))) + { + result = methodCallExpression.Arguments[0]; + return true; + } + + result = null; + return false; + } + } + + return base.VisitMethodCall(methodCall); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs index b96a193e323..29ddddb8e38 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs @@ -29,7 +29,7 @@ public static ExecutableQuery> Translate TranslateScalar(provider); body = ExpressionReplacer.Replace(body, queryableParameter, Expression.Constant(queryable)); - body = PartialEvaluator.EvaluatePartially(body); + body = PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(body)); return ExpressionToPipelineTranslator.Translate(context, body); } @@ -338,7 +338,7 @@ private static TranslatedPipeline TranslateLookupPipelineAgainstQueryable>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(expression)); var context = TranslationContext.Create(translationOptions, contextData); var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, sourceSerializer, asRoot: true); var simplifiedAst = AstSimplifier.Simplify(translation.Ast); @@ -74,7 +74,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - expression = (LambdaExpression)PartialEvaluator.EvaluatePartially(expression); + expression = (LambdaExpression)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(expression)); var parameter = expression.Parameters.Single(); var context = TranslationContext.Create(translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); @@ -104,7 +104,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(expression)); var parameter = expression.Parameters.Single(); var context = TranslationContext.Create(translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); @@ -124,6 +124,7 @@ internal static BsonDocument TranslateExpressionToElemMatchFilter( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { + expression = (Expression>)ClrCompatExpressionRewriter.Rewrite(expression); expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); var context = TranslationContext.Create(translationOptions); var parameter = expression.Parameters.Single(); @@ -141,7 +142,7 @@ internal static BsonDocument TranslateExpressionToFilter( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(expression)); var context = TranslationContext.Create(translationOptions); var filter = ExpressionToFilterTranslator.TranslateLambda(context, expression, documentSerializer, asRoot: true); filter = AstSimplifier.SimplifyAndConvert(filter); @@ -175,7 +176,7 @@ private static RenderedProjectionDefinition TranslateExpressionToProjec return new RenderedProjectionDefinition(null, (IBsonSerializer)inputSerializer); } - expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(expression)); var context = TranslationContext.Create(translationOptions); var simplifier = forFind ? new AstFindProjectionSimplifier() : new AstSimplifier(); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs new file mode 100644 index 00000000000..aa694c827fd --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs @@ -0,0 +1,268 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Bson; +using MongoDB.Driver.TestHelpers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +public class CSharp5749Tests : LinqIntegrationTest +{ + public CSharp5749Tests(ClassFixture fixture) + : base(fixture) + { + } + +#if NET6_0_OR_GREATER + + [Fact] + public void MemoryExtension_contains_in_where_should_work() + { + var names = new[] { "Two", "Three" }; + + var queryable = MangledQueryable().Where(x => names.Contains(x.Name)); + + Assert.Equal(2, queryable.Count()); + } + + [Fact] + public void MemoryExtension_contains_in_single_should_work() + { + var names = new[] { "Two" }; + + var actual = MangledQueryable().Single(x => names.Contains(x.Name)); + + Assert.Equal("Two", actual.Name); + } + + [Fact] + public void MemoryExtension_contains_in_any_should_work() + { + var ids = new[] { 2 }; + + var actual = MangledQueryable().Any(x => ids.Contains(x.Id)); + + Assert.True(actual); + } + + [Fact] + public void MemoryExtension_contains_in_count_should_work() + { + var ids = new[] { 2 }; + + var actual = MangledQueryable().Count(x => ids.Contains(x.Id)); + + Assert.Equal(1, actual); + } + + [Fact] + public void MemoryExtension_sequenceequal_in_where_should_work() + { + var ratings = new[] { 1, 9, 6 }; + + var queryable = MangledQueryable().Where(x => ratings.SequenceEqual(x.Ratings)); + + Assert.Equal(1, queryable.Count()); + } + + [Fact] + public void MemoryExtension_sequenceequal_in_single_should_work() + { + var ratings = new[] { 1, 9, 6 }; + + var actual = MangledQueryable().Single(x => ratings.SequenceEqual(x.Ratings)); + + Assert.Equal("Three", actual.Name); + } + + [Fact] + public void MemoryExtension_sequenceequas_in_any_should_work() + { + var ratings = new[] { 1, 2, 3, 4, 5 }; + + var actual = MangledQueryable().Any(x => ratings.SequenceEqual(x.Ratings)); + + Assert.True(actual); + } + + [Fact] + public void MemoryExtension_sequenceequas_in_count_should_work() + { + var ratings = new[] { 3, 4, 5, 6, 7 }; + + var actual = MangledQueryable().Count(x => ratings.SequenceEqual(x.Ratings)); + + Assert.Equal(1, actual); + } + +#endif + + public class C + { + public int Id { get; set; } + public string Name { get; set; } + public int[] Ratings { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + BsonDocument.Parse("{ _id : 1, Name: \"One\", Ratings: [1,2,3,4,5] }"), + BsonDocument.Parse("{ _id : 2, Name: \"Two\", Ratings: [3,4,5,6,7] }"), + BsonDocument.Parse("{ _id : 3, Name: \"Three\", Ratings: [1,9,6] }") + ]; + } + + public IQueryable MangledQueryable() + => new ManglingQueryable(Fixture.Collection.AsQueryable()); + + public class ManglingQueryProvider(IQueryProvider innerProvider) : IQueryProvider + { + public IQueryable CreateQuery(Expression expression) + => innerProvider.CreateQuery(EnumerableToMemoryExtensionsMangler.Mangle(expression)); + + public IQueryable CreateQuery(Expression expression) + { + var mangledExpression = EnumerableToMemoryExtensionsMangler.Mangle(expression); + return innerProvider.CreateQuery(mangledExpression); + } + + public object Execute(Expression expression) + { + var mangledExpression = EnumerableToMemoryExtensionsMangler.Mangle(expression); + return innerProvider.Execute(mangledExpression); + } + + public TResult Execute(Expression expression) + { + var mangledExpression = EnumerableToMemoryExtensionsMangler.Mangle(expression); + return innerProvider.Execute(mangledExpression); + } + } + + public class ManglingQueryable(IQueryable innerQueryable) : IOrderedQueryable + { + private readonly ManglingQueryProvider _provider = new(innerQueryable.Provider); + + public Type ElementType => innerQueryable.ElementType; + public Expression Expression => innerQueryable.Expression; + public IQueryProvider Provider => _provider; + + public IEnumerator GetEnumerator() + { + var mangledExpression = EnumerableToMemoryExtensionsMangler.Mangle(Expression); + var query = innerQueryable.Provider.CreateQuery(mangledExpression); + return query.GetEnumerator(); + } + + System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() + => GetEnumerator(); + } + + public class EnumerableToMemoryExtensionsMangler : ExpressionVisitor + { + private static readonly EnumerableToMemoryExtensionsMangler s_instance = new(); + + public static Expression Mangle(Expression expression) + => s_instance.Visit(expression); + + protected override Expression VisitMethodCall(MethodCallExpression node) + { + // Check if this is Enumerable.Contains or Enumerable.SequenceEqual + if (node.Method.DeclaringType == typeof(Enumerable) && + node.Method.IsGenericMethod && + (node.Method.Name == nameof(Enumerable.Contains) || + node.Method.Name == nameof(Enumerable.SequenceEqual))) + { + var elementType = node.Method.GetGenericArguments()[0]; + + // Get ReadOnlySpan.op_Implicit method + var spanType = typeof(ReadOnlySpan<>).MakeGenericType(elementType); + var opImplicitMethod = spanType.GetMethod("op_Implicit", + BindingFlags.Public | BindingFlags.Static, + null, + [elementType.MakeArrayType()], + null); + + if (opImplicitMethod == null) + return base.VisitMethodCall(node); + + switch (node.Method.Name) + { + case nameof(Enumerable.Contains): + { + var source = node.Arguments[0]; + if (!source.Type.IsArray) + return base.VisitMethodCall(node); + + var containsMethod = typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == nameof(MemoryExtensions.Contains) && + m.IsGenericMethod && + m.GetParameters().Length == 2 && + m.GetParameters()[0].ParameterType.IsGenericType && + m.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == + typeof(ReadOnlySpan<>)) + ?.MakeGenericMethod(elementType) + ?? throw new InvalidOperationException( + "Could not find MemoryExtensions.Contains method."); + + return Expression.Call(containsMethod, + Expression.Call(opImplicitMethod, Visit(source)), + Visit(node.Arguments[1])); + } + + case nameof(Enumerable.SequenceEqual): + { + var first = node.Arguments[0]; + var second = node.Arguments[1]; + if (!first.Type.IsArray || !second.Type.IsArray) + return base.VisitMethodCall(node); + + var sequenceEqualMethod = typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == nameof(MemoryExtensions.SequenceEqual) && + m.IsGenericMethod && + m.GetParameters().Length == 2 && + m.GetParameters()[0].ParameterType.IsGenericType && + m.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == + typeof(ReadOnlySpan<>) && + m.GetParameters()[1].ParameterType.IsGenericType && + m.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == + typeof(ReadOnlySpan<>)) + ?.MakeGenericMethod(elementType) + ?? throw new InvalidOperationException( + "Could not find MemoryExtensions.SequenceEquals method."); + + return Expression.Call(sequenceEqualMethod, + Expression.Call(opImplicitMethod, Visit(first)), + Expression.Call(opImplicitMethod, Visit(second))); + } + } + } + + return base.VisitMethodCall(node); + } + } +} From e112dd54962b1e7fe58ec999268ef2d59348d84a Mon Sep 17 00:00:00 2001 From: Damien Guard Date: Fri, 10 Oct 2025 11:41:31 +0100 Subject: [PATCH 2/5] Simplify calling both Partial & Compat into a new Preprocessor. --- ...ingWithOutputExpressionStageDefinitions.cs | 8 ++--- ....cs => ClrCompatExpressionPreprocessor.cs} | 12 +++---- .../Misc/LinqExpressionPreprocessor.cs | 33 +++++++++++++++++++ .../ExpressionToExecutableQueryTranslator.cs | 4 +-- .../LookupMethodToPipelineTranslator.cs | 4 +-- .../ExpressionToSetStageTranslator.cs | 4 +-- .../Linq/LinqProviderAdapter.cs | 13 ++++---- 7 files changed, 55 insertions(+), 23 deletions(-) rename src/MongoDB.Driver/Linq/Linq3Implementation/Misc/{ClrCompatExpressionRewriter.cs => ClrCompatExpressionPreprocessor.cs} (88%) create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs index 66ed5ceca0a..40e41bbd51c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs @@ -63,7 +63,7 @@ private AstStage RenderProjectStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer outputSerializer) { - var partiallyEvaluatedOutput = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(_output)); + var partiallyEvaluatedOutput = (Expression>)LinqExpressionPreprocessor.Preprocess(_output); var context = TranslationContext.Create(translationOptions); var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation); @@ -105,7 +105,7 @@ protected override AstStage RenderGroupingStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer> groupingOutputSerializer) { - var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(_groupBy)); + var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); @@ -149,7 +149,7 @@ protected override AstStage RenderGroupingStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer, TInput>> groupingOutputSerializer) { - var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(_groupBy)); + var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); @@ -187,7 +187,7 @@ protected override AstStage RenderGroupingStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer> groupingOutputSerializer) { - var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(_groupBy)); + var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs similarity index 88% rename from src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs rename to src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs index 71c0ea79175..0963bd0ddfa 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs @@ -24,12 +24,12 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; /// This visitor rewrites expressions where new features of .NET CLR or /// C# compiler interfere with LINQ expression tree translation. /// -internal class ClrCompatExpressionRewriter : ExpressionVisitor +internal class ClrCompatExpressionPreprocessor : ExpressionVisitor { - private static readonly ClrCompatExpressionRewriter s_clrCompatExpressionRewriter = new(); + private static readonly ClrCompatExpressionPreprocessor s_clrCompatExpressionPreprocessor = new(); - public static Expression Rewrite(Expression expression) - => s_clrCompatExpressionRewriter.Visit(expression); + public static Expression Preprocess(Expression expression) + => s_clrCompatExpressionPreprocessor.Visit(expression); /// protected override Expression VisitMethodCall(MethodCallExpression methodCall) @@ -46,8 +46,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCall) { // Replace MemoryExtensions.Contains(ReadOnlySpan, T) with Enumerable.Contains(IEnumerable, T) case nameof(MemoryExtensions.Contains) - when (methodCall.Arguments.Count == 2 - || (methodCall.Arguments.Count == 3 && methodCall.Arguments[2] is ConstantExpression { Value: null })) + when (methodCall.Arguments.Count == 2 || (methodCall.Arguments.Count == 3 && + methodCall.Arguments[2] is ConstantExpression { Value: null })) && TryUnwrapSpanImplicitCast(methodCall.Arguments[0], out var unwrappedSpanArg): return Visit( Call( diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs new file mode 100644 index 00000000000..44bb4051461 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs @@ -0,0 +1,33 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +/// +/// This class is called before we process any LINQ expression trees +/// to perform any necessary pre-processing such as CLR compatibility +/// and partial evaluation. +/// +internal static class LinqExpressionPreprocessor +{ + public static Expression Preprocess(Expression expression) + { + expression = ClrCompatExpressionPreprocessor.Preprocess(expression); + expression = PartialEvaluator.EvaluatePartially(expression); + return expression; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs index 29ddddb8e38..a6a89b7639f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs @@ -29,7 +29,7 @@ public static ExecutableQuery> Translate TranslateScalar(provider); body = ExpressionReplacer.Replace(body, queryableParameter, Expression.Constant(queryable)); - body = PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(body)); + body = LinqExpressionPreprocessor.Preprocess(body); return ExpressionToPipelineTranslator.Translate(context, body); } @@ -338,7 +338,7 @@ private static TranslatedPipeline TranslateLookupPipelineAgainstQueryable>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(expression)); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var context = TranslationContext.Create(translationOptions, contextData); var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, sourceSerializer, asRoot: true); var simplifiedAst = AstSimplifier.Simplify(translation.Ast); @@ -74,7 +74,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - expression = (LambdaExpression)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(expression)); + expression = (LambdaExpression)LinqExpressionPreprocessor.Preprocess(expression); var parameter = expression.Parameters.Single(); var context = TranslationContext.Create(translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); @@ -104,7 +104,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(expression)); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var parameter = expression.Parameters.Single(); var context = TranslationContext.Create(translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); @@ -124,8 +124,7 @@ internal static BsonDocument TranslateExpressionToElemMatchFilter( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - expression = (Expression>)ClrCompatExpressionRewriter.Rewrite(expression); - expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var context = TranslationContext.Create(translationOptions); var parameter = expression.Parameters.Single(); var symbol = context.CreateSymbol(parameter, "@", elementSerializer); // @ represents the implied element @@ -142,7 +141,7 @@ internal static BsonDocument TranslateExpressionToFilter( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - expression = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(expression)); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var context = TranslationContext.Create(translationOptions); var filter = ExpressionToFilterTranslator.TranslateLambda(context, expression, documentSerializer, asRoot: true); filter = AstSimplifier.SimplifyAndConvert(filter); @@ -176,7 +175,7 @@ private static RenderedProjectionDefinition TranslateExpressionToProjec return new RenderedProjectionDefinition(null, (IBsonSerializer)inputSerializer); } - expression = (Expression>)PartialEvaluator.EvaluatePartially(ClrCompatExpressionRewriter.Rewrite(expression)); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var context = TranslationContext.Create(translationOptions); var simplifier = forFind ? new AstFindProjectionSimplifier() : new AstSimplifier(); From 04984d1d22b36d119bcc0acf252c4af99e1b8640 Mon Sep 17 00:00:00 2001 From: Damien Guard Date: Tue, 14 Oct 2025 12:20:39 +0100 Subject: [PATCH 3/5] Update comment for clarity. --- .../Misc/ClrCompatExpressionPreprocessor.cs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs index 0963bd0ddfa..382c69d5a8f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs @@ -36,10 +36,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCall) { var method = methodCall.Method; - // C# 14 introduced implicit casts to ReadOnlySpan and Span - // This then binds Contains and SequenceEqual to the MemoryExtensions instead of Enumerable - // and we can't just add support for MemoryExtensions elsewhere as it's actually invalid - // in an expression tree due to the ref semantics of ReadOnlySpan. + // C# 14 compiler changes mean array and list Contains now prefer MemoryExtensions with a ReadOnlySpan cast + // instead of the traditional Enumerable methods. This aliases them back to their original Enumerable methods. if (method.DeclaringType == typeof(MemoryExtensions) && method.IsGenericMethod) { switch (method.Name) From 4a88a8366ccb6fc3c5348ef970f78ea91dd9fec0 Mon Sep 17 00:00:00 2001 From: Damien Guard Date: Tue, 14 Oct 2025 17:14:30 +0100 Subject: [PATCH 4/5] Perform partial evaluation first to use MemoryExtensions where possible in local eval. --- .../Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs index 44bb4051461..b58fc49c963 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs @@ -26,8 +26,8 @@ internal static class LinqExpressionPreprocessor { public static Expression Preprocess(Expression expression) { - expression = ClrCompatExpressionPreprocessor.Preprocess(expression); expression = PartialEvaluator.EvaluatePartially(expression); + expression = ClrCompatExpressionPreprocessor.Preprocess(expression); return expression; } } From bddbb4efaf067716f307af96b3d3e95795e979f2 Mon Sep 17 00:00:00 2001 From: rstam Date: Wed, 15 Oct 2025 14:18:36 -0400 Subject: [PATCH 5/5] CSHARP-5749: Requested changes. --- .../Misc/ClrCompatExpressionPreprocessor.cs | 90 ------ .../Misc/ClrCompatExpressionRewriter.cs | 150 ++++++++++ .../Misc/LinqExpressionPreprocessor.cs | 2 +- .../Misc/MethodInfoExtensions.cs | 47 +++ .../Misc/TypeExtensions.cs | 16 + .../Reflection/EnumerableMethod.cs | 6 + .../Reflection/MemoryExtensionsMethod.cs | 150 ++++++++++ .../Jira/CSharp5749Tests.cs | 281 +++++++++--------- .../LegacyPredicateTranslatorTests.cs | 2 +- .../Translators/PredicateTranslatorTests.cs | 2 +- 10 files changed, 507 insertions(+), 239 deletions(-) delete mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MemoryExtensionsMethod.cs diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs deleted file mode 100644 index 382c69d5a8f..00000000000 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionPreprocessor.cs +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright 2010-present MongoDB Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -using System; -using System.Linq.Expressions; -using MongoDB.Driver.Linq.Linq3Implementation.Reflection; -using static System.Linq.Expressions.Expression; - -namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; - -/// -/// This visitor rewrites expressions where new features of .NET CLR or -/// C# compiler interfere with LINQ expression tree translation. -/// -internal class ClrCompatExpressionPreprocessor : ExpressionVisitor -{ - private static readonly ClrCompatExpressionPreprocessor s_clrCompatExpressionPreprocessor = new(); - - public static Expression Preprocess(Expression expression) - => s_clrCompatExpressionPreprocessor.Visit(expression); - - /// - protected override Expression VisitMethodCall(MethodCallExpression methodCall) - { - var method = methodCall.Method; - - // C# 14 compiler changes mean array and list Contains now prefer MemoryExtensions with a ReadOnlySpan cast - // instead of the traditional Enumerable methods. This aliases them back to their original Enumerable methods. - if (method.DeclaringType == typeof(MemoryExtensions) && method.IsGenericMethod) - { - switch (method.Name) - { - // Replace MemoryExtensions.Contains(ReadOnlySpan, T) with Enumerable.Contains(IEnumerable, T) - case nameof(MemoryExtensions.Contains) - when (methodCall.Arguments.Count == 2 || (methodCall.Arguments.Count == 3 && - methodCall.Arguments[2] is ConstantExpression { Value: null })) - && TryUnwrapSpanImplicitCast(methodCall.Arguments[0], out var unwrappedSpanArg): - return Visit( - Call( - EnumerableMethod.Contains.MakeGenericMethod(method.GetGenericArguments()[0]), - unwrappedSpanArg, methodCall.Arguments[1]))!; - - // Replace MemoryExtensions.SequenceEqual(ReadOnlySpan, ReadOnlySpan) with Enumerable.SequenceEqual(IEnumerable, IEnumerable) - case nameof(MemoryExtensions.SequenceEqual) - when methodCall.Arguments.Count == 2 - && TryUnwrapSpanImplicitCast(methodCall.Arguments[0], out var unwrappedSpanArg) - && TryUnwrapSpanImplicitCast(methodCall.Arguments[1], out var unwrappedOtherArg): - return Visit( - Call( - EnumerableMethod.SequenceEqual.MakeGenericMethod(method.GetGenericArguments()[0]), - unwrappedSpanArg, unwrappedOtherArg))!; - } - - // Erase implicit casts to ReadOnlySpan and Span - static bool TryUnwrapSpanImplicitCast(Expression expression, out Expression result) - { - if (expression is MethodCallExpression - { - Method: - { - Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType - } - } methodCallExpression - && implicitCastDeclaringType.GetGenericTypeDefinition() is var genericTypeDefinition - && (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>))) - { - result = methodCallExpression.Arguments[0]; - return true; - } - - result = null; - return false; - } - } - - return base.VisitMethodCall(methodCall); - } -} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs new file mode 100644 index 00000000000..62983df83a7 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs @@ -0,0 +1,150 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +/// +/// This visitor rewrites expressions where new features of .NET CLR or +/// C# compiler interfere with LINQ expression tree translation. +/// +internal class ClrCompatExpressionRewriter : ExpressionVisitor +{ + private static readonly ClrCompatExpressionRewriter __instance = new(); + + public static Expression Rewrite(Expression expression) + => __instance.Visit(expression); + + /// + protected override Expression VisitMethodCall(MethodCallExpression node) + { + node = (MethodCallExpression)base.VisitMethodCall(node); + + var method = node.Method; + var arguments = node.Arguments; + + return method.Name switch + { + "Contains" => VisitContainsMethod(node, method, arguments), + "SequenceEqual" => VisitSequenceEqualMethod(node, method, arguments), + _ => node + }; + + static Expression VisitContainsMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection arguments) + { + if (method.IsOneOf(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValue, MemoryExtensionsMethod.ContainsWithSpanAndValue)) + { + var itemType = method.GetGenericArguments().Single(); + var span = arguments[0]; + var value = arguments[1]; + + if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) && + unwrappedSpan.Type.ImplementsIEnumerableOf(itemType)) + { + return + Expression.Call( + EnumerableMethod.Contains.MakeGenericMethod(itemType), + [unwrappedSpan, value]); + } + } + else if (method.Is(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValueAndComparer)) + { + var itemType = method.GetGenericArguments().Single(); + var span = arguments[0]; + var value = arguments[1]; + var comparer = arguments[2]; + + if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) && + unwrappedSpan.Type.ImplementsIEnumerableOf(itemType)) + { + return + Expression.Call( + EnumerableMethod.ContainsWithComparer.MakeGenericMethod(itemType), + [unwrappedSpan, value, comparer]); + } + } + + return node; + } + + static Expression VisitSequenceEqualMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection arguments) + { + if (method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpan, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpan)) + { + var itemType = method.GetGenericArguments().Single(); + var span = arguments[0]; + var other = arguments[1]; + + if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) && + TryUnwrapSpanImplicitCast(other, out var unwrappedOther) && + unwrappedSpan.Type.ImplementsIEnumerableOf(itemType) && + unwrappedOther.Type.ImplementsIEnumerableOf(itemType)) + { + return + Expression.Call( + EnumerableMethod.SequenceEqual.MakeGenericMethod(itemType), + [unwrappedSpan, unwrappedOther]); + } + } + else if (method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpanAndComparer)) + { + var itemType = method.GetGenericArguments().Single(); + var span = arguments[0]; + var other = arguments[1]; + var comparer = arguments[2]; + + if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) && + TryUnwrapSpanImplicitCast(other, out var unwrappedOther) && + unwrappedSpan.Type.ImplementsIEnumerableOf(itemType) && + unwrappedOther.Type.ImplementsIEnumerableOf(itemType)) + { + return + Expression.Call( + EnumerableMethod.SequenceEqualWithComparer.MakeGenericMethod(itemType), + [unwrappedSpan, unwrappedOther, comparer]); + } + } + + return node; + } + + // Erase implicit casts to ReadOnlySpan and Span + static bool TryUnwrapSpanImplicitCast(Expression expression, out Expression result) + { + if (expression is MethodCallExpression + { + Method: + { + Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType + } + } methodCallExpression + && implicitCastDeclaringType.GetGenericTypeDefinition() is var genericTypeDefinition + && (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>))) + { + result = methodCallExpression.Arguments[0]; + return true; + } + + result = null; + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs index b58fc49c963..d302dd05cb4 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs @@ -26,8 +26,8 @@ internal static class LinqExpressionPreprocessor { public static Expression Preprocess(Expression expression) { + expression = ClrCompatExpressionRewriter.Rewrite(expression); expression = PartialEvaluator.EvaluatePartially(expression); - expression = ClrCompatExpressionPreprocessor.Preprocess(expression); return expression; } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs index d7a61ab07d5..54ef243728b 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs @@ -13,12 +13,59 @@ * limitations under the License. */ +using System; using System.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Misc { internal static class MethodInfoExtensions { + public static bool Has1GenericArgument(this MethodInfo method, out Type genericArgument) + { + if (method.IsGenericMethod && + method.GetGenericArguments() is var genericArguments && + genericArguments.Length == 1) + { + genericArgument = genericArguments[0]; + return true; + } + + genericArgument = null; + return false; + } + + public static bool Has2Parameters(this MethodInfo method, out ParameterInfo parameter1, out ParameterInfo parameter2) + { + if (method.GetParameters() is var parameters && + parameters.Length == 2) + { + parameter1 = parameters[0]; + parameter2 = parameters[1]; + return true; + } + + parameter1 = null; + parameter2 = null; + return false; + } + + public static bool Has3Parameters(this MethodInfo method, out ParameterInfo parameter1, out ParameterInfo parameter2, out ParameterInfo parameter3) + { + if (method.GetParameters() is var parameters && + parameters.Length == 3) + { + parameter1 = parameters[0]; + parameter2 = parameters[1]; + parameter3 = parameters[2]; + return true; + } + + parameter1 = null; + parameter2 = null; + parameter3 = null; + return false; + } + public static bool Is(this MethodInfo method, MethodInfo comparand) { if (comparand != null) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs index ccb8f699740..f6d2f758940 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs @@ -248,11 +248,27 @@ public static bool IsNullableOf(this Type type, Type valueType) return type.IsNullable(out var nullableValueType) && nullableValueType == valueType; } + public static bool IsReadOnlySpanOf(this Type type, Type itemType) + { + return + type.IsGenericType && + type.GetGenericTypeDefinition() == typeof(ReadOnlySpan<>) && + type.GetGenericArguments()[0] == itemType; + } + public static bool IsSameAsOrNullableOf(this Type type, Type valueType) { return type == valueType || type.IsNullableOf(valueType); } + public static bool IsSpanOf(this Type type, Type itemType) + { + return + type.IsGenericType && + type.GetGenericTypeDefinition() == typeof(Span<>) && + type.GetGenericArguments()[0] == itemType; + } + public static bool IsSubclassOfOrImplements(this Type type, Type baseTypeOrInterface) { return diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs index 0ae3e99ca4a..a10f2a67531 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs @@ -59,6 +59,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __cast; private static readonly MethodInfo __concat; private static readonly MethodInfo __contains; + private static readonly MethodInfo __containsWithComparer; private static readonly MethodInfo __count; private static readonly MethodInfo __countWithPredicate; private static readonly MethodInfo __defaultIfEmpty; @@ -150,6 +151,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __selectManyWithSelectorTakingIndex; private static readonly MethodInfo __selectWithSelectorTakingIndex; private static readonly MethodInfo __sequenceEqual; + private static readonly MethodInfo __sequenceEqualWithComparer; private static readonly MethodInfo __single; private static readonly MethodInfo __singleOrDefault; private static readonly MethodInfo __singleOrDefaultWithPredicate; @@ -226,6 +228,7 @@ static EnumerableMethod() __cast = ReflectionInfo.Method((IEnumerable source) => source.Cast()); __concat = ReflectionInfo.Method((IEnumerable first, IEnumerable second) => first.Concat(second)); __contains = ReflectionInfo.Method((IEnumerable source, object value) => source.Contains(value)); + __containsWithComparer = ReflectionInfo.Method((IEnumerable source, object value, IEqualityComparer comparer) => source.Contains(value, comparer)); __count = ReflectionInfo.Method((IEnumerable source) => source.Count()); __countWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.Count(predicate)); __defaultIfEmpty = ReflectionInfo.Method((IEnumerable source) => source.DefaultIfEmpty()); @@ -317,6 +320,7 @@ static EnumerableMethod() __selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); __selectWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Select(selector)); __sequenceEqual = ReflectionInfo.Method((IEnumerable first, IEnumerable second) => first.SequenceEqual(second)); + __sequenceEqualWithComparer = ReflectionInfo.Method((IEnumerable first, IEnumerable second, IEqualityComparer comparer) => first.SequenceEqual(second, comparer)); __single = ReflectionInfo.Method((IEnumerable source) => source.Single()); __singleOrDefault = ReflectionInfo.Method((IEnumerable source) => source.SingleOrDefault()); __singleOrDefaultWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.SingleOrDefault(predicate)); @@ -392,6 +396,7 @@ static EnumerableMethod() public static MethodInfo Cast => __cast; public static MethodInfo Concat => __concat; public static MethodInfo Contains => __contains; + public static MethodInfo ContainsWithComparer => __containsWithComparer; public static MethodInfo Count => __count; public static MethodInfo CountWithPredicate => __countWithPredicate; public static MethodInfo DefaultIfEmpty => __defaultIfEmpty; @@ -483,6 +488,7 @@ static EnumerableMethod() public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex; public static MethodInfo SelectWithSelectorTakingIndex => __selectWithSelectorTakingIndex; public static MethodInfo SequenceEqual => __sequenceEqual; + public static MethodInfo SequenceEqualWithComparer => __sequenceEqualWithComparer; public static MethodInfo Single => __single; public static MethodInfo SingleOrDefault => __singleOrDefault; public static MethodInfo SingleOrDefaultWithPredicate => __singleOrDefaultWithPredicate; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MemoryExtensionsMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MemoryExtensionsMethod.cs new file mode 100644 index 00000000000..59aa2432c5e --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MemoryExtensionsMethod.cs @@ -0,0 +1,150 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +internal static class MemoryExtensionsMethod +{ + // private static fields + private static readonly MethodInfo __containsWithReadOnlySpanAndValue; + private static readonly MethodInfo __containsWithReadOnlySpanAndValueAndComparer; + private static readonly MethodInfo __containsWithSpanAndValue; + private static readonly MethodInfo __sequenceEqualWithReadOnlySpanAndReadOnlySpan; + private static readonly MethodInfo __sequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer; + private static readonly MethodInfo __sequenceEqualWithSpanAndReadOnlySpan; + private static readonly MethodInfo __sequenceEqualWithSpanAndReadOnlySpanAndComparer; + + // static constructor + static MemoryExtensionsMethod() + { + __containsWithReadOnlySpanAndValue = GetContainsWithReadOnlySpanAndValueMethod(); + __containsWithReadOnlySpanAndValueAndComparer = GetContainsWithReadOnlySpanAndValueAndComparerMethod(); + __containsWithSpanAndValue = GetContainsWithSpanAndValueMethod(); + __sequenceEqualWithReadOnlySpanAndReadOnlySpan = GetSequenceEqualWithReadOnlySpanAndReadOnlySpan(); + __sequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer = GetSequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer(); + __sequenceEqualWithSpanAndReadOnlySpan = GetSequenceEqualWithSpanAndReadOnlySpan(); + __sequenceEqualWithSpanAndReadOnlySpanAndComparer = GetSequenceEqualWithSpanAndReadOnlySpanAndComparer(); + } + + // public static properties + public static MethodInfo ContainsWithReadOnlySpanAndValue => __containsWithReadOnlySpanAndValue; + public static MethodInfo ContainsWithReadOnlySpanAndValueAndComparer => __containsWithReadOnlySpanAndValueAndComparer; + public static MethodInfo ContainsWithSpanAndValue => __containsWithSpanAndValue; + public static MethodInfo SequenceEqualWithReadOnlySpanAndReadOnlySpan => __sequenceEqualWithReadOnlySpanAndReadOnlySpan; + public static MethodInfo SequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer => __sequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer; + public static MethodInfo SequenceEqualWithSpanAndReadOnlySpan => __sequenceEqualWithSpanAndReadOnlySpan; + public static MethodInfo SequenceEqualWithSpanAndReadOnlySpanAndComparer => __sequenceEqualWithSpanAndReadOnlySpanAndComparer; + + // private static methods + private static MethodInfo GetContainsWithReadOnlySpanAndValueMethod() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "Contains" && + m.Has1GenericArgument(out var itemType) && + m.Has2Parameters(out var spanParameter, out var valueParameter) && + spanParameter.ParameterType.IsReadOnlySpanOf(itemType) && + valueParameter.ParameterType == itemType); + } + + private static MethodInfo GetContainsWithReadOnlySpanAndValueAndComparerMethod() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "Contains" && + m.Has1GenericArgument(out var itemType) && + m.Has3Parameters(out var spanParameter, out var valueParameter, out var comparerParameter) && + spanParameter.ParameterType.IsReadOnlySpanOf(itemType) && + valueParameter.ParameterType == itemType && + comparerParameter.ParameterType == typeof(IEqualityComparer<>).MakeGenericType(itemType)); + } + + private static MethodInfo GetContainsWithSpanAndValueMethod() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "Contains" && + m.Has1GenericArgument(out var itemType) && + m.Has2Parameters(out var spanParameter, out var valueParameter) && + spanParameter.ParameterType.IsSpanOf(itemType) && + valueParameter.ParameterType == itemType); + } + + private static MethodInfo GetSequenceEqualWithReadOnlySpanAndReadOnlySpan() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "SequenceEqual" && + m.Has1GenericArgument(out var itemType) && + m.Has2Parameters(out var spanParameter, out var otherParameter) && + spanParameter.ParameterType.IsReadOnlySpanOf(itemType) && + otherParameter.ParameterType.IsReadOnlySpanOf(itemType)); + } + + private static MethodInfo GetSequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "SequenceEqual" && + m.Has1GenericArgument(out var itemType) && + m.Has3Parameters(out var spanParameter, out var otherParameter, out var comparerParameter) && + spanParameter.ParameterType.IsReadOnlySpanOf(itemType) && + otherParameter.ParameterType.IsReadOnlySpanOf(itemType) && + comparerParameter.ParameterType == typeof(IEqualityComparer<>).MakeGenericType(itemType)); + } + + private static MethodInfo GetSequenceEqualWithSpanAndReadOnlySpan() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "SequenceEqual" && + m.Has1GenericArgument(out var itemType) && + m.Has2Parameters(out var spanParameter, out var otherParameter) && + spanParameter.ParameterType.IsSpanOf(itemType) && + otherParameter.ParameterType.IsReadOnlySpanOf(itemType)); + } + + private static MethodInfo GetSequenceEqualWithSpanAndReadOnlySpanAndComparer() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "SequenceEqual" && + m.Has1GenericArgument(out var itemType) && + m.Has3Parameters(out var spanParameter, out var otherParameter, out var comparerParameter) && + spanParameter.ParameterType.IsSpanOf(itemType) && + otherParameter.ParameterType.IsReadOnlySpanOf(itemType) && + comparerParameter.ParameterType == typeof(IEqualityComparer<>).MakeGenericType(itemType)); + } +} diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs index aa694c827fd..470f8828777 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs @@ -15,15 +15,20 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using System.Linq.Expressions; using System.Reflection; +using FluentAssertions; using MongoDB.Bson; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; using MongoDB.Driver.TestHelpers; using Xunit; namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; +#if NET6_0_OR_GREATER || NETCOREAPP3_1_OR_GREATER public class CSharp5749Tests : LinqIntegrationTest { public CSharp5749Tests(ClassFixture fixture) @@ -31,90 +36,96 @@ public CSharp5749Tests(ClassFixture fixture) { } -#if NET6_0_OR_GREATER - [Fact] - public void MemoryExtension_contains_in_where_should_work() + public void MemoryExtensions_Contains_in_Where_should_work() { + var collection = Fixture.Collection; var names = new[] { "Two", "Three" }; - var queryable = MangledQueryable().Where(x => names.Contains(x.Name)); + var queryable = collection.AsQueryable().Where(Rewrite((C x) => names.Contains(x.Name))); - Assert.Equal(2, queryable.Count()); + var results = queryable.ToArray(); + results.Select(x => x.Id).Should().Equal(2, 3); } [Fact] - public void MemoryExtension_contains_in_single_should_work() + public void MemoryExtensions_Contains_in_Single_should_work() { + var collection = Fixture.Collection; var names = new[] { "Two" }; - var actual = MangledQueryable().Single(x => names.Contains(x.Name)); + var result = collection.AsQueryable().Single(Rewrite((C x) => names.Contains(x.Name))); - Assert.Equal("Two", actual.Name); + result.Id.Should().Be(2); } [Fact] - public void MemoryExtension_contains_in_any_should_work() + public void MemoryExtensions_Contains_in_Any_should_work() { + var collection = Fixture.Collection; var ids = new[] { 2 }; - var actual = MangledQueryable().Any(x => ids.Contains(x.Id)); + var result = collection.AsQueryable().Any(Rewrite((C x) => ids.Contains(x.Id))); - Assert.True(actual); + result.Should().BeTrue(); } [Fact] - public void MemoryExtension_contains_in_count_should_work() + public void MemoryExtensions_Contains_in_Count_should_work() { + var collection = Fixture.Collection; var ids = new[] { 2 }; - var actual = MangledQueryable().Count(x => ids.Contains(x.Id)); + var result = collection.AsQueryable().Count(Rewrite((C x) => ids.Contains(x.Id))); - Assert.Equal(1, actual); + result.Should().Be(1); } [Fact] - public void MemoryExtension_sequenceequal_in_where_should_work() + public void MemoryExtensions_SequenceEqual_in_Where_should_work() { + var collection = Fixture.Collection; var ratings = new[] { 1, 9, 6 }; - var queryable = MangledQueryable().Where(x => ratings.SequenceEqual(x.Ratings)); + var queryable = collection.AsQueryable().Where(Rewrite((C x) => ratings.SequenceEqual(x.Ratings))); - Assert.Equal(1, queryable.Count()); + var results = queryable.ToArray(); + results.Select(x => x.Id).Should().Equal(3); } [Fact] - public void MemoryExtension_sequenceequal_in_single_should_work() + public void MemoryExtensions_SequenceEqual_in_Single_should_work() { + var collection = Fixture.Collection; var ratings = new[] { 1, 9, 6 }; - var actual = MangledQueryable().Single(x => ratings.SequenceEqual(x.Ratings)); + var result = collection.AsQueryable().Single(Rewrite((C x) => ratings.SequenceEqual(x.Ratings))); - Assert.Equal("Three", actual.Name); + result.Id.Should().Be(3); } [Fact] - public void MemoryExtension_sequenceequas_in_any_should_work() + public void MemoryExtensions_SequenceEqual_in_Any_should_work() { + var collection = Fixture.Collection; var ratings = new[] { 1, 2, 3, 4, 5 }; - var actual = MangledQueryable().Any(x => ratings.SequenceEqual(x.Ratings)); + var result = collection.AsQueryable().Any(Rewrite((C x) => ratings.SequenceEqual(x.Ratings))); - Assert.True(actual); + result.Should().BeTrue(); } [Fact] - public void MemoryExtension_sequenceequas_in_count_should_work() + public void MemoryExtensions_SequenceEqual_in_Count_should_work() { + var collection = Fixture.Collection; var ratings = new[] { 3, 4, 5, 6, 7 }; - var actual = MangledQueryable().Count(x => ratings.SequenceEqual(x.Ratings)); + var result = collection.AsQueryable().Count(Rewrite((C x) => ratings.SequenceEqual(x.Ratings))); - Assert.Equal(1, actual); + result.Should().Be(1); } -#endif - public class C { public int Id { get; set; } @@ -126,143 +137,121 @@ public sealed class ClassFixture : MongoCollectionFixture { protected override IEnumerable InitialData => [ - BsonDocument.Parse("{ _id : 1, Name: \"One\", Ratings: [1,2,3,4,5] }"), - BsonDocument.Parse("{ _id : 2, Name: \"Two\", Ratings: [3,4,5,6,7] }"), - BsonDocument.Parse("{ _id : 3, Name: \"Three\", Ratings: [1,9,6] }") + BsonDocument.Parse("{ _id : 1, Name : \"One\", Ratings : [1, 2, 3, 4, 5] }"), + BsonDocument.Parse("{ _id : 2, Name : \"Two\", Ratings : [3, 4, 5, 6, 7] }"), + BsonDocument.Parse("{ _id : 3, Name : \"Three\", Ratings : [1, 9, 6] }") ]; } - public IQueryable MangledQueryable() - => new ManglingQueryable(Fixture.Collection.AsQueryable()); - - public class ManglingQueryProvider(IQueryProvider innerProvider) : IQueryProvider + private Expression> Rewrite(Expression> predicate) { - public IQueryable CreateQuery(Expression expression) - => innerProvider.CreateQuery(EnumerableToMemoryExtensionsMangler.Mangle(expression)); - - public IQueryable CreateQuery(Expression expression) - { - var mangledExpression = EnumerableToMemoryExtensionsMangler.Mangle(expression); - return innerProvider.CreateQuery(mangledExpression); - } - - public object Execute(Expression expression) - { - var mangledExpression = EnumerableToMemoryExtensionsMangler.Mangle(expression); - return innerProvider.Execute(mangledExpression); - } - - public TResult Execute(Expression expression) - { - var mangledExpression = EnumerableToMemoryExtensionsMangler.Mangle(expression); - return innerProvider.Execute(mangledExpression); - } + return (Expression>)new EnumerableToMemoryExtensionsRewriter().Visit(predicate); } - public class ManglingQueryable(IQueryable innerQueryable) : IOrderedQueryable + public class EnumerableToMemoryExtensionsRewriter : ExpressionVisitor { - private readonly ManglingQueryProvider _provider = new(innerQueryable.Provider); - - public Type ElementType => innerQueryable.ElementType; - public Expression Expression => innerQueryable.Expression; - public IQueryProvider Provider => _provider; - - public IEnumerator GetEnumerator() + protected override Expression VisitMethodCall(MethodCallExpression node) { - var mangledExpression = EnumerableToMemoryExtensionsMangler.Mangle(Expression); - var query = innerQueryable.Provider.CreateQuery(mangledExpression); - return query.GetEnumerator(); - } + node = (MethodCallExpression)base.VisitMethodCall(node); - System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() - => GetEnumerator(); - } + var method = node.Method; + var arguments = node.Arguments; - public class EnumerableToMemoryExtensionsMangler : ExpressionVisitor - { - private static readonly EnumerableToMemoryExtensionsMangler s_instance = new(); - - public static Expression Mangle(Expression expression) - => s_instance.Visit(expression); + return method.Name switch + { + "Contains" => VisitContainsMethod(node, method, arguments), + "SequenceEqual" => VisitSequenceEqualMethod(node, method, arguments), + _ => node + }; - protected override Expression VisitMethodCall(MethodCallExpression node) - { - // Check if this is Enumerable.Contains or Enumerable.SequenceEqual - if (node.Method.DeclaringType == typeof(Enumerable) && - node.Method.IsGenericMethod && - (node.Method.Name == nameof(Enumerable.Contains) || - node.Method.Name == nameof(Enumerable.SequenceEqual))) + static Expression VisitContainsMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection arguments) { - var elementType = node.Method.GetGenericArguments()[0]; + if (method.Is(EnumerableMethod.Contains)) + { + var itemType = method.GetGenericArguments().Single(); + var source = arguments[0]; + var value = arguments[1]; + + if (source.Type.IsArray) + { + var readOnlySpan = ImplicitCastArrayToSpan(source, typeof(ReadOnlySpan<>), itemType); + return + Expression.Call( + MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValue.MakeGenericMethod(itemType), + [readOnlySpan, value]); + } + } + else if (method.Is(EnumerableMethod.ContainsWithComparer)) + { + var itemType = method.GetGenericArguments().Single(); + var source = arguments[0]; + var value = arguments[1]; + var comparer = arguments[2]; + + if (source.Type.IsArray) + { + var readOnlySpan = ImplicitCastArrayToSpan(source, typeof(ReadOnlySpan<>), itemType); + return + Expression.Call( + MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValueAndComparer.MakeGenericMethod(itemType), + [readOnlySpan, value, comparer]); + } + } - // Get ReadOnlySpan.op_Implicit method - var spanType = typeof(ReadOnlySpan<>).MakeGenericType(elementType); - var opImplicitMethod = spanType.GetMethod("op_Implicit", - BindingFlags.Public | BindingFlags.Static, - null, - [elementType.MakeArrayType()], - null); - if (opImplicitMethod == null) - return base.VisitMethodCall(node); + return node; + } - switch (node.Method.Name) + static Expression VisitSequenceEqualMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection arguments) + { + if (method.Is(EnumerableMethod.SequenceEqual)) + { + var itemType = method.GetGenericArguments().Single(); + var first = arguments[0]; + var second = arguments[1]; + + if (first.Type.IsArray && second.Type.IsArray) + { + var firstReadOnlySpan = ImplicitCastArrayToSpan(first, typeof(ReadOnlySpan<>), itemType); + var secondReadOnlySpan = ImplicitCastArrayToSpan(second, typeof(ReadOnlySpan<>), itemType); + return + Expression.Call( + MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpan.MakeGenericMethod(itemType), + [firstReadOnlySpan, secondReadOnlySpan]); + } + } + else if (method.Is(EnumerableMethod.SequenceEqualWithComparer)) { - case nameof(Enumerable.Contains): - { - var source = node.Arguments[0]; - if (!source.Type.IsArray) - return base.VisitMethodCall(node); - - var containsMethod = typeof(MemoryExtensions) - .GetMethods(BindingFlags.Public | BindingFlags.Static) - .FirstOrDefault(m => - m.Name == nameof(MemoryExtensions.Contains) && - m.IsGenericMethod && - m.GetParameters().Length == 2 && - m.GetParameters()[0].ParameterType.IsGenericType && - m.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == - typeof(ReadOnlySpan<>)) - ?.MakeGenericMethod(elementType) - ?? throw new InvalidOperationException( - "Could not find MemoryExtensions.Contains method."); - - return Expression.Call(containsMethod, - Expression.Call(opImplicitMethod, Visit(source)), - Visit(node.Arguments[1])); - } - - case nameof(Enumerable.SequenceEqual): - { - var first = node.Arguments[0]; - var second = node.Arguments[1]; - if (!first.Type.IsArray || !second.Type.IsArray) - return base.VisitMethodCall(node); - - var sequenceEqualMethod = typeof(MemoryExtensions) - .GetMethods(BindingFlags.Public | BindingFlags.Static) - .FirstOrDefault(m => - m.Name == nameof(MemoryExtensions.SequenceEqual) && - m.IsGenericMethod && - m.GetParameters().Length == 2 && - m.GetParameters()[0].ParameterType.IsGenericType && - m.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == - typeof(ReadOnlySpan<>) && - m.GetParameters()[1].ParameterType.IsGenericType && - m.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == - typeof(ReadOnlySpan<>)) - ?.MakeGenericMethod(elementType) - ?? throw new InvalidOperationException( - "Could not find MemoryExtensions.SequenceEquals method."); - - return Expression.Call(sequenceEqualMethod, - Expression.Call(opImplicitMethod, Visit(first)), - Expression.Call(opImplicitMethod, Visit(second))); - } + var itemType = method.GetGenericArguments().Single(); + var first = arguments[0]; + var second = arguments[1]; + var comparer = arguments[2]; + + if (first.Type.IsArray && second.Type.IsArray) + { + var firstReadOnlySpan = ImplicitCastArrayToSpan(first, typeof(ReadOnlySpan<>), itemType); + var secondReadOnlySpan = ImplicitCastArrayToSpan(second, typeof(ReadOnlySpan<>), itemType); + return + Expression.Call( + MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpan.MakeGenericMethod(itemType), + [firstReadOnlySpan, secondReadOnlySpan, comparer]); + } } + + return node; } - return base.VisitMethodCall(node); + static Expression ImplicitCastArrayToSpan(Expression value, Type spanType, Type itemType) + { + var opImplicitMethod = spanType.MakeGenericType(itemType).GetMethod( + "op_Implicit", + BindingFlags.Public | BindingFlags.Static, + null, + [itemType.MakeArrayType()], + null); + return Expression.Call(opImplicitMethod, value); + } } } } +#endif diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs index cd49af1955f..fa01543be13 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs @@ -1180,7 +1180,7 @@ public void TestWhereXNotEquals1Not() private void Assert(Expression> expression, int expectedCount, string expectedFilter) { - expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var parameter = expression.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs index 0869d70822e..c96a2a96b9a 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs @@ -1150,7 +1150,7 @@ private void Assert(IMongoCollection collection, Expressio public List Assert(IMongoCollection collection, Expression> filter, int expectedCount, BsonDocument expectedFilter) { - filter = (Expression>)PartialEvaluator.EvaluatePartially(filter); + filter = (Expression>)LinqExpressionPreprocessor.Preprocess(filter); var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); var parameter = filter.Parameters.Single();