Skip to content

Commit

Permalink
fix: insert a cast to idictionary when an explicit setter is present (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
TimothyMakkison authored Apr 18, 2023
1 parent 587abbc commit 20f6f06
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ namespace Riok.Mapperly.Descriptors.MappingBuilders;
public static class DictionaryMappingBuilder
{
private const string CountPropertyName = nameof(IDictionary<object, object>.Count);
private const string SetterIndexerPropertyName = "set_Item";

private const string ToImmutableDictionaryMethodName = nameof(ImmutableDictionary.ToImmutableDictionary);
private const string ToImmutableSortedDictionaryMethodName = nameof(ImmutableSortedDictionary.ToImmutableSortedDictionary);
Expand Down Expand Up @@ -64,7 +65,8 @@ public static class DictionaryMappingBuilder
keyMapping,
valueMapping,
false,
objectFactory: objectFactory);
objectFactory: objectFactory,
explicitCast: GetExplicitIndexer(ctx));
}

public static IExistingTargetMapping? TryBuildExistingTargetMapping(MappingBuilderContext ctx)
Expand All @@ -87,7 +89,8 @@ public static class DictionaryMappingBuilder
ctx.Source,
ctx.Target,
keyMapping,
valueMapping);
valueMapping,
explicitCast: GetExplicitIndexer(ctx));
}

private static (ITypeMapping, ITypeMapping)? BuildKeyValueMapping(MappingBuilderContext ctx)
Expand Down Expand Up @@ -148,6 +151,14 @@ private static (ITypeSymbol, ITypeSymbol)? GetEnumerableKeyValueTypes(MappingBui
return (enumeratedType.TypeArguments[0], enumeratedType.TypeArguments[1]);
}

private static INamedTypeSymbol? GetExplicitIndexer(MappingBuilderContext ctx)
{
if (ctx.Target.ImplementsGeneric(ctx.Types.IDictionaryT, SetterIndexerPropertyName, out var typedInter, out var isExplicit) && !isExplicit)
return null;

return typedInter;
}

