diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs index cf36528829b..a88ffb55ea1 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs @@ -19,6 +19,7 @@ using MongoDB.Bson.Serialization; using MongoDB.Bson.Serialization.Serializers; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators @@ -123,6 +124,15 @@ public static AggregationExpression TranslateLambdaBody( asRoot ? context.CreateSymbolWithVarName(parameterExpression, varName: "ROOT", parameterSerializer, isCurrent: true) : context.CreateSymbol(parameterExpression, parameterSerializer, isCurrent: false); + + return TranslateLambdaBody(context, lambdaExpression, parameterSymbol); + } + + public static AggregationExpression TranslateLambdaBody( + TranslationContext context, + LambdaExpression lambdaExpression, + Symbol parameterSymbol) + { var lambdaContext = context.WithSymbol(parameterSymbol); var translatedBody = Translate(lambdaContext, lambdaExpression.Body); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs index a1ea0c4ebbd..a6ff40dbf0d 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/FirstOrLastMethodToAggregationExpressionTranslator.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System.Linq; using System.Linq.Expressions; using System.Reflection; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; @@ -97,12 +98,13 @@ public static AggregationExpression Translate(TranslationContext context, Method if (method.IsOneOf(__withPredicateMethods)) { var predicateLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); - var predicateTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, predicateLambda, itemSerializer, asRoot: false); - var predicateParameter = predicateLambda.Parameters[0]; + var parameterExpression = predicateLambda.Parameters.Single(); + var parameterSymbol = context.CreateSymbol(parameterExpression, itemSerializer, isCurrent: false); + var predicateTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, predicateLambda, parameterSymbol); sourceAst = AstExpression.Filter( input: sourceAst, cond: predicateTranslation.Ast, - @as: predicateParameter.Name); + @as: parameterSymbol.Var.Name); } AstExpression ast; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs index fb5337ecc5d..1eff6481de5 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/WhereMethodToAggregationExpressionTranslator.cs @@ -59,7 +59,7 @@ public static AggregationExpression Translate(TranslationContext context, Method var ast = AstExpression.Filter( sourceTranslation.Ast, predicateTranslation.Ast, - predicateParameter.Name, + @as: predicateSymbol.Var.Name, limitTranslation?.Ast); var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5258Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5258Tests.cs new file mode 100644 index 00000000000..bdb3bf5a05f --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5258Tests.cs @@ -0,0 +1,72 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using FluentAssertions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp5258Tests : Linq3IntegrationTest + { + [Fact] + public void Select_First_with_predicate_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Select(_x => _x.List.First(_y => _y > 2)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $arrayElemAt : [{ $filter : { input : '$List', as : 'v__0', cond : { $gt : ['$$v__0', 2] } } }, 0] }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Should().Equal(3, 4); + } + + [Fact] + public void Select_Where_should_work() + { + var collection = GetCollection(); + + var queryable = collection.AsQueryable() + .Select(_x => _x.List.Where(_y => _y % 2 == 0)); + + var stages = Translate(collection, queryable); + AssertStages(stages, "{ $project : { _v : { $filter : { input : '$List', as : 'v__0', cond : { $eq : [{ $mod : ['$$v__0', 2] }, 0] } } }, _id : 0 } }"); + + var results = queryable.ToList(); + results.Count.Should().Be(2); + results[0].Should().Equal(); + results[1].Should().Equal(2, 4, 6); + } + + private IMongoCollection GetCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection, + new C { Id = 1, List = [1, 3, 5] }, + new C { Id = 2, List = [2, 4, 6] }); + return collection; + } + + private class C + { + public int Id { get; set; } + public int[] List { get; set; } + } + } +}