Skip to content

Commit

Permalink
[packing] Simple packing implementation (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
btwj authored Apr 16, 2024
1 parent 80edfda commit b9da4e0
Show file tree
Hide file tree
Showing 28 changed files with 687 additions and 172 deletions.
6 changes: 6 additions & 0 deletions aeneas/src/core/Opcode.v3
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ type Opcode {
case FloatRoundD;
// Reference equality
case RefEq;
// IntRep operations
case IntRepCreate;
case IntRepView;
// Tuple operations
case TupleCreate(length: int);
case TupleGetElem(index: int);
Expand Down Expand Up @@ -241,6 +244,9 @@ component Opcodes {

t[Opcode.RefEq.tag] = P |C;

t[Opcode.IntRepCreate.tag] = P;
t[Opcode.IntRepView.tag] = P;

t[Opcode.TupleCreate.tag] = P;
t[Opcode.TupleGetElem.tag] = P;

Expand Down
7 changes: 7 additions & 0 deletions aeneas/src/core/Operator.v3
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ component V3Op {
def newRefEq(t: Type) -> Operator {
return newOp0(Opcode.RefEq, [t], [t, t], type_z);
}
//----------------------------------------------------------------------------
def newIntRepCreate(ft: Type, tt: IntRepType) -> Operator {
return newOp0(Opcode.IntRepCreate, [ft, tt], [ft], tt);
}
def newIntRepView(ft: IntRepType, tt: Type) -> Operator {
return newOp0(Opcode.IntRepView, [ft, tt], [ft], tt);
}
//----------------------------------------------------------------------------
def newTupleCreate(tupleType: Type) -> Operator {
var paramTypes = Lists.toArray(tupleType.nested);
Expand Down
2 changes: 2 additions & 0 deletions aeneas/src/ir/Facts.v3
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ component Facts {
def O_SAFE_DIVIDE = Fact.O_NO_ZERO_CHECK | Fact.O_PURE;
// alias for NO_NULL_CHECK
def O_NO_NAN_CHECK = Fact.O_NO_NULL_CHECK;
// facts for a safe shift
def O_SAFE_SHIFT = Fact.O_NO_SHIFT_CHECK | Fact.O_PURE;

def isLive(ic: IrClass) -> bool {
return (ic.facts & (Fact.C_ALLOCATED | Fact.C_HEAP)) != NONE;
Expand Down
1 change: 1 addition & 0 deletions aeneas/src/ir/Ir.v3
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class IrClass extends IrItem {
var maxClassId: int;
var machSize: int = -1;
var boxing: Boxing;
var packed: bool;

new(ctype, typeArgs, parent, fields, methods) { }
def inherits(m: IrMember) -> bool {
Expand Down
13 changes: 12 additions & 1 deletion aeneas/src/ir/Normalization.v3
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class NormalizerConfig {
var NormalizeRange: bool;

var GetScalar: (Compiler, Program, Type) -> Scalar.set = defaultGetScalar;
var GetBitWidth: (Compiler, Program, Type) -> byte = defaultGetBitWidth;
var MaxScalarWidth: byte = 64;

def setSignatureLimits(maxp: int, maxr: int) {
if (maxp < MaxParams) MaxParams = maxp;
Expand All @@ -24,11 +26,20 @@ class NormalizerConfig {
}
def defaultGetScalar(compiler: Compiler, prog: Program, t: Type) -> Scalar.set {
match (t) {
x: IntType => return if(x.width <= 32, Scalar.B32, Scalar.B64); // XXX: Scalar.R32, once packed refs
x: IntType => return if(x.width <= 32, Scalar.B32 | Scalar.B64, Scalar.B64); // XXX: Scalar.R32, once packed refs
x: FloatType => return if(x.is64, Scalar.F64, Scalar.F32);
_ => return Scalar.Ref;
}
}
def defaultGetBitWidth(compiler: Compiler, prog: Program, t: Type) -> byte {
var target = compiler.target;
match (t) {
x: IntType => return x.width;
x: FloatType => return x.total_width;
x: BoolType => return 1;
_ => return 64;
}
}

// Normalizes a program based on the results of reachability analysis.
def TRANSFERRABLE_FACTS = (Fact.M_ABSTRACT | Fact.M_INLINE | Fact.M_OPERATOR | Fact.M_NEW | Fact.M_EMPTY | Fact.M_EQUALS);
Expand Down
38 changes: 19 additions & 19 deletions aeneas/src/ir/PackingSolver.v3
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Interval(start: byte, end: byte) #unboxed {
}

// A fixed-length mutable pattern of bits.
class PackingPattern(bits: Array<PackingBit>) {
class ScalarPattern(bits: Array<PackingBit>) {
def size = bits.length;

def render(buf: StringBuilder) -> StringBuilder {
Expand Down Expand Up @@ -80,7 +80,7 @@ class PackingPattern(bits: Array<PackingBit>) {
def unassignInterval(i: Interval) -> this {
for (j = i.start; j < i.end; j++) bits[j] = PackingBit.Unassigned;
}
def copy() -> PackingPattern { return PackingPattern.new(Arrays.dup(bits)); }
def copy() -> ScalarPattern { return ScalarPattern.new(Arrays.dup(bits)); }
}

type PackingField #unboxed {
Expand All @@ -106,7 +106,7 @@ class PackingSolver(size: byte, refPatterns: RefPatterns) {
fieldOrder = null;
solution = null;
}
def canDistinguish(state: Array<PackingPattern>, elements: Array<bool>) -> bool {
def canDistinguish(state: Array<ScalarPattern>, elements: Array<bool>) -> bool {
var numElements = 0;
for (i in elements) if (i) numElements++;
if (numElements <= 1) return true;
Expand Down Expand Up @@ -162,12 +162,12 @@ class PackingSolver(size: byte, refPatterns: RefPatterns) {
}
return false;
}
def checkDistinguishable(state: Array<PackingPattern>) -> bool {
def checkDistinguishable(state: Array<ScalarPattern>) -> bool {
if (tryExplicitTaggingHeuristic(state)) return true;
return tryAssignmentHeuristic(state);
}
// Run the backtracking solver algorithm.
def solve(idx: int, state: Array<PackingPattern>) -> bool {
def solve(idx: int, state: Array<ScalarPattern>) -> bool {
if (idx == fieldOrder.length) return checkDistinguishable(state);

var CaseField = fieldOrder[idx];
Expand Down Expand Up @@ -202,12 +202,12 @@ class PackingSolver(size: byte, refPatterns: RefPatterns) {
}
}

var patterns = Array<PackingPattern>.new(cases.length);
var patterns = Array<ScalarPattern>.new(cases.length);
var assignments = HashMap<CaseField, Interval>.new(CaseField.hash, CaseField.==);
if (isRefScalar) {
for (i < cases.length) {
var c = cases[i];
var casePacking: PackingPattern;
var casePacking: ScalarPattern;

var containsRef = false;
for (j < cases[i].length) {
Expand Down Expand Up @@ -269,7 +269,7 @@ class PackingSolver(size: byte, refPatterns: RefPatterns) {
while (1 << i < numCases) i++;
return i;
}
private def tryExplicitTaggingHeuristic(state: Array<PackingPattern>) -> bool {
private def tryExplicitTaggingHeuristic(state: Array<ScalarPattern>) -> bool {
// if there are enough contiguous aligned ?s, we can just use them to tag
var longest: Interval = EMPTY_INTERVAL, curStart: byte = 0;
for (i < size) {
Expand Down Expand Up @@ -300,7 +300,7 @@ class PackingSolver(size: byte, refPatterns: RefPatterns) {
}
return false;
}
private def tryAssignmentHeuristic(state: Array<PackingPattern>) -> bool {
private def tryAssignmentHeuristic(state: Array<ScalarPattern>) -> bool {
// difficult case: we have to build the decision tree and check
var elements = Array<bool>.new(state.length);
for (i < elements.length) elements[i] = true;
Expand Down Expand Up @@ -332,7 +332,7 @@ type PackingProblem(cases: Array<Array<PackingField>>, assignments: Array<(CaseF
}

class PackingSolution(
patterns: Array<PackingPattern>,
patterns: Array<ScalarPattern>,
assignments: HashMap<CaseField, Interval>,
isRef: bool,
problem: PackingProblem) {
Expand Down Expand Up @@ -364,31 +364,31 @@ class PackingSolution(

// Represents a collection of patterns related to references and non-references.
class RefPatterns(
ptrref: PackingPattern,
ptrref: ScalarPattern,
refInterval: Interval,
nonptrref: PackingPattern,
nonref: PackingPattern,
nullref: PackingPattern) {}
nonptrref: ScalarPattern,
nonref: ScalarPattern,
nullref: ScalarPattern) {}

def EMPTY_INTERVAL = Interval(0, 0);

component PackingPatterns {
component ScalarPatterns {
def TAGGED_PTR_64 = RefPatterns.new(
parse("????_????_????_????_...._...._...._...._...._...._...._...._...._...._...._.??0"),
Interval(3, 48),
parse("????_????_????_????_????_????_????_????_????_????_????_????_????_????_????_???1"),
PackingPattern.new(Array<PackingBit>.new(64)),
ScalarPattern.new(Array<PackingBit>.new(64)),
parse("0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_????_????_????") // XXX: bump to 1MiB?
);
def TAGGED_PTR_32 = RefPatterns.new(
parse("...._...._...._...._...._...._...._.??0"),
Interval(3, 32),
parse("????_????_????_????_????_????_????_???1"),
PackingPattern.new(Array<PackingBit>.new(32)),
ScalarPattern.new(Array<PackingBit>.new(32)),
parse("0000_0000_0000_0000_0000_????_????_????") // XXX: bump to 64KiB?
);

def parse(s: string) -> PackingPattern {
def parse(s: string) -> ScalarPattern {
var bits = Vector<PackingBit>.new();
var min = 0;
if (s.length >= 2 && s[0] == '0' && (s[1] == 'b' || s[1] == 'B')) min = 2; // skip 0b prefix if present
Expand All @@ -404,6 +404,6 @@ component PackingPatterns {
_ => bits.put(PackingBit.Assigned(c));
}
}
return PackingPattern.new(bits.extract());
return ScalarPattern.new(bits.extract());
}
}
58 changes: 46 additions & 12 deletions aeneas/src/ir/SsaNormalizer.v3
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ class SsaRaNormalizer extends SsaRebuilder {
var read = curBlock.opArrayGetElem(V3.arrayByteType, indexType, facts,
array, curBlock.opIntAdd(newGraph.intConst(i + offset), i_offset));
read = curBlock.opIntViewI0(Byte.TYPE, wt, read);
var shifted = if(i == 0, read, curBlock.addApplyF(wt.opShl(), [read, newGraph.intConst(8 * i)], Fact.O_PURE | Fact.O_NO_SHIFT_CHECK));
var shifted = if(i == 0, read, curBlock.addApplyF(wt.opShl(), [read, newGraph.intConst(8 * i)], Facts.O_SAFE_SHIFT));
result = curBlock.addApplyF(wt.opOr(), [result, shifted], Fact.O_PURE);
facts |= Fact.O_NO_NULL_CHECK;
}
Expand All @@ -603,7 +603,7 @@ class SsaRaNormalizer extends SsaRebuilder {
var facts: Fact.set = Fact.O_NO_BOUNDS_CHECK;
var indexType = norm.config.RangeStartType;
for (i < size) {
var shifted = curBlock.addApplyF(it.opShr(), [val, newGraph.intConst(8 * i)], Fact.O_PURE | Fact.O_NO_SHIFT_CHECK);
var shifted = curBlock.addApplyF(it.opShr(), [val, newGraph.intConst(8 * i)], Facts.O_SAFE_SHIFT);
var b = curBlock.opIntViewI0(it, Byte.TYPE, shifted);
var write = curBlock.opArraySetElem(array.getType(), indexType, facts,
array, curBlock.opIntAdd(newGraph.intConst(i + offset), i_offset), b);
Expand Down Expand Up @@ -766,7 +766,7 @@ class SsaRaNormalizer extends SsaRebuilder {
}
// Normalize equality between two unboxed variants of unknown case (e.g. x: T == y: T)
def normUnboxedVariantEqual(i_old: SsaApplyOp, vn: VariantNorm, refs: Array<SsaInstr>) -> SsaInstr {
if (vn.isTagless()) return normEqual(i_old, vn, refs); // we can always do a scalar-by-scalar comparison if no boxes are involved.
if (vn.hasNoTag()) return normEqual(i_old, vn, refs); // we can always do a scalar-by-scalar comparison if no boxes are involved.
var ct = ClassType.!(vn.oldType);
if (ct.superType != null) return normUnboxedVariantCaseEqual(i_old, vn, refs);

Expand Down Expand Up @@ -795,7 +795,11 @@ class SsaRaNormalizer extends SsaRebuilder {
}
def normVariantGetTag(vn: VariantNorm, args: Range<SsaInstr>) -> SsaInstr {
if (vn == null) return null;
if (vn.isTagless()) return newGraph.zeroConst();
if (vn.hasNoTag()) return newGraph.zeroConst();
if (vn.hasExplicitTag() && vn.tag.isPacked()) {
var tag = vn.tag, tagIdx = tag.indexes[0], tagInterval = tag.intervals[0];
return genExtractInterval(args[tagIdx], tagInterval, IntRepType.!(vn.at(tagIdx)), tag.tn.at(0));
}
return args[vn.tagIndex()];
}
def normTupleGetElem(i_old: SsaInstr, args: Array<SsaDfEdge>, op: Operator, index: int) {
Expand Down Expand Up @@ -1292,12 +1296,7 @@ class SsaRaNormalizer extends SsaRebuilder {
normType(raField.receiver); // XXX: normType() side-effect of flattening
if (isVariant && raField != null && rc.isUnboxed()) {
// field of unboxed data type
var vals = Array<SsaInstr>.new(nf.length);
var field = rc.variantNorm.fields[raField.orig.index];
for (i < vals.length) {
var idx = field.indexes[i];
vals[i] = genVariantScalarView(rc.variantNorm.at(idx), field.tn.at(i), ai_new[idx]);
}
var vals = genVariantGetField(rc, raField, rc.variantNorm, ai_new);
return mapNnf(i_old, vals);
}
var receiver = ai_new[0];
Expand Down Expand Up @@ -1371,18 +1370,53 @@ class SsaRaNormalizer extends SsaRebuilder {

var result = Array<SsaInstr>.new(vn.size);
for (i < result.length) result[i] = newGraph.nullConst(vn.at(i));
if (!vn.isTagless()) result[vn.tagIndex()] = newGraph.intConst(vn.tagValue);

if (vn.hasExplicitTag()) {
if (vn.tag.isPacked()) {
var tagIdx = vn.tag.indexes[0];
result[tagIdx] = genSetInterval(result[tagIdx], newGraph.intConst(vn.tagValue), vn.tag.intervals[0], vn.tag.tn.newType, IntRepType.!(vn.at(tagIdx)));
} else {
result[vn.tagIndex()] = newGraph.intConst(vn.tagValue);
}
}

for (i < vn.fields.length) {
var f = vn.fields[i];
var fieldRanges = vn.fieldRanges[i], os = fieldRanges.0;
for (j < f.indexes.length) {
var idx = f.indexes[j];
result[idx] = genVariantScalarView(f.tn.at(j), vn.at(idx), ai_inputs[os + j]);
if (!f.isPacked()) result[idx] = genVariantScalarView(f.tn.at(j), vn.at(idx), ai_inputs[os + j]);
else result[idx] = genSetInterval(result[idx], ai_inputs[os + j], f.intervals[j], f.tn.at(j), IntRepType.!(vn.at(idx)));
}
}
return result;
}
def genExtractInterval(scalar: SsaInstr, interval: Interval, ft: IntRepType, tt: Type) -> SsaInstr {
if (interval.start > 0) scalar = curBlock.addApplyF(ft.opShr(), [scalar, newGraph.intConst(interval.start)], Facts.O_SAFE_SHIFT);
scalar = genVariantScalarView(ft, tt, scalar);
return scalar;
}
def genSetInterval(scalar: SsaInstr, value: SsaInstr, interval: Interval, ft: Type, tt: IntRepType) -> SsaInstr {
var intRep = genVariantScalarView(ft, tt, value);
intRep = curBlock.addApplyF(tt.opShl(), [intRep, newGraph.intConst(interval.start)], Facts.O_SAFE_SHIFT);
return curBlock.pure(tt.opOr(), [scalar, intRep]);
}
def genVariantGetField(rc: RaClass, raField: RaField, vn: VariantNorm, ninputs: Array<SsaInstr>) -> Array<SsaInstr> {
var nf = raField.liveFields(norm.ra);
var vals = Array<SsaInstr>.new(nf.length);
var field = rc.variantNorm.fields[raField.orig.index];

for (i < vals.length) {
var idx = field.indexes[i];
if (field.isPacked()) {
var irt = IntRepType.!(vn.at(idx));
vals[i] = genExtractInterval(ninputs[idx], field.intervals[i], IntRepType.!(vn.at(idx)), field.tn.at(i));
} else {
vals[i] = genVariantScalarView(rc.variantNorm.at(idx), field.tn.at(i), ninputs[idx]);
}
}
return vals;
}
def normNullCheck(oldApp: SsaApplyOp, op: Operator) {
var newArgs = genRefs(oldApp.inputs);
if (newArgs.length >= 1) addNullCheck(oldApp, newArgs[0]);
Expand Down
Loading

0 comments on commit b9da4e0

Please sign in to comment.