private static LinqDicitonaryMapping? ResolveImmutableCollectMethod(MappingBuilderContext ctx, ITypeMapping keyMapping, ITypeMapping valueMapping)
{
if (SymbolEqualityComparer.Default.Equals(ctx.Target.OriginalDefinition, ctx.Types.ImmutableSortedDictionaryT))
Expand All @@ -159,4 +170,5 @@ private static (ITypeSymbol, ITypeSymbol)? GetEnumerableKeyValueTypes(MappingBui

return null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,40 @@ namespace Riok.Mapperly.Descriptors.Mappings.ExistingTarget;
public class ForEachSetDictionaryExistingTargetMapping : ExistingTargetMapping
{
private const string LoopItemVariableName = "item";
private const string ExplicitCastVariableName = "targetDict";
private const string KeyPropertyName = nameof(KeyValuePair<object, object>.Key);
private const string ValuePropertyName = nameof(KeyValuePair<object, object>.Value);

private readonly ITypeMapping _keyMapping;
private readonly ITypeMapping _valueMapping;
private readonly INamedTypeSymbol? _explicitCast;

public ForEachSetDictionaryExistingTargetMapping(
ITypeSymbol sourceType,
ITypeSymbol targetType,
ITypeMapping keyMapping,
ITypeMapping valueMapping)
ITypeMapping valueMapping,
INamedTypeSymbol? explicitCast)
: base(sourceType, targetType)
{
_keyMapping = keyMapping;
_valueMapping = valueMapping;
_explicitCast = explicitCast;
}

public override IEnumerable<StatementSyntax> Build(TypeMappingBuildContext ctx, ExpressionSyntax target)
{
if (_explicitCast != null)
{
var type = FullyQualifiedIdentifier(_explicitCast);
var cast = CastExpression(type, target);

var castedVariable = ctx.NameBuilder.New(ExplicitCastVariableName);
target = IdentifierName(castedVariable);

yield return LocalDeclarationStatement(DeclareVariable(castedVariable, cast));
}

var loopItemVariableName = ctx.NameBuilder.New(LoopItemVariableName);

var convertedKeyExpression = _keyMapping.Build(ctx.WithSource(MemberAccess(loopItemVariableName, KeyPropertyName)));
Expand All @@ -40,13 +55,10 @@ public override IEnumerable<StatementSyntax> Build(TypeMappingBuildContext ctx,
ElementAccess(target, convertedKeyExpression),
convertedValueExpression);

return new StatementSyntax[]
{
ForEachStatement(
yield return ForEachStatement(
VarIdentifier,
Identifier(loopItemVariableName),
ctx.Source,
Block(ExpressionStatement(assignment))),
};
Block(ExpressionStatement(assignment)));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ public ForEachSetDictionaryMapping(
ITypeMapping valueMapping,
bool sourceHasCount,
ITypeSymbol? typeToInstantiate = null,
ObjectFactory? objectFactory = null)
: base(new ForEachSetDictionaryExistingTargetMapping(sourceType, targetType, keyMapping, valueMapping))
ObjectFactory? objectFactory = null,
INamedTypeSymbol? explicitCast = null)
: base(new ForEachSetDictionaryExistingTargetMapping(sourceType, targetType, keyMapping, valueMapping, explicitCast))
{
_sourceHasCount = sourceHasCount;
_objectFactory = objectFactory;
Expand Down
46 changes: 35 additions & 11 deletions src/Riok.Mapperly/Helpers/SymbolExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,28 +100,52 @@ internal static bool ImplementsGeneric(
return genericIntf != null;
}

internal static bool HasImplicitInterfaceMethod(this ITypeSymbol symbol, INamedTypeSymbol inter, string methodName)
internal static bool ImplementsGeneric(
this ITypeSymbol t,
INamedTypeSymbol genericInterfaceSymbol,
string symbolName,
[NotNullWhen(true)] out INamedTypeSymbol? genericIntf,
out bool isExplicit)
{
// return true if symbol is the same interface - does not check that the method is implemented so class A : IList<T> { } will be accepted
if (SymbolEqualityComparer.Default.Equals(symbol.OriginalDefinition, inter))
if (SymbolEqualityComparer.Default.Equals(t.OriginalDefinition, genericInterfaceSymbol))
{
genericIntf = (INamedTypeSymbol)t;
isExplicit = false;
return true;
}

genericIntf = t.AllInterfaces.FirstOrDefault(x => x.IsGenericType && SymbolEqualityComparer.Default.Equals(x.OriginalDefinition, genericInterfaceSymbol));

// return false if it does not implement the interface
if (!symbol.ImplementsGeneric(inter, out var typedInter))
if (genericIntf == null)
{
isExplicit = false;
return false;
}

var interfaceMethodSymbol = typedInter.GetMembers(methodName).OfType<IMethodSymbol>().Single();
var interfaceSymbol = genericIntf.GetMembers(symbolName).First();

var methodInterImplementaton = symbol.FindImplementationForInterfaceMember(interfaceMethodSymbol) as IMethodSymbol;
var symbolImplementaton = t.FindImplementationForInterfaceMember(interfaceSymbol);

// if null then the method is unimplemented
// symbol implements genericInterface but has not implemented the corresponding methods
// this can only occur in unit tests
if (methodInterImplementaton is null)
return false;
if (symbolImplementaton == null)
throw new NotSupportedException("Symbol implementation cannot be null for objects implementing interface.");

// check if symbol is explicit
isExplicit = symbolImplementaton switch
{
IMethodSymbol methodSymbol => methodSymbol.ExplicitInterfaceImplementations.Any(),
IPropertySymbol propertySymbol => propertySymbol.ExplicitInterfaceImplementations.Any(),
_ => throw new NotImplementedException(),
};

// check if methodImplementation is explicit
return !methodInterImplementaton.ExplicitInterfaceImplementations.Any();
return true;
}

internal static bool HasImplicitInterfaceMethod(this ITypeSymbol symbol, INamedTypeSymbol inter, string methodName)
{
return symbol.ImplementsGeneric(inter, methodName, out _, out var isExplicit) && !isExplicit;
}

internal static bool CanConsumeType(this ITypeParameterSymbol typeParameter, Compilation compilation, ITypeSymbol type)
Expand Down
160 changes: 160 additions & 0 deletions test/Riok.Mapperly.Tests/Mapping/DictionaryTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,166 @@ public void DictionaryToCustomDictionaryWithObjectFactory()
""");
}

[Fact]
public void IDictionaryToExplicitDictionaryShouldCast()
{
var source = TestSourceBuilder.Mapping(
"IDictionary<string, string>",
"A",
"""
public class A : IDictionary<string, string>
{
string IDictionary<string, string>.this[string key]
{
get => _dictionaryImplementation[key];
set => _dictionaryImplementation[key] = value;
}
}
""");

TestHelper.GenerateMapper(source)
.Should()
.HaveSingleMethodBody(
"""
var target = new global::A();
var targetDict = (global::System.Collections.Generic.IDictionary<string, string>)target;
foreach (var item in source)
{
targetDict[item.Key] = item.Value;
}
return target;
""");
}

[Fact]
public void DictionaryToImplicitDictionaryShouldNotCast()
{
var source = TestSourceBuilder.Mapping(
"Dictionary<string, string>",
"A",
"""
public class A : IDictionary<string, string>
{
public string this[string key]
{
get => _dictionaryImplementation[key];
set => _dictionaryImplementation[key] = value;
}
}
""");

TestHelper.GenerateMapper(source)
.Should()
.HaveSingleMethodBody(
"""
var target = new global::A();
foreach (var item in source)
{
target[item.Key] = item.Value;
}
return target;
""");
}

[Fact]
public void DictionaryToExistingExplicitDictionaryShouldCast()
{
var source = TestSourceBuilder.Mapping(
"A",
"B",
"class A { public Dictionary<string, string> Values { get; } }",
"class B { public C Values { get; } }",
"""
public class C : IDictionary<string, string>
{
string IDictionary<string, string>.this[string key]
{
get => _dictionaryImplementation[key];
set => _dictionaryImplementation[key] = value;
}
}
""");

TestHelper.GenerateMapper(source)
.Should()
.HaveSingleMethodBody(
"""
var target = new global::B();
var targetDict = (global::System.Collections.Generic.IDictionary<string, string>)target.Values;
foreach (var item in source.Values)
{
targetDict[item.Key] = item.Value;
}
return target;
""");
}

[Fact]
public void DictionaryToExplicitDictionaryWithObjectFactoryShouldCast()
{
var source = TestSourceBuilder.MapperWithBodyAndTypes(
"[ObjectFactory] A CreateA() => new();"
+ "partial A Map(Dictionary<string, string> source);",
"""
public class A : IDictionary<string, string>
{
string IDictionary<string, string>.this[string key]
{
get => _dictionaryImplementation[key];
set => _dictionaryImplementation[key] = value;
}
}
""");

TestHelper.GenerateMapper(source)
.Should()
.HaveSingleMethodBody(
"""
var target = CreateA();
var targetDict = (global::System.Collections.Generic.IDictionary<string, string>)target;
foreach (var item in source)
{
targetDict[item.Key] = item.Value;
}
return target;
""");
}

[Fact]
public void DictionaryToImplicitDictionaryWithObjectFactoryShouldNotCast()
{
var source = TestSourceBuilder.MapperWithBodyAndTypes(
"[ObjectFactory] A CreateA() => new();"
+ "partial A Map(Dictionary<string, string> source);",
"""
public class A : IDictionary<string, string>
{
public string this[string key]
{
get => _dictionaryImplementation[key];
set => _dictionaryImplementation[key] = value;
}
}
""");

TestHelper.GenerateMapper(source)
.Should()
.HaveSingleMethodBody(
"""
var target = CreateA();
foreach (var item in source)
{
target[item.Key] = item.Value;
}
return target;
""");
}

[Fact]
public Task DictionaryToCustomDictionaryWithPrivateCtorShouldDiagnostic()
{
Expand Down

0 comments on commit 20f6f06

Please sign in to comment.