Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stack overflow in the configuration source gen #106511

Merged
merged 4 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,9 @@ private bool IsAssignableTo(ITypeSymbol source, ITypeSymbol dest)
return conversion.IsReference && conversion.IsImplicit;
}

private bool IsUnsupportedType(ITypeSymbol type)
private static HashSet<ITypeSymbol>? s_visitedTypes;
ericstj marked this conversation as resolved.
Show resolved Hide resolved

private bool IsUnsupportedType(ITypeSymbol type, HashSet<ITypeSymbol>? visitedTypes = null)
{
if (type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T)
{
Expand All @@ -562,26 +564,68 @@ private bool IsUnsupportedType(ITypeSymbol type)
return true;
}

if (type is IArrayTypeSymbol arrayTypeSymbol)
if (visitedTypes?.Contains(type) is true)
{
return arrayTypeSymbol.Rank > 1 || IsUnsupportedType(arrayTypeSymbol.ElementType);
// avoid infinite recursion in nested types like
// public record ExampleSettings
// {
// public TreeElement? Tree { get; set; }
// }
// public sealed class TreeElement : Dictionary<string, TreeElement>;
//
// return false for the second call. The type will continue be checked in the first call anyway.
return false;
}

if (IsCollection(type))
IArrayTypeSymbol? arrayTypeSymbol = type as IArrayTypeSymbol;
bool isCollection = IsCollection(type);

if (arrayTypeSymbol is null && !isCollection)
{
return false;
}

bool restoreVisitedTypes = false;
if (visitedTypes is null)
{
visitedTypes = Interlocked.Exchange(ref s_visitedTypes, null) ?? new(SymbolEqualityComparer.Default);
visitedTypes.Clear();
restoreVisitedTypes = true;
}

visitedTypes.Add(type);

bool result;
if (arrayTypeSymbol is not null)
{
result = arrayTypeSymbol.Rank > 1 || IsUnsupportedType(arrayTypeSymbol.ElementType, visitedTypes);
}
else
{
Debug.Assert(isCollection);

INamedTypeSymbol collectionType = (INamedTypeSymbol)type;

if (IsCandidateDictionary(collectionType, out ITypeSymbol? keyType, out ITypeSymbol? elementType))
{
return IsUnsupportedType(keyType) || IsUnsupportedType(elementType);
result = IsUnsupportedType(keyType, visitedTypes) || IsUnsupportedType(elementType, visitedTypes);
}
else if (TryGetElementType(collectionType, out elementType))
{
return IsUnsupportedType(elementType);
result = IsUnsupportedType(elementType, visitedTypes);
}
else
{
result = false;
}
}

return false;
if (restoreVisitedTypes)
{
Interlocked.Exchange(ref s_visitedTypes, visitedTypes);
}

return result;
}

private bool ConstructorParametersContainUnsupportedType(IMethodSymbol ctor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ bool TryRegisterCore()
}
case DictionarySpec dictionarySpec:
{
// Base case to avoid stack overflow for recursive object graphs.
_seenTransitiveTypes.Add(typeRef, true);

bool shouldRegister = _typeIndex.CanBindTo(typeRef) &&
TryRegisterTransitiveTypesForMethodGen(dictionarySpec.KeyTypeRef) &&
TryRegisterTransitiveTypesForMethodGen(dictionarySpec.ElementTypeRef) &&
Expand All @@ -145,6 +148,9 @@ bool TryRegisterCore()
}
case CollectionSpec collectionSpec:
{
// Base case to avoid stack overflow for recursive object graphs.
_seenTransitiveTypes.Add(typeRef, true);

if (_typeIndex.GetTypeSpec(collectionSpec.ElementTypeRef) is ComplexTypeSpec)
{
_namespaces.Add("System.Linq");
Expand All @@ -157,8 +163,7 @@ bool TryRegisterCore()
{
// Base case to avoid stack overflow for recursive object graphs.
// Register all object types for gen; we need to throw runtime exceptions in some cases.
bool shouldRegister = true;
_seenTransitiveTypes.Add(typeRef, shouldRegister);
_seenTransitiveTypes.Add(typeRef, true);

// List<string> is used in generated code as a temp holder for formatting
// an error for config properties that don't map to object properties.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,15 @@ public string Color
}
}

public sealed class TreeElement : Dictionary<string, TreeElement>;

public record TypeWithRecursionThroughCollections
{
public TreeElement? Tree { get; set; }
public TreeElement?[]? Flat { get; set; }
ericstj marked this conversation as resolved.
Show resolved Hide resolved
public List<TreeElement>? List { get; set; }
}

public record RecordWithArrayParameter(string[] Array);

public readonly record struct ReadonlyRecordStructTypeOptions(string Color, int Length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,87 @@ public void CanBindOnParametersAndProperties_RecordWithArrayConstructorParameter
Assert.Equal(new string[] { "a", "b", "c" }, options.Array);
}

/// <summary>
/// Test binding to recursive types using Dictionary or Collections.
/// This ensure no stack overflow will occur during the compilation through the source gen or at runtime.
/// </summary>
[Fact]
public void BindToRecursiveTypesTest()
{
string jsonConfig = @"{
""Tree"": {
""Branch1"": {
""Leaf1"": {},
""Leaf2"": {}
},
""Branch2"": {
""Leaf3"": {}
}
},
""Flat"": [
{
""Element1"": {
""SubElement1"": {}
}
},
{
""Element2"": {
""SubElement2"": {}
}
},
{
""Element3"": {}
}
],
""List"": [
{
""Item1"": {
""NestedItem1"": {}
}
},
{
""Item2"": {}
},
]
}";

var configuration = new ConfigurationBuilder()
.AddJsonStream(new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(jsonConfig)))
.Build();

var instance = new TypeWithRecursionThroughCollections();
configuration.Bind(instance);

// Validate the dictionary
Assert.NotNull(instance.Tree);
Assert.Equal(2, instance.Tree.Count);
ericstj marked this conversation as resolved.
Show resolved Hide resolved
Assert.NotNull(instance.Tree["Branch1"]);
Assert.Equal(2, instance.Tree["Branch1"].Count);
Assert.Equal(["Leaf1", "Leaf2"], instance.Tree["Branch1"].Keys);
Assert.Equal(["Leaf3"], instance.Tree["Branch2"].Keys);

// Validate the array
Assert.NotNull(instance.Flat);
Assert.Equal(3, instance.Flat.Length);
Assert.Equal(["Element1"], instance.Flat[0].Keys);
Assert.Equal(["Element2"], instance.Flat[1].Keys);
Assert.Equal(["Element3"], instance.Flat[2].Keys);
Assert.Equal(1, instance.Flat[0].Values.Count);
Assert.Equal(["SubElement1"], instance.Flat[0].Values.ToArray()[0].Keys);
Assert.Equal(1, instance.Flat[1].Values.Count);
Assert.Equal(["SubElement2"], instance.Flat[1].Values.ToArray()[0].Keys);
Assert.Equal(1, instance.Flat[2].Values.Count);

// Validate the List
Assert.NotNull(instance.Flat);
Assert.Equal(2, instance.List.Count);
Assert.Equal(["Item1"], instance.List[0].Keys);
Assert.Equal(["Item2"], instance.List[1].Keys);
Assert.Equal(1, instance.List[0].Values.Count);
Assert.Equal(["NestedItem1"], instance.List[0].Values.ToArray()[0].Keys);
Assert.Equal(1, instance.List[1].Values.Count);
}

[Fact]
public void CanBindReadonlyRecordStructOptions()
{
Expand Down