Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
mario-bucev committed Nov 24, 2022
1 parent 5526990 commit c0139cf
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 27 deletions.
69 changes: 42 additions & 27 deletions core/src/main/scala/stainless/extraction/oo/RefinementLifting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ class RefinementLifting(override val s: Trees, override val t: Trees)
case _ => super.transform(e)
}

case s.LetRec(fds, body) =>
val nlfd = fds.map(lfd => transform(Inner(lfd)).toLocal)
t.LetRec(nlfd, transform(body)).copiedFrom(e)

case s.ApplyLetRec(id, tparams, tpe, tps, args) => liftRefinements(tpe) match {
case s.RefinementType(vd, s.BooleanLiteral(true)) =>
val ftTpe = vd.tpe.asInstanceOf[s.FunctionType]
Expand Down Expand Up @@ -197,42 +201,48 @@ class RefinementLifting(override val s: Trees, override val t: Trees)
}

override def transform(tpe: s.Type): t.Type = super.transform(liftRefinements(tpe))
}

override protected def extractFunction(context: TransformerContext, fd: s.FunDef): (t.FunDef, Unit) = {
import s._
import exprOps._
def transform(fd: s.FunAbstraction): t.FunAbstraction = {
import s._
import exprOps._

// FIXME: Shouldn't we propagate `newParams`?
val (newParams, cond) = context.parameterConds(fd.params)
val (newParams, cond) = parameterConds(fd.params)

val specced = exprOps.BodyWithSpecs(fd.fullBody)
val specced = exprOps.BodyWithSpecs(fd.fullBody)

val optPre =
if (cond != s.BooleanLiteral(true)) Some(exprOps.Precondition(cond).setPos(cond))
else None
val optPre =
if (cond != s.BooleanLiteral(true)) Some(exprOps.Precondition(cond).setPos(cond))
else None

val optOldPost = specced.getSpec(exprOps.PostconditionKind).map(_.expr)
val optPost = (context.liftRefinements(fd.returnType) match {
case s.RefinementType(vd2, pred) => optOldPost match {
case Some(post @ s.Lambda(Seq(res), body)) =>
Some(s.Lambda(Seq(res), s.and(
val optOldPost = specced.getSpec(exprOps.PostconditionKind).map(_.expr)
val optPost = (liftRefinements(fd.returnType) match {
case s.RefinementType(vd2, pred) => optOldPost match {
case Some(post@s.Lambda(Seq(res), body)) =>
Some(s.Lambda(Seq(res), s.and(
exprOps.replaceFromSymbols(Map(vd2 -> res.toVariable), pred),
body).copiedFrom(body)).copiedFrom(post))
case Some(post @ s.Lambda(_, _)) =>
sys.error(s"Unexpected number of params for postcondition lambda: $post")
case None =>
Some(s.Lambda(Seq(vd2), pred).copiedFrom(fd))
}
case _ => optOldPost
}).map(exprOps.Postcondition.apply)

(context.transform(fd.copy(
fullBody = specced.addSpec(optPre).withSpec(optPost).reconstructed,
returnType = context.dropRefinements(fd.returnType)
).copiedFrom(fd)), ())
case Some(post@s.Lambda(_, _)) =>
sys.error(s"Unexpected number of params for postcondition lambda: $post")
case None =>
Some(s.Lambda(Seq(vd2), pred).copiedFrom(fd))
}
case _ => optOldPost
}).map(exprOps.Postcondition.apply)

fd.to(t)(
fd.id,
fd.tparams.map(identity.transform),
newParams.map(identity.transform),
transform(dropRefinements(fd.returnType)),
transform(specced.addSpec(optPre).withSpec(optPost).reconstructed),
fd.flags.map(identity.transform)
).copiedFrom(fd)
}
}

override protected def extractFunction(context: TransformerContext, fd: s.FunDef): (t.FunDef, Unit) =
(context.transform(s.Outer(fd)).toFun, ())

override protected def extractSort(context: TransformerContext, sort: s.ADTSort): ((t.ADTSort, Option[t.FunDef]), Unit) = {
import s._
import context.symbols.{given, _}
Expand Down Expand Up @@ -294,6 +304,11 @@ class RefinementLifting(override val s: Trees, override val t: Trees)
// TODO: lift refinements to invariant?
(context.transform(cd), ())
}

private val identity = {
class IdentityImpl(override val s: self.s.type, override val t: self.t.type) extends ConcreteTreeTransformer(s, t)
new IdentityImpl(self.s, self.t)
}
}

object RefinementLifting {
Expand Down
32 changes: 32 additions & 0 deletions frontends/benchmarks/verification/valid/i1214a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import stainless.lang._

object i1214a {
abstract sealed class Ordinal {
def l: BigInt = {
this match {
case Nat() => BigInt(0)
case Transfinite(in) => in.l + 1
}
}.ensuring(res => res >= 0)
}

case class Nat() extends Ordinal
case class Transfinite(in: Ordinal) extends Ordinal //removing Ordinal removes location information

object lemmas {
def bar(a: Ordinal): Unit = {
decreases(a.l, 1)
def helper(c: Transfinite): Unit = {
decreases(c.l, 0)
assert(c.isInstanceOf[Transfinite])
bar(c.in)
}

a match {
case b: Transfinite =>
helper(b)
case Nat() =>
}
}
}
}
17 changes: 17 additions & 0 deletions frontends/benchmarks/verification/valid/i1214b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import stainless.lang._

object i1214b {
abstract sealed class Ordinal
case class Nat() extends Ordinal
case class Transfinite(in: Ordinal) extends Ordinal //removing Ordinal removes location information

def free(c: Transfinite): Unit = {
assert(c.isInstanceOf[Transfinite])
}

def bar(): Unit = {
def inner(c: Transfinite): Unit = {
assert(c.isInstanceOf[Transfinite])
}
}
}

0 comments on commit c0139cf

Please sign in to comment.