From 4b673872b88fa98b3f7d57574014b8ea49ff6888 Mon Sep 17 00:00:00 2001 From: Oleksandr Poliakov Date: Tue, 18 Mar 2025 17:06:23 -0700 Subject: [PATCH] CSHARP-4779: Support Dictionary(IEnumerable> collection) constructor in LINQ3 --- .../Ast/Optimizers/AstSimplifier.cs | 57 +++++++ .../Reflection/DictionaryConstructor.cs | 34 ++++ ...essionToAggregationExpressionTranslator.cs | 104 ++++++++++++ ...essionToAggregationExpressionTranslator.cs | 4 + ...nToAggregationExpressionTranslatorTests.cs | 155 ++++++++++++++++++ 5 files changed, 354 insertions(+) create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs create mode 100644 src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs create mode 100644 tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs index 206fd8b308c..5b1fd53d478 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs @@ -14,6 +14,7 @@ */ using System; +using System.Collections.Generic; using System.Linq; using MongoDB.Bson; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; @@ -454,8 +455,42 @@ public override AstNode VisitMapExpression(AstMapExpression node) } } + if (node.In is AstComputedDocumentExpression inComputedDocumentExpression && + inComputedDocumentExpression.Fields.All(f => f.Value is AstGetFieldExpression getFieldExpression && getFieldExpression.Input == node.As && getFieldExpression.CanBeConvertedToFieldPath())) + { + + // { $map : { input : { $map : { input : , as : "y", in : { A : "$$y.FieldA" } } }, as: "v", in : { B : '$$v.A' } } } => { $map : { input : , as: "v", in : { B : "$$v.FieldA" } } } + if (node.Input is AstMapExpression inputMapExpression && + inputMapExpression.In is AstComputedDocumentExpression innerInComputedDocumentExpression) + { + var simplified = AstExpression.Map( + inputMapExpression.Input, + inputMapExpression.As, + AstExpression.ComputedDocument(inComputedDocumentExpression.Fields.Select(f => RemapField(f, node.As.Name, innerInComputedDocumentExpression.Fields)))); + + return Visit(simplified); + } + + // { $map : { input : [{ A: "$$ROOT.FieldA" }], as : "v", in: { B : "$$v.A" } } } => [{ B : "$FieldA }] + if (node.Input is AstComputedArrayExpression inputArrayExpression && + inputArrayExpression.Items.All(i => i is AstComputedDocumentExpression)) + { + var simplified = AstExpression.ComputedArray(inputArrayExpression.Items.Select(i => + AstExpression.ComputedDocument(inComputedDocumentExpression.Fields.Select(f => RemapField(f, node.As.Name, ((AstComputedDocumentExpression)i).Fields))))); + return Visit(simplified); + } + } + return base.VisitMapExpression(node); + static AstComputedField RemapField(AstComputedField field, string @as, IEnumerable innerFields) + { + var fieldPath = ((AstGetFieldExpression)field.Value).ConvertToFieldPath().Replace($"$${@as}.", string.Empty); + var innerField = innerFields.Single(f => f.Path == fieldPath); + + return AstExpression.ComputedField(field.Path, innerField.Value); + } + static AstExpression UltimateGetFieldInput(AstGetFieldExpression getField) { if (getField.Input is AstGetFieldExpression nestedInputGetField) @@ -574,7 +609,29 @@ arg is AstBinaryExpression argBinaryExpression && return AstExpression.Binary(oppositeComparisonOperator, argBinaryExpression.Arg1, argBinaryExpression.Arg2); } + // { $arrayToObject : [[{ k : 'A', v : '$A' }, { k : 'B', v : '$B' }]] } => { A : '$A', B : '$B' } + if (node.Operator is AstUnaryOperator.ArrayToObject && + arg is AstComputedArrayExpression computedArrayExpression && + computedArrayExpression.Items.All( + i => i is AstComputedDocumentExpression computedDocumentExpression && + computedDocumentExpression.Fields.FirstOrDefault(f => f.Path == "k")?.Value is AstConstantExpression && + computedDocumentExpression.Fields.Any(f => f.Path == "v")) + ) + { + var fields = computedArrayExpression.Items.Select(KeyValuePairDocumentToComputedField); + return AstExpression.ComputedDocument(fields); + } + return node.Update(arg); + + static AstComputedField KeyValuePairDocumentToComputedField(AstExpression expression) + { + var documentExpression = (AstComputedDocumentExpression)expression; + var keyExpression = documentExpression.Fields.First(f => f.Path == "k").Value; + var valueExpression = documentExpression.Fields.First(f => f.Path == "v").Value; + + return AstExpression.ComputedField(((AstConstantExpression)keyExpression).Value.AsString, valueExpression); + } } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs new file mode 100644 index 00000000000..37be3984270 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs @@ -0,0 +1,34 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class DictionaryConstructor + { + // public static methods + public static bool IsIEnumerableKeyValuePairConstructor(ConstructorInfo ctor) + { + var parameters = ctor.GetParameters(); + return parameters.Length == 1 && + parameters[0].ParameterType.ImplementsIEnumerable(out var enumerableType) && + enumerableType.IsConstructedGenericType && + enumerableType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs new file mode 100644 index 00000000000..72dd19ccb18 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs @@ -0,0 +1,104 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators +{ + internal static class NewDictionaryExpressionToAggregationExpressionTranslator + { + public static TranslatedExpression Translate(TranslationContext context, NewExpression expression) + { + var arguments = expression.Arguments; + var collectionExpression = arguments.Single(); + var collectionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, collectionExpression); + + if (collectionTranslation.Serializer is IBsonArraySerializer bsonArraySerializer && + bsonArraySerializer.TryGetItemSerializationInfo(out var itemSerializationInfo)) + { + IBsonSerializer keySerializer = null; + IBsonSerializer valueSerializer = null; + AstExpression collectionTranslationAst; + + if (itemSerializationInfo.Serializer is IRepresentationConfigurable { Representation: BsonType.Array }) + { + collectionTranslationAst = collectionTranslation.Ast; + } + else if (itemSerializationInfo.Serializer is IBsonDocumentSerializer itemDocumentSerializer) + { + if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Key", out var keyMemberSerializationInfo) || + !itemDocumentSerializer.TryGetMemberSerializationInfo("Value", out var valueMemberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"document serializer class {itemSerializationInfo.Serializer.GetType()} does not provide member serialization info for required fields."); + } + + if (keyMemberSerializationInfo.ElementName == "k" && valueMemberSerializationInfo.ElementName == "v") + { + collectionTranslationAst = collectionTranslation.Ast; + } + else + { + keySerializer = keyMemberSerializationInfo.Serializer; + valueSerializer = valueMemberSerializationInfo.Serializer; + + var pairVar = AstExpression.Var("pair"); + var computedDocumentAst = AstExpression.ComputedDocument([ + AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyMemberSerializationInfo.ElementName)), + AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueMemberSerializationInfo.ElementName)) + ]); + collectionTranslationAst = AstExpression.Map(collectionTranslation.Ast, pairVar, computedDocumentAst); + } + } + else + { + throw new ExpressionNotSupportedException(expression, because: $"document serializer class {itemSerializationInfo.Serializer.GetType()} does not implement {nameof(IBsonDocumentSerializer)}"); + } + + if (keySerializer is not IRepresentationConfigurable { Representation: BsonType.String }) + { + throw new ExpressionNotSupportedException(expression, because: "key did not serialize as a string"); + } + + var ast = AstExpression.Unary(AstUnaryOperator.ArrayToObject, collectionTranslationAst); + var resultSerializer = CreateDictionarySerializer(keySerializer, valueSerializer); + return new TranslatedExpression(expression, ast, resultSerializer); + } + + throw new ExpressionNotSupportedException(expression); + } + + public static bool CanTranslate(NewExpression expression) + => expression.Type.IsConstructedGenericType && + expression.Type.GetGenericTypeDefinition() == typeof(Dictionary<,>) && + DictionaryConstructor.IsIEnumerableKeyValuePairConstructor(expression.Constructor); + + private static IBsonSerializer CreateDictionarySerializer(IBsonSerializer keySerializer, IBsonSerializer valueSerializer) + { + var dictionaryType = typeof(Dictionary<,>).MakeGenericType(keySerializer.ValueType, valueSerializer.ValueType); + var serializerType = typeof(DictionaryInterfaceImplementerSerializer<,,>).MakeGenericType(dictionaryType, keySerializer.ValueType, valueSerializer.ValueType); + + return (IBsonSerializer)Activator.CreateInstance(serializerType, DictionaryRepresentation.Document, keySerializer, valueSerializer); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs index b54f431e516..af521d05658 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs @@ -50,6 +50,10 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr { return NewKeyValuePairExpressionToAggregationExpressionTranslator.Translate(context, expression); } + if (NewDictionaryExpressionToAggregationExpressionTranslator.CanTranslate(expression)) + { + return NewDictionaryExpressionToAggregationExpressionTranslator.Translate(context, expression); + } return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty()); } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs new file mode 100644 index 00000000000..59931e8f199 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs @@ -0,0 +1,155 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#if NET6_0_OR_GREATER || NETCOREAPP3_1_OR_GREATER + +using System; +using System.Collections.Generic; +using System.Linq; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver.Linq; +using MongoDB.Driver.TestHelpers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators +{ + public class NewDictionaryExpressionToAggregationExpressionTranslatorTests : LinqIntegrationTest + { + public NewDictionaryExpressionToAggregationExpressionTranslatorTests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_should_translate() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { new KeyValuePair("A", d.A), new KeyValuePair("B", d.B) })); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { A : '$A', B: '$B' }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ ["A"] = "a", ["B"] = "b" }); + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_Create_should_translate() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { KeyValuePair.Create("A", d.A), KeyValuePair.Create("B", d.B) })); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { A : '$A', B: '$B' }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ ["A"] = "a", ["B"] = "b" }); + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_should_translate_Guid_as_string_key() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { new KeyValuePair(d.GuidAsString, d.A) })); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { $arrayToObject : [[{ k : '$GuidAsString', v : '$A' }]] }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ [Guid.Parse("3E9AE467-9705-4C17-9655-EE7730BCC2EE")] = "a" }); + } + + + [Fact] + public void NewDictionary_with_KeyValuePairs_should_translate_dynamic_array() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + d.Items.Select(i => new KeyValuePair(i.P, i.W)))); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { $arrayToObject : { $map: { input: '$Items', as: 'i', in: { k: '$$i.P', v: '$$i.W' } } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ ["x"] = "y" }); + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_throws_on_non_string_key() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { new KeyValuePair(42, d.A) })); + + var exception = Record.Exception(() => queryable.ToList()); + + exception.Should().NotBeNull(); + exception.Should().BeOfType(); + } + + public class C + { + public string A { get; set; } + + public string B { get; set; } + + [BsonRepresentation(BsonType.String)] + public Guid GuidAsString { get; set; } + + public Item[] Items { get; set; } + } + + public class Item + { + public string P { get; set; } + + public string W { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new C + { + A = "a", + B = "b", + GuidAsString = Guid.Parse("3E9AE467-9705-4C17-9655-EE7730BCC2EE"), + Items = [ new Item { P = "x", W = "y" } ] + }, + ]; + } + } +} +#endif