Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSHARP-4872: Add support for Append in aggregate expressions. #1569

Merged
merged 1 commit into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ internal static class EnumerableMethod
private static readonly MethodInfo __all;
private static readonly MethodInfo __any;
private static readonly MethodInfo __anyWithPredicate;
private static readonly MethodInfo __append;
private static readonly MethodInfo __averageDecimal;
private static readonly MethodInfo __averageDecimalWithSelector;
private static readonly MethodInfo __averageDouble;
Expand Down Expand Up @@ -138,6 +139,7 @@ internal static class EnumerableMethod
private static readonly MethodInfo __ofType;
private static readonly MethodInfo __orderBy;
private static readonly MethodInfo __orderByDescending;
private static readonly MethodInfo __prepend;
private static readonly MethodInfo __range;
private static readonly MethodInfo __repeat;
private static readonly MethodInfo __reverse;
Expand Down Expand Up @@ -195,6 +197,7 @@ static EnumerableMethod()
__all = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.All(predicate));
__any = ReflectionInfo.Method((IEnumerable<object> source) => source.Any());
__anyWithPredicate = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.Any(predicate));
__append = ReflectionInfo.Method((IEnumerable<object> source, object element) => source.Append(element));
__averageDecimal = ReflectionInfo.Method((IEnumerable<decimal> source) => source.Average());
__averageDecimalWithSelector = ReflectionInfo.Method((IEnumerable<object> source, Func<object, decimal> selector) => source.Average(selector));
__averageDouble = ReflectionInfo.Method((IEnumerable<double> source) => source.Average());
Expand Down Expand Up @@ -301,6 +304,7 @@ static EnumerableMethod()
__ofType = ReflectionInfo.Method((IEnumerable source) => source.OfType<object>());
__orderBy = ReflectionInfo.Method((IEnumerable<object> source, Func<object, object> keySelector) => source.OrderBy(keySelector));
__orderByDescending = ReflectionInfo.Method((IEnumerable<object> source, Func<object, object> keySelector) => source.OrderByDescending(keySelector));
__prepend = ReflectionInfo.Method((IEnumerable<object> source, object element) => source.Prepend(element));
__range = ReflectionInfo.Method((int start, int count) => Enumerable.Range(start, count));
__repeat = ReflectionInfo.Method((object element, int count) => Enumerable.Repeat(element, count));
__reverse = ReflectionInfo.Method((IEnumerable<object> source) => source.Reverse());
Expand Down Expand Up @@ -357,6 +361,7 @@ static EnumerableMethod()
public static MethodInfo All => __all;
public static MethodInfo Any => __any;
public static MethodInfo AnyWithPredicate => __anyWithPredicate;
public static MethodInfo Append => __append;
public static MethodInfo AverageDecimal => __averageDecimal;
public static MethodInfo AverageDecimalWithSelector => __averageDecimalWithSelector;
public static MethodInfo AverageDouble => __averageDouble;
Expand Down Expand Up @@ -463,6 +468,7 @@ static EnumerableMethod()
public static MethodInfo OfType => __ofType;
public static MethodInfo OrderBy => __orderBy;
public static MethodInfo OrderByDescending => __orderByDescending;
public static MethodInfo Prepend => __prepend;
public static MethodInfo Range => __range;
public static MethodInfo Repeat => __repeat;
public static MethodInfo Reverse => __reverse;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ internal static class QueryableMethod
private static readonly MethodInfo __all;
private static readonly MethodInfo __any;
private static readonly MethodInfo __anyWithPredicate;
private static readonly MethodInfo __append;
private static readonly MethodInfo __asQueryable;
private static readonly MethodInfo __averageDecimal;
private static readonly MethodInfo __averageDecimalWithSelector;
Expand Down Expand Up @@ -86,6 +87,7 @@ internal static class QueryableMethod
private static readonly MethodInfo __ofType;
private static readonly MethodInfo __orderBy;
private static readonly MethodInfo __orderByDescending;
private static readonly MethodInfo __prepend;
private static readonly MethodInfo __reverse;
private static readonly MethodInfo __select;
private static readonly MethodInfo __selectMany;
Expand Down Expand Up @@ -136,6 +138,7 @@ static QueryableMethod()
__all = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.All(predicate));
__any = ReflectionInfo.Method((IQueryable<object> source) => source.Any());
__anyWithPredicate = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, bool>> predicate) => source.Any(predicate));
__append = ReflectionInfo.Method((IQueryable<object> source, object element) => source.Append(element));
__asQueryable = ReflectionInfo.Method((IEnumerable<object> source) => source.AsQueryable());
__averageDecimal = ReflectionInfo.Method((IQueryable<decimal> source) => source.Average());
__averageDecimalWithSelector = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, decimal>> selector) => source.Average(selector));
Expand Down Expand Up @@ -192,6 +195,7 @@ static QueryableMethod()
__ofType = ReflectionInfo.Method((IQueryable source) => source.OfType<object>());
__orderBy = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, object>> keySelector) => source.OrderBy(keySelector));
__orderByDescending = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, object>> keySelector) => source.OrderByDescending(keySelector));
__prepend = ReflectionInfo.Method((IQueryable<object> source, object element) => source.Prepend(element));
__reverse = ReflectionInfo.Method((IQueryable<object> source) => source.Reverse());
__select = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, object>> selector) => source.Select(selector));
__selectMany = ReflectionInfo.Method((IQueryable<object> source, Expression<Func<object, IEnumerable<object>>> selector) => source.SelectMany(selector));
Expand Down Expand Up @@ -241,6 +245,7 @@ static QueryableMethod()
public static MethodInfo All => __all;
public static MethodInfo Any => __any;
public static MethodInfo AnyWithPredicate => __anyWithPredicate;
public static MethodInfo Append => __append;
public static MethodInfo AsQueryable => __asQueryable;
public static MethodInfo AverageDecimal => __averageDecimal;
public static MethodInfo AverageDecimalWithSelector => __averageDecimalWithSelector;
Expand Down Expand Up @@ -297,6 +302,7 @@ static QueryableMethod()
public static MethodInfo OfType => __ofType;
public static MethodInfo OrderBy => __orderBy;
public static MethodInfo OrderByDescending => __orderByDescending;
public static MethodInfo Prepend => __prepend;
public static MethodInfo Reverse => __reverse;
public static MethodInfo Select => __select;
public static MethodInfo SelectMany => __selectMany;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ public static AggregationExpression Translate(TranslationContext context, Method
case "AddYears":
return DateTimeAddOrSubtractMethodToAggregationExpressionTranslator.Translate(context, expression);

case "Append":
case "Prepend":
return AppendOrPrependMethodToAggregationExpressionTranslator.Translate(context, expression);

case "Bottom":
case "BottomN":
case "FirstN":
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq.Expressions;
using System.Reflection;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
{
internal static class AppendOrPrependMethodToAggregationExpressionTranslator
{
private static readonly MethodInfo[] __appendOrPrependMethods =
{
EnumerableMethod.Append,
EnumerableMethod.Prepend,
QueryableMethod.Append,
QueryableMethod.Prepend
};

private static readonly MethodInfo[] __appendMethods =
{
EnumerableMethod.Append,
QueryableMethod.Append
};

public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
{
var method = expression.Method;
var arguments = expression.Arguments;

if (method.IsOneOf(__appendOrPrependMethods))
{
var sourceExpression = arguments[0];
var elementExpression = arguments[1];

var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);

AggregationExpression elementTranslation;
if (elementExpression is ConstantExpression elementConstantExpression)
{
var value = elementConstantExpression.Value;
var serializedValue = SerializationHelper.SerializeValue(itemSerializer, value);
elementTranslation = new AggregationExpression(elementExpression, AstExpression.Constant(serializedValue), itemSerializer);
}
else
{
elementTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, elementExpression);
if (!elementTranslation.Serializer.Equals(itemSerializer))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check for compatible serializers instead? For example if collection is of long and appended item is int.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should never happen. By the time we get here both serializers will always have the exact same ValueType.

The only question is whether they are serialized the same way.

{
throw new ExpressionNotSupportedException(expression, because: "argument serializers are not compatible");
}
}

var ast = method.IsOneOf(__appendMethods) ?
AstExpression.ConcatArrays(sourceTranslation.Ast, AstExpression.ComputedArray(elementTranslation.Ast)) :
AstExpression.ConcatArrays(AstExpression.ComputedArray(elementTranslation.Ast), sourceTranslation.Ast);
var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer);

return new AggregationExpression(expression, ast, serializer);
}

throw new ExpressionNotSupportedException(expression);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq;
using FluentAssertions;
using MongoDB.TestHelpers.XunitExtensions;
using Xunit;

namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira
{
public class CSharp4872Tests : Linq3IntegrationTest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I be able to Append int to collection of longs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but...

The compiler will insert a Convert from int to long so in the end you are actually appending a long.

{
[Theory]
[ParameterAttributeData]
public void Append_constant_should_work(
[Values(false, true)] bool withNestedAsQueryable)
{
var collection = GetCollection();

var queryable = withNestedAsQueryable ?
collection.AsQueryable().Select(x => x.A.AsQueryable().Append(4).ToList()) :
collection.AsQueryable().Select(x => x.A.Append(4).ToList());

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { _v : { $concatArrays : ['$A', [4]] }, _id : 0 } }");

var result = queryable.Single();
result.Should().Equal(1, 2, 3, 4);
}

[Theory]
[ParameterAttributeData]
public void Append_expression_should_work(
[Values(false, true)] bool withNestedAsQueryable)
{
var collection = GetCollection();

var queryable = withNestedAsQueryable ?
collection.AsQueryable().Select(x => x.A.AsQueryable().Append(x.B).ToList()) :
collection.AsQueryable().Select(x => x.A.Append(x.B).ToList());

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { _v : { $concatArrays : ['$A', ['$B']] }, _id : 0 } }");

var result = queryable.Single();
result.Should().Equal(1, 2, 3, 4);
}

[Theory]
[ParameterAttributeData]
public void Prepend_constant_should_work(
[Values(false, true)] bool withNestedAsQueryable)
{
var collection = GetCollection();

var queryable = withNestedAsQueryable ?
collection.AsQueryable().Select(x => x.A.AsQueryable().Prepend(4).ToList()) :
collection.AsQueryable().Select(x => x.A.Prepend(4).ToList());

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { _v : { $concatArrays : [[4], '$A'] }, _id : 0 } }");

var result = queryable.Single();
result.Should().Equal(4, 1, 2, 3);
}

[Theory]
[ParameterAttributeData]
public void Prepend_expression_should_work(
[Values(false, true)] bool withNestedAsQueryable)
{
var collection = GetCollection();

var queryable = withNestedAsQueryable ?
collection.AsQueryable().Select(x => x.A.AsQueryable().Prepend(x.B).ToList()) :
collection.AsQueryable().Select(x => x.A.Prepend(x.B).ToList());

var stages = Translate(collection, queryable);
AssertStages(stages, "{ $project : { _v : { $concatArrays : [['$B'], '$A'] }, _id : 0 } }");

var result = queryable.Single();
result.Should().Equal(4, 1, 2, 3);
}

private IMongoCollection<C> GetCollection()
{
var collection = GetCollection<C>("test");
CreateCollection(
collection,
new C { Id = 1, A = [1, 2, 3], B = 4 });
return collection;
}

private class C
{
public int Id { get; set; }
public int[] A { get; set; }
public int B { get; set; }
}
}
}