Skip to content

Commit

Permalink
wasm gc: fix issues with virtual calls
Browse files Browse the repository at this point in the history
  • Loading branch information
konsoletyper committed Aug 16, 2024
1 parent 2805631 commit 40fbce0
Show file tree
Hide file tree
Showing 16 changed files with 245 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ private void contributeExceptionUtils() {
analyzer.linkMethod(new MethodReference(WasmGCSupport.class, "aiiobe", ArrayIndexOutOfBoundsException.class))
.use();
analyzer.linkMethod(new MethodReference(WasmGCSupport.class, "cce", ClassCastException.class)).use();
analyzer.linkMethod(new MethodReference(WasmGCSupport.class, "cnse", CloneNotSupportedException.class)).use();
}

private void contributeInitializerUtils() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import org.teavm.backend.wasm.model.expression.WasmIntBinary;
import org.teavm.backend.wasm.model.expression.WasmIntBinaryOperation;
import org.teavm.backend.wasm.model.expression.WasmIntType;
import org.teavm.backend.wasm.model.expression.WasmIntUnary;
import org.teavm.backend.wasm.model.expression.WasmIntUnaryOperation;
import org.teavm.backend.wasm.model.expression.WasmLoadFloat32;
import org.teavm.backend.wasm.model.expression.WasmLoadFloat64;
import org.teavm.backend.wasm.model.expression.WasmLoadInt32;
Expand Down Expand Up @@ -302,6 +304,11 @@ protected WasmExpression nullLiteral(Expr expr) {
return new WasmInt32Constant(0);
}

@Override
protected WasmExpression genIsNull(WasmExpression value) {
return new WasmIntUnary(WasmIntType.INT32, WasmIntUnaryOperation.EQZ, value);
}

