Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1159 #1246

Merged
merged 1 commit into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,7 @@ class AntiAliasing(override val s: Trees)(override val t: s.type)(using override

def makeAssignment(pos: inox.utils.Position, resSelect: Expr, outerEffect: Effect, effect: Effect, effectCond: Option[Expr]): Expr = {
val (cond, result) = select(pos, resSelect.getType, resSelect, outerEffect.path.toSeq)
val combinedCond = Some(cond).zip(effectCond).map {
case (BooleanLiteral(b1), c2) =>
if (b1) c2 else BooleanLiteral(false)
case (c1, BooleanLiteral(b2)) =>
if (b2) c1 else BooleanLiteral(false)
case (c1, c2) =>
And(c1, c2)
}.getOrElse(cond)
val combinedCond = andJoin(cond +: effectCond.toSeq)

// We only overwrite the receiver when it is an actual mutable type.
// This is necessary to handle immutable types being upcasted to `Any`, which is mutable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,46 +242,75 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type)

if (modifiedVars.isEmpty) fdWithoutSideEffects else {
val freshVars: Seq[Variable] = modifiedVars.map(_.freshen)

val newParams: Seq[ValDef] = inner.params ++ freshVars.map(_.toVal)
val freshVarDecls: Seq[Variable] = freshVars.map(_.freshen)

val rewritingMap: Map[Variable, Variable] = modifiedVars.zip(freshVarDecls).toMap
val freshBody = postMap {
case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e))
case v: Variable => rewritingMap.get(v)
case _ => None
} (bd)

val wrappedBody = freshVars.zip(freshVarDecls).foldLeft(freshBody) {
(body, p) => LetVar(p._2.toVal, p._1, body)
def recHelper(body: Expr,
// This map is used by the postcondition rewriting to replace the result variable
// of the ensuring clause to its freshened couter-part
extraVarReplace: Map[Variable, Expr],
// Set to true only for the postcondition, as we want to not rewrite `old` expression
// found in the body of the function so that we can catch these invalid uses in ImperativeCleanup
rewriteOldExpr: Boolean)
(bodyWrapper: (Expr, Seq[Variable]) => Expr): (Expr, Expr => Expr, Seq[Variable]) = {
assert(extraVarReplace.keySet.intersect(modifiedVars.toSet).isEmpty)
val freshVarDecls: Seq[Variable] = modifiedVars.map(_.freshen)
val rewritingMap: Map[Variable, Variable] = modifiedVars.zip(freshVarDecls).toMap

val freshBody = postMap {
case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e))
case v: Variable => rewritingMap.get(v).orElse(extraVarReplace.get(v))
case Old(v: Variable) if rewriteOldExpr && freshVarDecls.contains(v) =>
Some(freshVars(freshVarDecls.indexOf(v)))
case _ => None
} (body)
val wrappedBody = bodyWrapper(freshBody, freshVarDecls)

val (res, scope, fun) = toFunction(wrappedBody)(using State(state.parent, Set(),
state.localsMapping.map { case (v, (fd, mvs)) =>
(v, (fd, mvs.map(v => rewritingMap.getOrElse(v, v))))
} + (fd.id -> (fd, freshVarDecls))
))

(res, scope, freshVarDecls.map(fun))
}

val (fdRes, fdScope, fdFun) = toFunction(wrappedBody)(using State(state.parent, Set(),
state.localsMapping.map { case (v, (fd, mvs)) =>
(v, (fd, mvs.map(v => rewritingMap.getOrElse(v, v))))
} + (fd.id -> ((fd, freshVarDecls)))
))

val newRes = Tuple(fdRes +: freshVarDecls.map(fdFun))
val (fdRes, fdScope, fdDecls) = recHelper(bd, Map.empty, rewriteOldExpr = false) {
case (freshBody, freshVarDecls) =>
freshVars.zip(freshVarDecls).foldLeft(freshBody) {
(body, p) => LetVar(p._2.toVal, p._1, body)
}
}
val newRes = Tuple(fdRes +: fdDecls)
val newBody = fdScope(newRes)

val newReturnType = TupleType(inner.returnType +: modifiedVars.map(_.tpe))

