diff --git a/benchmarks/wasm/staged/return_poly.wat b/benchmarks/wasm/staged/return_poly.wat new file mode 100644 index 00000000..1bab5ef0 --- /dev/null +++ b/benchmarks/wasm/staged/return_poly.wat @@ -0,0 +1,19 @@ +(module + (type (;0;) (func)) + (type (;1;) (func (result i32))) + ;; TODO: It seems that our parser or preprocessor has some problems; the result type of the last line doesn't take effect + (func (result i32) + block + i32.const 21 + i32.const 35 + i32.const 42 + return + end + i32.const 100 + ) + (func (type 0) + call 0 + ;; unreachable + ) + (export "$real_main" (func 1)) +) diff --git a/headers/wasm/concolic_driver.hpp b/headers/wasm/concolic_driver.hpp index 8e8ca815..427a0de8 100644 --- a/headers/wasm/concolic_driver.hpp +++ b/headers/wasm/concolic_driver.hpp @@ -4,6 +4,7 @@ #include "concrete_rt.hpp" #include "smt_solver.hpp" #include "symbolic_rt.hpp" +#include "utils.hpp" #include #include #include @@ -43,33 +44,33 @@ inline void ConcolicDriver::run() { auto unexplored = ExploreTree.pick_unexplored(); if (!unexplored) { - std::cout << "No unexplored nodes found, exiting..." << std::endl; + GENSYM_INFO("No unexplored nodes found, exiting..."); return; } auto cond = unexplored->collect_path_conds(); auto result = solver.solve(cond); if (!result.has_value()) { - // TODO: current implementation is buggy, there could be other reachable - // unexplored paths - std::cout << "Found an unreachable path, marking it as unreachable..." - << std::endl; + GENSYM_INFO("Found an unreachable path, marking it as unreachable..."); unexplored->fillUnreachableNode(); continue; } auto new_env = result.value(); SymEnv.update(std::move(new_env)); try { + GENSYM_INFO("Now execute the program with symbolic environment: "); + GENSYM_INFO(SymEnv.to_string()); entrypoint(); - std::cout << "Execution finished successfully with symbolic environment:" - << std::endl; - std::cout << SymEnv.to_string() << std::endl; + GENSYM_INFO("Execution finished successfully with symbolic environment:"); + GENSYM_INFO(SymEnv.to_string()); } catch (...) { ExploreTree.fillFailedNode(); - std::cout << "Caught runtime error with symbolic environment:" - << std::endl; - std::cout << SymEnv.to_string() << std::endl; + GENSYM_INFO("Caught runtime error with symbolic environment:"); + GENSYM_INFO(SymEnv.to_string()); return; } +#if defined(RUN_ONCE) + return; +#endif } } diff --git a/headers/wasm/concrete_rt.hpp b/headers/wasm/concrete_rt.hpp index a0961453..a9abccf2 100644 --- a/headers/wasm/concrete_rt.hpp +++ b/headers/wasm/concrete_rt.hpp @@ -72,9 +72,7 @@ class Stack_t { Num pop() { #ifdef DEBUG - if (count == 0) { - throw std::runtime_error("Stack underflow"); - } + assert(count > 0 && "Stack underflow"); #endif Num num = stack_ptr[count - 1]; count--; @@ -117,7 +115,7 @@ class Stack_t { void initialize() { // todo: remove this method - reset(); + reset(); } void reset() { count = 0; } diff --git a/headers/wasm/controls.hpp b/headers/wasm/controls.hpp new file mode 100644 index 00000000..16fa5136 --- /dev/null +++ b/headers/wasm/controls.hpp @@ -0,0 +1,5 @@ +#include +#include + +using MCont_t = std::function; +using Cont_t = std::function; \ No newline at end of file diff --git a/headers/wasm/smt_solver.hpp b/headers/wasm/smt_solver.hpp index f2450905..bc8cc9f9 100644 --- a/headers/wasm/smt_solver.hpp +++ b/headers/wasm/smt_solver.hpp @@ -3,6 +3,7 @@ #include "concrete_rt.hpp" #include "symbolic_rt.hpp" +#include "utils.hpp" #include "z3++.h" #include #include @@ -35,8 +36,8 @@ class Solver { std::vector result; // Reference: // https://github.com/Z3Prover/z3/blob/master/examples/c%2B%2B/example.cpp#L59 - - std::cout << "Solved Z3 model" << std::endl << model << std::endl; + GENSYM_INFO("Solved Z3 model"); + GENSYM_INFO(model); for (unsigned i = 0; i < model.size(); ++i) { z3::func_decl var = model[i]; z3::expr value = model.get_const_interp(var); @@ -48,8 +49,7 @@ class Solver { } result[id] = Num(value.get_numeral_int64()); } else { - std::cout << "Find a variable that is not created by GenSym: " << name - << std::endl; + GENSYM_INFO("Find a variable that is not created by GenSym: " + name); } } return result; diff --git a/headers/wasm/symbolic_rt.hpp b/headers/wasm/symbolic_rt.hpp index 18629c80..94351f07 100644 --- a/headers/wasm/symbolic_rt.hpp +++ b/headers/wasm/symbolic_rt.hpp @@ -2,7 +2,9 @@ #define WASM_SYMBOLIC_RT_HPP #include "concrete_rt.hpp" +#include "controls.hpp" #include +#include #include #include #include @@ -22,6 +24,12 @@ class Symbolic { static int max_id = 0; +#ifdef NO_REUSE +static bool REUSE_MODE = false; +#else +static bool REUSE_MODE = true; +#endif + class Symbol : public Symbolic { public: // TODO: add type information to determine the size of bitvector @@ -65,6 +73,10 @@ struct SymVal { SymVal negate() const; }; +static SymVal make_symbolic(int index) { + return SymVal(std::make_shared(index)); +} + inline SymVal Concrete(Num num) { return SymVal(std::make_shared(num)); } @@ -133,6 +145,8 @@ inline SymVal SymVal::makeSymbolic() const { } } +class Snapshot_t; + class SymStack_t { public: void push(SymVal val) { @@ -142,6 +156,11 @@ class SymStack_t { SymVal pop() { // Pop a symbolic value from the stack + +#ifdef DEBUG + printf("[Debug] poping from stack, size of symbolic stack is: %zu\n", + stack.size()); +#endif auto ret = stack.back(); stack.pop_back(); return ret; @@ -149,11 +168,25 @@ class SymStack_t { SymVal peek() { return stack.back(); } + std::monostate shift(int32_t offset, int32_t size) { + auto n = stack.size(); + for (size_t i = n - size; i < n; ++i) { + stack[i - offset] = stack[i]; + } + stack.resize(n - offset); + return std::monostate(); + } + void reset() { // Reset the symbolic stack stack.clear(); } + void reuse(Snapshot_t snapshot); + + size_t size() const { return stack.size(); } + +private: std::vector stack; }; @@ -187,9 +220,46 @@ class SymFrames_t { stack.clear(); } + void reuse(Snapshot_t snapshot); + std::vector stack; }; +// A snapshot of the symbolic state and execution context (control) +class Snapshot_t { +public: + explicit Snapshot_t(); + + SymStack_t get_stack() const { return stack; } + SymFrames_t get_frames() const { return frames; } + +private: + SymStack_t stack; + SymFrames_t frames; +}; + +inline void SymStack_t::reuse(Snapshot_t snapshot) { +// Reusing the symbolic stack from the snapshot +#ifdef DEBUG + std::cout << "Reusing symbolic state from snapshot" << std::endl; + std::cout << "Old stack size = " << stack.size() << std::endl; + std::cout << "New stack size = " << snapshot.get_stack().stack.size() + << std::endl; +#endif + stack = snapshot.get_stack().stack; +} + +inline void SymFrames_t::reuse(Snapshot_t snapshot) { +// Reusing the symbolic frames from the snapshot +#ifdef DEBUG + std::cout << "Reusing symbolic state from snapshot" << std::endl; + std::cout << "Old frame size = " << stack.size() << std::endl; + std::cout << "New frame size = " << snapshot.get_frames().stack.size() + << std::endl; +#endif + stack = snapshot.get_frames().stack; +} + static SymFrames_t SymFrames; struct Node; @@ -199,7 +269,7 @@ struct NodeBox { std::unique_ptr node; NodeBox *parent; - std::monostate fillIfElseNode(SymVal cond); + std::monostate fillIfElseNode(SymVal cond, const Snapshot_t &snapshot); std::monostate fillFinishedNode(); std::monostate fillFailedNode(); std::monostate fillUnreachableNode(); @@ -251,8 +321,9 @@ struct IfElseNode : Node { SymVal cond; std::unique_ptr true_branch; std::unique_ptr false_branch; + Snapshot_t snapshot; - IfElseNode(SymVal cond, NodeBox *parent) + IfElseNode(SymVal cond, NodeBox *parent, Snapshot_t snapshot) : cond(cond), true_branch(std::make_unique(parent)), false_branch(std::make_unique(parent)) {} @@ -367,13 +438,15 @@ inline NodeBox::NodeBox(NodeBox *parent) /* TODO: avoid allocation of unexplored node */ parent(parent) {} -inline std::monostate NodeBox::fillIfElseNode(SymVal cond) { - // fill the current NodeBox with an ifelse branch node it's unexplored +inline std::monostate NodeBox::fillIfElseNode(SymVal cond, + const Snapshot_t &snapshot) { + // fill the current NodeBox with an ifelse branch node when it's unexplored if (dynamic_cast(node.get())) { - node = std::make_unique(cond, this); + node = std::make_unique(cond, this, snapshot); } - assert(dynamic_cast(node.get()) != nullptr && - "Current node is not an IfElseNode, cannot fill it!"); + assert( + dynamic_cast(node.get()) != nullptr && + "Current node is not an Unexplored nor an IfElseNode, cannot fill it!"); return std::monostate(); } @@ -427,6 +500,32 @@ inline std::vector NodeBox::collect_path_conds() { return result; } +class Reuse_t { +public: + Reuse_t() : reuse_flag(false) {} + bool is_reusing() { + // we are in reuse mode and the flag is set + return REUSE_MODE && reuse_flag; + } + + void turn_on_reusing() { reuse_flag = true; } + + void turn_off_reusing() { reuse_flag = false; } + +private: + bool reuse_flag; +}; + +static Reuse_t Reuse; + +inline Snapshot_t::Snapshot_t() : stack(SymStack), frames(SymFrames) { +#ifdef DEBUG + std::cout << "Creating snapshot of size " << stack.size() << std::endl; +#endif + assert(!Reuse.is_reusing() && + "Creating snapshot while reusing the symbolic stack"); +} + class ExploreTree_t { public: explicit ExploreTree_t() @@ -435,14 +534,19 @@ class ExploreTree_t { void reset_cursor() { // Reset the cursor to the root of the tree cursor = root.get(); + Reuse.turn_off_reusing(); + // if root cursor is a branch node, then we can reuse the snapshot inside it + if (auto ite = dynamic_cast(cursor->node.get())) { + Reuse.turn_on_reusing(); + } } std::monostate fillFinishedNode() { return cursor->fillFinishedNode(); } std::monostate fillFailedNode() { return cursor->fillFailedNode(); } - std::monostate fillIfElseNode(SymVal cond) { - return cursor->fillIfElseNode(cond); + std::monostate fillIfElseNode(SymVal cond, const Snapshot_t &snapshot) { + return cursor->fillIfElseNode(cond, snapshot); } std::monostate moveCursor(bool branch) { @@ -456,6 +560,24 @@ class ExploreTree_t { } else { cursor = if_else_node->false_branch.get(); } + + if (dynamic_cast(cursor->node.get())) { + // If we meet an unexplored node, resume the snapshot before and keep + // going + +#ifdef DEBUG + std::cout << "Resuming snapshot for unexplored node" << std::endl; +#endif + if (Reuse.is_reusing()) { + Reuse.turn_off_reusing(); + SymStack.reuse(if_else_node->snapshot); + } + } else if (dynamic_cast(cursor->node.get())) { + // if we are moving to a branch node, we must have reused the symbolic + // states + assert((!REUSE_MODE || Reuse.is_reusing()) && + "Moving to a branch node without reusing symbolic states"); + } return std::monostate(); } @@ -531,9 +653,7 @@ class SymEnv_t { return map[symbol->get_id()]; } - void update(std::vector new_env) { - map = std::move(new_env); - } + void update(std::vector new_env) { map = std::move(new_env); } std::string to_string() const { std::string result; @@ -548,7 +668,7 @@ class SymEnv_t { } private: - std::vector map; // The symbolic environment, a vector of Num + std::vector map; // The symbolic environment, a vector of Num }; static SymEnv_t SymEnv; diff --git a/headers/wasm/utils.hpp b/headers/wasm/utils.hpp index 8a86ac98..ba57a1df 100644 --- a/headers/wasm/utils.hpp +++ b/headers/wasm/utils.hpp @@ -12,4 +12,28 @@ } while (0) #endif +#ifndef NO_DBG +#define GENSYM_DBG(obj) \ + do { \ + std::cout << "LOG: " << obj << " (" << __FILE__ << ":" \ + << std::to_string(__LINE__) << ")" << std::endl; \ + } while (0) +#else +#define GENSYM_LOG(message) \ + do { \ + } while (0) +#endif + +#ifndef NO_INFO +#define GENSYM_INFO(obj) \ + do { \ + std::cout << obj << std::endl; \ + } while (0) +#else +#define GENSYM_INFO(message) \ + do { \ + } while (0) + +#endif + #endif // UTILS_HPP \ No newline at end of file diff --git a/src/main/scala/wasm/StagedConcolicMiniWasm.scala b/src/main/scala/wasm/StagedConcolicMiniWasm.scala index 833bbc9b..769a0b85 100644 --- a/src/main/scala/wasm/StagedConcolicMiniWasm.scala +++ b/src/main/scala/wasm/StagedConcolicMiniWasm.scala @@ -26,28 +26,57 @@ trait StagedWasmEvaluator extends SAIOps { trait ReturnSite trait StagedNum { + def tipe: ValueType + } + + trait StagedConcreteNum { def tipe: ValueType = this match { - case I32(_, _) => NumType(I32Type) - case I64(_, _) => NumType(I64Type) - case F32(_, _) => NumType(F32Type) - case F64(_, _) => NumType(F64Type) + case I32C(_) => NumType(I32Type) + case I64C(_) => NumType(I64Type) + case F32C(_) => NumType(F32Type) + case F64C(_) => NumType(F64Type) } def i: Rep[Num] + } + + case class I32C(i: Rep[Num]) extends StagedConcreteNum + case class I64C(i: Rep[Num]) extends StagedConcreteNum + case class F32C(i: Rep[Num]) extends StagedConcreteNum + case class F64C(i: Rep[Num]) extends StagedConcreteNum + + + trait StagedSymbolicNum { + def tipe: ValueType = this match { + case I32S(_) => NumType(I32Type) + case I64S(_) => NumType(I64Type) + case F32S(_) => NumType(F32Type) + case F64S(_) => NumType(F64Type) + } def s: Rep[SymVal] } - case class I32(i: Rep[Num], s: Rep[SymVal]) extends StagedNum - case class I64(i: Rep[Num], s: Rep[SymVal]) extends StagedNum - case class F32(i: Rep[Num], s: Rep[SymVal]) extends StagedNum - case class F64(i: Rep[Num], s: Rep[SymVal]) extends StagedNum - def toStagedNum(num: Num): StagedNum = { + case class I32S(s: Rep[SymVal]) extends StagedSymbolicNum + case class I64S(s: Rep[SymVal]) extends StagedSymbolicNum + case class F32S(s: Rep[SymVal]) extends StagedSymbolicNum + case class F64S(s: Rep[SymVal]) extends StagedSymbolicNum + + def toStagedNum(num: Num): StagedConcreteNum = { num match { - case I32V(_) => I32(num, Concrete(num)) - case I64V(_) => I64(num, Concrete(num)) - case F32V(_) => F32(num, Concrete(num)) - case F64V(_) => F64(num, Concrete(num)) + case I32V(_) => I32C(num) + case I64V(_) => I64C(num) + case F32V(_) => F32C(num) + case F64V(_) => F64C(num) + } + } + + def toStagedSymbolicNum(num: Num): StagedSymbolicNum = { + num match { + case I32V(_) => I32S(Concrete(num)) + case I64V(_) => I64S(Concrete(num)) + case F32V(_) => F32S(Concrete(num)) + case F64V(_) => F64S(Concrete(num)) } } @@ -59,12 +88,21 @@ trait StagedWasmEvaluator extends SAIOps { case NumType(F64Type) => 8 } - def toTagger: (Rep[Num], Rep[SymVal]) => StagedNum = { + def concreteTag: (Rep[Num]) => StagedConcreteNum = { + ty match { + case NumType(I32Type) => I32C + case NumType(I64Type) => I64C + case NumType(F32Type) => F32C + case NumType(F64Type) => F64C + } + } + + def symbolicTag: (Rep[SymVal]) => StagedSymbolicNum = { ty match { - case NumType(I32Type) => I32 - case NumType(I64Type) => I64 - case NumType(F32Type) => F32 - case NumType(F64Type) => F64 + case NumType(I32Type) => I32S + case NumType(I64Type) => I64S + case NumType(F32Type) => F32S + case NumType(F64Type) => F64S } } } @@ -74,12 +112,22 @@ trait StagedWasmEvaluator extends SAIOps { frameTypes: List[ValueType] ) { def push(ty: ValueType): Context = { - Context(ty :: stackTypes, frameTypes) + this.copy(stackTypes = ty :: stackTypes) + } + + def peek: ValueType = { + stackTypes.head } def pop(): (ValueType, Context) = { val (ty :: rest) = stackTypes - (ty, Context(rest, frameTypes)) + (ty, this.copy(stackTypes = rest)) + } + + def take(n: Int): Context = { + Predef.assert(n <= stackTypes.size, s"Context.take size $n is larger than stack size ${stackTypes.size}") + val (taken, rest) = stackTypes.splitAt(n) + this.copy(stackTypes = rest) } def shift(offset: Int, size: Int): Context = { @@ -93,11 +141,50 @@ trait StagedWasmEvaluator extends SAIOps { ) } } + + } + + case class ContextTransition(startCtx: Context, history: List[Instr], endCtx: Context) { + def log(instr: Instr): ContextTransition = { + this.copy(history = instr :: history) + } + + def clearHistory: (Context, List[Instr], CleanCT) = { + (startCtx, history, CleanCT(endCtx)) + } + + def push(ty: ValueType): ContextTransition = { + this.copy(endCtx = endCtx.push(ty)) + } + + def peek: ValueType = { + endCtx.peek + } + + def pop(): (ValueType, ContextTransition) = { + val (ty, newCtx) = endCtx.pop() + (ty, this.copy(endCtx = newCtx)) + } + + def take(n: Int): ContextTransition = { + this.copy(endCtx = endCtx.take(n)) + } + + def shift(offset: Int, size: Int): ContextTransition = { + this.copy(endCtx = endCtx.shift(offset, size)) + } + } + + case class CleanCT(ctx: Context) + + // we can treat every CleanCT as a ContextTransition + implicit def toContextCT(ct: CleanCT): ContextTransition = { + ContextTransition(ct.ctx, Nil, ct.ctx) } type MCont[A] = Unit => A type Cont[A] = (MCont[A]) => A - type Trail[A] = List[Context => Rep[Cont[A]]] + type Trail[A] = List[CleanCT => Rep[Cont[A]]] // a cache storing the compiled code for each function, to reduce re-compilation val compileCache = new HashMap[Int, Rep[(MCont[Unit]) => Unit]] @@ -113,173 +200,268 @@ trait StagedWasmEvaluator extends SAIOps { } + // TODO: maybe we don't need concern snapshot at compile time at all + trait Snapshot + + // Create a snapshot of the symbolic execution, we should ensure that current symstack is in use + // We don't need to store the control information, since the control is totally decided by concrete states + def makeSnapshot(): Rep[Snapshot] = { + "snapshot-make".reflectCtrlWith[Snapshot]() + } + + def isSymStateInUse: Rep[Boolean] = !ReuseManager.isReusing + def eval(insts: List[Instr], - kont: Context => Rep[Cont[Unit]], + kont: CleanCT => Rep[Cont[Unit]], mkont: Rep[MCont[Unit]], trail: Trail[Unit]) - (implicit ctx: Context): Rep[Unit] = { - if (insts.isEmpty) return kont(ctx)(mkont) + (oldCT: ContextTransition): Rep[Unit] = { + if (insts.isEmpty) { + val (oldCtx, history, ct) = oldCT.clearHistory + if (isSymStateInUse) { + evalSym(history)(oldCtx) + } + return kont(ct)(mkont) + } // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") // Predef.println(s"[DEBUG] Current context: $ctx") - val (inst, rest) = (insts.head, insts.tail) + val ct = oldCT.log(inst) inst match { case Drop => - val (_, newCtx) = Stack.pop() - eval(rest, kont, mkont, trail)(newCtx) + val (ty, ct1) = ct.pop() + Stack.popC(ty) + eval(rest, kont, mkont, trail)(ct1) case WasmConst(num) => - val newCtx = Stack.push(toStagedNum(num)) - eval(rest, kont, mkont, trail)(newCtx) + Stack.pushC(toStagedNum(num)) + val ct1 = ct.push(num.tipe(module)) + eval(rest, kont, mkont, trail)(ct1) case Symbolic(ty) => - val (id, newCtx1) = Stack.pop() - val symVal = id.makeSymbolic() - val concVal = SymEnv.read(symVal) - val tagger = ty.toTagger - val value = tagger(concVal, symVal) - val newCtx2 = Stack.push(value)(newCtx1) - eval(rest, kont, mkont, trail)(newCtx2) + val id = Stack.popC(ty) + val symVal = id.makeSymbolic(ty) + val num = SymEnv.read(symVal.s) + Stack.pushC(ty.concreteTag(num)) + val ct1 = ct.pop()._2.push(ty) + eval(rest, kont, mkont, trail)(ct1) case LocalGet(i) => - val newCtx = Stack.push(Frames.get(i)) - eval(rest, kont, mkont, trail)(newCtx) + Stack.pushC(Frames.getC(i)(ct.endCtx)) + val ct1 = ct.push(ct.endCtx.frameTypes(i)) + eval(rest, kont, mkont, trail)(ct1) case LocalSet(i) => - val (num, newCtx) = Stack.pop() - Frames.set(i, num)(newCtx) - eval(rest, kont, mkont, trail)(newCtx) + val (ty, ct1) = ct.pop() + val num = Stack.popC(ty) + Frames.setC(i, num) + eval(rest, kont, mkont, trail)(ct1) case LocalTee(i) => - val (num, newCtx) = Stack.peek - Frames.set(i, num) - eval(rest, kont, mkont, trail)(newCtx) + val ty = ct.peek + val num = Stack.peekC(ty) + Frames.setC(i, num) + eval(rest, kont, mkont, trail)(ct) case GlobalGet(i) => - val newCtx = Stack.push(Globals(i)) - eval(rest, kont, mkont, trail)(newCtx) + Stack.pushC(Globals.getC(i)) + val ct1 = ct.push(module.globals(i).ty.ty) + eval(rest, kont, mkont, trail)(ct1) case GlobalSet(i) => - val (value, newCtx) = Stack.pop() + val (ty, ct1) = ct.pop() + val num = Stack.popC(ty) module.globals(i).ty match { - case GlobalType(tipe, true) => Globals(i) = value + case GlobalType(tipe, true) => { + Globals.setC(i, num) + } case _ => throw new Exception("Cannot set immutable global") } - eval(rest, kont, mkont, trail)(newCtx) + eval(rest, kont, mkont, trail)(ct1) case Store(StoreOp(align, offset, ty, None)) => - val (value, newCtx1) = Stack.pop() - val (addr, newCtx2) = Stack.pop()(newCtx1) + val (ty1, ct1) = ct.pop() + val value = Stack.popC(ty1) + val (ty2, ct2) = ct1.pop() + val addr = Stack.popC(ty2) Memory.storeInt(addr.toInt, offset, value.toInt) - eval(rest, kont, mkont, trail)(newCtx2) - case Nop => eval(rest, kont, mkont, trail) + eval(rest, kont, mkont, trail)(ct2) + case Nop => eval(rest, kont, mkont, trail)(ct) case Load(LoadOp(align, offset, ty, None, None)) => - val (addr, newCtx1) = Stack.pop() - val value = Memory.loadInt(addr.toInt, offset) - val newCtx2 = Stack.push(value)(newCtx1) - eval(rest, kont, mkont, trail)(newCtx2) + val (ty1, ct1) = ct.pop() + val addr = Stack.popC(ty1) + val num = Memory.loadIntC(addr.toInt, offset) + Stack.pushC(num) + val ct2 = ct1.push(ty) + eval(rest, kont, mkont, trail)(ct2) case MemorySize => ??? case MemoryGrow => - val (delta, newCtx1) = Stack.pop() + val (ty, ct1) = ct.pop() + val delta = Stack.popC(ty) val ret = Memory.grow(delta.toInt) val retNum = Values.I32V(ret) - val retSym = "Concrete".reflectCtrlWith[SymVal](retNum) - val newCtx2 = Stack.push(I32(retNum, retSym))(newCtx1) - eval(rest, kont, mkont, trail)(newCtx2) + // For now, we assume that the result of memory.grow only depends on the execution path, + // we can relax this by turning it return to a symbol value and mimic the memory.grow's result as input. + Stack.pushC(I32C(retNum)) + val ct2 = ct1.push(NumType(I32Type)) + eval(rest, kont, mkont, trail)(ct2) case MemoryFill => ??? case Unreachable => unreachable() case Test(op) => - val (v, newCtx1) = Stack.pop() - val newCtx2 = Stack.push(evalTestOp(op, v))(newCtx1) - eval(rest, kont, mkont, trail)(newCtx2) + val (ty, ct1) = ct.pop() + val v = Stack.popC(ty) + Stack.pushC(evalTestOpC(op, v)) + val ct2 = ct1.push(v.tipe) + eval(rest, kont, mkont, trail)(ct2) case Unary(op) => - val (v, newCtx1) = Stack.pop() - val newCtx2 = Stack.push(evalUnaryOp(op, v))(newCtx1) - eval(rest, kont, mkont, trail)(newCtx2) + val (ty, ct1) = ct.pop() + val v = Stack.popC(ty) + val res = evalUnaryOpC(op, v) + Stack.pushC(res) + val ct2 = ct1.push(res.tipe) + eval(rest, kont, mkont, trail)(ct2) case Binary(op) => - val (v2, newCtx1) = Stack.pop() - val (v1, newCtx2) = Stack.pop()(newCtx1) - val newCtx3 = Stack.push(evalBinOp(op, v1, v2))(newCtx2) - eval(rest, kont, mkont, trail)(newCtx3) + val (ty2, ct1) = ct.pop() + val v2 = Stack.popC(ty2) + val (ty1, ct2) = ct1.pop() + val v1 = Stack.popC(ty1) + val res = evalBinOpC(op, v1, v2) + Stack.pushC(res) + val ct3 = ct2.push(res.tipe) + eval(rest, kont, mkont, trail)(ct3) case Compare(op) => - val (v2, newCtx1) = Stack.pop() - val (v1, newCtx2) = Stack.pop()(newCtx1) - val newCtx3 = Stack.push(evalRelOp(op, v1, v2))(newCtx2) - eval(rest, kont, mkont, trail)(newCtx3) + val (ty2, ct1) = ct.pop() + val v2 = Stack.popC(ty2) + val (ty1, ct2) = ct1.pop() + val v1 = Stack.popC(ty1) + val res = evalRelOpC(op, v1, v2) + Stack.pushC(res) + val ct3 = ct2.push(res.tipe) + eval(rest, kont, mkont, trail)(ct3) case WasmBlock(ty, inner) => // no need to modify the stack when entering a block // the type system guarantees that we will never take more than the input size from the stack val funcTy = ty.funcType - val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size + val exitSize = ct.endCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size val dummy = makeDummy - def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + def restK(ct: CleanCT): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the block, stackSize =", Stack.size) - val offset = restCtx.stackTypes.size - exitSize - val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) - eval(rest, kont, mk, trail)(newRestCtx) + val offset = ct.endCtx.stackTypes.size - exitSize + Stack.shiftC(offset, funcTy.out.size) + if (isSymStateInUse) { + Stack.shiftS(offset, funcTy.out.size) + } + val ct1 = ct.shift(offset, funcTy.out.size) + eval(rest, kont, mk, trail)(ct1) }) - eval(inner, restK _, mkont, restK _ :: trail) + // TODO: extract this into a function + val (oldCtx, history, ct1) = ct.clearHistory + if (isSymStateInUse) { + evalSym(history)(oldCtx) + } + eval(inner, restK _, mkont, restK _ :: trail)(ct1) case Loop(ty, inner) => val funcTy = ty.funcType - val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size + val exitSize = ct.endCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size val dummy = makeDummy - def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + def restK(ct: CleanCT): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the loop, stackSize =", Stack.size) - val offset = restCtx.stackTypes.size - exitSize - val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) - eval(rest, kont, mk, trail)(newRestCtx) + val offset = ct.endCtx.stackTypes.size - exitSize + Stack.shiftC(offset, funcTy.out.size) + if (isSymStateInUse) { + Stack.shiftS(offset, funcTy.out.size) + } + val ct1 = ct.shift(offset, funcTy.out.size) + eval(rest, kont, mk, trail)(ct1) }) - val enterSize = ctx.stackTypes.size - def loop(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + val enterSize = ct.endCtx.stackTypes.size + def loop(ct: CleanCT): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the loop, stackSize =", Stack.size) - val offset = restCtx.stackTypes.size - enterSize - val newRestCtx = Stack.shift(offset, funcTy.inps.size)(restCtx) - eval(inner, restK _, mk, loop _ :: trail)(newRestCtx) + val offset = ct.endCtx.stackTypes.size - enterSize + Stack.shiftC(offset, funcTy.inps.size) + if (isSymStateInUse) { + Stack.shiftS(offset, funcTy.inps.size) + } + val ct1 = ct.shift(offset, funcTy.inps.size) + eval(inner, restK _, mk, loop _ :: trail)(ct1) }) - loop(ctx)(mkont) + val (oldCtx, history, ct1) = ct.clearHistory + if (isSymStateInUse) { + evalSym(history)(oldCtx) + } + loop(ct1)(mkont) case If(ty, thn, els) => val funcTy = ty.funcType - val (cond, newCtx) = Stack.pop() - val exitSize = newCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size - // TODO: can we avoid code duplication here? - val dummy = makeDummy - def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + val (condTy, ct1) = ct.pop() + val cond = Stack.popC(condTy) + val exitSize = ct1.endCtx.stackTypes.size - funcTy.inps.size + funcTy.out.size + def restK(ct: CleanCT): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the if, stackSize =", Stack.size) - val offset = restCtx.stackTypes.size - exitSize - val newRestCtx = Stack.shift(offset, funcTy.out.size)(restCtx) - eval(rest, kont, mk, trail)(newRestCtx) + val offset = ct.endCtx.stackTypes.size - exitSize + Stack.shiftC(offset, funcTy.out.size) + if (isSymStateInUse) { + Stack.shiftS(offset, funcTy.out.size) + } + val ct1 = ct.shift(offset, funcTy.out.size) + eval(rest, kont, mk, trail)(ct1) }) - // TODO: put the cond.s to path condition - ExploreTree.fillWithIfElse(cond.s) + val (oldCtx, history, ct2) = ct1.clearHistory + if (isSymStateInUse) { + // when we are not reusing + evalSym(history)(oldCtx) + val snapshot = makeSnapshot() + val symCond = Stack.popS(condTy) + ExploreTree.fillWithIfElse(symCond.s, snapshot) + } if (cond.toInt != 0) { ExploreTree.moveCursor(true) - eval(thn, restK _, mkont, restK _ :: trail)(newCtx) + eval(thn, restK _, mkont, restK _ :: trail)(ct2) } else { ExploreTree.moveCursor(false) - eval(els, restK _, mkont, restK _ :: trail)(newCtx) + eval(els, restK _, mkont, restK _ :: trail)(ct2) } () case Br(label) => info(s"Jump to $label") - trail(label)(ctx)(mkont) + val (oldCtx, history, ct1) = ct.clearHistory + if (isSymStateInUse) { + evalSym(history)(oldCtx) + } + trail(label)(ct1)(mkont) case BrIf(label) => - val (cond, newCtx) = Stack.pop() + val (ty, ct1) = ct.pop() + val cond = Stack.popC(ty) + val (oldCtx, history, ct2) = ct1.clearHistory info(s"The br_if(${label})'s condition is ", cond.toInt) - // TODO: put the cond.s to path condition - ExploreTree.fillWithIfElse(cond.s) + if (isSymStateInUse) { + evalSym(history)(oldCtx) + val symCond = Stack.popS(ty) + val snapshot = makeSnapshot() + ExploreTree.fillWithIfElse(symCond.s, snapshot) + } if (cond.toInt != 0) { info(s"Jump to $label") ExploreTree.moveCursor(true) - trail(label)(newCtx)(mkont) + trail(label)(ct2)(mkont) } else { info(s"Continue") ExploreTree.moveCursor(false) - eval(rest, kont, mkont, trail)(newCtx) + eval(rest, kont, mkont, trail)(ct2) } () case BrTable(labels, default) => - val (label, newCtx) = Stack.pop() + val (ty, ct1) = ct.pop() + val label = Stack.popC(ty) + val (oldCtx, history, ct2) = ct1.clearHistory + if (isSymStateInUse) { + evalSym(history)(oldCtx) + } def aux(choices: List[Int], idx: Int): Rep[Unit] = { - if (choices.isEmpty) trail(default)(newCtx)(mkont) + if (choices.isEmpty) trail(default)(ct2)(mkont) else { val cond = (label - toStagedNum(I32V(idx))).isZero() - ExploreTree.fillWithIfElse(cond.s) + if (isSymStateInUse) { + val labelSym = Stack.peekS(ty) + val condSym = (labelSym - toStagedSymbolicNum(I32V(idx))).isZero() + val snapshot = makeSnapshot() + ExploreTree.fillWithIfElse(condSym.s, snapshot) + } if (cond.toInt != 0) { ExploreTree.moveCursor(true) - trail(choices.head)(newCtx)(mkont) + trail(choices.head)(ct2)(mkont) } else { ExploreTree.moveCursor(false) @@ -288,12 +470,142 @@ trait StagedWasmEvaluator extends SAIOps { } } aux(labels, 0) - case Return => trail.last(ctx)(mkont) - case Call(f) => evalCall(rest, kont, mkont, trail, f, false) - case ReturnCall(f) => evalCall(rest, kont, mkont, trail, f, true) + if (isSymStateInUse) { + Stack.popS(ty) + } + () + case Return => + // return instruction is also stack-polymorphic + val (oldCtx, history, ct2) = ct.clearHistory + if (isSymStateInUse) { + evalSym(history)(oldCtx) + } + trail.last(ct2)(mkont) + case Call(f) => evalCall(rest, kont, mkont, trail, f, false)(ct) + case ReturnCall(f) => evalCall(rest, kont, mkont, trail, f, true)(ct) + case _ => + val todo = "todo-op".reflectCtrlWith[Unit]() + eval(rest, kont, mkont, trail)(ct) + } + } + + def replayAndClearHistory(ct: ContextTransition): ContextTransition = { + val (oldCtx, history, ct1) = ct.clearHistory + if (isSymStateInUse) { + evalSym(history)(oldCtx) + } + ct1 + } + + // call the symbolic interpreter to evaluate the history that just executed by + // concrete interpreter + def evalSym(history: List[Instr]) + (ctx: Context): Rep[Unit] = { + // val func = topFun((_: Rep[Unit]) => evalS(history.reverse)) + // func(()) + evalS(history.reverse)(ctx) + } + + def evalS(insts: List[Instr]) + (ctx: Context): Rep[Unit] = { + if (insts.isEmpty) return () + + // Predef.println(s"[DEBUG] Evaluating instructions: ${insts.mkString(", ")}") + // Predef.println(s"[DEBUG] Current context: $ctx") + val (inst, rest) = (insts.head, insts.tail) + inst match { + case Drop => + val (ty, newCtx) = ctx.pop() + Stack.popS(ty) + evalS(rest)(newCtx) + case WasmConst(num) => + Stack.pushS(toStagedSymbolicNum(num)) + val newCtx = ctx.push(num.tipe(module)) + evalS(rest)(newCtx) + case Symbolic(ty) => + val id = Stack.popS(ty) + val symVal = id.makeSymbolic(ty) + Stack.pushS(symVal) + val newCtx = ctx.pop()._2.push(ty) + evalS(rest)(newCtx) + case LocalGet(i) => + Stack.pushS(Frames.getS(i)(ctx)) + val newCtx = ctx.push(ctx.frameTypes(i)) + evalS(rest)(newCtx) + case LocalSet(i) => + val (ty, newCtx) = ctx.pop() + val sym = Stack.popS(ty) + Frames.setS(i, sym) + evalS(rest)(newCtx) + case LocalTee(i) => + val ty = ctx.pop()._1 + val sym = Stack.peekS(ty) + Frames.setS(i, sym) + evalS(rest)(ctx) + case GlobalGet(i) => + Stack.pushS(Globals.getS(i)) + val newCtx = ctx.push(module.globals(i).ty.ty) + evalS(rest)(newCtx) + case GlobalSet(i) => + val (ty, newCtx) = ctx.pop() + val sym = Stack.popS(ty) + module.globals(i).ty match { + case GlobalType(tipe, true) => { + Globals.setS(i, sym) + } + case _ => throw new Exception("Cannot set immutable global") + } + evalS(rest)(newCtx) + case Nop => evalS(rest)(ctx) + case Store(StoreOp(align, offset, ty, None)) => ??? + case Load(LoadOp(align, offset, ty, None, None)) => ??? + case MemorySize => ??? + case MemoryGrow => ??? + case MemoryFill => ??? + case Unreachable => unreachable() + case Test(op) => + val (ty, newCtx1) = ctx.pop() + val s = Stack.popS(ty) + Stack.pushS(evalTestOpS(op, s)) + val newCtx2 = newCtx1.push(s.tipe) + evalS(rest)(newCtx2) + case Unary(op) => + val (ty, newCtx1) = ctx.pop() + val s = Stack.popS(ty) + val res = evalUnaryOpS(op, s) + Stack.pushS(res) + val newCtx2 = newCtx1.push(res.tipe) + evalS(rest)(newCtx2) + case Binary(op) => + val (ty2, newCtx1) = ctx.pop() + val s2 = Stack.popS(ty2) + val (ty1, newCtx2) = newCtx1.pop() + val s1 = Stack.popS(ty1) + val res = evalBinOpS(op, s1, s2) + Stack.pushS(res) + val newCtx3 = newCtx2.push(res.tipe) + evalS(rest)(newCtx3) + case Compare(op) => + val (ty2, newCtx1) = ctx.pop() + val s2 = Stack.popS(ty2) + val (ty1, newCtx2) = newCtx1.pop() + val s1 = Stack.popS(ty1) + val res = evalRelOpS(op, s1, s2) + Stack.pushS(res) + val newCtx3 = newCtx2.push(res.tipe) + evalS(rest)(newCtx3) + case WasmBlock(ty, inner) => () + case Loop(ty, inner) => () + case If(ty, thn, els) => () + case Br(label) => () + case BrIf(label) => () + case BrTable(labels, default) => () + case Return => () + case Call(f) => () + case ReturnCall(f) => () case _ => val todo = "todo-op".reflectCtrlWith[Unit]() - eval(rest, kont, mkont, trail) + evalS(rest)(ctx) } } @@ -301,12 +613,16 @@ trait StagedWasmEvaluator extends SAIOps { def evalCall(rest: List[Instr], - kont: Context => Rep[Cont[Unit]], + kont: CleanCT => Rep[Cont[Unit]], mkont: Rep[MCont[Unit]], trail: Trail[Unit], funcIndex: Int, isTail: Boolean) - (implicit ctx: Context): Rep[Unit] = { + (implicit ct: ContextTransition): Rep[Unit] = { + val (oldCtx, history, ct1) = ct.clearHistory + if (isSymStateInUse) { + evalSym(history)(oldCtx) + } module.funcs(funcIndex) match { case FuncDef(_, FuncBodyDef(ty, _, bodyLocals, body)) => val locals = bodyLocals ++ ty.inps @@ -316,63 +632,102 @@ trait StagedWasmEvaluator extends SAIOps { } else { val callee = topFun((mk: Rep[MCont[Unit]]) => { info(s"Entered the function at $funcIndex, stackSize =", Stack.size) - // we can do some check here to ensure the function returns correct size of stack - eval(body, (_: Context) => forwardKont, mk, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) + // the return instruction is also stack polymorphic + def retK(ct: CleanCT): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { + info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) + val offset = ct.ctx.stackTypes.size - ty.out.size + Stack.shiftC(offset, ty.out.size) + Stack.shiftS(offset, ty.out.size) + mk(()) + }) + eval(body, retK _, mk, retK _::Nil)(CleanCT(Context(Nil, locals))) }) compileCache(funcIndex) = callee callee } // Predef.println(s"[DEBUG] locals size: ${locals.size}") - val (args, newCtx) = Stack.take(ty.inps.size) + val ct2 = ct1.take(ty.inps.size) + val exitSize = ty.out.size + ct2.endCtx.stackTypes.size if (isTail) { // when tail call, return to the caller's return continuation - Frames.popFrame(ctx.frameTypes.size) - Frames.pushFrame(locals) - Frames.putAll(args) + val argsC = Stack.takeC(ty.inps) + Frames.popFrameC(ct2.endCtx.frameTypes.size) + Frames.pushFrameC(locals) + Frames.putAllC(argsC) + if (isSymStateInUse) { + val argsS = Stack.takeS(ty.inps) + Frames.popFrameS(ct2.endCtx.frameTypes.size) + Frames.pushFrameS(locals) + Frames.putAllS(argsS) + } callee(mkont) } else { // We make a new trail by `restK`, since function creates a new block to escape // (more or less like `return`) val restK: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => { info(s"Exiting the function at $funcIndex, stackSize =", Stack.size) - Frames.popFrame(locals.size) - eval(rest, kont, mk, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size))) + Frames.popFrameC(locals.size) + Frames.popFrameS(locals.size) + val newCtx = ct2.endCtx.copy(stackTypes = ty.out.reverse ++ ct2.endCtx.stackTypes) + eval(rest, kont, mk, trail)(CleanCT(newCtx)) }) val dummy = makeDummy val newMKont: Rep[MCont[Unit]] = funHere((_u: Rep[Unit]) => { restK(mkont) }, dummy) - Frames.pushFrame(locals) - Frames.putAll(args) + val argsC = Stack.takeC(ty.inps) + Frames.pushFrameC(locals) + Frames.putAllC(argsC) + if (isSymStateInUse) { + val argsS = Stack.takeS(ty.inps) + Frames.pushFrameS(locals) + Frames.putAllS(argsS) + } callee(newMKont) } case Import("console", "log", _) | Import("spectest", "print_i32", _) => //println(s"[DEBUG] current stack: $stack") - val (v, newCtx) = Stack.pop() + val (ty, ct2) = ct1.pop() + val v = Stack.popC(ty) + Stack.popS(ty) println(v.toInt) - eval(rest, kont, mkont, trail)(newCtx) + eval(rest, kont, mkont, trail)(ct2) case Import("console", "assert", _) => - val (v, newCtx) = Stack.pop() + val (ty, ct2) = ct1.pop() + val v = Stack.popC(ty) + // TODO: We should also add s into exploration tree + val s = Stack.popS(ty) runtimeAssert(v.toInt != 0) - eval(rest, kont, mkont, trail)(newCtx) + eval(rest, kont, mkont, trail)(ct2) case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") case _ => throw new Exception(s"Definition at $funcIndex is not callable") } } - def evalTestOp(op: TestOp, value: StagedNum): StagedNum = op match { + def evalTestOpC(op: TestOp, value: StagedConcreteNum): StagedConcreteNum = op match { case Eqz(_) => value.isZero } - def evalUnaryOp(op: UnaryOp, value: StagedNum): StagedNum = op match { + def evalTestOpS(op: TestOp, value: StagedSymbolicNum): StagedSymbolicNum = op match { + case Eqz(_) => value.isZero + } + + def evalUnaryOpC(op: UnaryOp, value: StagedConcreteNum): StagedConcreteNum = op match { case Clz(_) => value.clz() case Ctz(_) => value.ctz() case Popcnt(_) => value.popcnt() case _ => ??? } - def evalBinOp(op: BinOp, v1: StagedNum, v2: StagedNum): StagedNum = op match { + def evalUnaryOpS(op: UnaryOp, value: StagedSymbolicNum): StagedSymbolicNum = op match { + case Clz(_) => value.clz() + case Ctz(_) => value.ctz() + case Popcnt(_) => value.popcnt() + case _ => ??? + } + + def evalBinOpC(op: BinOp, v1: StagedConcreteNum, v2: StagedConcreteNum): StagedConcreteNum = op match { case Add(_) => v1 + v2 case Mul(_) => v1 * v2 case Sub(_) => v1 - v2 @@ -386,7 +741,35 @@ trait StagedWasmEvaluator extends SAIOps { throw new Exception(s"Unknown binary operation $op") } - def evalRelOp(op: RelOp, v1: StagedNum, v2: StagedNum): StagedNum = op match { + def evalBinOpS(op: BinOp, v1: StagedSymbolicNum, v2: StagedSymbolicNum): StagedSymbolicNum = op match { + case Add(_) => v1 + v2 + case Mul(_) => v1 * v2 + case Sub(_) => v1 - v2 + case Shl(_) => v1 << v2 + // case ShrS(_) => v1 >> v2 // TODO: signed shift right + case ShrU(_) => v1 >> v2 + case And(_) => v1 & v2 + case DivS(_) => v1 / v2 + case DivU(_) => v1 / v2 + case _ => + throw new Exception(s"Unknown binary operation $op") + } + + def evalRelOpC(op: RelOp, v1: StagedConcreteNum, v2: StagedConcreteNum): StagedConcreteNum = op match { + case Eq(_) => v1 numEq v2 + case Ne(_) => v1 numNe v2 + case LtS(_) => v1 < v2 + case LtU(_) => v1 ltu v2 + case GtS(_) => v1 > v2 + case GtU(_) => v1 gtu v2 + case LeS(_) => v1 <= v2 + case LeU(_) => v1 leu v2 + case GeS(_) => v1 >= v2 + case GeU(_) => v1 geu v2 + case _ => ??? + } + + def evalRelOpS(op: RelOp, v1: StagedSymbolicNum, v2: StagedSymbolicNum): StagedSymbolicNum = op match { case Eq(_) => v1 numEq v2 case Ne(_) => v1 numNe v2 case LtS(_) => v1 < v2 @@ -426,12 +809,14 @@ trait StagedWasmEvaluator extends SAIOps { } val (instrs, locals) = (funBody.body, funBody.locals) resetStacks() - Frames.pushFrame(locals) - eval(instrs, (_: Context) => forwardKont, mkont, ((_: Context) => forwardKont)::Nil)(Context(Nil, locals)) - Frames.popFrame(locals.size) + Frames.pushFrameC(locals) + Frames.pushFrameS(locals) + eval(instrs, _ => forwardKont, mkont, ((_: CleanCT) => forwardKont)::Nil)(CleanCT(Context(Nil, locals))) + Frames.popFrameC(locals.size) + Frames.popFrameS(locals.size) } - def evalTop(main: Option[String], printRes: Boolean, dumpTree: Option[String]): Rep[Unit] = { + def evalTop(main: Option[String], printRes: Boolean): Rep[Unit] = { val haltK: Rep[Unit] => Rep[Unit] = (_) => { info("Exiting the program...") if (printRes) { @@ -450,66 +835,78 @@ trait StagedWasmEvaluator extends SAIOps { // stack operations object Stack { - def shift(offset: Int, size: Int)(ctx: Context): Context = { + def shiftC(offset: Int, size: Int) = { if (offset > 0) { "stack-shift".reflectCtrlWith[Unit](offset, size) } - ctx.shift(offset, size) + } + + def shiftS(offset: Int, size: Int) = { + if (offset > 0) { + "sym-stack-shift".reflectCtrlWith[Unit](offset, size) + } } def initialize(): Rep[Unit] = { "stack-init".reflectCtrlWith[Unit]() } - def pop()(implicit ctx: Context): (StagedNum, Context) = { - val (ty, newContext) = ctx.pop() - val num = ty match { - case NumType(I32Type) => I32("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) - case NumType(I64Type) => I64("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) - case NumType(F32Type) => F32("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) - case NumType(F32Type) => F64("stack-pop".reflectCtrlWith[Num](), "sym-stack-pop".reflectCtrlWith[SymVal]()) - } - (num, newContext) + def popC(ty: ValueType): StagedConcreteNum = ty match { + case NumType(I32Type) => I32C("stack-pop".reflectCtrlWith[Num]()) + case NumType(I64Type) => I64C("stack-pop".reflectCtrlWith[Num]()) + case NumType(F32Type) => F32C("stack-pop".reflectCtrlWith[Num]()) + case NumType(F32Type) => F64C("stack-pop".reflectCtrlWith[Num]()) } - def peek(implicit ctx: Context): (StagedNum, Context) = { - val ty = ctx.stackTypes.head - val num = ty match { - case NumType(I32Type) => I32("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) - case NumType(I64Type) => I64("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) - case NumType(F32Type) => F32("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) - case NumType(F32Type) => F64("stack-peek".reflectCtrlWith[Num](), "sym-stack-peek".reflectCtrlWith[SymVal]()) - } - (num, ctx) + def popS(ty: ValueType): StagedSymbolicNum = ty match { + case NumType(I32Type) => I32S("sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(I64Type) => I64S("sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F32S("sym-stack-pop".reflectCtrlWith[SymVal]()) + case NumType(F64Type) => F64S("sym-stack-pop".reflectCtrlWith[SymVal]()) } - def push(num: StagedNum)(implicit ctx: Context): Context = { - num match { - case I32(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) - case I64(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) - case F32(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) - case F64(v, s) => "stack-push".reflectCtrlWith[Unit](v); "sym-stack-push".reflectCtrlWith[Unit](s) - } - ctx.push(num.tipe) + def peekC(ty: ValueType): StagedConcreteNum = ty match { + case NumType(I32Type) => I32C("stack-peek".reflectCtrlWith[Num]()) + case NumType(I64Type) => I64C("stack-peek".reflectCtrlWith[Num]()) + case NumType(F32Type) => F32C("stack-peek".reflectCtrlWith[Num]()) + case NumType(F32Type) => F64C("stack-peek".reflectCtrlWith[Num]()) } - def take(n: Int)(implicit ctx: Context): (List[StagedNum], Context) = n match { - case 0 => (Nil, ctx) - case n => - val (v, newCtx1) = pop() - val (rest, newCtx2) = take(n - 1) - (v::rest, newCtx2) + def peekS(ty: ValueType): StagedSymbolicNum = ty match { + case NumType(I32Type) => I32S("sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(I64Type) => I64S("sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(F32Type) => F32S("sym-stack-peek".reflectCtrlWith[SymVal]()) + case NumType(F64Type) => F64S("sym-stack-peek".reflectCtrlWith[SymVal]()) } - def drop(n: Int)(implicit ctx: Context): Context = { - take(n)._2 + def pushC(num: StagedConcreteNum) = num match { + case I32C(v) => "stack-push".reflectCtrlWith[Unit](v) + case I64C(v) => "stack-push".reflectCtrlWith[Unit](v) + case F32C(v) => "stack-push".reflectCtrlWith[Unit](v) + case F64C(v) => "stack-push".reflectCtrlWith[Unit](v) } - def shift(offset: Rep[Int], size: Rep[Int]): Rep[Unit] = { - if (offset > 0) { - "stack-shift".reflectCtrlWith[Unit](offset, size) - "sym-stack-shift".reflectCtrlWith[Unit](offset, size) - } + def pushS(num: StagedSymbolicNum) = num match { + case I32S(s) => "sym-stack-push".reflectCtrlWith[Unit](s) + case I64S(s) => "sym-stack-push".reflectCtrlWith[Unit](s) + case F32S(s) => "sym-stack-push".reflectCtrlWith[Unit](s) + case F64S(s) => "sym-stack-push".reflectCtrlWith[Unit](s) + } + + def takeC(types: List[ValueType]): List[StagedConcreteNum] = types match { + case Nil => Nil + case t :: ts => + val v = popC(t) + val rest = takeC(ts) + v :: rest + } + + def takeS(types: List[ValueType]): List[StagedSymbolicNum] = types match { + case Nil => Nil + case t :: ts => + val v = popS(t) + val rest = takeS(ts) + v :: rest } def print(): Rep[Unit] = { @@ -522,41 +919,72 @@ trait StagedWasmEvaluator extends SAIOps { } object Frames { - def get(i: Int)(implicit ctx: Context): StagedNum = { + def getC(i: Int)(implicit ctx: Context): StagedConcreteNum = { // val offset = ctx.frameTypes.take(i).map(_.size).sum ctx.frameTypes(i) match { - case NumType(I32Type) => I32("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) - case NumType(I64Type) => I64("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) - case NumType(F32Type) => F32("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) - case NumType(F64Type) => F64("frame-get".reflectCtrlWith[Num](i), "sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(I32Type) => I32C("frame-get".reflectCtrlWith[Num](i)) + case NumType(I64Type) => I64C("frame-get".reflectCtrlWith[Num](i)) + case NumType(F32Type) => F32C("frame-get".reflectCtrlWith[Num](i)) + case NumType(F64Type) => F64C("frame-get".reflectCtrlWith[Num](i)) } } - def set(i: Int, v: StagedNum)(implicit ctx: Context): Rep[Unit] = { - // val offset = ctx.frameTypes.take(i).map(_.size).sum + def getS(i: Int)(implicit ctx: Context): StagedSymbolicNum = { + ctx.frameTypes(i) match { + case NumType(I32Type) => I32S("sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(I64Type) => I64S("sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(F32Type) => F32S("sym-frame-get".reflectCtrlWith[SymVal](i)) + case NumType(F64Type) => F64S("sym-frame-get".reflectCtrlWith[SymVal](i)) + } + } + + def setC(i: Int, v: StagedConcreteNum): Rep[Unit] = { v match { - case I32(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) - case I64(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) - case F32(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) - case F64(v, s) => "frame-set".reflectCtrlWith[Unit](i, v); "sym-frame-set".reflectCtrlWith[Unit](i, s) + case I32C(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case I64C(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case F32C(v) => "frame-set".reflectCtrlWith[Unit](i, v) + case F64C(v) => "frame-set".reflectCtrlWith[Unit](i, v) + } + } + + def setS(i: Int, s: StagedSymbolicNum): Rep[Unit] = { + s match { + case I32S(s) => "sym-frame-set".reflectCtrlWith[Unit](i, s) + case I64S(s) => "sym-frame-set".reflectCtrlWith[Unit](i, s) + case F32S(s) => "sym-frame-set".reflectCtrlWith[Unit](i, s) + case F64S(s) => "sym-frame-set".reflectCtrlWith[Unit](i, s) } } - def pushFrame(locals: List[ValueType]): Rep[Unit] = { + def pushFrameC(locals: List[ValueType]): Rep[Unit] = { // Predef.println(s"[DEBUG] push frame: $locals") val size = locals.size "frame-push".reflectCtrlWith[Unit](size) + } + + def pushFrameS(locals: List[ValueType]): Rep[Unit] = { + // Predef.println(s"[DEBUG] push frame: $locals") + val size = locals.size "sym-frame-push".reflectCtrlWith[Unit](size) } - def popFrame(size: Int): Rep[Unit] = { + def popFrameC(size: Int): Rep[Unit] = { "frame-pop".reflectCtrlWith[Unit](size) + } + + def popFrameS(size: Int): Rep[Unit] = { "sym-frame-pop".reflectCtrlWith[Unit](size) } - def putAll(args: List[StagedNum])(implicit ctx: Context): Rep[Unit] = { + def putAllC(args: List[StagedConcreteNum]): Rep[Unit] = { for ((arg, i) <- args.view.reverse.zipWithIndex) { - Frames.set(i, arg) + Frames.setC(i, arg) + } + } + + def putAllS(args: List[StagedSymbolicNum]): Rep[Unit] = { + for ((arg, i) <- args.view.reverse.zipWithIndex) { + Frames.setS(i, arg) } } } @@ -567,8 +995,12 @@ trait StagedWasmEvaluator extends SAIOps { // todo: store symbolic value to memory via extract/concat operation } - def loadInt(base: Rep[Int], offset: Int): StagedNum = { - I32("I32V".reflectCtrlWith[Num]("memory-load-int".reflectCtrlWith[Int](base, offset)), "sym-load-int-todo".reflectCtrlWith[SymVal](base, offset)) + def loadIntC(base: Rep[Int], offset: Int): StagedConcreteNum = { + I32C("I32V".reflectCtrlWith[Num]("memory-load-int".reflectCtrlWith[Int](base, offset))) + } + + def loadIntS(base: Rep[Int], offset: Int): StagedSymbolicNum = { + I32S("sym-load-int-todo".reflectCtrlWith[SymVal](base, offset)) } // Returns the previous memory size on success, or -1 if the memory cannot be grown. @@ -603,29 +1035,47 @@ trait StagedWasmEvaluator extends SAIOps { // global read/write object Globals { - def apply(i: Int): StagedNum = { + def getC(i: Int): StagedConcreteNum = { module.globals(i).ty match { - case GlobalType(NumType(I32Type), _) => I32("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) - case GlobalType(NumType(I64Type), _) => I64("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) - case GlobalType(NumType(F32Type), _) => F32("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) - case GlobalType(NumType(F64Type), _) => F64("global-get".reflectCtrlWith[Num](i), "sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(I32Type), _) => I32C("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(I64Type), _) => I64C("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(F32Type), _) => F32C("global-get".reflectCtrlWith[Num](i)) + case GlobalType(NumType(F64Type), _) => F64C("global-get".reflectCtrlWith[Num](i)) } } - def update(i: Int, v: StagedNum): Rep[Unit] = { + def getS(i: Int): StagedSymbolicNum = { module.globals(i).ty match { - case GlobalType(NumType(I32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) - case GlobalType(NumType(I64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) - case GlobalType(NumType(F32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) - case GlobalType(NumType(F64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i);"sym-global-set".reflectCtrlWith[Unit](i, v.s) + case GlobalType(NumType(I32Type), _) => I32S("sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(I64Type), _) => I64S("sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(F32Type), _) => F32S("sym-global-get".reflectCtrlWith[SymVal](i)) + case GlobalType(NumType(F64Type), _) => F64S("sym-global-get".reflectCtrlWith[SymVal](i)) + } + } + + def setC(i: Int, v: StagedConcreteNum): Rep[Unit] = { + module.globals(i).ty match { + case GlobalType(NumType(I32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i) + case GlobalType(NumType(I64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i) + case GlobalType(NumType(F32Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i) + case GlobalType(NumType(F64Type), _) => "global-set".reflectCtrlWith[Unit](i, v.i) + } + } + + def setS(i: Int, s: StagedSymbolicNum): Rep[Unit] = { + module.globals(i).ty match { + case GlobalType(NumType(I32Type), _) => "sym-global-set".reflectCtrlWith[Unit](i, s.s) + case GlobalType(NumType(I64Type), _) => "sym-global-set".reflectCtrlWith[Unit](i, s.s) + case GlobalType(NumType(F32Type), _) => "sym-global-set".reflectCtrlWith[Unit](i, s.s) + case GlobalType(NumType(F64Type), _) => "sym-global-set".reflectCtrlWith[Unit](i, s.s) } } } // Exploration tree, object ExploreTree { - def fillWithIfElse(s: Rep[SymVal]): Rep[Unit] = { - "tree-fill-if-else".reflectCtrlWith[Unit](s) + def fillWithIfElse(sym: Rep[SymVal], snapshot: Rep[Snapshot]): Rep[Unit] = { + "tree-fill-if-else".reflectCtrlWith[Unit](sym, snapshot) } def fillWithFinished(): Rep[Unit] = { @@ -633,6 +1083,7 @@ trait StagedWasmEvaluator extends SAIOps { } def moveCursor(branch: Boolean): Rep[Unit] = { + // when moving cursor from to an unexplored node, we need to change the reuse state "tree-move-cursor".reflectCtrlWith[Unit](branch) } @@ -651,165 +1102,337 @@ trait StagedWasmEvaluator extends SAIOps { } } + object ReuseManager { + def isReusing: Rep[Boolean] = { + "reuse-is-reusing".reflectCtrlWith[Boolean]() + } + + def turnOnReuse(): Rep[Unit] = { + "reuse-turn-on".reflectCtrlWith[Unit]() + } + + def turnOffReuse(): Rep[Unit] = { + "reuse-turn-off".reflectCtrlWith[Unit]() + } + } + // runtime Num type - implicit class StagedNumOps(num: StagedNum) { + implicit class StagedConcreteNumOps(num: StagedConcreteNum) { + + def makeSymbolic(ty: ValueType): StagedSymbolicNum = num match { + case I32C(x) => I32S("make-symbolic-concrete".reflectCtrlWith[SymVal](num.toInt)) + } def toInt: Rep[Int] = "num-to-int".reflectCtrlWith[Int](num.i) - def isZero(): StagedNum = num match { - case I32(x_c, x_s) => I32(Values.I32V("is-zero".reflectCtrlWith[Int](num.toInt)), "sym-is-zero".reflectCtrlWith[SymVal](x_s)) + def isZero(): StagedConcreteNum = num match { + case I32C(x_c) => I32C(Values.I32V("is-zero".reflectCtrlWith[Int](num.toInt))) + } + + def clz(): StagedConcreteNum = num match { + case I32C(x) => I32C("clz".reflectCtrlWith[Num](x)) + case I64C(x) => I64C("clz".reflectCtrlWith[Num](x)) + } + + def ctz(): StagedConcreteNum = num match { + case I32C(x) => I32C("ctz".reflectCtrlWith[Num](x)) + case I64C(x) => I64C("ctz".reflectCtrlWith[Num](x)) + } + + def popcnt(): StagedConcreteNum = num match { + case I32C(x) => I32C("popcnt".reflectCtrlWith[Num](x)) + case I64C(x) => I64C("popcnt".reflectCtrlWith[Num](x)) + } + + def +(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-add".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-add".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-add".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-add".reflectCtrlWith[Num](x, y)) + } + } + + def -(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-sub".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-sub".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-sub".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-sub".reflectCtrlWith[Num](x, y)) + } + } + + def *(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-mul".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-mul".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-mul".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-mul".reflectCtrlWith[Num](x, y)) + } + } + + def /(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-div".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-div".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-div".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-div".reflectCtrlWith[Num](x, y)) + } + } + + def <<(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-shl".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-shl".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-shl".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-shl".reflectCtrlWith[Num](x, y)) + } + } + + def >>(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-shr".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-shr".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-shr".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-shr".reflectCtrlWith[Num](x, y)) + } + } + + def &(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("binary-and".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I64C("binary-and".reflectCtrlWith[Num](x, y)) + case (F32C(x), F32C(y)) => F32C("binary-and".reflectCtrlWith[Num](x, y)) + case (F64C(x), F64C(y)) => F64C("binary-and".reflectCtrlWith[Num](x, y)) + } + } + + def numEq(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-eq".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-eq".reflectCtrlWith[Num](x, y)) + } + } + + def numNe(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-ne".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-ne".reflectCtrlWith[Num](x, y)) + } + } + + def <(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-lt".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-lt".reflectCtrlWith[Num](x, y)) + } } - def clz(): StagedNum = num match { - case I32(x_c, x_s) => I32("clz".reflectCtrlWith[Num](x_c), "sym-clz".reflectCtrlWith[SymVal](x_s)) - case I64(x_c, x_s) => I64("clz".reflectCtrlWith[Num](x_c), "sym-clz".reflectCtrlWith[SymVal](x_s)) + def ltu(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-ltu".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-ltu".reflectCtrlWith[Num](x, y)) + } } - def ctz(): StagedNum = num match { - case I32(x_c, x_s) => I32("ctz".reflectCtrlWith[Num](x_c), "sym-ctz".reflectCtrlWith[SymVal](x_s)) - case I64(x_c, x_s) => I64("ctz".reflectCtrlWith[Num](x_c), "sym-ctz".reflectCtrlWith[SymVal](x_s)) + def >(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-gt".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-gt".reflectCtrlWith[Num](x, y)) + } } - def popcnt(): StagedNum = num match { - case I32(x_c, x_s) => I32("popcnt".reflectCtrlWith[Num](x_c), "sym-popcnt".reflectCtrlWith[SymVal](x_s)) - case I64(x_c, x_s) => I64("popcnt".reflectCtrlWith[Num](x_c), "sym-popcnt".reflectCtrlWith[SymVal](x_s)) + def gtu(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-gtu".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-gtu".reflectCtrlWith[Num](x, y)) + } } - def makeSymbolic(): Rep[SymVal] = { - "make-symbolic".reflectCtrlWith[SymVal](num.s) + def <=(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-le".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-le".reflectCtrlWith[Num](x, y)) + } + } + + def leu(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-leu".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-leu".reflectCtrlWith[Num](x, y)) + } } - def +(rhs: StagedNum): StagedNum = { + def >=(rhs: StagedConcreteNum): StagedConcreteNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-add".reflectCtrlWith[Num](x_c, y_c), "sym-binary-add".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32C(x), I32C(y)) => I32C("relation-ge".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-ge".reflectCtrlWith[Num](x, y)) } } + def geu(rhs: StagedConcreteNum): StagedConcreteNum = { + (num, rhs) match { + case (I32C(x), I32C(y)) => I32C("relation-geu".reflectCtrlWith[Num](x, y)) + case (I64C(x), I64C(y)) => I32C("relation-geu".reflectCtrlWith[Num](x, y)) + } + } + } + + implicit class StagedSymbolicNumOps(num: StagedSymbolicNum) { + def makeSymbolic(ty: ValueType): StagedSymbolicNum = num match { + case I32S(x) => I32S("make-symbolic".reflectCtrlWith[SymVal](x)) + case _ => throw new RuntimeException("Symbol index must be an i32") + } + + def isZero(): StagedSymbolicNum = num match { + case I32S(x) => I32S("sym-is-zero".reflectCtrlWith[SymVal](x)) + } + + def clz(): StagedSymbolicNum = num match { + case I32S(x) => I32S("sym-clz".reflectCtrlWith[SymVal](x)) + case I64S(x) => I64S("sym-clz".reflectCtrlWith[SymVal](x)) + } + + def ctz(): StagedSymbolicNum = num match { + case I32S(x) => I32S("sym-ctz".reflectCtrlWith[SymVal](x)) + case I64S(x) => I64S("sym-ctz".reflectCtrlWith[SymVal](x)) + } + + def popcnt(): StagedSymbolicNum = num match { + case I32S(x) => I32S("sym-popcnt".reflectCtrlWith[SymVal](x)) + case I64S(x) => I64S("sym-popcnt".reflectCtrlWith[SymVal](x)) + } + + def +(rhs: StagedSymbolicNum): StagedSymbolicNum = { + (num, rhs) match { + case (I32S(x), I32S(y)) => I32S("sym-binary-add".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-add".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-add".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-add".reflectCtrlWith[SymVal](x, y)) + } + } - def -(rhs: StagedNum): StagedNum = { + def -(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-sub".reflectCtrlWith[Num](x_c, y_c), "sym-binary-sub".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-sub".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-sub".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-sub".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-sub".reflectCtrlWith[SymVal](x, y)) } } - def *(rhs: StagedNum): StagedNum = { + def *(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-mul".reflectCtrlWith[Num](x_c, y_c), "sym-binary-mul".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-mul".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-mul".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-mul".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-mul".reflectCtrlWith[SymVal](x, y)) } } - def /(rhs: StagedNum): StagedNum = { + def /(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-div".reflectCtrlWith[Num](x_c, y_c), "sym-binary-div".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-div".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-div".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-div".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-div".reflectCtrlWith[SymVal](x, y)) } } - def <<(rhs: StagedNum): StagedNum = { + def <<(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-shl".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shl".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-shl".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-shl".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-shl".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-shl".reflectCtrlWith[SymVal](x, y)) } } - def >>(rhs: StagedNum): StagedNum = { + def >>(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-shr".reflectCtrlWith[Num](x_c, y_c), "sym-binary-shr".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-shr".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-shr".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-shr".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-shr".reflectCtrlWith[SymVal](x, y)) } } - def &(rhs: StagedNum): StagedNum = { + def &(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I64("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) - case (F32(x_c, x_s), F32(y_c, y_s)) => F32("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) - case (F64(x_c, x_s), F64(y_c, y_s)) => F64("binary-and".reflectCtrlWith[Num](x_c, y_c), "sym-binary-and".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-binary-and".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I64S("sym-binary-and".reflectCtrlWith[SymVal](x, y)) + case (F32S(x), F32S(y)) => F32S("sym-binary-and".reflectCtrlWith[SymVal](x, y)) + case (F64S(x), F64S(y)) => F64S("sym-binary-and".reflectCtrlWith[SymVal](x, y)) } } - def numEq(rhs: StagedNum): StagedNum = { + def numEq(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-eq".reflectCtrlWith[Num](x_c, y_c), "sym-relation-eq".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-eq".reflectCtrlWith[Num](x_c, y_c), "sym-relation-eq".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-eq".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-eq".reflectCtrlWith[SymVal](x, y)) } } - def numNe(rhs: StagedNum): StagedNum = { + def numNe(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ne".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ne".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ne".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ne".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-ne".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-ne".reflectCtrlWith[SymVal](x, y)) } } - def <(rhs: StagedNum): StagedNum = { + def <(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-lt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-lt".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-lt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-lt".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-lt".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-lt".reflectCtrlWith[SymVal](x, y)) } } - def ltu(rhs: StagedNum): StagedNum = { + def ltu(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ltu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ltu".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ltu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ltu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("relation-ltu".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("relation-ltu".reflectCtrlWith[SymVal](x, y)) } } - def >(rhs: StagedNum): StagedNum = { + def >(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-gt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gt".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-gt".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gt".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-gt".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-gt".reflectCtrlWith[SymVal](x, y)) } } - def gtu(rhs: StagedNum): StagedNum = { + def gtu(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-gtu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gtu".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-gtu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-gtu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-gtu".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-gtu".reflectCtrlWith[SymVal](x, y)) } } - def <=(rhs: StagedNum): StagedNum = { + def <=(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-le".reflectCtrlWith[Num](x_c, y_c), "sym-relation-le".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-le".reflectCtrlWith[Num](x_c, y_c), "sym-relation-le".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-le".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-le".reflectCtrlWith[SymVal](x, y)) } } - def leu(rhs: StagedNum): StagedNum = { + def leu(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-leu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-leu".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-leu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-leu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-leu".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-leu".reflectCtrlWith[SymVal](x, y)) } } - def >=(rhs: StagedNum): StagedNum = { + def >=(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-ge".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ge".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-ge".reflectCtrlWith[Num](x_c, y_c), "sym-relation-ge".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-ge".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-ge".reflectCtrlWith[SymVal](x, y)) } } - def geu(rhs: StagedNum): StagedNum = { + def geu(rhs: StagedSymbolicNum): StagedSymbolicNum = { (num, rhs) match { - case (I32(x_c, x_s), I32(y_c, y_s)) => I32("relation-geu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-geu".reflectCtrlWith[SymVal](x_s, y_s)) - case (I64(x_c, x_s), I64(y_c, y_s)) => I32("relation-geu".reflectCtrlWith[Num](x_c, y_c), "sym-relation-geu".reflectCtrlWith[SymVal](x_s, y_s)) + case (I32S(x), I32S(y)) => I32S("sym-relation-geu".reflectCtrlWith[SymVal](x, y)) + case (I64S(x), I64S(y)) => I32S("sym-relation-geu".reflectCtrlWith[SymVal](x, y)) } } } @@ -848,7 +1471,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { else if (m.toString.endsWith("I32V")) "I32V" else if (m.toString.endsWith("I64V")) "I64V" else if (m.toString.endsWith("SymVal")) "SymVal" - + else if (m.toString.endsWith("Snapshot")) "Snapshot_t" else super.remap(m) } @@ -902,10 +1525,14 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("Stack.push("); shallow(value); emit(")") case Node(_, "stack-shift", List(offset, size), _) => emit("Stack.shift("); shallow(offset); emit(", "); shallow(size); emit(")") + case Node(_, "sym-stack-shift", List(offset, size), _) => + emit("SymStack.shift("); shallow(offset); emit(", "); shallow(size); emit(")") case Node(_, "stack-pop", _, _) => emit("Stack.pop()") case Node(_, "sym-stack-pop", _, _) => emit("SymStack.pop()") + case Node(_, "snapshot-make", _, _) => + emit("Snapshot_t()") case Node(_, "frame-pop", List(i), _) => emit("Frames.popFrame("); shallow(i); emit(")") case Node(_, "sym-frame-pop", List(i), _) => @@ -935,7 +1562,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { case Node(_, "binary-add", List(lhs, rhs), _) => shallow(lhs); emit(" + "); shallow(rhs) case Node(_, "binary-sub", List(lhs, rhs), _) => - shallow(lhs); emit(" - "); shallow(rhs) + // todo: avoid using c++ operator, use explicit method call so operator's precedence issues won't exist + emit("("); shallow(lhs); emit(" - "); shallow(rhs); emit(")") case Node(_, "binary-mul", List(lhs, rhs), _) => shallow(lhs); emit(" * "); shallow(rhs) case Node(_, "binary-div", List(lhs, rhs), _) => @@ -990,12 +1618,14 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { shallow(num); emit(".toInt()") case Node(_, "make-symbolic", List(num), _) => shallow(num); emit(".makeSymbolic()") + case Node(_, "make-symbolic-concrete", List(num), _) => + emit("make_symbolic("); shallow(num); emit(")") case Node(_, "sym-env-read", List(sym), _) => emit("SymEnv.read("); shallow(sym); emit(")") case Node(_, "assert-true", List(cond), _) => emit("GENSYM_ASSERT("); shallow(cond); emit(")") - case Node(_, "tree-fill-if-else", List(s), _) => - emit("ExploreTree.fillIfElseNode("); shallow(s); emit(")") + case Node(_, "tree-fill-if-else", List(sym, snapshot), _) => + emit("ExploreTree.fillIfElseNode("); shallow(sym); emit(", "); shallow(snapshot); emit(")") case Node(_, "tree-fill-finished", List(), _) => emit("ExploreTree.fillFinishedNode()") case Node(_, "tree-move-cursor", List(b), _) => @@ -1006,6 +1636,8 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { emit("ExploreTree.dump_graphviz("); shallow(f); emit(")") case Node(_, "sym-not", List(s), _) => shallow(s); emit(".negate()") + case Node(_, "reuse-is-reusing", List(), _) => + emit("Reuse.is_reusing()") case Node(_, "dummy", _, _) => emit("std::monostate()") case Node(_, "dummy-op", _, _) => emit("std::monostate()") case Node(_, "no-op", _, _) => @@ -1063,12 +1695,12 @@ trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEv object WasmToCppCompiler { case class GeneratedCpp(source: String, headerFolders: List[String]) - def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean, dumpTree: Option[String]): GeneratedCpp = { + def compile(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean): GeneratedCpp = { println(s"Now compiling wasm module with entry function $main") val driver = new WasmToCppCompilerDriver[Unit, Unit] { def module: ModuleInstance = moduleInst def snippet(x: Rep[Unit]): Rep[Unit] = { - evalTop(main, printRes, dumpTree) + evalTop(main, printRes) } } GeneratedCpp(driver.code, driver.codegen.includePaths.toList) @@ -1079,8 +1711,8 @@ object WasmToCppCompiler { outputCpp: String, outputExe: String, printRes: Boolean, - dumpTree: Option[String]): Unit = { - val generated = compile(moduleInst, main, printRes, dumpTree) + macros: String*): Unit = { + val generated = compile(moduleInst, main, printRes) val code = generated.source val writer = new java.io.PrintWriter(new java.io.File(outputCpp)) @@ -1091,7 +1723,9 @@ object WasmToCppCompiler { } import sys.process._ - val command = s"g++ -std=c++20 $outputCpp -o $outputExe -O3 -g -l z3 " + generated.headerFolders.map(f => s"-I$f").mkString(" ") + val includeFlags = generated.headerFolders.map(f => s"-I$f").mkString(" ") + val macroFlags = macros.map(m => s"-D$m").mkString(" ") + val command = s"g++ -std=c++20 $outputCpp -o $outputExe -O3 -g -l z3 " + includeFlags + " " + macroFlags if (command.! != 0) { throw new RuntimeException(s"Compilation failed for $outputCpp") } diff --git a/src/test/scala/genwasym/TestStagedConcolicEval.scala b/src/test/scala/genwasym/TestStagedConcolicEval.scala index a65d0eda..48c24634 100644 --- a/src/test/scala/genwasym/TestStagedConcolicEval.scala +++ b/src/test/scala/genwasym/TestStagedConcolicEval.scala @@ -9,12 +9,24 @@ import gensym.wasm.parser._ import gensym.wasm.stagedconcolicminiwasm._ class TestStagedConcolicEval extends FunSuite { - def testFileToCpp(filename: String, main: Option[String] = None, expect: Option[List[Float]]=None) = { + def testFileConcolicCpp(filename: String, main: Option[String] = None) = { val moduleInst = ModuleInstance(Parser.parseFile(filename)) val cppFile = s"$filename.cpp" val exe = s"$cppFile.exe" val exploreTreeFile = s"$filename.tree.dot" - WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true, Some(exploreTreeFile)) + WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true) + + import sys.process._ + val result = Process(s"./$exe", None, "TREE_FILE" -> exploreTreeFile).!! + println(result) + } + + // only test concrete execution and its result + def testFileConcreteCpp(filename: String, main: Option[String] = None, expect: Option[List[Float]] = None) = { + val moduleInst = ModuleInstance(Parser.parseFile(filename)) + val cppFile = s"$filename.cpp" + val exe = s"$cppFile.exe" + WasmToCppCompiler.compileToExe(moduleInst, main, cppFile, exe, true, "NO_INFO", "RUN_ONCE") import sys.process._ val result = s"./$exe".!! @@ -30,13 +42,61 @@ class TestStagedConcolicEval extends FunSuite { }) } - test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main")) } + test("ack-cpp") { testFileConcolicCpp("./benchmarks/wasm/ack.wat", Some("real_main")) } test("bug-finding") { - testFileToCpp("./benchmarks/wasm/branch-strip-buggy.wat", Some("real_main")) + testFileConcolicCpp("./benchmarks/wasm/branch-strip-buggy.wat", Some("real_main")) } test("brtable-bug-finding") { - testFileToCpp("./benchmarks/wasm/staged/brtable_concolic.wat") + testFileConcolicCpp("./benchmarks/wasm/staged/brtable_concolic.wat") + } + + test("return-poly - concrete") { + testFileConcreteCpp("./benchmarks/wasm/staged/return_poly.wat", Some("$real_main"), expect=Some(List(42))) + } + test("ack-cpp - concrete") { testFileConcreteCpp("./benchmarks/wasm/ack.wat", Some("real_main"), expect=Some(List(7))) } + test("power - concrete") { testFileConcreteCpp("./benchmarks/wasm/pow.wat", Some("real_main"), expect=Some(List(1024))) } + test("start - concrete") { testFileConcreteCpp("./benchmarks/wasm/start.wat") } + test("fact - concrete") { testFileConcreteCpp("./benchmarks/wasm/fact.wat", None, expect=Some(List(120))) } + // TODO: Waiting more symbolic operators' implementations + // test("loop - concrete") { testFileConcreteCpp("./benchmarks/wasm/loop.wat", None, expect=Some(List(10))) } + test("even-odd - concrete") { testFileConcreteCpp("./benchmarks/wasm/even_odd.wat", None, expect=Some(List(1))) } + // TODO: Waiting symbolic memory's implementations + // test("load - concrete") { testFileConcreteCpp("./benchmarks/wasm/load.wat", None, expect=Some(List(1))) } + // test("btree - concrete") { testFileConcreteCpp("./benchmarks/wasm/btree/2o1u-unlabeled.wat") } + test("fib - concrete") { testFileConcreteCpp("./benchmarks/wasm/fib.wat", None, expect=Some(List(144))) } + test("tribonacci - concrete") { testFileConcreteCpp("./benchmarks/wasm/tribonacci.wat", None, expect=Some(List(504))) } + + // test("return - concrete") { + // Since all of the thrown exceptions had been captured in concolic driver, this test is not valid anymore + // intercept[java.lang.RuntimeException] { + // testFileConcreteCpp("./benchmarks/wasm/return.wat", Some("$real_main")) + // } + // } + + test("return_call - concrete") { + testFileConcreteCpp("./benchmarks/wasm/sum.wat", Some("sum10"), expect=Some(List(55))) + } + + test("block input - concrete") { + testFileConcreteCpp("./benchmarks/wasm/block.wat", Some("real_main"), expect=Some(List(9))) + } + test("loop block input - concrete") { + testFileConcreteCpp("./benchmarks/wasm/block.wat", Some("test_loop_input"), expect=Some(List(55))) + } + test("if block input - concrete") { + testFileConcreteCpp("./benchmarks/wasm/block.wat", Some("test_if_input"), expect=Some(List(25))) } + test("block input - poly br - concrete") { + testFileConcreteCpp("./benchmarks/wasm/block.wat", Some("test_poly_br"), expect=Some(List(0))) + } + test("loop block - poly br - concrete") { + testFileConcreteCpp("./benchmarks/wasm/loop_poly.wat", None, expect=Some(List(2, 1))) + } + + test("brtable-cpp - concrete") { + testFileConcreteCpp("./benchmarks/wasm/staged/brtable.wat") + } + }