From 1d1c5d6c9c2debdc71696c53f6894f4a3ebf2f86 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 23 Jul 2025 15:12:28 +0800 Subject: [PATCH 01/12] split symbolic and concrete stack at compile time --- .../scala/wasm/StagedConcolicMiniWasm.scala | 845 +++++++++++++----- 1 file changed, 617 insertions(+), 228 deletions(-) diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 833bbc9b..5bc89605 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -26,28 +26,57 @@ trait StagedWasmEvaluator extends SAIOps { trait ReturnSite trait StagedNum { + def tipe: ValueType + } + + trait StagedConcreteNum { def tipe: ValueType = this match { - case I32(_, _) => NumType(I32Type) - case I64(_, _) => NumType(I64Type) - case F32(_, _) => NumType(F32Type) - case F64(_, _) => NumType(F64Type) + case I32C(_) => NumType(I32Type) + case I64C(_) => NumType(I64Type) + case F32C(_) => NumType(F32Type) + case F64C(_) => NumType(F64Type) } def i: Rep[Num] + } + + case class I32C(i: Rep[Num]) extends StagedConcreteNum + case class I64C(i: Rep[Num]) extends StagedConcreteNum + case class F32C(i: Rep[Num]) extends StagedConcreteNum + case class F64C(i: Rep[Num]) extends StagedConcreteNum + + + trait StagedSymbolicNum { + def tipe: ValueType = this match { + case I32S(_) => NumType(I32Type) + case I64S(_) => NumType(I64Type) + case F32S(_) => NumType(F32Type) + case F64S(_) => NumType(F64Type) + } def s: Rep[SymVal] } - case class I32(i: Rep[Num], s: Rep[SymVal]) extends StagedNum - case class I64(i: Rep[Num], s: Rep[SymVal]) extends StagedNum - case class F32(i: Rep[Num], s: Rep[SymVal]) extends StagedNum - case class F64(i: Rep[Num], s: Rep[SymVal]) extends StagedNum - def toStagedNum(num: Num): StagedNum = { + case class I32S(s: Rep[SymVal]) extends StagedSymbolicNum + case class I64S(s: Rep[SymVal]) extends StagedSymbolicNum + case class F32S(s: Rep[SymVal]) extends StagedSymbolicNum + case class F64S(s: Rep[SymVal]) extends StagedSymbolicNum + + def toStagedNum(num: Num): StagedConcreteNum = { num match { - case I32V(_) => I32(num, Concrete(num)) - case I64V(_) => I64(num, Concrete(num)) - case F32V(_) => F32(num, Concrete(num)) - case F64V(_) => F64(num, Concrete(num)) + case I32V(_) => I32C(num) + case I64V(_) => I64C(num) + case F32V(_) => F32C(num) + case F64V(_) => F64C(num) + } + } + + def toStagedSymbolicNum(num: Num): StagedSymbolicNum = { + num match { + case I32V(_) => I32S(Concrete(num)) + case I64V(_) => I64S(Concrete(num)) + case F32V(_) => F32S(Concrete(num)) + case F64V(_) => F64S(Concrete(num)) } } @@ -59,12 +88,21 @@ trait StagedWasmEvaluator extends SAIOps { case NumType(F64Type) => 8 } - def toTagger: (Rep[Num], Rep[SymVal]) => StagedNum = { + def concreteTag: (Rep[Num]) => StagedConcreteNum = { + ty match { + case NumType(I32Type) => I32C + case NumType(I64Type) => I64C + case NumType(F32Type) => F32C + case NumType(F64Type) => F64C + } + } + + def symbolicTag: (Rep[SymVal]) => StagedSymbolicNum = { ty match { - case NumType(I32Type) => I32 - case NumType(I64Type) => I64 - case NumType(F32Type) => F32 - case NumType(F64Type) => F64 + case NumType(I32Type) => I32S + case NumType(I64Type) => I64S + case NumType(F32Type) => F32S + case NumType(F64Type) => F64S } } } @@ -82,6 +120,12 @@ trait StagedWasmEvaluator extends SAIOps { (ty, Context(rest, frameTypes)) } + def take(n: Int): Context = { + Predef.assert(n <= stackTypes.size, s"Context.take size $n is larger than stack size ${stackTypes.size}") + val (taken, rest) = stackTypes.splitAt(n) + Context(rest, frameTypes) + } + def shift(offset: Int, size: Int): Context = { // Predef.println(s"[DEBUG] Shifting stack by $offset, size $size, $this") Predef.assert(offset >= 0, s"Context shift offset must be non-negative, get $offset") @@ -126,78 +170,136 @@ trait StagedWasmEvaluator extends SAIOps { val (inst, rest) = (insts.head, insts.tail) inst match { case Drop => - val (_, newCtx) = Stack.pop() + val (ty, newCtx) = ctx.pop() + Stack.popC(ty) + Stack.popS(ty) eval(rest, kont, mkont, trail)(newCtx) case WasmConst(num) => - val newCtx = Stack.push(toStagedNum(num)) + Stack.pushC(toStagedNum(num)) + Stack.pushS(toStagedSymbolicNum(num)) + val newCtx = ctx.push(num.tipe(module)) eval(rest, kont, mkont, trail)(newCtx) case Symbolic(ty) => - val (id, newCtx1) = Stack.pop() + Stack.popC(ty) + val id = Stack.popS(ty) val symVal = id.makeSymbolic() - val concVal = SymEnv.read(symVal) - val tagger = ty.toTagger - val value = tagger(concVal, symVal) - val newCtx2 = Stack.push(value)(newCtx1) - eval(rest, kont, mkont, trail)(newCtx2) + val num = SymEnv.read(symVal.s) + Stack.pushC(ty.concreteTag(num)) + Stack.pushS(symVal) + val newCtx = ctx.pop()._2.push(ty) + eval(rest, kont, mkont, trail)(newCtx) case LocalGet(i) => - val newCtx = Stack.push(Frames.get(i)) + Stack.pushC(Frames.getC(i)) + Stack.pushS(Frames.getS(i)) + val newCtx = ctx.push(ctx.frameTypes(i)) eval(rest, kont, mkont, trail)(newCtx) case LocalSet(i) => - val (num, newCtx) = Stack.pop() - Frames.set(i, num)(newCtx) + val (ty, newCtx) = ctx.pop() + val num = Stack.popC(ty) + val sym = Stack.popS(ty) + Frames.setC(i, num) + Frames.setS(i, sym) eval(rest, kont, mkont, trail)(newCtx) case LocalTee(i) => - val (num, newCtx) = Stack.peek - Frames.set(i, num) - eval(rest, kont, mkont, trail)(newCtx) + val ty = ctx.pop()._1 + val num = Stack.peekC(ty) + val sym = Stack.peekS(ty) + Frames.setC(i, num) + Frames.setS(i, sym) + eval(rest, kont, mkont, trail)(ctx) case GlobalGet(i) => - val newCtx = Stack.push(Globals(i)) + Stack.pushC(Globals.getC(i)) + Stack.pushS(Globals.getS(i)) + val newCtx = ctx.push(module.globals(i).ty.ty) eval(rest, kont, mkont, trail)(newCtx) case GlobalSet(i) => - val (value, newCtx) = Stack.pop() + val (ty, newCtx) = ctx.pop() + val num = Stack.popC(ty) + val sym = Stack.popS(ty) module.globals(i).ty match { - case GlobalType(tipe, true) => Globals(i) = value + case GlobalType(tipe, true) => { + Globals.setC(i, num) + Globals.setS(i, sym) + } case _ => throw new Exception("Cannot set immutable global") } eval(rest, kont, mkont, trail)(newCtx) case Store(StoreOp(align, offset, ty, None)) => - val (value, newCtx1) = Stack.pop() - val (addr, newCtx2) = Stack.pop()(newCtx1) + val (ty1, newCtx1) = ctx.pop() + val value = Stack.popC(ty1) + val symValue = Stack.popS(ty1) + val (ty2, newCtx2) = newCtx1.pop() + val addr = Stack.popC(ty2) + val symAddr = Stack.popS(ty2) Memory.storeInt(addr.toInt, offset, value.toInt) eval(rest, kont, mkont, trail)(newCtx2) case Nop => eval(rest, kont, mkont, trail) case Load(LoadOp(align, offset, ty, None, None)) => - val (addr, newCtx1) = Stack.pop() - val value = Memory.loadInt(addr.toInt, offset) - val newCtx2 = Stack.push(value)(newCtx1) + val (ty1, newCtx1) = ctx.pop() + val addr = Stack.popC(ty1) + Stack.popS(ty1) + val num = Memory.loadIntC(addr.toInt, offset) + val sym = Memory.loadIntS(addr.toInt, offset) + Stack.pushC(num) + Stack.pushS(sym) + val newCtx2 = newCtx1.push(ty) eval(rest, kont, mkont, trail)(newCtx2) case MemorySize => ??? case MemoryGrow => - val (delta, newCtx1) = Stack.pop() + val (ty, newCtx) = ctx.pop() + val delta = Stack.popC(ty) + Stack.popS(ty) val ret = Memory.grow(delta.toInt) val retNum = Values.I32V(ret) + // For now, we assume that the result of memory.grow only depends on the execution path, + // we can relax this by turning it return to a symbol value and mimic the memory.grow's result as input. val retSym = "Concrete".reflectCtrlWith[SymVal](retNum) - val newCtx2 = Stack.push(I32(retNum, retSym))(newCtx1) + Stack.pushC(I32C(retNum)) + Stack.pushS(I32S(retSym)) + val newCtx2 = ctx.push(NumType(I32Type)) eval(rest, kont, mkont, trail)(newCtx2) case MemoryFill => ??? case Unreachable => unreachable() case Test(op) => - val (v, newCtx1) = Stack.pop() - val newCtx2 = Stack.push(evalTestOp(op, v))(newCtx1) + val (ty, newCtx1) = ctx.pop() + val v = Stack.popC(ty) + val s = Stack.popS(ty) + Stack.pushC(evalTestOpC(op, v)) + Stack.pushS(evalTestOpS(op, s)) + val newCtx2 = newCtx1.push(v.tipe) eval(rest, kont, mkont, trail)(newCtx2) case Unary(op) => - val (v, newCtx1) = Stack.pop() - val newCtx2 = Stack.push(evalUnaryOp(op, v))(newCtx1) + val (ty, newCtx1) = ctx.pop() + val v = Stack.popC(ty) + val s = Stack.popS(ty) + val res = evalUnaryOpC(op, v) + Stack.pushC(res) + Stack.pushS(evalUnaryOpS(op, s)) + val newCtx2 = newCtx1.push(res.tipe) eval(rest, kont, mkont, trail)(newCtx2) case Binary(op) => - val (v2, newCtx1) = Stack.pop() - val (v1, newCtx2) = Stack.pop()(newCtx1) - val newCtx3 = Stack.push(evalBinOp(op, v1, v2))(newCtx2) + val (ty2, newCtx1) = ctx.pop() + val v2 = Stack.popC(ty2) + val s2 = Stack.popS(ty2) + val (ty1, newCtx2) = newCtx1.pop() + val v1 = Stack.popC(ty1) + val s1 = Stack.popS(ty1) + val res = evalBinOpC(op, v1, v2) + Stack.pushC(res) + Stack.pushS(evalBinOpS(op, s1, s2)) + val newCtx3 = newCtx2.push(res.tipe) eval(rest, kont, mkont, trail)(newCtx3) case Compare(op) => - val (v2, newCtx1) = Stack.pop() - val (v1, newCtx2) = Stack.pop()(newCtx1) - val newCtx3 = Stack.push(evalRelOp(op, v1, v2))(newCtx2) + val (ty2, newCtx1) = ctx.pop() + val v2 = Stack.popC(ty2) + val s2 = Stack.popS(ty2) + val (ty1, newCtx2) = newCtx1.pop() + val v1 = Stack.popC(ty1) + val s1 = Stack.popS(ty1) + val res = evalRelOpC(op, v1, v2) + Stack.pushC(res) + Stack.pushS(evalRelOpS(op, s1, s2)) + val newCtx3 = newCtx2.push(res.tipe) eval(rest, kont, mkont, trail)(newCtx3) case WasmBlock(ty, inner) => // no need to modify the stack when entering a block @@ -208,7 +310,9 @@ trait StagedWasmEvaluator extends SAIOps { def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the block, stackSize =", Stack.size) val offset = restCtx.stackTypes.size - exitSize - val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + Stack.shiftC(offset, funcTy.out.size) + Stack.shiftS(offset, funcTy.out.size) + val newRestCtx = restCtx.shift(offset, funcTy.out.size) eval(rest, kont, mk, trail)(newRestCtx) }) eval(inner, restK _, mkont, restK _ :: trail) @@ -219,31 +323,37 @@ trait StagedWasmEvaluator extends SAIOps { def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the loop, stackSize =", Stack.size) val offset = restCtx.stackTypes.size - exitSize - val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + Stack.shiftC(offset, funcTy.out.size) + Stack.shiftS(offset, funcTy.out.size) + val newRestCtx = restCtx.shift(offset, funcTy.out.size) eval(rest, kont, mk, trail)(newRestCtx) }) val enterSize = ctx.stackTypes.size def loop(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the loop, stackSize =", Stack.size) val offset = restCtx.stackTypes.size - enterSize - val newRestCtx = Stack.shift(offset, funcTy.inps.size)(restCtx) + Stack.shiftC(offset, funcTy.inps.size) + Stack.shiftS(offset, funcTy.inps.size) + val newRestCtx = restCtx.shift(offset, funcTy.inps.size) eval(inner, restK _, mk, loop _ :: trail)(newRestCtx) }) loop(ctx)(mkont) case If(ty, thn, els) => val funcTy = ty.funcType - val (cond, newCtx) = Stack.pop() + val (condTy, newCtx) = ctx.pop() + val cond = Stack.popC(condTy) + val symCond = Stack.popS(condTy) val exitSize = newCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size - // TODO: can we avoid code duplication here? - val dummy = makeDummy def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the if, stackSize =", Stack.size) val offset = restCtx.stackTypes.size - exitSize - val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) + Stack.shiftC(offset, funcTy.out.size) + Stack.shiftS(offset, funcTy.out.size) + val newRestCtx = restCtx.shift(offset, funcTy.out.size) eval(rest, kont, mk, trail)(newRestCtx) }) // TODO: put the cond.s to path condition - ExploreTree.fillWithIfElse(cond.s) + ExploreTree.fillWithIfElse(symCond.s) if (cond.toInt != 0) { ExploreTree.moveCursor(true) eval(thn, restK _, mkont, restK _ :: trail)(newCtx) @@ -256,10 +366,11 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Jump to $label") trail(label)(ctx)(mkont) case BrIf(label) => - val (cond, newCtx) = Stack.pop() + val (ty, newCtx) = ctx.pop() + val cond = Stack.popC(ty) + val symCond = Stack.popS(ty) info(s"The br_if(${label})'s condition is ", cond.toInt) - // TODO: put the cond.s to path condition - ExploreTree.fillWithIfElse(cond.s) + ExploreTree.fillWithIfElse(symCond.s) if (cond.toInt != 0) { info(s"Jump to $label") ExploreTree.moveCursor(true) @@ -271,12 +382,15 @@ trait StagedWasmEvaluator extends SAIOps { } () case BrTable(labels, default) => - val (label, newCtx) = Stack.pop() + val (ty, newCtx) = ctx.pop() + val label = Stack.popC(ty) + val labelSym = Stack.popS(ty) def aux(choices: List[Int], idx: Int): Rep[Unit] = { if (choices.isEmpty) trail(default)(newCtx)(mkont) else { val cond = (label - toStagedNum(I32V(idx))).isZero() - ExploreTree.fillWithIfElse(cond.s) + val condSym = (labelSym - toStagedSymbolicNum(I32V(idx))).isZero() + ExploreTree.fillWithIfElse(condSym.s) if (cond.toInt != 0) { ExploreTree.moveCursor(true) trail(choices.head)(newCtx)(mkont) @@ -323,37 +437,49 @@ trait StagedWasmEvaluator extends SAIOps { callee } // Predef.println(s"[DEBUG] locals size: ${locals.size}") - val (args, newCtx) = Stack.take(ty.inps.size) + val newCtx = ctx.take(ty.inps.size) + val argsC = Stack.takeC(ty.inps) + val argsS = Stack.takeS(ty.inps) if (isTail) { // when tail call, return to the caller's return continuation - Frames.popFrame(ctx.frameTypes.size) - Frames.pushFrame(locals) - Frames.putAll(args) + Frames.popFrameC(ctx.frameTypes.size) + Frames.popFrameS(ctx.frameTypes.size) + Frames.pushFrameC(locals) + Frames.pushFrameS(locals) + Frames.putAllC(argsC) + Frames.putAllS(argsS) callee(mkont) } else { // We make a new trail by `restK`, since function creates a new block to escape // (more or less like `return`) val restK: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) - Frames.popFrame(locals.size) + Frames.popFrameC(locals.size) + Frames.popFrameS(locals.size) eval(rest, kont, mk, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size))) }) val dummy = makeDummy val newMKont: Rep[MCont[Unit]] = funHere((_u: Rep[Unit]) => { restK(mkont) }, dummy) - Frames.pushFrame(locals) - Frames.putAll(args) + Frames.pushFrameC(locals) + Frames.pushFrameS(locals) + Frames.putAllC(argsC) + Frames.putAllS(argsS) callee(newMKont) } case Import("console", "log", _) | Import("spectest", "print_i32", _) => //println(s"[DEBUG] current stack: $stack") - val (v, newCtx) = Stack.pop() + val (ty, newCtx) = ctx.pop() + val v = Stack.popC(ty) + Stack.popS(ty) println(v.toInt) eval(rest, kont, mkont, trail)(newCtx) case Import("console", "assert", _) => - val (v, newCtx) = Stack.pop() + val (ty, newCtx) = ctx.pop() + val v = Stack.popC(ty) + Stack.popS(ty) runtimeAssert(v.toInt != 0) eval(rest, kont, mkont, trail)(newCtx) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") @@ -361,18 +487,43 @@ trait StagedWasmEvaluator extends SAIOps { } } - def evalTestOp(op: TestOp, value: StagedNum): StagedNum = op match { + def evalTestOpC(op: TestOp, value: StagedConcreteNum): StagedConcreteNum = op match { + case Eqz(_) => value.isZero + } + + def evalTestOpS(op: TestOp, value: StagedSymbolicNum): StagedSymbolicNum = op match { case Eqz(_) => value.isZero } - def evalUnaryOp(op: UnaryOp, value: StagedNum): StagedNum = op match { + def evalUnaryOpC(op: UnaryOp, value: StagedConcreteNum): StagedConcreteNum = op match { + case Clz(_) => value.clz() + case Ctz(_) => value.ctz() + case Popcnt(_) => value.popcnt() + case _ => ??? + } + + def evalUnaryOpS(op: UnaryOp, value: StagedSymbolicNum): StagedSymbolicNum = op match { case Clz(_) => value.clz() case Ctz(_) => value.ctz() case Popcnt(_) => value.popcnt() case _ => ??? } - def evalBinOp(op: BinOp, v1: StagedNum, v2: StagedNum): StagedNum = op match { + def evalBinOpC(op: BinOp, v1: StagedConcreteNum, v2: StagedConcreteNum): StagedConcreteNum = op match { + case Add(_) => v1 + v2 + case Mul(_) => v1 * v2 + case Sub(_) => v1 - v2 + case Shl(_) => v1 << v2 + // case ShrS(_) => v1 >> v2 // TODO: signed shift right + case ShrU(_) => v1 >> v2 + case And(_) => v1 & v2 + case DivS(_) => v1 / v2 + case DivU(_) => v1 / v2 + case _ => + throw new Exception(s"Unknown binary operation $op") + } + + def evalBinOpS(op: BinOp, v1: StagedSymbolicNum, v2: StagedSymbolicNum): StagedSymbolicNum = op match { case Add(_) => v1 + v2 case Mul(_) => v1 * v2 case Sub(_) => v1 - v2 @@ -386,7 +537,21 @@ trait StagedWasmEvaluator extends SAIOps { throw new Exception(s"Unknown binary operation $op") } - def evalRelOp(op: RelOp, v1: StagedNum, v2: StagedNum): StagedNum = op match { + def evalRelOpC(op: RelOp, v1: StagedConcreteNum, v2: StagedConcreteNum): StagedConcreteNum = op match { + case Eq(_) => v1 numEq v2 + case Ne(_) => v1 numNe v2 + case LtS(_) => v1 < v2 + case LtU(_) => v1 ltu v2 + case GtS(_) => v1 > v2 + case GtU(_) => v1 gtu v2 + case LeS(_) => v1 <= v2 + case LeU(_) => v1 leu v2 + case GeS(_) => v1 >= v2 + case GeU(_) => v1 geu v2 + case _ => ??? + } + + def evalRelOpS(op: RelOp, v1: StagedSymbolicNum, v2: StagedSymbolicNum): StagedSymbolicNum = op match { case Eq(_) => v1 numEq v2 case Ne(_) => v1 numNe v2 case LtS(_) => v1 < v2 @@ -426,9 +591,11 @@ trait StagedWasmEvaluator extends SAIOps { } val (instrs, locals) = (funBody.body, funBody.locals) resetStacks() - Frames.pushFrame(locals) + Frames.pushFrameC(locals) + Frames.pushFrameS(locals) eval(instrs, (_: Context) => forwardKont, mkont, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) - Frames.popFrame(locals.size) + Frames.popFrameC(locals.size) + Frames.popFrameS(locals.size) } def evalTop(main: Option[String], printRes: Boolean, dumpTree: Option[String]): Rep[Unit] = { @@ -450,66 +617,78 @@ trait StagedWasmEvaluator extends SAIOps { // stack operations object Stack { - def shift(offset: Int, size: Int)(ctx: Context): Context = { + def shiftC(offset: Int, size: Int) = { if (offset > 0) { "stack-shift".reflectCtrlWith[Unit](offset, size) } - ctx.shift(offset, size) + } + + def shiftS(offset: Int, size: Int) = { + if (offset > 0) { + "sym-stack-shift".reflectCtrlWith[Unit](offset, size) + } } def initialize(): Rep[Unit] = { "stack-init".reflectCtrlWith[Unit]() } - def pop()(implicit ctx: Context): (StagedNum, Context) = { - val (ty, newContext) = ctx.pop() - val num = ty match { - case NumType(I32Type) => I32("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) - case NumType(I64Type) => I64("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) - case NumType(F32Type) => F32("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) - case NumType(F32Type) => F64("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) - } - (num, newContext) + def popC(ty: ValueType): StagedConcreteNum = ty match { + case NumType(I32Type) => I32C("stack-pop".reflectCtrlWith[Num]()) + case NumType(I64Type) => I64C("stack-pop".reflectCtrlWith[Num]()) + case NumType(F32Type) => F32C("stack-pop".reflectCtrlWith[Num]()) + case NumType(F32Type) => F64C("stack-pop".reflectCtrlWith[Num]()) } - def peek(implicit ctx: Context): (StagedNum, Context) = { - val ty = ctx.stackTypes.head - val num = ty match { - case NumType(I32Type) => I32("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) - case NumType(I64Type) => I64("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) - case NumType(F32Type) => F32("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) - case NumType(F32Type) => F64("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) - } - (num, ctx) + def popS(ty: ValueType): StagedSymbolicNum = ty match { + case NumType(I32Type) => I32S("sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(I64Type) => I64S("sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F32S("sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(F64Type) => F64S("sym-stack-pop".reflectCtrlWith[SymVal]()) } - def push(num: StagedNum)(implicit ctx: Context): Context = { - num match { - case I32(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) - case I64(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) - case F32(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) - case F64(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) - } - ctx.push(num.tipe) + def peekC(ty: ValueType): StagedConcreteNum = ty match { + case NumType(I32Type) => I32C("stack-peek".reflectCtrlWith[Num]()) + case NumType(I64Type) => I64C("stack-peek".reflectCtrlWith[Num]()) + case NumType(F32Type) => F32C("stack-peek".reflectCtrlWith[Num]()) + case NumType(F32Type) => F64C("stack-peek".reflectCtrlWith[Num]()) } - def take(n: Int)(implicit ctx: Context): (List[StagedNum], Context) = n match { - case 0 => (Nil, ctx) - case n => - val (v, newCtx1) = pop() - val (rest, newCtx2) = take(n - 1) - (v::rest, newCtx2) + def peekS(ty: ValueType): StagedSymbolicNum = ty match { + case NumType(I32Type) => I32S("sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(I64Type) => I64S("sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F32S("sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(F64Type) => F64S("sym-stack-peek".reflectCtrlWith[SymVal]()) } - def drop(n: Int)(implicit ctx: Context): Context = { - take(n)._2 + def pushC(num: StagedConcreteNum) = num match { + case I32C(v) => "stack-push".reflectCtrlWith[Unit](v) + case I64C(v) => "stack-push".reflectCtrlWith[Unit](v) + case F32C(v) => "stack-push".reflectCtrlWith[Unit](v) + case F64C(v) => "stack-push".reflectCtrlWith[Unit](v) } - def shift(offset: Rep[Int], size: Rep[Int]): Rep[Unit] = { - if (offset > 0) { - "stack-shift".reflectCtrlWith[Unit](offset, size) - "sym-stack-shift".reflectCtrlWith[Unit](offset, size) - } + def pushS(num: StagedSymbolicNum) = num match { + case I32S(s) => "sym-stack-push".reflectCtrlWith[Unit](s) + case I64S(s) => "sym-stack-push".reflectCtrlWith[Unit](s) + case F32S(s) => "sym-stack-push".reflectCtrlWith[Unit](s) + case F64S(s) => "sym-stack-push".reflectCtrlWith[Unit](s) + } + + def takeC(types: List[ValueType]): List[StagedConcreteNum] = types match { + case Nil => Nil + case t :: ts => + val v = popC(t) + val rest = takeC(ts) + v :: rest + } + + def takeS(types: List[ValueType]): List[StagedSymbolicNum] = types match { + case Nil => Nil + case t :: ts => + val v = popS(t) + val rest = takeS(ts) + v :: rest } def print(): Rep[Unit] = { @@ -522,41 +701,72 @@ trait StagedWasmEvaluator extends SAIOps { } object Frames { - def get(i: Int)(implicit ctx: Context): StagedNum = { + def getC(i: Int)(implicit ctx: Context): StagedConcreteNum = { // val offset = ctx.frameTypes.take(i).map(_.size).sum ctx.frameTypes(i) match { - case NumType(I32Type) => I32("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) - case NumType(I64Type) => I64("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) - case NumType(F32Type) => F32("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) - case NumType(F64Type) => F64("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(I32Type) => I32C("frame-get".reflectCtrlWith[Num](i)) + case NumType(I64Type) => I64C("frame-get".reflectCtrlWith[Num](i)) + case NumType(F32Type) => F32C("frame-get".reflectCtrlWith[Num](i)) + case NumType(F64Type) => F64C("frame-get".reflectCtrlWith[Num](i)) } } - def set(i: Int, v: StagedNum)(implicit ctx: Context): Rep[Unit] = { - // val offset = ctx.frameTypes.take(i).map(_.size).sum + def getS(i: Int)(implicit ctx: Context): StagedSymbolicNum = { + ctx.frameTypes(i) match { + case NumType(I32Type) => I32S("sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(I64Type) => I64S("sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(F32Type) => F32S("sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(F64Type) => F64S("sym-frame-get".reflectCtrlWith[SymVal](i)) + } + } + + def setC(i: Int, v: StagedConcreteNum): Rep[Unit] = { v match { - case I32(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) - case I64(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) - case F32(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) - case F64(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) + case I32C(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case I64C(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case F32C(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case F64C(v) => "frame-set".reflectCtrlWith[Unit](i, v) + } + } + + def setS(i: Int, s: StagedSymbolicNum): Rep[Unit] = { + s match { + case I32S(s) => "sym-frame-set".reflectCtrlWith[Unit](i, s) + case I64S(s) => "sym-frame-set".reflectCtrlWith[Unit](i, s) + case F32S(s) => "sym-frame-set".reflectCtrlWith[Unit](i, s) + case F64S(s) => "sym-frame-set".reflectCtrlWith[Unit](i, s) } } - def pushFrame(locals: List[ValueType]): Rep[Unit] = { + def pushFrameC(locals: List[ValueType]): Rep[Unit] = { // Predef.println(s"[DEBUG] push frame: $locals") val size = locals.size "frame-push".reflectCtrlWith[Unit](size) + } + + def pushFrameS(locals: List[ValueType]): Rep[Unit] = { + // Predef.println(s"[DEBUG] push frame: $locals") + val size = locals.size "sym-frame-push".reflectCtrlWith[Unit](size) } - def popFrame(size: Int): Rep[Unit] = { + def popFrameC(size: Int): Rep[Unit] = { "frame-pop".reflectCtrlWith[Unit](size) + } + + def popFrameS(size: Int): Rep[Unit] = { "sym-frame-pop".reflectCtrlWith[Unit](size) } - def putAll(args: List[StagedNum])(implicit ctx: Context): Rep[Unit] = { + def putAllC(args: List[StagedConcreteNum]): Rep[Unit] = { for ((arg, i) <- args.view.reverse.zipWithIndex) { - Frames.set(i, arg) + Frames.setC(i, arg) + } + } + + def putAllS(args: List[StagedSymbolicNum]): Rep[Unit] = { + for ((arg, i) <- args.view.reverse.zipWithIndex) { + Frames.setS(i, arg) } } } @@ -567,8 +777,12 @@ trait StagedWasmEvaluator extends SAIOps { // todo: store symbolic value to memory via extract/concat operation } - def loadInt(base: Rep[Int], offset: Int): StagedNum = { - I32("I32V".reflectCtrlWith[Num]("memory-load-int".reflectCtrlWith[Int](base, offset)), "sym-load-int-todo".reflectCtrlWith[SymVal](base, offset)) + def loadIntC(base: Rep[Int], offset: Int): StagedConcreteNum = { + I32C("I32V".reflectCtrlWith[Num]("memory-load-int".reflectCtrlWith[Int](base, offset))) + } + + def loadIntS(base: Rep[Int], offset: Int): StagedSymbolicNum = { + I32S("sym-load-int-todo".reflectCtrlWith[SymVal](base, offset)) } // Returns the previous memory size on success, or -1 if the memory cannot be grown. @@ -603,21 +817,39 @@ trait StagedWasmEvaluator extends SAIOps { // global read/write object Globals { - def apply(i: Int): StagedNum = { + def getC(i: Int): StagedConcreteNum = { module.globals(i).ty match { - case GlobalType(NumType(I32Type), _) => I32("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) - case GlobalType(NumType(I64Type), _) => I64("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) - case GlobalType(NumType(F32Type), _) => F32("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) - case GlobalType(NumType(F64Type), _) => F64("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(I32Type), _) => I32C("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(I64Type), _) => I64C("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(F32Type), _) => F32C("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(F64Type), _) => F64C("global-get".reflectCtrlWith[Num](i)) } } - def update(i: Int, v: StagedNum): Rep[Unit] = { + def getS(i: Int): StagedSymbolicNum = { module.globals(i).ty match { - case GlobalType(NumType(I32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) - case GlobalType(NumType(I64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) - case GlobalType(NumType(F32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) - case GlobalType(NumType(F64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) + case GlobalType(NumType(I32Type), _) => I32S("sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(I64Type), _) => I64S("sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(F32Type), _) => F32S("sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(F64Type), _) => F64S("sym-global-get".reflectCtrlWith[SymVal](i)) + } + } + + def setC(i: Int, v: StagedConcreteNum): Rep[Unit] = { + module.globals(i).ty match { + case GlobalType(NumType(I32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i) + case GlobalType(NumType(I64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i) + case GlobalType(NumType(F32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i) + case GlobalType(NumType(F64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i) + } + } + + def setS(i: Int, s: StagedSymbolicNum): Rep[Unit] = { + module.globals(i).ty match { + case GlobalType(NumType(I32Type), _) => "sym-global-set".reflectCtrlWith[Unit](i, s.s) + case GlobalType(NumType(I64Type), _) => "sym-global-set".reflectCtrlWith[Unit](i, s.s) + case GlobalType(NumType(F32Type), _) => "sym-global-set".reflectCtrlWith[Unit](i, s.s) + case GlobalType(NumType(F64Type), _) => "sym-global-set".reflectCtrlWith[Unit](i, s.s) } } } @@ -652,164 +884,320 @@ trait StagedWasmEvaluator extends SAIOps { } // runtime Num type - implicit class StagedNumOps(num: StagedNum) { + implicit class StagedConcreteNumOps(num: StagedConcreteNum) { def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num.i) - def isZero(): StagedNum = num match { - case I32(x_c, x_s) => I32(Values.I32V("is-zero".reflectCtrlWith[Int](num.toInt)), "sym-is-zero".reflectCtrlWith[SymVal](x_s)) + def isZero(): StagedConcreteNum = num match { + case I32C(x_c) => I32C(Values.I32V("is-zero".reflectCtrlWith[Int](num.toInt))) + } + + def clz(): StagedConcreteNum = num match { + case I32C(x) => I32C("clz".reflectCtrlWith[Num](x)) + case I64C(x) => I64C("clz".reflectCtrlWith[Num](x)) + } + + def ctz(): StagedConcreteNum = num match { + case I32C(x) => I32C("ctz".reflectCtrlWith[Num](x)) + case I64C(x) => I64C("ctz".reflectCtrlWith[Num](x)) + } + + def popcnt(): StagedConcreteNum = num match { + case I32C(x) => I32C("popcnt".reflectCtrlWith[Num](x)) + case I64C(x) => I64C("popcnt".reflectCtrlWith[Num](x)) + } + + def +(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-add".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-add".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-add".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-add".reflectCtrlWith[Num](x, y)) + } + } + + def -(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-sub".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-sub".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-sub".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-sub".reflectCtrlWith[Num](x, y)) + } + } + + def *(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-mul".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-mul".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-mul".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-mul".reflectCtrlWith[Num](x, y)) + } + } + + def /(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-div".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-div".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-div".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-div".reflectCtrlWith[Num](x, y)) + } + } + + def <<(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-shl".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-shl".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-shl".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-shl".reflectCtrlWith[Num](x, y)) + } + } + + def >>(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-shr".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-shr".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-shr".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-shr".reflectCtrlWith[Num](x, y)) + } + } + + def &(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-and".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-and".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-and".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-and".reflectCtrlWith[Num](x, y)) + } } - def clz(): StagedNum = num match { - case I32(x_c, x_s) => I32("clz".reflectCtrlWith[Num](x_c), "sym-clz".reflectCtrlWith[SymVal](x_s)) - case I64(x_c, x_s) => I64("clz".reflectCtrlWith[Num](x_c), "sym-clz".reflectCtrlWith[SymVal](x_s)) + def numEq(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-eq".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-eq".reflectCtrlWith[Num](x, y)) + } } - def ctz(): StagedNum = num match { - case I32(x_c, x_s) => I32("ctz".reflectCtrlWith[Num](x_c), "sym-ctz".reflectCtrlWith[SymVal](x_s)) - case I64(x_c, x_s) => I64("ctz".reflectCtrlWith[Num](x_c), "sym-ctz".reflectCtrlWith[SymVal](x_s)) + def numNe(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-ne".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-ne".reflectCtrlWith[Num](x, y)) + } } - def popcnt(): StagedNum = num match { - case I32(x_c, x_s) => I32("popcnt".reflectCtrlWith[Num](x_c), "sym-popcnt".reflectCtrlWith[SymVal](x_s)) - case I64(x_c, x_s) => I64("popcnt".reflectCtrlWith[Num](x_c), "sym-popcnt".reflectCtrlWith[SymVal](x_s)) + def <(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-lt".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-lt".reflectCtrlWith[Num](x, y)) + } } - def makeSymbolic(): Rep[SymVal] = { - "make-symbolic".reflectCtrlWith[SymVal](num.s) + def ltu(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-ltu".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-ltu".reflectCtrlWith[Num](x, y)) + } + } + + def >(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-gt".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-gt".reflectCtrlWith[Num](x, y)) + } + } + + def gtu(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-gtu".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-gtu".reflectCtrlWith[Num](x, y)) + } + } + + def <=(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-le".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-le".reflectCtrlWith[Num](x, y)) + } } - def +(rhs: StagedNum): StagedNum = { + def leu(rhs: StagedConcreteNum): StagedConcreteNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32C(x), I32C(y)) => I32C("relation-leu".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-leu".reflectCtrlWith[Num](x, y)) } } + def >=(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-ge".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-ge".reflectCtrlWith[Num](x, y)) + } + } + + def geu(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-geu".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-geu".reflectCtrlWith[Num](x, y)) + } + } + } + + implicit class StagedSymbolicNumOps(num: StagedSymbolicNum) { + def makeSymbolic(): StagedSymbolicNum = num match { + case I32S(x) => I32S("make-symbolic".reflectCtrlWith[SymVal](x)) + case I64S(x) => I64S("make-symbolic".reflectCtrlWith[SymVal](x)) + case F32S(x) => F32S("make-symbolic".reflectCtrlWith[SymVal](x)) + case F64S(x) => F64S("make-symbolic".reflectCtrlWith[SymVal](x)) + } + + def isZero(): StagedSymbolicNum = num match { + case I32S(x) => I32S("sym-is-zero".reflectCtrlWith[SymVal](x)) + } + + def clz(): StagedSymbolicNum = num match { + case I32S(x) => I32S("sym-clz".reflectCtrlWith[SymVal](x)) + case I64S(x) => I64S("sym-clz".reflectCtrlWith[SymVal](x)) + } + + def ctz(): StagedSymbolicNum = num match { + case I32S(x) => I32S("sym-ctz".reflectCtrlWith[SymVal](x)) + case I64S(x) => I64S("sym-ctz".reflectCtrlWith[SymVal](x)) + } + + def popcnt(): StagedSymbolicNum = num match { + case I32S(x) => I32S("sym-popcnt".reflectCtrlWith[SymVal](x)) + case I64S(x) => I64S("sym-popcnt".reflectCtrlWith[SymVal](x)) + } + + def +(rhs: StagedSymbolicNum): StagedSymbolicNum = { + (num, rhs) match { + case (I32S(x), I32S(y)) => I32S("sym-binary-add".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-add".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-add".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-add".reflectCtrlWith[SymVal](x, y)) + } + } - def -(rhs: StagedNum): StagedNum = { + def -(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-sub".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-sub".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-sub".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-sub".reflectCtrlWith[SymVal](x, y)) } } - def *(rhs: StagedNum): StagedNum = { + def *(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-mul".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-mul".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-mul".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-mul".reflectCtrlWith[SymVal](x, y)) } } - def /(rhs: StagedNum): StagedNum = { + def /(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-div".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-div".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-div".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-div".reflectCtrlWith[SymVal](x, y)) } } - def <<(rhs: StagedNum): StagedNum = { + def <<(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-shl".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-shl".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-shl".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-shl".reflectCtrlWith[SymVal](x, y)) } } - def >>(rhs: StagedNum): StagedNum = { + def >>(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-shr".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-shr".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-shr".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-shr".reflectCtrlWith[SymVal](x, y)) } } - def &(rhs: StagedNum): StagedNum = { + def &(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-and".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-and".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-and".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-and".reflectCtrlWith[SymVal](x, y)) } } - def numEq(rhs: StagedNum): StagedNum = { + def numEq(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-eq".reflectCtrlWith[Num](x_c, y_c), "sym-relation-eq".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-eq".reflectCtrlWith[Num](x_c, y_c), "sym-relation-eq".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-eq".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-eq".reflectCtrlWith[SymVal](x, y)) } } - def numNe(rhs: StagedNum): StagedNum = { + def numNe(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ne".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ne".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ne".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ne".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-ne".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-ne".reflectCtrlWith[SymVal](x, y)) } } - def <(rhs: StagedNum): StagedNum = { + def <(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-lt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-lt".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-lt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-lt".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-lt".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-lt".reflectCtrlWith[SymVal](x, y)) } } - def ltu(rhs: StagedNum): StagedNum = { + def ltu(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ltu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ltu".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ltu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ltu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("relation-ltu".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("relation-ltu".reflectCtrlWith[SymVal](x, y)) } } - def >(rhs: StagedNum): StagedNum = { + def >(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-gt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gt".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-gt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gt".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-gt".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-gt".reflectCtrlWith[SymVal](x, y)) } } - def gtu(rhs: StagedNum): StagedNum = { + def gtu(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-gtu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gtu".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-gtu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gtu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-gtu".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-gtu".reflectCtrlWith[SymVal](x, y)) } } - def <=(rhs: StagedNum): StagedNum = { + def <=(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-le".reflectCtrlWith[Num](x_c, y_c), "sym-relation-le".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-le".reflectCtrlWith[Num](x_c, y_c), "sym-relation-le".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-le".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-le".reflectCtrlWith[SymVal](x, y)) } } - def leu(rhs: StagedNum): StagedNum = { + def leu(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-leu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-leu".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-leu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-leu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-leu".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-leu".reflectCtrlWith[SymVal](x, y)) } } - def >=(rhs: StagedNum): StagedNum = { + def >=(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ge".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ge".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ge".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ge".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-ge".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-ge".reflectCtrlWith[SymVal](x, y)) } } - def geu(rhs: StagedNum): StagedNum = { + def geu(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-geu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-geu".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-geu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-geu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-geu".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-geu".reflectCtrlWith[SymVal](x, y)) } } } @@ -935,7 +1323,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "binary-add", List(lhs, rhs), _) => shallow(lhs); emit(" + "); shallow(rhs) case Node(_, "binary-sub", List(lhs, rhs), _) => - shallow(lhs); emit(" - "); shallow(rhs) + // todo: avoid using c++ operator, use explicit method call so operator's precedence issues won't exist + emit("("); shallow(lhs); emit(" - "); shallow(rhs); emit(")") case Node(_, "binary-mul", List(lhs, rhs), _) => shallow(lhs); emit(" * "); shallow(rhs) case Node(_, "binary-div", List(lhs, rhs), _) => From 70f754f3b3d5d78b0d25b1a92e8a994919ef0223 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Sun, 17 Aug 2025 16:40:23 -0400 Subject: [PATCH 02/12] remember both start context & end context, and history between them --- headers/wasm/symbolic_rt.hpp | 13 +- .../scala/wasm/StagedConcolicMiniWasm.scala | 292 ++++++++++++------ 2 files changed, 206 insertions(+), 99 deletions(-) diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index 18629c80..a139e494 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -531,9 +531,7 @@ class SymEnv_t { return map[symbol->get_id()]; } - void update(std::vector new_env) { - map = std::move(new_env); - } + void update(std::vector new_env) { map = std::move(new_env); } std::string to_string() const { std::string result; @@ -548,9 +546,16 @@ class SymEnv_t { } private: - std::vector map; // The symbolic environment, a vector of Num + std::vector map; // The symbolic environment, a vector of Num }; static SymEnv_t SymEnv; +class Reuse_t { +public: + bool is_reusing() { return false; } +}; + +static Reuse_t Reuse; + #endif // WASM_SYMBOLIC_RT_HPP \ No newline at end of file diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 5bc89605..72547377 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -112,18 +112,22 @@ trait StagedWasmEvaluator extends SAIOps { frameTypes: List[ValueType] ) { def push(ty: ValueType): Context = { - Context(ty :: stackTypes, frameTypes) + this.copy(stackTypes = ty :: stackTypes) + } + + def peek: ValueType = { + stackTypes.head } def pop(): (ValueType, Context) = { val (ty :: rest) = stackTypes - (ty, Context(rest, frameTypes)) + (ty, this.copy(stackTypes = rest)) } def take(n: Int): Context = { Predef.assert(n <= stackTypes.size, s"Context.take size $n is larger than stack size ${stackTypes.size}") val (taken, rest) = stackTypes.splitAt(n) - Context(rest, frameTypes) + this.copy(stackTypes = rest) } def shift(offset: Int, size: Int): Context = { @@ -137,11 +141,51 @@ trait StagedWasmEvaluator extends SAIOps { ) } } + + } + + case class ContextTransition(startCtx: Context, history: List[Instr], endCtx: Context) { + def log(instr: Instr): ContextTransition = { + this.copy(history = instr :: history) + } + + def clearHistory: (Context, List[Instr], ContextTransition) = { + (startCtx, history, this.copy(startCtx = endCtx, history = Nil)) + } + + def push(ty: ValueType): ContextTransition = { + this.copy(endCtx = endCtx.push(ty)) + } + + def peek: ValueType = { + endCtx.peek + } + + def pop(): (ValueType, ContextTransition) = { + val (ty, newCtx) = endCtx.pop() + (ty, this.copy(endCtx = newCtx)) + } + + def take(n: Int): ContextTransition = { + this.copy(endCtx = endCtx.take(n)) + } + + def shift(offset: Int, size: Int): ContextTransition = { + this.copy(endCtx = endCtx.shift(offset, size)) + } + } + + object ContextTransition { + + def apply(startCtx: Context) = { + new ContextTransition(startCtx, Nil, startCtx) + } + } type MCont[A] = Unit => A type Cont[A] = (MCont[A]) => A - type Trail[A] = List[Context => Rep[Cont[A]]] + type Trail[A] = List[ContextTransition => Rep[Cont[A]]] // a cache storing the compiled code for each function, to reduce re-compilation val compileCache = new HashMap[Int, Rep[(MCont[Unit]) => Unit]] @@ -156,29 +200,29 @@ trait StagedWasmEvaluator extends SAIOps { }) } - def eval(insts: List[Instr], - kont: Context => Rep[Cont[Unit]], + kont: ContextTransition => Rep[Cont[Unit]], mkont: Rep[MCont[Unit]], trail: Trail[Unit]) - (implicit ctx: Context): Rep[Unit] = { - if (insts.isEmpty) return kont(ctx)(mkont) + (implicit oldCT: ContextTransition): Rep[Unit] = { + if (insts.isEmpty) return kont(oldCT)(mkont) // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") // Predef.println(s"[DEBUG] Current context: $ctx") - + implicit val ctx = oldCT.endCtx val (inst, rest) = (insts.head, insts.tail) + val ct = oldCT.log(inst) inst match { case Drop => - val (ty, newCtx) = ctx.pop() + val (ty, ct1) = ct.pop() Stack.popC(ty) Stack.popS(ty) - eval(rest, kont, mkont, trail)(newCtx) + eval(rest, kont, mkont, trail)(ct1) case WasmConst(num) => Stack.pushC(toStagedNum(num)) Stack.pushS(toStagedSymbolicNum(num)) - val newCtx = ctx.push(num.tipe(module)) - eval(rest, kont, mkont, trail)(newCtx) + val ct1 = ct.push(num.tipe(module)) + eval(rest, kont, mkont, trail)(ct1) case Symbolic(ty) => Stack.popC(ty) val id = Stack.popS(ty) @@ -186,34 +230,34 @@ trait StagedWasmEvaluator extends SAIOps { val num = SymEnv.read(symVal.s) Stack.pushC(ty.concreteTag(num)) Stack.pushS(symVal) - val newCtx = ctx.pop()._2.push(ty) - eval(rest, kont, mkont, trail)(newCtx) + val ct1 = ct.pop()._2.push(ty) + eval(rest, kont, mkont, trail)(ct1) case LocalGet(i) => Stack.pushC(Frames.getC(i)) Stack.pushS(Frames.getS(i)) - val newCtx = ctx.push(ctx.frameTypes(i)) - eval(rest, kont, mkont, trail)(newCtx) + val ct1 = ct.push(ctx.frameTypes(i)) + eval(rest, kont, mkont, trail)(ct1) case LocalSet(i) => - val (ty, newCtx) = ctx.pop() + val (ty, ct1) = ct.pop() val num = Stack.popC(ty) val sym = Stack.popS(ty) Frames.setC(i, num) Frames.setS(i, sym) - eval(rest, kont, mkont, trail)(newCtx) + eval(rest, kont, mkont, trail)(ct1) case LocalTee(i) => - val ty = ctx.pop()._1 + val ty = ct.peek val num = Stack.peekC(ty) val sym = Stack.peekS(ty) Frames.setC(i, num) Frames.setS(i, sym) - eval(rest, kont, mkont, trail)(ctx) + eval(rest, kont, mkont, trail)(ct) case GlobalGet(i) => Stack.pushC(Globals.getC(i)) Stack.pushS(Globals.getS(i)) - val newCtx = ctx.push(module.globals(i).ty.ty) - eval(rest, kont, mkont, trail)(newCtx) + val ct1 = ct.push(module.globals(i).ty.ty) + eval(rest, kont, mkont, trail)(ct1) case GlobalSet(i) => - val (ty, newCtx) = ctx.pop() + val (ty, ct1) = ct.pop() val num = Stack.popC(ty) val sym = Stack.popS(ty) module.globals(i).ty match { @@ -223,30 +267,30 @@ trait StagedWasmEvaluator extends SAIOps { } case _ => throw new Exception("Cannot set immutable global") } - eval(rest, kont, mkont, trail)(newCtx) + eval(rest, kont, mkont, trail)(ct1) case Store(StoreOp(align, offset, ty, None)) => - val (ty1, newCtx1) = ctx.pop() + val (ty1, ct1) = ct.pop() val value = Stack.popC(ty1) val symValue = Stack.popS(ty1) - val (ty2, newCtx2) = newCtx1.pop() + val (ty2, ct2) = ct1.pop() val addr = Stack.popC(ty2) val symAddr = Stack.popS(ty2) Memory.storeInt(addr.toInt, offset, value.toInt) - eval(rest, kont, mkont, trail)(newCtx2) + eval(rest, kont, mkont, trail)(ct2) case Nop => eval(rest, kont, mkont, trail) case Load(LoadOp(align, offset, ty, None, None)) => - val (ty1, newCtx1) = ctx.pop() + val (ty1, ct1) = ct.pop() val addr = Stack.popC(ty1) Stack.popS(ty1) val num = Memory.loadIntC(addr.toInt, offset) val sym = Memory.loadIntS(addr.toInt, offset) Stack.pushC(num) Stack.pushS(sym) - val newCtx2 = newCtx1.push(ty) - eval(rest, kont, mkont, trail)(newCtx2) + val ct2 = ct1.push(ty) + eval(rest, kont, mkont, trail)(ct2) case MemorySize => ??? case MemoryGrow => - val (ty, newCtx) = ctx.pop() + val (ty, ct1) = ct.pop() val delta = Stack.popC(ty) Stack.popS(ty) val ret = Memory.grow(delta.toInt) @@ -256,144 +300,168 @@ trait StagedWasmEvaluator extends SAIOps { val retSym = "Concrete".reflectCtrlWith[SymVal](retNum) Stack.pushC(I32C(retNum)) Stack.pushS(I32S(retSym)) - val newCtx2 = ctx.push(NumType(I32Type)) - eval(rest, kont, mkont, trail)(newCtx2) + val ct2 = ct1.push(NumType(I32Type)) + eval(rest, kont, mkont, trail)(ct2) case MemoryFill => ??? case Unreachable => unreachable() case Test(op) => - val (ty, newCtx1) = ctx.pop() + val (ty, ct1) = ct.pop() val v = Stack.popC(ty) val s = Stack.popS(ty) Stack.pushC(evalTestOpC(op, v)) Stack.pushS(evalTestOpS(op, s)) - val newCtx2 = newCtx1.push(v.tipe) - eval(rest, kont, mkont, trail)(newCtx2) + val ct2 = ct1.push(v.tipe) + eval(rest, kont, mkont, trail)(ct2) case Unary(op) => - val (ty, newCtx1) = ctx.pop() + val (ty, ct1) = ct.pop() val v = Stack.popC(ty) val s = Stack.popS(ty) val res = evalUnaryOpC(op, v) Stack.pushC(res) Stack.pushS(evalUnaryOpS(op, s)) - val newCtx2 = newCtx1.push(res.tipe) - eval(rest, kont, mkont, trail)(newCtx2) + val ct2 = ct1.push(res.tipe) + eval(rest, kont, mkont, trail)(ct2) case Binary(op) => - val (ty2, newCtx1) = ctx.pop() + val (ty2, ct1) = ct.pop() val v2 = Stack.popC(ty2) val s2 = Stack.popS(ty2) - val (ty1, newCtx2) = newCtx1.pop() + val (ty1, ct2) = ct1.pop() val v1 = Stack.popC(ty1) val s1 = Stack.popS(ty1) val res = evalBinOpC(op, v1, v2) Stack.pushC(res) Stack.pushS(evalBinOpS(op, s1, s2)) - val newCtx3 = newCtx2.push(res.tipe) - eval(rest, kont, mkont, trail)(newCtx3) + val ct3 = ct2.push(res.tipe) + eval(rest, kont, mkont, trail)(ct3) case Compare(op) => - val (ty2, newCtx1) = ctx.pop() + val (ty2, ct1) = ct.pop() val v2 = Stack.popC(ty2) val s2 = Stack.popS(ty2) - val (ty1, newCtx2) = newCtx1.pop() + val (ty1, ct2) = ct1.pop() val v1 = Stack.popC(ty1) val s1 = Stack.popS(ty1) val res = evalRelOpC(op, v1, v2) Stack.pushC(res) Stack.pushS(evalRelOpS(op, s1, s2)) - val newCtx3 = newCtx2.push(res.tipe) - eval(rest, kont, mkont, trail)(newCtx3) + val ct3 = ct2.push(res.tipe) + eval(rest, kont, mkont, trail)(ct3) case WasmBlock(ty, inner) => // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size val dummy = makeDummy - def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + def restK(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the block, stackSize =", Stack.size) - val offset = restCtx.stackTypes.size - exitSize + val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) Stack.shiftS(offset, funcTy.out.size) - val newRestCtx = restCtx.shift(offset, funcTy.out.size) - eval(rest, kont, mk, trail)(newRestCtx) + val ct1 = ct.shift(offset, funcTy.out.size) + eval(rest, kont, mk, trail)(ct1) }) - eval(inner, restK _, mkont, restK _ :: trail) + // TODO: extract this into a function + val (oldCtx, history, ct1) = ct.clearHistory + if (!ReuseManager.isReusing) { + evalSym(history)(oldCtx) + } + eval(inner, restK _, mkont, restK _ :: trail)(ct1) case Loop(ty, inner) => val funcTy = ty.funcType val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size val dummy = makeDummy - def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + def restK(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the loop, stackSize =", Stack.size) - val offset = restCtx.stackTypes.size - exitSize + val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) Stack.shiftS(offset, funcTy.out.size) - val newRestCtx = restCtx.shift(offset, funcTy.out.size) - eval(rest, kont, mk, trail)(newRestCtx) + val ct1 = ct.shift(offset, funcTy.out.size) + eval(rest, kont, mk, trail)(ct1) }) val enterSize = ctx.stackTypes.size - def loop(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + def loop(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the loop, stackSize =", Stack.size) - val offset = restCtx.stackTypes.size - enterSize + val offset = ct.endCtx.stackTypes.size - enterSize Stack.shiftC(offset, funcTy.inps.size) Stack.shiftS(offset, funcTy.inps.size) - val newRestCtx = restCtx.shift(offset, funcTy.inps.size) - eval(inner, restK _, mk, loop _ :: trail)(newRestCtx) + val ct1 = ct.shift(offset, funcTy.inps.size) + eval(inner, restK _, mk, loop _ :: trail)(ct1) }) - loop(ctx)(mkont) + val (oldCtx, history, ct1) = ct.clearHistory + if (!ReuseManager.isReusing) { + evalSym(history)(oldCtx) + } + loop(ct1)(mkont) case If(ty, thn, els) => val funcTy = ty.funcType - val (condTy, newCtx) = ctx.pop() + val (condTy, ct1) = ct.pop() val cond = Stack.popC(condTy) val symCond = Stack.popS(condTy) - val exitSize = newCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size - def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + val exitSize = ct1.endCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size + def restK(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the if, stackSize =", Stack.size) - val offset = restCtx.stackTypes.size - exitSize + val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) Stack.shiftS(offset, funcTy.out.size) - val newRestCtx = restCtx.shift(offset, funcTy.out.size) - eval(rest, kont, mk, trail)(newRestCtx) + val ct1 = ct.shift(offset, funcTy.out.size) + eval(rest, kont, mk, trail)(ct1) }) - // TODO: put the cond.s to path condition + val (oldCtx, history, ct2) = ct1.clearHistory + if (!ReuseManager.isReusing) { + evalSym(history)(oldCtx) + } ExploreTree.fillWithIfElse(symCond.s) if (cond.toInt != 0) { ExploreTree.moveCursor(true) - eval(thn, restK _, mkont, restK _ :: trail)(newCtx) + eval(thn, kont, mkont, trail)(ct2) } else { ExploreTree.moveCursor(false) - eval(els, restK _, mkont, restK _ :: trail)(newCtx) + eval(els, restK _, mkont, restK _ :: trail)(ct2) } () case Br(label) => info(s"Jump to $label") - trail(label)(ctx)(mkont) + val (oldCtx, history, ct1) = ct.clearHistory + if (!ReuseManager.isReusing) { + evalSym(history)(oldCtx) + } + trail(label)(ct1)(mkont) case BrIf(label) => - val (ty, newCtx) = ctx.pop() + val (ty, ct1) = ct.pop() val cond = Stack.popC(ty) val symCond = Stack.popS(ty) + val (oldCtx, history, ct2) = ct1.clearHistory + if (!ReuseManager.isReusing) { + evalSym(history)(oldCtx) + } info(s"The br_if(${label})'s condition is ", cond.toInt) ExploreTree.fillWithIfElse(symCond.s) if (cond.toInt != 0) { info(s"Jump to $label") ExploreTree.moveCursor(true) - trail(label)(newCtx)(mkont) + trail(label)(ct2)(mkont) } else { info(s"Continue") ExploreTree.moveCursor(false) - eval(rest, kont, mkont, trail)(newCtx) + eval(rest, kont, mkont, trail)(ct2) } () case BrTable(labels, default) => - val (ty, newCtx) = ctx.pop() + val (ty, ct1) = ct.pop() val label = Stack.popC(ty) val labelSym = Stack.popS(ty) + val (oldCtx, history, ct2) = ct1.clearHistory + if (!ReuseManager.isReusing) { + evalSym(history)(oldCtx) + } def aux(choices: List[Int], idx: Int): Rep[Unit] = { - if (choices.isEmpty) trail(default)(newCtx)(mkont) + if (choices.isEmpty) trail(default)(ct2)(mkont) else { val cond = (label - toStagedNum(I32V(idx))).isZero() val condSym = (labelSym - toStagedSymbolicNum(I32V(idx))).isZero() ExploreTree.fillWithIfElse(condSym.s) if (cond.toInt != 0) { ExploreTree.moveCursor(true) - trail(choices.head)(newCtx)(mkont) + trail(choices.head)(ct2)(mkont) } else { ExploreTree.moveCursor(false) @@ -402,7 +470,7 @@ trait StagedWasmEvaluator extends SAIOps { } } aux(labels, 0) - case Return => trail.last(ctx)(mkont) + case Return => trail.last(ct)(mkont) case Call(f) => evalCall(rest, kont, mkont, trail, f, false) case ReturnCall(f) => evalCall(rest, kont, mkont, trail, f, true) case _ => @@ -411,16 +479,29 @@ trait StagedWasmEvaluator extends SAIOps { } } + // call the symbolic interpreter to evaluate the history that just executed by + // concrete interpreter + def evalSym(history: List[Instr]) + (implicit ctx: Context): Rep[Unit] = { + // TODO: the context we currently passing is not right, we should use the + // original context where the history start + () + } + def forwardKont: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => mk(())) def evalCall(rest: List[Instr], - kont: Context => Rep[Cont[Unit]], + kont: ContextTransition => Rep[Cont[Unit]], mkont: Rep[MCont[Unit]], trail: Trail[Unit], funcIndex: Int, isTail: Boolean) - (implicit ctx: Context): Rep[Unit] = { + (implicit ct: ContextTransition): Rep[Unit] = { + val (oldCtx, history, ct1) = ct.clearHistory + if (!ReuseManager.isReusing) { + evalSym(history)(oldCtx) + } module.funcs(funcIndex) match { case FuncDef(_, FuncBodyDef(ty, _, bodyLocals, body)) => val locals = bodyLocals ++ ty.inps @@ -431,23 +512,25 @@ trait StagedWasmEvaluator extends SAIOps { val callee = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the function at $funcIndex, stackSize =", Stack.size) // we can do some check here to ensure the function returns correct size of stack - eval(body, (_: Context) => forwardKont, mk, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) + eval(body, (_: ContextTransition) => forwardKont, mk, ((_: ContextTransition) => forwardKont)::Nil)(ContextTransition(Context(Nil, locals))) }) compileCache(funcIndex) = callee callee } // Predef.println(s"[DEBUG] locals size: ${locals.size}") - val newCtx = ctx.take(ty.inps.size) + val ct2 = ct1.take(ty.inps.size) val argsC = Stack.takeC(ty.inps) val argsS = Stack.takeS(ty.inps) if (isTail) { // when tail call, return to the caller's return continuation - Frames.popFrameC(ctx.frameTypes.size) - Frames.popFrameS(ctx.frameTypes.size) + Frames.popFrameC(ct2.endCtx.frameTypes.size) Frames.pushFrameC(locals) - Frames.pushFrameS(locals) Frames.putAllC(argsC) - Frames.putAllS(argsS) + if (!ReuseManager.isReusing) { + Frames.popFrameS(ct2.endCtx.frameTypes.size) + Frames.pushFrameS(locals) + Frames.putAllS(argsS) + } callee(mkont) } else { // We make a new trail by `restK`, since function creates a new block to escape @@ -456,32 +539,35 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) Frames.popFrameC(locals.size) Frames.popFrameS(locals.size) - eval(rest, kont, mk, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size))) + val newCtx = ct2.endCtx.copy(stackTypes = ty.out.reverse ++ ct2.endCtx.stackTypes) + eval(rest, kont, mk, trail)(ContextTransition(newCtx)) }) val dummy = makeDummy val newMKont: Rep[MCont[Unit]] = funHere((_u: Rep[Unit]) => { restK(mkont) }, dummy) Frames.pushFrameC(locals) - Frames.pushFrameS(locals) Frames.putAllC(argsC) - Frames.putAllS(argsS) + if (!ReuseManager.isReusing) { + Frames.pushFrameS(locals) + Frames.putAllS(argsS) + } callee(newMKont) } case Import("console", "log", _) | Import("spectest", "print_i32", _) => //println(s"[DEBUG] current stack: $stack") - val (ty, newCtx) = ctx.pop() + val (ty, ct1) = ct.pop() val v = Stack.popC(ty) Stack.popS(ty) println(v.toInt) - eval(rest, kont, mkont, trail)(newCtx) + eval(rest, kont, mkont, trail)(ct1) case Import("console", "assert", _) => - val (ty, newCtx) = ctx.pop() + val (ty, ct1) = ct.pop() val v = Stack.popC(ty) Stack.popS(ty) runtimeAssert(v.toInt != 0) - eval(rest, kont, mkont, trail)(newCtx) + eval(rest, kont, mkont, trail)(ct1) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } @@ -593,7 +679,7 @@ trait StagedWasmEvaluator extends SAIOps { resetStacks() Frames.pushFrameC(locals) Frames.pushFrameS(locals) - eval(instrs, (_: Context) => forwardKont, mkont, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) + eval(instrs, _ => forwardKont, mkont, ((_: ContextTransition) => forwardKont)::Nil)(ContextTransition(Context(Nil, locals))) Frames.popFrameC(locals.size) Frames.popFrameS(locals.size) } @@ -883,6 +969,20 @@ trait StagedWasmEvaluator extends SAIOps { } } + object ReuseManager { + def isReusing: Rep[Boolean] = { + "reuse-is-reusing".reflectCtrlWith[Boolean]() + } + + def turnOnReuse(): Rep[Unit] = { + "reuse-turn-on".reflectCtrlWith[Unit]() + } + + def turnOffReuse(): Rep[Unit] = { + "reuse-turn-off".reflectCtrlWith[Unit]() + } + } + // runtime Num type implicit class StagedConcreteNumOps(num: StagedConcreteNum) { @@ -1395,6 +1495,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("ExploreTree.dump_graphviz("); shallow(f); emit(")") case Node(_, "sym-not", List(s), _) => shallow(s); emit(".negate()") + case Node(_, "reuse-is-reusing", List(), _) => + emit("Reuse.is_reusing()") case Node(_, "dummy", _, _) => emit("std::monostate()") case Node(_, "dummy-op", _, _) => emit("std::monostate()") case Node(_, "no-op", _, _) => From 354cdd259c8e52ac2c7b9966d9cf272fef12e16e Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Wed, 20 Aug 2025 14:22:16 -0400 Subject: [PATCH 03/12] (WIP)split symbolic and concrete interpreter --- headers/wasm/concolic_driver.hpp | 5 +- headers/wasm/concrete_rt.hpp | 6 +- headers/wasm/symbolic_rt.hpp | 16 +- .../scala/wasm/StagedConcolicMiniWasm.scala | 181 +++++++++++++----- 4 files changed, 152 insertions(+), 56 deletions(-) diff --git a/headers/wasm/concolic_driver.hpp b/headers/wasm/concolic_driver.hpp index 8e8ca815..799f7fb8 100644 --- a/headers/wasm/concolic_driver.hpp +++ b/headers/wasm/concolic_driver.hpp @@ -49,8 +49,6 @@ inline void ConcolicDriver::run() { auto cond = unexplored->collect_path_conds(); auto result = solver.solve(cond); if (!result.has_value()) { - // TODO: current implementation is buggy, there could be other reachable - // unexplored paths std::cout << "Found an unreachable path, marking it as unreachable..." << std::endl; unexplored->fillUnreachableNode(); @@ -59,6 +57,9 @@ inline void ConcolicDriver::run() { auto new_env = result.value(); SymEnv.update(std::move(new_env)); try { + std::cout << "Now execute the program with symbolic environment: " + << std::endl + << SymEnv.to_string() << std::endl; entrypoint(); std::cout << "Execution finished successfully with symbolic environment:" << std::endl; diff --git a/headers/wasm/concrete_rt.hpp b/headers/wasm/concrete_rt.hpp index a0961453..a9abccf2 100644 --- a/headers/wasm/concrete_rt.hpp +++ b/headers/wasm/concrete_rt.hpp @@ -72,9 +72,7 @@ class Stack_t { Num pop() { #ifdef DEBUG - if (count == 0) { - throw std::runtime_error("Stack underflow"); - } + assert(count > 0 && "Stack underflow"); #endif Num num = stack_ptr[count - 1]; count--; @@ -117,7 +115,7 @@ class Stack_t { void initialize() { // todo: remove this method - reset(); + reset(); } void reset() { count = 0; } diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index a139e494..9c14a741 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -65,6 +65,10 @@ struct SymVal { SymVal negate() const; }; +static SymVal make_symbolic(int index) { + return SymVal(std::make_shared(index)); +} + inline SymVal Concrete(Num num) { return SymVal(std::make_shared(num)); } @@ -142,6 +146,11 @@ class SymStack_t { SymVal pop() { // Pop a symbolic value from the stack + +#ifdef DEBUG + printf("[Debug] poping from stack, size of symbolic stack is: %zu\n", + stack.size()); +#endif auto ret = stack.back(); stack.pop_back(); return ret; @@ -368,12 +377,13 @@ inline NodeBox::NodeBox(NodeBox *parent) parent(parent) {} inline std::monostate NodeBox::fillIfElseNode(SymVal cond) { - // fill the current NodeBox with an ifelse branch node it's unexplored + // fill the current NodeBox with an ifelse branch node when it's unexplored if (dynamic_cast(node.get())) { node = std::make_unique(cond, this); } - assert(dynamic_cast(node.get()) != nullptr && - "Current node is not an IfElseNode, cannot fill it!"); + assert( + dynamic_cast(node.get()) != nullptr && + "Current node is not an Unexplored nor an IfElseNode, cannot fill it!"); return std::monostate(); } diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 72547377..9d547ccf 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -205,7 +205,13 @@ trait StagedWasmEvaluator extends SAIOps { mkont: Rep[MCont[Unit]], trail: Trail[Unit]) (implicit oldCT: ContextTransition): Rep[Unit] = { - if (insts.isEmpty) return kont(oldCT)(mkont) + if (insts.isEmpty) { + val (oldCtx, history, ct) = oldCT.clearHistory + if (!ReuseManager.isReusing) { + evalSym(history)(oldCtx) + } + return kont(ct)(mkont) + } // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") // Predef.println(s"[DEBUG] Current context: $ctx") @@ -216,54 +222,42 @@ trait StagedWasmEvaluator extends SAIOps { case Drop => val (ty, ct1) = ct.pop() Stack.popC(ty) - Stack.popS(ty) eval(rest, kont, mkont, trail)(ct1) case WasmConst(num) => Stack.pushC(toStagedNum(num)) - Stack.pushS(toStagedSymbolicNum(num)) val ct1 = ct.push(num.tipe(module)) eval(rest, kont, mkont, trail)(ct1) case Symbolic(ty) => - Stack.popC(ty) - val id = Stack.popS(ty) - val symVal = id.makeSymbolic() + val id = Stack.popC(ty) + val symVal = id.makeSymbolic(ty) val num = SymEnv.read(symVal.s) Stack.pushC(ty.concreteTag(num)) - Stack.pushS(symVal) val ct1 = ct.pop()._2.push(ty) eval(rest, kont, mkont, trail)(ct1) case LocalGet(i) => Stack.pushC(Frames.getC(i)) - Stack.pushS(Frames.getS(i)) val ct1 = ct.push(ctx.frameTypes(i)) eval(rest, kont, mkont, trail)(ct1) case LocalSet(i) => val (ty, ct1) = ct.pop() val num = Stack.popC(ty) - val sym = Stack.popS(ty) Frames.setC(i, num) - Frames.setS(i, sym) eval(rest, kont, mkont, trail)(ct1) case LocalTee(i) => val ty = ct.peek val num = Stack.peekC(ty) - val sym = Stack.peekS(ty) Frames.setC(i, num) - Frames.setS(i, sym) eval(rest, kont, mkont, trail)(ct) case GlobalGet(i) => Stack.pushC(Globals.getC(i)) - Stack.pushS(Globals.getS(i)) val ct1 = ct.push(module.globals(i).ty.ty) eval(rest, kont, mkont, trail)(ct1) case GlobalSet(i) => val (ty, ct1) = ct.pop() val num = Stack.popC(ty) - val sym = Stack.popS(ty) module.globals(i).ty match { case GlobalType(tipe, true) => { Globals.setC(i, num) - Globals.setS(i, sym) } case _ => throw new Exception("Cannot set immutable global") } @@ -271,35 +265,27 @@ trait StagedWasmEvaluator extends SAIOps { case Store(StoreOp(align, offset, ty, None)) => val (ty1, ct1) = ct.pop() val value = Stack.popC(ty1) - val symValue = Stack.popS(ty1) val (ty2, ct2) = ct1.pop() val addr = Stack.popC(ty2) - val symAddr = Stack.popS(ty2) Memory.storeInt(addr.toInt, offset, value.toInt) eval(rest, kont, mkont, trail)(ct2) case Nop => eval(rest, kont, mkont, trail) case Load(LoadOp(align, offset, ty, None, None)) => val (ty1, ct1) = ct.pop() val addr = Stack.popC(ty1) - Stack.popS(ty1) val num = Memory.loadIntC(addr.toInt, offset) - val sym = Memory.loadIntS(addr.toInt, offset) Stack.pushC(num) - Stack.pushS(sym) val ct2 = ct1.push(ty) eval(rest, kont, mkont, trail)(ct2) case MemorySize => ??? case MemoryGrow => val (ty, ct1) = ct.pop() val delta = Stack.popC(ty) - Stack.popS(ty) val ret = Memory.grow(delta.toInt) val retNum = Values.I32V(ret) // For now, we assume that the result of memory.grow only depends on the execution path, // we can relax this by turning it return to a symbol value and mimic the memory.grow's result as input. - val retSym = "Concrete".reflectCtrlWith[SymVal](retNum) Stack.pushC(I32C(retNum)) - Stack.pushS(I32S(retSym)) val ct2 = ct1.push(NumType(I32Type)) eval(rest, kont, mkont, trail)(ct2) case MemoryFill => ??? @@ -307,42 +293,32 @@ trait StagedWasmEvaluator extends SAIOps { case Test(op) => val (ty, ct1) = ct.pop() val v = Stack.popC(ty) - val s = Stack.popS(ty) Stack.pushC(evalTestOpC(op, v)) - Stack.pushS(evalTestOpS(op, s)) val ct2 = ct1.push(v.tipe) eval(rest, kont, mkont, trail)(ct2) case Unary(op) => val (ty, ct1) = ct.pop() val v = Stack.popC(ty) - val s = Stack.popS(ty) val res = evalUnaryOpC(op, v) Stack.pushC(res) - Stack.pushS(evalUnaryOpS(op, s)) val ct2 = ct1.push(res.tipe) eval(rest, kont, mkont, trail)(ct2) case Binary(op) => val (ty2, ct1) = ct.pop() val v2 = Stack.popC(ty2) - val s2 = Stack.popS(ty2) val (ty1, ct2) = ct1.pop() val v1 = Stack.popC(ty1) - val s1 = Stack.popS(ty1) val res = evalBinOpC(op, v1, v2) Stack.pushC(res) - Stack.pushS(evalBinOpS(op, s1, s2)) val ct3 = ct2.push(res.tipe) eval(rest, kont, mkont, trail)(ct3) case Compare(op) => val (ty2, ct1) = ct.pop() val v2 = Stack.popC(ty2) - val s2 = Stack.popS(ty2) val (ty1, ct2) = ct1.pop() val v1 = Stack.popC(ty1) - val s1 = Stack.popS(ty1) val res = evalRelOpC(op, v1, v2) Stack.pushC(res) - Stack.pushS(evalRelOpS(op, s1, s2)) val ct3 = ct2.push(res.tipe) eval(rest, kont, mkont, trail)(ct3) case WasmBlock(ty, inner) => @@ -395,7 +371,6 @@ trait StagedWasmEvaluator extends SAIOps { val funcTy = ty.funcType val (condTy, ct1) = ct.pop() val cond = Stack.popC(condTy) - val symCond = Stack.popS(condTy) val exitSize = ct1.endCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size def restK(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the if, stackSize =", Stack.size) @@ -407,9 +382,11 @@ trait StagedWasmEvaluator extends SAIOps { }) val (oldCtx, history, ct2) = ct1.clearHistory if (!ReuseManager.isReusing) { + // when we are not reusing evalSym(history)(oldCtx) + val symCond = Stack.popS(condTy) + ExploreTree.fillWithIfElse(symCond.s) } - ExploreTree.fillWithIfElse(symCond.s) if (cond.toInt != 0) { ExploreTree.moveCursor(true) eval(thn, kont, mkont, trail)(ct2) @@ -428,13 +405,13 @@ trait StagedWasmEvaluator extends SAIOps { case BrIf(label) => val (ty, ct1) = ct.pop() val cond = Stack.popC(ty) - val symCond = Stack.popS(ty) val (oldCtx, history, ct2) = ct1.clearHistory + info(s"The br_if(${label})'s condition is ", cond.toInt) if (!ReuseManager.isReusing) { evalSym(history)(oldCtx) + val symCond = Stack.popS(ty) + ExploreTree.fillWithIfElse(symCond.s) } - info(s"The br_if(${label})'s condition is ", cond.toInt) - ExploreTree.fillWithIfElse(symCond.s) if (cond.toInt != 0) { info(s"Jump to $label") ExploreTree.moveCursor(true) @@ -457,8 +434,10 @@ trait StagedWasmEvaluator extends SAIOps { if (choices.isEmpty) trail(default)(ct2)(mkont) else { val cond = (label - toStagedNum(I32V(idx))).isZero() - val condSym = (labelSym - toStagedSymbolicNum(I32V(idx))).isZero() - ExploreTree.fillWithIfElse(condSym.s) + if (!ReuseManager.isReusing) { + val condSym = (labelSym - toStagedSymbolicNum(I32V(idx))).isZero() + ExploreTree.fillWithIfElse(condSym.s) + } if (cond.toInt != 0) { ExploreTree.moveCursor(true) trail(choices.head)(ct2)(mkont) @@ -483,9 +462,112 @@ trait StagedWasmEvaluator extends SAIOps { // concrete interpreter def evalSym(history: List[Instr]) (implicit ctx: Context): Rep[Unit] = { - // TODO: the context we currently passing is not right, we should use the - // original context where the history start - () + // val func = topFun((_: Rep[Unit]) => evalS(history.reverse)) + // func(()) + evalS(history.reverse) + } + + def evalS(insts: List[Instr]) + (implicit ctx: Context): Rep[Unit] = { + if (insts.isEmpty) return () + + // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") + // Predef.println(s"[DEBUG] Current context: $ctx") + val (inst, rest) = (insts.head, insts.tail) + inst match { + case Drop => + val (ty, newCtx) = ctx.pop() + Stack.popS(ty) + evalS(rest)(newCtx) + case WasmConst(num) => + Stack.pushS(toStagedSymbolicNum(num)) + val newCtx = ctx.push(num.tipe(module)) + evalS(rest)(newCtx) + case Symbolic(ty) => + val id = Stack.popS(ty) + val symVal = id.makeSymbolic(ty) + Stack.pushS(symVal) + val newCtx = ctx.pop()._2.push(ty) + evalS(rest)(newCtx) + case LocalGet(i) => + Stack.pushS(Frames.getS(i)) + val newCtx = ctx.push(ctx.frameTypes(i)) + evalS(rest)(newCtx) + case LocalSet(i) => + val (ty, newCtx) = ctx.pop() + val sym = Stack.popS(ty) + Frames.setS(i, sym) + evalS(rest)(newCtx) + case LocalTee(i) => + val ty = ctx.pop()._1 + val sym = Stack.peekS(ty) + Frames.setS(i, sym) + evalS(rest)(ctx) + case GlobalGet(i) => + Stack.pushS(Globals.getS(i)) + val newCtx = ctx.push(module.globals(i).ty.ty) + evalS(rest)(newCtx) + case GlobalSet(i) => + val (ty, newCtx) = ctx.pop() + val sym = Stack.popS(ty) + module.globals(i).ty match { + case GlobalType(tipe, true) => { + Globals.setS(i, sym) + } + case _ => throw new Exception("Cannot set immutable global") + } + evalS(rest)(newCtx) + case Nop => evalS(rest)(ctx) + case Store(StoreOp(align, offset, ty, None)) => ??? + case Load(LoadOp(align, offset, ty, None, None)) => ??? + case MemorySize => ??? + case MemoryGrow => ??? + case MemoryFill => ??? + case Unreachable => unreachable() + case Test(op) => + val (ty, newCtx1) = ctx.pop() + val s = Stack.popS(ty) + Stack.pushS(evalTestOpS(op, s)) + val newCtx2 = newCtx1.push(s.tipe) + evalS(rest)(newCtx2) + case Unary(op) => + val (ty, newCtx1) = ctx.pop() + val s = Stack.popS(ty) + val res = evalUnaryOpS(op, s) + Stack.pushS(res) + val newCtx2 = newCtx1.push(res.tipe) + evalS(rest)(newCtx2) + case Binary(op) => + val (ty2, newCtx1) = ctx.pop() + val s2 = Stack.popS(ty2) + val (ty1, newCtx2) = newCtx1.pop() + val s1 = Stack.popS(ty1) + val res = evalBinOpS(op, s1, s2) + Stack.pushS(res) + val newCtx3 = newCtx2.push(res.tipe) + evalS(rest)(newCtx3) + case Compare(op) => + val (ty2, newCtx1) = ctx.pop() + val s2 = Stack.popS(ty2) + val (ty1, newCtx2) = newCtx1.pop() + val s1 = Stack.popS(ty1) + val res = evalRelOpS(op, s1, s2) + Stack.pushS(res) + val newCtx3 = newCtx2.push(res.tipe) + evalS(rest)(newCtx3) + case WasmBlock(ty, inner) => () + case Loop(ty, inner) => () + case If(ty, thn, els) => () + case Br(label) => () + case BrIf(label) => () + case BrTable(labels, default) => () + case Return => () + case Call(f) => () + case ReturnCall(f) => () + case _ => + val todo = "todo-op".reflectCtrlWith[Unit]() + evalS(rest) + } } def forwardKont: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => mk(())) @@ -951,6 +1033,7 @@ trait StagedWasmEvaluator extends SAIOps { } def moveCursor(branch: Boolean): Rep[Unit] = { + // when moving cursor from to an unexplored node, we need to change the reuse state "tree-move-cursor".reflectCtrlWith[Unit](branch) } @@ -986,6 +1069,10 @@ trait StagedWasmEvaluator extends SAIOps { // runtime Num type implicit class StagedConcreteNumOps(num: StagedConcreteNum) { + def makeSymbolic(ty: ValueType): StagedSymbolicNum = num match { + case I32C(x) => I32S("make-symbolic-concrete".reflectCtrlWith[SymVal](num.toInt)) + } + def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num.i) def isZero(): StagedConcreteNum = num match { @@ -1142,11 +1229,9 @@ trait StagedWasmEvaluator extends SAIOps { } implicit class StagedSymbolicNumOps(num: StagedSymbolicNum) { - def makeSymbolic(): StagedSymbolicNum = num match { + def makeSymbolic(ty: ValueType): StagedSymbolicNum = num match { case I32S(x) => I32S("make-symbolic".reflectCtrlWith[SymVal](x)) - case I64S(x) => I64S("make-symbolic".reflectCtrlWith[SymVal](x)) - case F32S(x) => F32S("make-symbolic".reflectCtrlWith[SymVal](x)) - case F64S(x) => F64S("make-symbolic".reflectCtrlWith[SymVal](x)) + case _ => throw new RuntimeException("Symbol index must be an i32") } def isZero(): StagedSymbolicNum = num match { @@ -1479,6 +1564,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { shallow(num); emit(".toInt()") case Node(_, "make-symbolic", List(num), _) => shallow(num); emit(".makeSymbolic()") + case Node(_, "make-symbolic-concrete", List(num), _) => + emit("make_symbolic("); shallow(num); emit(")") case Node(_, "sym-env-read", List(sym), _) => emit("SymEnv.read("); shallow(sym); emit(")") case Node(_, "assert-true", List(cond), _) => From d82821d5334b3d9e6fa593109efd08fb7e58e98d Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Thu, 21 Aug 2025 14:02:25 -0400 Subject: [PATCH 04/12] ensure that the history has been executed when writing symstack in concrete interpreter --- src/main/scala/wasm/StagedConcolicMiniWasm.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 9d547ccf..ba651e10 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -425,7 +425,6 @@ trait StagedWasmEvaluator extends SAIOps { case BrTable(labels, default) => val (ty, ct1) = ct.pop() val label = Stack.popC(ty) - val labelSym = Stack.popS(ty) val (oldCtx, history, ct2) = ct1.clearHistory if (!ReuseManager.isReusing) { evalSym(history)(oldCtx) @@ -435,6 +434,7 @@ trait StagedWasmEvaluator extends SAIOps { else { val cond = (label - toStagedNum(I32V(idx))).isZero() if (!ReuseManager.isReusing) { + val labelSym = Stack.peekS(ty) val condSym = (labelSym - toStagedSymbolicNum(I32V(idx))).isZero() ExploreTree.fillWithIfElse(condSym.s) } @@ -449,6 +449,10 @@ trait StagedWasmEvaluator extends SAIOps { } } aux(labels, 0) + if (!ReuseManager.isReusing) { + Stack.popS(ty) + } + () case Return => trail.last(ct)(mkont) case Call(f) => evalCall(rest, kont, mkont, trail, f, false) case ReturnCall(f) => evalCall(rest, kont, mkont, trail, f, true) @@ -639,17 +643,17 @@ trait StagedWasmEvaluator extends SAIOps { case Import("console", "log", _) | Import("spectest", "print_i32", _) => //println(s"[DEBUG] current stack: $stack") - val (ty, ct1) = ct.pop() + val (ty, ct2) = ct1.pop() val v = Stack.popC(ty) Stack.popS(ty) println(v.toInt) - eval(rest, kont, mkont, trail)(ct1) + eval(rest, kont, mkont, trail)(ct2) case Import("console", "assert", _) => - val (ty, ct1) = ct.pop() + val (ty, ct2) = ct1.pop() val v = Stack.popC(ty) Stack.popS(ty) runtimeAssert(v.toInt != 0) - eval(rest, kont, mkont, trail)(ct1) + eval(rest, kont, mkont, trail)(ct2) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } From bfd29e24f6cc47aa26d0d73fb4025d6e729a4f0e Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Thu, 21 Aug 2025 14:18:34 -0400 Subject: [PATCH 05/12] remove implicit parameters --- .../scala/wasm/StagedConcolicMiniWasm.scala | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index ba651e10..74fcd0ad 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -204,7 +204,7 @@ trait StagedWasmEvaluator extends SAIOps { kont: ContextTransition => Rep[Cont[Unit]], mkont: Rep[MCont[Unit]], trail: Trail[Unit]) - (implicit oldCT: ContextTransition): Rep[Unit] = { + (oldCT: ContextTransition): Rep[Unit] = { if (insts.isEmpty) { val (oldCtx, history, ct) = oldCT.clearHistory if (!ReuseManager.isReusing) { @@ -215,7 +215,6 @@ trait StagedWasmEvaluator extends SAIOps { // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") // Predef.println(s"[DEBUG] Current context: $ctx") - implicit val ctx = oldCT.endCtx val (inst, rest) = (insts.head, insts.tail) val ct = oldCT.log(inst) inst match { @@ -235,8 +234,8 @@ trait StagedWasmEvaluator extends SAIOps { val ct1 = ct.pop()._2.push(ty) eval(rest, kont, mkont, trail)(ct1) case LocalGet(i) => - Stack.pushC(Frames.getC(i)) - val ct1 = ct.push(ctx.frameTypes(i)) + Stack.pushC(Frames.getC(i)(ct.endCtx)) + val ct1 = ct.push(ct.endCtx.frameTypes(i)) eval(rest, kont, mkont, trail)(ct1) case LocalSet(i) => val (ty, ct1) = ct.pop() @@ -269,7 +268,7 @@ trait StagedWasmEvaluator extends SAIOps { val addr = Stack.popC(ty2) Memory.storeInt(addr.toInt, offset, value.toInt) eval(rest, kont, mkont, trail)(ct2) - case Nop => eval(rest, kont, mkont, trail) + case Nop => eval(rest, kont, mkont, trail)(ct) case Load(LoadOp(align, offset, ty, None, None)) => val (ty1, ct1) = ct.pop() val addr = Stack.popC(ty1) @@ -325,7 +324,7 @@ trait StagedWasmEvaluator extends SAIOps { // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType - val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size + val exitSize = ct.endCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size val dummy = makeDummy def restK(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the block, stackSize =", Stack.size) @@ -343,7 +342,7 @@ trait StagedWasmEvaluator extends SAIOps { eval(inner, restK _, mkont, restK _ :: trail)(ct1) case Loop(ty, inner) => val funcTy = ty.funcType - val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size + val exitSize = ct.endCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size val dummy = makeDummy def restK(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the loop, stackSize =", Stack.size) @@ -353,7 +352,7 @@ trait StagedWasmEvaluator extends SAIOps { val ct1 = ct.shift(offset, funcTy.out.size) eval(rest, kont, mk, trail)(ct1) }) - val enterSize = ctx.stackTypes.size + val enterSize = ct.endCtx.stackTypes.size def loop(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the loop, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - enterSize @@ -454,25 +453,25 @@ trait StagedWasmEvaluator extends SAIOps { } () case Return => trail.last(ct)(mkont) - case Call(f) => evalCall(rest, kont, mkont, trail, f, false) - case ReturnCall(f) => evalCall(rest, kont, mkont, trail, f, true) + case Call(f) => evalCall(rest, kont, mkont, trail, f, false)(ct) + case ReturnCall(f) => evalCall(rest, kont, mkont, trail, f, true)(ct) case _ => val todo = "todo-op".reflectCtrlWith[Unit]() - eval(rest, kont, mkont, trail) + eval(rest, kont, mkont, trail)(ct) } } // call the symbolic interpreter to evaluate the history that just executed by // concrete interpreter def evalSym(history: List[Instr]) - (implicit ctx: Context): Rep[Unit] = { + (ctx: Context): Rep[Unit] = { // val func = topFun((_: Rep[Unit]) => evalS(history.reverse)) // func(()) - evalS(history.reverse) + evalS(history.reverse)(ctx) } def evalS(insts: List[Instr]) - (implicit ctx: Context): Rep[Unit] = { + (ctx: Context): Rep[Unit] = { if (insts.isEmpty) return () // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") @@ -494,7 +493,7 @@ trait StagedWasmEvaluator extends SAIOps { val newCtx = ctx.pop()._2.push(ty) evalS(rest)(newCtx) case LocalGet(i) => - Stack.pushS(Frames.getS(i)) + Stack.pushS(Frames.getS(i)(ctx)) val newCtx = ctx.push(ctx.frameTypes(i)) evalS(rest)(newCtx) case LocalSet(i) => @@ -570,7 +569,7 @@ trait StagedWasmEvaluator extends SAIOps { case ReturnCall(f) => () case _ => val todo = "todo-op".reflectCtrlWith[Unit]() - evalS(rest) + evalS(rest)(ctx) } } From 8bcc08f7cf9cf76aa10bed8565c8e959c6c048bc Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Thu, 21 Aug 2025 14:23:15 -0400 Subject: [PATCH 06/12] shift when not reusing --- src/main/scala/wasm/StagedConcolicMiniWasm.scala | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 74fcd0ad..9956554b 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -330,7 +330,9 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Exiting the block, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) - Stack.shiftS(offset, funcTy.out.size) + if (!ReuseManager.isReusing) { + Stack.shiftS(offset, funcTy.out.size) + } val ct1 = ct.shift(offset, funcTy.out.size) eval(rest, kont, mk, trail)(ct1) }) @@ -348,7 +350,9 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Exiting the loop, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) - Stack.shiftS(offset, funcTy.out.size) + if (!ReuseManager.isReusing) { + Stack.shiftS(offset, funcTy.out.size) + } val ct1 = ct.shift(offset, funcTy.out.size) eval(rest, kont, mk, trail)(ct1) }) @@ -357,7 +361,9 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Entered the loop, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - enterSize Stack.shiftC(offset, funcTy.inps.size) - Stack.shiftS(offset, funcTy.inps.size) + if (!ReuseManager.isReusing) { + Stack.shiftS(offset, funcTy.inps.size) + } val ct1 = ct.shift(offset, funcTy.inps.size) eval(inner, restK _, mk, loop _ :: trail)(ct1) }) @@ -375,7 +381,9 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Exiting the if, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) - Stack.shiftS(offset, funcTy.out.size) + if (!ReuseManager.isReusing) { + Stack.shiftS(offset, funcTy.out.size) + } val ct1 = ct.shift(offset, funcTy.out.size) eval(rest, kont, mk, trail)(ct1) }) From 866fcd57163dc7a2596c746592d934c18181831e Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Fri, 22 Aug 2025 15:00:32 -0400 Subject: [PATCH 07/12] a dedicated type for empty ContextTransition & remove unuseful parameter --- .../scala/wasm/StagedConcolicMiniWasm.scala | 72 ++++++++++++------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 9956554b..527cf4dc 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -149,8 +149,8 @@ trait StagedWasmEvaluator extends SAIOps { this.copy(history = instr :: history) } - def clearHistory: (Context, List[Instr], ContextTransition) = { - (startCtx, history, this.copy(startCtx = endCtx, history = Nil)) + def clearHistory: (Context, List[Instr], CleanCT) = { + (startCtx, history, CleanCT(endCtx)) } def push(ty: ValueType): ContextTransition = { @@ -175,17 +175,20 @@ trait StagedWasmEvaluator extends SAIOps { } } - object ContextTransition { - - def apply(startCtx: Context) = { - new ContextTransition(startCtx, Nil, startCtx) - } + case class CleanCT(ctx: Context) { + def startCtx: Context = ctx + def history: List[Instr] = Nil + def endCtx: Context = ctx + } + // we can treat every CleanCT as a ContextTransition + implicit def toContextCT(ct: CleanCT): ContextTransition = { + ContextTransition(ct.ctx, Nil, ct.ctx) } type MCont[A] = Unit => A type Cont[A] = (MCont[A]) => A - type Trail[A] = List[ContextTransition => Rep[Cont[A]]] + type Trail[A] = List[CleanCT => Rep[Cont[A]]] // a cache storing the compiled code for each function, to reduce re-compilation val compileCache = new HashMap[Int, Rep[(MCont[Unit]) => Unit]] @@ -201,7 +204,7 @@ trait StagedWasmEvaluator extends SAIOps { } def eval(insts: List[Instr], - kont: ContextTransition => Rep[Cont[Unit]], + kont: CleanCT => Rep[Cont[Unit]], mkont: Rep[MCont[Unit]], trail: Trail[Unit]) (oldCT: ContextTransition): Rep[Unit] = { @@ -326,7 +329,7 @@ trait StagedWasmEvaluator extends SAIOps { val funcTy = ty.funcType val exitSize = ct.endCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size val dummy = makeDummy - def restK(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + def restK(ct: CleanCT): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the block, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) @@ -346,7 +349,7 @@ trait StagedWasmEvaluator extends SAIOps { val funcTy = ty.funcType val exitSize = ct.endCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size val dummy = makeDummy - def restK(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + def restK(ct: CleanCT): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the loop, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) @@ -357,7 +360,7 @@ trait StagedWasmEvaluator extends SAIOps { eval(rest, kont, mk, trail)(ct1) }) val enterSize = ct.endCtx.stackTypes.size - def loop(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + def loop(ct: CleanCT): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the loop, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - enterSize Stack.shiftC(offset, funcTy.inps.size) @@ -377,7 +380,7 @@ trait StagedWasmEvaluator extends SAIOps { val (condTy, ct1) = ct.pop() val cond = Stack.popC(condTy) val exitSize = ct1.endCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size - def restK(ct: ContextTransition): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + def restK(ct: CleanCT): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the if, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) @@ -460,7 +463,13 @@ trait StagedWasmEvaluator extends SAIOps { Stack.popS(ty) } () - case Return => trail.last(ct)(mkont) + case Return => + // return instruction is also stack-polymorphic + val (oldCtx, history, ct2) = ct.clearHistory + if (!ReuseManager.isReusing) { + evalSym(history)(oldCtx) + } + trail.last(ct2)(mkont) case Call(f) => evalCall(rest, kont, mkont, trail, f, false)(ct) case ReturnCall(f) => evalCall(rest, kont, mkont, trail, f, true)(ct) case _ => @@ -469,6 +478,14 @@ trait StagedWasmEvaluator extends SAIOps { } } + def replayAndClearHistory(ct: ContextTransition): ContextTransition = { + val (oldCtx, history, ct1) = ct.clearHistory + if (!ReuseManager.isReusing) { + evalSym(history)(oldCtx) + } + ct1 + } + // call the symbolic interpreter to evaluate the history that just executed by // concrete interpreter def evalSym(history: List[Instr]) @@ -585,7 +602,7 @@ trait StagedWasmEvaluator extends SAIOps { def evalCall(rest: List[Instr], - kont: ContextTransition => Rep[Cont[Unit]], + kont: CleanCT => Rep[Cont[Unit]], mkont: Rep[MCont[Unit]], trail: Trail[Unit], funcIndex: Int, @@ -605,7 +622,7 @@ trait StagedWasmEvaluator extends SAIOps { val callee = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the function at $funcIndex, stackSize =", Stack.size) // we can do some check here to ensure the function returns correct size of stack - eval(body, (_: ContextTransition) => forwardKont, mk, ((_: ContextTransition) => forwardKont)::Nil)(ContextTransition(Context(Nil, locals))) + eval(body, (_: CleanCT) => forwardKont, mk, ((_: CleanCT) => forwardKont)::Nil)(CleanCT(Context(Nil, locals))) }) compileCache(funcIndex) = callee callee @@ -633,7 +650,7 @@ trait StagedWasmEvaluator extends SAIOps { Frames.popFrameC(locals.size) Frames.popFrameS(locals.size) val newCtx = ct2.endCtx.copy(stackTypes = ty.out.reverse ++ ct2.endCtx.stackTypes) - eval(rest, kont, mk, trail)(ContextTransition(newCtx)) + eval(rest, kont, mk, trail)(CleanCT(newCtx)) }) val dummy = makeDummy val newMKont: Rep[MCont[Unit]] = funHere((_u: Rep[Unit]) => { @@ -658,7 +675,8 @@ trait StagedWasmEvaluator extends SAIOps { case Import("console", "assert", _) => val (ty, ct2) = ct1.pop() val v = Stack.popC(ty) - Stack.popS(ty) + // TODO: We should also add s into exploration tree + val s = Stack.popS(ty) runtimeAssert(v.toInt != 0) eval(rest, kont, mkont, trail)(ct2) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") @@ -772,12 +790,12 @@ trait StagedWasmEvaluator extends SAIOps { resetStacks() Frames.pushFrameC(locals) Frames.pushFrameS(locals) - eval(instrs, _ => forwardKont, mkont, ((_: ContextTransition) => forwardKont)::Nil)(ContextTransition(Context(Nil, locals))) + eval(instrs, _ => forwardKont, mkont, ((_: CleanCT) => forwardKont)::Nil)(CleanCT(Context(Nil, locals))) Frames.popFrameC(locals.size) Frames.popFrameS(locals.size) } - def evalTop(main: Option[String], printRes: Boolean, dumpTree: Option[String]): Rep[Unit] = { + def evalTop(main: Option[String], printRes: Boolean): Rep[Unit] = { val haltK: Rep[Unit] => Rep[Unit] = (_) => { info("Exiting the program...") if (printRes) { @@ -1486,6 +1504,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.push("); shallow(value); emit(")") case Node(_, "stack-shift", List(offset, size), _) => emit("Stack.shift("); shallow(offset); emit(", "); shallow(size); emit(")") + case Node(_, "sym-stack-shift", List(offset, size), _) => + emit("SymStack.shift("); shallow(offset); emit(", "); shallow(size); emit(")") case Node(_, "stack-pop", _, _) => emit("Stack.pop()") case Node(_, "sym-stack-pop", _, _) => @@ -1652,12 +1672,12 @@ trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEv object WasmToCppCompiler { case class GeneratedCpp(source: String, headerFolders: List[String]) - def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean, dumpTree: Option[String]): GeneratedCpp = { + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean): GeneratedCpp = { println(s"Now compiling wasm module with entry function $main") val driver = new WasmToCppCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst def snippet(x: Rep[Unit]): Rep[Unit] = { - evalTop(main, printRes, dumpTree) + evalTop(main, printRes) } } GeneratedCpp(driver.code, driver.codegen.includePaths.toList) @@ -1668,8 +1688,8 @@ object WasmToCppCompiler { outputCpp: String, outputExe: String, printRes: Boolean, - dumpTree: Option[String]): Unit = { - val generated = compile(moduleInst, main, printRes, dumpTree) + macros: String*): Unit = { + val generated = compile(moduleInst, main, printRes) val code = generated.source val writer = new java.io.PrintWriter(new java.io.File(outputCpp)) @@ -1680,7 +1700,9 @@ object WasmToCppCompiler { } import sys.process._ - val command = s"g++ -std=c++20 $outputCpp -o $outputExe -O3 -g -l z3 " + generated.headerFolders.map(f => s"-I$f").mkString(" ") + val includeFlags = generated.headerFolders.map(f => s"-I$f").mkString(" ") + val macroFlags = macros.map(m => s"-D$m").mkString(" ") + val command = s"g++ -std=c++20 $outputCpp -o $outputExe -O3 -g -l z3 " + includeFlags + " " + macroFlags if (command.! != 0) { throw new RuntimeException(s"Compilation failed for $outputCpp") } From 40db49a3b31065c30f91d1c7bc4ae81528929552 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Fri, 22 Aug 2025 16:50:38 -0400 Subject: [PATCH 08/12] return instruction is stack polymorphic --- benchmarks/wasm/staged/return_poly.wat | 19 ++++++++++++++ headers/wasm/concolic_driver.hpp | 24 ++++++++--------- headers/wasm/smt_solver.hpp | 8 +++--- headers/wasm/symbolic_rt.hpp | 10 +++++++ headers/wasm/utils.hpp | 24 +++++++++++++++++ .../scala/wasm/StagedConcolicMiniWasm.scala | 18 ++++++++----- .../genwasym/TestStagedConcolicEval.scala | 26 +++++++++++++++---- 7 files changed, 101 insertions(+), 28 deletions(-) create mode 100644 benchmarks/wasm/staged/return_poly.wat diff --git a/benchmarks/wasm/staged/return_poly.wat b/benchmarks/wasm/staged/return_poly.wat new file mode 100644 index 00000000..1bab5ef0 --- /dev/null +++ b/benchmarks/wasm/staged/return_poly.wat @@ -0,0 +1,19 @@ +(module + (type (;0;) (func)) + (type (;1;) (func (result i32))) + ;; TODO: It seems that our parser or preprocessor has some problems; the result type of the last line doesn't take effect + (func (result i32) + block + i32.const 21 + i32.const 35 + i32.const 42 + return + end + i32.const 100 + ) + (func (type 0) + call 0 + ;; unreachable + ) + (export "$real_main" (func 1)) +) diff --git a/headers/wasm/concolic_driver.hpp b/headers/wasm/concolic_driver.hpp index 799f7fb8..427a0de8 100644 --- a/headers/wasm/concolic_driver.hpp +++ b/headers/wasm/concolic_driver.hpp @@ -4,6 +4,7 @@ #include "concrete_rt.hpp" #include "smt_solver.hpp" #include "symbolic_rt.hpp" +#include "utils.hpp" #include #include #include @@ -43,34 +44,33 @@ inline void ConcolicDriver::run() { auto unexplored = ExploreTree.pick_unexplored(); if (!unexplored) { - std::cout << "No unexplored nodes found, exiting..." << std::endl; + GENSYM_INFO("No unexplored nodes found, exiting..."); return; } auto cond = unexplored->collect_path_conds(); auto result = solver.solve(cond); if (!result.has_value()) { - std::cout << "Found an unreachable path, marking it as unreachable..." - << std::endl; + GENSYM_INFO("Found an unreachable path, marking it as unreachable..."); unexplored->fillUnreachableNode(); continue; } auto new_env = result.value(); SymEnv.update(std::move(new_env)); try { - std::cout << "Now execute the program with symbolic environment: " - << std::endl - << SymEnv.to_string() << std::endl; + GENSYM_INFO("Now execute the program with symbolic environment: "); + GENSYM_INFO(SymEnv.to_string()); entrypoint(); - std::cout << "Execution finished successfully with symbolic environment:" - << std::endl; - std::cout << SymEnv.to_string() << std::endl; + GENSYM_INFO("Execution finished successfully with symbolic environment:"); + GENSYM_INFO(SymEnv.to_string()); } catch (...) { ExploreTree.fillFailedNode(); - std::cout << "Caught runtime error with symbolic environment:" - << std::endl; - std::cout << SymEnv.to_string() << std::endl; + GENSYM_INFO("Caught runtime error with symbolic environment:"); + GENSYM_INFO(SymEnv.to_string()); return; } +#if defined(RUN_ONCE) + return; +#endif } } diff --git a/headers/wasm/smt_solver.hpp b/headers/wasm/smt_solver.hpp index f2450905..bc8cc9f9 100644 --- a/headers/wasm/smt_solver.hpp +++ b/headers/wasm/smt_solver.hpp @@ -3,6 +3,7 @@ #include "concrete_rt.hpp" #include "symbolic_rt.hpp" +#include "utils.hpp" #include "z3++.h" #include #include @@ -35,8 +36,8 @@ class Solver { std::vector result; // Reference: // https://github.com/Z3Prover/z3/blob/master/examples/c%2B%2B/example.cpp#L59 - - std::cout << "Solved Z3 model" << std::endl << model << std::endl; + GENSYM_INFO("Solved Z3 model"); + GENSYM_INFO(model); for (unsigned i = 0; i < model.size(); ++i) { z3::func_decl var = model[i]; z3::expr value = model.get_const_interp(var); @@ -48,8 +49,7 @@ class Solver { } result[id] = Num(value.get_numeral_int64()); } else { - std::cout << "Find a variable that is not created by GenSym: " << name - << std::endl; + GENSYM_INFO("Find a variable that is not created by GenSym: " + name); } } return result; diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index 9c14a741..4ccddb5a 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -3,6 +3,7 @@ #include "concrete_rt.hpp" #include +#include #include #include #include @@ -158,6 +159,15 @@ class SymStack_t { SymVal peek() { return stack.back(); } + std::monostate shift(int32_t offset, int32_t size) { + auto n = stack.size(); + for (size_t i = n - size; i < n; ++i) { + stack[i - offset] = stack[i]; + } + stack.resize(n - offset); + return std::monostate(); + } + void reset() { // Reset the symbolic stack stack.clear(); diff --git a/headers/wasm/utils.hpp b/headers/wasm/utils.hpp index 8a86ac98..ba57a1df 100644 --- a/headers/wasm/utils.hpp +++ b/headers/wasm/utils.hpp @@ -12,4 +12,28 @@ } while (0) #endif +#ifndef NO_DBG +#define GENSYM_DBG(obj) \ + do { \ + std::cout << "LOG: " << obj << " (" << __FILE__ << ":" \ + << std::to_string(__LINE__) << ")" << std::endl; \ + } while (0) +#else +#define GENSYM_LOG(message) \ + do { \ + } while (0) +#endif + +#ifndef NO_INFO +#define GENSYM_INFO(obj) \ + do { \ + std::cout << obj << std::endl; \ + } while (0) +#else +#define GENSYM_INFO(message) \ + do { \ + } while (0) + +#endif + #endif // UTILS_HPP \ No newline at end of file diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 527cf4dc..1a050d74 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -175,11 +175,7 @@ trait StagedWasmEvaluator extends SAIOps { } } - case class CleanCT(ctx: Context) { - def startCtx: Context = ctx - def history: List[Instr] = Nil - def endCtx: Context = ctx - } + case class CleanCT(ctx: Context) // we can treat every CleanCT as a ContextTransition implicit def toContextCT(ct: CleanCT): ContextTransition = { @@ -621,8 +617,15 @@ trait StagedWasmEvaluator extends SAIOps { } else { val callee = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the function at $funcIndex, stackSize =", Stack.size) - // we can do some check here to ensure the function returns correct size of stack - eval(body, (_: CleanCT) => forwardKont, mk, ((_: CleanCT) => forwardKont)::Nil)(CleanCT(Context(Nil, locals))) + // the return instruction is also stack polymorphic + def retK(ct: CleanCT): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) + val offset = ct.ctx.stackTypes.size - ty.out.size + Stack.shiftC(offset, ty.out.size) + Stack.shiftS(offset, ty.out.size) + mk(()) + }) + eval(body, retK _, mk, retK _::Nil)(CleanCT(Context(Nil, locals))) }) compileCache(funcIndex) = callee callee @@ -631,6 +634,7 @@ trait StagedWasmEvaluator extends SAIOps { val ct2 = ct1.take(ty.inps.size) val argsC = Stack.takeC(ty.inps) val argsS = Stack.takeS(ty.inps) + val exitSize = ty.out.size + ct2.endCtx.stackTypes.size if (isTail) { // when tail call, return to the caller's return continuation Frames.popFrameC(ct2.endCtx.frameTypes.size) diff --git a/src/test/scala/genwasym/TestStagedConcolicEval.scala b/src/test/scala/genwasym/TestStagedConcolicEval.scala index a65d0eda..77ce1ec0 100644 --- a/src/test/scala/genwasym/TestStagedConcolicEval.scala +++ b/src/test/scala/genwasym/TestStagedConcolicEval.scala @@ -9,12 +9,24 @@ import gensym.wasm.parser._ import gensym.wasm.stagedconcolicminiwasm._ class TestStagedConcolicEval extends FunSuite { - def testFileToCpp(filename: String, main: Option[String] = None, expect: Option[List[Float]]=None) = { + def testFileConcolicCpp(filename: String, main: Option[String] = None) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) val cppFile = s"$filename.cpp" val exe = s"$cppFile.exe" val exploreTreeFile = s"$filename.tree.dot" - WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true, Some(exploreTreeFile)) + WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true) + + import sys.process._ + val result = Process(s"./$exe", None, "TREE_FILE" -> exploreTreeFile).!! + println(result) + } + + // only test concrete execution and its result + def testFileConcreteCpp(filename: String, main: Option[String] = None, expect: Option[List[Float]] = None) = { + val moduleInst = ModuleInstance(Parser.parseFile(filename)) + val cppFile = s"$filename.cpp" + val exe = s"$cppFile.exe" + WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true, "NO_INFO") import sys.process._ val result = s"./$exe".!! @@ -30,13 +42,17 @@ class TestStagedConcolicEval extends FunSuite { }) } - test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main")) } + test("ack-cpp") { testFileConcolicCpp("./benchmarks/wasm/ack.wat", Some("real_main")) } test("bug-finding") { - testFileToCpp("./benchmarks/wasm/branch-strip-buggy.wat", Some("real_main")) + testFileConcolicCpp("./benchmarks/wasm/branch-strip-buggy.wat", Some("real_main")) } test("brtable-bug-finding") { - testFileToCpp("./benchmarks/wasm/staged/brtable_concolic.wat") + testFileConcolicCpp("./benchmarks/wasm/staged/brtable_concolic.wat") + } + + test("return - concrete") { + testFileConcreteCpp("./benchmarks/wasm/staged/return_poly.wat", Some("$real_main"), expect=Some(List(42))) } } From 3351aba2d8293211d6f6f4b84cca6553b8314e54 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Fri, 22 Aug 2025 17:57:14 -0400 Subject: [PATCH 09/12] migrate test cases of StagedMiniwasm & fix If's impl --- .../scala/wasm/StagedConcolicMiniWasm.scala | 2 +- .../genwasym/TestStagedConcolicEval.scala | 48 ++++++++++++++++++- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 1a050d74..a370e71f 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -395,7 +395,7 @@ trait StagedWasmEvaluator extends SAIOps { } if (cond.toInt != 0) { ExploreTree.moveCursor(true) - eval(thn, kont, mkont, trail)(ct2) + eval(thn, restK _, mkont, restK _ :: trail)(ct2) } else { ExploreTree.moveCursor(false) eval(els, restK _, mkont, restK _ :: trail)(ct2) diff --git a/src/test/scala/genwasym/TestStagedConcolicEval.scala b/src/test/scala/genwasym/TestStagedConcolicEval.scala index 77ce1ec0..48c24634 100644 --- a/src/test/scala/genwasym/TestStagedConcolicEval.scala +++ b/src/test/scala/genwasym/TestStagedConcolicEval.scala @@ -26,7 +26,7 @@ class TestStagedConcolicEval extends FunSuite { val moduleInst = ModuleInstance(Parser.parseFile(filename)) val cppFile = s"$filename.cpp" val exe = s"$cppFile.exe" - WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true, "NO_INFO") + WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true, "NO_INFO", "RUN_ONCE") import sys.process._ val result = s"./$exe".!! @@ -52,7 +52,51 @@ class TestStagedConcolicEval extends FunSuite { testFileConcolicCpp("./benchmarks/wasm/staged/brtable_concolic.wat") } - test("return - concrete") { + test("return-poly - concrete") { testFileConcreteCpp("./benchmarks/wasm/staged/return_poly.wat", Some("$real_main"), expect=Some(List(42))) } + test("ack-cpp - concrete") { testFileConcreteCpp("./benchmarks/wasm/ack.wat", Some("real_main"), expect=Some(List(7))) } + test("power - concrete") { testFileConcreteCpp("./benchmarks/wasm/pow.wat", Some("real_main"), expect=Some(List(1024))) } + test("start - concrete") { testFileConcreteCpp("./benchmarks/wasm/start.wat") } + test("fact - concrete") { testFileConcreteCpp("./benchmarks/wasm/fact.wat", None, expect=Some(List(120))) } + // TODO: Waiting more symbolic operators' implementations + // test("loop - concrete") { testFileConcreteCpp("./benchmarks/wasm/loop.wat", None, expect=Some(List(10))) } + test("even-odd - concrete") { testFileConcreteCpp("./benchmarks/wasm/even_odd.wat", None, expect=Some(List(1))) } + // TODO: Waiting symbolic memory's implementations + // test("load - concrete") { testFileConcreteCpp("./benchmarks/wasm/load.wat", None, expect=Some(List(1))) } + // test("btree - concrete") { testFileConcreteCpp("./benchmarks/wasm/btree/2o1u-unlabeled.wat") } + test("fib - concrete") { testFileConcreteCpp("./benchmarks/wasm/fib.wat", None, expect=Some(List(144))) } + test("tribonacci - concrete") { testFileConcreteCpp("./benchmarks/wasm/tribonacci.wat", None, expect=Some(List(504))) } + + // test("return - concrete") { + // Since all of the thrown exceptions had been captured in concolic driver, this test is not valid anymore + // intercept[java.lang.RuntimeException] { + // testFileConcreteCpp("./benchmarks/wasm/return.wat", Some("$real_main")) + // } + // } + + test("return_call - concrete") { + testFileConcreteCpp("./benchmarks/wasm/sum.wat", Some("sum10"), expect=Some(List(55))) + } + + test("block input - concrete") { + testFileConcreteCpp("./benchmarks/wasm/block.wat", Some("real_main"), expect=Some(List(9))) + } + test("loop block input - concrete") { + testFileConcreteCpp("./benchmarks/wasm/block.wat", Some("test_loop_input"), expect=Some(List(55))) + } + test("if block input - concrete") { + testFileConcreteCpp("./benchmarks/wasm/block.wat", Some("test_if_input"), expect=Some(List(25))) + } + test("block input - poly br - concrete") { + testFileConcreteCpp("./benchmarks/wasm/block.wat", Some("test_poly_br"), expect=Some(List(0))) + } + test("loop block - poly br - concrete") { + testFileConcreteCpp("./benchmarks/wasm/loop_poly.wat", None, expect=Some(List(2, 1))) + } + + test("brtable-cpp - concrete") { + testFileConcreteCpp("./benchmarks/wasm/staged/brtable.wat") + } + } From e92ccde6aed044f2e426dab2fe13b9828fe14f62 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Mon, 25 Aug 2025 21:27:39 -0400 Subject: [PATCH 10/12] rename --- .../scala/wasm/StagedConcolicMiniWasm.scala | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index a370e71f..c074fefd 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -199,6 +199,8 @@ trait StagedWasmEvaluator extends SAIOps { }) } + def isSymStackInUse: Rep[Boolean] = !ReuseManager.isReusing + def eval(insts: List[Instr], kont: CleanCT => Rep[Cont[Unit]], mkont: Rep[MCont[Unit]], @@ -206,7 +208,7 @@ trait StagedWasmEvaluator extends SAIOps { (oldCT: ContextTransition): Rep[Unit] = { if (insts.isEmpty) { val (oldCtx, history, ct) = oldCT.clearHistory - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { evalSym(history)(oldCtx) } return kont(ct)(mkont) @@ -329,7 +331,7 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Exiting the block, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { Stack.shiftS(offset, funcTy.out.size) } val ct1 = ct.shift(offset, funcTy.out.size) @@ -337,7 +339,7 @@ trait StagedWasmEvaluator extends SAIOps { }) // TODO: extract this into a function val (oldCtx, history, ct1) = ct.clearHistory - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { evalSym(history)(oldCtx) } eval(inner, restK _, mkont, restK _ :: trail)(ct1) @@ -349,7 +351,7 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Exiting the loop, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { Stack.shiftS(offset, funcTy.out.size) } val ct1 = ct.shift(offset, funcTy.out.size) @@ -360,14 +362,14 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Entered the loop, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - enterSize Stack.shiftC(offset, funcTy.inps.size) - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { Stack.shiftS(offset, funcTy.inps.size) } val ct1 = ct.shift(offset, funcTy.inps.size) eval(inner, restK _, mk, loop _ :: trail)(ct1) }) val (oldCtx, history, ct1) = ct.clearHistory - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { evalSym(history)(oldCtx) } loop(ct1)(mkont) @@ -380,14 +382,14 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Exiting the if, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { Stack.shiftS(offset, funcTy.out.size) } val ct1 = ct.shift(offset, funcTy.out.size) eval(rest, kont, mk, trail)(ct1) }) val (oldCtx, history, ct2) = ct1.clearHistory - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { // when we are not reusing evalSym(history)(oldCtx) val symCond = Stack.popS(condTy) @@ -404,7 +406,7 @@ trait StagedWasmEvaluator extends SAIOps { case Br(label) => info(s"Jump to $label") val (oldCtx, history, ct1) = ct.clearHistory - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { evalSym(history)(oldCtx) } trail(label)(ct1)(mkont) @@ -413,7 +415,7 @@ trait StagedWasmEvaluator extends SAIOps { val cond = Stack.popC(ty) val (oldCtx, history, ct2) = ct1.clearHistory info(s"The br_if(${label})'s condition is ", cond.toInt) - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { evalSym(history)(oldCtx) val symCond = Stack.popS(ty) ExploreTree.fillWithIfElse(symCond.s) @@ -432,14 +434,14 @@ trait StagedWasmEvaluator extends SAIOps { val (ty, ct1) = ct.pop() val label = Stack.popC(ty) val (oldCtx, history, ct2) = ct1.clearHistory - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { evalSym(history)(oldCtx) } def aux(choices: List[Int], idx: Int): Rep[Unit] = { if (choices.isEmpty) trail(default)(ct2)(mkont) else { val cond = (label - toStagedNum(I32V(idx))).isZero() - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { val labelSym = Stack.peekS(ty) val condSym = (labelSym - toStagedSymbolicNum(I32V(idx))).isZero() ExploreTree.fillWithIfElse(condSym.s) @@ -455,14 +457,14 @@ trait StagedWasmEvaluator extends SAIOps { } } aux(labels, 0) - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { Stack.popS(ty) } () case Return => // return instruction is also stack-polymorphic val (oldCtx, history, ct2) = ct.clearHistory - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { evalSym(history)(oldCtx) } trail.last(ct2)(mkont) @@ -476,7 +478,7 @@ trait StagedWasmEvaluator extends SAIOps { def replayAndClearHistory(ct: ContextTransition): ContextTransition = { val (oldCtx, history, ct1) = ct.clearHistory - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { evalSym(history)(oldCtx) } ct1 @@ -605,7 +607,7 @@ trait StagedWasmEvaluator extends SAIOps { isTail: Boolean) (implicit ct: ContextTransition): Rep[Unit] = { val (oldCtx, history, ct1) = ct.clearHistory - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { evalSym(history)(oldCtx) } module.funcs(funcIndex) match { @@ -640,7 +642,7 @@ trait StagedWasmEvaluator extends SAIOps { Frames.popFrameC(ct2.endCtx.frameTypes.size) Frames.pushFrameC(locals) Frames.putAllC(argsC) - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { Frames.popFrameS(ct2.endCtx.frameTypes.size) Frames.pushFrameS(locals) Frames.putAllS(argsS) @@ -662,7 +664,7 @@ trait StagedWasmEvaluator extends SAIOps { }, dummy) Frames.pushFrameC(locals) Frames.putAllC(argsC) - if (!ReuseManager.isReusing) { + if (isSymStackInUse) { Frames.pushFrameS(locals) Frames.putAllS(argsS) } From 4a4c82e161e50e40863ed76eceb796faa592f978 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 26 Aug 2025 17:05:20 -0400 Subject: [PATCH 11/12] reuse symbolic states --- headers/wasm/controls.hpp | 5 + headers/wasm/symbolic_rt.hpp | 112 ++++++++++++++++-- .../scala/wasm/StagedConcolicMiniWasm.scala | 75 +++++++----- 3 files changed, 150 insertions(+), 42 deletions(-) create mode 100644 headers/wasm/controls.hpp diff --git a/headers/wasm/controls.hpp b/headers/wasm/controls.hpp new file mode 100644 index 00000000..16fa5136 --- /dev/null +++ b/headers/wasm/controls.hpp @@ -0,0 +1,5 @@ +#include +#include + +using MCont_t = std::function; +using Cont_t = std::function; \ No newline at end of file diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index 4ccddb5a..f5dee431 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -2,6 +2,7 @@ #define WASM_SYMBOLIC_RT_HPP #include "concrete_rt.hpp" +#include "controls.hpp" #include #include #include @@ -138,6 +139,8 @@ inline SymVal SymVal::makeSymbolic() const { } } +class Snapshot_t; + class SymStack_t { public: void push(SymVal val) { @@ -173,6 +176,11 @@ class SymStack_t { stack.clear(); } + void reuse(Snapshot_t snapshot); + + size_t size() const { return stack.size(); } + +private: std::vector stack; }; @@ -206,9 +214,46 @@ class SymFrames_t { stack.clear(); } + void reuse(Snapshot_t snapshot); + std::vector stack; }; +// A snapshot of the symbolic state and execution context (control) +class Snapshot_t { +public: + explicit Snapshot_t(); + + SymStack_t get_stack() const { return stack; } + SymFrames_t get_frames() const { return frames; } + +private: + SymStack_t stack; + SymFrames_t frames; +}; + +inline void SymStack_t::reuse(Snapshot_t snapshot) { +// Reusing the symbolic stack from the snapshot +#ifdef DEBUG + std::cout << "Reusing symbolic state from snapshot" << std::endl; + std::cout << "Old stack size = " << stack.size() << std::endl; + std::cout << "New stack size = " << snapshot.get_stack().stack.size() + << std::endl; +#endif + stack = snapshot.get_stack().stack; +} + +inline void SymFrames_t::reuse(Snapshot_t snapshot) { +// Reusing the symbolic frames from the snapshot +#ifdef DEBUG + std::cout << "Reusing symbolic state from snapshot" << std::endl; + std::cout << "Old frame size = " << stack.size() << std::endl; + std::cout << "New frame size = " << snapshot.get_frames().stack.size() + << std::endl; +#endif + stack = snapshot.get_frames().stack; +} + static SymFrames_t SymFrames; struct Node; @@ -218,7 +263,7 @@ struct NodeBox { std::unique_ptr node; NodeBox *parent; - std::monostate fillIfElseNode(SymVal cond); + std::monostate fillIfElseNode(SymVal cond, const Snapshot_t &snapshot); std::monostate fillFinishedNode(); std::monostate fillFailedNode(); std::monostate fillUnreachableNode(); @@ -270,8 +315,9 @@ struct IfElseNode : Node { SymVal cond; std::unique_ptr true_branch; std::unique_ptr false_branch; + Snapshot_t snapshot; - IfElseNode(SymVal cond, NodeBox *parent) + IfElseNode(SymVal cond, NodeBox *parent, Snapshot_t snapshot) : cond(cond), true_branch(std::make_unique(parent)), false_branch(std::make_unique(parent)) {} @@ -386,10 +432,11 @@ inline NodeBox::NodeBox(NodeBox *parent) /* TODO: avoid allocation of unexplored node */ parent(parent) {} -inline std::monostate NodeBox::fillIfElseNode(SymVal cond) { +inline std::monostate NodeBox::fillIfElseNode(SymVal cond, + const Snapshot_t &snapshot) { // fill the current NodeBox with an ifelse branch node when it's unexplored if (dynamic_cast(node.get())) { - node = std::make_unique(cond, this); + node = std::make_unique(cond, this, snapshot); } assert( dynamic_cast(node.get()) != nullptr && @@ -447,6 +494,34 @@ inline std::vector NodeBox::collect_path_conds() { return result; } +class Reuse_t { +public: + Reuse_t() : reuse_flag(false) {} + bool is_reusing() { +#ifdef NO_REUSE + return false; +#endif + return reuse_flag; + } + + void turn_on_reusing() { reuse_flag = true; } + + void turn_off_reusing() { reuse_flag = false; } + +private: + bool reuse_flag; +}; + +static Reuse_t Reuse; + +inline Snapshot_t::Snapshot_t() : stack(SymStack), frames(SymFrames) { +#ifdef DEBUG + std::cout << "Creating snapshot of size " << stack.size() << std::endl; +#endif + assert(!Reuse.is_reusing() && + "Creating snapshot while reusing the symbolic stack"); +} + class ExploreTree_t { public: explicit ExploreTree_t() @@ -455,14 +530,19 @@ class ExploreTree_t { void reset_cursor() { // Reset the cursor to the root of the tree cursor = root.get(); + Reuse.turn_off_reusing(); + // if root cursor is a branch node, then we can reuse the snapshot inside it + if (auto ite = dynamic_cast(cursor->node.get())) { + Reuse.turn_on_reusing(); + } } std::monostate fillFinishedNode() { return cursor->fillFinishedNode(); } std::monostate fillFailedNode() { return cursor->fillFailedNode(); } - std::monostate fillIfElseNode(SymVal cond) { - return cursor->fillIfElseNode(cond); + std::monostate fillIfElseNode(SymVal cond, const Snapshot_t &snapshot) { + return cursor->fillIfElseNode(cond, snapshot); } std::monostate moveCursor(bool branch) { @@ -476,6 +556,19 @@ class ExploreTree_t { } else { cursor = if_else_node->false_branch.get(); } + + if (dynamic_cast(cursor->node.get())) { + // If we meet an unexplored node, resume the snapshot before and keep + // going + +#ifdef DEBUG + std::cout << "Resuming snapshot for unexplored node" << std::endl; +#endif + if (Reuse.is_reusing()) { + Reuse.turn_off_reusing(); + SymStack.reuse(if_else_node->snapshot); + } + } return std::monostate(); } @@ -571,11 +664,4 @@ class SymEnv_t { static SymEnv_t SymEnv; -class Reuse_t { -public: - bool is_reusing() { return false; } -}; - -static Reuse_t Reuse; - #endif // WASM_SYMBOLIC_RT_HPP \ No newline at end of file diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index c074fefd..769a0b85 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -199,7 +199,17 @@ trait StagedWasmEvaluator extends SAIOps { }) } - def isSymStackInUse: Rep[Boolean] = !ReuseManager.isReusing + + // TODO: maybe we don't need concern snapshot at compile time at all + trait Snapshot + + // Create a snapshot of the symbolic execution, we should ensure that current symstack is in use + // We don't need to store the control information, since the control is totally decided by concrete states + def makeSnapshot(): Rep[Snapshot] = { + "snapshot-make".reflectCtrlWith[Snapshot]() + } + + def isSymStateInUse: Rep[Boolean] = !ReuseManager.isReusing def eval(insts: List[Instr], kont: CleanCT => Rep[Cont[Unit]], @@ -208,7 +218,7 @@ trait StagedWasmEvaluator extends SAIOps { (oldCT: ContextTransition): Rep[Unit] = { if (insts.isEmpty) { val (oldCtx, history, ct) = oldCT.clearHistory - if (isSymStackInUse) { + if (isSymStateInUse) { evalSym(history)(oldCtx) } return kont(ct)(mkont) @@ -331,7 +341,7 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Exiting the block, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) - if (isSymStackInUse) { + if (isSymStateInUse) { Stack.shiftS(offset, funcTy.out.size) } val ct1 = ct.shift(offset, funcTy.out.size) @@ -339,7 +349,7 @@ trait StagedWasmEvaluator extends SAIOps { }) // TODO: extract this into a function val (oldCtx, history, ct1) = ct.clearHistory - if (isSymStackInUse) { + if (isSymStateInUse) { evalSym(history)(oldCtx) } eval(inner, restK _, mkont, restK _ :: trail)(ct1) @@ -351,7 +361,7 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Exiting the loop, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) - if (isSymStackInUse) { + if (isSymStateInUse) { Stack.shiftS(offset, funcTy.out.size) } val ct1 = ct.shift(offset, funcTy.out.size) @@ -362,14 +372,14 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Entered the loop, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - enterSize Stack.shiftC(offset, funcTy.inps.size) - if (isSymStackInUse) { + if (isSymStateInUse) { Stack.shiftS(offset, funcTy.inps.size) } val ct1 = ct.shift(offset, funcTy.inps.size) eval(inner, restK _, mk, loop _ :: trail)(ct1) }) val (oldCtx, history, ct1) = ct.clearHistory - if (isSymStackInUse) { + if (isSymStateInUse) { evalSym(history)(oldCtx) } loop(ct1)(mkont) @@ -382,18 +392,19 @@ trait StagedWasmEvaluator extends SAIOps { info(s"Exiting the if, stackSize =", Stack.size) val offset = ct.endCtx.stackTypes.size - exitSize Stack.shiftC(offset, funcTy.out.size) - if (isSymStackInUse) { + if (isSymStateInUse) { Stack.shiftS(offset, funcTy.out.size) } val ct1 = ct.shift(offset, funcTy.out.size) eval(rest, kont, mk, trail)(ct1) }) val (oldCtx, history, ct2) = ct1.clearHistory - if (isSymStackInUse) { + if (isSymStateInUse) { // when we are not reusing evalSym(history)(oldCtx) + val snapshot = makeSnapshot() val symCond = Stack.popS(condTy) - ExploreTree.fillWithIfElse(symCond.s) + ExploreTree.fillWithIfElse(symCond.s, snapshot) } if (cond.toInt != 0) { ExploreTree.moveCursor(true) @@ -406,7 +417,7 @@ trait StagedWasmEvaluator extends SAIOps { case Br(label) => info(s"Jump to $label") val (oldCtx, history, ct1) = ct.clearHistory - if (isSymStackInUse) { + if (isSymStateInUse) { evalSym(history)(oldCtx) } trail(label)(ct1)(mkont) @@ -415,10 +426,11 @@ trait StagedWasmEvaluator extends SAIOps { val cond = Stack.popC(ty) val (oldCtx, history, ct2) = ct1.clearHistory info(s"The br_if(${label})'s condition is ", cond.toInt) - if (isSymStackInUse) { + if (isSymStateInUse) { evalSym(history)(oldCtx) val symCond = Stack.popS(ty) - ExploreTree.fillWithIfElse(symCond.s) + val snapshot = makeSnapshot() + ExploreTree.fillWithIfElse(symCond.s, snapshot) } if (cond.toInt != 0) { info(s"Jump to $label") @@ -434,17 +446,18 @@ trait StagedWasmEvaluator extends SAIOps { val (ty, ct1) = ct.pop() val label = Stack.popC(ty) val (oldCtx, history, ct2) = ct1.clearHistory - if (isSymStackInUse) { + if (isSymStateInUse) { evalSym(history)(oldCtx) } def aux(choices: List[Int], idx: Int): Rep[Unit] = { if (choices.isEmpty) trail(default)(ct2)(mkont) else { val cond = (label - toStagedNum(I32V(idx))).isZero() - if (isSymStackInUse) { + if (isSymStateInUse) { val labelSym = Stack.peekS(ty) val condSym = (labelSym - toStagedSymbolicNum(I32V(idx))).isZero() - ExploreTree.fillWithIfElse(condSym.s) + val snapshot = makeSnapshot() + ExploreTree.fillWithIfElse(condSym.s, snapshot) } if (cond.toInt != 0) { ExploreTree.moveCursor(true) @@ -457,14 +470,14 @@ trait StagedWasmEvaluator extends SAIOps { } } aux(labels, 0) - if (isSymStackInUse) { + if (isSymStateInUse) { Stack.popS(ty) } () case Return => // return instruction is also stack-polymorphic val (oldCtx, history, ct2) = ct.clearHistory - if (isSymStackInUse) { + if (isSymStateInUse) { evalSym(history)(oldCtx) } trail.last(ct2)(mkont) @@ -478,7 +491,7 @@ trait StagedWasmEvaluator extends SAIOps { def replayAndClearHistory(ct: ContextTransition): ContextTransition = { val (oldCtx, history, ct1) = ct.clearHistory - if (isSymStackInUse) { + if (isSymStateInUse) { evalSym(history)(oldCtx) } ct1 @@ -607,7 +620,7 @@ trait StagedWasmEvaluator extends SAIOps { isTail: Boolean) (implicit ct: ContextTransition): Rep[Unit] = { val (oldCtx, history, ct1) = ct.clearHistory - if (isSymStackInUse) { + if (isSymStateInUse) { evalSym(history)(oldCtx) } module.funcs(funcIndex) match { @@ -634,15 +647,15 @@ trait StagedWasmEvaluator extends SAIOps { } // Predef.println(s"[DEBUG] locals size: ${locals.size}") val ct2 = ct1.take(ty.inps.size) - val argsC = Stack.takeC(ty.inps) - val argsS = Stack.takeS(ty.inps) val exitSize = ty.out.size + ct2.endCtx.stackTypes.size if (isTail) { // when tail call, return to the caller's return continuation + val argsC = Stack.takeC(ty.inps) Frames.popFrameC(ct2.endCtx.frameTypes.size) Frames.pushFrameC(locals) Frames.putAllC(argsC) - if (isSymStackInUse) { + if (isSymStateInUse) { + val argsS = Stack.takeS(ty.inps) Frames.popFrameS(ct2.endCtx.frameTypes.size) Frames.pushFrameS(locals) Frames.putAllS(argsS) @@ -662,9 +675,11 @@ trait StagedWasmEvaluator extends SAIOps { val newMKont: Rep[MCont[Unit]] = funHere((_u: Rep[Unit]) => { restK(mkont) }, dummy) + val argsC = Stack.takeC(ty.inps) Frames.pushFrameC(locals) Frames.putAllC(argsC) - if (isSymStackInUse) { + if (isSymStateInUse) { + val argsS = Stack.takeS(ty.inps) Frames.pushFrameS(locals) Frames.putAllS(argsS) } @@ -1059,8 +1074,8 @@ trait StagedWasmEvaluator extends SAIOps { // Exploration tree, object ExploreTree { - def fillWithIfElse(s: Rep[SymVal]): Rep[Unit] = { - "tree-fill-if-else".reflectCtrlWith[Unit](s) + def fillWithIfElse(sym: Rep[SymVal], snapshot: Rep[Snapshot]): Rep[Unit] = { + "tree-fill-if-else".reflectCtrlWith[Unit](sym, snapshot) } def fillWithFinished(): Rep[Unit] = { @@ -1456,7 +1471,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { else if (m.toString.endsWith("I32V")) "I32V" else if (m.toString.endsWith("I64V")) "I64V" else if (m.toString.endsWith("SymVal")) "SymVal" - + else if (m.toString.endsWith("Snapshot")) "Snapshot_t" else super.remap(m) } @@ -1516,6 +1531,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.pop()") case Node(_, "sym-stack-pop", _, _) => emit("SymStack.pop()") + case Node(_, "snapshot-make", _, _) => + emit("Snapshot_t()") case Node(_, "frame-pop", List(i), _) => emit("Frames.popFrame("); shallow(i); emit(")") case Node(_, "sym-frame-pop", List(i), _) => @@ -1607,8 +1624,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("SymEnv.read("); shallow(sym); emit(")") case Node(_, "assert-true", List(cond), _) => emit("GENSYM_ASSERT("); shallow(cond); emit(")") - case Node(_, "tree-fill-if-else", List(s), _) => - emit("ExploreTree.fillIfElseNode("); shallow(s); emit(")") + case Node(_, "tree-fill-if-else", List(sym, snapshot), _) => + emit("ExploreTree.fillIfElseNode("); shallow(sym); emit(", "); shallow(snapshot); emit(")") case Node(_, "tree-fill-finished", List(), _) => emit("ExploreTree.fillFinishedNode()") case Node(_, "tree-move-cursor", List(b), _) => From 0c70300a89280658e72a3729c2b221bf1c7f37b5 Mon Sep 17 00:00:00 2001 From: butterunderflow Date: Tue, 26 Aug 2025 17:25:39 -0400 Subject: [PATCH 12/12] tweak --- headers/wasm/symbolic_rt.hpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index f5dee431..94351f07 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -24,6 +24,12 @@ class Symbolic { static int max_id = 0; +#ifdef NO_REUSE +static bool REUSE_MODE = false; +#else +static bool REUSE_MODE = true; +#endif + class Symbol : public Symbolic { public: // TODO: add type information to determine the size of bitvector @@ -498,10 +504,8 @@ class Reuse_t { public: Reuse_t() : reuse_flag(false) {} bool is_reusing() { -#ifdef NO_REUSE - return false; -#endif - return reuse_flag; + // we are in reuse mode and the flag is set + return REUSE_MODE && reuse_flag; } void turn_on_reusing() { reuse_flag = true; } @@ -568,6 +572,11 @@ class ExploreTree_t { Reuse.turn_off_reusing(); SymStack.reuse(if_else_node->snapshot); } + } else if (dynamic_cast(cursor->node.get())) { + // if we are moving to a branch node, we must have reused the symbolic + // states + assert((!REUSE_MODE || Reuse.is_reusing()) && + "Moving to a branch node without reusing symbolic states"); } return std::monostate(); }