Skip to content

Commit

Permalink
Equivalence checking as a component
Browse files Browse the repository at this point in the history
  • Loading branch information
mario-bucev committed Mar 3, 2023
1 parent 9804a40 commit 1837a11
Show file tree
Hide file tree
Showing 161 changed files with 5,305 additions and 1,275 deletions.
26 changes: 3 additions & 23 deletions core/src/main/scala/stainless/Component.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
package stainless

import utils.{CheckFilter, DefinitionIdFinder, DependenciesFinder}
import extraction.xlang.trees as xt
import io.circe.*
import extraction.xlang.{trees => xt}
import io.circe._
import stainless.extraction.ExtractionSummary

import java.io.File
import scala.concurrent.Future

trait Component { self =>
Expand All @@ -31,27 +32,6 @@ object optFunctions extends inox.OptionDef[Seq[String]] {
val usageRhs = "f1,f2,..."
}

object optCompareFuns extends inox.OptionDef[Seq[String]] {
val name = "comparefuns"
val default = Seq[String]()
val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser)
val usageRhs = "f1,f2,..."
}

object optModels extends inox.OptionDef[Seq[String]] {
val name = "models"
val default = Seq[String]()
val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser)
val usageRhs = "f1,f2,..."
}

object optNorm extends inox.OptionDef[String] {
val name = "norm"
val default = ""
val parser = inox.OptionParsers.stringParser
val usageRhs = "f"
}

