Skip to content

Commit

Permalink
CSHARP-4872: Add support for Append in aggregate expressions.
Browse files Browse the repository at this point in the history
  • Loading branch information
rstam committed Dec 7, 2024
1 parent fe331f9 commit 19c9138
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ internal static class EnumerableMethod
private static readonly MethodInfo __all;
private static readonly MethodInfo __any;
private static readonly MethodInfo __anyWithPredicate;
private static readonly MethodInfo __append;
private static readonly MethodInfo __averageDecimal;
private static readonly MethodInfo __averageDecimalWithSelector;
private static readonly MethodInfo __averageDouble;
Expand Down Expand Up @@ -195,6 +196,7 @@ static EnumerableMethod()
__all = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.All(predicate));
__any = ReflectionInfo.Method((IEnumerable<object> source) => source.Any());
__anyWithPredicate = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.Any(predicate));
__append = ReflectionInfo.Method((IEnumerable<object> source, object element) => source.Append(element));
__averageDecimal = ReflectionInfo.Method((IEnumerable<decimal> source) => source.Average());
__averageDecimalWithSelector = ReflectionInfo.Method((IEnumerable<object> source, Func<object, decimal> selector) => source.Average(selector));
__averageDouble = ReflectionInfo.Method((IEnumerable<double> source) => source.Average());
Expand Down Expand Up @@ -357,6 +359,7 @@ static EnumerableMethod()
public static MethodInfo All => __all;
public static MethodInfo Any => __any;
public static MethodInfo AnyWithPredicate => __anyWithPredicate;
public static MethodInfo Append => __append;
public static MethodInfo AverageDecimal => __averageDecimal;
public static MethodInfo AverageDecimalWithSelector => __averageDecimalWithSelector;
public static MethodInfo AverageDouble => __averageDouble;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ internal static class QueryableMethod
private static readonly MethodInfo __all;
private static readonly MethodInfo __any;
private static readonly MethodInfo __anyWithPredicate;
private static readonly MethodInfo __append;
private static readonly MethodInfo __asQueryable;
private static readonly MethodInfo __averageDecimal;
private static readonly MethodInfo __averageDecimalWithSelector;
Expand Down Expand Up @@ -135,6 +136,7 @@ static QueryableMethod()
__all = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.All(predicate));
__any = ReflectionInfo.Method((IQueryable<object> source) => source.Any());
__anyWithPredicate = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.Any(predicate));
__append = ReflectionInfo.Method((IQueryable<object> source, object element) => source.Append(element));
__asQueryable = ReflectionInfo.Method((IEnumerable<object> source) => source.AsQueryable());
__averageDecimal = ReflectionInfo.Method((IQueryable<decimal> source) => source.Average());
__averageDecimalWithSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, decimal>> selector) => source.Average(selector));
Expand Down Expand Up @@ -239,6 +241,7 @@ static QueryableMethod()
public static MethodInfo All => __all;
public static MethodInfo Any => __any;
public static MethodInfo AnyWithPredicate => __anyWithPredicate;
public static MethodInfo Append => __append;
public static MethodInfo AsQueryable => __asQueryable;
public static MethodInfo AverageDecimal => __averageDecimal;
public static MethodInfo AverageDecimalWithSelector => __averageDecimalWithSelector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
case "Aggregate": return AggregateMethodToAggregationExpressionTranslator.Translate(context, expression);
case "All": return AllMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Any": return AnyMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Append": return AppendMethodToAggregationExpressionTranslator.Translate(context, expression);
case "AsQueryable": return AsQueryableMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Average": return AverageMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Ceiling": return CeilingMethodToAggregationExpressionTranslator.Translate(context, expression);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq.Expressions;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
{
internal static class AppendMethodToAggregationExpressionTranslator
{
public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
{
if (EnumerableAppendMethodToAggregationExpressionTranslator.CanTranslate(expression))
{
return EnumerableAppendMethodToAggregationExpressionTranslator.Translate(context, expression);
}

throw new ExpressionNotSupportedException(expression);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq.Expressions;
using System.Reflection;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
{
internal static class EnumerableAppendMethodToAggregationExpressionTranslator
{
private static readonly MethodInfo[] __appendMethods =
{
EnumerableMethod.Append,
QueryableMethod.Append
};

public static bool CanTranslate(MethodCallExpression expression)
=> expression.Method.IsOneOf(__appendMethods);

public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
{
var method = expression.Method;
var arguments = expression.Arguments;

if (method.IsOneOf(__appendMethods))
{
var firstExpression = arguments[0];
var secondExpression = arguments[1];

var firstTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, firstExpression);
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, firstTranslation);
var itemSerializer = ArraySerializerHelper.GetItemSerializer(firstTranslation.Serializer);

AggregationExpression secondTranslation;
if (secondExpression is ConstantExpression secondConstantExpression)
{
var value = secondConstantExpression.Value;
var serializedValue = SerializationHelper.SerializeValue(itemSerializer, value);
secondTranslation = new AggregationExpression(secondExpression, AstExpression.Constant(serializedValue), itemSerializer);
}
else
{
secondTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, secondExpression);
if (!secondTranslation.Serializer.Equals(itemSerializer))
{
throw new ExpressionNotSupportedException(expression, because: "argument serializers are not compatible");
}
}

var ast = AstExpression.ConcatArrays(firstTranslation.Ast, AstExpression.ComputedArray(secondTranslation.Ast));
var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer);

return new AggregationExpression(expression, ast, serializer);
}

throw new ExpressionNotSupportedException(expression);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq;
using FluentAssertions;
using MongoDB.TestHelpers.XunitExtensions;
using Xunit;

namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira
{
public class CSharp4872Tests : Linq3IntegrationTest
{
[Theory]
[ParameterAttributeData]
public void Append_constant_should_work(
[Values(false, true)] bool withNestedAsQueryable)
{
var collection = GetCollection();

var queryable = withNestedAsQueryable ?
collection.AsQueryable().Select(x => x.A.AsQueryable().Append(4).ToList()) :
collection.AsQueryable().Select(x => x.A.Append(4).ToList());

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { _v : { $concatArrays : ['$A', [4]] }, _id : 0 } }");

var result = queryable.Single();
result.Should().Equal(1, 2, 3, 4);
}

[Theory]
[ParameterAttributeData]
public void Append_expression_should_work(
[Values(false, true)] bool withNestedAsQueryable)
{
var collection = GetCollection();

var queryable = withNestedAsQueryable ?
collection.AsQueryable().Select(x => x.A.AsQueryable().Append(x.B).ToList()) :
collection.AsQueryable().Select(x => x.A.Append(x.B).ToList());

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { _v : { $concatArrays : ['$A', ['$B']] }, _id : 0 } }");

var result = queryable.Single();
result.Should().Equal(1, 2, 3, 4);
}

private IMongoCollection<C> GetCollection()
{
var collection = GetCollection<C>("test");
CreateCollection(
collection,
new C { Id = 1, A = [1, 2, 3], B = 4 });
return collection;
}

private class C
{
public int Id { get; set; }
public int[] A { get; set; }
public int B { get; set; }
}
}
}

0 comments on commit 19c9138

Please sign in to comment.