@Override
public void visit(SubscriptExpr expr) {
WasmExpression ptr = getArrayElementPointer(expr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@
import org.teavm.backend.wasm.model.expression.WasmIntBinary;
import org.teavm.backend.wasm.model.expression.WasmIntBinaryOperation;
import org.teavm.backend.wasm.model.expression.WasmIntType;
import org.teavm.backend.wasm.model.expression.WasmIntUnary;
import org.teavm.backend.wasm.model.expression.WasmIntUnaryOperation;
import org.teavm.backend.wasm.model.expression.WasmReturn;
import org.teavm.backend.wasm.model.expression.WasmSetLocal;
import org.teavm.backend.wasm.model.expression.WasmSwitch;
Expand Down Expand Up @@ -902,9 +900,16 @@ protected WasmExpression generateInvocation(InvocationExpr expr, CallSiteIdentif
var block = new WasmBlock(false);
block.setType(mapType(reference.getReturnType()));

var instanceVar = tempVars.acquire(instanceWasmType);
block.getBody().add(new WasmSetLocal(instanceVar, instance));
instance = new WasmGetLocal(instanceVar);
WasmLocal instanceVar;
var isTemporary = false;
if (instance instanceof WasmGetLocal) {
instanceVar = ((WasmGetLocal) instance).getLocal();
} else {
instanceVar = tempVars.acquire(instanceWasmType);
block.getBody().add(new WasmSetLocal(instanceVar, instance));
instance = new WasmGetLocal(instanceVar);
isTemporary = true;
}

var arguments = new ArrayList<WasmExpression>();
arguments.add(instance);
Expand All @@ -918,7 +923,9 @@ protected WasmExpression generateInvocation(InvocationExpr expr, CallSiteIdentif
var call = generateVirtualCall(instanceVar, reference, arguments);

block.getBody().add(call);
tempVars.release(instanceVar);
if (isTemporary) {
tempVars.release(instanceVar);
}
return block;
}
}
Expand Down Expand Up @@ -1133,7 +1140,7 @@ public void visit(InstanceOfExpr expr) {
result.acceptVisitor(typeInference);
var cachedObject = exprCache.create(result, typeInference.getResult(), expr.getLocation(), block.getBody());

var ifNull = new WasmBranch(genIsZero(cachedObject.expr()), block);
var ifNull = new WasmBranch(genIsNull(cachedObject.expr()), block);
ifNull.setResult(new WasmInt32Constant(0));
block.getBody().add(new WasmDrop(ifNull));

Expand Down Expand Up @@ -1171,7 +1178,7 @@ public void visit(CastExpr expr) {
var wasmSourceType = typeInference.getResult();
var valueToCast = exprCache.create(result, wasmSourceType, expr.getLocation(), block.getBody());

var nullCheck = new WasmBranch(genIsZero(valueToCast.expr()), block);
var nullCheck = new WasmBranch(genIsNull(valueToCast.expr()), block);
nullCheck.setResult(valueToCast.expr());
block.getBody().add(new WasmDrop(nullCheck));

Expand Down Expand Up @@ -1544,9 +1551,7 @@ private static WasmFloatBinaryOperation negate(WasmFloatBinaryOperation op) {
}
}

private WasmExpression genIsZero(WasmExpression value) {
return new WasmIntUnary(WasmIntType.INT32, WasmIntUnaryOperation.EQZ, value);
}
protected abstract WasmExpression genIsNull(WasmExpression value);

protected abstract WasmType mapType(ValueType type);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.teavm.backend.wasm.model.expression.WasmIntBinaryOperation;
import org.teavm.backend.wasm.model.expression.WasmIntType;
import org.teavm.backend.wasm.model.expression.WasmNullConstant;
import org.teavm.backend.wasm.model.expression.WasmReturn;
import org.teavm.backend.wasm.model.expression.WasmSetGlobal;
import org.teavm.backend.wasm.model.expression.WasmStructNewDefault;
import org.teavm.backend.wasm.model.expression.WasmStructSet;
Expand Down Expand Up @@ -122,7 +123,7 @@ public WasmGCClassGenerator(WasmModule module, ClassReaderSource classSource,
standardClasses = new WasmGCStandardClasses(this);
strings = new WasmGCStringPool(standardClasses, module, functionProvider);
supertypeGenerator = new WasmGCSupertypeFunctionGenerator(module, this, names, tagRegistry, functionTypes);
typeMapper = new WasmGCTypeMapper(this);
typeMapper = new WasmGCTypeMapper(this, functionTypes, module);
}

public WasmGCSupertypeFunctionProvider getSupertypeProvider() {
Expand All @@ -143,6 +144,7 @@ public boolean process() {

@Override
public void contributeToInitializerDefinitions(WasmFunction function) {
fillVirtualTableSupertypes();
for (var classInfo : classInfoMap.values()) {
var classInstanceType = classInfo.virtualTableStructure != null
? classInfo.virtualTableStructure
Expand All @@ -152,6 +154,35 @@ public void contributeToInitializerDefinitions(WasmFunction function) {
}
}

private void fillVirtualTableSupertypes() {
for (var classInfo : classInfoMap.values()) {
if (classInfo.virtualTableStructure != null && classInfo.getValueType() instanceof ValueType.Object
&& classInfo.hasOwnVirtualTable) {
var className = ((ValueType.Object) classInfo.getValueType()).getClassName();
classInfo.virtualTableStructure.setSupertype(findVirtualTableSupertype(className));
}
}
}

private WasmStructure findVirtualTableSupertype(String className) {
while (className != null) {
var cls = classSource.get(className);
if (cls == null) {
break;
}
className = cls.getParent();
if (className == null) {
break;
}
var parentInfo = classInfoMap.get(ValueType.object(className));
if (parentInfo != null && parentInfo.virtualTableStructure != null) {
return parentInfo.virtualTableStructure;
}
}
var classClass = classInfoMap.get(ValueType.object("java.lang.Class"));
return classClass != null ? classClass.structure : null;
}

@Override
public void contributeToInitializer(WasmFunction function) {
var classClass = standardClasses.classClass();
Expand Down Expand Up @@ -209,7 +240,8 @@ public WasmGCClassInfo getClassInfo(ValueType type) {
fillFields(classInfo, type);
}
var pointerName = names.forClassInstance(type);
var classStructure = virtualTable != null && virtualTable.hasValidEntries()
classInfo.hasOwnVirtualTable = virtualTable != null && virtualTable.hasValidEntries();
var classStructure = classInfo.hasOwnVirtualTable
? initRegularClassStructure(((ValueType.Object) type).getClassName())
: standardClasses.classClass().getStructure();
classInfo.virtualTableStructure = classStructure;
Expand Down Expand Up @@ -330,21 +362,30 @@ private void initRegularClass(WasmGCClassInfo classInfo, VirtualTable virtualTab
}
}
if (virtualTable != null && virtualTable.hasValidEntries()) {
fillVirtualTableMethods(target, classStructure, classInfo.pointer, virtualTable,
virtualTableFieldOffset, name, new HashSet<>());
fillVirtualTableMethods(target, classStructure, classInfo.pointer, virtualTable, virtualTable,
new HashSet<>());
}
};
}

private int fillVirtualTableMethods(List<WasmExpression> target, WasmStructure structure, WasmGlobal global,
VirtualTable virtualTable, int index, String origin, Set<MethodDescriptor> filled) {
private void fillVirtualTableMethods(List<WasmExpression> target, WasmStructure structure, WasmGlobal global,
VirtualTable virtualTable, VirtualTable original, Set<MethodDescriptor> filled) {
if (virtualTable.getParent() != null) {
fillVirtualTableMethods(target, structure, global, virtualTable.getParent(), original, filled);
}
for (var method : virtualTable.getMethods()) {
var entry = virtualTable.getEntry(method);
var entry = original.getEntry(method);
if (entry != null && entry.getImplementor() != null && filled.add(method)
&& !method.equals(GET_CLASS_METHOD)) {
var fieldIndex = virtualTableFieldOffset + entry.getIndex();
var expectedType = (WasmType.CompositeReference) structure.getFields().get(fieldIndex)
.asUnpackedType();
var expectedFunctionType = (WasmFunctionType) expectedType.composite;
var function = functionProvider.forInstanceMethod(entry.getImplementor());
if (!origin.equals(entry.getImplementor().getClassName())) {
var functionType = getFunctionType(virtualTable.getClassName(), method);
if (!virtualTable.getClassName().equals(entry.getImplementor().getClassName())
|| expectedFunctionType != function.getType()) {
var functionType = typeMapper.getFunctionType(virtualTable.getClassName(), method, true);
functionType.getSupertypes().add(expectedFunctionType);
var wrapperFunction = new WasmFunction(functionType);
module.functions.add(wrapperFunction);
var call = new WasmCall(function);
Expand All @@ -358,17 +399,14 @@ private int fillVirtualTableMethods(List<WasmExpression> target, WasmStructure s
call.getArguments().add(new WasmGetLocal(params[i]));
wrapperFunction.add(params[i]);
}
wrapperFunction.getBody().add(new WasmReturn(call));
function = wrapperFunction;
}
function.setReferenced(true);
var ref = new WasmFunctionReference(function);
target.add(new WasmStructSet(structure, new WasmGetGlobal(global), index + entry.getIndex(), ref));
target.add(new WasmStructSet(structure, new WasmGetGlobal(global), fieldIndex, ref));
}
}
if (virtualTable.getParent() != null) {
index = fillVirtualTableMethods(target, structure, global, virtualTable.getParent(), index, origin,
filled);
}
return index;
}

private WasmStructure initRegularClassStructure(String className) {
Expand All @@ -391,23 +429,12 @@ private void addVirtualTableFields(WasmStructure structure, VirtualTable virtual
if (methodDesc == null) {
structure.getFields().add(WasmType.Reference.FUNC.asStorage());
} else {
var functionType = getFunctionType(virtualTable.getClassName(), methodDesc);
var functionType = typeMapper.getFunctionType(virtualTable.getClassName(), methodDesc, false);
structure.getFields().add(functionType.getReference().asStorage());
}
}
}

private WasmFunctionType getFunctionType(String className, MethodDescriptor methodDesc) {
var returnType = typeMapper.mapType(methodDesc.getResultType());
var javaParamTypes = methodDesc.getParameterTypes();
var paramTypes = new WasmType[javaParamTypes.length + 1];
paramTypes[0] = getClassInfo(className).getType();
for (var i = 0; i < javaParamTypes.length; ++i) {
paramTypes[i + 1] = typeMapper.mapType(javaParamTypes[i]);
}
return functionTypes.of(returnType, paramTypes);
}

private void initArrayClass(WasmGCClassInfo classInfo, ValueType.Array type) {
classInfo.initializer = target -> {
var itemTypeInfo = getClassInfo(type.getItemType());
Expand Down Expand Up @@ -497,6 +524,9 @@ private void fillSimpleClassFields(List<WasmStorageType> fields, String classNam
if (className.equals("java.lang.Object") && field.getName().equals("monitor")) {
continue;
}
if (className.equals("java.lang.Class") && field.getName().equals("platformClass")) {
continue;
}
if (field.hasModifier(ElementModifier.STATIC)) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public class WasmGCClassInfo {
private ValueType valueType;
WasmStructure structure;
WasmArray array;
boolean hasOwnVirtualTable;
WasmStructure virtualTableStructure;
WasmGlobal pointer;
WasmGlobal initializerPointer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,26 @@
*/
package org.teavm.backend.wasm.generate.gc.classes;

import java.util.List;
import org.teavm.backend.wasm.WasmFunctionTypes;
import org.teavm.backend.wasm.model.WasmFunctionType;
import org.teavm.backend.wasm.model.WasmModule;
import org.teavm.backend.wasm.model.WasmPackedType;
import org.teavm.backend.wasm.model.WasmStorageType;
import org.teavm.backend.wasm.model.WasmType;
import org.teavm.model.MethodDescriptor;
import org.teavm.model.ValueType;

public class WasmGCTypeMapper {
private WasmGCClassInfoProvider classInfoProvider;
private WasmFunctionTypes functionTypes;
private WasmModule module;

WasmGCTypeMapper(WasmGCClassInfoProvider classInfoProvider) {
WasmGCTypeMapper(WasmGCClassInfoProvider classInfoProvider, WasmFunctionTypes functionTypes,
WasmModule module) {
this.classInfoProvider = classInfoProvider;
this.functionTypes = functionTypes;
this.module = module;
}

public WasmStorageType mapStorageType(ValueType type) {
Expand Down Expand Up @@ -76,4 +86,21 @@ public WasmType mapType(ValueType type) {
return classInfoProvider.getClassInfo(type).getType();
}
}

public WasmFunctionType getFunctionType(String className, MethodDescriptor methodDesc, boolean fresh) {
var returnType = mapType(methodDesc.getResultType());
var javaParamTypes = methodDesc.getParameterTypes();
var paramTypes = new WasmType[javaParamTypes.length + 1];
paramTypes[0] = classInfoProvider.getClassInfo(className).getType();
for (var i = 0; i < javaParamTypes.length; ++i) {
paramTypes[i + 1] = mapType(javaParamTypes[i]);
}
if (fresh) {
var type = new WasmFunctionType(null, returnType, List.of(paramTypes));
module.types.add(type);
return type;
} else {
return functionTypes.of(returnType, paramTypes);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ protected WasmExpression nullLiteral(Expr expr) {
: WasmType.Reference.STRUCT);
}

@Override
protected WasmExpression genIsNull(WasmExpression value) {
return new WasmReferencesEqual(value, new WasmNullConstant(WasmType.Reference.STRUCT));
}

@Override
protected CallSiteIdentifier generateCallSiteId(TextLocation location) {
return new SimpleCallSite();
Expand Down Expand Up @@ -268,12 +273,19 @@ protected WasmExpression generateVirtualCall(WasmLocal instance, MethodReference
}
var instanceStruct = context.classInfoProvider().getClassInfo(vtable.getClassName()).getStructure();

var actualInstanceType = (WasmType.CompositeReference) instance.getType();
var actualInstanceStruct = (WasmStructure) actualInstanceType.composite;
var actualVtableType = (WasmType.CompositeReference) actualInstanceStruct.getFields().get(0).asUnpackedType();
var actualVtableStruct = (WasmStructure) actualVtableType.composite;

WasmExpression classRef = new WasmStructGet(instanceStruct, new WasmGetLocal(instance),
WasmGCClassInfoProvider.CLASS_FIELD_OFFSET);
var index = context.classInfoProvider().getVirtualMethodsOffset() + vtableIndex;
var vtableStruct = context.classInfoProvider().getClassInfo(vtable.getClassName())
.getVirtualTableStructure();
classRef = new WasmCast(classRef, vtableStruct.getReference());
if (!vtableStruct.isSupertypeOf(actualVtableStruct)) {
classRef = new WasmCast(classRef, vtableStruct.getReference());
}

var functionRef = new WasmStructGet(vtableStruct, classRef, index);
var functionTypeRef = (WasmType.CompositeReference) vtableStruct.getFields().get(index).asUnpackedType();
Expand Down Expand Up @@ -432,7 +444,6 @@ public void visit(InvocationExpr expr) {
result = invocation(expr, null, false);
}


@Override
protected WasmExpression invocation(InvocationExpr expr, List<WasmExpression> resultConsumer, boolean willDrop) {
if (expr.getType() == InvocationType.SPECIAL || expr.getType() == InvocationType.STATIC) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public void contributeToInitializer(WasmFunction function) {
var value = new WasmCall(nextCharArrayFunction);
function.getBody().add(new WasmStructSet(stringStruct, new WasmGetGlobal(str.global),
WasmGCClassInfoProvider.CUSTOM_FIELD_OFFSETS, value));
function.getBody().add(new WasmStructSet(stringStruct, new WasmGetGlobal(str.global),
WasmGCClassInfoProvider.CLASS_FIELD_OFFSET,
new WasmGetGlobal(standardClasses.stringClass().getPointer())));
}
}

Expand Down
Loading

0 comments on commit 40fbce0

Please sign in to comment.