trait ComponentRun { self =>
val component: Component
val trees: ast.Trees
Expand Down
20 changes: 12 additions & 8 deletions core/src/main/scala/stainless/MainHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ trait MainHelpers extends inox.MainHelpers { self =>
case object TestsGeneration extends Category {
override def toString: String = "Tests Generation"
}
case object EquivChk extends Category {
override def toString: String = "Equivalence checking"
}

override protected def getOptions: Map[inox.OptionDef[_], Description] = super.getOptions - inox.solvers.optAssumeChecked ++ Map(
optVersion -> Description(General, "Display the version number"),
optConfigFile -> Description(General, "Path to configuration file, set to false to disable (default: stainless.conf or .stainless.conf)"),
optFunctions -> Description(General, "Only consider functions f1,f2,..."),
optCompareFuns -> Description(General, "Only consider functions f1,f2,... for equivalence checking"),
optModels -> Description(General, "Consider functions f1, f2, ... as model functions for equivalence checking"),
optNorm -> Description(General, "Use function f as normalization function for equivalence checking"),
extraction.utils.optDebugObjects -> Description(General, "Only print debug output for functions/adts named o1,o2,..."),
extraction.utils.optDebugPhases -> Description(General, {
// f interpolator does not process escape sequence, we workaround that with the following trick.
Expand All @@ -44,6 +44,7 @@ trait MainHelpers extends inox.MainHelpers { self =>
evaluators.optCodeGen -> Description(Evaluators, "Use code generating evaluator"),
codegen.optInstrumentFields -> Description(Evaluators, "Instrument ADT field access during code generation"),
codegen.optSmallArrays -> Description(Evaluators, "Assume all arrays fit into memory during code generation"),
verification.optSilent -> Description(Verification, "Do not print any message when a verification condition fails due to invalidity or timeout"),
verification.optFailEarly -> Description(Verification, "Halt verification as soon as a check fails (invalid or unknown)"),
verification.optFailInvalid -> Description(Verification, "Halt verification as soon as a check is invalid"),
verification.optVCCache -> Description(Verification, "Enable caching of verification conditions"),
Expand Down Expand Up @@ -77,6 +78,13 @@ trait MainHelpers extends inox.MainHelpers { self =>
utils.Caches.optCacheDir -> Description(General, "Specify the directory in which cache files should be stored"),
testgen.optOutputFile -> Description(TestsGeneration, "Specify the output file"),
testgen.optGenCIncludes -> Description(TestsGeneration, "(GenC variant only) Specify header includes"),
equivchk.optCompareFuns -> Description(EquivChk, "Only consider functions f1,f2,... for equivalence checking"),
equivchk.optModels -> Description(EquivChk, "Consider functions f1, f2, ... as model functions for equivalence checking"),
equivchk.optNorm -> Description(EquivChk, "Use function f as normalization function for equivalence checking"),
equivchk.optEquivalenceOutput -> Description(EquivChk, "JSON output file for equivalence checking"),
equivchk.optN -> Description(EquivChk, "Consider the top N models"),
equivchk.optInitScore -> Description(EquivChk, "Initial score for models, must be positive"),
equivchk.optMaxPerm -> Description(EquivChk, "Maximum number of permutations to be tested when matching auxiliary functions"),
) ++ MainHelpers.components.map { component =>
val option = inox.FlagOptionDef(component.name, default = false)
option -> Description(Pipelines, component.description)
Expand Down Expand Up @@ -108,6 +116,7 @@ trait MainHelpers extends inox.MainHelpers { self =>
frontend.DebugSectionRecovery,
frontend.DebugSectionExtraDeps,
genc.DebugSectionGenC,
equivchk.DebugSectionEquivChk
)

override protected def displayVersion(reporter: inox.Reporter): Unit = {
Expand Down Expand Up @@ -186,11 +195,6 @@ trait MainHelpers extends inox.MainHelpers { self =>
}

import ctx.{reporter, timers}

if (extraction.trace.Trace.optionsError) {
reporter.fatalError(s"Equivalence checking for --comparefuns and --models only works in batched mode.")
}

if (!useParallelism) {
reporter.warning(s"Parallelism is disabled.")
}
Expand Down
61 changes: 61 additions & 0 deletions core/src/main/scala/stainless/ast/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,65 @@ trait TypeOps extends inox.ast.TypeOps {
}.transform(tpe)
}

protected class Unsolvable extends Exception
protected def unsolvable = throw new Unsolvable

/** Collects the constraints that need to be solved for [[unify]].
* Note: this is an override point. */
protected def unificationConstraints(t1: Type, t2: Type, free: Seq[TypeParameter]): List[(TypeParameter, Type)] = (t1, t2) match {
case (adt: ADTType, _) if adt.lookupSort.isEmpty => unsolvable
case (_, adt: ADTType) if adt.lookupSort.isEmpty => unsolvable

case _ if t1 == t2 => Nil

case (adt1: ADTType, adt2: ADTType) if adt1.id == adt2.id =>
(adt1.tps zip adt2.tps).toList flatMap (p => unificationConstraints(p._1, p._2, free))

case (rt: RefinementType, _) => unificationConstraints(rt.getType, t2, free)
case (_, rt: RefinementType) => unificationConstraints(t1, rt.getType, free)

case (pi: PiType, _) => unificationConstraints(pi.getType, t2, free)
case (_, pi: PiType) => unificationConstraints(t1, pi.getType, free)

case (sigma: SigmaType, _) => unificationConstraints(sigma.getType, t2, free)
case (_, sigma: SigmaType) => unificationConstraints(t1, sigma.getType, free)

case (tp: TypeParameter, _) if !(typeOps.typeParamsOf(t2) contains tp) && (free contains tp) => List(tp -> t2)
case (_, tp: TypeParameter) if !(typeOps.typeParamsOf(t1) contains tp) && (free contains tp) => List(tp -> t1)
case (_: TypeParameter, _) => unsolvable
case (_, _: TypeParameter) => unsolvable

case typeOps.Same(NAryType(ts1, _), NAryType(ts2, _)) if ts1.size == ts2.size =>
(ts1 zip ts2).toList flatMap (p => unificationConstraints(p._1, p._2, free))
case _ => unsolvable
}

/** Solves the constraints collected by [[unificationConstraints]].
* Note: this is an override point. */
protected def unificationSolution(const: List[(Type, Type)]): List[(TypeParameter, Type)] = const match {
case Nil => Nil
case (tp: TypeParameter, t) :: tl =>
val replaced = tl map { case (t1, t2) =>
(typeOps.instantiateType(t1, Map(tp -> t)), typeOps.instantiateType(t2, Map(tp -> t)))
}
(tp -> t) :: unificationSolution(replaced)
case (adt: ADTType, _) :: tl if adt.lookupSort.isEmpty => unsolvable
case (_, adt: ADTType) :: tl if adt.lookupSort.isEmpty => unsolvable
case (ADTType(id1, tps1), ADTType(id2, tps2)) :: tl if id1 == id2 =>
unificationSolution((tps1 zip tps2).toList ++ tl)
case typeOps.Same(NAryType(ts1, _), NAryType(ts2, _)) :: tl if ts1.size == ts2.size =>
unificationSolution((ts1 zip ts2).toList ++ tl)
case _ =>
unsolvable
}

/** Unifies two types, under a set of free variables */
def unify(t1: Type, t2: Type, free: Seq[TypeParameter]): Option[List[(TypeParameter, Type)]] = {
try {
Some(unificationSolution(unificationConstraints(t1, t2, free)))
} catch {
case _: Unsolvable => None
}
}

}
Loading

0 comments on commit 1837a11

Please sign in to comment.