Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check for particular collection expr attribute by name, as it can be provided internally in packages #70179

Merged
merged 2 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ public CSharpUseCollectionExpressionForCreateDiagnosticAnalyzer()
{
}

protected override bool IsSupported(Compilation compilation)
=> compilation.CollectionBuilderAttribute() is not null;

protected override void InitializeWorker(CodeBlockStartAnalysisContext<SyntaxKind> context)
=> context.RegisterSyntaxNodeAction(AnalyzeInvocationExpression, SyntaxKind.InvocationExpression);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ bool IsConstructibleCollectionType(ITypeSymbol? type)
return true;

// If it has a [CollectionBuilder] attribute on it, it is a valid collection expression type.
var collectionBuilderType = compilation.CollectionBuilderAttribute();
if (namedType.GetAttributes().Any(a => a.AttributeClass?.Equals(collectionBuilderType) is true))
if (namedType.GetAttributes().Any(a => a.AttributeClass.IsCollectionBuilderAttribute()))
return true;

// At this point, all that is left are collection-initializer types. These need to derive from
Expand Down Expand Up @@ -513,23 +512,20 @@ public static bool IsCollectionFactoryCreate(
}

memberAccess = memberAccessExpression;
var createMethod = semanticModel.GetSymbolInfo(memberAccessExpression, cancellationToken).Symbol as IMethodSymbol;
if (createMethod is not { IsStatic: true })
if (semanticModel.GetSymbolInfo(memberAccessExpression, cancellationToken).Symbol is not IMethodSymbol { IsStatic: true } createMethod)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (semanticModel.GetSymbolInfo(memberAccessExpression, cancellationToken).Symbol is not IMethodSymbol { IsStatic: true } createMethod)
if (semanticModel.GetSymbolInfo(memberAccessExpression, cancellationToken).Symbol is not IMethodSymbol { IsStatic: true } createMethod)

return false;

var factoryType = semanticModel.GetSymbolInfo(memberAccessExpression.Expression, cancellationToken).Symbol as INamedTypeSymbol;
if (factoryType is null)
if (semanticModel.GetSymbolInfo(memberAccessExpression.Expression, cancellationToken).Symbol is not INamedTypeSymbol factoryType)
return false;

var compilation = semanticModel.Compilation;

// The pattern is a type like `ImmutableArray` (non-generic), returning an instance of `ImmutableArray<T>`. The
// actual collection type (`ImmutableArray<T>`) has to have a `[CollectionBuilder(...)]` attribute on it that
// then points at the factory type.
var collectionBuilderAttribute = compilation.CollectionBuilderAttribute()!;
var collectionBuilderAttributeData = createMethod.ReturnType.OriginalDefinition
.GetAttributes()
.FirstOrDefault(a => collectionBuilderAttribute.Equals(a.AttributeClass));
.FirstOrDefault(a => a.AttributeClass.IsCollectionBuilderAttribute());
if (collectionBuilderAttributeData?.ConstructorArguments is not [{ Value: ITypeSymbol collectionBuilderType }, { Value: CreateName }])
return false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,6 @@ public static ImmutableArray<IAssemblySymbol> GetReferencedAssemblySymbols(this
public static INamedTypeSymbol? OnSerializedAttribute(this Compilation compilation)
=> compilation.GetTypeByMetadataName(typeof(OnSerializedAttribute).FullName!);

public static INamedTypeSymbol? CollectionBuilderAttribute(this Compilation compilation)
=> compilation.GetTypeByMetadataName("System.Runtime.CompilerServices.CollectionBuilderAttribute");

public static INamedTypeSymbol? ComRegisterFunctionAttribute(this Compilation compilation)
=> compilation.GetTypeByMetadataName(typeof(ComRegisterFunctionAttribute).FullName!);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -619,5 +619,24 @@ private static bool IsEqualsObject(ISymbol member)

public static INamedTypeSymbol TryConstruct(this INamedTypeSymbol type, ITypeSymbol[] typeArguments)
=> typeArguments.Length > 0 ? type.Construct(typeArguments) : type;

public static bool IsCollectionBuilderAttribute([NotNullWhen(true)] this INamedTypeSymbol? type)
=> type is
{
Name: "CollectionBuilderAttribute",
ContainingNamespace:
{
Name: nameof(System.Runtime.CompilerServices),
ContainingNamespace:
{
Name: nameof(System.Runtime),
ContainingNamespace:
{
Name: nameof(System),
ContainingNamespace.IsGlobalNamespace: true,
}
}
}
};
}
}