val newSpecs = specs.map {
case Postcondition(post @ Lambda(Seq(res), postBody)) =>
/*
Essentially translates:
(res: (R, T1, T2, ...)) => {
// ...
pureFnCapturingModifiedVars
// ...
}
(where R is the result type of the function in question, and T1, T2, ... the types of the modified vars)
into:
(res: (R, T1, T2, ...)) => {
val modVar1 = res._2
val modVar2 = res._3
// ...
pureFnCapturingModifiedVars(modVar1, modVar2, ...)
// ...
}
*/
val newRes = ValDef(res.id.freshen, newReturnType)

val newBody = replaceSingle(
(modifiedVars.zip(freshVars).map { case (ov, nv) => Old(ov) -> nv } ++
modifiedVars.zipWithIndex.map { case (v, i) =>
(v -> TupleSelect(newRes.toVariable, i+2)): (Expr, Expr)
} :+ (res.toVariable -> TupleSelect(newRes.toVariable, 1))).toMap,
postBody
)

val (r, scope, _) = toFunction(newBody)
Postcondition(Lambda(Seq(newRes), scope(r)).setPos(post))
val (pcRes, pcScope, _) = recHelper(postBody, Map(res.toVariable -> TupleSelect(newRes.toVariable, 1)), rewriteOldExpr = true) {
case (freshBody, freshVarDecls) =>
freshVarDecls.zipWithIndex.foldLeft(freshBody) {
case (body, (vr, ix)) => LetVar(vr.toVal, TupleSelect(newRes.toVariable, ix + 2), body)
}
}
Postcondition(Lambda(Seq(newRes), pcScope(pcRes)).setPos(post))

case spec => spec.transform { cond =>
val fresh = replaceFromSymbols((modifiedVars zip freshVars).toMap, cond)
Expand Down
19 changes: 19 additions & 0 deletions frontends/benchmarks/imperative/invalid/i1159a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import stainless.lang._

object i1159a {

def foo(): Unit = {
var i = 0
def isZero = i == 0

def inside: Unit = {
require(i <= 10)
i += 1
}.ensuring(_ =>
isZero
)

inside
}

}
18 changes: 18 additions & 0 deletions frontends/benchmarks/imperative/invalid/i1159b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import stainless.lang._

object i1159b {

def foo(): Unit = {
var i = 0
def isZero = i == 0

def inside: Unit = {
require(i <= 10)
}.ensuring(_ =>
isZero // Reported as invalid, even though we never modify i (due to the hoisting of `inside`)
)

inside
}

}
22 changes: 22 additions & 0 deletions frontends/benchmarks/imperative/invalid/i1159c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import stainless.lang._

object i1159c {

def foo(): Unit = {
var i: BigInt = 0
var j: BigInt = 1
def isAnswerToLife = i + j == 42

def inside: Unit = {
require(i <= 10)
i = 41
j = 0
// Oh no, we are off-by-one to the answer to life :(
}.ensuring(_ =>
isAnswerToLife
)

inside
}

}
23 changes: 23 additions & 0 deletions frontends/benchmarks/imperative/invalid/i1159d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import stainless.lang._

object i1159d {

def foo(): Unit = {
var i: BigInt = 0
var j: BigInt = 1
def evenSumIsEven = {
require(i % 2 == 0 && j % 2 == 0)
(i + j) % 2 == 0
}

def inside: Unit = {
i = 41
j = 0
}.ensuring(_ =>
evenSumIsEven
)

inside
}

}
56 changes: 56 additions & 0 deletions frontends/benchmarks/imperative/valid/i1159.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import stainless.lang._

object i1159 {

def foo1(): Unit = {
var i = 1234
def isZero = i == 0

def inside: Unit = {
require(i <= 2000)
i = 0
}.ensuring(_ =>
isZero
)

inside
}

def foo2(): Unit = {
var i: BigInt = 0
var j: BigInt = 1
def isAnswerToLife = i + j == 42

def inside(wantAnswer: Boolean): Unit = {
if (wantAnswer) {
i = 22
j = 20
} else {
i = 0xbadca11
j = 0
}
}.ensuring(_ =>
wantAnswer ==> isAnswerToLife
)

// We don't know what we want
(inside(true), inside(false))
}

def foo3(): Unit = {
var i: BigInt = 0
var j: BigInt = 1
def evenSumIsEven = {
require(i % 2 == 0 && j % 2 == 0)
(i + j) % 2 == 0
}

def inside(newI: BigInt, newJ: BigInt): Unit = {
i = newI
j = newJ
}.ensuring(_ =>
(newI % 2 == 0 && newJ % 2 == 0) ==> evenSumIsEven
)
}

}