Skip to content

Commit cadd91d

Browse files
author
Julien Couvreur
committed
Extensions: Add ReduceExtensionMember API
1 parent a01d6a0 commit cadd91d

File tree

23 files changed

+739
-28
lines changed

23 files changed

+739
-28
lines changed

src/Compilers/CSharp/Portable/Binder/Binder_Expressions.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8768,12 +8768,12 @@ static MethodGroupResolution resolveMethods(
87688768

87698769
bool inapplicable = false;
87708770
if (method.IsExtensionMethod
8771-
&& (object)method.ReduceExtensionMethod(receiverType, binder.Compilation) == null)
8771+
&& method.ReduceExtensionMethod(receiverType, binder.Compilation) is null)
87728772
{
87738773
inapplicable = true;
87748774
}
87758775
else if (method.GetIsNewExtensionMember()
8776-
&& SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(binder.Compilation, method, receiverType) == null)
8776+
&& SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(binder.Compilation, method, receiverType, wasExtensionFullyInferred: out _) is null)
87778777
{
87788778
inapplicable = true;
87798779
}

src/Compilers/CSharp/Portable/Binder/Binder_Symbols.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1585,7 +1585,7 @@ private void CheckWhatCandidatesWeHave(
15851585
else
15861586
{
15871587
Debug.Assert(symbol.GetIsNewExtensionMember());
1588-
if (SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(this.Compilation, symbol, receiverType) is { } compatibleSubstitutedMember)
1588+
if (SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(this.Compilation, symbol, receiverType, wasExtensionFullyInferred: out _) is { } compatibleSubstitutedMember)
15891589
{
15901590
if (compatibleSubstitutedMember.IsStatic)
15911591
{

src/Compilers/CSharp/Portable/Binder/Semantics/OverloadResolution/MethodTypeInference.cs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ private enum Dependency
129129
Indirect = 0x12
130130
}
131131

132-
private readonly CSharpCompilation _compilation;
132+
#nullable enable
133+
private readonly CSharpCompilation? _compilation;
134+
#nullable disable
133135
private readonly ConversionsBase _conversions;
134136
private readonly ImmutableArray<TypeParameterSymbol> _methodTypeParameters;
135137
private readonly NamedTypeSymbol _constructedContainingTypeOfMethod;
@@ -320,14 +322,14 @@ public static MethodTypeInferenceResult Infer(
320322

321323
#nullable enable
322324
private MethodTypeInferrer(
323-
CSharpCompilation compilation,
325+
CSharpCompilation? compilation,
324326
ConversionsBase conversions,
325327
ImmutableArray<TypeParameterSymbol> methodTypeParameters,
326328
NamedTypeSymbol constructedContainingTypeOfMethod,
327329
ImmutableArray<TypeWithAnnotations> formalParameterTypes,
328330
ImmutableArray<RefKind> formalParameterRefKinds,
329331
ImmutableArray<BoundExpression> arguments,
330-
Extensions extensions,
332+
Extensions? extensions,
331333
Dictionary<TypeParameterSymbol, int>? ordinals)
332334
{
333335
_compilation = compilation;
@@ -2856,7 +2858,7 @@ private bool Fix(int iParam, ref CompoundUseSiteInfo<AssemblySymbol> useSiteInfo
28562858
}
28572859

28582860
private static (TypeWithAnnotations Type, bool FromFunctionType) Fix(
2859-
CSharpCompilation compilation,
2861+
CSharpCompilation? compilation,
28602862
ConversionsBase conversions,
28612863
TypeParameterSymbol typeParameter,
28622864
HashSet<TypeWithAnnotations>? exact,
@@ -2989,6 +2991,7 @@ private static (TypeWithAnnotations Type, bool FromFunctionType) Fix(
29892991
var resultType = functionType.GetInternalDelegateType();
29902992
if (hasExpressionTypeConstraint(typeParameter))
29912993
{
2994+
Debug.Assert(compilation is not null); // Tracked by https://github.com/dotnet/roslyn/issues/80658
29922995
var expressionOfTType = compilation.GetWellKnownType(WellKnownType.System_Linq_Expressions_Expression_T);
29932996
resultType = expressionOfTType.Construct(resultType);
29942997
}
@@ -3179,6 +3182,7 @@ private static NamedTypeSymbol GetInterfaceInferenceBound(ImmutableArray<NamedTy
31793182
// Helper methods
31803183
//
31813184

3185+
#nullable enable
31823186
/// <summary>
31833187
/// We apply type inference to an extension type, using the receiver as argument against the
31843188
/// extension parameter.
@@ -3187,7 +3191,7 @@ private static NamedTypeSymbol GetInterfaceInferenceBound(ImmutableArray<NamedTy
31873191
public static ImmutableArray<TypeWithAnnotations> InferTypeArgumentsFromReceiverType(
31883192
NamedTypeSymbol extension,
31893193
BoundExpression receiver,
3190-
CSharpCompilation compilation,
3194+
CSharpCompilation? compilation,
31913195
ConversionsBase conversions,
31923196
ref CompoundUseSiteInfo<AssemblySymbol> useSiteInfo)
31933197
{
@@ -3216,6 +3220,7 @@ public static ImmutableArray<TypeWithAnnotations> InferTypeArgumentsFromReceiver
32163220

32173221
return inferrer.GetInferredTypeArguments(out _);
32183222
}
3223+
#nullable disable
32193224

32203225
////////////////////////////////////////////////////////////////////////////////
32213226
//

src/Compilers/CSharp/Portable/Compilation/CSharpSemanticModel.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,7 @@ private ImmutableArray<ISymbol> LookupSymbolsInternal(
16811681
else
16821682
{
16831683
Debug.Assert(symbol.GetIsNewExtensionMember());
1684-
if (SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(binder.Compilation, symbol, receiverType) is { } compatibleSubstitutedMember)
1684+
if (SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(binder.Compilation, symbol, receiverType, wasExtensionFullyInferred: out _) is { } compatibleSubstitutedMember)
16851685
{
16861686
results.Add(compatibleSubstitutedMember.GetPublicSymbol());
16871687
}

src/Compilers/CSharp/Portable/Symbols/ConstraintsHelper.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -717,8 +717,6 @@ public static bool CheckConstraintsForNamedType(
717717

718718
public static bool CheckConstraints(this NamedTypeSymbol type, in CheckConstraintsArgs args)
719719
{
720-
Debug.Assert(args.CurrentCompilation is object);
721-
722720
if (!RequiresChecking(type))
723721
{
724722
return true;

src/Compilers/CSharp/Portable/Symbols/MemberSymbolExtensions.cs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,8 @@ internal static TMember ConstructIncludingExtension<TMember>(this TMember member
281281
else
282282
{
283283
Debug.Assert(method.GetIsNewExtensionMember());
284-
constructed = (MethodSymbol?)SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(compilation, constructed, receiverType);
285-
286-
if (checkFullyInferred && constructed?.IsGenericMethod == true && typeArguments.IsDefaultOrEmpty)
284+
constructed = (MethodSymbol?)SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(compilation, constructed, receiverType, out bool wasExtensionFullyInferred);
285+
if (checkFullyInferred && (!wasExtensionFullyInferred || (constructed?.IsGenericMethod == true && typeArguments.IsDefaultOrEmpty)))
287286
{
288287
return null;
289288
}
@@ -297,7 +296,13 @@ internal static TMember ConstructIncludingExtension<TMember>(this TMember member
297296
// infer type arguments based off the receiver type if needed, check applicability
298297
Debug.Assert(receiverType is not null);
299298
Debug.Assert(property.GetIsNewExtensionMember());
300-
return (PropertySymbol?)SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(compilation, property, receiverType);
299+
var result = (PropertySymbol?)SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(compilation, property, receiverType, wasExtensionFullyInferred: out bool wasFullyInferred);
300+
if (checkFullyInferred && !wasFullyInferred)
301+
{
302+
return null;
303+
}
304+
305+
return result;
301306
}
302307

303308
throw ExceptionUtilities.UnexpectedValue(member.Kind);

src/Compilers/CSharp/Portable/Symbols/MethodSymbol.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,8 @@ public override TResult Accept<TResult>(CSharpSymbolVisitor<TResult> visitor)
743743
return visitor.VisitMethod(this);
744744
}
745745

746-
public MethodSymbol ReduceExtensionMethod(TypeSymbol receiverType, CSharpCompilation compilation)
746+
#nullable enable
747+
public MethodSymbol? ReduceExtensionMethod(TypeSymbol receiverType, CSharpCompilation? compilation)
747748
{
748749
return ReduceExtensionMethod(receiverType, compilation, wasFullyInferred: out _);
749750
}
@@ -755,7 +756,7 @@ public MethodSymbol ReduceExtensionMethod(TypeSymbol receiverType, CSharpCompila
755756
/// <param name="compilation">The compilation in which constraints should be checked.
756757
/// Should not be null, but if it is null we treat constraints as we would in the latest
757758
/// language version.</param>
758-
public MethodSymbol ReduceExtensionMethod(TypeSymbol receiverType, CSharpCompilation compilation, out bool wasFullyInferred)
759+
public MethodSymbol? ReduceExtensionMethod(TypeSymbol receiverType, CSharpCompilation? compilation, out bool wasFullyInferred)
759760
{
760761
if ((object)receiverType == null)
761762
{
@@ -770,6 +771,7 @@ public MethodSymbol ReduceExtensionMethod(TypeSymbol receiverType, CSharpCompila
770771

771772
return ReducedExtensionMethodSymbol.Create(this, receiverType, compilation, out wasFullyInferred);
772773
}
774+
#nullable disable
773775

774776
/// <summary>
775777
/// If this is an extension method, returns a reduced extension method

src/Compilers/CSharp/Portable/Symbols/PublicModel/MethodSymbol.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,26 @@ ITypeSymbol IMethodSymbol.GetTypeInferredDuringReduction(ITypeParameterSymbol re
201201
GetPublicSymbol();
202202
}
203203

204-
IMethodSymbol IMethodSymbol.ReduceExtensionMethod(ITypeSymbol receiverType)
204+
#nullable enable
205+
IMethodSymbol? IMethodSymbol.ReduceExtensionMethod(ITypeSymbol receiverType)
205206
{
206207
return _underlying.ReduceExtensionMethod(
207208
receiverType.EnsureCSharpSymbolOrNull(nameof(receiverType)), compilation: null).
208209
GetPublicSymbol();
209210
}
210211

212+
IMethodSymbol? IMethodSymbol.ReduceExtensionMember(ITypeSymbol receiverType)
213+
{
214+
if (_underlying.GetIsNewExtensionMember() && SourceMemberContainerTypeSymbol.IsAllowedExtensionMember(_underlying))
215+
{
216+
var csharpReceiver = receiverType.EnsureCSharpSymbolOrNull(nameof(receiverType));
217+
return (IMethodSymbol?)SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(compilation: null, _underlying, csharpReceiver, wasExtensionFullyInferred: out _).GetPublicSymbol();
218+
}
219+
220+
return null;
221+
}
222+
#nullable disable
223+
211224
ImmutableArray<IMethodSymbol> IMethodSymbol.ExplicitInterfaceImplementations
212225
{
213226
get

src/Compilers/CSharp/Portable/Symbols/PublicModel/PropertySymbol.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,17 @@ ImmutableArray<CustomModifier> IPropertySymbol.RefCustomModifiers
115115
IPropertySymbol? IPropertySymbol.PartialImplementationPart => _underlying.PartialImplementationPart.GetPublicSymbol();
116116

117117
bool IPropertySymbol.IsPartialDefinition => (_underlying as SourcePropertySymbol)?.IsPartialDefinition ?? false;
118+
119+
IPropertySymbol? IPropertySymbol.ReduceExtensionMember(ITypeSymbol receiverType)
120+
{
121+
if (_underlying.GetIsNewExtensionMember() && SourceMemberContainerTypeSymbol.IsAllowedExtensionMember(_underlying))
122+
{
123+
var csharpReceiver = receiverType.EnsureCSharpSymbolOrNull(nameof(receiverType));
124+
return (IPropertySymbol?)SourceNamedTypeSymbol.GetCompatibleSubstitutedMember(compilation: null, _underlying, csharpReceiver, wasExtensionFullyInferred: out _).GetPublicSymbol();
125+
}
126+
127+
return null;
128+
}
118129
#nullable disable
119130

120131
#region ISymbol Members

src/Compilers/CSharp/Portable/Symbols/Source/SourceNamedTypeSymbol_Extension.cs

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,20 +1168,25 @@ private static string RawNameToHashString(string rawName)
11681168
return CodeAnalysis.CodeGen.PrivateImplementationDetails.HashToHex(hash);
11691169
}
11701170

1171-
internal static Symbol? GetCompatibleSubstitutedMember(CSharpCompilation compilation, Symbol extensionMember, TypeSymbol receiverType)
1171+
/// <summary>
1172+
/// Given a receiver type, check if we can infer type arguments for the extension block and check for compatibility.
1173+
/// If that is successful, return the substituted extension member and whether the extension block was fully inferred.
1174+
/// </summary>
1175+
internal static Symbol? GetCompatibleSubstitutedMember(CSharpCompilation? compilation, Symbol extensionMember, TypeSymbol receiverType, out bool wasExtensionFullyInferred)
11721176
{
11731177
Debug.Assert(extensionMember.GetIsNewExtensionMember());
11741178

11751179
NamedTypeSymbol extension = extensionMember.ContainingType;
11761180
if (extension.ExtensionParameter is null)
11771181
{
1182+
wasExtensionFullyInferred = false;
11781183
return null;
11791184
}
11801185

11811186
Symbol result;
11821187
if (extensionMember.IsDefinition)
11831188
{
1184-
NamedTypeSymbol? constructedExtension = inferExtensionTypeArguments(extension, receiverType, compilation);
1189+
NamedTypeSymbol? constructedExtension = inferExtensionTypeArguments(extension, receiverType, compilation, out wasExtensionFullyInferred);
11851190
if (constructedExtension is null)
11861191
{
11871192
return null;
@@ -1191,23 +1196,27 @@ private static string RawNameToHashString(string rawName)
11911196
}
11921197
else
11931198
{
1199+
wasExtensionFullyInferred = true;
11941200
result = extensionMember;
11951201
}
11961202

1203+
ConversionsBase conversions = compilation?.Conversions ?? (ConversionsBase)extensionMember.ContainingAssembly.CorLibrary.TypeConversions;
1204+
11971205
Debug.Assert(result.ContainingType.ExtensionParameter is not null);
11981206
var discardedUseSiteInfo = CompoundUseSiteInfo<AssemblySymbol>.Discarded;
1199-
Conversion conversion = compilation.Conversions.ConvertExtensionMethodThisArg(parameterType: result.ContainingType.ExtensionParameter.Type, receiverType, ref discardedUseSiteInfo, isMethodGroupConversion: false);
1207+
Conversion conversion = conversions.ConvertExtensionMethodThisArg(parameterType: result.ContainingType.ExtensionParameter.Type, receiverType, ref discardedUseSiteInfo, isMethodGroupConversion: false);
12001208
if (!conversion.Exists)
12011209
{
12021210
return null;
12031211
}
12041212

12051213
return result;
12061214

1207-
static NamedTypeSymbol? inferExtensionTypeArguments(NamedTypeSymbol extension, TypeSymbol receiverType, CSharpCompilation compilation)
1215+
static NamedTypeSymbol? inferExtensionTypeArguments(NamedTypeSymbol extension, TypeSymbol receiverType, CSharpCompilation? compilation, out bool wasExtensionFullyInferred)
12081216
{
12091217
if (extension.Arity == 0)
12101218
{
1219+
wasExtensionFullyInferred = true;
12111220
return extension;
12121221
}
12131222

@@ -1219,12 +1228,14 @@ private static string RawNameToHashString(string rawName)
12191228

12201229
var discardedUseSiteInfo = CompoundUseSiteInfo<AssemblySymbol>.Discarded;
12211230
ImmutableArray<TypeWithAnnotations> typeArguments = MethodTypeInferrer.InferTypeArgumentsFromReceiverType(extension, receiverValue, compilation, conversions, ref discardedUseSiteInfo);
1222-
if (typeArguments.IsDefault || typeArguments.Any(t => !t.HasType))
1231+
if (typeArguments.IsDefault)
12231232
{
1233+
wasExtensionFullyInferred = false;
12241234
return null;
12251235
}
12261236

1227-
var result = extension.Construct(typeArguments);
1237+
ImmutableArray<TypeWithAnnotations> typeArgsForConstruct = fillNotInferredTypeArguments(extension, typeArguments, out wasExtensionFullyInferred);
1238+
var result = extension.Construct(typeArgsForConstruct);
12281239

12291240
var constraintArgs = new ConstraintsHelper.CheckConstraintsArgs(compilation, conversions, includeNullability: false,
12301241
NoLocation.Singleton, diagnostics: BindingDiagnosticBag.Discarded, template: CompoundUseSiteInfo<AssemblySymbol>.Discarded);
@@ -1237,6 +1248,20 @@ private static string RawNameToHashString(string rawName)
12371248

12381249
return result;
12391250
}
1251+
1252+
static ImmutableArray<TypeWithAnnotations> fillNotInferredTypeArguments(NamedTypeSymbol extension, ImmutableArray<TypeWithAnnotations> typeArgs, out bool wasFullyInferred)
1253+
{
1254+
// For the purpose of construction we use original type parameters in place of type arguments that we couldn't infer from the first argument.
1255+
wasFullyInferred = typeArgs.All(static t => t.HasType);
1256+
if (!wasFullyInferred)
1257+
{
1258+
return typeArgs.ZipAsArray(
1259+
extension.TypeParameters,
1260+
(t, tp) => t.HasType ? t : TypeWithAnnotations.Create(tp));
1261+
}
1262+
1263+
return typeArgs;
1264+
}
12401265
}
12411266
}
12421267
}

0 commit comments

Comments
 (0)