Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stack overflow in the configuration source gen #106511

Merged
merged 4 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,9 @@ private bool IsAssignableTo(ITypeSymbol source, ITypeSymbol dest)
return conversion.IsReference && conversion.IsImplicit;
}

private bool IsUnsupportedType(ITypeSymbol type)
private HashSet<ITypeSymbol>? _visitedTypes = new(SymbolEqualityComparer.Default);

private bool IsUnsupportedType(ITypeSymbol type, HashSet<ITypeSymbol>? visitedTypes = null)
{
if (type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T)
{
Expand All @@ -569,25 +571,55 @@ private bool IsUnsupportedType(ITypeSymbol type)
return true;
}

if (type is IArrayTypeSymbol arrayTypeSymbol)
if (visitedTypes?.Contains(type) is true)
{
return arrayTypeSymbol.Rank > 1 || IsUnsupportedType(arrayTypeSymbol.ElementType);
// avoid infinite recursion in nested types like
// public record RecursiveType
// {
// public TreeElement? Tree { get; set; }
// }
// public sealed class TreeElement : Dictionary<string, TreeElement>;
//
// return false for the second call. The type will continue be checked in the first call anyway.
return false;
}

if (IsCollection(type))
IArrayTypeSymbol? arrayTypeSymbol = type as IArrayTypeSymbol;
if (arrayTypeSymbol is null)
{
INamedTypeSymbol collectionType = (INamedTypeSymbol)type;

if (IsCandidateDictionary(collectionType, out ITypeSymbol? keyType, out ITypeSymbol? elementType))
{
return IsUnsupportedType(keyType) || IsUnsupportedType(elementType);
}
else if (TryGetElementType(collectionType, out elementType))
if (!IsCollection(type))
{
return IsUnsupportedType(elementType);
return false;
}
}

if (visitedTypes is null)
{
visitedTypes = _visitedTypes;
visitedTypes.Clear();
}

visitedTypes.Add(type);

if (arrayTypeSymbol is not null)
{
return arrayTypeSymbol.Rank > 1 || IsUnsupportedType(arrayTypeSymbol.ElementType, visitedTypes);
}

Debug.Assert(IsCollection(type));

INamedTypeSymbol collectionType = (INamedTypeSymbol)type;

if (IsCandidateDictionary(collectionType, out ITypeSymbol? keyType, out ITypeSymbol? elementType))
{
return IsUnsupportedType(keyType, visitedTypes) || IsUnsupportedType(elementType, visitedTypes);
}

if (TryGetElementType(collectionType, out elementType))
{
return IsUnsupportedType(elementType, visitedTypes);
}

return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ bool TryRegisterCore()
}
case DictionarySpec dictionarySpec:
{
// Base case to avoid stack overflow for recursive object graphs.
_seenTransitiveTypes.Add(typeRef, true);

bool shouldRegister = _typeIndex.CanBindTo(typeRef) &&
TryRegisterTransitiveTypesForMethodGen(dictionarySpec.KeyTypeRef) &&
TryRegisterTransitiveTypesForMethodGen(dictionarySpec.ElementTypeRef) &&
Expand All @@ -145,6 +148,9 @@ bool TryRegisterCore()
}
case CollectionSpec collectionSpec:
{
// Base case to avoid stack overflow for recursive object graphs.
_seenTransitiveTypes.Add(typeRef, true);

if (_typeIndex.GetTypeSpec(collectionSpec.ElementTypeRef) is ComplexTypeSpec)
{
_namespaces.Add("System.Linq");
Expand All @@ -157,8 +163,7 @@ bool TryRegisterCore()
{
// Base case to avoid stack overflow for recursive object graphs.
// Register all object types for gen; we need to throw runtime exceptions in some cases.
bool shouldRegister = true;
_seenTransitiveTypes.Add(typeRef, shouldRegister);
_seenTransitiveTypes.Add(typeRef, true);

// List<string> is used in generated code as a temp holder for formatting
// an error for config properties that don't map to object properties.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ public string Color
}
}

public sealed class TreeElement : Dictionary<string, TreeElement>;

public record TypeWithRecursionThroughCollections
{
public TreeElement? Tree { get; set; }
public TreeElement?[]? Flat { get; set; }
ericstj marked this conversation as resolved.
Show resolved Hide resolved
public List<TreeElement>? List { get; set; }
}

public record RecordWithArrayParameter(string[] Array);

public readonly record struct ReadonlyRecordStructTypeOptions(string Color, int Length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,87 @@ public void CanBindOnParametersAndProperties_RecordWithArrayConstructorParameter
Assert.Equal(new string[] { "a", "b", "c" }, options.Array);
}

/// <summary>
/// Test binding to recursive types using Dictionary or Collections.
/// This ensure no stack overflow will occur during the compilation through the source gen or at runtime.
/// </summary>
[Fact]
public void BindToRecursiveTypesTest()
{
string jsonConfig = @"{
""Tree"": {
""Branch1"": {
""Leaf1"": {},
""Leaf2"": {}
},
""Branch2"": {
""Leaf3"": {}
}
},
""Flat"": [
{
""Element1"": {
""SubElement1"": {}
}
},
{
""Element2"": {
""SubElement2"": {}
}
},
{
""Element3"": {}
}
],
""List"": [
{
""Item1"": {
""NestedItem1"": {}
}
},
{
""Item2"": {}
},
]
}";

var configuration = new ConfigurationBuilder()
.AddJsonStream(new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(jsonConfig)))
.Build();

var instance = new TypeWithRecursionThroughCollections();
configuration.Bind(instance);

// Validate the dictionary
Assert.NotNull(instance.Tree);
Assert.Equal(2, instance.Tree.Count);
ericstj marked this conversation as resolved.
Show resolved Hide resolved
Assert.NotNull(instance.Tree["Branch1"]);
Assert.Equal(2, instance.Tree["Branch1"].Count);
Assert.Equal(["Leaf1", "Leaf2"], instance.Tree["Branch1"].Keys);
Assert.Equal(["Leaf3"], instance.Tree["Branch2"].Keys);

// Validate the array
Assert.NotNull(instance.Flat);
Assert.Equal(3, instance.Flat.Length);
Assert.Equal(["Element1"], instance.Flat[0].Keys);
Assert.Equal(["Element2"], instance.Flat[1].Keys);
Assert.Equal(["Element3"], instance.Flat[2].Keys);
Assert.Equal(1, instance.Flat[0].Values.Count);
Assert.Equal(["SubElement1"], instance.Flat[0].Values.ToArray()[0].Keys);
Assert.Equal(1, instance.Flat[1].Values.Count);
Assert.Equal(["SubElement2"], instance.Flat[1].Values.ToArray()[0].Keys);
Assert.Equal(1, instance.Flat[2].Values.Count);

// Validate the List
Assert.NotNull(instance.Flat);
Assert.Equal(2, instance.List.Count);
Assert.Equal(["Item1"], instance.List[0].Keys);
Assert.Equal(["Item2"], instance.List[1].Keys);
Assert.Equal(1, instance.List[0].Values.Count);
Assert.Equal(["NestedItem1"], instance.List[0].Values.ToArray()[0].Keys);
Assert.Equal(1, instance.List[1].Values.Count);
}

[Fact]
public void CanBindReadonlyRecordStructOptions()
{
Expand Down
Loading