Skip to content

Commit

Permalink
Merge pull request #70931 from CyrusNajmabadi/extractMethodRef
Browse files Browse the repository at this point in the history
Update extract-method to respect returning by-ref
  • Loading branch information
CyrusNajmabadi authored Nov 27, 2023
2 parents 153899b + 4cbb9da commit 132e539
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

namespace Microsoft.CodeAnalysis.CSharp.ExtractMethod
{
using static SyntaxFactory;

internal partial class CSharpMethodExtractor
{
private abstract partial class CSharpCodeGenerator : CodeGenerator<StatementSyntax, SyntaxNode, CSharpCodeGenerationOptions>
Expand Down Expand Up @@ -99,7 +101,7 @@ protected override IMethodSymbol GenerateMethodDefinition(
accessibility: Accessibility.Private,
modifiers: CreateMethodModifiers(),
returnType: AnalyzerResult.ReturnType,
refKind: RefKind.None,
refKind: AnalyzerResult.ReturnsByRef ? RefKind.Ref : RefKind.None,
explicitInterfaceImplementations: default,
name: _methodName.ToString(),
typeParameters: CreateMethodTypeParameters(),
Expand Down Expand Up @@ -571,47 +573,48 @@ protected override StatementSyntax CreateReturnStatement(string identifierName =
protected override ExpressionSyntax CreateCallSignature()
{
var methodName = CreateMethodNameForInvocation().WithAdditionalAnnotations(Simplifier.Annotation);
var arguments = new List<ArgumentSyntax>();
var isLocalFunction = LocalFunction && ShouldLocalFunctionCaptureParameter(SemanticDocument.Root);

using var _ = ArrayBuilder<ArgumentSyntax>.GetInstance(out var arguments);

foreach (var argument in AnalyzerResult.MethodParameters)
{
if (!isLocalFunction || !argument.CanBeCapturedByLocalFunction)
{
var modifier = GetParameterRefSyntaxKind(argument.ParameterModifier);
var refOrOut = modifier == SyntaxKind.None ? default : SyntaxFactory.Token(modifier);
arguments.Add(SyntaxFactory.Argument(SyntaxFactory.IdentifierName(argument.Name)).WithRefOrOutKeyword(refOrOut));
var refOrOut = modifier == SyntaxKind.None ? default : Token(modifier);
arguments.Add(Argument(IdentifierName(argument.Name)).WithRefOrOutKeyword(refOrOut));
}
}

var invocation = SyntaxFactory.InvocationExpression(methodName,
SyntaxFactory.ArgumentList(SyntaxFactory.SeparatedList(arguments)));

var shouldPutAsyncModifier = this.SelectionResult.ShouldPutAsyncModifier();
if (!shouldPutAsyncModifier)
{
return invocation;
}

if (this.SelectionResult.ShouldCallConfigureAwaitFalse())
var invocation = (ExpressionSyntax)InvocationExpression(methodName, ArgumentList(SeparatedList(arguments)));
if (this.SelectionResult.ShouldPutAsyncModifier())
{
if (AnalyzerResult.ReturnType.GetMembers().Any(static x => x is IMethodSymbol
{
Name: nameof(Task.ConfigureAwait),
Parameters: { Length: 1 } parameters
} && parameters[0].Type.SpecialType == SpecialType.System_Boolean))
if (this.SelectionResult.ShouldCallConfigureAwaitFalse())
{
invocation = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
invocation,
SyntaxFactory.IdentifierName(nameof(Task.ConfigureAwait))),
SyntaxFactory.ArgumentList(SyntaxFactory.SingletonSeparatedList(
SyntaxFactory.Argument(SyntaxFactory.LiteralExpression(SyntaxKind.FalseLiteralExpression)))));
if (AnalyzerResult.ReturnType.GetMembers().Any(static x => x is IMethodSymbol
{
Name: nameof(Task.ConfigureAwait),
Parameters: [{ Type.SpecialType: SpecialType.System_Boolean }],
}))
{
invocation = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
invocation,
IdentifierName(nameof(Task.ConfigureAwait))),
ArgumentList(SingletonSeparatedList(
Argument(LiteralExpression(SyntaxKind.FalseLiteralExpression)))));
}
}

invocation = AwaitExpression(invocation);
}

return SyntaxFactory.AwaitExpression(invocation);
if (AnalyzerResult.ReturnsByRef)
invocation = RefExpression(invocation);

return invocation;
}

