Skip to content

Commit

Permalink
AntiAliasing: avoid rebuilding mutated objects when possible (#1507)
Browse files Browse the repository at this point in the history
  • Loading branch information
mario-bucev authored Apr 12, 2024
1 parent e9f9193 commit 97caa41
Show file tree
Hide file tree
Showing 13 changed files with 642 additions and 35 deletions.
213 changes: 181 additions & 32 deletions core/src/main/scala/stainless/extraction/imperative/AntiAliasing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -256,39 +256,187 @@ class AntiAliasing(override val s: Trees)(override val t: s.type)(using override
}

// NOTE: `args` must refer to the arguments of the function invocation before transformation (the original args)
def mapApplication(formalArgs: Seq[ValDef], args: Seq[Expr], nfi: Expr, nfiType: Type, fiEffects: Set[Effect], env: Env): Expr = {
def mapApplication(formalArgs: Seq[ValDef], args: Seq[Expr], nfi: Expr, nfiType: Type, fiEffects: Set[Effect], isOpaqueOrExtern: Boolean, env: Env): Expr = {

def affectedBindings(updTarget: Target, isReplacement: Boolean): Map[ValDef, Set[Target]] = {
def isAffected(t: Target): Boolean = {
if (isReplacement) t.maybeProperPrefixOf(updTarget)
else t.maybePrefixOf(updTarget) || updTarget.maybePrefixOf(t)
}
env.targets.map {
case (vd, targets) =>
val affected = targets.filter(isAffected)
vd -> affected
}.filter(_._2.nonEmpty)
}

if (fiEffects.exists(e => formalArgs contains e.receiver.toVal)) {
val localEffects: Seq[Set[(Effect, Set[(Effect, Option[Expr])])]] = (formalArgs zip args)
.map { case (vd, arg) => (fiEffects.filter(_.receiver == vd.toVariable), arg) }
.filter { case (effects, _) => effects.nonEmpty }
.map { case (effects, arg) => effects map (e => (e, e on arg)) }
val localEffects = formalArgs.zip(args)
.map { case (vd, arg) =>
// Effects for each parameter
(vd.toVariable, fiEffects.filter(_.receiver == vd.toVariable), arg)
}.filter { case (_, effects, _) => effects.nonEmpty }

val freshRes = ValDef(FreshIdentifier("res"), nfiType).copiedFrom(nfi)

val assgns = (for {
(effects, index) <- localEffects.zipWithIndex
(outerEffect0, innerEffects) <- effects
(effect0, effectCond) <- innerEffects
} yield {
val outerEffect = outerEffect0.removeUnknownAccessor
val effect = effect0.removeUnknownAccessor
val pos = args(index).getPos
val resSelect = TupleSelect(freshRes.toVariable, index + 2)
// Follow all aliases of the updated target (may include self if it has no alias)
val primaryTargs = dealiasTarget(effect.toTarget(effectCond), env)
val assignedPrimaryTargs = primaryTargs.toSeq
.map(t => makeAssignment(pos, resSelect, outerEffect.path.path, t, dropVcs = true))
val updAliasingValDefs = updatedAliasingValDefs(primaryTargs, env, pos)

assignedPrimaryTargs ++ updAliasingValDefs
}).flatten
val extractResults = Block(assgns, TupleSelect(freshRes.toVariable, 1))

if (isMutableType(nfiType)) {
LetVar(freshRes, nfi, extractResults)
} else {
Let(freshRes, nfi, extractResults)
val assgns = localEffects.zipWithIndex.flatMap {
case ((vd, effects, arg), effIndex) =>
// +1 because we are a tuple and +1 because the first component is for the result of the function
val resSelect = TupleSelect(freshRes.toVariable, effIndex + 2)
// All effects on the given parameter, applied to the given argument
val paramWithArgsEffect = for {
outerEffect0 <- effects
(effect0, effectCond) <- outerEffect0 on arg
} yield {
val outerEffect = outerEffect0.removeUnknownAccessor
val effect = effect0.removeUnknownAccessor
val primaryTargs = dealiasTarget(effect.toTarget(effectCond), env)
(outerEffect, primaryTargs)
}
// Suppose we have the following definitions:
// case class Ref(var x: Int, var y: Int)
// case class RefRef(var lhs: Ref, var rhs: Ref)
//
// def modifyLhs(rr: RefRef, v: Int): Unit = {
// rr.lhs.x = v
// rr.lhs.y = v
// }
// def test1(testRR: RefRef): Unit = {
// val rrAlias = testRR
// val lhsAlias = testRR.lhs
// modifyLhs(testRR, 123)
// // ...
// }
// `modifyLhs` is (essentially) transformed as follows by `AntiAliasing` (not here in `mapApplication`):
// def modifyLhs(rr: RefRef, v: Int): (Unit, RefRef) = {
// ((), RefRef(Ref(v, v), rr.rhs)
// }
// The transformed `modifyLhs` returns a copy of the "updated" `rr`.
//
// Our task here in `mapApplication` is to transform the call to `modifyLhs`.
// Intuitively, in this case, we can "update" `testRR` to point to the "updated" version
// returned by `modifyLhs`, and update the aliases accordingly:
// def test1(testRR: RefRef): Unit = {
// val rrAlias = testRR
// val lhsAlias = testRR.lhs
// val res = modifyLhs(testRR, 123)
// testRR = res._2
// rrAlias = testRR
// lhsAlias = testRR.lhs
// // ...
// }
// We can do so because we know precisely the `Targets` of the argument, namely `testRR`
// and we can update its aliases accordingly.
// This correspond to the `Success` case of having a `ModifyingEffect` on `vd` (here: `rr`)
// applied on `arg` (here: `testRR`).
//
// However, sometimes, we may not always succeed in computing the precise targets,
// as in the following example:
// def test2(testRR: RefRef): Unit = {
// val lhsAlias = testRR.lhs
// val rhsAlias = testRR.rhs
// modifyLhs(RefRef(lhsAlias, rhsAlias), 123)
// }
// Here, we are not able to compute the targets of `RefRef(lhsAlias, rhsAlias)`,
// which corresponds to the `Failure` case. As such, we cannot simply "update"
// the `testRR` variable using the returned result as-is (as we did for `test1`).
//
// Instead, we need to apply each effect of `modifyLhs` *individually* on the argument.
// The effects for `modifyLhs` are (stored in `localEffects`):
// rr -> Set(ReplacementEffect(rr.lhs.x), ReplacementEffect(rr.lhs.y)))
// So we need to apply two `ReplacementEffect`, one on `rr.lhs.x` and one on `rr.lhs.y` on the argument.
// Doing so with `paramWithArgsEffect` gives us:
// ReplacementEffect(rr.lhs.x) -> Set(Target(testRR, None, .lhs.x))
// ReplacementEffect(rr.lhs.y) -> Set(Target(testRR, None, .lhs.y))
// which we can then use to update `testRR` (alongside their aliases):
// def test2(testRR: RefRef): Unit = {
// var lhsAlias: Ref = testRR.lhs
// val rhsAlias: Ref = testRR.rhs
// val res: (Unit, RefRef) = modifyLhs(RefRef(lhsAlias, rhsAlias), 123)
// // Note that we "update" each field individually, this is due to
// // having each effect applied separately!
// testRR = RefRef(Ref(res._2.lhs.x, testRR.lhs.y), testRR.rhs)
// lhsAlias = testRR.lhs
// testRR = RefRef(Ref(testRR.lhs.x, res._2.lhs.y), testRR.rhs)
// lhsAlias = testRR.lhs
// // ...
// }
//
// Note that we can always apply this second technique even if we have precise aliases.
// However, this tends to "rebuild" the object instead of reusing the "updated" result
// which can lead to verification inefficiency (and does not work well in presence of
// @opaque or @extern functions).
Try(ModifyingEffect(vd, Path.empty).on(arg)) match {
case Success(modEffect) =>
// Update everything that the argument is aliasing
val primaryTargs = modEffect.flatMap { case (eff, cond) => dealiasTarget(eff.toTarget(cond), env) }
val assignedPrimaryTargs = primaryTargs
// The order of assignments does not matter between "primary targets"
// but it must precede the update of aliases (`updAliasingValDefs`)
.toSeq
.map(t => makeAssignment(arg.getPos, resSelect, Seq.empty, t, dropVcs = true))
// We need to be careful with what we are updating here.
// If we expand on the above example with the following function:
// def t3(refref: RefRef): Unit = {
// val lhs = refref.lhs
// val oldLhs = lhs.x
// replaceLhs(refref, 123)
// assert(lhs.x == oldLhs)
// assert(refref.lhs.x == 123)
// }
// In `replaceLhs`, we have a ReplacementEffect on `rr.lhs`, this means
// that `rr.lhs` is replaced with a new `Ref`, leaving all aliases of `rr.lhs`
// (in `t3`, the `val lhs`) untouched. So, after the call to `replaceLhs`,
// any modification to `rr.lhs` do not alter the other aliases (here, `lhs`).
// The function `t3` should be transformed as follows:
// def t3(refref: RefRef): Unit = {
// val lhs = refref.lhs
// val oldLhs = lhs.x
// val res = replaceLhs(refref, 123)
// refref = res._2
// assert(lhs.x == oldLhs)
// assert(refref.lhs.x == 123)
// }
// In particular, note that we *do not* touch `lhs`: the following transformation is incorrect:
// var lhs = refref.lhs
// val oldLhs = lhs.x
// val res = replaceLhs(refref, 123)
// refref = res._2
// lhs = refref.lhs
// because after the call to `replaceLhs`, `lhs` and `refref.lhs` become unrelated.
// Note that, for @opaque and @extern function, we assume the object was mutated in each of its field
// and therefore update all aliases.
val aliasingVds = {
if (isOpaqueOrExtern) {
primaryTargs.flatMap(affectedBindings(_, false))
} else {
paramWithArgsEffect.flatMap {
case (eff, targs) =>
targs.flatMap(affectedBindings(_, eff.kind == ReplacementKind))
}
}
}
val updAliasingValDefs = aliasingVds
.toSeq // See comment on `assignedPrimaryTargs`
.flatMap { case (vd, targs) =>
targs.map(t => makeAssignment(arg.getPos, t.wrap.get, Seq.empty, Target(vd.toVariable, t.condition, Path.empty), true))
}
assignedPrimaryTargs ++ updAliasingValDefs
case Failure(_) =>
paramWithArgsEffect.toSeq.flatMap { case (outerEffect, primaryTargs) =>
// Update everything that the argument is aliasing
val assignedPrimaryTargs = primaryTargs
.toSeq
.map(t => makeAssignment(arg.getPos, resSelect, outerEffect.path.path, t, dropVcs = true))
// Update everything aliasing the argument
val updAliasingValDefs = updatedAliasingValDefs(primaryTargs, env, arg.getPos)
assignedPrimaryTargs ++ updAliasingValDefs
}
}
}

val extractResults = Block(assgns, TupleSelect(freshRes.toVariable, 1))
Let(freshRes, nfi, extractResults)
} else {
nfi
}
Expand Down Expand Up @@ -724,8 +872,8 @@ class AntiAliasing(override val s: Trees)(override val t: s.type)(using override
val nfi = FunctionInvocation(
id, tps, args.map(transform(_, env))
).copiedFrom(fi)

mapApplication(fd.params, args, nfi, fi.tfd.instantiate(analysis.getReturnType(fd)), effects(fd), env)
val isExternOrOpaque = symbols.getFunction(id).flags.exists(f => f == Extern || f == Opaque)
mapApplication(fd.params, args, nfi, fi.tfd.instantiate(analysis.getReturnType(fd)), effects(fd), isExternOrOpaque, env)

case alr @ ApplyLetRec(id, tparams, tpe, tps, args) =>
val fd = Inner(env.locals(id))
Expand All @@ -752,7 +900,8 @@ class AntiAliasing(override val s: Trees)(override val t: s.type)(using override
).copiedFrom(alr)

val resultType = typeOps.instantiateType(analysis.getReturnType(fd), (tparams zip tps).toMap)
mapApplication(fd.params, args, nfi, resultType, effects(fd), env)
val isExternOrOpaque = env.locals(id).flags.exists(f => f == Extern || f == Opaque)
mapApplication(fd.params, args, nfi, resultType, effects(fd), isExternOrOpaque, env)

case app @ Application(callee, args) =>
checkAliasing(app, args, env)
Expand All @@ -770,7 +919,7 @@ class AntiAliasing(override val s: Trees)(override val t: s.type)(using override
case (vd, i) if ftEffects(i) => ModifyingEffect(vd.toVariable, Path.empty)
}
val to = makeFunctionTypeExplicit(ft).asInstanceOf[FunctionType].to
mapApplication(params, args, nfi, to, appEffects.toSet, env)
mapApplication(params, args, nfi, to, appEffects.toSet, false, env)
} else {
Application(transform(callee, env), args.map(transform(_, env))).copiedFrom(app)
}
Expand Down
4 changes: 3 additions & 1 deletion frontends/benchmarks/imperative/invalid/ExternMutation.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import stainless.lang._
import stainless.annotation._
import StaticChecks._

object ExternMutation {
case class Box(var value: BigInt)
Expand All @@ -8,7 +10,7 @@ object ExternMutation {
def f2(b: Container[Box]): Unit = ???

def g2(b: Container[Box]) = {
val b0 = b
@ghost val b0 = snapshot(b)
f2(b)
assert(b == b0) // fails because `Container` is mutable
}
Expand Down
32 changes: 32 additions & 0 deletions frontends/benchmarks/imperative/invalid/OpaqueMutation1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import stainless.lang.{ghost => ghostExpr, _}
import stainless.proof._
import stainless.annotation._
import StaticChecks._

object OpaqueMutation1 {

case class Box(var cnt: BigInt, var other: BigInt) {
@opaque // Note the opaque
def secretSauce(x: BigInt): BigInt = cnt + x // Nobody thought of it!

@opaque // Note the opaque here as well
def increment(): Unit = {
@ghost val oldBox = snapshot(this)
cnt += 1
ghostExpr {
unfold(secretSauce(other))
unfold(oldBox.secretSauce(other))
check(oldBox.secretSauce(other) + 1 == this.secretSauce(other))
}
}.ensuring(_ => old(this).secretSauce(other) + 1 == this.secretSauce(other))
}

def test(b: Box): Unit = {
@ghost val oldBox = snapshot(b)
b.increment()
// Note that, even though the implementation of `increment` does not alter `other`,
// we do not have that knowledge here since the function is marked as opaque.
// Therefore, the following is incorrect (but it holds for `b.other`, see the other `valid/OpaqueMutation`)
assert(oldBox.secretSauce(oldBox.other) + 1 == b.secretSauce(oldBox.other))
}
}
33 changes: 33 additions & 0 deletions frontends/benchmarks/imperative/invalid/OpaqueMutation2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import stainless.lang.{ghost => ghostExpr, _}
import stainless.proof._
import stainless.annotation._
import StaticChecks._

object OpaqueMutation2 {
case class SmallerBox(var otherCnt: BigInt)

case class Box(var cnt: BigInt, var smallerBox: SmallerBox) {
@opaque // Note the opaque
def secretSauce(x: BigInt): BigInt = cnt + x // Nobody thought of it!

@opaque // Note the opaque here as well
def increment(): Unit = {
@ghost val oldBox = snapshot(this)
cnt += 1
ghostExpr {
unfold(secretSauce(smallerBox.otherCnt))
unfold(oldBox.secretSauce(smallerBox.otherCnt))
check(oldBox.secretSauce(smallerBox.otherCnt) + 1 == this.secretSauce(smallerBox.otherCnt))
}
}.ensuring(_ => old(this).secretSauce(smallerBox.otherCnt) + 1 == this.secretSauce(smallerBox.otherCnt))
}

def test(b: Box): Unit = {
@ghost val oldBox = snapshot(b)
b.increment()
// Note that, even though the implementation of `increment` does not alter `smallerBox`,
// we do not have that knowledge here since the function is marked as opaque.
// Therefore, the following is incorrect (but it holds for `b.other`, see the other `valid/OpaqueMutation`)
assert(oldBox.secretSauce(oldBox.smallerBox.otherCnt) + 1 == b.secretSauce(oldBox.smallerBox.otherCnt))
}
}
15 changes: 15 additions & 0 deletions frontends/benchmarks/imperative/valid/ExternMutation.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import stainless.annotation._

object ExternMutation {
case class Box(var value: BigInt)
case class Container[@mutable T](t: T)

@extern
def f2(b: Container[Box]): Unit = ???

def g2(b: Container[Box]) = {
val b0 = b
f2(b)
assert(b == b0) // Ok, even though `b` is assumed to be modified because `b0` is an alias of `b`
}
}
4 changes: 2 additions & 2 deletions frontends/benchmarks/imperative/valid/MutableTuple.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ object MutableTuple {
}

def t3(): (Foo, Bar) = {
val bar = Bar(1)
val foo = Foo(2)
val bar = Bar(10)
val foo = Foo(20)
(foo, bar)
}

Expand Down
Loading

0 comments on commit 97caa41

Please sign in to comment.