From 249090093a6444f12edb9cd233c30d3fa771a798 Mon Sep 17 00:00:00 2001 From: AlekseyTs Date: Wed, 28 May 2025 08:54:44 -0700 Subject: [PATCH] Avoid dereferencing null CheckConstraintsArgs.CurrentCompilation Fixes #78430. --- .../Portable/Symbols/ConstraintsHelper.cs | 22 ++++++++++++--- .../Portable/Symbols/SymbolDistinguisher.cs | 20 ++++++------- .../Test/Emit3/RefStructInterfacesTests.cs | 28 +++++++++++++++++++ 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/src/Compilers/CSharp/Portable/Symbols/ConstraintsHelper.cs b/src/Compilers/CSharp/Portable/Symbols/ConstraintsHelper.cs index 5291a3d6cc0cf..f632e5f658acd 100644 --- a/src/Compilers/CSharp/Portable/Symbols/ConstraintsHelper.cs +++ b/src/Compilers/CSharp/Portable/Symbols/ConstraintsHelper.cs @@ -523,7 +523,9 @@ public static void CheckAllConstraints(this TypeSymbol type, CheckConstraintsArg internal readonly struct CheckConstraintsArgs { - public readonly CSharpCompilation CurrentCompilation; +#nullable enable + public readonly CSharpCompilation? CurrentCompilation; +#nullable disable public readonly ConversionsBase Conversions; public readonly bool IncludeNullability; public readonly Location Location; @@ -590,7 +592,12 @@ private static bool CheckConstraintsSingleType(TypeSymbol type, in CheckConstrai } else if (type.Kind == SymbolKind.PointerType) { - Binder.CheckManagedAddr(args.CurrentCompilation, ((PointerTypeSymbol)type).PointedAtType, args.Location, args.Diagnostics); +#nullable enable + if (args.CurrentCompilation is not null) + { + Binder.CheckManagedAddr(args.CurrentCompilation, ((PointerTypeSymbol)type).PointedAtType, args.Location, args.Diagnostics); + } +#nullable disable } return false; // continue walking types } @@ -734,6 +741,7 @@ public static bool CheckConstraints(this NamedTypeSymbol type, in CheckConstrain diagnosticsBuilder.Free(); +#nullable enable // we only check for distinct interfaces when the type is not from source, as we // trust that types that are from source have already been checked by the compiler // to prevent this from happening in the first place. @@ -742,7 +750,7 @@ public static bool CheckConstraints(this NamedTypeSymbol type, in CheckConstrain result = false; args.Diagnostics.Add(ErrorCode.ERR_BogusType, args.Location, type); } - +#nullable disable return result; } @@ -958,7 +966,8 @@ private static bool CheckBasicConstraints( { if (typeParameter.AllowsRefLikeType) { - if (args.CurrentCompilation.SourceModule != typeParameter.ContainingModule) +#nullable enable + if (args.CurrentCompilation is not null && args.CurrentCompilation.SourceModule != typeParameter.ContainingModule) { if (MessageID.IDS_FeatureAllowsRefStructConstraint.GetFeatureAvailabilityDiagnosticInfo(args.CurrentCompilation) is { } diagnosticInfo) { @@ -970,6 +979,7 @@ private static bool CheckBasicConstraints( diagnosticsBuilder.Add(new TypeParameterDiagnosticInfo(typeParameter, new UseSiteInfo(new CSDiagnosticInfo(ErrorCode.ERR_RuntimeDoesNotSupportByRefLikeGenerics)))); } } +#nullable disable } else { @@ -1011,6 +1021,7 @@ private static bool CheckBasicConstraints( } else if (managedKind == ManagedKind.UnmanagedWithGenerics) { +#nullable enable // When there is no compilation, we are being invoked through the API IMethodSymbol.ReduceExtensionMethod(...). // In that case we consider the unmanaged constraint to be satisfied as if we were compiling with the latest // language version. The net effect of this is that in some IDE scenarios completion might consider an @@ -1024,6 +1035,7 @@ private static bool CheckBasicConstraints( return false; } } +#nullable disable } } @@ -1203,9 +1215,11 @@ private static void CheckConstraintType( } else { +#nullable enable SymbolDistinguisher distinguisher = new SymbolDistinguisher(args.CurrentCompilation, constraintType.Type, typeArgument.Type); constraintTypeErrorArgument = distinguisher.First; typeArgumentErrorArgument = distinguisher.Second; +#nullable disable } diagnosticsBuilder.Add(new TypeParameterDiagnosticInfo(typeParameter, new UseSiteInfo(new CSDiagnosticInfo(errorCode, containingSymbol.ConstructedFrom(), constraintTypeErrorArgument, typeParameter, typeArgumentErrorArgument)))); diff --git a/src/Compilers/CSharp/Portable/Symbols/SymbolDistinguisher.cs b/src/Compilers/CSharp/Portable/Symbols/SymbolDistinguisher.cs index b71990d8bc9c0..fb604da012007 100644 --- a/src/Compilers/CSharp/Portable/Symbols/SymbolDistinguisher.cs +++ b/src/Compilers/CSharp/Portable/Symbols/SymbolDistinguisher.cs @@ -2,8 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -#nullable disable - using Microsoft.CodeAnalysis.CSharp.Symbols; using Roslyn.Utilities; using System; @@ -24,13 +22,13 @@ namespace Microsoft.CodeAnalysis.CSharp /// internal sealed class SymbolDistinguisher { - private readonly CSharpCompilation _compilation; + private readonly CSharpCompilation? _compilation; private readonly Symbol _symbol0; private readonly Symbol _symbol1; private ImmutableArray _lazyDescriptions; - public SymbolDistinguisher(CSharpCompilation compilation, Symbol symbol0, Symbol symbol1) + public SymbolDistinguisher(CSharpCompilation? compilation, Symbol symbol0, Symbol symbol1) { Debug.Assert(symbol0 != symbol1); CheckSymbolKind(symbol0); @@ -96,8 +94,8 @@ private void MakeDescriptions() Symbol unwrappedSymbol0 = UnwrapSymbol(_symbol0); Symbol unwrappedSymbol1 = UnwrapSymbol(_symbol1); - string location0 = GetLocationString(_compilation, unwrappedSymbol0); - string location1 = GetLocationString(_compilation, unwrappedSymbol1); + string? location0 = GetLocationString(_compilation, unwrappedSymbol0); + string? location1 = GetLocationString(_compilation, unwrappedSymbol1); // The locations should not be equal, but they might be if the same // SyntaxTree is referenced by two different compilations. @@ -158,7 +156,7 @@ private static Symbol UnwrapSymbol(Symbol symbol) } } - private static string GetLocationString(CSharpCompilation compilation, Symbol unwrappedSymbol) + private static string? GetLocationString(CSharpCompilation? compilation, Symbol unwrappedSymbol) { Debug.Assert((object)unwrappedSymbol == UnwrapSymbol(unwrappedSymbol)); @@ -179,10 +177,10 @@ private static string GetLocationString(CSharpCompilation compilation, Symbol un { if (compilation != null) { - PortableExecutableReference metadataReference = compilation.GetMetadataReference(containingAssembly) as PortableExecutableReference; + PortableExecutableReference? metadataReference = compilation.GetMetadataReference(containingAssembly) as PortableExecutableReference; if (metadataReference != null) { - string path = metadataReference.FilePath; + string? path = metadataReference.FilePath; if (!string.IsNullOrEmpty(path)) { return path; @@ -219,7 +217,7 @@ private Symbol GetSymbol() return (_index == 0) ? _distinguisher._symbol0 : _distinguisher._symbol1; } - public override bool Equals(object obj) + public override bool Equals(object? obj) { var other = obj as Description; return other != null && @@ -243,7 +241,7 @@ public override string ToString() return _distinguisher.GetDescription(_index); } - string IFormattable.ToString(string format, IFormatProvider formatProvider) + string IFormattable.ToString(string? format, IFormatProvider? formatProvider) { return ToString(); } diff --git a/src/Compilers/CSharp/Test/Emit3/RefStructInterfacesTests.cs b/src/Compilers/CSharp/Test/Emit3/RefStructInterfacesTests.cs index dfd89fae251fe..3edeaabec1313 100644 --- a/src/Compilers/CSharp/Test/Emit3/RefStructInterfacesTests.cs +++ b/src/Compilers/CSharp/Test/Emit3/RefStructInterfacesTests.cs @@ -29198,5 +29198,33 @@ static async IAsyncEnumerator B() where T : allows ref struct Diagnostic(ErrorCode.ERR_IteratorRefLikeElementType, "B").WithLocation(8, 38) ); } + + [Fact] + [WorkItem("https://github.com/dotnet/roslyn/issues/78430")] + public void Issue78430() + { + var source = +@" +public ref struct TestStruct +{ + public int Prop1 {get; set;} +} + +public static class TestClass +{ + public static void TestExtensionMethod(this T value) + where T : allows ref struct + { + } +} +"; + var comp = CreateCompilation(source, targetFramework: TargetFramework.Net90); + CompileAndVerify(comp, verify: ExecutionConditionUtil.IsMonoOrCoreClr ? Verification.Passes : Verification.Skipped).VerifyDiagnostics(); + + var testStruct = comp.GetTypeByMetadataName("TestStruct"); + var extensionMethodSymbol = comp.GetMember("TestClass.TestExtensionMethod"); + + AssertEx.Equal("void TestStruct.TestExtensionMethod()", extensionMethodSymbol.ReduceExtensionMethod(testStruct, null).ToTestDisplayString()); + } } }