protected override StatementSyntax CreateAssignmentExpressionStatement(SyntaxToken identifier, ExpressionSyntax rvalue)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public override bool ContainingScopeHasAsyncKeyword()
return CSharpSyntaxFacts.Instance.GetRootStandaloneExpression(scope);
}

public override ITypeSymbol? GetContainingScopeType()
public override (ITypeSymbol? returnType, bool returnsByRef) GetReturnType()
{
if (GetContainingScope() is not ExpressionSyntax node)
{
Expand All @@ -60,9 +60,7 @@ public override bool ContainingScopeHasAsyncKeyword()
{
var variableDeclExpression = node.GetAncestorOrThis<VariableDeclarationSyntax>();
if (variableDeclExpression != null)
{
return model.GetTypeInfo(variableDeclExpression.Type).Type;
}
return (model.GetTypeInfo(variableDeclExpression.Type).Type, returnsByRef: false);
}

if (node.IsExpressionInCast())
Expand All @@ -72,57 +70,65 @@ public override bool ContainingScopeHasAsyncKeyword()
// 1. if regular binding returns a meaningful type, we use it as it is
// 2. if it doesn't, even if the cast itself wasn't included in the selection, we will treat it
// as it was in the selection
var regularType = GetRegularExpressionType(model, node);
var (regularType, returnsByRef) = GetRegularExpressionType(model, node);
if (regularType != null)
{
return regularType;
}
return (regularType, returnsByRef);

if (node.Parent is CastExpressionSyntax castExpression)
{
return model.GetTypeInfo(castExpression).Type;
}
return (model.GetTypeInfo(castExpression).Type, returnsByRef: false);
}

return GetRegularExpressionType(model, node);
}

private static ITypeSymbol? GetRegularExpressionType(SemanticModel semanticModel, ExpressionSyntax node)
private static (ITypeSymbol? typeSymbol, bool returnsByRef) GetRegularExpressionType(SemanticModel semanticModel, ExpressionSyntax node)
{
// regular case. always use ConvertedType to get implicit conversion right.
var expression = node.GetUnparenthesizedExpression();

var info = semanticModel.GetTypeInfo(expression);
var conv = semanticModel.GetConversion(expression);

if (info.ConvertedType == null || info.ConvertedType.IsErrorType())
var returnsByRef = false;
if (expression is RefExpressionSyntax refExpression)
{
// there is no implicit conversion involved. no need to go further
return info.GetTypeWithAnnotatedNullability();
expression = refExpression.Expression;
returnsByRef = true;
}

// always use converted type if method group
if ((!node.IsKind(SyntaxKind.ObjectCreationExpression) && semanticModel.GetMemberGroup(expression).Length > 0) ||
IsCoClassImplicitConversion(info, conv, semanticModel.Compilation.CoClassType()))
{
return info.GetConvertedTypeWithAnnotatedNullability();
}
var typeSymbol = GetRegularExpressionTypeWorker();
return (typeSymbol, returnsByRef);

// check implicit conversion
if (conv.IsImplicit && (conv.IsConstantExpression || conv.IsEnumeration))
ITypeSymbol? GetRegularExpressionTypeWorker()
{
return info.GetConvertedTypeWithAnnotatedNullability();
}
var info = semanticModel.GetTypeInfo(expression);
var conv = semanticModel.GetConversion(expression);

// use FormattableString if conversion between String and FormattableString
if (info.Type?.SpecialType == SpecialType.System_String &&
info.ConvertedType?.IsFormattableStringOrIFormattable() == true)
{
return info.GetConvertedTypeWithAnnotatedNullability();
}
if (info.ConvertedType == null || info.ConvertedType.IsErrorType())
{
// there is no implicit conversion involved. no need to go further
return info.GetTypeWithAnnotatedNullability();
}

// always use converted type if method group
if ((!node.IsKind(SyntaxKind.ObjectCreationExpression) && semanticModel.GetMemberGroup(expression).Length > 0) ||
IsCoClassImplicitConversion(info, conv, semanticModel.Compilation.CoClassType()))
{
return info.GetConvertedTypeWithAnnotatedNullability();
}

// check implicit conversion
if (conv.IsImplicit && (conv.IsConstantExpression || conv.IsEnumeration))
{
return info.GetConvertedTypeWithAnnotatedNullability();
}

// use FormattableString if conversion between String and FormattableString
if (info.Type?.SpecialType == SpecialType.System_String &&
info.ConvertedType?.IsFormattableStringOrIFormattable() == true)
{
return info.GetConvertedTypeWithAnnotatedNullability();
}

// always try to use type that is more specific than object type if possible.
return !info.Type.IsObjectType() ? info.GetTypeWithAnnotatedNullability() : info.GetConvertedTypeWithAnnotatedNullability();
// always try to use type that is more specific than object type if possible.
return !info.Type.IsObjectType() ? info.GetTypeWithAnnotatedNullability() : info.GetConvertedTypeWithAnnotatedNullability();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ AnonymousMethodExpressionSyntax or
});
}

