diff --git a/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/LambdaMethodReference.java b/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/LambdaMethodReference.java index 282b1e6dc..2b6809a25 100644 --- a/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/LambdaMethodReference.java +++ b/baseline-error-prone/src/main/java/com/palantir/baseline/errorprone/LambdaMethodReference.java @@ -17,6 +17,7 @@ package com.palantir.baseline.errorprone; import com.google.auto.service.AutoService; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import com.google.errorprone.BugPattern; import com.google.errorprone.VisitorState; @@ -25,16 +26,26 @@ import com.google.errorprone.fixes.SuggestedFixes; import com.google.errorprone.matchers.Description; import com.google.errorprone.util.ASTHelpers; +import com.google.errorprone.util.ErrorProneToken; import com.sun.source.tree.BlockTree; +import com.sun.source.tree.ClassTree; import com.sun.source.tree.ExpressionTree; +import com.sun.source.tree.IdentifierTree; import com.sun.source.tree.LambdaExpressionTree; import com.sun.source.tree.MethodInvocationTree; import com.sun.source.tree.ReturnTree; import com.sun.source.tree.StatementTree; import com.sun.source.tree.Tree; +import com.sun.source.tree.VariableTree; +import com.sun.tools.javac.code.Flags; import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Type; +import com.sun.tools.javac.parser.Tokens; +import java.util.List; +import java.util.Objects; import java.util.Optional; +import java.util.Set; +import javax.annotation.Nullable; @AutoService(BugChecker.class) @BugPattern( @@ -44,18 +55,13 @@ providesFix = BugPattern.ProvidesFix.REQUIRES_HUMAN_ATTENTION, severity = BugPattern.SeverityLevel.SUGGESTION, summary = "Lambda should be a method reference") +@SuppressWarnings("checkstyle:CyclomaticComplexity") public final class LambdaMethodReference extends BugChecker implements BugChecker.LambdaExpressionTreeMatcher { private static final String MESSAGE = "Lambda should be a method reference"; @Override public Description matchLambdaExpression(LambdaExpressionTree tree, VisitorState state) { - // Only handle simple no-arg method references for the time being, don't worry about - // simplifying map.forEach((k, v) -> func(k, v)) to map.forEach(this::func) - if (tree.getParameters().size() > 1) { - return Description.NO_MATCH; - } - LambdaExpressionTree.BodyKind bodyKind = tree.getBodyKind(); Tree body = tree.getBody(); // n.b. These checks are meant to avoid any and all cleverness. The goal is to be confident @@ -91,56 +97,190 @@ public Description matchLambdaExpression(LambdaExpressionTree tree, VisitorState private Description checkMethodInvocation( MethodInvocationTree methodInvocation, LambdaExpressionTree root, VisitorState state) { - if (!methodInvocation.getArguments().isEmpty() - || !methodInvocation.getTypeArguments().isEmpty()) { + Symbol.MethodSymbol methodSymbol = ASTHelpers.getSymbol(methodInvocation); + if (methodSymbol == null + || !methodInvocation.getTypeArguments().isEmpty() + || hasExplicitParameterTypes(root, state)) { return Description.NO_MATCH; } - Symbol.MethodSymbol methodSymbol = ASTHelpers.getSymbol(methodInvocation); - if (methodSymbol == null || shouldIgnore(methodSymbol, root, methodInvocation)) { + + ExpressionTree receiver = ASTHelpers.getReceiver(methodInvocation); + boolean isLocal = isLocal(methodInvocation); + if (!isLocal && !(receiver instanceof IdentifierTree)) { return Description.NO_MATCH; } - return buildDescription(root) - .setMessage(MESSAGE) - .addFix(buildFix(methodSymbol, methodInvocation, root, state)) - .build(); + if (methodInvocation.getArguments().isEmpty() && root.getParameters().size() == 1) { + return convertVariableInstanceMethods(methodSymbol, methodInvocation, root, state); + } + + if (methodInvocation.getArguments().size() == root.getParameters().size()) { + return convertMethodInvocations(methodSymbol, methodInvocation, root, state); + } + + return Description.NO_MATCH; } - private static boolean shouldIgnore( - Symbol.MethodSymbol methodSymbol, LambdaExpressionTree root, MethodInvocationTree methodInvocation) { - if (!methodSymbol.isStatic()) { - if (root.getParameters().size() == 1) { - Symbol paramSymbol = ASTHelpers.getSymbol(Iterables.getOnlyElement(root.getParameters())); - Symbol receiverSymbol = ASTHelpers.getSymbol(ASTHelpers.getReceiver(methodInvocation)); - return !paramSymbol.equals(receiverSymbol); + private static boolean hasExplicitParameterTypes(LambdaExpressionTree lambda, VisitorState state) { + for (VariableTree varTree : lambda.getParameters()) { + boolean expectComma = false; + // Must avoid refactoring lambdas which declare explicit parameter types + for (ErrorProneToken token : state.getTokensForNode(varTree)) { + if (token.kind() == Tokens.TokenKind.EOF) { + return false; + } else if ((token.kind() == Tokens.TokenKind.IDENTIFIER && expectComma) + || (token.kind() == Tokens.TokenKind.COMMA && !expectComma)) { + return true; + } + expectComma = !expectComma; } - return true; } - return !root.getParameters().isEmpty(); + return false; + } + + private Description convertVariableInstanceMethods( + Symbol.MethodSymbol methodSymbol, + MethodInvocationTree methodInvocation, + LambdaExpressionTree root, + VisitorState state) { + Symbol paramSymbol = ASTHelpers.getSymbol(Iterables.getOnlyElement(root.getParameters())); + Symbol receiverSymbol = ASTHelpers.getSymbol(ASTHelpers.getReceiver(methodInvocation)); + if (!paramSymbol.equals(receiverSymbol)) { + return Description.NO_MATCH; + } + return buildFix(methodSymbol, methodInvocation, root, state, isLocal(methodInvocation)) + .map(fix -> + buildDescription(root).setMessage(MESSAGE).addFix(fix).build()) + .orElse(Description.NO_MATCH); + } + + private Description convertMethodInvocations( + Symbol.MethodSymbol methodSymbol, + MethodInvocationTree methodInvocation, + LambdaExpressionTree root, + VisitorState state) { + List methodParams = getSymbols(methodInvocation.getArguments()); + List lambdaParam = getSymbols(root.getParameters()); + + // We are guaranteed that all of root params are symbols so equality should handle cases where methodInvocation + // arguments are not symbols or are out of order + if (!methodParams.equals(lambdaParam)) { + return Description.NO_MATCH; + } + + return buildFix(methodSymbol, methodInvocation, root, state, isLocal(methodInvocation)) + .map(fix -> + buildDescription(root).setMessage(MESSAGE).addFix(fix).build()) + .orElse(Description.NO_MATCH); + } + + private static List getSymbols(List params) { + return params.stream() + .map(ASTHelpers::getSymbol) + .filter(Objects::nonNull) + .collect(ImmutableList.toImmutableList()); } private static Optional buildFix( Symbol.MethodSymbol symbol, MethodInvocationTree invocation, LambdaExpressionTree root, - VisitorState state) { + VisitorState state, + boolean isLocal) { + if (isAmbiguousMethod(symbol, ASTHelpers.getReceiver(invocation), state)) { + return Optional.empty(); + } + SuggestedFix.Builder builder = SuggestedFix.builder(); - return toMethodReference(qualifyTarget(symbol, invocation, builder, state)) + return qualifyTarget(symbol, invocation, root, builder, state, isLocal) + .flatMap(LambdaMethodReference::toMethodReference) .map(qualified -> builder.replace(root, qualified).build()); } - private static String qualifyTarget( + private static boolean isAmbiguousMethod( + Symbol.MethodSymbol symbol, @Nullable ExpressionTree receiver, VisitorState state) { + if (symbol.isStatic()) { + if (symbol.params().size() != 1) { + return false; + } + Symbol.ClassSymbol classSymbol = ASTHelpers.enclosingClass(symbol); + if (classSymbol == null) { + return false; + } + Set matching = ASTHelpers.findMatchingMethods( + symbol.name, + sym -> sym != null && !sym.isStatic() && sym.getParameters().isEmpty(), + classSymbol.type, + state.getTypes()); + return !matching.isEmpty(); + } else { + if (!symbol.params().isEmpty()) { + return false; + } + if (receiver == null) { + return false; + } + Type receiverType = ASTHelpers.getType(receiver); + if (receiverType == null) { + return false; + } + Set matching = ASTHelpers.findMatchingMethods( + symbol.name, + sym -> sym != null + && sym.isStatic() + && sym.getParameters().size() == 1 + && state.getTypes() + .isAssignable( + state.getTypes().erasure(receiverType), + state.getTypes() + .erasure(sym.params().get(0).type)), + receiverType, + state.getTypes()); + return !matching.isEmpty(); + } + } + + private static Optional qualifyTarget( Symbol.MethodSymbol symbol, MethodInvocationTree invocation, + LambdaExpressionTree root, SuggestedFix.Builder builder, - VisitorState state) { + VisitorState state, + boolean isLocal) { + if (!symbol.isStatic() && isLocal) { + // Validate teh local method is defined in this class + ClassTree enclosingClass = ASTHelpers.findEnclosingNode(state.getPath(), ClassTree.class); + if (enclosingClass == null) { + return Optional.empty(); + } + Type.ClassType enclosingType = ASTHelpers.getType(enclosingClass); + if (!ASTHelpers.findMatchingMethods(symbol.name, symbol::equals, enclosingType, state.getTypes()) + .isEmpty()) { + return Optional.of("this." + symbol.name.toString()); + } + return Optional.empty(); + } + + ExpressionTree receiver = ASTHelpers.getReceiver(invocation); Type receiverType = ASTHelpers.getReceiverType(invocation); if (receiverType == null || receiverType.getLowerBound() != null || receiverType.getUpperBound() != null) { - return SuggestedFixes.qualifyType(state, builder, symbol); + return Optional.of(SuggestedFixes.qualifyType(state, builder, symbol)); + } + Symbol receiverSymbol = ASTHelpers.getSymbol(receiver); + if (!symbol.isStatic() + && receiver instanceof IdentifierTree + && !Objects.equals(ImmutableList.of(receiverSymbol), getSymbols(root.getParameters()))) { + if (!isFinal(receiverSymbol)) { + // Not safe to replace lambdas which lazily reference values with an eager capture. + return Optional.empty(); + } + return Optional.of(state.getSourceForNode(receiver) + '.' + symbol.name.toString()); } - return SuggestedFixes.qualifyType(state, builder, state.getTypes().erasure(receiverType)) - + '.' - + symbol.name.toString(); + + return Optional.of( + SuggestedFixes.qualifyType(state, builder, state.getTypes().erasure(receiverType)) + + '.' + + symbol.name.toString()); } private static Optional toMethodReference(String qualifiedMethodName) { @@ -151,4 +291,15 @@ private static Optional toMethodReference(String qualifiedMethodName) { } return Optional.empty(); } + + private static boolean isLocal(MethodInvocationTree methodInvocationTree) { + ExpressionTree receiver = ASTHelpers.getReceiver(methodInvocationTree); + return receiver == null + || (receiver instanceof IdentifierTree + && "this".equals(((IdentifierTree) receiver).getName().toString())); + } + + private static boolean isFinal(Symbol symbol) { + return (symbol.flags() & (Flags.FINAL | Flags.EFFECTIVELY_FINAL)) != 0; + } } diff --git a/baseline-error-prone/src/test/java/com/palantir/baseline/errorprone/LambdaMethodReferenceTest.java b/baseline-error-prone/src/test/java/com/palantir/baseline/errorprone/LambdaMethodReferenceTest.java index 84d8eb501..b2bd0c435 100644 --- a/baseline-error-prone/src/test/java/com/palantir/baseline/errorprone/LambdaMethodReferenceTest.java +++ b/baseline-error-prone/src/test/java/com/palantir/baseline/errorprone/LambdaMethodReferenceTest.java @@ -20,24 +20,24 @@ import com.google.errorprone.BugCheckerRefactoringTestHelper; import com.google.errorprone.CompilationTestHelper; import java.util.List; +import java.util.Map; import java.util.Optional; -import org.junit.jupiter.api.BeforeEach; +import java.util.function.Supplier; import org.junit.jupiter.api.Test; public class LambdaMethodReferenceTest { - private CompilationTestHelper compilationHelper; - private RefactoringValidator refactoringValidator; + private CompilationTestHelper compile() { + return CompilationTestHelper.newInstance(LambdaMethodReference.class, getClass()); + } - @BeforeEach - public void before() { - compilationHelper = CompilationTestHelper.newInstance(LambdaMethodReference.class, getClass()); - refactoringValidator = RefactoringValidator.of(new LambdaMethodReference(), getClass()); + private RefactoringValidator refactor() { + return RefactoringValidator.of(new LambdaMethodReference(), getClass()); } @Test public void testMethodReference() { - compilationHelper + compile() .addSourceLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -52,8 +52,8 @@ public void testMethodReference() { } @Test - void testFunction() { - compilationHelper + void testInstanceMethod() { + compile() .addSourceLines( "Test.java", "import " + Optional.class.getName() + ';', @@ -67,8 +67,84 @@ void testFunction() { } @Test - public void testPositive_block() { - compilationHelper + void testLocalInstanceMethod() { + compile() + .addSourceLines( + "Test.java", + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional foo(Optional optional) {", + " // BUG: Diagnostic contains: Lambda should be a method reference", + " return optional.map(v -> bar(v));", + " }", + " private Integer bar(String value) {", + " return value.length();", + " }", + "}") + .doTest(); + } + + @Test + public void testLocalInstanceMethodSupplier() { + compile() + .addSourceLines( + "Test.java", + "import " + ImmutableList.class.getName() + ';', + "import " + List.class.getName() + ';', + "import " + Optional.class.getName() + ';', + "class Test {", + " public List foo(Optional> optional) {", + " // BUG: Diagnostic contains: Lambda should be a method reference", + " return optional.orElseGet(() -> bar());", + " }", + " private List bar() {", + " return ImmutableList.of();", + " }", + "}") + .doTest(); + } + + @Test + void testLocalStaticMethod_multiParam() { + compile() + .addSourceLines( + "Test.java", + "import " + Map.class.getName() + ';', + "class Test {", + " public void foo(Map map) {", + " // BUG: Diagnostic contains: Lambda should be a method reference", + " map.forEach((k, v) -> bar(k, v));", + " }", + " private static void bar(String key, String value) {", + " System.out.println(key + value);", + " }", + "}") + .doTest(); + } + + @Test + public void testLocalMethodSupplier_block() { + compile() + .addSourceLines( + "Test.java", + "import " + ImmutableList.class.getName() + ';', + "import " + List.class.getName() + ';', + "import " + Optional.class.getName() + ';', + "class Test {", + " public List foo(Optional> optional) {", + " // BUG: Diagnostic contains: Lambda should be a method reference", + " return optional.orElseGet(() -> { return bar(); });", + " }", + " private List bar() {", + " return ImmutableList.of();", + " }", + "}") + .doTest(); + } + + @Test + public void testStaticMethod_block() { + compile() .addSourceLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -84,8 +160,8 @@ public void testPositive_block() { } @Test - public void testAutoFix_block() { - refactoringValidator + public void testAutoFix_staticMethod_block() { + refactor() .addInputLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -110,8 +186,34 @@ public void testAutoFix_block() { } @Test - void testAutoFix_instanceMethod() { - refactoringValidator + public void testAutoFix_staticMethodWithParam() { + refactor() + .addInputLines( + "Test.java", + "import " + ImmutableList.class.getName() + ';', + "import " + List.class.getName() + ';', + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional> foo(Optional optional) {", + " return optional.map(v -> ImmutableList.of(v));", + " }", + "}") + .addOutputLines( + "Test.java", + "import " + ImmutableList.class.getName() + ';', + "import " + List.class.getName() + ';', + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional> foo(Optional optional) {", + " return optional.map(ImmutableList::of);", + " }", + "}") + .doTest(BugCheckerRefactoringTestHelper.TestMode.TEXT_MATCH); + } + + @Test + void testAutoFix_InstanceMethod() { + refactor() .addInputLines( "Test.java", "import " + Optional.class.getName() + ';', @@ -132,8 +234,74 @@ void testAutoFix_instanceMethod() { } @Test - void testAutoFix_specificInstanceMethod() { - refactoringValidator + void testNegative_InstanceMethodWithType() { + refactor() + .addInputLines( + "Test.java", + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional foo(Optional optional) {", + " return optional.map((String v) -> v.length());", + " }", + "}") + .expectUnchanged() + .doTest(); + } + + @Test + void testNegative_ambiguousStaticReference() { + refactor() + .addInputLines( + "Test.java", + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional foo(Optional optional) {", + " return optional.map(value -> Long.toString(value));", + " }", + "}") + .expectUnchanged() + .doTest(); + } + + @Test + void testNegative_ambiguousInstanceReference() { + refactor() + .addInputLines( + "Test.java", + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional foo(Optional optional) {", + " return optional.map(value -> value.toString());", + " }", + "}") + .expectUnchanged() + .doTest(); + } + + @Test + void testNegative_ambiguousThis() { + refactor() + .addInputLines( + "Test.java", + "import " + Supplier.class.getName() + ';', + "class Test {", + " class Inner {", + " public Supplier foo() {", + // this::bar is incorrect because 'this' is Inner and 'bar' is defined on 'Test'. + " return () -> bar();", + " }", + " }", + " private String bar() {", + " return \"\";", + " }", + "}") + .expectUnchanged() + .doTest(); + } + + @Test + void testAutoFix_SpecificInstanceMethod() { + refactor() .addInputLines( "Test.java", "import " + Optional.class.getName() + ';', @@ -154,8 +322,8 @@ void testAutoFix_specificInstanceMethod() { } @Test - void testAutoFix_specificInstanceMethod_withTypeParameters() { - refactoringValidator + void testAutoFix_SpecificInstanceMethod_withTypeParameters() { + refactor() .addInputLines( "Test.java", "import " + Optional.class.getName() + ';', @@ -176,20 +344,136 @@ void testAutoFix_specificInstanceMethod_withTypeParameters() { } @Test - public void testNegative_block_localMethod() { - compilationHelper - .addSourceLines( + void testAutoFix_localInstanceMethod() { + refactor() + .addInputLines( "Test.java", - "import " + ImmutableList.class.getName() + ';', - "import " + List.class.getName() + ';', "import " + Optional.class.getName() + ';', "class Test {", - " public List foo(Optional> optional) {", - // A future improvement may rewrite the following to 'orElseGet(this::bar)' - " return optional.orElseGet(() -> { return bar(); });", + " public Optional foo(Optional optional) {", + " return optional.map(v -> bar(v));", " }", - " private List bar() {", - " return ImmutableList.of();", + " private String bar(String v) {", + " return v;", + " }", + "}") + .addOutputLines( + "Test.java", + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional foo(Optional optional) {", + " return optional.map(this::bar);", + " }", + " private String bar(String v) {", + " return v;", + " }", + "}") + .doTest(); + } + + @Test + void testAutoFix_localInstanceMethod_explicitThis() { + refactor() + .addInputLines( + "Test.java", + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional foo(Optional optional) {", + " return optional.map(v -> this.bar(v));", + " }", + " private String bar(String v) {", + " return v;", + " }", + "}") + .addOutputLines( + "Test.java", + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional foo(Optional optional) {", + " return optional.map(this::bar);", + " }", + " private String bar(String v) {", + " return v;", + " }", + "}") + .doTest(); + } + + @Test + void testAutoFix_localStaticMethod() { + refactor() + .addInputLines( + "Test.java", + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional foo(Optional optional) {", + " return optional.map(v -> bar(v));", + " }", + " private static String bar(String v) {", + " return v;", + " }", + "}") + .addOutputLines( + "Test.java", + "import " + Optional.class.getName() + ';', + "class Test {", + " public Optional foo(Optional optional) {", + " return optional.map(Test::bar);", + " }", + " private static String bar(String v) {", + " return v;", + " }", + "}") + .doTest(); + } + + @Test + void testAutoFix_localStaticMethod_multiParam() { + refactor() + .addInputLines( + "Test.java", + "import " + Map.class.getName() + ';', + "class Test {", + " public void foo(Map map) {", + " map.forEach((k, v) -> bar(k, v));", + " }", + " private static void bar(String key, String value) {", + " System.out.println(key + value);", + " }", + "}") + .addOutputLines( + "Test.java", + "import " + Map.class.getName() + ';', + "class Test {", + " public void foo(Map map) {", + " map.forEach(Test::bar);", + " }", + " private static void bar(String key, String value) {", + " System.out.println(key + value);", + " }", + "}") + .doTest(); + } + + @Test + void testAutoFix_StaticMethod_multiParam() { + refactor() + .addInputLines( + "Test.java", + "import " + Map.class.getName() + ';', + "import " + ImmutableList.class.getName() + ';', + "class Test {", + " public void foo(Map map) {", + " map.forEach((k, v) -> ImmutableList.of(k, v));", + " }", + "}") + .addOutputLines( + "Test.java", + "import " + Map.class.getName() + ';', + "import " + ImmutableList.class.getName() + ';', + "class Test {", + " public void foo(Map map) {", + " map.forEach(ImmutableList::of);", " }", "}") .doTest(); @@ -197,7 +481,7 @@ public void testNegative_block_localMethod() { @Test public void testAutoFix_block_localMethod() { - refactoringValidator + refactor() .addInputLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -218,8 +502,7 @@ public void testAutoFix_block_localMethod() { "import " + Optional.class.getName() + ';', "class Test {", " public List foo(Optional> optional) {", - // This is not modified, may be improved later - " return optional.orElseGet(() -> { return bar(); });", + " return optional.orElseGet(this::bar);", " }", " private List bar() {", " return ImmutableList.of();", @@ -230,7 +513,7 @@ public void testAutoFix_block_localMethod() { @Test public void testNegative_block() { - compilationHelper + compile() .addSourceLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -246,7 +529,7 @@ public void testNegative_block() { @Test public void testPositive_expression() { - compilationHelper + compile() .addSourceLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -263,7 +546,7 @@ public void testPositive_expression() { @Test public void testAutoFix_expression() { - refactoringValidator + refactor() .addInputLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -287,29 +570,9 @@ public void testAutoFix_expression() { .doTest(BugCheckerRefactoringTestHelper.TestMode.TEXT_MATCH); } - @Test - public void testNegative_expression_localMethod() { - compilationHelper - .addSourceLines( - "Test.java", - "import " + ImmutableList.class.getName() + ';', - "import " + List.class.getName() + ';', - "import " + Optional.class.getName() + ';', - "class Test {", - " public List foo(Optional> optional) {", - // A future improvement may rewrite the following to 'orElseGet(this::bar)' - " return optional.orElseGet(() -> bar());", - " }", - " private List bar() {", - " return ImmutableList.of();", - " }", - "}") - .doTest(); - } - @Test public void testNegative_expression_staticMethod() { - compilationHelper + compile() .addSourceLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -324,7 +587,7 @@ public void testNegative_expression_staticMethod() { @Test public void testAutoFix_expression_localMethod() { - refactoringValidator + refactor() .addInputLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -345,8 +608,7 @@ public void testAutoFix_expression_localMethod() { "import " + Optional.class.getName() + ';', "class Test {", " public List foo(Optional> optional) {", - // This is not modified, may be improved later - " return optional.orElseGet(() -> bar());", + " return optional.orElseGet(this::bar);", " }", " private List bar() {", " return ImmutableList.of();", @@ -355,9 +617,54 @@ public void testAutoFix_expression_localMethod() { .doTest(BugCheckerRefactoringTestHelper.TestMode.TEXT_MATCH); } + @Test + public void testAutoFix_expression_referenceMethod() { + refactor() + .addInputLines( + "Test.java", + "import " + ImmutableList.class.getName() + ';', + "import " + List.class.getName() + ';', + "import " + Optional.class.getName() + ';', + "import " + Supplier.class.getName() + ';', + "class Test {", + " public List foo(Optional> a, Supplier> b) {", + " return a.orElseGet(() -> b.get());", + " }", + "}") + .addOutputLines( + "Test.java", + "import " + ImmutableList.class.getName() + ';', + "import " + List.class.getName() + ';', + "import " + Optional.class.getName() + ';', + "import " + Supplier.class.getName() + ';', + "class Test {", + " public List foo(Optional> a, Supplier> b) {", + " return a.orElseGet(b::get);", + " }", + "}") + .doTest(BugCheckerRefactoringTestHelper.TestMode.TEXT_MATCH); + } + + @Test + void testNegative_LocalStaticMethod_multiParam() { + compile() + .addSourceLines( + "Test.java", + "import " + Map.class.getName() + ';', + "class Test {", + " public void foo(Map map) {", + " map.forEach((k, v) -> bar(v, k));", + " }", + " private static void bar(Integer value, String key) {", + " System.out.println(key + value);", + " }", + "}") + .doTest(); + } + @Test public void testNegative_expression() { - compilationHelper + compile() .addSourceLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -373,7 +680,7 @@ public void testNegative_expression() { @Test public void testNegative_expression_chain() { - compilationHelper + compile() .addSourceLines( "Test.java", "import " + ImmutableList.class.getName() + ';', @@ -388,4 +695,36 @@ public void testNegative_expression_chain() { "}") .doTest(); } + + @Test + public void testNegative_dont_eagerly_capture_reference() { + compile() + .addSourceLines( + "Test.java", + "import " + Supplier.class.getName() + ';', + "class Test {", + " private Object mutable = null;", + " public Supplier foo() {", + " mutable = Long.toString(System.nanoTime());", + // mutable::toString would not take later modifications into account + " return () -> mutable.toString();", + " }", + "}") + .doTest(); + } + + @Test + public void testGuavaToJavaUtilOptional() { + refactor() + .addInputLines( + "Test.java", + "import java.util.stream.Stream;", + "class Test {", + " Stream> f(Stream> in) {", + " return in.map(value -> value.toJavaUtil());", + " }", + "}") + .expectUnchanged() + .doTest(); + } } diff --git a/changelog/@unreleased/pr-1365.v2.yml b/changelog/@unreleased/pr-1365.v2.yml new file mode 100644 index 000000000..19b1f566d --- /dev/null +++ b/changelog/@unreleased/pr-1365.v2.yml @@ -0,0 +1,6 @@ +type: improvement +improvement: + description: Convert multi param lambdas and local method invokations to method + references + links: + - https://github.com/palantir/gradle-baseline/pull/1365