Skip to content

Commit

Permalink
Fix stack overflow in the configuration source gen (#106511)
Browse files Browse the repository at this point in the history
* Fix StackOverFlow in the Logging source gen

* Address the feedback

* Avoid static field

---------

Co-authored-by: Eric StJohn <ericstj@microsoft.com>
  • Loading branch information
tarekgh and ericstj authored Aug 19, 2024
1 parent b28d3f8 commit f32d1a7
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,9 @@ private bool IsAssignableTo(ITypeSymbol source, ITypeSymbol dest)
return conversion.IsReference && conversion.IsImplicit;
}

private bool IsUnsupportedType(ITypeSymbol type)
private HashSet<ITypeSymbol>? _visitedTypes = new(SymbolEqualityComparer.Default);

private bool IsUnsupportedType(ITypeSymbol type, HashSet<ITypeSymbol>? visitedTypes = null)
{
if (type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T)
{
Expand All @@ -569,25 +571,55 @@ private bool IsUnsupportedType(ITypeSymbol type)
return true;
}

if (type is IArrayTypeSymbol arrayTypeSymbol)
if (visitedTypes?.Contains(type) is true)
{
return arrayTypeSymbol.Rank > 1 || IsUnsupportedType(arrayTypeSymbol.ElementType);
// avoid infinite recursion in nested types like
// public record RecursiveType
// {
// public TreeElement? Tree { get; set; }
// }
// public sealed class TreeElement : Dictionary<string, TreeElement>;
//
// return false for the second call. The type will continue be checked in the first call anyway.
return false;
}

if (IsCollection(type))
IArrayTypeSymbol? arrayTypeSymbol = type as IArrayTypeSymbol;
if (arrayTypeSymbol is null)
{
INamedTypeSymbol collectionType = (INamedTypeSymbol)type;

if (IsCandidateDictionary(collectionType, out ITypeSymbol? keyType, out ITypeSymbol? elementType))
{
return IsUnsupportedType(keyType) || IsUnsupportedType(elementType);
}
else if (TryGetElementType(collectionType, out elementType))
if (!IsCollection(type))
{
return IsUnsupportedType(elementType);
return false;
}
}

if (visitedTypes is null)
{
visitedTypes = _visitedTypes;
visitedTypes.Clear();
}

visitedTypes.Add(type);

if (arrayTypeSymbol is not null)
{
return arrayTypeSymbol.Rank > 1 || IsUnsupportedType(arrayTypeSymbol.ElementType, visitedTypes);
}

Debug.Assert(IsCollection(type));

INamedTypeSymbol collectionType = (INamedTypeSymbol)type;

if (IsCandidateDictionary(collectionType, out ITypeSymbol? keyType, out ITypeSymbol? elementType))
{
return IsUnsupportedType(keyType, visitedTypes) || IsUnsupportedType(elementType, visitedTypes);
}

if (TryGetElementType(collectionType, out elementType))
{
return IsUnsupportedType(elementType, visitedTypes);
}

return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ bool TryRegisterCore()
}
case DictionarySpec dictionarySpec:
{
// Base case to avoid stack overflow for recursive object graphs.
_seenTransitiveTypes.Add(typeRef, true);

bool shouldRegister = _typeIndex.CanBindTo(typeRef) &&
TryRegisterTransitiveTypesForMethodGen(dictionarySpec.KeyTypeRef) &&
TryRegisterTransitiveTypesForMethodGen(dictionarySpec.ElementTypeRef) &&
Expand All @@ -145,6 +148,9 @@ bool TryRegisterCore()
}
case CollectionSpec collectionSpec:
{
// Base case to avoid stack overflow for recursive object graphs.
_seenTransitiveTypes.Add(typeRef, true);

if (_typeIndex.GetTypeSpec(collectionSpec.ElementTypeRef) is ComplexTypeSpec)
{
_namespaces.Add("System.Linq");
Expand All @@ -157,8 +163,7 @@ bool TryRegisterCore()
{
// Base case to avoid stack overflow for recursive object graphs.
// Register all object types for gen; we need to throw runtime exceptions in some cases.
bool shouldRegister = true;
_seenTransitiveTypes.Add(typeRef, shouldRegister);
_seenTransitiveTypes.Add(typeRef, true);

// List<string> is used in generated code as a temp holder for formatting
// an error for config properties that don't map to object properties.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,15 @@ public string Color
}
}

public sealed class TreeElement : Dictionary<string, TreeElement>;

public record TypeWithRecursionThroughCollections
{
public TreeElement? Tree { get; set; }
public TreeElement?[]? Flat { get; set; }
public List<TreeElement>? List { get; set; }
}

public record RecordWithArrayParameter(string[] Array);

public readonly record struct ReadonlyRecordStructTypeOptions(string Color, int Length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,87 @@ public void CanBindOnParametersAndProperties_RecordWithArrayConstructorParameter
Assert.Equal(new string[] { "a", "b", "c" }, options.Array);
}

/// <summary>
/// Test binding to recursive types using Dictionary or Collections.
/// This ensure no stack overflow will occur during the compilation through the source gen or at runtime.
/// </summary>
[Fact]
public void BindToRecursiveTypesTest()
{
string jsonConfig = @"{
""Tree"": {
""Branch1"": {
""Leaf1"": {},
""Leaf2"": {}
},
""Branch2"": {
""Leaf3"": {}
}
},
""Flat"": [
{
""Element1"": {
""SubElement1"": {}
}
},
{
""Element2"": {
""SubElement2"": {}
}
},
{
""Element3"": {}
}
],
""List"": [
{
""Item1"": {
""NestedItem1"": {}
}
},
{
""Item2"": {}
},
]
}";

var configuration = new ConfigurationBuilder()
.AddJsonStream(new System.IO.MemoryStream(System.Text.Encoding.UTF8.GetBytes(jsonConfig)))
.Build();

var instance = new TypeWithRecursionThroughCollections();
configuration.Bind(instance);

// Validate the dictionary
Assert.NotNull(instance.Tree);
Assert.Equal(2, instance.Tree.Count);
Assert.NotNull(instance.Tree["Branch1"]);
Assert.Equal(2, instance.Tree["Branch1"].Count);
Assert.Equal(["Leaf1", "Leaf2"], instance.Tree["Branch1"].Keys);
Assert.Equal(["Leaf3"], instance.Tree["Branch2"].Keys);

// Validate the array
Assert.NotNull(instance.Flat);
Assert.Equal(3, instance.Flat.Length);
Assert.Equal(["Element1"], instance.Flat[0].Keys);
Assert.Equal(["Element2"], instance.Flat[1].Keys);
Assert.Equal(["Element3"], instance.Flat[2].Keys);
Assert.Equal(1, instance.Flat[0].Values.Count);
Assert.Equal(["SubElement1"], instance.Flat[0].Values.ToArray()[0].Keys);
Assert.Equal(1, instance.Flat[1].Values.Count);
Assert.Equal(["SubElement2"], instance.Flat[1].Values.ToArray()[0].Keys);
Assert.Equal(1, instance.Flat[2].Values.Count);

// Validate the List
Assert.NotNull(instance.Flat);
Assert.Equal(2, instance.List.Count);
Assert.Equal(["Item1"], instance.List[0].Keys);
Assert.Equal(["Item2"], instance.List[1].Keys);
Assert.Equal(1, instance.List[0].Values.Count);
Assert.Equal(["NestedItem1"], instance.List[0].Values.ToArray()[0].Keys);
Assert.Equal(1, instance.List[1].Values.Count);
}

[Fact]
public void CanBindReadonlyRecordStructOptions()
{
Expand Down

0 comments on commit f32d1a7

Please sign in to comment.