Skip to content

Commit

Permalink
Lower RefLayout opcodes in MachLowering (#198)
Browse files Browse the repository at this point in the history
  • Loading branch information
btwj authored Sep 19, 2023
1 parent dd82b4c commit 7a9af29
Show file tree
Hide file tree
Showing 14 changed files with 416 additions and 76 deletions.
34 changes: 22 additions & 12 deletions aeneas/src/core/Eval.v3
Original file line number Diff line number Diff line change
Expand Up @@ -813,11 +813,11 @@ def evalOp(op: Operator, args: Arguments) -> Result {
}
RefLayoutGetField(offset) => {
var ref = args.ref(0);
return doRefLayoutGetField(args, ref, offset);
return doRefLayoutGetField(args, args.getTypeArg(1), ref, offset);
}
RefLayoutSetField(offset) => {
var ref = args.ref(0);
return doRefLayoutSetField(args, ref, offset, args.getArg(1));
return doRefLayoutSetField(args, args.getTypeArg(1), ref, offset, args.getArg(1));
}
RefLayoutAtRepeatedField(offset, scale, max) => {
var ref = args.ref(0);
Expand All @@ -831,13 +831,25 @@ def evalOp(op: Operator, args: Arguments) -> Result {
var ref = args.ref(0);
var index = args.i(1);
if (u32.view(index) >= u32.view(max)) return args.throw(V3Exception.BoundsCheck, null);
return doRefLayoutGetField(args, ref, offset + scale * index);
return doRefLayoutGetField(args, args.getTypeArg(1), ref, offset + scale * index);
}
RefLayoutSetRepeatedField(offset, scale, max) => {
var ref = args.ref(0);
var index = args.i(1);
if (u32.view(index) >= u32.view(max)) return args.throw(V3Exception.BoundsCheck, null);
return doRefLayoutSetField(args, ref, offset + scale * index, args.getArg(2));
return doRefLayoutSetField(args, args.getTypeArg(1), ref, offset + scale * index, args.getArg(2));
}
ByteArrayGetField(offset) => {
var array = args.r(0);
var i_offset = args.i(1);
// XXX: Refactor so no intermediate ByteArrayOffset object needed
return doRefLayoutGetField(args, args.getTypeArg(0), ByteArrayOffset.new(array, offset), i_offset);
}
ByteArraySetField(offset) => {
var array = args.r(0);
var i_offset = args.i(1);
// XXX: Refactor so no intermediate ByteArrayOffset object needed
return doRefLayoutSetField(args, args.getTypeArg(0), ByteArrayOffset.new(array, offset), i_offset, args.getArg(2));
}

//----------------------------------------------------------------------------
Expand All @@ -864,10 +876,9 @@ def evalOp(op: Operator, args: Arguments) -> Result {
return args.throw("EvalUnimplemented", op.opcode.name);
}

def doRefLayoutGetField(args: Arguments, ref: ByteArrayOffset, offset: int) -> Result {
def doRefLayoutGetField(args: Arguments, fieldType: Type, ref: ByteArrayOffset, offset: int) -> Result {
if (ref == null || ref.array == null) return args.throw(V3Exception.NullCheck, null);
var t = args.getTypeArg(1);
match (t) {
match (fieldType) {
x: IntType => {
var width = (x.width + 7) & ~7; // round up bitwidth
var byteSize = width >> 3;
Expand All @@ -891,14 +902,13 @@ def doRefLayoutGetField(args: Arguments, ref: ByteArrayOffset, offset: int) -> R
if (x.is64) return Float64Val.new(v);
else return Float32Val.new(u32.view(v));
}
_ => return args.throw("EvalException", Strings.format1("invalid RefLayoutField type %q", t.render));
_ => return args.throw("EvalException", Strings.format1("invalid RefLayoutField type %q", fieldType.render));
}
}
def doRefLayoutSetField(args: Arguments, ref: ByteArrayOffset, offset: int, val: Val) -> Result {
def doRefLayoutSetField(args: Arguments, fieldType: Type, ref: ByteArrayOffset, offset: int, val: Val) -> Result {
if (ref == null || ref.array == null) return args.throw(V3Exception.NullCheck, null);
var t = args.getTypeArg(1);
var size = 0, bits: u64 = 0;
match (t) {
match (fieldType) {
x: IntType => {
size = x.packedByteSize();
bits = if(x.width <= 32, u32.view(Int.unbox(val)), u64.view(Long.unboxSU(val, x.signed)));
Expand All @@ -913,7 +923,7 @@ def doRefLayoutSetField(args: Arguments, ref: ByteArrayOffset, offset: int, val:
else if (Float64Val.?(val)) bits = Float64Val.!(val).bits;
}
// TODO: enum types?
_ => return args.throw("EvalException", Strings.format1("invalid RefLayoutField type %q", t.render));
_ => return args.throw("EvalException", Strings.format1("invalid RefLayoutField type %q", fieldType.render));
}
ref.write(false, offset, size, bits);
return Values.BOTTOM;
Expand Down
2 changes: 2 additions & 0 deletions aeneas/src/core/Opcode.v3
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ type Opcode {
case RefLayoutAtRepeatedField(offset: int, scale: int, max: int);
case RefLayoutGetRepeatedField(offset: int, scale: int, max: int);
case RefLayoutSetRepeatedField(offset: int, scale: int, max: int);
case ByteArrayGetField(offset: int);
case ByteArraySetField(offset: int);
// System operations
case SystemCall(syscall: SystemCall);
// Container for VST operations
Expand Down
10 changes: 10 additions & 0 deletions aeneas/src/core/Operator.v3
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,14 @@ component V3Op {
var opcode = Opcode.RefLayoutSetRepeatedField(offset, scale, max);
return newOp0(opcode, [refType, fieldType], [refType, Int.TYPE, fieldType], Void.TYPE);
}
def newByteArrayGetField(offset: int, fieldType: Type) -> Operator {
var opcode = Opcode.ByteArrayGetField(offset);
return newOp0(opcode, [fieldType], [V3.arrayByteType, Int.TYPE], fieldType);
}
def newByteArraySetField(offset: int, fieldType: Type) -> Operator {
var opcode = Opcode.ByteArraySetField(offset);
return newOp0(opcode, [fieldType], [V3.arrayByteType, Int.TYPE, fieldType], Void.TYPE);
}
//----------------------------------------------------------------------------
def bestCallVirtual(spec: IrSpec) -> Operator {
if (spec.receiver.typeCon.kind == V3Kind.CLASS) {
Expand Down Expand Up @@ -594,6 +602,8 @@ def renderOp(op: Operator, buf: StringBuilder) -> StringBuilder {
RefLayoutIn(offset) => rfunc = StringBuilder.putd(_, offset);
RefLayoutGetField(offset) => rfunc = StringBuilder.putd(_, offset);
RefLayoutSetField(offset) => rfunc = StringBuilder.putd(_, offset);
ByteArrayGetField(offset) => rfunc = StringBuilder.putd(_, offset);
ByteArraySetField(offset) => rfunc = StringBuilder.putd(_, offset);
ConditionalThrow(exception) => rfunc = StringBuilder.puts(_, exception);
SystemCall(syscall) => rfunc = StringBuilder.puts(_, syscall.name);
VstSugar(op) => rfunc = StringBuilder.puts(_, op.name);
Expand Down
92 changes: 30 additions & 62 deletions aeneas/src/ir/SsaNormalizer.v3
Original file line number Diff line number Diff line change
Expand Up @@ -369,60 +369,47 @@ class SsaRaNormalizer extends SsaRebuilder {
var fn = normTypeArg(op, 1);
var newArgs = genRefs(app.inputs);
var array = newArgs[0], i_offset = newArgs[an.size];
var facts: Fact.set = Fact.O_NO_BOUNDS_CHECK;
var result: SsaInstr;

match (fn.oldType) {
x: IntType => {
var result = genArrayByteGetMultiple(app, x, array, i_offset, offset);
var wt = Int.getType(false, x.byteSize() * 8);
result = curBlock.opIntViewI0(wt, x, result);
return map1(app, result);
}
x: IntType => result = curBlock.opByteArrayGetField(x, offset, facts, array, i_offset);
x: FloatType => result = curBlock.opByteArrayGetField(x, offset, facts, array, i_offset);
x: EnumType => {
var it = IntType.!(x.enumDecl.tagType);
var result = genArrayByteGetMultiple(app, it, array, i_offset, offset);
var wt = Int.getType(false, it.byteSize() * 8);
result = curBlock.opByteArrayGetField(wt, offset, facts, array, i_offset);
var caseCount = newGraph.intConst(x.enumDecl.cases.length);
var inBound = curBlock.opIntULt(norm.config.ArrayLengthType, it, result, caseCount);
var wt = Int.getType(false, it.byteSize() * 8);
result = curBlock.opIntViewI0(wt, it, result);
result = curBlock.opIntViewI0(it, wt, result);
var split = SsaBlockSplit.new(context, curBlock);
curBlock = split.addIf(inBound);
curBlock = split.addElse();
curBlock = split.finish();
curBlock.opIntViewI0(wt, it, result);
result = split.addPhi(it, [result, newGraph.nullConst(it)]);
return map1(app, result);
}
x: FloatType => {
var it = Int.getType(false, x.total_width);
var result = genArrayByteGetMultiple(app, it, array, i_offset, offset);
var wt = Int.getType(false, it.byteSize() * 8);
result = curBlock.opIntViewI0(wt, it, result);
var floatViewOp = if(x.is64, V3Op.newFloat64ViewI(it), V3Op.newFloat32ViewI(it));
result = curBlock.addApply(app.source, floatViewOp, [result]);
return map1(app, result);
}
_ => context.fail1("unexpected type %q", fn.oldType.render);
}
return map1(app, result);
}
RefLayoutSetField(offset) => {
var fn = normTypeArg(op, 1);
var newArgs = genRefs(args);
var array = newArgs[0], i_offset = newArgs[1], val = newArgs[2];
match(fn.oldType) {
x: IntType => {
genArrayByteSetMultiple(app, x, array, i_offset, offset, val);
}
var facts: Fact.set = Fact.O_NO_BOUNDS_CHECK;
var result: SsaInstr;

match (fn.oldType) {
x: IntType => result = curBlock.opByteArraySetField(fn.oldType, offset, facts, array, i_offset, val);
x: FloatType => result = curBlock.opByteArraySetField(fn.oldType, offset, facts, array, i_offset, val);
x: EnumType => {
var it = IntType.!(x.enumDecl.tagType);
genArrayByteSetMultiple(app, it, array, i_offset, offset, val);
}
x: FloatType => {
var it = Int.getType(false, x.total_width);
var width = x.byteSize();
val = curBlock.opIntView(x, it, val);
genArrayByteSetMultiple(app, it, array, i_offset, offset, val);
var wt = Int.getType(false, it.byteSize() * 8);
result = curBlock.opByteArraySetField(wt, offset, facts, array, i_offset, val);
}
_ => context.fail1("unexpected type %q", fn.oldType.render);
}
return map1(app, result);
}
RefLayoutGetRepeatedField(offset, scale, max) => {
var an = arrayByteNorm;
Expand All @@ -436,38 +423,24 @@ class SsaRaNormalizer extends SsaRebuilder {
}
var result = SsaInstr.!(newGraph.intConst(0));
i_offset = curBlock.opIntAdd(curBlock.opIntMul(index, newGraph.intConst(scale)), i_offset);
var facts: Fact.set = Fact.O_NO_BOUNDS_CHECK;
match (fn.oldType) {
x: IntType => {
var result = genArrayByteGetMultiple(app, x, array, i_offset, offset);
var wt = Int.getType(false, x.byteSize() * 8);
result = curBlock.opIntViewI0(wt, x, result);
return map1(app, result);
}
x: IntType => result = curBlock.opByteArrayGetField(x, offset, facts, array, i_offset);
x: FloatType => result = curBlock.opByteArrayGetField(x, offset, facts, array, i_offset);
x: EnumType => {
var it = IntType.!(x.enumDecl.tagType);
var result = genArrayByteGetMultiple(app, it, array, i_offset, offset);
result = curBlock.opByteArrayGetField(it, offset, facts, array, i_offset);
var caseCount = newGraph.intConst(x.enumDecl.cases.length);
var inBound = curBlock.opIntULt(it, it, result, caseCount);
var split = SsaBlockSplit.new(context, curBlock);
curBlock = split.addIf(inBound);
curBlock = split.addElse();
curBlock = split.finish();
var wt = Int.getType(false, it.byteSize() * 8);
result = curBlock.opIntViewI0(wt, it, result);
result = split.addPhi(it, [result, newGraph.nullConst(it)]);
return map1(app, result);
}
x: FloatType => {
var it = Int.getType(false, x.total_width);
var result = genArrayByteGetMultiple(app, it, array, i_offset, offset);
var wt = Int.getType(false, it.byteSize() * 8);
result = curBlock.opIntViewI0(wt, it, result);
var floatViewOp = if(x.is64, V3Op.newFloat64ViewI(it), V3Op.newFloat32ViewI(it));
result = curBlock.addApply(app.source, floatViewOp, [result]);
return map1(app, result);
}
_ => ;
}
return map1(app, result);
}
RefLayoutSetRepeatedField(offset, scale, max) => {
var fn = normTypeArg(op, 1);
Expand All @@ -479,22 +452,17 @@ class SsaRaNormalizer extends SsaRebuilder {
curBlock.opConditionalThrow(V3Exception.BoundsCheck, oob);
}
i_offset = curBlock.opIntAdd(curBlock.opIntMul(index, newGraph.intConst(scale)), i_offset);
match(val.getType()) {
x: IntType => {
genArrayByteSetMultiple(app, x, array, i_offset, offset, val);
}
var facts: Fact.set = Fact.O_NO_BOUNDS_CHECK;
var result: SsaInstr;
match (fn.oldType) {
x: IntType => result = curBlock.opByteArraySetField(fn.oldType, offset, facts, array, i_offset, val);
x: FloatType => result = curBlock.opByteArraySetField(fn.oldType, offset, facts, array, i_offset, val);
x: EnumType => {
var it = IntType.!(x.enumDecl.tagType);
genArrayByteSetMultiple(app, it, array, i_offset, offset, val);
result = curBlock.opByteArraySetField(it, offset, facts, array, i_offset, val);
}
x: FloatType => {
var it = Int.getType(false, x.total_width);
var op = if(x.is64, V3Op.opIntViewF64, V3Op.opIntViewF32);
val = curBlock.pure(op, [val]);
genArrayByteSetMultiple(app, it, array, i_offset, offset, val);
}
_ => ;
}
return map1(app, result);
}
PtrAtContents => {
// rewrite to PtrAtArrayElem
Expand Down
12 changes: 12 additions & 0 deletions aeneas/src/jvm/JvmType.v3
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,16 @@ component JvmTypes {
def SIG_VOID_BYTE_ARRAY = JvmSig.new([], JvmTypes.BYTE_ARRAY);
def SIG_DOUBLE_DOUBLE_DOUBLE_LONG = JvmSig.new([DOUBLE, DOUBLE, DOUBLE], LONG);
def SIG_DOUBLE_DOUBLE_DOUBLE_BOOLEAN = JvmSig.new([DOUBLE, DOUBLE, DOUBLE], BOOLEAN);
def SIG_BYTE_ARRAY_INT_INT_BYTE = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, INT], BYTE);
def SIG_BYTE_ARRAY_INT_INT_SHORT = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, INT], SHORT);
def SIG_BYTE_ARRAY_INT_INT_INT = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, INT], INT);
def SIG_BYTE_ARRAY_INT_INT_LONG = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, INT], LONG);
def SIG_BYTE_ARRAY_INT_INT_FLOAT = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, INT], FLOAT);
def SIG_BYTE_ARRAY_INT_INT_DOUBLE = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, INT], DOUBLE);
def SIG_BYTE_ARRAY_INT_BYTE_INT_VOID = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, BYTE, INT], VOID);
def SIG_BYTE_ARRAY_INT_SHORT_INT_VOID = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, SHORT, INT], VOID);
def SIG_BYTE_ARRAY_INT_INT_INT_VOID = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, INT, INT], VOID);
def SIG_BYTE_ARRAY_INT_LONG_INT_VOID = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, LONG, INT], VOID);
def SIG_BYTE_ARRAY_INT_FLOAT_INT_VOID = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, FLOAT, INT], VOID);
def SIG_BYTE_ARRAY_INT_DOUBLE_INT_VOID = JvmSig.new([JvmTypes.BYTE_ARRAY, INT, DOUBLE, INT], VOID);
}
45 changes: 45 additions & 0 deletions aeneas/src/jvm/SsaJvmGen.v3
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,51 @@ class SsaJvmGen(jprog: JvmProgram, context: SsaContext, jsig: JvmSig, code: JvmC
emitThrow(exception);
code.patchBranch(b);
}
ByteArrayGetField(offset) => {
var fieldType = op.typeArgs[0];
code.iconst(offset);
match (fieldType) {
x: IntType => {
match (x.packedByteSize()) {
1 => jprog.invokesystem(code, "readBytes1", JvmTypes.SIG_BYTE_ARRAY_INT_INT_BYTE);
2 => jprog.invokesystem(code, "readBytes2", JvmTypes.SIG_BYTE_ARRAY_INT_INT_SHORT);
3 => jprog.invokesystem(code, "readBytes3", JvmTypes.SIG_BYTE_ARRAY_INT_INT_INT);
4 => jprog.invokesystem(code, "readBytes4", JvmTypes.SIG_BYTE_ARRAY_INT_INT_INT);
5 => jprog.invokesystem(code, "readBytes5", JvmTypes.SIG_BYTE_ARRAY_INT_INT_LONG);
6 => jprog.invokesystem(code, "readBytes6", JvmTypes.SIG_BYTE_ARRAY_INT_INT_LONG);
7 => jprog.invokesystem(code, "readBytes7", JvmTypes.SIG_BYTE_ARRAY_INT_INT_LONG);
8 => jprog.invokesystem(code, "readBytes8", JvmTypes.SIG_BYTE_ARRAY_INT_INT_LONG);
}
emitIntTrunc(x);
}
x: FloatType => {
if (x == Float.FLOAT32) jprog.invokesystem(code, "readBytesFloat", JvmTypes.SIG_BYTE_ARRAY_INT_INT_FLOAT);
else jprog.invokesystem(code, "readBytesDouble", JvmTypes.SIG_BYTE_ARRAY_INT_INT_DOUBLE);
}
}
}
ByteArraySetField(offset) => {
var fieldType = op.typeArgs[0];
code.iconst(offset);
match (fieldType) {
x: IntType => {
match (x.packedByteSize()) {
1 => jprog.invokesystem(code, "setBytes1", JvmTypes.SIG_BYTE_ARRAY_INT_BYTE_INT_VOID);
2 => jprog.invokesystem(code, "setBytes2", JvmTypes.SIG_BYTE_ARRAY_INT_SHORT_INT_VOID);
3 => jprog.invokesystem(code, "setBytes3", JvmTypes.SIG_BYTE_ARRAY_INT_INT_INT_VOID);
4 => jprog.invokesystem(code, "setBytes4", JvmTypes.SIG_BYTE_ARRAY_INT_INT_INT_VOID);
5 => jprog.invokesystem(code, "setBytes5", JvmTypes.SIG_BYTE_ARRAY_INT_LONG_INT_VOID);
6 => jprog.invokesystem(code, "setBytes6", JvmTypes.SIG_BYTE_ARRAY_INT_LONG_INT_VOID);
7 => jprog.invokesystem(code, "setBytes7", JvmTypes.SIG_BYTE_ARRAY_INT_LONG_INT_VOID);
8 => jprog.invokesystem(code, "setBytes8", JvmTypes.SIG_BYTE_ARRAY_INT_LONG_INT_VOID);
}
}
x: FloatType => {
if (x == Float.FLOAT32) jprog.invokesystem(code, "setBytesFloat", JvmTypes.SIG_BYTE_ARRAY_INT_FLOAT_INT_VOID);
else jprog.invokesystem(code, "setBytesDouble", JvmTypes.SIG_BYTE_ARRAY_INT_DOUBLE_INT_VOID);
}
}
}
} else {
context.fail1("unexpected opcode in SSA->JVM: %s", op.opcode.name);
}
Expand Down
Loading

0 comments on commit 7a9af29

Please sign in to comment.