Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow fully oblivious types to coexist with nullable-aware base types in partial classes #55861

Merged
merged 7 commits into from
Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -289,37 +289,59 @@ private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeDeclaredBase
var newBasesBeingResolved = basesBeingResolved.Prepend(this.OriginalDefinition);
var baseInterfaces = ArrayBuilder<NamedTypeSymbol>.GetInstance();

NamedTypeSymbol baseType = null;
TypeWithAnnotations baseType = default;
RikkiGibson marked this conversation as resolved.
Show resolved Hide resolved
SourceLocation baseTypeLocation = null;

var interfaceLocations = SpecializedSymbolCollections.GetPooledSymbolDictionaryInstance<NamedTypeSymbol, SourceLocation>();

foreach (var decl in this.declaration.Declarations)
{
Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> one = MakeOneDeclaredBases(newBasesBeingResolved, decl, diagnostics);
if ((object)one == null) continue;
var (partBase, partInterfaces) = MakeOneDeclaredBases(newBasesBeingResolved, decl, diagnostics);
if (partBase.IsDefault && partInterfaces.IsDefaultOrEmpty) continue;

var partBase = one.Item1;
var partInterfaces = one.Item2;
if (!reportedPartialConflict)
{
if ((object)baseType == null)
if (baseType.IsDefault)
{
baseType = partBase;
baseTypeLocation = decl.NameLocation;
}
else if (baseType.TypeKind == TypeKind.Error && (object)partBase != null)
else if (baseType.TypeKind == TypeKind.Error && !partBase.IsDefault)
{
// if the old base was an error symbol, copy it to the interfaces list so it doesn't get lost
partInterfaces = partInterfaces.Add(baseType);
partInterfaces = partInterfaces.Add((NamedTypeSymbol)baseType.Type);
baseType = partBase;
baseTypeLocation = decl.NameLocation;
}
else if ((object)partBase != null && !TypeSymbol.Equals(partBase, baseType, TypeCompareKind.ConsiderEverything2) && partBase.TypeKind != TypeKind.Error)
else if (!partBase.IsDefault && !partBase.Equals(baseType, TypeCompareKind.ConsiderEverything) && partBase.TypeKind != TypeKind.Error)
{
// the parts do not agree
if (partBase.Equals(baseType, TypeCompareKind.AllNullableIgnoreOptions))
{
if (baseType.VisitType(
type: null,
(type, arg, flag) => !type.NullableAnnotation.IsOblivious(),
RikkiGibson marked this conversation as resolved.
Show resolved Hide resolved
typePredicate: null,
arg: (object)null) is null)
RikkiGibson marked this conversation as resolved.
Show resolved Hide resolved
{
// 'baseType' is completely oblivious. Prefer 'partBase' in this case.
baseType = partBase;
baseTypeLocation = decl.NameLocation;
continue;
}
else if (partBase.VisitType(
type: null,
(type, arg, flag) => !type.NullableAnnotation.IsOblivious(),
RikkiGibson marked this conversation as resolved.
Show resolved Hide resolved
typePredicate: null,
arg: (object)null) is null)
{
// 'partBase' is completely oblivious. Prefer 'baseType' in this case.
continue;
}
}

var info = diagnostics.Add(ErrorCode.ERR_PartialMultipleBases, Locations[0], this);
baseType = new ExtendedErrorTypeSymbol(baseType, LookupResultKind.Ambiguous, info);
baseType = baseType.WithType(new ExtendedErrorTypeSymbol(baseType.Type, LookupResultKind.Ambiguous, info));
baseTypeLocation = decl.NameLocation;
reportedPartialConflict = true;
}
Expand Down Expand Up @@ -347,7 +369,7 @@ private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeDeclaredBase
}
}

if ((object)baseType != null)
if (!baseType.IsDefault)
{
Debug.Assert(baseTypeLocation != null);
if (baseType.IsStatic)
Expand Down Expand Up @@ -380,7 +402,7 @@ private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeDeclaredBase

diagnostics.Add(Locations[0], useSiteInfo);

return new Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>>(baseType, baseInterfacesRO);
return Tuple.Create((NamedTypeSymbol)baseType.Type, baseInterfacesRO);
}

private static BaseListSyntax GetBaseListOpt(SingleTypeDeclaration decl)
Expand All @@ -395,15 +417,15 @@ private static BaseListSyntax GetBaseListOpt(SingleTypeDeclaration decl)
}

// process the base list for one part of a partial class, or for the only part of any other type declaration.
private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeOneDeclaredBases(ConsList<TypeSymbol> newBasesBeingResolved, SingleTypeDeclaration decl, BindingDiagnosticBag diagnostics)
private (TypeWithAnnotations localBase, ImmutableArray<NamedTypeSymbol> interfaces) MakeOneDeclaredBases(ConsList<TypeSymbol> newBasesBeingResolved, SingleTypeDeclaration decl, BindingDiagnosticBag diagnostics)
RikkiGibson marked this conversation as resolved.
Show resolved Hide resolved
{
BaseListSyntax bases = GetBaseListOpt(decl);
if (bases == null)
{
return null;
return default;
}

NamedTypeSymbol localBase = null;
TypeWithAnnotations localBase = default;
var localInterfaces = ArrayBuilder<NamedTypeSymbol>.GetInstance();
var baseBinder = this.DeclaringCompilation.GetBinder(bases);

Expand All @@ -424,11 +446,11 @@ private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeOneDeclaredB

var location = new SourceLocation(typeSyntax);

TypeSymbol baseType;
TypeWithAnnotations baseType;

if (i == 0 && TypeKind == TypeKind.Class) // allow class in the first position
{
baseType = baseBinder.BindType(typeSyntax, diagnostics, newBasesBeingResolved).Type;
baseType = baseBinder.BindType(typeSyntax, diagnostics, newBasesBeingResolved);

SpecialType baseSpecialType = baseType.SpecialType;
if (IsRestrictedBaseType(baseSpecialType))
Expand All @@ -452,7 +474,7 @@ private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeOneDeclaredB
}
}

