Skip to content

Commit

Permalink
feat: add support fot ImmutableDictionary and `ImmutableSortedDicti…
Browse files Browse the repository at this point in the history
…onary` (#351)
  • Loading branch information
TimothyMakkison authored Apr 18, 2023
1 parent 6ec7e80 commit 587abbc
Show file tree
Hide file tree
Showing 26 changed files with 338 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Abstractions;
using Riok.Mapperly.Descriptors.Mappings;
Expand All @@ -11,6 +12,9 @@ public static class DictionaryMappingBuilder
{
private const string CountPropertyName = nameof(IDictionary<object, object>.Count);

private const string ToImmutableDictionaryMethodName = nameof(ImmutableDictionary.ToImmutableDictionary);
private const string ToImmutableSortedDictionaryMethodName = nameof(ImmutableSortedDictionary.ToImmutableSortedDictionary);

public static ITypeMapping? TryBuildMapping(MappingBuilderContext ctx)
{
if (!ctx.IsConversionEnabled(MappingConversionType.Dictionary))
Expand Down Expand Up @@ -38,6 +42,11 @@ public static class DictionaryMappingBuilder
dictionaryObjectFactory);
}

// if target is an immutable dictionary then use LinqDictionaryMapper
var immutableLinqMapping = ResolveImmutableCollectMethod(ctx, keyMapping, valueMapping);
if (immutableLinqMapping != null)
return immutableLinqMapping;

// the target is not a well known dictionary type
// it should have a an object factory or a parameterless public ctor
if (!ctx.ObjectFactories.TryFindObjectFactory(ctx.Source, ctx.Target, out var objectFactory) && !ctx.Target.HasAccessibleParameterlessConstructor())
Expand Down Expand Up @@ -66,6 +75,14 @@ public static class DictionaryMappingBuilder
if (BuildKeyValueMapping(ctx) is not var (keyMapping, valueMapping))
return null;

// if target is an immutable dictionary then don't create a foreach loop
if (ctx.Target.OriginalDefinition.ImplementsGeneric(ctx.Types.IImmutableDictionaryT, out _))
{
ctx.ReportDiagnostic(DiagnosticDescriptors.CannotMapToReadOnlyMember);
return null;
}

// add values to dictionary by setting key values in a foreach loop
return new ForEachSetDictionaryExistingTargetMapping(
ctx.Source,
ctx.Target,
Expand Down Expand Up @@ -130,4 +147,16 @@ private static (ITypeSymbol, ITypeSymbol)? GetEnumerableKeyValueTypes(MappingBui

return (enumeratedType.TypeArguments[0], enumeratedType.TypeArguments[1]);
}

private static LinqDicitonaryMapping? ResolveImmutableCollectMethod(MappingBuilderContext ctx, ITypeMapping keyMapping, ITypeMapping valueMapping)
{
if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableSortedDictionaryT))
return new LinqDicitonaryMapping(ctx.Source, ctx.Target, ctx.Types.ImmutableSortedDictionary.GetStaticGenericMethod(ToImmutableSortedDictionaryMethodName)!, keyMapping, valueMapping);

// if taget is an ImmutableDictionary or implements interface IImmutableDictionary
if (ctx.Target.OriginalDefinition.ImplementsGeneric(ctx.Types.IImmutableDictionaryT, out _))
return new LinqDicitonaryMapping(ctx.Source, ctx.Target, ctx.Types.ImmutableDictionary.GetStaticGenericMethod(ToImmutableDictionaryMethodName)!, keyMapping, valueMapping);

return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ private static LinqEnumerableMapping BuildLinqMapping(
{
var collectMethod = collectMethodName == null
? null
: ResolveStaticMethod(ctx.Types.Enumerable, collectMethodName);
: ctx.Types.Enumerable.GetStaticGenericMethod(collectMethodName);

var selectMethod = elementMapping.IsSynthetic
? null
: ResolveStaticMethod(ctx.Types.Enumerable, SelectMethodName);
: ctx.Types.Enumerable.GetStaticGenericMethod(SelectMethodName);

return new LinqEnumerableMapping(ctx.Source, ctx.Target, elementMapping, selectMethod, collectMethod);
}
Expand All @@ -145,7 +145,7 @@ private static LinqConstructorMapping BuildLinqConstructorMapping(
{
var selectMethod = elementMapping.IsSynthetic
? null
: ResolveStaticMethod(ctx.Types.Enumerable, SelectMethodName);
: ctx.Types.Enumerable.GetStaticGenericMethod(SelectMethodName);

return new LinqConstructorMapping(ctx.Source, ctx.Target, elementMapping, selectMethod);
}
Expand Down Expand Up @@ -210,41 +210,34 @@ private static (bool CanMapWithLinq, string? CollectMethod) ResolveCollectMethod

var selectMethod = elementMapping.IsSynthetic
? null
: ResolveStaticMethod(ctx.Types.Enumerable, SelectMethodName);
: ctx.Types.Enumerable.GetStaticGenericMethod(SelectMethodName);

return new LinqEnumerableMapping(ctx.Source, ctx.Target, elementMapping, selectMethod, collectMethod);
}

private static IMethodSymbol? ResolveImmutableCollectMethod(MappingBuilderContext ctx)
{
if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableArrayT))
return ResolveStaticMethod(ctx.Types.ImmutableArray, ToImmutableArrayMethodName);
return ctx.Types.ImmutableArray.GetStaticGenericMethod(ToImmutableArrayMethodName);

if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableListT))
return ResolveStaticMethod(ctx.Types.ImmutableList, ToImmutableListMethodName);
return ctx.Types.ImmutableList.GetStaticGenericMethod(ToImmutableListMethodName);

