Skip to content

Commit

Permalink
Support safety annotations on type-use e.g. Collection<@safe String> (
Browse files Browse the repository at this point in the history
#2187)

Support safety annotations on type-use e.g. `Collection<@safe String>`
  • Loading branch information
carterkozak authored Apr 11, 2022
1 parent 5a64315 commit 8b053f9
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@
import com.sun.source.tree.VariableTree;
import com.sun.source.util.TreePath;
import com.sun.tools.javac.code.Symbol.MethodSymbol;
import com.sun.tools.javac.code.Symbol.TypeVariableSymbol;
import com.sun.tools.javac.code.Symbol.VarSymbol;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.code.Type.TypeVar;
import java.util.List;
import java.util.Objects;

/**
* Ensures that safe-logging annotated elements are handled correctly by annotated method parameters.
Expand Down Expand Up @@ -72,6 +76,28 @@ public final class IllegalSafeLoggingArgument extends BugChecker
.onClass("com.palantir.logsafe.SafeArg")
.named("of");

private static Type resolveParameterType(Type input, MethodInvocationTree tree, VisitorState state) {
if (input instanceof TypeVar) {
TypeVar typeVar = (TypeVar) input;

Type receiver = ASTHelpers.getReceiverType(tree);
if (receiver == null) {
return input;
}
MethodSymbol methodSymbol = ASTHelpers.getSymbol(tree);
// List<String> -> Collection<E> gives us Collection<String>
Type boundToMethodOwner = state.getTypes().asSuper(receiver, methodSymbol.owner);
List<TypeVariableSymbol> ownerTypeVars = methodSymbol.owner.getTypeParameters();
for (int i = 0; i < ownerTypeVars.size(); i++) {
TypeVariableSymbol ownerVar = ownerTypeVars.get(i);
if (Objects.equals(ownerVar, typeVar.tsym)) {
return boundToMethodOwner.getTypeArguments().get(i);
}
}
}
return input;
}

@Override
public Description matchMethodInvocation(MethodInvocationTree tree, VisitorState state) {
List<? extends ExpressionTree> arguments = tree.getArguments();
Expand All @@ -85,7 +111,10 @@ public Description matchMethodInvocation(MethodInvocationTree tree, VisitorState
List<VarSymbol> parameters = methodSymbol.getParameters();
for (int i = 0; i < parameters.size(); i++) {
VarSymbol parameter = parameters.get(i);
Safety parameterSafety = SafetyAnnotations.getSafety(parameter, state);
Type resolvedParameterType = resolveParameterType(parameter.type, tree, state);
Safety parameterSafety = Safety.mergeAssumingUnknownIsSame(
SafetyAnnotations.getSafety(parameter, state),
SafetyAnnotations.getSafety(resolvedParameterType, state));
if (parameterSafety.allowsAll()) {
// Fast path: all types are accepted, there's no reason to do further analysis.
continue;
Expand Down Expand Up @@ -164,7 +193,7 @@ public Description matchCompoundAssignment(CompoundAssignmentTree tree, VisitorS

private Description handleAssignment(
ExpressionTree assignmentTree, ExpressionTree variable, ExpressionTree expression, VisitorState state) {
Safety variableDeclaredSafety = SafetyAnnotations.getSafety(ASTHelpers.getSymbol(variable), state);
Safety variableDeclaredSafety = SafetyAnnotations.getSafety(variable, state);
if (variableDeclaredSafety.allowsAll()) {
return Description.NO_MATCH;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,56 @@

package com.palantir.baseline.errorprone.safety;

import com.google.common.collect.Multimap;
import com.google.errorprone.VisitorState;
import com.google.errorprone.suppliers.Suppliers;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.Tree;
import com.sun.tools.javac.code.Attribute;
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Symbol.MethodSymbol;
import com.sun.tools.javac.code.Symbol.VarSymbol;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.util.List;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import javax.lang.model.element.Name;
import javax.lang.model.element.TypeElement;

public final class SafetyAnnotations {
private static final String SAFE = "com.palantir.logsafe.Safe";
private static final String UNSAFE = "com.palantir.logsafe.Unsafe";
private static final String DO_NOT_LOG = "com.palantir.logsafe.DoNotLog";

public static Safety getSafety(ExpressionTree input, VisitorState state) {
// Check the result type
ExpressionTree tree = ASTHelpers.stripParentheses(input);
Safety resultTypeSafet = getResultTypeSafety(tree, state);
private static final TypeArgumentHandlers SAFETY_IS_COMBINATION_OF_TYPE_ARGUMENTS = new TypeArgumentHandlers(
new TypeArgumentHandler(Iterable.class),
new TypeArgumentHandler(Iterator.class),
new TypeArgumentHandler(Map.class),
new TypeArgumentHandler(Map.Entry.class),
new TypeArgumentHandler(Multimap.class),
new TypeArgumentHandler(Stream.class),
new TypeArgumentHandler(Optional.class));

// Check the argument symbol itself:
Symbol argumentSymbol = ASTHelpers.getSymbol(tree);
Safety symbolSafety = argumentSymbol != null ? getSafety(argumentSymbol, state) : Safety.UNKNOWN;
return Safety.mergeAssumingUnknownIsSame(resultTypeSafet, symbolSafety);
public static Safety getSafety(Tree tree, VisitorState state) {
// Check the symbol itself:
Symbol treeSymbol = ASTHelpers.getSymbol(tree);
Safety symbolSafety = getSafety(treeSymbol, state);
Type type = tree instanceof ExpressionTree
? ASTHelpers.getResultType((ExpressionTree) tree)
: ASTHelpers.getType(tree);

if (type == null) {
return symbolSafety;
} else {
return Safety.mergeAssumingUnknownIsSame(symbolSafety, getSafety(type, state), getSafety(type.tsym, state));
}
}

public static Safety getSafety(@Nullable Symbol symbol, VisitorState state) {
Expand All @@ -66,6 +91,91 @@ public static Safety getSafety(@Nullable Symbol symbol, VisitorState state) {
return Safety.UNKNOWN;
}

public static Safety getSafety(@Nullable Type type, VisitorState state) {
if (type != null) {
return getSafetyInternal(type, state, null);
}
return Safety.UNKNOWN;
}

private static Safety getSafetyInternal(Type type, VisitorState state, Set<String> dejaVu) {
List<Attribute.TypeCompound> typeAnnotations = type.getAnnotationMirrors();
for (Attribute.TypeCompound annotation : typeAnnotations) {
TypeElement annotationElement =
(TypeElement) annotation.getAnnotationType().asElement();
Name name = annotationElement.getQualifiedName();
if (name.contentEquals(DO_NOT_LOG)) {
return Safety.DO_NOT_LOG;
}
if (name.contentEquals(UNSAFE)) {
return Safety.UNSAFE;
}
if (name.contentEquals(SAFE)) {
return Safety.SAFE;
}
}
return SAFETY_IS_COMBINATION_OF_TYPE_ARGUMENTS.getSafety(type, state, dejaVu);
}

private static final class TypeArgumentHandlers {
private final TypeArgumentHandler[] handlers;

TypeArgumentHandlers(TypeArgumentHandler... handlers) {
this.handlers = handlers;
}

Safety getSafety(Type type, VisitorState state, @Nullable Set<String> dejaVu) {
for (TypeArgumentHandler handler : handlers) {
Safety result = handler.getSafety(type, state, dejaVu);
if (result != null) {
return result;
}
}
return Safety.UNKNOWN;
}
}

private static final class TypeArgumentHandler {
private final com.google.errorprone.suppliers.Supplier<Type> typeSupplier;

TypeArgumentHandler(Class<?> clazz) {
if (clazz.getTypeParameters().length == 0) {
throw new IllegalStateException("Class " + clazz + " has no type parameters");
}
this.typeSupplier = Suppliers.typeFromClass(clazz);
}

@Nullable
Safety getSafety(Type type, VisitorState state, @Nullable Set<String> dejaVu) {
Type baseType = typeSupplier.get(state);
if (ASTHelpers.isSubtype(type, baseType, state)) {
// ensure we're matching the expected type arguments
Set<String> deJaVuToPass = dejaVu == null ? new HashSet<>() : dejaVu;
// Use the string value for cycle detection, the type itself is not guaranteed
// to declare hash/equals.
String typeString = type.toString();
if (!deJaVuToPass.add(typeString)) {
return Safety.UNKNOWN;
}
// Apply the input type arguments to the base type
Type asSubtype = state.getTypes().asSuper(type, baseType.tsym);
Safety safety = Safety.SAFE;
List<Type> typeArguments = asSubtype.getTypeArguments();
for (Type typeArgument : typeArguments) {
Safety safetyBasedOnType = SafetyAnnotations.getSafetyInternal(typeArgument, state, deJaVuToPass);
Safety safetyBasedOnSymbol = SafetyAnnotations.getSafety(typeArgument.tsym, state);
Safety typeArgumentSafety =
Safety.mergeAssumingUnknownIsSame(safetyBasedOnType, safetyBasedOnSymbol);
safety = safety.leastUpperBound(typeArgumentSafety);
}
// remove the type on the way out, otherwise map<Foo,Foo> would break.
deJaVuToPass.remove(typeString);
return safety;
}
return null;
}
}

private static Safety getSuperMethodSafety(MethodSymbol method, VisitorState state) {
Safety safety = Safety.UNKNOWN;
if (!method.isStaticOrInstanceInit()) {
Expand Down Expand Up @@ -99,10 +209,5 @@ private static Safety getSuperMethodParameterSafety(VarSymbol varSymbol, Visitor
return safety;
}

public static Safety getResultTypeSafety(ExpressionTree expression, VisitorState state) {
Type resultType = ASTHelpers.getResultType(expression);
return resultType == null ? Safety.UNKNOWN : getSafety(resultType.tsym, state);
}

private SafetyAnnotations() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
import com.sun.tools.javac.code.Symbol;
import com.sun.tools.javac.code.Symbol.MethodSymbol;
import com.sun.tools.javac.code.Symbol.VarSymbol;
import com.sun.tools.javac.code.Type;
import com.sun.tools.javac.processing.JavacProcessingEnvironment;
import java.io.Closeable;
import java.util.Arrays;
Expand Down Expand Up @@ -291,7 +290,7 @@ public AccessPathStore<Safety> initialStore(UnderlyingAST _underlyingAst, List<L
AccessPathStore.Builder<Safety> result = AccessPathStore.<Safety>empty().toBuilder();

for (LocalVariableNode param : parameters) {
Safety declared = SafetyAnnotations.getSafety((Symbol) param.getElement(), state);
Safety declared = SafetyAnnotations.getSafety(param.getTree(), state);
result.setInformation(AccessPath.fromLocalVariable(param), declared);
}
return result.build();
Expand Down Expand Up @@ -668,9 +667,8 @@ public TransferResult<Safety, AccessPathStore<Safety>> visitAssignment(
AssignmentNode node, TransferInput<Safety, AccessPathStore<Safety>> input) {
ReadableUpdates updates = new ReadableUpdates();
Safety expressionSafety = getValueOfSubNode(input, node.getExpression());
Safety targetSymbolSafety = SafetyAnnotations.getSafety(
ASTHelpers.getSymbol(node.getTarget().getTree()), state);
Safety safety = Safety.mergeAssumingUnknownIsSame(expressionSafety, targetSymbolSafety);
Safety targetSafety = SafetyAnnotations.getSafety(node.getTarget().getTree(), state);
Safety safety = Safety.mergeAssumingUnknownIsSame(expressionSafety, targetSafety);
Node target = node.getTarget();
if (target instanceof LocalVariableNode) {
updates.trySet(target, safety);
Expand Down Expand Up @@ -771,23 +769,20 @@ private static boolean hasNonNullConstantValue(LocalVariableNode node) {
@Override
public TransferResult<Safety, AccessPathStore<Safety>> visitVariableDeclaration(
VariableDeclarationNode node, TransferInput<Safety, AccessPathStore<Safety>> input) {
Safety variableTypeSafety =
SafetyAnnotations.getSafety(ASTHelpers.getSymbol(node.getTree().getType()), state);
Safety variableSafety = SafetyAnnotations.getSafety(ASTHelpers.getSymbol(node.getTree()), state);
Safety variableTypeSafety = SafetyAnnotations.getSafety(node.getTree().getType(), state);
Safety variableSafety = SafetyAnnotations.getSafety(node.getTree(), state);
Safety safety = Safety.mergeAssumingUnknownIsSame(variableTypeSafety, variableSafety);
return noStoreChanges(safety, input);
}

@Override
public TransferResult<Safety, AccessPathStore<Safety>> visitFieldAccess(
FieldAccessNode node, TransferInput<Safety, AccessPathStore<Safety>> input) {
Safety fieldSafety = SafetyAnnotations.getSafety(ASTHelpers.getSymbol(node.getTree()), state);
Type fieldType = ASTHelpers.getType(node.getTree());
Safety typeSafety = fieldType == null ? Safety.UNKNOWN : SafetyAnnotations.getSafety(fieldType.tsym, state);
Safety fieldSafety = SafetyAnnotations.getSafety(node.getTree(), state);
VarSymbol symbol = (VarSymbol) ASTHelpers.getSymbol(node.getTree());
AccessPath maybeAccessPath = AccessPath.fromFieldAccess(node);
Safety flowSafety = fieldSafety(symbol, maybeAccessPath, input.getRegularStore());
Safety safety = Safety.mergeAssumingUnknownIsSame(fieldSafety, typeSafety, flowSafety);
Safety safety = Safety.mergeAssumingUnknownIsSame(fieldSafety, flowSafety);
return noStoreChanges(safety, input);
}

Expand Down Expand Up @@ -950,9 +945,7 @@ public TransferResult<Safety, AccessPathStore<Safety>> visitTypeCast(
private TransferResult<Safety, AccessPathStore<Safety>> handleTypeConversion(
Tree newType, Node original, TransferInput<Safety, AccessPathStore<Safety>> input) {
Safety valueSafety = getValueOfSubNode(input, original);
Type targetType = ASTHelpers.getType(newType);
Safety narrowTargetSafety =
targetType == null ? Safety.UNKNOWN : SafetyAnnotations.getSafety(targetType.tsym, state);
Safety narrowTargetSafety = SafetyAnnotations.getSafety(newType, state);
Safety resultSafety = Safety.mergeAssumingUnknownIsSame(valueSafety, narrowTargetSafety);
return noStoreChanges(resultSafety, input);
}
Expand Down Expand Up @@ -1024,18 +1017,16 @@ private Safety getKnownMethodSafety(

private Safety getMethodSymbolSafety(
MethodInvocationNode node, TransferInput<Safety, AccessPathStore<Safety>> input) {
Safety resultTypeSafety = SafetyAnnotations.getResultTypeSafety(node.getTree(), state);
Safety methodSafety = SafetyAnnotations.getSafety(node.getTree(), state);
MethodSymbol methodSymbol = ASTHelpers.getSymbol(node.getTree());
if (methodSymbol != null) {
Safety methodSafety = Safety.mergeAssumingUnknownIsSame(
SafetyAnnotations.getSafety(methodSymbol, state), resultTypeSafety);
// non-annotated toString inherits type-level safety.
if (methodSafety == Safety.UNKNOWN && TO_STRING.matches(node.getTree(), state)) {
return getValueOfSubNode(input, node.getTarget().getReceiver());
}
return methodSafety;
}
return resultTypeSafety;
return methodSafety;
}

@Override
Expand Down
Loading

0 comments on commit 8b053f9

Please sign in to comment.