Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collectors.toMap handling for streams #938

Merged
merged 23 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not required, but should we update the docs in lines 74-100 and the list of examples in the docs for observableCallToInnerMethodOrLambda to include info on .collect() calls?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a few lines in 8b3f984

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* THE SOFTWARE.
*/

import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.LinkedHashMultimap;
Expand Down Expand Up @@ -97,6 +98,13 @@ class StreamNullabilityPropagator extends BaseNoOpHandler {
* 'Observable.filter'). In general, for observable.a().b().c(), c is the outer call of b and b the outer call
* of a in the chain.
*
* We also support collect-like methods, which take a collector factory method as an argument, e.g.:
*
* stream.filter(...).collect(Collectors.toMap(l1, l2)) (where l1 and l2 are lambdas)
*
* For such scenarios, the lambdas l1 and l2 (or the named method in the equivalent anonymous class) serve
* an equivalent role to the map methods discussed above.
*
* This class works by building the following maps which keep enough state outside of the standard dataflow
* analysis for us to figure out what's going on:
*
Expand All @@ -118,14 +126,30 @@ class StreamNullabilityPropagator extends BaseNoOpHandler {
private final Map<MethodInvocationTree, Tree> observableCallToInnerMethodOrLambda =
new LinkedHashMap<>();

// Maps collect calls in the observable call chain to the relevant inner methods or lambdas.
@AutoValue
abstract static class CollectRecordAndInnerMethod {

static CollectRecordAndInnerMethod create(
CollectLikeMethodRecord collectlikeMethodRecord, Tree innerMethodOrLambda) {
return new AutoValue_StreamNullabilityPropagator_CollectRecordAndInnerMethod(
collectlikeMethodRecord, innerMethodOrLambda);
}

abstract CollectLikeMethodRecord getCollectLikeMethodRecord();

abstract Tree getInnerMethodOrLambda();
}

// Maps collect calls in the observable call chain to the relevant (collect record, inner method
// or lambda) pairs.
// We need a Multimap here since there may be multiple relevant methods / lambdas.
// E.g.: stream.filter(...).collect(Collectors.toMap(l1, l2)) => {l1,l2}
private final Multimap<MethodInvocationTree, Tree> collectCallToInnerMethodsOrLambdas =
LinkedHashMultimap.create();
// E.g.: stream.filter(...).collect(Collectors.toMap(l1, l2)) => (record for toMap, {l1,l2})
private final Multimap<MethodInvocationTree, CollectRecordAndInnerMethod>
collectCallToRecordsAndInnerMethodsOrLambdas = LinkedHashMultimap.create();

// Map from map or collect method (or lambda) to corresponding previous filter method (e.g.
// B.apply => A.filter)
// B.apply => A.filter for the map example above, or l1 => A.filter and l2 => A.filter for the
// collect example above)
private final Map<Tree, MapOrCollectMethodToFilterInstanceRecord> mapOrCollectRecordToFilterMap =
new LinkedHashMap<>();

Expand Down Expand Up @@ -175,7 +199,7 @@ public void onMatchTopLevelClass(
this.filterMethodOrLambdaSet.clear();
this.observableOuterCallInChain.clear();
this.observableCallToInnerMethodOrLambda.clear();
this.collectCallToInnerMethodsOrLambdas.clear();
this.collectCallToRecordsAndInnerMethodsOrLambdas.clear();
this.mapOrCollectRecordToFilterMap.clear();
this.filterToNSMap.clear();
this.bodyToMethodOrLambda.clear();
Expand Down Expand Up @@ -229,10 +253,14 @@ public void onMatchMethodInvocation(
observableCallToInnerMethodOrLambda.put(tree, argTree);
}
} else {
CollectLikeMethodRecord collectlikeMethodRecord =
streamType.getCollectlikeMethodRecord(methodSymbol);
if (collectlikeMethodRecord != null && methodSymbol.getParameters().length() == 1) {
handleCollectCall(tree, collectlikeMethodRecord);
if (methodSymbol.getParameters().length() == 1) {
for (CollectLikeMethodRecord collectlikeMethodRecord :
streamType.getCollectlikeMethodRecords(methodSymbol)) {
boolean handled = handleCollectCall(tree, collectlikeMethodRecord);
if (handled) {
break;
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but wonder if as a follow up you want to extract the other two cases of this if into their own methods too, with onMatchMethodInvocation(...) as just a dispatcher.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #944 to keep this PR simpler

}
}
Expand All @@ -241,12 +269,20 @@ public void onMatchMethodInvocation(

/**
* Handles a call to a collect-like method. If the argument to the method is supported, updates
* the {@link #collectCallToInnerMethodsOrLambdas} map appropriately.
* the {@link #collectCallToRecordsAndInnerMethodsOrLambdas} map appropriately.
*
* @param collectInvocationTree The MethodInvocationTree representing the call to the collect-like
* method.
* @param collectlikeMethodRecord The record representing the collect-like method.
* @return true if the argument to the collect method was a call to the factory method in the
* record, false otherwise.
*/
private void handleCollectCall(
MethodInvocationTree tree, CollectLikeMethodRecord collectlikeMethodRecord) {
ExpressionTree argTree = tree.getArguments().get(0);
private boolean handleCollectCall(
MethodInvocationTree collectInvocationTree, CollectLikeMethodRecord collectlikeMethodRecord) {
ExpressionTree argTree = collectInvocationTree.getArguments().get(0);
if (argTree instanceof MethodInvocationTree) {
// the argument passed to the collect method. We check if this is a call to the collector
// factory method from the record
MethodInvocationTree collectInvokeArg = (MethodInvocationTree) argTree;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is the Collectors.toMap(...) invocation, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, 8b3f984

Symbol.MethodSymbol collectInvokeArgSymbol = ASTHelpers.getSymbol(collectInvokeArg);
if (collectInvokeArgSymbol
Expand All @@ -268,14 +304,21 @@ private void handleCollectCall(
handleMapOrCollectAnonClassBody(
collectlikeMethodRecord,
anonClassBody,
t -> collectCallToInnerMethodsOrLambdas.put(tree, t));
t ->
collectCallToRecordsAndInnerMethodsOrLambdas.put(
collectInvocationTree,
CollectRecordAndInnerMethod.create(collectlikeMethodRecord, t)));
}
} else if (factoryMethodArg instanceof LambdaExpressionTree) {
collectCallToInnerMethodsOrLambdas.put(tree, factoryMethodArg);
collectCallToRecordsAndInnerMethodsOrLambdas.put(
collectInvocationTree,
CollectRecordAndInnerMethod.create(collectlikeMethodRecord, factoryMethodArg));
}
}
return true;
}
}
return false;
}

private void buildObservableCallChain(MethodInvocationTree tree) {
Expand Down Expand Up @@ -312,19 +355,15 @@ private void handleChainFromFilter(
mapOrCollectRecordToFilterMap.put(
observableCallToInnerMethodOrLambda.get(outerCallInChain), record);
}
} else if (collectCallToInnerMethodsOrLambdas.containsKey(outerCallInChain)) {
Symbol.MethodSymbol collectMethod = ASTHelpers.getSymbol(outerCallInChain);
CollectLikeMethodRecord collectlikeMethodRecord =
streamType.getCollectlikeMethodRecord(collectMethod);
if (collectlikeMethodRecord != null) {
// Update mapOrCollectRecordToFilterMap for all relevant methods / lambdas
for (Tree innerMethodOrLambda :
collectCallToInnerMethodsOrLambdas.get(outerCallInChain)) {
MapOrCollectMethodToFilterInstanceRecord record =
new MapOrCollectMethodToFilterInstanceRecord(
collectlikeMethodRecord, filterMethodOrLambda);
mapOrCollectRecordToFilterMap.put(innerMethodOrLambda, record);
}
} else if (collectCallToRecordsAndInnerMethodsOrLambdas.containsKey(outerCallInChain)) {
// Update mapOrCollectRecordToFilterMap for all relevant methods / lambdas
for (CollectRecordAndInnerMethod collectRecordAndInnerMethod :
collectCallToRecordsAndInnerMethodsOrLambdas.get(outerCallInChain)) {
MapOrCollectMethodToFilterInstanceRecord record =
new MapOrCollectMethodToFilterInstanceRecord(
collectRecordAndInnerMethod.getCollectLikeMethodRecord(), filterMethodOrLambda);
mapOrCollectRecordToFilterMap.put(
collectRecordAndInnerMethod.getInnerMethodOrLambda(), record);
}
}
} while (outerCallInChain != null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ public static StreamNullabilityPropagator getJavaStreamNullabilityPropagator() {
ImmutableSet.of(0, 1),
"apply",
ImmutableSet.of(0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Rx have an equivalent? See getRxStreamNullabilityPropagator()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RxJava 3 does have an Observable.collect method that takes a Collector:

http://reactivex.io/RxJava/3.x/javadoc/io/reactivex/rxjava3/core/Observable.html#collect-java.util.stream.Collector-

And there was a previous collect API that acts like reduce or fold where this support may be relevant. But I'd rather address our immediate need in this PR, and maybe file a follow-up task on systematically going through the Stream and RxJava APIs to add support for further methods. Sound ok?

.withCollectMethodFromSignature(
"<R,A>collect(java.util.stream.Collector<? super T,A,R>)",
"java.util.stream.Collectors",
"<T,K>groupingBy(java.util.function.Function<? super T,? extends K>)",
ImmutableSet.of(0),
"apply",
ImmutableSet.of(0))
.withCollectMethodFromSignature(
"<R,A>collect(java.util.stream.Collector<? super T,A,R>)",
"com.google.common.collect.ImmutableMap",
"<T,K,V>toImmutableMap(java.util.function.Function<? super T,? extends K>,java.util.function.Function<? super T,? extends V>)",
ImmutableSet.of(0, 1),
"apply",
ImmutableSet.of(0))
// List of methods of java.util.stream.Stream through which we just propagate the
// nullability information of the last call, e.g. m() in
// Observable.filter(...).m().map(...) means the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ public static CollectLikeMethodRecord create(
argsFromStream);
}

/** The fully qualified name of the class that contains the collector factory method. */
/**
* The fully qualified name of the class that contains the collector factory method, e.g., {@code
* java.util.stream.Collectors}.
*/
public abstract String collectorFactoryMethodClass();

/**
Expand All @@ -60,15 +63,22 @@ public static CollectLikeMethodRecord create(
public abstract String collectorFactoryMethodSignature();

/**
* The indices of the arguments to the collector factory method that are passed the elements of
* the stream
* The indices of the arguments to the collector factory method that are lambdas (or anonymous
* classes) which get invoked with the elements of the stream
*/
public abstract ImmutableSet<Integer> argsToCollectorFactoryMethod();

/** Name of the method that gets passed the elements of the stream */
/**
* Name of the method that gets passed the elements of the stream, e.g., "apply" for an anonymous
* class implementing {@link Function}. We assume that all such methods have the same name.
*/
@Override
public abstract String innerMethodName();

/**
* Argument indices to which stream elements are directly passed. We assume the same indices are
* used for all methods getting passed elements from the stream.
*/
@Override
public abstract ImmutableSet<Integer> argsFromStream();
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.predicates.TypePredicate;
import com.google.errorprone.predicates.type.DescendantOf;
Expand All @@ -48,7 +49,7 @@ public class StreamModelBuilder {
private ImmutableSet.Builder<String> filterMethodSimpleNames;
private ImmutableMap.Builder<String, MapLikeMethodRecord> mapMethodSigToRecord;
private ImmutableMap.Builder<String, MapLikeMethodRecord> mapMethodSimpleNameToRecord;
private ImmutableMap.Builder<String, CollectLikeMethodRecord> collectMethodSigToRecord;
private ImmutableMultimap.Builder<String, CollectLikeMethodRecord> collectMethodSigToRecords;
private ImmutableSet.Builder<String> passthroughMethodSigs;
private ImmutableSet.Builder<String> passthroughMethodSimpleNames;

Expand All @@ -75,7 +76,7 @@ private void finalizeOpenStreamTypeRecord() {
filterMethodSimpleNames.build(),
mapMethodSigToRecord.build(),
mapMethodSimpleNameToRecord.build(),
collectMethodSigToRecord.build(),
collectMethodSigToRecords.build(),
passthroughMethodSigs.build(),
passthroughMethodSimpleNames.build()));
}
Expand Down Expand Up @@ -109,7 +110,7 @@ private void initializeBuilders() {
this.filterMethodSimpleNames = ImmutableSet.builder();
this.mapMethodSigToRecord = ImmutableMap.builder();
this.mapMethodSimpleNameToRecord = ImmutableMap.builder();
this.collectMethodSigToRecord = ImmutableMap.builder();
this.collectMethodSigToRecords = ImmutableMultimap.builder();
this.passthroughMethodSigs = ImmutableSet.builder();
this.passthroughMethodSimpleNames = ImmutableSet.builder();
}
Expand Down Expand Up @@ -172,14 +173,35 @@ public StreamModelBuilder withMapMethodAllFromName(
return this;
}

/**
* Add a model for a collect method that takes a collector factory method as its argument to the
* last stream type. See the methods of {@link CollectLikeMethodRecord} for further details.
*
* @param collectMethodSig The full sub-signature (everything except the receiver type) of the
* collect method, e.g. {@code "<R,A>collect(java.util.stream.Collector<? super T,A,R>)"}.
* @param collectorFactoryMethodClass The fully qualified name of the class that contains the
* collector factory method; see {@link
* CollectLikeMethodRecord#collectorFactoryMethodClass()}.
* @param collectorFactoryMethodSig The signature of the factory method that creates the collector
* instance passed to the collect method; see {@link
* CollectLikeMethodRecord#collectorFactoryMethodSignature()}.
* @param argsToCollectorFactoryMethod The indices of the arguments to the collector factory
* method; see {@link CollectLikeMethodRecord#argsToCollectorFactoryMethod()}.
* @param innerMethodName Name of the method that gets passed the elements of the stream; see
* {@link CollectLikeMethodRecord#innerMethodName()}.
* @param argsFromStream Argument indices to which stream elements are directly passed; see {@link
* CollectLikeMethodRecord#argsFromStream()}.
* @return This builder (for chaining).
* @see CollectLikeMethodRecord
*/
public StreamModelBuilder withCollectMethodFromSignature(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Javadoc, please, specially because this has a lot of arguments with complex to understand semantics (e.g. argsToCollectorFactoryMethod vs argsToCollectorFactoryMethod). An example in the docs might even be a good idea :)

Edit: After reading a bit more, an alternative is to link to the docs on CollectLikeMethodRecord, but I also wouldn't mind a bit of redundant documentation of this arguments here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String collectMethodSig,
String collectorFactoryMethodClass,
String collectorFactoryMethodSig,
ImmutableSet<Integer> argsToCollectorFactoryMethod,
String innerMethodName,
ImmutableSet<Integer> argsFromStream) {
this.collectMethodSigToRecord.put(
this.collectMethodSigToRecords.put(
collectMethodSig,
CollectLikeMethodRecord.create(
collectorFactoryMethodClass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
*/
import static com.uber.nullaway.NullabilityUtil.castToNonNull;

import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.errorprone.VisitorState;
import com.google.errorprone.predicates.TypePredicate;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Type;
import javax.annotation.Nullable;

/** An immutable model describing a class from a stream-based API such as RxJava. */
public class StreamTypeRecord {
Expand All @@ -47,7 +48,7 @@ public class StreamTypeRecord {
private final ImmutableMap<String, MapLikeMethodRecord> mapMethodSigToRecord;
private final ImmutableMap<String, MapLikeMethodRecord> mapMethodSimpleNameToRecord;

private final ImmutableMap<String, CollectLikeMethodRecord> collectMethodSigToRecord;
private final ImmutableMultimap<String, CollectLikeMethodRecord> collectMethodSigToRecords;

// List of methods of java.util.stream.Stream through which we just propagate the nullability
// information of the last call, e.g. m() in Observable.filter(...).m().map(...) means the
Expand All @@ -64,15 +65,15 @@ public StreamTypeRecord(
ImmutableSet<String> filterMethodSimpleNames,
ImmutableMap<String, MapLikeMethodRecord> mapMethodSigToRecord,
ImmutableMap<String, MapLikeMethodRecord> mapMethodSimpleNameToRecord,
ImmutableMap<String, CollectLikeMethodRecord> collectMethodSigToRecord,
ImmutableMultimap<String, CollectLikeMethodRecord> collectMethodSigToRecords,
ImmutableSet<String> passthroughMethodSigs,
ImmutableSet<String> passthroughMethodSimpleNames) {
this.typePredicate = typePredicate;
this.filterMethodSigs = filterMethodSigs;
this.filterMethodSimpleNames = filterMethodSimpleNames;
this.mapMethodSigToRecord = mapMethodSigToRecord;
this.mapMethodSimpleNameToRecord = mapMethodSimpleNameToRecord;
this.collectMethodSigToRecord = collectMethodSigToRecord;
this.collectMethodSigToRecords = collectMethodSigToRecords;
this.passthroughMethodSigs = passthroughMethodSigs;
this.passthroughMethodSimpleNames = passthroughMethodSimpleNames;
}
Expand Down Expand Up @@ -101,9 +102,9 @@ record =
return record;
}

@Nullable
public CollectLikeMethodRecord getCollectlikeMethodRecord(Symbol.MethodSymbol methodSymbol) {
return collectMethodSigToRecord.get(methodSymbol.toString());
public ImmutableCollection<CollectLikeMethodRecord> getCollectlikeMethodRecords(
Symbol.MethodSymbol methodSymbol) {
return collectMethodSigToRecords.get(methodSymbol.toString());
}

public boolean isPassthroughMethod(Symbol.MethodSymbol methodSymbol) {
Expand Down
Loading