if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableHashSetT))
return ResolveStaticMethod(ctx.Types.ImmutableHashSet, ToImmutableHashSetMethodName);
return ctx.Types.ImmutableHashSet.GetStaticGenericMethod(ToImmutableHashSetMethodName);

if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableQueueT))
return ResolveStaticMethod(ctx.Types.ImmutableQueue, CreateRangeQueueMethodName);
return ctx.Types.ImmutableQueue.GetStaticGenericMethod(CreateRangeQueueMethodName);

if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableStackT))
return ResolveStaticMethod(ctx.Types.ImmutableStack, CreateRangeStackMethodName);
return ctx.Types.ImmutableStack.GetStaticGenericMethod(CreateRangeStackMethodName);

if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableSortedSetT))
return ResolveStaticMethod(ctx.Types.ImmutableSortedSet, ToImmutableSortedSetMethodName);
return ctx.Types.ImmutableSortedSet.GetStaticGenericMethod(ToImmutableSortedSetMethodName);

return null;
}

private static IMethodSymbol? ResolveStaticMethod(INamedTypeSymbol namedType, string methodName)
{
return namedType.GetMembers(methodName)
.OfType<IMethodSymbol>()
.FirstOrDefault(m => m.IsStatic && m.IsGenericMethod);
}