public override ITypeSymbol GetContainingScopeType()
public override (ITypeSymbol returnType, bool returnsByRef) GetReturnType()
{
Contract.ThrowIfTrue(SelectionInExpression);

Expand All @@ -73,31 +73,31 @@ public override ITypeSymbol GetContainingScopeType()
case AccessorDeclarationSyntax access:
// property or event case
if (access.Parent == null || access.Parent.Parent == null)
{
return null;
}
return default;

return semanticModel.GetDeclaredSymbol(access.Parent.Parent) switch
{
IPropertySymbol propertySymbol => propertySymbol.Type,
IEventSymbol eventSymbol => eventSymbol.Type,
_ => null,
IPropertySymbol propertySymbol => (propertySymbol.Type, propertySymbol.ReturnsByRef),
IEventSymbol eventSymbol => (eventSymbol.Type, false),
_ => default,
};

case MethodDeclarationSyntax method:
return semanticModel.GetDeclaredSymbol(method).ReturnType;

case ParenthesizedLambdaExpressionSyntax lambda:
return semanticModel.GetLambdaOrAnonymousMethodReturnType(lambda);

case SimpleLambdaExpressionSyntax lambda:
return semanticModel.GetLambdaOrAnonymousMethodReturnType(lambda);
case MethodDeclarationSyntax methodDeclaration:
{
return semanticModel.GetDeclaredSymbol(methodDeclaration) is not IMethodSymbol method
? default
: (method.ReturnType, method.ReturnsByRef);
}

case AnonymousMethodExpressionSyntax anonymous:
return semanticModel.GetLambdaOrAnonymousMethodReturnType(anonymous);
case AnonymousFunctionExpressionSyntax function:
{
return semanticModel.GetSymbolInfo(function).Symbol is not IMethodSymbol method
? default
: (method.ReturnType, method.ReturnsByRef);
}

default:
return null;
return default;
}
}
}
Expand Down
42 changes: 25 additions & 17 deletions src/Features/CSharp/Portable/ExtractMethod/CSharpSelectionResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,31 +75,39 @@ protected CSharpSelectionResult(
{
}

protected override ISyntaxFacts SyntaxFacts => CSharpSyntaxFacts.Instance;
protected override ISyntaxFacts SyntaxFacts
=> CSharpSyntaxFacts.Instance;

public override SyntaxNode GetNodeForDataFlowAnalysis()
{
var node = base.GetNodeForDataFlowAnalysis();

// If we're returning a value by ref we actually want to do the analysis on the underlying expression.
return node is RefExpressionSyntax refExpression
? refExpression.Expression
: node;
}

protected override bool UnderAnonymousOrLocalMethod(SyntaxToken token, SyntaxToken firstToken, SyntaxToken lastToken)
{
var current = token.Parent;
for (; current != null; current = current.Parent)
for (var current = token.Parent; current != null; current = current.Parent)
{
if (current is MemberDeclarationSyntax or
SimpleLambdaExpressionSyntax or
ParenthesizedLambdaExpressionSyntax or
AnonymousMethodExpressionSyntax or
LocalFunctionStatementSyntax)
if (current is MemberDeclarationSyntax)
return false;

if (current is
SimpleLambdaExpressionSyntax or
ParenthesizedLambdaExpressionSyntax or
AnonymousMethodExpressionSyntax or
LocalFunctionStatementSyntax)
{
break;
// make sure the selection contains the lambda
return firstToken.SpanStart <= current.GetFirstToken().SpanStart &&
current.GetLastToken().Span.End <= lastToken.Span.End;
}
}

if (current is null or MemberDeclarationSyntax)
{
return false;
}

// make sure the selection contains the lambda
return firstToken.SpanStart <= current.GetFirstToken().SpanStart &&
current.GetLastToken().Span.End <= lastToken.Span.End;
return false;
}

public override SyntaxNode GetOutermostCallSiteContainerToProcess(CancellationToken cancellationToken)
Expand Down
Loading

0 comments on commit 132e539

Please sign in to comment.