Skip to content

Commit

Permalink
feat: use IEnumerable constructors where possible (#342)
Browse files Browse the repository at this point in the history
  • Loading branch information
TimothyMakkison authored Apr 17, 2023
1 parent b5f4559 commit 74d0b52
Show file tree
Hide file tree
Showing 9 changed files with 293 additions and 459 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ public static class EnumerableMappingBuilder
if (immutableLinqMapping is not null)
return immutableLinqMapping;

// if target is a type that takes IEnumerable in its constructor
if (HasEnumerableConstructor(ctx, elementMapping.TargetType))
return BuildLinqConstructorMapping(ctx, elementMapping);

return ctx.IsExpression
? null
: BuildCustomTypeMapping(ctx, elementMapping);
Expand Down Expand Up @@ -117,6 +121,28 @@ private static LinqEnumerableMapping BuildLinqMapping(
return new LinqEnumerableMapping(ctx.Source, ctx.Target, elementMapping, selectMethod, collectMethod);
}

private static bool HasEnumerableConstructor(MappingBuilderContext ctx, ITypeSymbol typeSymbol)
{
if (ctx.Target is not INamedTypeSymbol namedType)
return false;

var typedEnumerable = ctx.Types.IEnumerableT.Construct(typeSymbol);

return namedType.Constructors.Any(m => m.Parameters.Length == 1
&& SymbolEqualityComparer.Default.Equals(m.Parameters[0].Type, typedEnumerable));
}

private static LinqConstructorMapping BuildLinqConstructorMapping(
MappingBuilderContext ctx,
ITypeMapping elementMapping)
{
var selectMethod = elementMapping.IsSynthetic
? null
: ResolveStaticMethod(ctx.Types.Enumerable, SelectMethodName);

return new LinqConstructorMapping(ctx.Source, ctx.Target, elementMapping, selectMethod);
}

private static ExistingTargetMappingMethodWrapper? BuildCustomTypeMapping(
MappingBuilderContext ctx,
ITypeMapping elementMapping)
Expand All @@ -127,12 +153,6 @@ private static LinqEnumerableMapping BuildLinqMapping(
return null;
}

if (ctx.Target.ImplementsGeneric(ctx.Types.StackT, out _))
return new ForEachAddEnumerableMapping(ctx.Source, ctx.Target, elementMapping, objectFactory, nameof(Stack<object>.Push));

if (ctx.Target.ImplementsGeneric(ctx.Types.QueueT, out _))
return new ForEachAddEnumerableMapping(ctx.Source, ctx.Target, elementMapping, objectFactory, nameof(Queue<object>.Enqueue));

// create a foreach loop with add calls if source is not an array
// and ICollection.Add(T): void is implemented and not explicit
// ensures add is not called and immutable types
Expand Down
50 changes: 50 additions & 0 deletions src/Riok.Mapperly/Descriptors/Mappings/LinqConstructorMapping.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Riok.Mapperly.Emit.SyntaxFactoryHelper;

namespace Riok.Mapperly.Descriptors.Mappings;

/// <summary>
/// Represents an enumerable mapping where the target type accepts IEnumerable as a single argument.
/// </summary>
public class LinqConstructorMapping : TypeMapping
{
private const string LambdaParamName = "x";

private readonly ITypeMapping _elementMapping;
private readonly IMethodSymbol? _selectMethod;

public LinqConstructorMapping(
ITypeSymbol sourceType,
ITypeSymbol targetType,
ITypeMapping elementMapping,
IMethodSymbol? selectMethod)
: base(sourceType, targetType)
{
_elementMapping = elementMapping;
_selectMethod = selectMethod;
}

public override ExpressionSyntax Build(TypeMappingBuildContext ctx)
{
var lambdaParamName = ctx.NameBuilder.New(LambdaParamName);

ExpressionSyntax mappedSource;

// Select / Map if needed
if (_selectMethod != null)
{
var sourceMapExpression = _elementMapping.Build(ctx.WithSource(lambdaParamName));
var convertLambda = SimpleLambdaExpression(Parameter(Identifier(lambdaParamName)))
.WithExpressionBody(sourceMapExpression);
mappedSource = StaticInvocation(_selectMethod, ctx.Source, convertLambda);
}
else
{
mappedSource = _elementMapping.Build(ctx);
}

return CreateInstance(TargetType, mappedSource);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ public partial int ParseableInt(string value)
target.FlatteningIdValue = DirectInt(testObject.Flattening.IdValue);
target.Unflattening.IdValue = DirectInt(testObject.UnflatteningIdValue);
target.SourceTargetSameObjectType = testObject.SourceTargetSameObjectType;
target.StackValue = MapToStack(testObject.StackValue);
target.QueueValue = MapToQueue(testObject.QueueValue);
target.ImmutableArrayValue = global::System.Collections.Immutable.ImmutableArray.ToImmutableArray(global::System.Linq.Enumerable.Select(testObject.ImmutableArrayValue, x1 => ParseableInt(x1)));
target.ImmutableListValue = global::System.Collections.Immutable.ImmutableList.ToImmutableList(global::System.Linq.Enumerable.Select(testObject.ImmutableListValue, x2 => ParseableInt(x2)));
target.ImmutableHashSetValue = global::System.Collections.Immutable.ImmutableHashSet.ToImmutableHashSet(global::System.Linq.Enumerable.Select(testObject.ImmutableHashSetValue, x3 => ParseableInt(x3)));
target.ImmutableQueueValue = global::System.Collections.Immutable.ImmutableQueue.CreateRange(global::System.Linq.Enumerable.Select(testObject.ImmutableQueueValue, x4 => ParseableInt(x4)));
target.ImmutableStackValue = global::System.Collections.Immutable.ImmutableStack.CreateRange(global::System.Linq.Enumerable.Select(testObject.ImmutableStackValue, x5 => ParseableInt(x5)));
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(testObject.ImmutableSortedSetValue, x6 => ParseableInt(x6)));
target.StackValue = new global::System.Collections.Generic.Stack<int>(global::System.Linq.Enumerable.Select(testObject.StackValue, x1 => ParseableInt(x1)));
target.QueueValue = new global::System.Collections.Generic.Queue<int>(global::System.Linq.Enumerable.Select(testObject.QueueValue, x2 => ParseableInt(x2)));
target.ImmutableArrayValue = global::System.Collections.Immutable.ImmutableArray.ToImmutableArray(global::System.Linq.Enumerable.Select(testObject.ImmutableArrayValue, x3 => ParseableInt(x3)));
target.ImmutableListValue = global::System.Collections.Immutable.ImmutableList.ToImmutableList(global::System.Linq.Enumerable.Select(testObject.ImmutableListValue, x4 => ParseableInt(x4)));
target.ImmutableHashSetValue = global::System.Collections.Immutable.ImmutableHashSet.ToImmutableHashSet(global::System.Linq.Enumerable.Select(testObject.ImmutableHashSetValue, x5 => ParseableInt(x5)));
target.ImmutableQueueValue = global::System.Collections.Immutable.ImmutableQueue.CreateRange(global::System.Linq.Enumerable.Select(testObject.ImmutableQueueValue, x6 => ParseableInt(x6)));
target.ImmutableStackValue = global::System.Collections.Immutable.ImmutableStack.CreateRange(global::System.Linq.Enumerable.Select(testObject.ImmutableStackValue, x7 => ParseableInt(x7)));
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(testObject.ImmutableSortedSetValue, x8 => ParseableInt(x8)));
target.EnumValue = (global::Riok.Mapperly.IntegrationTests.Dto.TestEnumDtoByValue)testObject.EnumValue;
target.EnumName = MapToEnumDtoByName(testObject.EnumName);
target.EnumRawValue = (byte)testObject.EnumRawValue;
Expand Down Expand Up @@ -154,14 +154,14 @@ public partial int ParseableInt(string value)
target.NestedNullableTargetNotNullable = MapToTestObjectNested(dto.NestedNullableTargetNotNullable);
target.StringNullableTargetNotNullable = dto.StringNullableTargetNotNullable;
target.SourceTargetSameObjectType = dto.SourceTargetSameObjectType;
target.StackValue = MapToStack1(dto.StackValue);
target.QueueValue = MapToQueue1(dto.QueueValue);
target.ImmutableArrayValue = global::System.Collections.Immutable.ImmutableArray.ToImmutableArray(global::System.Linq.Enumerable.Select(dto.ImmutableArrayValue, x => x.ToString()));
target.ImmutableListValue = global::System.Collections.Immutable.ImmutableList.ToImmutableList(global::System.Linq.Enumerable.Select(dto.ImmutableListValue, x1 => x1.ToString()));
target.ImmutableHashSetValue = global::System.Collections.Immutable.ImmutableHashSet.ToImmutableHashSet(global::System.Linq.Enumerable.Select(dto.ImmutableHashSetValue, x2 => x2.ToString()));
target.ImmutableQueueValue = global::System.Collections.Immutable.ImmutableQueue.CreateRange(global::System.Linq.Enumerable.Select(dto.ImmutableQueueValue, x3 => x3.ToString()));
target.ImmutableStackValue = global::System.Collections.Immutable.ImmutableStack.CreateRange(global::System.Linq.Enumerable.Select(dto.ImmutableStackValue, x4 => x4.ToString()));
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(dto.ImmutableSortedSetValue, x5 => x5.ToString()));
target.StackValue = new global::System.Collections.Generic.Stack<string>(global::System.Linq.Enumerable.Select(dto.StackValue, x => x.ToString()));
target.QueueValue = new global::System.Collections.Generic.Queue<string>(global::System.Linq.Enumerable.Select(dto.QueueValue, x1 => x1.ToString()));
target.ImmutableArrayValue = global::System.Collections.Immutable.ImmutableArray.ToImmutableArray(global::System.Linq.Enumerable.Select(dto.ImmutableArrayValue, x2 => x2.ToString()));
target.ImmutableListValue = global::System.Collections.Immutable.ImmutableList.ToImmutableList(global::System.Linq.Enumerable.Select(dto.ImmutableListValue, x3 => x3.ToString()));
target.ImmutableHashSetValue = global::System.Collections.Immutable.ImmutableHashSet.ToImmutableHashSet(global::System.Linq.Enumerable.Select(dto.ImmutableHashSetValue, x4 => x4.ToString()));
target.ImmutableQueueValue = global::System.Collections.Immutable.ImmutableQueue.CreateRange(global::System.Linq.Enumerable.Select(dto.ImmutableQueueValue, x5 => x5.ToString()));
target.ImmutableStackValue = global::System.Collections.Immutable.ImmutableStack.CreateRange(global::System.Linq.Enumerable.Select(dto.ImmutableStackValue, x6 => x6.ToString()));
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(dto.ImmutableSortedSetValue, x7 => x7.ToString()));
target.EnumValue = (global::Riok.Mapperly.IntegrationTests.Models.TestEnum)dto.EnumValue;
target.EnumName = (global::Riok.Mapperly.IntegrationTests.Models.TestEnum)dto.EnumName;
target.EnumRawValue = (global::Riok.Mapperly.IntegrationTests.Models.TestEnum)dto.EnumRawValue;
Expand Down Expand Up @@ -225,14 +225,14 @@ public partial void UpdateDto(global::Riok.Mapperly.IntegrationTests.Models.Test
target.StringValue = source.StringValue;
target.FlatteningIdValue = DirectInt(source.Flattening.IdValue);
target.SourceTargetSameObjectType = source.SourceTargetSameObjectType;
target.StackValue = MapToStack(source.StackValue);
target.QueueValue = MapToQueue(source.QueueValue);
target.ImmutableArrayValue = global::System.Collections.Immutable.ImmutableArray.ToImmutableArray(global::System.Linq.Enumerable.Select(source.ImmutableArrayValue, x1 => ParseableInt(x1)));
target.ImmutableListValue = global::System.Collections.Immutable.ImmutableList.ToImmutableList(global::System.Linq.Enumerable.Select(source.ImmutableListValue, x2 => ParseableInt(x2)));
target.ImmutableHashSetValue = global::System.Collections.Immutable.ImmutableHashSet.ToImmutableHashSet(global::System.Linq.Enumerable.Select(source.ImmutableHashSetValue, x3 => ParseableInt(x3)));
target.ImmutableQueueValue = global::System.Collections.Immutable.ImmutableQueue.CreateRange(global::System.Linq.Enumerable.Select(source.ImmutableQueueValue, x4 => ParseableInt(x4)));
target.ImmutableStackValue = global::System.Collections.Immutable.ImmutableStack.CreateRange(global::System.Linq.Enumerable.Select(source.ImmutableStackValue, x5 => ParseableInt(x5)));
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(source.ImmutableSortedSetValue, x6 => ParseableInt(x6)));
target.StackValue = new global::System.Collections.Generic.Stack<int>(global::System.Linq.Enumerable.Select(source.StackValue, x1 => ParseableInt(x1)));
target.QueueValue = new global::System.Collections.Generic.Queue<int>(global::System.Linq.Enumerable.Select(source.QueueValue, x2 => ParseableInt(x2)));
target.ImmutableArrayValue = global::System.Collections.Immutable.ImmutableArray.ToImmutableArray(global::System.Linq.Enumerable.Select(source.ImmutableArrayValue, x3 => ParseableInt(x3)));
target.ImmutableListValue = global::System.Collections.Immutable.ImmutableList.ToImmutableList(global::System.Linq.Enumerable.Select(source.ImmutableListValue, x4 => ParseableInt(x4)));
target.ImmutableHashSetValue = global::System.Collections.Immutable.ImmutableHashSet.ToImmutableHashSet(global::System.Linq.Enumerable.Select(source.ImmutableHashSetValue, x5 => ParseableInt(x5)));
target.ImmutableQueueValue = global::System.Collections.Immutable.ImmutableQueue.CreateRange(global::System.Linq.Enumerable.Select(source.ImmutableQueueValue, x6 => ParseableInt(x6)));
target.ImmutableStackValue = global::System.Collections.Immutable.ImmutableStack.CreateRange(global::System.Linq.Enumerable.Select(source.ImmutableStackValue, x7 => ParseableInt(x7)));
target.ImmutableSortedSetValue = global::System.Collections.Immutable.ImmutableSortedSet.ToImmutableSortedSet(global::System.Linq.Enumerable.Select(source.ImmutableSortedSetValue, x8 => ParseableInt(x8)));
target.EnumValue = (global::Riok.Mapperly.IntegrationTests.Dto.TestEnumDtoByValue)source.EnumValue;
target.EnumName = MapToEnumDtoByName(source.EnumName);
target.EnumRawValue = (byte)source.EnumRawValue;
Expand All @@ -254,28 +254,6 @@ private partial int PrivateDirectInt(int value)
return target;
}

