Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stack overflow in the configuration source gen #106511

Merged
merged 4 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,9 @@ private bool IsAssignableTo(ITypeSymbol source, ITypeSymbol dest)
return conversion.IsReference && conversion.IsImplicit;
}

private bool IsUnsupportedType(ITypeSymbol type)
[ThreadStatic] private static HashSet<ITypeSymbol> ts_visitedTypes;

private bool IsUnsupportedType(ITypeSymbol type, bool rootCall = true)
{
if (type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T)
{
Expand All @@ -562,9 +564,28 @@ private bool IsUnsupportedType(ITypeSymbol type)
return true;
}

if (rootCall)
{
ts_visitedTypes ??= new HashSet<ITypeSymbol>(SymbolEqualityComparer.Default);
ts_visitedTypes.Clear();
tarekgh marked this conversation as resolved.
Show resolved Hide resolved
}
else if (ts_visitedTypes.Contains(type))
{
// avoid infinite recursion in nested types like
// public record ExampleSettings
// {
// public TreeElement? Tree { get; set; }
// }
// public sealed class TreeElement : Dictionary<string, TreeElement>;
//
// return false for the second call. The type will continue be checked in the first call anyway.
return false;
}
ts_visitedTypes.Add(type);

if (type is IArrayTypeSymbol arrayTypeSymbol)
{
return arrayTypeSymbol.Rank > 1 || IsUnsupportedType(arrayTypeSymbol.ElementType);
return arrayTypeSymbol.Rank > 1 || IsUnsupportedType(arrayTypeSymbol.ElementType, rootCall: false);
}

if (IsCollection(type))
Expand All @@ -573,11 +594,11 @@ private bool IsUnsupportedType(ITypeSymbol type)

if (IsCandidateDictionary(collectionType, out ITypeSymbol? keyType, out ITypeSymbol? elementType))
{
return IsUnsupportedType(keyType) || IsUnsupportedType(elementType);
return IsUnsupportedType(keyType, rootCall: false) || IsUnsupportedType(elementType, rootCall: false);
}
else if (TryGetElementType(collectionType, out elementType))
{
return IsUnsupportedType(elementType);
return IsUnsupportedType(elementType, rootCall: false);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ bool TryRegisterCore()
}
case DictionarySpec dictionarySpec:
{
// Base case to avoid stack overflow for recursive object graphs.
_seenTransitiveTypes.Add(typeRef, true);

bool shouldRegister = _typeIndex.CanBindTo(typeRef) &&
TryRegisterTransitiveTypesForMethodGen(dictionarySpec.KeyTypeRef) &&
TryRegisterTransitiveTypesForMethodGen(dictionarySpec.ElementTypeRef) &&
Expand All @@ -145,6 +148,9 @@ bool TryRegisterCore()
}
case CollectionSpec collectionSpec:
{
// Base case to avoid stack overflow for recursive object graphs.
_seenTransitiveTypes.Add(typeRef, true);

if (_typeIndex.GetTypeSpec(collectionSpec.ElementTypeRef) is ComplexTypeSpec)
{
_namespaces.Add("System.Linq");
Expand All @@ -157,8 +163,7 @@ bool TryRegisterCore()
{
// Base case to avoid stack overflow for recursive object graphs.
// Register all object types for gen; we need to throw runtime exceptions in some cases.
bool shouldRegister = true;
_seenTransitiveTypes.Add(typeRef, shouldRegister);
_seenTransitiveTypes.Add(typeRef, true);

// List<string> is used in generated code as a temp holder for formatting
// an error for config properties that don't map to object properties.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ public string Color
}
}

public sealed class TreeElement : Dictionary<string, TreeElement>;

public record TypeWithRecursionThroughCollections
{
public TreeElement? Tree { get; set; }
public TreeElement?[]? Flat { get; set; }
ericstj marked this conversation as resolved.
Show resolved Hide resolved
}

public record RecordWithArrayParameter(string[] Array);

public readonly record struct ReadonlyRecordStructTypeOptions(string Color, int Length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,79 @@ public void CanBindOnParametersAndProperties_RecordWithArrayConstructorParameter
Assert.Equal(new string[] { "a", "b", "c" }, options.Array);
}

/// <summary>
/// Test binding to recursive types using Dictionary or Collections.
/// This ensure no stack overflow will occur during the compilation through the source gen or at runtime.
/// </summary>
[Fact]
public void BindToRecursiveTypesTest()
{
string jsonConfig = @"{
""Tree"": {
""Branch1"": {
""Leaf1"": {},
""Leaf2"": {}
},
""Branch2"": {
""Leaf3"": {}
}
},
""Flat"": [
{
""Element1"": {
""SubElement1"": {}
}
},
{
""Element2"": {
""SubElement2"": {}
}
},
{
""Element3"": {}
}
]
}";

var configuration = new ConfigurationBuilder()
.AddJsonStream(new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(jsonConfig)))
.Build();

var instance = new TypeWithRecursionThroughCollections();
configuration.Bind(instance);

Assert.NotNull(instance.Tree);
Assert.Equal(2, instance.Tree.Count);
ericstj marked this conversation as resolved.
Show resolved Hide resolved

Assert.NotNull(instance.Tree["Branch1"]);
Assert.Equal(2, instance.Tree["Branch1"].Count);
Assert.Equal(["Leaf1", "Leaf2"], instance.Tree["Branch1"].Keys);
Assert.Equal([[], []], instance.Tree["Branch1"].Values);

Assert.NotNull(instance.Tree["Branch2"]);
Assert.Equal(1, instance.Tree["Branch2"].Count);
Assert.Equal(["Leaf3"], instance.Tree["Branch2"].Keys);
Assert.Equal([[]], instance.Tree["Branch2"].Values);

Assert.NotNull(instance.Flat);
Assert.Equal(3, instance.Flat.Length);

Assert.Equal(["Element1"], instance.Flat[0].Keys);
Assert.Equal(["Element2"], instance.Flat[1].Keys);
Assert.Equal(["Element3"], instance.Flat[2].Keys);

Assert.Equal(1, instance.Flat[0].Values.Count);
Assert.Equal(["SubElement1"], instance.Flat[0].Values.ToArray()[0].Keys);
Assert.Equal([[]], instance.Flat[0].Values.ToArray()[0].Values);

Assert.Equal(1, instance.Flat[1].Values.Count);
Assert.Equal(["SubElement2"], instance.Flat[1].Values.ToArray()[0].Keys);
Assert.Equal([[]], instance.Flat[1].Values.ToArray()[0].Values);

Assert.Equal(1, instance.Flat[2].Values.Count);
Assert.Equal([[]], instance.Flat[2].Values.ToArray());
}

[Fact]
public void CanBindReadonlyRecordStructOptions()
{
Expand Down