Skip to content

Commit

Permalink
Equivalence checking: add option to specify weights, output resulting…
Browse files Browse the repository at this point in the history
… weights & ctex in the JSON output, lift batch restriction
  • Loading branch information
mario-bucev committed Jun 2, 2023
1 parent fab2b1a commit ed949b0
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 20 deletions.
3 changes: 2 additions & 1 deletion core/src/main/scala/stainless/MainHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ trait MainHelpers extends inox.MainHelpers { self =>
equivchk.optNorm -> Description(EquivChk, "Use function f as normalization function for equivalence checking"),
equivchk.optEquivalenceOutput -> Description(EquivChk, "JSON output file for equivalence checking"),
equivchk.optN -> Description(EquivChk, "Consider the top N models"),
equivchk.optInitScore -> Description(EquivChk, "Initial score for models, must be positive"),
equivchk.optInitScore -> Description(EquivChk, "Initial score for models"),
equivchk.optInitWeights -> Description(EquivChk, "Initial weights for models, overriding the initial score"),
equivchk.optMaxPerm -> Description(EquivChk, "Maximum number of permutations to be tested when matching auxiliary functions"),
) ++ MainHelpers.components.map { component =>
val option = inox.FlagOptionDef(component.name, default = false)
Expand Down
35 changes: 23 additions & 12 deletions core/src/main/scala/stainless/equivchk/EquivalenceChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class EquivalenceChecker(override val trees: Trees,
private val allCandidates: Seq[Identifier],
private val norm: Option[Identifier],
private val N: Int,
private val initScore: Int,
private val initWeights: Map[Identifier, Int],
private val maxMatchingPermutation: Int,
private val maxCtex: Int,
private val maxStepsEval: Int)
Expand All @@ -66,6 +66,7 @@ class EquivalenceChecker(override val trees: Trees,
(using val context: inox.Context)
extends Utils with stainless.utils.CtexRemapping { self =>
import trees._
require(initWeights.keySet == allModels.toSet)

//region Examination and rounds ADTs

Expand All @@ -88,7 +89,7 @@ class EquivalenceChecker(override val trees: Trees,

enum Classification {
case Valid(directModel: Identifier)
case Invalid(ctex: Seq[Map[ValDef, Expr]])
case Invalid(ctex: Seq[Seq[(ValDef, Expr)]])
case Unknown
}

Expand All @@ -111,12 +112,13 @@ class EquivalenceChecker(override val trees: Trees,
// Candidates that will need to be manually inspected...
unknowns: Map[Identifier, UnknownData],
// Incorrect signature
wrongs: Set[Identifier])
case class Ctex(mapping: Map[ValDef, Expr], expected: Expr, got: Expr)
wrongs: Set[Identifier],
weights: Map[Identifier, Int])
case class Ctex(mapping: Seq[(ValDef, Expr)], expected: Expr, got: Expr)
case class ValidData(path: Seq[Identifier], solvingInfo: SolvingInfo)
// The list of counter-examples can be empty; the candidate is still invalid but a ctex could not be extracted
// If the solvingInfo is None, the candidate has been pruned.
case class ErroneousData(ctexs: Seq[Map[ValDef, Expr]], solvingInfo: Option[SolvingInfo])
case class ErroneousData(ctexs: Seq[Seq[(ValDef, Expr)]], solvingInfo: Option[SolvingInfo])
case class UnknownData(solvingInfo: SolvingInfo)
// Note: fromCache and trivial are only relevant for valid candidates
case class SolvingInfo(time: Long, solverName: Option[String], fromCache: Boolean, trivial: Boolean) {
Expand All @@ -125,7 +127,7 @@ class EquivalenceChecker(override val trees: Trees,

def getCurrentResults(): Results = {
val equiv = clusters.map { case (model, clst) => model -> clst.toSet }.toMap
Results(equiv, valid.toMap, erroneous.toMap, unknowns.toMap, signatureMismatch.toSet)
Results(equiv, valid.toMap, erroneous.toMap, unknowns.toMap, signatureMismatch.toSet, models.toMap)
}
//endregion

Expand Down Expand Up @@ -219,7 +221,8 @@ class EquivalenceChecker(override val trees: Trees,
}
}

private val models = mutable.LinkedHashMap.from(allModels.map(_ -> initScore))
// Note: we build the LinkedHashMap using `allModels` as insertion order to be deterministic
private val models = mutable.LinkedHashMap.from(allModels.map(m => m -> initWeights(m)))
private val remainingCandidates = mutable.LinkedHashSet.from(allCandidates)
// Candidate -> set of models for tested so far for the candidate, but resulted in an unknown
private val candidateTestedModels = mutable.Map.from(allCandidates.map(_ -> mutable.Set.empty[Identifier]))
Expand Down Expand Up @@ -253,7 +256,7 @@ class EquivalenceChecker(override val trees: Trees,
val ordCtexs = ctexOrderedArguments(fun, pr)(counterex.vars).toSeq
ordCtexs.foreach(addCtex)
val fd = symbols.functions(fun)
val ctexVars = ordCtexs.map(ctex => fd.params.zip(ctex).toMap)
val ctexVars = ordCtexs.map(ctex => fd.params.zip(ctex))
erroneous += fun -> ErroneousData(ctexVars, Some(extractSolvingInfo(analysis, fun, Seq.empty)))
Some(Set(fun))
} else candidatesCallee.get(fun) match {
Expand Down Expand Up @@ -454,7 +457,7 @@ class EquivalenceChecker(override val trees: Trees,
val candFd = symbols.functions(cand)
// Take all ctex for `cand`, `eqLemma` and `proof`
val ctexOrderedArgs = (Seq(cand, eqLemma) ++ proof.toSeq).flatMap(id => allCtexs.getOrElse(id, Seq.empty))
val ctexsMap = ctexOrderedArgs.map(ctex => candFd.params.zip(ctex).toMap)
val ctexsMap = ctexOrderedArgs.map(ctex => candFd.params.zip(ctex))
erroneous += cand -> ErroneousData(ctexsMap, Some(solvingInfo.withAddedTime(currCumulativeSolvingTime)))
examinationState = ExaminationState.PickNext
RoundConclusion.CandidateClassified(cand, Classification.Invalid(ctexsMap), Set.empty)
Expand Down Expand Up @@ -608,7 +611,7 @@ class EquivalenceChecker(override val trees: Trees,
findMap(samples.zipWithIndex) { case (arg, sampleIx) =>
passTestSample(arg, instParams).map(_ -> sampleIx)
}.map { case ((evalArgs, expected, got), sampleIx) =>
EvalCheck.FailsTest(id, sampleIx, Ctex(cand.params.zip(evalArgs).toMap, expected, got))
EvalCheck.FailsTest(id, sampleIx, Ctex(cand.params.zip(evalArgs), expected, got))
}
}

Expand Down Expand Up @@ -753,7 +756,7 @@ class EquivalenceChecker(override val trees: Trees,
.map { case (ctex, expected, got) =>
// ctex is ordered according to the model, so we need to reorder cand according to the permutation
val candReorg = candPerm.m2c.map(cand.params)
Ctex(candReorg.zip(ctex).toMap, expected, got)
Ctex(candReorg.zip(ctex), expected, got)
}
}
//endregion
Expand Down Expand Up @@ -1091,8 +1094,16 @@ object EquivalenceChecker {
}
val testsOk = testsOk0.toMap

val initWeightsPath = context.options.findOptionOrDefault(optInitWeights)
.toSeq.map { case (fn, w) => CheckFilter.fullNameToPath(fn) -> w }

val initWeights = models.map { mod =>
indexOfPath(Some(initWeightsPath.map(_._1)), mod)
.map(ix => mod -> initWeightsPath(ix)._2)
.getOrElse(mod -> initScore)
}.toMap
class EquivalenceCheckerImpl(override val trees: ts.type, override val symbols: syms.type)
extends EquivalenceChecker(ts, models, functions, norm, n, initScore, maxPerm, maxCtex, defaultMaxStepsEval)(testsOk, symbols)
extends EquivalenceChecker(ts, models, functions, norm, n, initWeights, maxPerm, maxCtex, defaultMaxStepsEval)(testsOk, symbols)
val ec = new EquivalenceCheckerImpl(ts, syms)
class SuccessImpl(override val trees: ts.type, override val symbols: syms.type, override val equivChker: ec.type) extends Creation.Success
new SuccessImpl(ts, syms, ec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@ object optEquivalenceOutput extends FileOptionDef {

object optN extends inox.IntOptionDef("equivchk-n", EquivalenceChecker.defaultN, "<int>")
object optInitScore extends inox.IntOptionDef("equivchk-init-score", EquivalenceChecker.defaultInitScore, "<int>")
object optInitWeights extends inox.OptionDef[Map[String, Int]] {
val name = "equivchk-init-weights"
val default = Map.empty
val parser = s => {
def tryParsePair(s: String): Option[(String, Int)] = {
val ix = s.lastIndexOf(':')
if (ix <= 0) None // Also exclude empty function name
else {
val (fn, w) = s.splitAt(ix)
w.drop(1).toIntOption.map(fn -> _)
}
}
val pairs = s.split(",")

def go(i: Int, acc: Map[String, Int]): Option[Map[String, Int]] = {
if (i >= pairs.length) Some(acc)
else tryParsePair(pairs(i)) match {
case Some(p) => go(i + 1, acc + p)
case None => None
}
}
go(0, Map.empty)
}
val usageRhs = "fn1:w1,fn2:w2,..."
}
object optMaxPerm extends inox.IntOptionDef("equivchk-max-perm", EquivalenceChecker.defaultMaxMatchingPermutation, "<int>")
object optMaxCtex extends inox.IntOptionDef("equivchk-max-ctex", EquivalenceChecker.defaultMaxCtex, "<int>")

Expand Down Expand Up @@ -214,8 +239,11 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
val rsonStr = reason match {
case ec.PruningReason.SignatureMismatch => "signature mismatch"
case ec.PruningReason.ByTest(testId, sampleIx, ctex) =>
s"test falsification by ${testId.fullName} sample n°${sampleIx + 1} with ${prettyCtex(ec)(ctex.mapping)}\n Expected: ${ctex.expected} but got: ${ctex.got}"
case ec.PruningReason.ByPreviousCtex(ctex) => s"counter-example falsification with ${prettyCtex(ec)(ctex.mapping)}\n Expected: ${ctex.expected} but got: ${ctex.got}"
s"""test falsification by ${testId.fullName} sample n°${sampleIx + 1} with ${prettyCtex(ec)(ctex.mapping)}
| Expected: ${ctex.expected} but got: ${ctex.got}""".stripMargin
case ec.PruningReason.ByPreviousCtex(ctex) =>
s"""counter-example falsification with ${prettyCtex(ec)(ctex.mapping)}
| Expected: ${ctex.expected} but got: ${ctex.got}""".stripMargin
}
s"${fn.fullName}: $rsonStr"
}
Expand Down Expand Up @@ -291,15 +319,19 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
}
}

private def prettyCtex(ec: EquivalenceChecker)(ctex: Map[ec.trees.ValDef, ec.trees.Expr]): String =
ctex.toSeq.map { case (vd, e) => s"${vd.id.name} -> $e" }.mkString(", ")
private def prettyCtex(ec: EquivalenceChecker)(ctex: Seq[(ec.trees.ValDef, ec.trees.Expr)]): String =
ctex.map { case (vd, e) => s"${vd.id.name} -> $e" }.mkString(", ")

private def dumpResultsJson(out: File, ec: EquivalenceChecker)(res: ec.Results): Unit = {
val equivs = res.equiv.map { case (m, l) => m.fullName -> l.map(_.fullName).toSeq.sorted }
.toSeq.sortBy(_._1)
val errns = res.erroneous.keys.toSeq.map(_.fullName).sorted
val errns = res.erroneous.map { case (fn, errn) =>
fn.fullName -> errn.ctexs.map(_.map { case (vd, expr) => (vd.id.name, expr.toString) })
}.toSeq.sortBy(_._1)
val unknowns = res.unknowns.keys.toSeq.map(_.fullName).sorted
val wrongs = res.wrongs.toSeq.map(_.fullName).sorted
val weights = res.weights.map { case (mod, w) => mod.fullName -> w }.toSeq
.sortBy { case (mod, w) => (-w, mod) }

val json = Json.fromFields(Seq(
"equivalent" -> Json.fromValues(equivs.map { case (m, l) =>
Expand All @@ -308,9 +340,19 @@ class EquivalenceCheckingRun private(override val component: EquivalenceChecking
"functions" -> Json.fromValues(l.map(Json.fromString))
))
}),
"erroneous" -> Json.fromValues(errns.map(Json.fromString)),
"erroneous" -> Json.fromValues(errns.map { case (fn, ctexs) =>
Json.fromFields(Seq(
"function" -> Json.fromString(fn),
"ctexs" -> Json.fromValues(
ctexs.map { ctex =>
Json.fromFields(ctex.map { case (vd, expr) => vd -> Json.fromString(expr) })
}
)
))
}),
"timeout" -> Json.fromValues(unknowns.sorted.map(Json.fromString)),
"wrong" -> Json.fromValues(wrongs.sorted.map(Json.fromString)),
"weights" -> Json.fromFields(weights.map { case (mod, w) => mod -> Json.fromInt(w) })
))
JsonUtils.writeFile(out, json)
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/stainless/frontend/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ package object frontend {

private def batchSymbols(activeComponents: Seq[Component])(using ctx: inox.Context): Boolean = {
ctx.options.findOptionOrDefault(optBatchedProgram) ||
activeComponents.exists(Set(genc.GenCComponent, testgen.ScalaTestGenComponent, testgen.GenCTestGenComponent, equivchk.EquivalenceCheckingComponent).contains) ||
activeComponents.exists(Set(genc.GenCComponent, testgen.ScalaTestGenComponent, testgen.GenCTestGenComponent).contains) ||
ctx.options.findOptionOrDefault(optKeep).nonEmpty
}

Expand Down

0 comments on commit ed949b0

Please sign in to comment.