private global::System.Collections.Generic.Stack<int> MapToStack(global::System.Collections.Generic.Stack<string> source)
{
var target = new global::System.Collections.Generic.Stack<int>();
foreach (var item in source)
{
target.Push(ParseableInt(item));
}

return target;
}

private global::System.Collections.Generic.Queue<int> MapToQueue(global::System.Collections.Generic.Queue<string> source)
{
var target = new global::System.Collections.Generic.Queue<int>();
foreach (var item in source)
{
target.Enqueue(ParseableInt(item));
}

return target;
}

private string MapToString(global::Riok.Mapperly.IntegrationTests.Models.TestEnum source)
{
return source switch
Expand Down Expand Up @@ -324,28 +302,6 @@ private string MapToString(global::Riok.Mapperly.IntegrationTests.Models.TestEnu
return target;
}

private global::System.Collections.Generic.Stack<string> MapToStack1(global::System.Collections.Generic.Stack<int> source)
{
var target = new global::System.Collections.Generic.Stack<string>();
foreach (var item in source)
{
target.Push(item.ToString());
}

return target;
}

private global::System.Collections.Generic.Queue<string> MapToQueue1(global::System.Collections.Generic.Queue<int> source)
{
var target = new global::System.Collections.Generic.Queue<string>();
foreach (var item in source)
{
target.Enqueue(item.ToString());
}

return target;
}

private global::Riok.Mapperly.IntegrationTests.Models.TestEnum MapToTestEnum(string source)
{
return source switch
Expand Down Expand Up @@ -376,4 +332,4 @@ private string MapToString1(global::Riok.Mapperly.IntegrationTests.Dto.TestEnumD
return target;
}
}
}
}
Loading

0 comments on commit 74d0b52

Please sign in to comment.