Skip to content

Commit

Permalink
#7 update generation to consider IEnumerable
Browse files Browse the repository at this point in the history
  • Loading branch information
cathei committed Feb 12, 2023
1 parent f7f71fa commit ec8c182
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 30 deletions.
84 changes: 68 additions & 16 deletions LinqGen.Generator/CodeGenUtils.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// LinqGen.Generator, Maxwell Keonwoo Kang <code.athei@gmail.com>, 2022

using System;
using System.Linq;

namespace Cathei.LinqGen.Generator;
Expand All @@ -16,6 +17,7 @@ public static class CodeGenUtils
private const string LinqGenStructFunctionTypeName = "IStructFunction";

private const string SystemNamespace = "System";
private const string SystemCollectionsNamespace = "System.Collections";
private const string SystemCollectionsGenericNamespace = "System.Collections.Generic";
private const string SpanTypeName = "Span`1";
private const string ReadOnlySpanTypeName = "ReadOnlySpan`1";
Expand Down Expand Up @@ -66,19 +68,29 @@ public static bool IsUnityNativeArrayOrSlice(ITypeSymbol symbol)
public static bool TryParseStubInterface(INamedTypeSymbol symbol,
out ITypeSymbol inputElementSymbol, out INamedTypeSymbol signatureSymbol)
{
inputElementSymbol = default!;
signatureSymbol = default!;

// generic signature type should not be allowed
// receiver type is: IStub<IEnumerable<T>, TSignature>
if (symbol.TypeArguments.Length != 2 ||
symbol.TypeArguments[1] is not INamedTypeSymbol resultSignatureSymbol ||
!TryGetGenericEnumerableInterface(symbol.TypeArguments[0], out var sourceTypeSymbol) ||
sourceTypeSymbol.TypeArguments.Length != 1)
!TryGetEnumerableInterface(symbol.TypeArguments[0], out var interfaceSymbol))
{
inputElementSymbol = default!;
signatureSymbol = default!;
return false;
}

inputElementSymbol = sourceTypeSymbol.TypeArguments[0];
// find GetEnumerator method
var getEnumeratorSymbol = GetEnumeratorSymbol(interfaceSymbol);
if (getEnumeratorSymbol is null)
return false;

// find Current property
var currentTypeSymbol = GetCurrentSymbol(getEnumeratorSymbol.ReturnType);
if (currentTypeSymbol is null)
return false;

inputElementSymbol = currentTypeSymbol;
signatureSymbol = resultSignatureSymbol;
return true;
}
Expand Down Expand Up @@ -627,10 +639,17 @@ public static bool TryGetGenericListInterface(ITypeSymbol symbol, out INamedType
return interfaceSymbol != null!;
}

public static bool TryGetGenericEnumerableInterface(ITypeSymbol symbol, out INamedTypeSymbol interfaceSymbol)
private static bool TryGetEnumerableInterface(ITypeSymbol symbol, out INamedTypeSymbol interfaceSymbol)
{
interfaceSymbol = GetInterface(symbol, SystemCollectionsGenericNamespace, "IEnumerable`1")!;
return interfaceSymbol != null!;
if (interfaceSymbol != null!)
return true;

interfaceSymbol = GetInterface(symbol, SystemCollectionsNamespace, "IEnumerable")!;
if (interfaceSymbol != null!)
return true;

return false;
}

public static bool TryGetGenericCollectionInterface(ITypeSymbol symbol, out INamedTypeSymbol? interfaceSymbol)
Expand Down Expand Up @@ -668,24 +687,57 @@ public static bool TryGetEquatableSelfInterface(ITypeSymbol symbol, out INamedTy
return interfaceSymbol != null!;
}

private static T? FindDuckTypingSymbol<T>(ITypeSymbol typeSymbol, Func<ITypeSymbol, T?> getter)
{
// find GetEnumerator with same rule as C# duck typing
// TODO fallback to interface implementation
var result = getter(typeSymbol);

if (result != null)
return result;

if (typeSymbol.BaseType != null)
return FindDuckTypingSymbol(typeSymbol.BaseType, getter);

if (typeSymbol.TypeKind == TypeKind.Interface)
{
foreach (var interfaceSymbol in typeSymbol.Interfaces)
{
result = FindDuckTypingSymbol(interfaceSymbol, getter);
if (result != null)
return result;
}
}

return default;
}

public static IMethodSymbol? GetEnumeratorSymbol(ITypeSymbol enumerableSymbol)
{
// find GetEnumerator with same rule as C# duck typing
return enumerableSymbol.GetMembers()
// TODO fallback to interface implementation
return FindDuckTypingSymbol(enumerableSymbol, static x => x.GetMembers()
.OfType<IMethodSymbol>()
.Concat(enumerableSymbol.AllInterfaces.SelectMany(a=>a.GetMembers()).OfType<IMethodSymbol>())
.FirstOrDefault(x =>
x.DeclaredAccessibility == Accessibility.Public &&
x.Name == "GetEnumerator" && x.Parameters.Length == 0 && x.TypeParameters.Length == 0);
.FirstOrDefault(static x => x is
{
DeclaredAccessibility: Accessibility.Public,
Name: "GetEnumerator",
Parameters.Length: 0,
TypeParameters.Length: 0
}));
}

public static ITypeSymbol GetCurrentSymbol(ITypeSymbol enumeratorSymbol)
public static ITypeSymbol? GetCurrentSymbol(ITypeSymbol enumeratorSymbol)
{
// find Current property with same rule as C# duck typing
return enumeratorSymbol.GetMembers()
// TODO fallback to interface implementation
return FindDuckTypingSymbol(enumeratorSymbol, static x => x.GetMembers()
.OfType<IPropertySymbol>()
.First(x => x.DeclaredAccessibility == Accessibility.Public && x.Name == "Current")
.Type;
.FirstOrDefault(static x => x is
{
DeclaredAccessibility: Accessibility.Public,
Name: "Current"
})?.Type);
}

public static INamedTypeSymbol NormalizeSignature(INamedTypeSymbol signature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ public EnumerableGeneration(in LinqGenExpression expression, int id,
{
IsCollection = TryGetGenericCollectionInterface(sourceSymbol, out _);

// TODO fallback to interface specific implementation
var enumeratorSymbol = GetEnumeratorSymbol(sourceSymbol)!.ReturnType;
var elementSymbol = GetCurrentSymbol(enumeratorSymbol);
var elementSymbol = GetCurrentSymbol(enumeratorSymbol)!;

SourceEnumerableType = ParseTypeName(sourceSymbol);
SourceEnumeratorType = ParseTypeName(enumeratorSymbol);
Expand Down
26 changes: 13 additions & 13 deletions LinqGen.Tests/FundamentalTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,19 +201,19 @@ public Enumerator GetEnumerator()
}
}

// [Test]
// public void CollectionInterfaceTest()
// {
// ICollection collection = TestData.IntList;
// var generation = collection.Gen();
//
// object value = generation
// .Select(x => x)
// .First();
//
// Assert.AreEqual(collection.Count, generation.Count());
// Assert.AreEqual(TestData.IntList[0], value);
// }
[Test]
public void CollectionInterfaceTest()
{
ICollection collection = TestData.IntList;
var generation = collection.Gen();

int value = generation
.Cast<int>()
.First();

Assert.AreEqual(collection.Count, generation.Count());
Assert.AreEqual(TestData.IntList[0], value);
}

[Test]
public void CollectionGenericInterfaceTest()
Expand Down

0 comments on commit ec8c182

Please sign in to comment.