diff --git a/core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala b/core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala index c6c872aeac..c8ca6a15c1 100644 --- a/core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala +++ b/core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala @@ -214,14 +214,7 @@ class AntiAliasing(override val s: Trees)(override val t: s.type)(using override def makeAssignment(pos: inox.utils.Position, resSelect: Expr, outerEffect: Effect, effect: Effect, effectCond: Option[Expr]): Expr = { val (cond, result) = select(pos, resSelect.getType, resSelect, outerEffect.path.toSeq) - val combinedCond = Some(cond).zip(effectCond).map { - case (BooleanLiteral(b1), c2) => - if (b1) c2 else BooleanLiteral(false) - case (c1, BooleanLiteral(b2)) => - if (b2) c1 else BooleanLiteral(false) - case (c1, c2) => - And(c1, c2) - }.getOrElse(cond) + val combinedCond = andJoin(cond +: effectCond.toSeq) // We only overwrite the receiver when it is an actual mutable type. // This is necessary to handle immutable types being upcasted to `Any`, which is mutable. diff --git a/core/src/main/scala/stainless/extraction/imperative/ImperativeCodeElimination.scala b/core/src/main/scala/stainless/extraction/imperative/ImperativeCodeElimination.scala index 07b602b064..da49c74baf 100644 --- a/core/src/main/scala/stainless/extraction/imperative/ImperativeCodeElimination.scala +++ b/core/src/main/scala/stainless/extraction/imperative/ImperativeCodeElimination.scala @@ -242,46 +242,75 @@ class ImperativeCodeElimination(override val s: Trees)(override val t: s.type) if (modifiedVars.isEmpty) fdWithoutSideEffects else { val freshVars: Seq[Variable] = modifiedVars.map(_.freshen) - val newParams: Seq[ValDef] = inner.params ++ freshVars.map(_.toVal) - val freshVarDecls: Seq[Variable] = freshVars.map(_.freshen) - - val rewritingMap: Map[Variable, Variable] = modifiedVars.zip(freshVarDecls).toMap - val freshBody = postMap { - case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e)) - case v: Variable => rewritingMap.get(v) - case _ => None - } (bd) - val wrappedBody = freshVars.zip(freshVarDecls).foldLeft(freshBody) { - (body, p) => LetVar(p._2.toVal, p._1, body) + def recHelper(body: Expr, + // This map is used by the postcondition rewriting to replace the result variable + // of the ensuring clause to its freshened couter-part + extraVarReplace: Map[Variable, Expr], + // Set to true only for the postcondition, as we want to not rewrite `old` expression + // found in the body of the function so that we can catch these invalid uses in ImperativeCleanup + rewriteOldExpr: Boolean) + (bodyWrapper: (Expr, Seq[Variable]) => Expr): (Expr, Expr => Expr, Seq[Variable]) = { + assert(extraVarReplace.keySet.intersect(modifiedVars.toSet).isEmpty) + val freshVarDecls: Seq[Variable] = modifiedVars.map(_.freshen) + val rewritingMap: Map[Variable, Variable] = modifiedVars.zip(freshVarDecls).toMap + + val freshBody = postMap { + case Assignment(v, e) => rewritingMap.get(v).map(nv => Assignment(nv, e)) + case v: Variable => rewritingMap.get(v).orElse(extraVarReplace.get(v)) + case Old(v: Variable) if rewriteOldExpr && freshVarDecls.contains(v) => + Some(freshVars(freshVarDecls.indexOf(v))) + case _ => None + } (body) + val wrappedBody = bodyWrapper(freshBody, freshVarDecls) + + val (res, scope, fun) = toFunction(wrappedBody)(using State(state.parent, Set(), + state.localsMapping.map { case (v, (fd, mvs)) => + (v, (fd, mvs.map(v => rewritingMap.getOrElse(v, v)))) + } + (fd.id -> (fd, freshVarDecls)) + )) + + (res, scope, freshVarDecls.map(fun)) } - val (fdRes, fdScope, fdFun) = toFunction(wrappedBody)(using State(state.parent, Set(), - state.localsMapping.map { case (v, (fd, mvs)) => - (v, (fd, mvs.map(v => rewritingMap.getOrElse(v, v)))) - } + (fd.id -> ((fd, freshVarDecls))) - )) - - val newRes = Tuple(fdRes +: freshVarDecls.map(fdFun)) + val (fdRes, fdScope, fdDecls) = recHelper(bd, Map.empty, rewriteOldExpr = false) { + case (freshBody, freshVarDecls) => + freshVars.zip(freshVarDecls).foldLeft(freshBody) { + (body, p) => LetVar(p._2.toVal, p._1, body) + } + } + val newRes = Tuple(fdRes +: fdDecls) val newBody = fdScope(newRes) - val newReturnType = TupleType(inner.returnType +: modifiedVars.map(_.tpe)) val newSpecs = specs.map { case Postcondition(post @ Lambda(Seq(res), postBody)) => + /* + Essentially translates: + (res: (R, T1, T2, ...)) => { + // ... + pureFnCapturingModifiedVars + // ... + } + (where R is the result type of the function in question, and T1, T2, ... the types of the modified vars) + into: + (res: (R, T1, T2, ...)) => { + val modVar1 = res._2 + val modVar2 = res._3 + // ... + pureFnCapturingModifiedVars(modVar1, modVar2, ...) + // ... + } + */ val newRes = ValDef(res.id.freshen, newReturnType) - - val newBody = replaceSingle( - (modifiedVars.zip(freshVars).map { case (ov, nv) => Old(ov) -> nv } ++ - modifiedVars.zipWithIndex.map { case (v, i) => - (v -> TupleSelect(newRes.toVariable, i+2)): (Expr, Expr) - } :+ (res.toVariable -> TupleSelect(newRes.toVariable, 1))).toMap, - postBody - ) - - val (r, scope, _) = toFunction(newBody) - Postcondition(Lambda(Seq(newRes), scope(r)).setPos(post)) + val (pcRes, pcScope, _) = recHelper(postBody, Map(res.toVariable -> TupleSelect(newRes.toVariable, 1)), rewriteOldExpr = true) { + case (freshBody, freshVarDecls) => + freshVarDecls.zipWithIndex.foldLeft(freshBody) { + case (body, (vr, ix)) => LetVar(vr.toVal, TupleSelect(newRes.toVariable, ix + 2), body) + } + } + Postcondition(Lambda(Seq(newRes), pcScope(pcRes)).setPos(post)) case spec => spec.transform { cond => val fresh = replaceFromSymbols((modifiedVars zip freshVars).toMap, cond) diff --git a/frontends/benchmarks/imperative/invalid/i1159a.scala b/frontends/benchmarks/imperative/invalid/i1159a.scala new file mode 100644 index 0000000000..258b6a8779 --- /dev/null +++ b/frontends/benchmarks/imperative/invalid/i1159a.scala @@ -0,0 +1,19 @@ +import stainless.lang._ + +object i1159a { + + def foo(): Unit = { + var i = 0 + def isZero = i == 0 + + def inside: Unit = { + require(i <= 10) + i += 1 + }.ensuring(_ => + isZero + ) + + inside + } + +} diff --git a/frontends/benchmarks/imperative/invalid/i1159b.scala b/frontends/benchmarks/imperative/invalid/i1159b.scala new file mode 100644 index 0000000000..2d93cb4160 --- /dev/null +++ b/frontends/benchmarks/imperative/invalid/i1159b.scala @@ -0,0 +1,18 @@ +import stainless.lang._ + +object i1159b { + + def foo(): Unit = { + var i = 0 + def isZero = i == 0 + + def inside: Unit = { + require(i <= 10) + }.ensuring(_ => + isZero // Reported as invalid, even though we never modify i (due to the hoisting of `inside`) + ) + + inside + } + +} diff --git a/frontends/benchmarks/imperative/invalid/i1159c.scala b/frontends/benchmarks/imperative/invalid/i1159c.scala new file mode 100644 index 0000000000..61109c8c28 --- /dev/null +++ b/frontends/benchmarks/imperative/invalid/i1159c.scala @@ -0,0 +1,22 @@ +import stainless.lang._ + +object i1159c { + + def foo(): Unit = { + var i: BigInt = 0 + var j: BigInt = 1 + def isAnswerToLife = i + j == 42 + + def inside: Unit = { + require(i <= 10) + i = 41 + j = 0 + // Oh no, we are off-by-one to the answer to life :( + }.ensuring(_ => + isAnswerToLife + ) + + inside + } + +} diff --git a/frontends/benchmarks/imperative/invalid/i1159d.scala b/frontends/benchmarks/imperative/invalid/i1159d.scala new file mode 100644 index 0000000000..1a24973422 --- /dev/null +++ b/frontends/benchmarks/imperative/invalid/i1159d.scala @@ -0,0 +1,23 @@ +import stainless.lang._ + +object i1159d { + + def foo(): Unit = { + var i: BigInt = 0 + var j: BigInt = 1 + def evenSumIsEven = { + require(i % 2 == 0 && j % 2 == 0) + (i + j) % 2 == 0 + } + + def inside: Unit = { + i = 41 + j = 0 + }.ensuring(_ => + evenSumIsEven + ) + + inside + } + +} diff --git a/frontends/benchmarks/imperative/valid/i1159.scala b/frontends/benchmarks/imperative/valid/i1159.scala new file mode 100644 index 0000000000..96e2ee5e7f --- /dev/null +++ b/frontends/benchmarks/imperative/valid/i1159.scala @@ -0,0 +1,56 @@ +import stainless.lang._ + +object i1159 { + + def foo1(): Unit = { + var i = 1234 + def isZero = i == 0 + + def inside: Unit = { + require(i <= 2000) + i = 0 + }.ensuring(_ => + isZero + ) + + inside + } + + def foo2(): Unit = { + var i: BigInt = 0 + var j: BigInt = 1 + def isAnswerToLife = i + j == 42 + + def inside(wantAnswer: Boolean): Unit = { + if (wantAnswer) { + i = 22 + j = 20 + } else { + i = 0xbadca11 + j = 0 + } + }.ensuring(_ => + wantAnswer ==> isAnswerToLife + ) + + // We don't know what we want + (inside(true), inside(false)) + } + + def foo3(): Unit = { + var i: BigInt = 0 + var j: BigInt = 1 + def evenSumIsEven = { + require(i % 2 == 0 && j % 2 == 0) + (i + j) % 2 == 0 + } + + def inside(newI: BigInt, newJ: BigInt): Unit = { + i = newI + j = newJ + }.ensuring(_ => + (newI % 2 == 0 && newJ % 2 == 0) ==> evenSumIsEven + ) + } + +}