Skip to content

Commit

Permalink
wasm gc: fix issues related to virtual calls
Browse files Browse the repository at this point in the history
  • Loading branch information
konsoletyper committed Aug 13, 2024
1 parent 73edc0c commit a84c5fc
Show file tree
Hide file tree
Showing 12 changed files with 150 additions and 78 deletions.
2 changes: 1 addition & 1 deletion core/src/main/java/org/teavm/backend/c/CTarget.java
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ private void generateStrings(BuildTarget buildTarget, GenerationContext context)

private VirtualTableProvider createVirtualTableProvider(ListableClassHolderSource classes) {
VirtualTableBuilder builder = new VirtualTableBuilder(classes);
builder.setMethodsUsedAtCallSites(VirtualTableBuilder.getMethodsUsedOnCallSites(classes));
builder.setMethodsUsedAtCallSites(VirtualTableBuilder.getMethodsUsedOnCallSites(classes, true));
builder.setMethodCalledVirtually(controller::isVirtual);
return builder.build();
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/org/teavm/backend/wasm/WasmTarget.java
Original file line number Diff line number Diff line change
Expand Up @@ -1070,7 +1070,7 @@ private void renderMemoryLayout(WasmModule module, int address, GCIntrinsic gcIn

private VirtualTableProvider createVirtualTableProvider(ListableClassHolderSource classes) {
var builder = new VirtualTableBuilder(classes);
builder.setMethodsUsedAtCallSites(VirtualTableBuilder.getMethodsUsedOnCallSites(classes));
builder.setMethodsUsedAtCallSites(VirtualTableBuilder.getMethodsUsedOnCallSites(classes, true));
builder.setMethodCalledVirtually(controller::isVirtual);
return builder.build();
}
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/java/org/teavm/backend/wasm/gc/WasmGCUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ public static String findCommonSuperclass(ClassHierarchy hierarchy, ClassReader
if (firstPath.get(0) != secondPath.get(0)) {
return "java.lang.Object";
}
var max = Math.max(firstPath.size(), secondPath.size());
var min = Math.min(firstPath.size(), secondPath.size());
var index = 1;
while (index < max && firstPath.get(index) == secondPath.get(index)) {
while (index < min && firstPath.get(index) == secondPath.get(index)) {
++index;
}
return firstPath.get(index).getName();
return index < firstPath.size() ? firstPath.get(index).getName() : secondPath.get(index).getName();
}

private static List<ClassReader> findPathToRoot(ClassHierarchy hierarchy, ClassReader cls) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.teavm.ast.Statement;
import org.teavm.ast.SubscriptExpr;
import org.teavm.ast.TryCatchStatement;
import org.teavm.ast.UnwrapArrayExpr;
import org.teavm.backend.wasm.WasmFunctionRepository;
import org.teavm.backend.wasm.WasmFunctionTypes;
import org.teavm.backend.wasm.WasmHeap;
Expand Down Expand Up @@ -299,7 +298,7 @@ protected WasmExpression classLiteral(ValueType type) {
}

@Override
protected WasmExpression nullLiteral() {
protected WasmExpression nullLiteral(Expr expr) {
return new WasmInt32Constant(0);
}

Expand Down Expand Up @@ -371,11 +370,6 @@ private WasmExpression getArrayElementPointer(WasmExpression array, WasmExpressi
return new WasmIntBinary(WasmIntType.INT32, WasmIntBinaryOperation.ADD, array, index);
}

@Override
public void visit(UnwrapArrayExpr expr) {
accept(expr.getArray());
}

@Override
protected WasmExpression invocation(InvocationExpr expr, List<WasmExpression> resultConsumer, boolean willDrop) {
if (expr.getMethod().getClassName().equals(ShadowStack.class.getName())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import org.teavm.ast.ThrowStatement;
import org.teavm.ast.TryCatchStatement;
import org.teavm.ast.UnaryExpr;
import org.teavm.ast.UnwrapArrayExpr;
import org.teavm.ast.VariableExpr;
import org.teavm.ast.WhileStatement;
import org.teavm.backend.wasm.WasmRuntime;
Expand Down Expand Up @@ -453,7 +454,7 @@ private WasmExpression nullCheck(Expr value, TextLocation location) {
accept(value);
result.acceptVisitor(typeInference);
block.setType(typeInference.getResult());
var cachedValue = exprCache.create(result, WasmType.INT32, location, block.getBody());
var cachedValue = exprCache.create(result, typeInference.getResult(), location, block.getBody());

var check = new WasmBranch(cachedValue.expr(), block);
check.setResult(cachedValue.expr());
Expand Down Expand Up @@ -555,7 +556,7 @@ public void visit(SequentialStatement statement) {
@Override
public void visit(ConstantExpr expr) {
if (expr.getValue() == null) {
result = nullLiteral();
result = nullLiteral(expr);
} else if (expr.getValue() instanceof Integer) {
result = new WasmInt32Constant((Integer) expr.getValue());
} else if (expr.getValue() instanceof Long) {
Expand All @@ -574,7 +575,7 @@ public void visit(ConstantExpr expr) {
result.setLocation(expr.getLocation());
}

protected abstract WasmExpression nullLiteral();
protected abstract WasmExpression nullLiteral(Expr expr);

protected abstract WasmExpression stringLiteral(String s);

Expand Down Expand Up @@ -894,12 +895,14 @@ protected WasmExpression generateInvocation(InvocationExpr expr, CallSiteIdentif
return block;
} else {
var reference = expr.getMethod();
acceptWithType(expr.getArguments().get(0), ValueType.object(expr.getMethod().getClassName()));
var instanceType = ValueType.object(expr.getMethod().getClassName());
acceptWithType(expr.getArguments().get(0), instanceType);
var instanceWasmType = mapType(instanceType);
var instance = result;
var block = new WasmBlock(false);
block.setType(mapType(reference.getReturnType()));

var instanceVar = tempVars.acquire(WasmType.INT32);
var instanceVar = tempVars.acquire(instanceWasmType);
block.getBody().add(new WasmSetLocal(instanceVar, instance));
instance = new WasmGetLocal(instanceVar);

Expand Down Expand Up @@ -1066,7 +1069,8 @@ public void visit(ArrayFromDataExpr expr) {

for (int i = 0; i < expr.getData().size(); ++i) {
expr.getData().get(i).acceptVisitor(this);
block.getBody().add(storeArrayItem(new WasmGetLocal(array), new WasmInt32Constant(i), result, arrayType));
var arrayData = unwrapArray(new WasmGetLocal(array));
block.getBody().add(storeArrayItem(arrayData, new WasmInt32Constant(i), result, arrayType));
}

block.getBody().add(new WasmGetLocal(array));
Expand Down Expand Up @@ -1126,7 +1130,8 @@ public void visit(InstanceOfExpr expr) {
block.setType(WasmType.INT32);
block.setLocation(expr.getLocation());

var cachedObject = exprCache.create(result, WasmType.INT32, expr.getLocation(), block.getBody());
result.acceptVisitor(typeInference);
var cachedObject = exprCache.create(result, typeInference.getResult(), expr.getLocation(), block.getBody());

var ifNull = new WasmBranch(genIsZero(cachedObject.expr()), block);
ifNull.setResult(new WasmInt32Constant(0));
Expand Down Expand Up @@ -1545,6 +1550,17 @@ private WasmExpression genIsZero(WasmExpression value) {

protected abstract WasmType mapType(ValueType type);

protected WasmExpression unwrapArray(WasmExpression array) {
return array;
}

@Override
public void visit(UnwrapArrayExpr expr) {
accept(expr.getArray());
result = unwrapArray(result);
result.setLocation(expr.getLocation());
}

protected abstract class CallSiteIdentifier {
public abstract void generateRegister(List<WasmExpression> consumer, TextLocation location);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public void contributeToInitializer(WasmFunction function) {
private static VirtualTableProvider createVirtualTableProvider(ListableClassHolderSource classes,
Predicate<MethodReference> virtualMethods) {
var builder = new VirtualTableBuilder(classes);
builder.setMethodsUsedAtCallSites(VirtualTableBuilder.getMethodsUsedOnCallSites(classes));
builder.setMethodsUsedAtCallSites(VirtualTableBuilder.getMethodsUsedOnCallSites(classes, false));
builder.setMethodCalledVirtually(virtualMethods);
return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ public class WasmGCClassGenerator implements WasmGCClassInfoProvider, WasmGCInit
private Map<ValueType, WasmGCClassInfo> classInfoMap = new LinkedHashMap<>();
private Queue<WasmGCClassInfo> classInfoQueue = new ArrayDeque<>();
private ObjectIntMap<FieldReference> fieldIndexes = new ObjectIntHashMap<>();
private ObjectIntMap<MethodReference> methodIndexes = new ObjectIntHashMap<>();
private Map<FieldReference, WasmGlobal> staticFieldLocations = new HashMap<>();
private List<Consumer<WasmFunction>> staticFieldInitializers = new ArrayList<>();
private ClassInitializerInfo classInitializerInfo;
Expand Down Expand Up @@ -195,8 +194,7 @@ public WasmGCClassInfo getClassInfo(ValueType type) {
classInfo.structure = new WasmStructure(name != null ? names.forClass(name) : null);
if (name != null) {
var classReader = classSource.get(name);
if (classReader == null || !classReader.hasModifier(ElementModifier.ABSTRACT)
&& !classReader.hasModifier(ElementModifier.INTERFACE)) {
if (classReader == null || !classReader.hasModifier(ElementModifier.INTERFACE)) {
virtualTable = virtualTables.lookup(name);
}
if (classReader != null && classReader.getParent() != null) {
Expand Down Expand Up @@ -237,6 +235,11 @@ public int getClassArrayItemOffset() {
return classArrayItemOffset;
}

@Override
public int getVirtualMethodsOffset() {
return virtualTableFieldOffset;
}

private void initPrimitiveClass(WasmGCClassInfo classInfo, ValueType.Primitive type) {
classInfo.initializer = target -> {
int kind;
Expand Down Expand Up @@ -327,6 +330,10 @@ private void initRegularClass(WasmGCClassInfo classInfo, VirtualTable virtualTab

private int fillVirtualTableMethods(List<WasmExpression> target, WasmStructure structure, WasmGlobal global,
VirtualTable virtualTable, int index, String origin, Set<MethodDescriptor> filled) {
if (virtualTable.getParent() != null) {
index = fillVirtualTableMethods(target, structure, global, virtualTable.getParent(), index, origin,
filled);
}
for (var method : virtualTable.getMethods()) {
var entry = virtualTable.getEntry(method);
if (entry != null && entry.getImplementor() != null && filled.add(method)) {
Expand All @@ -337,24 +344,21 @@ private int fillVirtualTableMethods(List<WasmExpression> target, WasmStructure s
module.functions.add(wrapperFunction);
var call = new WasmCall(function);
var instanceParam = new WasmLocal(getClassInfo(virtualTable.getClassName()).getType());
wrapperFunction.getLocalVariables().add(instanceParam);
wrapperFunction.add(instanceParam);
var castTarget = getClassInfo(entry.getImplementor().getClassName()).getType();
call.getArguments().add(new WasmCast(new WasmGetLocal(instanceParam), castTarget));
var params = new WasmLocal[method.parameterCount()];
for (var i = 0; i < method.parameterCount(); ++i) {
params[i] = new WasmLocal(typeMapper.mapType(method.parameterType(i)));
call.getArguments().add(new WasmGetLocal(params[i]));
wrapperFunction.add(params[i]);
}
wrapperFunction.getLocalVariables().addAll(List.of(params));
}
function.setReferenced(true);
var ref = new WasmFunctionReference(function);
target.add(new WasmStructSet(structure, new WasmGetGlobal(global), index, ref));
}
}
if (virtualTable.getParent() != null) {
index = fillVirtualTableMethods(target, structure, global, virtualTable.getParent(), index, origin,
filled);
++index;
}
return index;
}
Expand All @@ -376,10 +380,12 @@ private void addVirtualTableFields(WasmStructure structure, VirtualTable virtual
addVirtualTableFields(structure, virtualTable.getParent());
}
for (var methodDesc : virtualTable.getMethods()) {
var functionType = getFunctionType(virtualTable.getClassName(), methodDesc);
var methodRef = new MethodReference(virtualTable.getClassName(), methodDesc);
methodIndexes.put(methodRef, structure.getFields().size());
structure.getFields().add(functionType.getReference().asStorage());
if (methodDesc == null) {
structure.getFields().add(WasmType.Reference.FUNC.asStorage());
} else {
var functionType = getFunctionType(virtualTable.getClassName(), methodDesc);
structure.getFields().add(functionType.getReference().asStorage());
}
}
}

Expand Down Expand Up @@ -454,15 +460,6 @@ private WasmGlobal generateStaticFieldLocation(FieldReference fieldRef) {
return global;
}

@Override
public int getVirtualMethodIndex(MethodReference methodRef) {
var result = methodIndexes.getOrDefault(methodRef, -1);
if (result < 0) {
throw new IllegalStateException("Can't get offset of method " + methodRef);
}
return result;
}

private void fillFields(WasmGCClassInfo classInfo, ValueType type) {
var fields = classInfo.structure.getFields();
fields.add(standardClasses.classClass().getType().asStorage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import org.teavm.backend.wasm.model.WasmGlobal;
import org.teavm.model.FieldReference;
import org.teavm.model.MethodReference;
import org.teavm.model.ValueType;

public interface WasmGCClassInfoProvider {
Expand All @@ -32,7 +31,7 @@ public interface WasmGCClassInfoProvider {

WasmGlobal getStaticFieldLocation(FieldReference fieldRef);

int getVirtualMethodIndex(MethodReference methodRef);
int getVirtualMethodsOffset();

default WasmGCClassInfo getClassInfo(String name) {
return getClassInfo(ValueType.object(name));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
*/
package org.teavm.backend.wasm.generate.gc.methods;

import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.teavm.backend.wasm.BaseWasmFunctionRepository;
import org.teavm.backend.wasm.WasmFunctionTypes;
import org.teavm.backend.wasm.gc.WasmGCMethodReturnTypes;
Expand All @@ -32,6 +38,8 @@
import org.teavm.backend.wasm.runtime.WasmGCSupport;
import org.teavm.model.ClassHierarchy;
import org.teavm.model.ClassReaderSource;
import org.teavm.model.ElementModifier;
import org.teavm.model.ListableClassReaderSource;
import org.teavm.model.MethodReference;
import org.teavm.model.classes.VirtualTableProvider;

Expand All @@ -43,7 +51,7 @@ public class WasmGCGenerationContext implements BaseWasmGenerationContext {
private VirtualTableProvider virtualTables;
private WasmGCTypeMapper typeMapper;
private WasmFunctionTypes functionTypes;
private ClassReaderSource classes;
private ListableClassReaderSource classes;
private ClassHierarchy hierarchy;
private BaseWasmFunctionRepository functions;
private WasmGCSupertypeFunctionProvider supertypeFunctions;
Expand All @@ -55,9 +63,10 @@ public class WasmGCGenerationContext implements BaseWasmGenerationContext {
private WasmFunction cceMethod;
private WasmGlobal exceptionGlobal;
private WasmTag exceptionTag;
private Map<String, Set<String>> interfaceImplementors;

public WasmGCGenerationContext(WasmModule module, VirtualTableProvider virtualTables,
WasmGCTypeMapper typeMapper, WasmFunctionTypes functionTypes, ClassReaderSource classes,
WasmGCTypeMapper typeMapper, WasmFunctionTypes functionTypes, ListableClassReaderSource classes,
ClassHierarchy hierarchy, BaseWasmFunctionRepository functions,
WasmGCSupertypeFunctionProvider supertypeFunctions, WasmGCClassInfoProvider classInfoProvider,
WasmGCStandardClasses standardClasses, WasmGCStringProvider strings,
Expand Down Expand Up @@ -180,4 +189,36 @@ public WasmGCIntrinsicProvider intrinsics() {
public WasmGCMethodReturnTypes returnTypes() {
return returnTypes;
}

public Collection<String> getInterfaceImplementors(String className) {
if (interfaceImplementors == null) {
fillInterfaceImplementors();
}
var result = interfaceImplementors.get(className);
return result != null ? result : List.of();
}

private void fillInterfaceImplementors() {
interfaceImplementors = new HashMap<>();
for (var className : classes.getClassNames()) {
var cls = classes.get(className);
if (!cls.hasModifier(ElementModifier.INTERFACE)) {
for (var itf : cls.getInterfaces()) {
addInterfaceImplementor(className, itf);
}
}
}
}

private void addInterfaceImplementor(String implementorName, String interfaceName) {
var implementorsByKey = interfaceImplementors.computeIfAbsent(interfaceName, k -> new LinkedHashSet<>());
if (implementorsByKey.add(implementorName)) {
var itf = classes.get(implementorName);
if (itf != null) {
for (var parentItf : itf.getInterfaces()) {
addInterfaceImplementor(implementorName, parentItf);
}
}
}
}
}
Loading

0 comments on commit a84c5fc

Please sign in to comment.