if (baseType.IsSealed && !this.IsStatic) // Give precedence to ERR_StaticDerivedFromNonObject
if (baseType.Type.IsSealed && !this.IsStatic) // Give precedence to ERR_StaticDerivedFromNonObject
{
diagnostics.Add(ErrorCode.ERR_CantDeriveFromSealedType, location, this, baseType);
continue;
Expand All @@ -469,7 +491,7 @@ private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeOneDeclaredB
{
baseTypeIsErrorWithoutInterfaceGuess = true;

TypeKind guessTypeKind = baseType.GetNonErrorTypeKindGuess();
TypeKind guessTypeKind = baseType.Type.GetNonErrorTypeKindGuess();
if (guessTypeKind == TypeKind.Interface)
{
//base type is an error *with* a guessed interface
Expand All @@ -481,42 +503,41 @@ private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeOneDeclaredB
baseType.TypeKind == TypeKind.Delegate ||
baseType.TypeKind == TypeKind.Struct ||
baseTypeIsErrorWithoutInterfaceGuess) &&
((object)localBase == null))
localBase.IsDefault)
{
localBase = (NamedTypeSymbol)baseType;
Debug.Assert((object)localBase != null);
if (this.IsStatic && localBase.SpecialType != SpecialType.System_Object)
localBase = baseType;
Debug.Assert(!localBase.IsDefault);
if (this.IsStatic && localBase.Type.SpecialType != SpecialType.System_Object)
{
// Static class '{0}' cannot derive from type '{1}'. Static classes must derive from object.
var info = diagnostics.Add(ErrorCode.ERR_StaticDerivedFromNonObject, location, this, localBase);
localBase = new ExtendedErrorTypeSymbol(localBase, LookupResultKind.NotReferencable, info);
localBase = localBase.WithType(new ExtendedErrorTypeSymbol(localBase.Type, LookupResultKind.NotReferencable, info));
}
checkPrimaryConstructorBaseType(baseTypeSyntax, localBase);
checkPrimaryConstructorBaseType(baseTypeSyntax, localBase.Type);
continue;
}
}
else
{
baseType = baseBinder.BindType(typeSyntax, diagnostics, newBasesBeingResolved).Type;
baseType = baseBinder.BindType(typeSyntax, diagnostics, newBasesBeingResolved);
}

if (i == 0)
{
checkPrimaryConstructorBaseType(baseTypeSyntax, baseType);
checkPrimaryConstructorBaseType(baseTypeSyntax, baseType.Type);
}

switch (baseType.TypeKind)
{
case TypeKind.Interface:
foreach (var t in localInterfaces)
{
if (t.Equals(baseType, TypeCompareKind.ConsiderEverything))
if (t.Equals(baseType.Type, TypeCompareKind.ConsiderEverything))
{
diagnostics.Add(ErrorCode.ERR_DuplicateInterfaceInBaseList, location, baseType);
}
else if (t.Equals(baseType, TypeCompareKind.ObliviousNullableModifierMatchesAny))
else if (t.Equals(baseType.Type, TypeCompareKind.ObliviousNullableModifierMatchesAny))
{
// duplicates with ?/! differences are reported later, we report local differences between oblivious and ?/! here
diagnostics.Add(ErrorCode.WRN_DuplicateInterfaceWithNullabilityMismatchInBaseList, location, baseType, this);
}
}
Expand All @@ -533,20 +554,20 @@ private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeOneDeclaredB
diagnostics.Add(ErrorCode.ERR_RefStructInterfaceImpl, location, this, baseType);
}

if (baseType.ContainsDynamic())
if (baseType.Type.ContainsDynamic())
{
diagnostics.Add(ErrorCode.ERR_DeriveFromConstructedDynamic, location, this, baseType);
}

localInterfaces.Add((NamedTypeSymbol)baseType);
localInterfaces.Add((NamedTypeSymbol)baseType.Type);
continue;

case TypeKind.Class:
if (TypeKind == TypeKind.Class)
{
if ((object)localBase == null)
if (localBase.IsDefault)
{
localBase = (NamedTypeSymbol)baseType;
localBase = baseType;
diagnostics.Add(ErrorCode.ERR_BaseClassMustBeFirst, location, baseType);
continue;
}
Expand All @@ -564,7 +585,7 @@ private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeOneDeclaredB

case TypeKind.Error:
// put the error type in the interface list so we don't lose track of it
localInterfaces.Add((NamedTypeSymbol)baseType);
localInterfaces.Add((NamedTypeSymbol)baseType.Type);
continue;

case TypeKind.Dynamic:
Expand All @@ -586,7 +607,7 @@ private Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>> MakeOneDeclaredB
diagnostics.Add(ErrorCode.ERR_ObjectCantHaveBases, new SourceLocation(name));
}

return new Tuple<NamedTypeSymbol, ImmutableArray<NamedTypeSymbol>>(localBase, localInterfaces.ToImmutableAndFree());
return (localBase, localInterfaces.ToImmutableAndFree());

void checkPrimaryConstructorBaseType(BaseTypeSyntax baseTypeSyntax, TypeSymbol baseType)
{
Expand Down
Loading