Skip to content

Commit

Permalink
fix: improve type inference for generics in invoke insn (#927)
Browse files Browse the repository at this point in the history
  • Loading branch information
skylot committed May 10, 2020
1 parent b1d5ed0 commit 404136c
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package jadx.core.dex.visitors.typeinference;

import jadx.core.dex.instructions.BaseInvokeNode;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.RootNode;

/**
* Special dynamic bound for invoke with generics.
* Arguments bound type calculated using instance generic type.
*/
public final class TypeBoundInvokeUse implements ITypeBoundDynamic {
private final RootNode root;
private final BaseInvokeNode invokeNode;
private final RegisterArg arg;
private final ArgType genericArgType;

public TypeBoundInvokeUse(RootNode root, BaseInvokeNode invokeNode, RegisterArg arg, ArgType genericArgType) {
this.root = root;
this.invokeNode = invokeNode;
this.arg = arg;
this.genericArgType = genericArgType;
}

@Override
public BoundEnum getBound() {
return BoundEnum.USE;
}

@Override
public ArgType getType(TypeUpdateInfo updateInfo) {
return getArgType(updateInfo.getType(invokeNode.getInstanceArg()), updateInfo.getType(arg));
}

@Override
public ArgType getType() {
return getArgType(invokeNode.getInstanceArg().getType(), arg.getType());
}

private ArgType getArgType(ArgType instanceType, ArgType argType) {
ArgType resultGeneric = root.getTypeUtils().replaceClassGenerics(instanceType, genericArgType);
if (resultGeneric != null) {
return resultGeneric;
}
return argType;
}

@Override
public RegisterArg getArg() {
return arg;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TypeBoundInvokeUse that = (TypeBoundInvokeUse) o;
return invokeNode.equals(that.invokeNode);
}

@Override
public int hashCode() {
return invokeNode.hashCode();
}

@Override
public String toString() {
return "InvokeAssign{" + invokeNode.getCallMth().getShortId()
+ ", argType=" + genericArgType
+ ", currentType=" + getType()
+ ", instanceArg=" + invokeNode.getInstanceArg()
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -298,17 +298,31 @@ private ITypeBound makeUseBound(RegisterArg regArg) {
return null;
}
if (insn instanceof BaseInvokeNode) {
IMethodDetails methodDetails = root.getMethodUtils().getMethodDetails((BaseInvokeNode) insn);
if (methodDetails != null) {
if (methodDetails.getArgTypes().stream().anyMatch(ArgType::containsTypeVariable)) {
// don't add const bound for generic type variables
return null;
}
TypeBoundInvokeUse invokeUseBound = makeInvokeUseBound(regArg, (BaseInvokeNode) insn);
if (invokeUseBound != null) {
return invokeUseBound;
}
}
return new TypeBoundConst(BoundEnum.USE, regArg.getInitType(), regArg);
}

private TypeBoundInvokeUse makeInvokeUseBound(RegisterArg regArg, BaseInvokeNode invoke) {
InsnArg instanceArg = invoke.getInstanceArg();
if (instanceArg == null || instanceArg == regArg) {
return null;
}
IMethodDetails methodDetails = root.getMethodUtils().getMethodDetails(invoke);
if (methodDetails == null) {
return null;
}
int argIndex = invoke.getArgIndex(regArg) - invoke.getFirstArgOffset();
ArgType argType = methodDetails.getArgTypes().get(argIndex);
if (!argType.containsTypeVariable()) {
return null;
}
return new TypeBoundInvokeUse(root, invoke, regArg, argType);
}

private boolean tryPossibleTypes(SSAVar var, ArgType type) {
List<ArgType> types = makePossibleTypesList(type);
for (ArgType candidateType : types) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
import org.slf4j.LoggerFactory;

import jadx.core.Consts;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.PrimitiveType;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.nodes.IMethodDetails;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.dex.nodes.utils.TypeUtils;
import jadx.core.utils.exceptions.JadxOverflowException;
import jadx.core.utils.exceptions.JadxRuntimeException;

Expand Down Expand Up @@ -278,27 +281,58 @@ private Map<InsnType, ITypeListener> initListenerRegistry() {
}

private TypeUpdateResult invokeListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) {
if (insn.getResult() == null) {
InvokeNode invoke = (InvokeNode) insn;
if (isAssign(invoke, arg)) {
// TODO: implement backward type propagation (from result to instance)
return SAME;
}
if (candidateType.containsTypeVariable()) {
InvokeNode invokeNode = (InvokeNode) insn;
if (isAssign(insn, arg)) {
// TODO: implement backward type propagation (from result to instance)
if (invoke.getInstanceArg() == arg && candidateType.containsGeneric()) {
// resolve result and arg types from generic instance type
IMethodDetails methodDetails = root.getMethodUtils().getMethodDetails(invoke);
if (methodDetails == null) {
return SAME;
} else {
ArgType returnType = root.getMethodUtils().getMethodGenericReturnType(invokeNode);
if (returnType == null) {
return SAME;
}
TypeUtils typeUtils = root.getTypeUtils();
Map<ArgType, ArgType> typeVarsMap = typeUtils.getTypeVariablesMapping(candidateType);
if (typeVarsMap.isEmpty()) {
return SAME;
}

boolean allSame = true;
if (invoke.getResult() != null) {
ArgType returnType = typeUtils.replaceTypeVariablesUsingMap(methodDetails.getReturnType(), typeVarsMap);
if (returnType != null) {
TypeUpdateResult result = updateTypeChecked(updateInfo, invoke.getResult(), returnType);
if (result == REJECT) {
return REJECT;
}
if (result == CHANGED) {
allSame = false;
}
}
ArgType resultGeneric = root.getTypeUtils().replaceClassGenerics(candidateType, returnType);
if (resultGeneric == null) {
return SAME;
}

int argOffset = invoke.getFirstArgOffset();
List<ArgType> argTypes = methodDetails.getArgTypes();
int argsCount = argTypes.size();
for (int i = 0; i < argsCount; i++) {
ArgType genericArgType = argTypes.get(i);
ArgType resultArgType = typeUtils.replaceClassGenerics(candidateType, genericArgType);
if (resultArgType != null) {
InsnArg invokeArg = invoke.getArg(argOffset + i);
TypeUpdateResult result = updateTypeChecked(updateInfo, invokeArg, resultArgType);
if (result == REJECT) {
return REJECT;
}
if (result == CHANGED) {
allSame = false;
}
}
return updateTypeChecked(updateInfo, insn.getResult(), resultGeneric);
}
return allSame ? SAME : CHANGED;
}
return SAME;

}

private TypeUpdateResult sameFirstArgListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) {
Expand Down Expand Up @@ -377,12 +411,21 @@ private TypeUpdateResult suggestAllSameListener(TypeUpdateInfo updateInfo, InsnN
}

private TypeUpdateResult checkCastListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) {
if (!isAssign(insn, arg)) {
return SAME;
IndexInsnNode checkCast = (IndexInsnNode) insn;
if (isAssign(insn, arg)) {
InsnArg insnArg = insn.getArg(0);
TypeUpdateResult result = updateTypeChecked(updateInfo, insnArg, candidateType);
return result == REJECT ? SAME : result;
}
if (candidateType.containsGeneric()) {
ArgType castType = (ArgType) checkCast.getIndex();
TypeCompareEnum compResult = comparator.compareTypes(candidateType, castType);
if (compResult == TypeCompareEnum.NARROW_BY_GENERIC) {
// propagate generic type to result
return updateTypeChecked(updateInfo, checkCast.getResult(), candidateType);
}
}
InsnArg insnArg = insn.getArg(0);
TypeUpdateResult result = updateTypeChecked(updateInfo, insnArg, candidateType);
return result == REJECT ? SAME : result;
return SAME;
}

private TypeUpdateResult arrayGetListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package jadx.tests.integration.types;

import java.util.HashMap;
import java.util.Map;

import org.junit.jupiter.api.Test;

import jadx.tests.api.IntegrationTest;

import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;

public class TestGenerics5 extends IntegrationTest {

public static class TestCls {
private InheritableThreadLocal<Map<String, String>> inheritableThreadLocal;

public void put(String key, String val) {
if (key == null) {
throw new IllegalArgumentException("key cannot be null");
}
Map<String, String> map = this.inheritableThreadLocal.get();
if (map == null) {
map = new HashMap<>();
this.inheritableThreadLocal.set(map);
}
map.put(key, val);
}

public void remove(String key) {
Map<String, String> map = this.inheritableThreadLocal.get();
if (map != null) {
map.remove(key);
}
}
}

@Test
public void test() {
noDebugInfo();
assertThat(getClassNode(TestCls.class))
.code()
.countString(2, "Map<String, String> map = this.inheritableThreadLocal.get();");
}
}

0 comments on commit 404136c

Please sign in to comment.