private static ITypeSymbol? GetEnumeratedType(MappingBuilderContext ctx, ITypeSymbol type)
{
return type.ImplementsGeneric(ctx.Types.IEnumerableT, out var enumerableIntf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ namespace Riok.Mapperly.Descriptors.Mappings.ExistingTarget;
public class ForEachSetDictionaryExistingTargetMapping : ExistingTargetMapping
{
private const string LoopItemVariableName = "item";
private const string KeyValueKeyPropertyName = nameof(KeyValuePair<object, object>.Key);
private const string KeyValueValuePropertyName = nameof(KeyValuePair<object, object>.Value);
private const string KeyPropertyName = nameof(KeyValuePair<object, object>.Key);
private const string ValuePropertyName = nameof(KeyValuePair<object, object>.Value);

private readonly ITypeMapping _keyMapping;
private readonly ITypeMapping _valueMapping;
Expand All @@ -33,8 +33,8 @@ public override IEnumerable<StatementSyntax> Build(TypeMappingBuildContext ctx,
{
var loopItemVariableName = ctx.NameBuilder.New(LoopItemVariableName);

var convertedKeyExpression = _keyMapping.Build(ctx.WithSource(MemberAccess(loopItemVariableName, KeyValueKeyPropertyName)));
var convertedValueExpression = _valueMapping.Build(ctx.WithSource(MemberAccess(loopItemVariableName, KeyValueValuePropertyName)));
var convertedKeyExpression = _keyMapping.Build(ctx.WithSource(MemberAccess(loopItemVariableName, KeyPropertyName)));
var convertedValueExpression = _valueMapping.Build(ctx.WithSource(MemberAccess(loopItemVariableName, ValuePropertyName)));

var assignment = Assignment(
ElementAccess(target, convertedKeyExpression),
Expand Down
56 changes: 56 additions & 0 deletions src/Riok.Mapperly/Descriptors/Mappings/LinqDictionaryMapping.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Riok.Mapperly.Emit.SyntaxFactoryHelper;

namespace Riok.Mapperly.Descriptors.Mappings;

/// <summary>
/// Represents an enumerable mapping which works by using linq (select + collect).
/// </summary>
public class LinqDicitonaryMapping : TypeMapping
{
private const string LambdaParamName = "x";

private const string KeyPropertyName = nameof(KeyValuePair<object, object>.Key);
private const string ValuePropertyName = nameof(KeyValuePair<object, object>.Value);

private readonly IMethodSymbol _collectMethod;
private readonly ITypeMapping _keyMapping;
private readonly ITypeMapping _valueMapping;

public LinqDicitonaryMapping(
ITypeSymbol sourceType,
ITypeSymbol targetType,
IMethodSymbol collectMethod,
ITypeMapping keyMapping,
ITypeMapping valueMapping)
: base(sourceType, targetType)
{
_collectMethod = collectMethod;
_keyMapping = keyMapping;
_valueMapping = valueMapping;
}

public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
{
var lambdaParamName = ctx.NameBuilder.New(LambdaParamName);

// if key and value types do not change then use a simple call
// ie: source.ToImmutableDictionary();
if (_keyMapping.IsSynthetic && _valueMapping.IsSynthetic)
return StaticInvocation(_collectMethod, ctx.Source);

// create expressions mapping the key and value and then create the final expression
// ie: source.ToImmutableDictionary(x => x.Key, x=> (int)x.Value);
var keyMapExpression = _keyMapping.Build(ctx.WithSource(MemberAccess(lambdaParamName, KeyPropertyName)));
var keyExpression = SimpleLambdaExpression(Parameter(Identifier(lambdaParamName)))
.WithExpressionBody(keyMapExpression);

var valueMapExpression = _valueMapping.Build(ctx.WithSource(MemberAccess(lambdaParamName, ValuePropertyName)));
var valueExpression = SimpleLambdaExpression(Parameter(Identifier(lambdaParamName)))
.WithExpressionBody(valueMapExpression);

return StaticInvocation(_collectMethod, ctx.Source, keyExpression, valueExpression);
}
}
3 changes: 3 additions & 0 deletions src/Riok.Mapperly/Descriptors/WellKnownTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class WellKnownTypes
private INamedTypeSymbol? _immutableSortedSetT;
private INamedTypeSymbol? _immutableDictionary;
private INamedTypeSymbol? _immutableDictionaryT;
private INamedTypeSymbol? _iImmutableDictionaryT;
private INamedTypeSymbol? _immutableSortedDictionary;
private INamedTypeSymbol? _immutableSortedDictionaryT;

Expand Down Expand Up @@ -98,7 +99,9 @@ internal WellKnownTypes(Compilation compilation)
public INamedTypeSymbol ImmutableSortedSet => _immutableSortedSet ??= GetTypeSymbol(typeof(ImmutableSortedSet));
public INamedTypeSymbol ImmutableSortedSetT => _immutableSortedSetT ??= GetTypeSymbol(typeof(ImmutableSortedSet<>));
public INamedTypeSymbol ImmutableDictionary => _immutableDictionary ??= GetTypeSymbol(typeof(ImmutableDictionary));
public INamedTypeSymbol IImmutableDictionaryT => _iImmutableDictionaryT ??= GetTypeSymbol(typeof(IImmutableDictionary<,>));
public INamedTypeSymbol ImmutableDictionaryT => _immutableDictionaryT ??= GetTypeSymbol(typeof(ImmutableDictionary<,>));

public INamedTypeSymbol ImmutableSortedDictionary => _immutableSortedDictionary ??= GetTypeSymbol(typeof(ImmutableSortedDictionary));
public INamedTypeSymbol ImmutableSortedDictionaryT => _immutableSortedDictionaryT ??= GetTypeSymbol(typeof(ImmutableSortedDictionary<,>));

Expand Down
7 changes: 7 additions & 0 deletions src/Riok.Mapperly/Helpers/SymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ internal static IEnumerable<IMappableMember> GetAccessibleMappableMembers(this I
.WhereNotNull();
}

internal static IMethodSymbol? GetStaticGenericMethod(this INamedTypeSymbol namedType, string methodName)
{
return namedType.GetMembers(methodName)
.OfType<IMethodSymbol>()
.FirstOrDefault(m => m.IsStatic && m.IsGenericMethod);
}

internal static bool ImplementsGeneric(
this ITypeSymbol t,
INamedTypeSymbol genericInterfaceSymbol,
Expand Down
4 changes: 3 additions & 1 deletion test/Riok.Mapperly.IntegrationTests/BaseMapperTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ protected TestObject NewTestObj()
ImmutableHashSetValue = ImmutableHashSet.Create("1", "2", "3"),
ImmutableQueueValue = ImmutableQueue.Create("1", "2", "3"),
ImmutableStackValue = ImmutableStack.Create("1", "2", "3"),
ImmutableSortedSetValue = ImmutableSortedSet.Create("1", "2", "3")
ImmutableSortedSetValue = ImmutableSortedSet.Create("1", "2", "3"),
ImmutableDictionaryValue = new Dictionary<string, string>() { { "1", "1" }, { "2", "2" }, { "3", "3" } }.ToImmutableDictionary(),
ImmutableSortedDictionaryValue = new Dictionary<string, string>() { { "1", "1" }, { "2", "2" }, { "3", "3" } }.ToImmutableSortedDictionary(),
};
}

Expand Down
4 changes: 4 additions & 0 deletions test/Riok.Mapperly.IntegrationTests/Dto/TestObjectDto.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ public TestObjectDto(int ctorValue, int unknownValue = 10, int ctorValue2 = 100)

public ImmutableSortedSet<int> ImmutableSortedSetValue { get; set; } = ImmutableSortedSet<int>.Empty;

public ImmutableDictionary<int, int> ImmutableDictionaryValue { get; set; } = ImmutableDictionary<int, int>.Empty;

public ImmutableSortedDictionary<int, int> ImmutableSortedDictionaryValue { get; set; } = ImmutableSortedDictionary<int, int>.Empty;

public TestEnumDtoByValue EnumValue { get; set; }

public TestEnumDtoByName EnumName { get; set; }
Expand Down
4 changes: 4 additions & 0 deletions test/Riok.Mapperly.IntegrationTests/Models/TestObject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ public TestObject(int ctorValue, int unknownValue = 10, int ctorValue2 = 100)

public ImmutableSortedSet<string> ImmutableSortedSetValue { get; set; } = ImmutableSortedSet<string>.Empty;

public ImmutableDictionary<string, string> ImmutableDictionaryValue { get; set; } = ImmutableDictionary<string, string>.Empty;

public ImmutableSortedDictionary<string, string> ImmutableSortedDictionaryValue { get; set; } = ImmutableSortedDictionary<string, string>.Empty;

public TestEnum EnumValue { get; set; }

public TestEnum EnumName { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@
2,
3
],
ImmutableDictionaryValue: {
1: 1,
2: 2,
3: 3
},
ImmutableSortedDictionaryValue: {
1: 1,
2: 2,
3: 3
},
EnumValue: DtoValue1,
EnumName: Value10,
EnumRawValue: 20,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ public partial int ParseableInt(string value)
target.ImmutableQueueValue = global::System.Collections.Immutable.ImmutableQueue.CreateRange(global::System.Linq.Enumerable.Select(testObject.ImmutableQueueValue, x6 => ParseableInt(x6)));
target.ImmutableStackValue = global::System.Collections.Immutable.ImmutableStack.CreateRange(global::System.Linq.Enumerable.Select(testObject.ImmutableStackValue, x7 => ParseableInt(x7)));
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(testObject.ImmutableSortedSetValue, x8 => ParseableInt(x8)));
target.ImmutableDictionaryValue = global::System.Collections.Immutable.ImmutableDictionary.ToImmutableDictionary(testObject.ImmutableDictionaryValue, x9 => ParseableInt(x9.Key), x9 => ParseableInt(x9.Value));
target.ImmutableSortedDictionaryValue = global::System.Collections.Immutable.ImmutableSortedDictionary.ToImmutableSortedDictionary(testObject.ImmutableSortedDictionaryValue, x10 => ParseableInt(x10.Key), x10 => ParseableInt(x10.Value));
target.EnumValue = (global::Riok.Mapperly.IntegrationTests.Dto.TestEnumDtoByValue)testObject.EnumValue;
target.EnumName = MapToEnumDtoByName(testObject.EnumName);
target.EnumRawValue = (byte)testObject.EnumRawValue;
Expand Down Expand Up @@ -162,6 +164,8 @@ public partial int ParseableInt(string value)
target.ImmutableQueueValue = global::System.Collections.Immutable.ImmutableQueue.CreateRange(global::System.Linq.Enumerable.Select(dto.ImmutableQueueValue, x5 => x5.ToString()));
target.ImmutableStackValue = global::System.Collections.Immutable.ImmutableStack.CreateRange(global::System.Linq.Enumerable.Select(dto.ImmutableStackValue, x6 => x6.ToString()));
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(dto.ImmutableSortedSetValue, x7 => x7.ToString()));
target.ImmutableDictionaryValue = global::System.Collections.Immutable.ImmutableDictionary.ToImmutableDictionary(dto.ImmutableDictionaryValue, x8 => x8.Key.ToString(), x8 => x8.Value.ToString());
target.ImmutableSortedDictionaryValue = global::System.Collections.Immutable.ImmutableSortedDictionary.ToImmutableSortedDictionary(dto.ImmutableSortedDictionaryValue, x9 => x9.Key.ToString(), x9 => x9.Value.ToString());
target.EnumValue = (global::Riok.Mapperly.IntegrationTests.Models.TestEnum)dto.EnumValue;
target.EnumName = (global::Riok.Mapperly.IntegrationTests.Models.TestEnum)dto.EnumName;
target.EnumRawValue = (global::Riok.Mapperly.IntegrationTests.Models.TestEnum)dto.EnumRawValue;
Expand Down Expand Up @@ -233,6 +237,8 @@ public partial void UpdateDto(global::Riok.Mapperly.IntegrationTests.Models.Test
target.ImmutableQueueValue = global::System.Collections.Immutable.ImmutableQueue.CreateRange(global::System.Linq.Enumerable.Select(source.ImmutableQueueValue, x6 => ParseableInt(x6)));
target.ImmutableStackValue = global::System.Collections.Immutable.ImmutableStack.CreateRange(global::System.Linq.Enumerable.Select(source.ImmutableStackValue, x7 => ParseableInt(x7)));
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(source.ImmutableSortedSetValue, x8 => ParseableInt(x8)));
target.ImmutableDictionaryValue = global::System.Collections.Immutable.ImmutableDictionary.ToImmutableDictionary(source.ImmutableDictionaryValue, x9 => ParseableInt(x9.Key), x9 => ParseableInt(x9.Value));
target.ImmutableSortedDictionaryValue = global::System.Collections.Immutable.ImmutableSortedDictionary.ToImmutableSortedDictionary(source.ImmutableSortedDictionaryValue, x10 => ParseableInt(x10.Key), x10 => ParseableInt(x10.Value));
target.EnumValue = (global::Riok.Mapperly.IntegrationTests.Dto.TestEnumDtoByValue)source.EnumValue;
target.EnumName = MapToEnumDtoByName(source.EnumName);
target.EnumRawValue = (byte)source.EnumRawValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@
2,
3
],
ImmutableDictionaryValue: {
1: 1,
2: 2,
3: 3
},
ImmutableSortedDictionaryValue: {
1: 1,
2: 2,
3: 3
},
EnumValue: DtoValue1,
EnumName: Value10,
EnumRawValue: 20,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@
2,
3
],
ImmutableDictionaryValue: {
1: 1,
2: 2,
3: 3
},
ImmutableSortedDictionaryValue: {
1: 1,
2: 2,
3: 3
},
EnumValue: DtoValue1,
EnumName: Value10,
EnumRawValue: 20,
Expand Down
Loading

0 comments on commit 587abbc

Please sign in to comment.