diff --git a/core/src/main/scala/stainless/Component.scala b/core/src/main/scala/stainless/Component.scala index d6277d77fc..eeab6c252c 100644 --- a/core/src/main/scala/stainless/Component.scala +++ b/core/src/main/scala/stainless/Component.scala @@ -3,10 +3,11 @@ package stainless import utils.{CheckFilter, DefinitionIdFinder, DependenciesFinder} -import extraction.xlang.trees as xt -import io.circe.* +import extraction.xlang.{trees => xt} +import io.circe._ import stainless.extraction.ExtractionSummary +import java.io.File import scala.concurrent.Future trait Component { self => @@ -31,27 +32,6 @@ object optFunctions extends inox.OptionDef[Seq[String]] { val usageRhs = "f1,f2,..." } -object optCompareFuns extends inox.OptionDef[Seq[String]] { - val name = "comparefuns" - val default = Seq[String]() - val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser) - val usageRhs = "f1,f2,..." -} - -object optModels extends inox.OptionDef[Seq[String]] { - val name = "models" - val default = Seq[String]() - val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser) - val usageRhs = "f1,f2,..." -} - -object optNorm extends inox.OptionDef[String] { - val name = "norm" - val default = "" - val parser = inox.OptionParsers.stringParser - val usageRhs = "f" -} - trait ComponentRun { self => val component: Component val trees: ast.Trees diff --git a/core/src/main/scala/stainless/MainHelpers.scala b/core/src/main/scala/stainless/MainHelpers.scala index 5ff349c294..b2136bf770 100644 --- a/core/src/main/scala/stainless/MainHelpers.scala +++ b/core/src/main/scala/stainless/MainHelpers.scala @@ -23,14 +23,14 @@ trait MainHelpers extends inox.MainHelpers { self => case object TestsGeneration extends Category { override def toString: String = "Tests Generation" } + case object EquivChk extends Category { + override def toString: String = "Equivalence checking" + } override protected def getOptions: Map[inox.OptionDef[_], Description] = super.getOptions - inox.solvers.optAssumeChecked ++ Map( optVersion -> Description(General, "Display the version number"), optConfigFile -> Description(General, "Path to configuration file, set to false to disable (default: stainless.conf or .stainless.conf)"), optFunctions -> Description(General, "Only consider functions f1,f2,..."), - optCompareFuns -> Description(General, "Only consider functions f1,f2,... for equivalence checking"), - optModels -> Description(General, "Consider functions f1, f2, ... as model functions for equivalence checking"), - optNorm -> Description(General, "Use function f as normalization function for equivalence checking"), extraction.utils.optDebugObjects -> Description(General, "Only print debug output for functions/adts named o1,o2,..."), extraction.utils.optDebugPhases -> Description(General, { // f interpolator does not process escape sequence, we workaround that with the following trick. @@ -44,6 +44,7 @@ trait MainHelpers extends inox.MainHelpers { self => evaluators.optCodeGen -> Description(Evaluators, "Use code generating evaluator"), codegen.optInstrumentFields -> Description(Evaluators, "Instrument ADT field access during code generation"), codegen.optSmallArrays -> Description(Evaluators, "Assume all arrays fit into memory during code generation"), + verification.optSilent -> Description(Verification, "Do not print any message when a verification condition fails due to invalidity or timeout"), verification.optFailEarly -> Description(Verification, "Halt verification as soon as a check fails (invalid or unknown)"), verification.optFailInvalid -> Description(Verification, "Halt verification as soon as a check is invalid"), verification.optVCCache -> Description(Verification, "Enable caching of verification conditions"), @@ -77,6 +78,13 @@ trait MainHelpers extends inox.MainHelpers { self => utils.Caches.optCacheDir -> Description(General, "Specify the directory in which cache files should be stored"), testgen.optOutputFile -> Description(TestsGeneration, "Specify the output file"), testgen.optGenCIncludes -> Description(TestsGeneration, "(GenC variant only) Specify header includes"), + equivchk.optCompareFuns -> Description(EquivChk, "Only consider functions f1,f2,... for equivalence checking"), + equivchk.optModels -> Description(EquivChk, "Consider functions f1, f2, ... as model functions for equivalence checking"), + equivchk.optNorm -> Description(EquivChk, "Use function f as normalization function for equivalence checking"), + equivchk.optEquivalenceOutput -> Description(EquivChk, "JSON output file for equivalence checking"), + equivchk.optN -> Description(EquivChk, "Consider the top N models"), + equivchk.optInitScore -> Description(EquivChk, "Initial score for models, must be positive"), + equivchk.optMaxPerm -> Description(EquivChk, "Maximum number of permutations to be tested when matching auxiliary functions"), ) ++ MainHelpers.components.map { component => val option = inox.FlagOptionDef(component.name, default = false) option -> Description(Pipelines, component.description) @@ -108,6 +116,7 @@ trait MainHelpers extends inox.MainHelpers { self => frontend.DebugSectionRecovery, frontend.DebugSectionExtraDeps, genc.DebugSectionGenC, + equivchk.DebugSectionEquivChk ) override protected def displayVersion(reporter: inox.Reporter): Unit = { @@ -186,11 +195,6 @@ trait MainHelpers extends inox.MainHelpers { self => } import ctx.{reporter, timers} - - if (extraction.trace.Trace.optionsError) { - reporter.fatalError(s"Equivalence checking for --comparefuns and --models only works in batched mode.") - } - if (!useParallelism) { reporter.warning(s"Parallelism is disabled.") } diff --git a/core/src/main/scala/stainless/ast/TypeOps.scala b/core/src/main/scala/stainless/ast/TypeOps.scala index 5eaff97da1..1b063df6e2 100644 --- a/core/src/main/scala/stainless/ast/TypeOps.scala +++ b/core/src/main/scala/stainless/ast/TypeOps.scala @@ -81,4 +81,65 @@ trait TypeOps extends inox.ast.TypeOps { }.transform(tpe) } + protected class Unsolvable extends Exception + protected def unsolvable = throw new Unsolvable + + /** Collects the constraints that need to be solved for [[unify]]. + * Note: this is an override point. */ + protected def unificationConstraints(t1: Type, t2: Type, free: Seq[TypeParameter]): List[(TypeParameter, Type)] = (t1, t2) match { + case (adt: ADTType, _) if adt.lookupSort.isEmpty => unsolvable + case (_, adt: ADTType) if adt.lookupSort.isEmpty => unsolvable + + case _ if t1 == t2 => Nil + + case (adt1: ADTType, adt2: ADTType) if adt1.id == adt2.id => + (adt1.tps zip adt2.tps).toList flatMap (p => unificationConstraints(p._1, p._2, free)) + + case (rt: RefinementType, _) => unificationConstraints(rt.getType, t2, free) + case (_, rt: RefinementType) => unificationConstraints(t1, rt.getType, free) + + case (pi: PiType, _) => unificationConstraints(pi.getType, t2, free) + case (_, pi: PiType) => unificationConstraints(t1, pi.getType, free) + + case (sigma: SigmaType, _) => unificationConstraints(sigma.getType, t2, free) + case (_, sigma: SigmaType) => unificationConstraints(t1, sigma.getType, free) + + case (tp: TypeParameter, _) if !(typeOps.typeParamsOf(t2) contains tp) && (free contains tp) => List(tp -> t2) + case (_, tp: TypeParameter) if !(typeOps.typeParamsOf(t1) contains tp) && (free contains tp) => List(tp -> t1) + case (_: TypeParameter, _) => unsolvable + case (_, _: TypeParameter) => unsolvable + + case typeOps.Same(NAryType(ts1, _), NAryType(ts2, _)) if ts1.size == ts2.size => + (ts1 zip ts2).toList flatMap (p => unificationConstraints(p._1, p._2, free)) + case _ => unsolvable + } + + /** Solves the constraints collected by [[unificationConstraints]]. + * Note: this is an override point. */ + protected def unificationSolution(const: List[(Type, Type)]): List[(TypeParameter, Type)] = const match { + case Nil => Nil + case (tp: TypeParameter, t) :: tl => + val replaced = tl map { case (t1, t2) => + (typeOps.instantiateType(t1, Map(tp -> t)), typeOps.instantiateType(t2, Map(tp -> t))) + } + (tp -> t) :: unificationSolution(replaced) + case (adt: ADTType, _) :: tl if adt.lookupSort.isEmpty => unsolvable + case (_, adt: ADTType) :: tl if adt.lookupSort.isEmpty => unsolvable + case (ADTType(id1, tps1), ADTType(id2, tps2)) :: tl if id1 == id2 => + unificationSolution((tps1 zip tps2).toList ++ tl) + case typeOps.Same(NAryType(ts1, _), NAryType(ts2, _)) :: tl if ts1.size == ts2.size => + unificationSolution((ts1 zip ts2).toList ++ tl) + case _ => + unsolvable + } + + /** Unifies two types, under a set of free variables */ + def unify(t1: Type, t2: Type, free: Seq[TypeParameter]): Option[List[(TypeParameter, Type)]] = { + try { + Some(unificationSolution(unificationConstraints(t1, t2, free))) + } catch { + case _: Unsolvable => None + } + } + } diff --git a/core/src/main/scala/stainless/equivchk/EquivalenceChecker.scala b/core/src/main/scala/stainless/equivchk/EquivalenceChecker.scala new file mode 100644 index 0000000000..1b23e7348e --- /dev/null +++ b/core/src/main/scala/stainless/equivchk/EquivalenceChecker.scala @@ -0,0 +1,1306 @@ +/* Copyright 2009-2021 EPFL, Lausanne */ + +package stainless +package equivchk + +import inox.utils.Position +import io.circe.{Json, JsonObject} +import stainless.equivchk.EquivalenceChecker._ +import stainless.extraction.trace._ +import stainless.utils.{CheckFilter, JsonUtils} +import stainless.verification.{VCResult, VCStatus, VerificationAnalysis} +import stainless.{FreshIdentifier, Identifier, Program, StainlessProgram, evaluators} + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.math.Ordering.IntOrdering +import scala.util.control.NonFatal + +// EquivalenceChecker workflow consists of a preliminary analysis and examinations made up of rounds. +// +// The preliminary analysis is done outside of EquivalenceChecker: it is a general verification pass over all candidates +// to catch for invalid VCs (such as division by zero and so on). +// These invalid VCs are communicated to EquivalenceChecker with reportErroneous. +// Candidates having at least one invalid VCs are classified as "erroneous" and +// are not considered for equivalence checking any further. +// +// After this pass, EquivalenceChecker works in examination and rounds. +// First, we pick a candidate to examine with `pickNextExamination` returning a `NextExamination` +// The picked candidate will be checked for equivalence against models according to various strategies. +// For each strategy, we prepare a so-called "round" with `prepareRound` which will return functions encoding +// equivalence checking condition according to that strategy. +// These functions are then sent to the solver, and the results are communicated back to EquivalenceChecker with `concludeRound`. +// Depending on the result, we can either classify the candidate and go with the next examination +// (indicated by `concludeRound` returning RoundConclusion.CandidateClassified) or we need to try a new strategy +// and go with the next round (`concludeRound` returning RoundConclusion.NextRound). +// +// The strategies are applied in the following order: +// -Pick the top 3 models according to their score. +// -Repeat until we are done or we have tried all 3 models: +// -"Model first without sublemmas": we try to prove equivalence by using the selected model as template for induction. +// Functions inside the candidate and the model do not get any special treatment. +// If the equiv. check succeeds, the candidate is correct. +// Otherwise, if it is invalid, the candidate is classified as erroneous (i.e., not equivalent, it is incorrect) +// If it is inconclusive (timeout), we try the next strategy. +// -"Candidate first without sublemmas": as above, except we use the candidate as template for induction +// -"Model first with sublemmas": uses the selected model for induction; functions calls appearing in model and candidate +// are matched against each other and are checked for equivalence as well. +// If these subfunctions equivalence all succeed, we are done. +// Otherwise (invalid results or timeout), we try another matching of function until we have tried all of them. +// Note that an invalid result does not necessarily mean that the candidate is incorrect, it may be the case that +// the matching of function calls we have tried is not the good one. +// -"Candidate first with sublemmas": as above but we use the candidate for the induction +// -If all these strategies are inconclusive, try with the next model (until we have tried all of the 3). +// +class EquivalenceChecker(override val trees: Trees, + private val allModels: Seq[Identifier], + private val allCandidates: Seq[Identifier], + private val norm: Option[Identifier], + private val N: Int, + private val initScore: Int, + private val maxMatchingPermutation: Int, + private val maxCtex: Int, + private val maxStepsEval: Int) + (private val tests: Map[Identifier, (Seq[trees.Expr], Seq[trees.Type])], + val symbols: trees.Symbols) + (using val context: inox.Context) + extends Utils with stainless.utils.CtexRemapping { self => + import trees._ + + //region Examination and rounds ADTs + + enum NextExamination { + // In both `Done` and `NewCandidate`, `pruned` contains candidates functions that got classified + // without needing any further examination. + case Done(pruned: Map[Identifier, PruningReason], results: Results) + case NewCandidate(cand: Identifier, model: Identifier, strat: EquivCheckStrategy, pruned: Map[Identifier, PruningReason]) + } + + enum RoundConclusion { + case NextRound(cand: Identifier, + model: Identifier, + strat: EquivCheckStrategy, + prunedSubFnsPairs: Set[(Identifier, Identifier, ArgPermutation)]) + case CandidateClassified(cand: Identifier, + classification: Classification, + prunedSubFnsPairs: Set[(Identifier, Identifier, ArgPermutation)]) + } + + enum Classification { + case Valid(directModel: Identifier) + case Invalid(ctex: Seq[Map[ValDef, Expr]]) + case Unknown + } + + enum PruningReason { + case SignatureMismatch + case ByTest(testId: Identifier, sampleIx: Int, ctex: Ctex) + case ByPreviousCtex(ctex: Ctex) + } + //endregion + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + //region Final results definitions + + case class Results(// Clusters + equiv: Map[Identifier, Set[Identifier]], + valid: Map[Identifier, ValidData], + // Incorrect, either due to not being equivalent or having invalid VCs + erroneous: Map[Identifier, ErroneousData], + // Candidates that will need to be manually inspected... + unknowns: Map[Identifier, UnknownData], + // Incorrect signature + wrongs: Set[Identifier]) + case class Ctex(mapping: Map[ValDef, Expr], expected: Expr, got: Expr) + case class ValidData(path: Seq[Identifier], solvingInfo: SolvingInfo) + // The list of counter-examples can be empty; the candidate is still invalid but a ctex could not be extracted + // If the solvingInfo is None, the candidate has been pruned. + case class ErroneousData(ctexs: Seq[Map[ValDef, Expr]], solvingInfo: Option[SolvingInfo]) + case class UnknownData(solvingInfo: SolvingInfo) + // Note: fromCache and trivial are only relevant for valid candidates + case class SolvingInfo(time: Long, solverName: Option[String], fromCache: Boolean, trivial: Boolean) { + def withAddedTime(extra: Long): SolvingInfo = copy(time = time + extra) + } + + def getCurrentResults(): Results = { + val equiv = clusters.map { case (model, clst) => model -> clst.toSet }.toMap + Results(equiv, valid.toMap, erroneous.toMap, unknowns.toMap, signatureMismatch.toSet) + } + //endregion + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + //region Definitions for strategies + + enum EquivCheckOrder { + case ModelFirst + case CandidateFirst + } + + case class EquivCheckStrategy(order: EquivCheckOrder, subFnsMatchingStrat: Option[SubFnsMatchingStrat]) { + def pretty: String = (order, subFnsMatchingStrat) match { + case (EquivCheckOrder.ModelFirst, None) => "model first without sublemmas" + case (EquivCheckOrder.CandidateFirst, None) => "candidate first without sublemmas" + case (EquivCheckOrder.ModelFirst, Some(matchingStrat)) => s"model first with sublemmas: ${matchingStrat.curr.pretty}" + case (EquivCheckOrder.CandidateFirst, Some(matchingStrat)) => s"candidate first with sublemmas: ${matchingStrat.curr.pretty}" + } + } + + object EquivCheckStrategy { + def init: EquivCheckStrategy = EquivCheckStrategy(EquivCheckOrder.ModelFirst, None) + } + + // Pairs of model - candidate sub functions with argument permutation + type SubFnsMatching = Matching[Identifier, ArgPermutation] + + case class SubFnsMatchingStrat(curr: SubFnsMatching, rest: Seq[SubFnsMatching], all: Seq[SubFnsMatching]) + + extension (matching: SubFnsMatching) { + def pretty: String = matching.pairs + .map { case ((mod, cand), perm) => + s"${CheckFilter.fixedFullName(mod)} <-> ${CheckFilter.fixedFullName(cand)} (permutation = ${perm.m2c.mkString(", ")})" + } + .mkString(", ") + } + //endregion + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + //region Private state + + private val ctexEvalPermutationLimit = 16 + + private enum ExaminationState { + case PickNext + case Examining(candidate: Identifier, roundState: RoundState) + } + + private case class RoundState(model: Identifier, + remainingModels: Seq[Identifier], + strat: EquivCheckStrategy, + equivLemmas: EquivLemmas, + cumulativeSolvingTime: Long) + + private enum EquivLemmas { + case ToGenerate + case Generated(eqLemma: Identifier, + proof: Option[Identifier], + sublemmas: Seq[Identifier]) + } + + private case class EquivCheckConf(model: FunDef, candidate: FunDef, strat: EquivCheckStrategy, topLevel: Boolean) { + val (fd1, fd2) = strat.order match { + case EquivCheckOrder.ModelFirst => (model, candidate) + case EquivCheckOrder.CandidateFirst => (candidate, model) + } + } + + // Function called from a candidate (callee) -> candidate(s) (caller) + // (may include themselves) + private val candidatesCallee: Map[Identifier, Set[Identifier]] = { + allCandidates.foldLeft(Map.empty[Identifier, Set[Identifier]]) { + case (acc, cand) => + val callees = symbols.transitiveCallees(cand) + callees.foldLeft(acc) { + case (acc, callee) => + val curr = acc.getOrElse(callee, Set.empty) + acc + (callee -> (curr + cand)) + } + } + } + + private val models = mutable.LinkedHashMap.from(allModels.map(_ -> initScore)) + private val remainingCandidates = mutable.LinkedHashSet.from(allCandidates) + // Candidate -> set of models for tested so far for the candidate, but resulted in an unknown + private val candidateTestedModels = mutable.Map.from(allCandidates.map(_ -> mutable.Set.empty[Identifier])) + private var nbExaminedCandidates = allCandidates.size + private var examinationState: ExaminationState = ExaminationState.PickNext + private val valid = mutable.Map.empty[Identifier, ValidData] + // candidate -> list of counter-examples (can be empty, in which case the candidate is invalid but a ctex could not be extracted) + private val erroneous = mutable.Map.empty[Identifier, ErroneousData] + private val unknowns = mutable.LinkedHashMap.empty[Identifier, UnknownData] + private val signatureMismatch = mutable.ArrayBuffer.empty[Identifier] + private val clusters = mutable.Map.empty[Identifier, mutable.ArrayBuffer[Identifier]] + + // Type -> multiplicity + private case class UnordSig(args: Map[Type, Int]) + // Type -> list of values, whose length is the multiplicity of the type + private case class UnordCtex(args: Map[Type, Seq[Expr]]) + private val ctexsDb = mutable.Map.empty[UnordSig, mutable.Set[UnordCtex]] + + // Set of model subfn - candidate subfn for which we know (by counter example falsification) that do not match + // This is useful for pruning invalid matching without having to re-evaluate the pair. + private val invalidFunctionsPairsCache = mutable.Set.empty[(Identifier, Identifier, ArgPermutation)] + //endregion + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + //region Public API + + def reportErroneous(pr: StainlessProgram)(analysis: VerificationAnalysis, counterex: pr.Model)(fun: Identifier): Option[Set[Identifier]] = { + if (allCandidates.contains(fun)) { + remainingCandidates -= fun + val ordCtexs = ctexOrderedArguments(fun, pr)(counterex.vars).toSeq + ordCtexs.foreach(addCtex) + val fd = symbols.functions(fun) + val ctexVars = ordCtexs.map(ctex => fd.params.zip(ctex).toMap) + erroneous += fun -> ErroneousData(ctexVars, Some(extractSolvingInfo(analysis, fun, Seq.empty))) + Some(Set(fun)) + } else candidatesCallee.get(fun) match { + case Some(cands) => + // This means this erroneous `fun` is called by all candidates in `cands`. + // `cands` should be of size 1 because a function called by multiple candidates must be either a library fn or + // a provided function which are all assumed to be correct. + cands.foreach { cand => + remainingCandidates -= cand + // No ctex available, because counterex corresponds to the signature of `fun` not necessarily `cand` + erroneous += cand -> ErroneousData(Seq.empty, Some(extractSolvingInfo(analysis, cand, Seq.empty))) + } + Some(cands) + case None => + // Nobody knows about this function + None + } + } + + def pickNextExamination(): NextExamination = { + assert(examinationState == ExaminationState.PickNext) + + val anyModel = symbols.functions(allModels.head) + var picked: Option[Identifier] = None + val pruned = mutable.Map.empty[Identifier, PruningReason] + while (picked.isEmpty && remainingCandidates.nonEmpty) { + val candId = remainingCandidates.head + remainingCandidates -= candId + val cand = symbols.functions(candId) + + if (areSignaturesCompatible(cand, anyModel)) { + evalCheck(anyModel, cand) match { + case EvalCheck.Ok => + picked = Some(candId) + case EvalCheck.FailsTest(testId, sampleIx, ctex) => + erroneous += candId -> ErroneousData(Seq(ctex.mapping), None) + pruned += candId -> PruningReason.ByTest(testId, sampleIx, ctex) + case EvalCheck.FailsCtex(ctex) => + erroneous += candId -> ErroneousData(Seq(ctex.mapping), None) + pruned += candId -> PruningReason.ByPreviousCtex(ctex) + } + } else { + signatureMismatch += candId + pruned += candId -> PruningReason.SignatureMismatch + } + } + + picked match { + case Some(candId) => + val topN = models.toSeq + .filter { case (mod, _ ) => !candidateTestedModels(candId).contains(mod) } // Do not test models for which this candidate got an unknown + .sortBy(-_._2) + .take(N).map(_._1) + if (topN.nonEmpty) { + val strat = EquivCheckStrategy.init + examinationState = ExaminationState.Examining(candId, RoundState(topN.head, topN.tail, strat, EquivLemmas.ToGenerate, 0L)) + NextExamination.NewCandidate(candId, topN.head, strat, pruned.toMap) + } else { + pickNextExamination() match { + case d@NextExamination.Done(_, _) => d.copy(pruned = pruned.toMap ++ d.pruned) + case nc@NextExamination.NewCandidate(_, _, _, _) => nc.copy(pruned = pruned.toMap ++ nc.pruned) + } + } + case None => + if (unknowns.nonEmpty && unknowns.size < nbExaminedCandidates) { + nbExaminedCandidates = unknowns.size + remainingCandidates ++= unknowns.keys + unknowns.clear() + pickNextExamination() match { + case d@NextExamination.Done(_, _) => d.copy(pruned = pruned.toMap ++ d.pruned) + case nc@NextExamination.NewCandidate(_, _, _, _) => nc.copy(pruned = pruned.toMap ++ nc.pruned) + } + } else { + NextExamination.Done(pruned.toMap, getCurrentResults()) + } + } + } + + def prepareRound(): Seq[FunDef] = { + val (cand, roundState) = examinationState match { + case ExaminationState.Examining(cand, roundState) => (cand, roundState) + case ExaminationState.PickNext => + sys.error("Trace must be in `Examining` state") + } + assert(roundState.equivLemmas == EquivLemmas.ToGenerate) + val conf = EquivCheckConf(symbols.functions(roundState.model), symbols.functions(cand), roundState.strat, topLevel = true) + val generated = equivalenceCheck(conf) + val equivLemmas = EquivLemmas.Generated(generated.eqLemma.id, generated.proof.map(_.id), generated.sublemmasAndReplFns.map(_.id)) + examinationState = ExaminationState.Examining(cand, roundState.copy(equivLemmas = equivLemmas)) + generated.eqLemma +: (generated.proof.toSeq ++ generated.sublemmasAndReplFns) + } + + def concludeRound(analysis: VerificationAnalysis): RoundConclusion = examinationState match { + case ExaminationState.Examining(cand, RoundState(model, remainingModels, strat, EquivLemmas.Generated(eqLemma, proof, sublemmas), currCumulativeSolvingTime)) => + val solvingInfo = extractSolvingInfo(analysis, cand, eqLemma +: (proof.toSeq ++ sublemmas)) + + def nextRoundOrUnknown(): RoundConclusion = { + models(model) = models(model) - 1 + (strat.order, strat.subFnsMatchingStrat) match { + case (EquivCheckOrder.ModelFirst, None) => + val nextStrat = EquivCheckStrategy(EquivCheckOrder.CandidateFirst, None) + val nextRS = RoundState(model, remainingModels, nextStrat, EquivLemmas.ToGenerate, currCumulativeSolvingTime + solvingInfo.time) + examinationState = ExaminationState.Examining(cand, nextRS) + RoundConclusion.NextRound(cand, model, nextStrat, Set.empty) + + case (EquivCheckOrder.CandidateFirst, None) => + val subFnsMatching = allSubFnsMatches(model, cand) + val pruned = pruneSubFnsMatching(subFnsMatching) + if (pruned.passed.isEmpty) { + // No matching for subfunctions available, we pick the next model if available + nextModelOrUnknown(pruned.invalidPairs) + } else { + val nextStrat = EquivCheckStrategy(EquivCheckOrder.ModelFirst, + Some(SubFnsMatchingStrat(pruned.passed.head, pruned.passed.tail, pruned.passed))) + val nextRS = RoundState(model, remainingModels, nextStrat, EquivLemmas.ToGenerate, currCumulativeSolvingTime + solvingInfo.time) + examinationState = ExaminationState.Examining(cand, nextRS) + RoundConclusion.NextRound(cand, model, nextStrat, pruned.invalidPairs) + } + + case (EquivCheckOrder.ModelFirst, Some(matchingStrat)) => + // Prune the remaining once again, maybe we got new ctex in the meantime + val pruned = pruneSubFnsMatching(matchingStrat.rest) + if (pruned.passed.nonEmpty) { + // Try with the next matching + val nextStrat = EquivCheckStrategy(EquivCheckOrder.ModelFirst, + Some(matchingStrat.copy(curr = pruned.passed.head, rest = pruned.passed.tail))) + val nextRS = RoundState(model, remainingModels, nextStrat, EquivLemmas.ToGenerate, currCumulativeSolvingTime + solvingInfo.time) + examinationState = ExaminationState.Examining(cand, nextRS) + RoundConclusion.NextRound(cand, model, nextStrat, pruned.invalidPairs) + } else { + // Move to function first with subfns matching, if possible + // Reuse the computed matching instead of computing it again, + // but prune it once again, maybe we got new ctex in the meantime. + val allPruned = pruneSubFnsMatching(matchingStrat.all) + if (allPruned.passed.isEmpty) { + // No matching for subfunctions available. + // We pick the next model if available + nextModelOrUnknown(pruned.invalidPairs ++ allPruned.invalidPairs) + } else { + val nextStrat = EquivCheckStrategy(EquivCheckOrder.CandidateFirst, + Some(SubFnsMatchingStrat( + allPruned.passed.head, allPruned.passed.tail, + // Update `all` with its re-pruned version + allPruned.passed))) + val nextRS = RoundState(model, remainingModels, nextStrat, EquivLemmas.ToGenerate, currCumulativeSolvingTime + solvingInfo.time) + examinationState = ExaminationState.Examining(cand, nextRS) + RoundConclusion.NextRound(cand, model, nextStrat, pruned.invalidPairs ++ allPruned.invalidPairs) + } + } + + case (EquivCheckOrder.CandidateFirst, Some(matchingStrat)) => + val pruned = pruneSubFnsMatching(matchingStrat.rest) + if (pruned.passed.nonEmpty) { + // Try with the next matching + val nextStrat = EquivCheckStrategy(EquivCheckOrder.CandidateFirst, + Some(matchingStrat.copy(curr = pruned.passed.head, rest = pruned.passed.tail))) + val nextRS = RoundState(model, remainingModels, nextStrat, EquivLemmas.ToGenerate, currCumulativeSolvingTime + solvingInfo.time) + examinationState = ExaminationState.Examining(cand, nextRS) + RoundConclusion.NextRound(cand, model, nextStrat, pruned.invalidPairs) + } else { + nextModelOrUnknown(pruned.invalidPairs) + } + } + } + def nextModelOrUnknown(invalidPairs: Set[(Identifier, Identifier, ArgPermutation)]): RoundConclusion = { + candidateTestedModels(cand) += model + if (remainingModels.nonEmpty) { + val nextStrat = EquivCheckStrategy.init + val nextRS = RoundState(remainingModels.head, remainingModels.tail, nextStrat, EquivLemmas.ToGenerate, currCumulativeSolvingTime + solvingInfo.time) + examinationState = ExaminationState.Examining(cand, nextRS) + RoundConclusion.NextRound(cand, remainingModels.head, nextStrat, invalidPairs) + } else { + // oh no, manual inspection incoming + examinationState = ExaminationState.PickNext + unknowns += cand -> UnknownData(solvingInfo.withAddedTime(currCumulativeSolvingTime)) + RoundConclusion.CandidateClassified(cand, Classification.Unknown, invalidPairs) + } + } + + ///////////////////////////////////////////////////////////////////////////////////// + + val report = analysis.toReport + val allCtexs = analysis.vrs.collect { + case (vc, VCResult(VCStatus.Invalid(VCStatus.CounterExample(model)), _, _)) => + ctexOrderedArguments(vc.fid, model.program)(model.vars).map(vc.fid -> _) + }.flatten.groupMap(_._1)(_._2) + + if (report.totalInvalid != 0) { + assert(allCtexs.nonEmpty, "Conspiration!") + allCtexs.foreach { case (_, ctexs) => + ctexs.foreach(addCtex) + } + if (strat.subFnsMatchingStrat.isDefined) { + nextRoundOrUnknown() + } else { + // schade + val candFd = symbols.functions(cand) + // Take all ctex for `cand`, `eqLemma` and `proof` + val ctexOrderedArgs = (Seq(cand, eqLemma) ++ proof.toSeq).flatMap(id => allCtexs.getOrElse(id, Seq.empty)) + val ctexsMap = ctexOrderedArgs.map(ctex => candFd.params.zip(ctex).toMap) + erroneous += cand -> ErroneousData(ctexsMap, Some(solvingInfo.withAddedTime(currCumulativeSolvingTime))) + examinationState = ExaminationState.PickNext + RoundConclusion.CandidateClassified(cand, Classification.Invalid(ctexsMap), Set.empty) + } + } else if (report.totalUnknown != 0) { + nextRoundOrUnknown() + } else { + assert(!models.contains(cand)) + val modelPath = valid.get(model).map(_.path).getOrElse(Seq.empty) + valid += cand -> ValidData(model +: modelPath, solvingInfo.withAddedTime(currCumulativeSolvingTime)) + val currScore = models(model) + models(model) = currScore + (if (currScore > 0) 20 else 100) + models(cand) = 0 // Welcome to the privileged club of models! + clusters.getOrElseUpdate(model, mutable.ArrayBuffer.empty) += cand + examinationState = ExaminationState.PickNext + RoundConclusion.CandidateClassified(cand, Classification.Valid(model), Set.empty) + } + + case _ => sys.error("Trace must be in `Examining` state with `Generated` lemmas") + } + //endregion + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + //region Generation of functions encoding equivalence + + private case class GeneratedEqLemmas(eqLemma: FunDef, proof: Option[FunDef], sublemmasAndReplFns: Seq[FunDef]) + + // Generate eqLemma and sublemmas for the given top-level model and candidate functions + private def equivalenceCheck(conf: EquivCheckConf): GeneratedEqLemmas = { + import conf.{fd1, fd2} + import exprOps._ + + // For the top-level model and candidate function + val permutation = ArgPermutation(conf.model.params.indices) // No permutation for top-level model and candidate + val (eqLemmaResTopLvl, topLvlRepl) = generateEqLemma(conf, permutation) + // For the sub-functions + val eqLemmasResSubs = conf.strat.subFnsMatchingStrat.toSeq.flatMap { matchingStrat => + matchingStrat.curr.pairs.flatMap { + case ((submod, subcand), perm) => + val newConf = conf.copy(model = symbols.functions(submod), candidate = symbols.functions(subcand), topLevel = false) + val (subres, subRepl) = generateEqLemma(newConf, perm) + Seq(subres.updatedFd) ++ subres.helper.toSeq ++ subRepl.toSeq + } + } + + GeneratedEqLemmas(eqLemmaResTopLvl.updatedFd, eqLemmaResTopLvl.helper, topLvlRepl.toSeq ++ eqLemmasResSubs) + } + + // Generate an eqLemma for the given fd1 and fd2 functions and the given permutation for the candidate function + private def generateEqLemma(conf: EquivCheckConf, perm: ArgPermutation): (ElimTraceInduct, Option[FunDef]) = { + import conf.{fd1, fd2} + import exprOps._ + + assert(areSignaturesCompatibleModuloPerm(conf.model, conf.candidate, perm)) // i.e. fd1 and fd2 + val freshId = FreshIdentifier(CheckFilter.fixedFullName(fd1.id) + "$" + CheckFilter.fixedFullName(fd2.id)) + val eqLemma0 = exprOps.freshenSignature(fd1).copy(id = freshId) + + // Body of fd2, with calls to subfunctions replaced + val fd2Repl = conf.strat.subFnsMatchingStrat.map { matchingStrat => + val replMap = matchingStrat.curr.pairs.map { + case ((submod, subcand), perm) => + conf.strat.order match { + case EquivCheckOrder.ModelFirst => + // f1 = model and f2 = candidate, and we want to replace all calls to candidate subfunctions by their models counterpart + subcand -> (submod, perm.m2c) + case EquivCheckOrder.CandidateFirst => + // Note: perm gives the permutation model ix -> cand ix, so we need to reverse it here + submod -> (subcand, perm.reverse.m2c) + } + } + inductPattern(symbols, fd2, fd2, "replacement", replMap) + .setPos(fd2.getPos) + .copy(flags = Seq(Derived(Some(fd2.id)))) + } + + val newParamTps = eqLemma0.tparams.map { tparam => tparam.tp } + val newParamVars = eqLemma0.params.map { param => param.toVariable } + + val subst = { + val nweParamVarsPermuted = conf.strat.order match { + case EquivCheckOrder.ModelFirst => + // f1 = model, f2 = candidate, so no re-ordering + newParamVars + case EquivCheckOrder.CandidateFirst => + // f1 = candidate, f2 = model: we need to "undo" the ordering + perm.m2c.map(newParamVars) + } + (conf.model.params.map(_.id) zip nweParamVarsPermuted).toMap + } + val tsubst = (conf.model.tparams zip newParamTps).map { case (tparam, targ) => tparam.tp.id -> targ }.toMap + val specializer = new Specializer(eqLemma0, eqLemma0.id, tsubst, subst, Map()) + + val specs = BodyWithSpecs(conf.model.fullBody).specs.filter(s => s.kind == LetKind || s.kind == PreconditionKind) + val pre = specs.map { + case Precondition(cond) => Precondition(specializer.transform(cond)) + case LetInSpec(vd, expr) => LetInSpec(vd, specializer.transform(expr)) + } + val (paramsFun1, paramsFun2) = { + conf.strat.order match { + case EquivCheckOrder.ModelFirst => + // f1 = model, f2 = candidate + (newParamVars, perm.reverse.m2c.map(newParamVars)) + case EquivCheckOrder.CandidateFirst => + (newParamVars, perm.m2c.map(newParamVars)) + } + } + val fun1 = FunctionInvocation(fd1.id, newParamTps, paramsFun1) + val fun2 = FunctionInvocation(fd2Repl.map(_.id).getOrElse(fd2.id), newParamTps, paramsFun2) + + val (normFun1, normFun2) = norm match { + case Some(n) if conf.topLevel => // Norm applies only to top-level model & candidate functions + (FunctionInvocation(n, newParamTps, newParamVars :+ fun1), + FunctionInvocation(n, newParamTps, newParamVars :+ fun2)) + case _ => (fun1, fun2) + } + + val res = ValDef.fresh("res", UnitType()) + val cond = Equals(normFun1, normFun2) + val post = Postcondition(Lambda(Seq(res), cond)) + val body = UnitLiteral() + val withPre = exprOps.reconstructSpecs(pre, Some(body), UnitType()) + val eqLemma1 = eqLemma0.copy( + fullBody = BodyWithSpecs(withPre).withSpec(post).reconstructed, + flags = Seq(Derived(Some(fd1.id)), Annotation("traceInduct", List(StringLiteral(fd1.id.name)))), + returnType = UnitType() + ) + val elim = elimTraceInduct(symbols, eqLemma1) + .getOrElse(sys.error("Impossible, eqLemma is annotated with @traceInduct")) + (elim, fd2Repl) + } + //endregion + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + //region Evaluation of model & function with collected counter-examples + + private enum EvalCheck { + case Ok + case FailsTest(testId: Identifier, sampleIx: Int, ctex: Ctex) + case FailsCtex(ctex: Ctex) + } + + // Eval check for top level candidate and model + private def evalCheck(model: FunDef, cand: FunDef): EvalCheck = { + assert(areSignaturesCompatible(model, cand)) + + def passAllTests: Option[EvalCheck.FailsTest] = { + def passTest(id: Identifier): Option[EvalCheck.FailsTest] = { + val (samples, instParams) = tests(id) + findMap(samples.zipWithIndex) { case (arg, sampleIx) => + passTestSample(arg, instParams).map(_ -> sampleIx) + }.map { case ((evalArgs, expected, got), sampleIx) => + EvalCheck.FailsTest(id, sampleIx, Ctex(cand.params.zip(evalArgs).toMap, expected, got)) + } + } + + def loop(tests: Seq[Identifier]): Option[EvalCheck.FailsTest] = { + if (tests.isEmpty) None + else passTest(tests.head) match { + case Some(f) => Some(f) + case None => loop(tests.tail) + } + } + + loop(tests.keys.toSeq) + } + + def passTestSample(arg: Expr, instTparams: Seq[Type]): Option[(Seq[Expr], Expr, Expr)] = { + val evalArg = try { + evaluate(arg) match { + case inox.evaluators.EvaluationResults.Successful(evalArg) => evalArg + case _ => + return None // If we cannot evaluate the argument (which should be a tuple), then we consider this test to be "successful" + } + } catch { + case NonFatal(_) => return None + } + val argsSplit = evalArg match { + case Tuple(args) => args + case _ => return None // ditto, we will not crash + } + + val invocationCand = FunctionInvocation(cand.id, instTparams, argsSplit) + val invocationModel = FunctionInvocation(allModels.head, instTparams, argsSplit) // any model will do + try { + (evaluate(invocationCand), evaluate(invocationModel)) match { + case (inox.evaluators.EvaluationResults.Successful(output), inox.evaluators.EvaluationResults.Successful(expected)) => + if (output == expected) None + else Some((argsSplit, expected, output)) + case _ => None + } + } catch { + case NonFatal(_) => None + } + } + + def evaluate(expr: Expr) = { + val syms: symbols.type = symbols + type ProgramType = inox.Program {val trees: self.trees.type; val symbols: syms.type} + val prog: ProgramType = inox.Program(self.trees)(syms) + val sem = new inox.Semantics { + val trees: self.trees.type = self.trees + val symbols: syms.type = syms + val program: prog.type = prog + + def createEvaluator(ctx: inox.Context) = ??? + + def createSolver(ctx: inox.Context) = ??? + } + class EvalImpl(override val program: prog.type, override val context: inox.Context) + (using override val semantics: sem.type) + extends evaluators.RecursiveEvaluator(program, context) + with inox.evaluators.HasDefaultGlobalContext + with inox.evaluators.HasDefaultRecContext + + val evaluator = new EvalImpl(prog, self.context)(using sem) + evaluator.eval(expr) + } + + val permutation = ArgPermutation(model.params.indices) // No permutation for top-level model and candidate + passAllTests + .orElse(evalCheckCtexOnly(model, cand, permutation).map(EvalCheck.FailsCtex.apply)) + .getOrElse(EvalCheck.Ok) + } + + // Eval check for top level candidate and model and their subfunctions + private def evalCheckCtexOnly(model: FunDef, cand: FunDef, candPerm: ArgPermutation): Option[Ctex] = { + assert(areSignaturesCompatibleModuloPerm(model, cand, candPerm)) + val subst = TyParamSubst(IntegerType(), i => Some(IntegerLiteral(i))) + + def passUnordCtex(ctex: UnordCtex): Option[(Seq[Expr], Expr, Expr)] = { + // From `ctex`, generate all possible ordered permutations of args according to the types + // If the type multiplicity is 1 for all params, then there is only one ordered ctex possible + val ctexSeq = ctex.args.toSeq + val perms = cartesianProduct(ctexSeq.map { case (_, args) => args.permutations.toSeq }) + findMap(perms.take(ctexEvalPermutationLimit).toSeq) { perm => + assert(perm.size == ctexSeq.size, "Cartesian product is hard to grasp, yes") + val permTpeMap: Map[Type, Seq[Expr]] = ctexSeq.map(_._1).zip(perm).toMap + assert(permTpeMap.forall { case (tpe, args) => args.forall(_.getType(using symbols) == tpe) }) + + // For each type, the current index within permTpeMap + val tpeIxs = mutable.Map.from(ctexSeq.map(_._1 -> 0)) + val ordArgs = for (vd <- model.params) yield { + val vdTpeInst = substTypeParams(model.tparams, vd.tpe)(using subst) + val arg = permTpeMap(vdTpeInst)(tpeIxs(vdTpeInst)) + tpeIxs(vdTpeInst) = tpeIxs(vdTpeInst) + 1 + arg + } + passOrdCtex(ordArgs).map { case (exp, got) => (ordArgs, exp, got) } + } + } + + def passOrdCtex(args: Seq[Expr]): Option[(Expr, Expr)] = { + val syms: symbols.type = symbols + type ProgramType = inox.Program {val trees: self.trees.type; val symbols: syms.type} + val prog: ProgramType = inox.Program(self.trees)(syms) + val sem = new inox.Semantics { + val trees: prog.trees.type = prog.trees + val symbols: syms.type = prog.symbols + val program: prog.type = prog + + def createEvaluator(ctx: inox.Context) = ??? + + def createSolver(ctx: inox.Context) = ??? + } + class EvalImpl(override val program: prog.type, override val context: inox.Context) + (using override val semantics: sem.type) + extends evaluators.RecursiveEvaluator(program, context) + with inox.evaluators.HasDefaultGlobalContext + with inox.evaluators.HasDefaultRecContext { + override lazy val maxSteps: Int = maxStepsEval + } + val evaluator = new EvalImpl(prog, self.context)(using sem) + + val tparams = model.tparams.map(_ => IntegerType()) + val invocationModel = evaluator.program.trees.FunctionInvocation(model.id, tparams, args) + val invocationCand = evaluator.program.trees.FunctionInvocation(cand.id, tparams, candPerm.reverse.m2c.map(args)) + try { + (evaluator.eval(invocationCand), evaluator.eval(invocationModel)) match { + case (inox.evaluators.EvaluationResults.Successful(output), inox.evaluators.EvaluationResults.Successful(expected)) => + if (output == expected) None + else Some((expected, output)) + case _ => None + } + } catch { + case NonFatal(_) => None + } + } + + // Substitute tparams with IntegerType() + val argsTpe = model.params.map(vd => substTypeParams(model.tparams, vd.tpe)(using subst)) + val unordSig = UnordSig(argsTpe.groupMapReduce(identity)(_ => 1)(_ + _)) + val ctexs = ctexsDb.getOrElse(unordSig, mutable.ArrayBuffer.empty) + findMap(ctexs.toSeq)(passUnordCtex) + .map { case (ctex, expected, got) => + // ctex is ordered according to the model, so we need to reorder cand according to the permutation + val candReorg = candPerm.m2c.map(cand.params) + Ctex(candReorg.zip(ctex).toMap, expected, got) + } + } + //endregion + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + //region Pruning of sub function matching + + private case class PrunedSubFnsMatching(passed: Seq[SubFnsMatching], invalidPairs: Set[(Identifier, Identifier, ArgPermutation)]) + + private def pruneSubFnsMatching(matching: Seq[SubFnsMatching]): PrunedSubFnsMatching = { + def loop(matching: Seq[SubFnsMatching], + acc: Seq[SubFnsMatching], + invalidPairs: Set[(Identifier, Identifier, ArgPermutation)]): (Seq[SubFnsMatching], Set[(Identifier, Identifier, ArgPermutation)]) = matching match { + case Seq() => (acc, invalidPairs) + case m +: rest => + // If this matching contains pairs that are invalid, skip it and go to the next + val mpairs = m.pairs.map { case ((mod, cand), perm) => (mod, cand, perm) }.toSet + if (mpairs.intersect(invalidPairs).nonEmpty) loop(rest, acc, invalidPairs) + else { + // Otherwise, try to falsify this matching by finding an invalid pair + val newInvPair = findMap(m.pairs.toSeq) { case ((mod, cand), perm) => + evalCheckCtexOnly(symbols.functions(mod), symbols.functions(cand), perm) + .map(_ => (mod, cand, perm)) + } + newInvPair match { + case Some((mod, cand, perm)) => + // A fine addition to my collection of invalid pairs + loop(rest, acc, invalidPairs + ((mod, cand, perm))) + case None => + // This matching passed the pruning, add it to the result + loop(rest, acc :+ m, invalidPairs) + } + } + } + val startInvPairs = invalidFunctionsPairsCache.toSet + val (remaining, invalidPairs) = loop(matching, Seq.empty, startInvPairs) + val extra = invalidPairs -- startInvPairs + invalidFunctionsPairsCache ++= extra + PrunedSubFnsMatching(remaining, extra) + } + //endregion + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + //region Generation of all possible matching for model and candidate subfunctions + + // Note: does not perform pruning by counter-example evaluation + private def allSubFnsMatches(model: Identifier, cand: Identifier): Seq[SubFnsMatching] = { + import math.Ordering.Implicits.seqOrdering + // Get all the (non-library) function transitive calls in the body of fd - excluding fd itself + def getTransitiveCalls(f: Identifier): Set[FunDef] = + symbols.transitiveCallees(f).filter(_ != f).map(symbols.functions(_)) + .filter(!_.flags.exists(_.name == "library")) + + def compatibleRetTpe(mod: FunDef, cand: FunDef): Boolean = { + // To check return type, substitute all cand tparams by mod's + val substMap = cand.tparams.zip(mod.tparams).map { case (tpd2, tpd1) => tpd2.tp -> (tpd1.tp: Type) }.toMap + mod.returnType == typeOps.instantiateType(cand.returnType, substMap) + } + + // Ensure that we do *not* match `choose` functions created from `choose` expressions. + // If we were to match them, we would unveil `choose` expressions which we don't want to do + // because these must remain hidden behind their `choose` expression counterpart. + def isChooseStub(fd: FunDef): Boolean = fd.id.name == "choose" && (fd.fullBody match { + case Choose(_, _) => true + case _ => false + }) + + val modSubs = getTransitiveCalls(model) + val candSubs = getTransitiveCalls(cand) + // All pairs model-candidate subfns that with compatible signature modulo arg permutation + val allValidPairs = for { + ms <- modSubs + if !isChooseStub(ms) + cs <- candSubs + if !isChooseStub(cs) + // If allArgsPermutations returns empty, then this ms-cs pairs is not compatible + argPerm <- allArgsPermutations(ms.params, ms.tparams, cs.params, cs.tparams) + if compatibleRetTpe(ms, cs) // still needs to check for return type, as allArgsPermutations is only about the arguments + } yield (ms.id, cs.id, argPerm) + + val allValidPairs2: Map[(Identifier, Identifier), Seq[ArgPermutation]] = + allValidPairs.groupMap { case (ms, cs, _) => (ms, cs) }(_._3) + .view.mapValues(_.toSeq.sortBy(_.m2c)).toMap // Sort the arg. permutation to ensure deterministic traversal + + // Matches between identifiers, with all possible permutation + val resMatching0: Set[Matching[Identifier, Seq[ArgPermutation]]] = allMatching(allValidPairs2) + if (resMatching0.isEmpty) return Seq.empty + + // Sort the results to ensure deterministic traversal and picking + val resMatching1: Seq[Matching[Identifier, Seq[ArgPermutation]]] = resMatching0.toSeq.sortBy(_.pairs.keys.toSeq) + + // To avoid explosion of all possible function matching with all combination of argument permutation, + // we distribute the number of maximum permutations per function matching s.t. we don't exceed maxMatchingPermutation + def distributePermutations(budget: Int, perms: Seq[Int]): Seq[Int] = { + type Ix = Int + def helper(budget: Int, remaining: Map[Ix, Int], distributed: Map[Ix, Int]): Seq[Int] = { + assert(budget >= 0) + assert(remaining.forall(_._2 > 0)) + assert(distributed.forall(_._2 >= 0)) + if (remaining.isEmpty) distributed.toSeq.sortBy(_._1).map(_._2) + else { + val toDistr0 = budget / remaining.size + if (toDistr0 == 0) { + val toInc = remaining.keys.toSeq.sorted.take(budget % remaining.size) + val distr2 = distributed ++ toInc.map(ix => ix -> (distributed(ix) + 1)).toMap + distr2.toSeq.sortBy(_._1).map(_._2) + } else { + val toDistr = math.min(toDistr0, remaining.values.min) + val rem2 = remaining.view.mapValues(_ - toDistr).filter(_._2 != 0).toMap + val distr2 = distributed ++ remaining.keys.map(ix => ix -> (distributed(ix) + toDistr)).toMap + val res = helper(budget - toDistr * remaining.size, rem2, distr2) + res + } + } + } + helper(budget, perms.zipWithIndex.map { case (p, ix) => ix -> p }.toMap, perms.indices.map(_ -> 0).toMap) + } + val nbPermutations = resMatching1.map(_.pairs.map(_._2.size).product) + val nbPermutationsDistr = distributePermutations(maxMatchingPermutation, nbPermutations) + + // We want a Matching[Identifier, ArgPermutation] and not a Seq[ArgPermutation] as "edge data" + def allPermutation(m: Matching[Identifier, Seq[ArgPermutation]], maxPerms: Int): Seq[Matching[Identifier, ArgPermutation]] = { + val tuples = m.pairs.toSeq.map { case ((mod, cand), perms) => + perms.map(perm => (mod, cand, perm)) + } + cartesianProduct(tuples).take(maxPerms).map { pairs => + val pairs2 = pairs.map { case (mod, cand, perm) => (mod, cand) -> perm }.toMap + assert(m.pairs.keySet == pairs2.keySet, "Cartesian product is hard") + assert(pairs2.forall { case ((mod, cand), perm) => m.pairs((mod, cand)).contains(perm) }, "Making sense of all of this is hard") + Matching(pairs2) + }.toSeq + } + + val res = resMatching1.zipWithIndex.flatMap { case (m, ix) => allPermutation(m, nbPermutationsDistr(ix)) } + assert(res.size <= maxMatchingPermutation, "oh no, I lied :(") + res + } + + // Note: tparams are not checked for permutation + private def allArgsPermutations(params1: Seq[ValDef], tparams1: Seq[TypeParameterDef], + params2: Seq[ValDef], tparams2: Seq[TypeParameterDef]): Set[ArgPermutation] = + EquivalenceChecker.allArgsPermutations(trees)(params1, tparams1, params2, tparams2) + //endregion + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + //region Miscellaneous + + private def addCtex(ctex: Seq[Expr]): Unit = { + val currNbCtex = ctexsDb.map(_._2.size).sum + if (currNbCtex < maxCtex) { + val sig = UnordSig(ctex.groupMapReduce(_.getType(using symbols))(_ => 1)(_ + _)) + val unordCtex = UnordCtex(ctex.groupBy(_.getType(using symbols))) + val arr = ctexsDb.getOrElseUpdate(sig, mutable.Set.empty) + arr += unordCtex + } + } + + // Note: order enforced + private def areSignaturesCompatible(fd1: FunDef, fd2: FunDef): Boolean = + EquivalenceChecker.areSignaturesCompatible(trees)(fd1, fd2) + + private def areSignaturesCompatibleModuloPerm(fd1: FunDef, fd2: FunDef, argPerm: ArgPermutation): Boolean = + EquivalenceChecker.areSignaturesCompatibleModuloPerm(trees)(fd1, fd2, argPerm) + + // deps refers to equivalence lemma, proofs and sublemmas. + private def extractSolvingInfo(analysis: VerificationAnalysis, cand: Identifier, deps: Seq[Identifier]): SolvingInfo = { + val all = (cand +: deps).toSet + val (time, solvers, fromCache, trivial) = analysis.vrs.foldLeft((0L, Set.empty[String], true, true)) { + case ((time, solvers, fromCache, trivial), (vc, vcRes)) if all(vc.fid) => + (time + vcRes.time.getOrElse(0L), solvers ++ vcRes.solverName.toSet, fromCache && vcRes.isValidFromCache, trivial && vcRes.isTrivial) + case (acc, _) => acc + } + val solver = if (solvers.isEmpty) None else Some(solvers.mkString(", ")) + SolvingInfo(time, solver, fromCache, trivial) + } + + // Order the ctex arguments according to the signature of `fn`, instantiate type parameters to BigInt and fill missing values with default values + private def ctexOrderedArguments(fn: Identifier, prog: StainlessProgram)(ctex: Map[prog.trees.ValDef, prog.trees.Expr]): Option[Seq[Expr]] = { + // `fn` may be a sublemma, etc. that do not originally belong to `symbols`, we pick the symbols of prog which will definitely contain it + val fd = prog.symbols.functions(fn) + val params = fd.params.map(prog2self(prog)(_)) + val tparams = fd.tparams.map(prog2self(prog)(_)) + val subst = TyParamSubst(IntegerType(), i => Some(IntegerLiteral(i))) + tryRemapCtex(params, tparams, prog)(ctex)(using subst, context) match { + case RemappedCtex.Success(args, _) => Some(args) + case _ => None + } + } + + private def findMap[A, B](as: Seq[A])(f: A => Option[B]): Option[B] = + if (as.isEmpty) None + else f(as.head).orElse(findMap(as.tail)(f)) + + // Adapted from inox.utils.SeqUtils, but lazy + private type Tuple[T] = Seq[T] + + private def cartesianProduct[T](seqs: Tuple[Seq[T]]): Iterator[Tuple[T]] = { + val sizes = seqs.map(_.size) + val max = sizes.product + + for (i <- (0 until max).iterator) yield { + var c = i + var sel = -1 + for (s <- sizes) yield { + val index = c % s + c = c / s + sel += 1 + seqs(sel)(index) + } + } + } + //endregion +} + +object EquivalenceChecker { + val defaultN = 3 + val defaultInitScore = 200 + val defaultMaxMatchingPermutation = 16 + val defaultMaxCtex = 1024 + val defaultMaxStepsEval = 512 + + type Path = Seq[String] + + sealed trait Creation { + val trees: Trees + val symbols: trees.Symbols + } + object Creation { + sealed trait Success extends Creation { self => + val equivChker: EquivalenceChecker{val trees: self.trees.type; val symbols: self.symbols.type} + } + sealed trait Failure(val reason: FailureReason) extends Creation + } + + enum FailureReason { + case IllFormedTests(invalid: Map[Identifier, TestExtractionFailure]) + case NoModels + case NoFunctions + case ModelsSignatureMismatch(m1: Identifier, m2: Identifier) + case OverlappingModelsAndFunctions(overlapping: Set[Identifier]) + case MultipleNormFunctions(norms: Set[Identifier]) + case NormSignatureMismatch(norm: Identifier) + case IllegalHyperparameterValue(details: String) + } + + enum TestExtractionFailure { + case ModelTakesNoArg // i.e. the model function does not take any argument, how can we possibly feed any data? + case ReturnTypeMismatch + case UnknownExpr + case NoData + case UnificationFailure + } + + def tryCreate(ts: Trees)(syms: ts.Symbols)(using context: inox.Context): Creation { + val trees: ts.type + val symbols: syms.type + } = { + val pathsOptCandidates: Option[Seq[Path]] = context.options.findOption(equivchk.optCompareFuns) map { functions => + functions map CheckFilter.fullNameToPath + } + val pathsOptModels: Option[Seq[Path]] = context.options.findOption(equivchk.optModels) map { functions => + functions map CheckFilter.fullNameToPath + } + val pathsOptNorm: Option[Seq[Path]] = + Some(Seq(context.options.findOptionOrDefault(equivchk.optNorm)).map(CheckFilter.fullNameToPath)) + + def isNorm(fid: Identifier): Boolean = indexOfPath(pathsOptNorm, fid).isDefined + + def failure(reason: FailureReason) = { + class FailureImpl(override val trees: ts.type, override val symbols: syms.type) extends Creation.Failure(reason) + new FailureImpl(ts, syms) + } + + val n = context.options.findOptionOrDefault(optN) + if (n <= 0) { + return failure(FailureReason.IllegalHyperparameterValue(s"${optN.name} must be strictly positive")) + } + val initScore = context.options.findOptionOrDefault(optInitScore) + // If you want to give negative score to your models, sure, do so! + + val maxPerm = context.options.findOptionOrDefault(optMaxPerm) + if (maxPerm <= 0) { + return failure(FailureReason.IllegalHyperparameterValue(s"${optMaxPerm.name} must be strictly positive")) + } + + val maxCtex = context.options.findOptionOrDefault(optMaxCtex) + if (maxCtex <= 0) { + return failure(FailureReason.IllegalHyperparameterValue(s"${optMaxCtex.name} must be strictly positive")) + } + + val models = syms.functions.values.flatMap { fd => + if (fd.flags.exists(_.name == "library")) None + else indexOfPath(pathsOptModels, fd.id).map(_ -> fd.id) + }.toSeq.distinct.sorted.map(_._2) + + val functions = syms.functions.values.flatMap { fd => + if (fd.flags.exists(_.name == "library")) None + else indexOfPath(pathsOptCandidates, fd.id).map(_ -> fd.id) + }.toSeq.distinct.sorted.map(_._2) + + if (models.isEmpty) { + return failure(FailureReason.NoModels) + } + if (functions.isEmpty) { + return failure(FailureReason.NoFunctions) + } + + val overlapping = models.toSet.intersect(functions.toSet) + if (overlapping.nonEmpty) { + return failure(FailureReason.OverlappingModelsAndFunctions(overlapping)) + } + + validateModelsSignature(ts)(syms, models) match { + case Some((m1, m2)) => + return failure(FailureReason.ModelsSignatureMismatch(m1, m2)) + case None => () + } + + val norm = { + val allNorms = syms.functions.values.filter(elem => isNorm(elem.id)).map(_.id).toSet + if (allNorms.size > 1) return failure(FailureReason.MultipleNormFunctions(allNorms)) + else if (allNorms.isEmpty) None + else { + val norm = allNorms.head + if (checkArgsNorm(ts)(syms, models.head, norm)) Some(norm) + else return failure(FailureReason.NormSignatureMismatch(norm)) + } + } + + val (testsNok0, testsOk0) = syms.functions.values.filter(_.flags.exists(elem => elem.name == "mkTest")).partitionMap { fd => + extractTest(ts)(syms, models.head, fd.id) match { + case success: ExtractedTest.Success => + Right(fd.id -> (success.samples.asInstanceOf[Seq[ts.Expr]], success.instTparams.asInstanceOf[Seq[ts.Type]])) + case ExtractedTest.Failure(reason) => Left(fd.id -> reason) + } + } + val testsNok = testsNok0.toMap + if (testsNok.nonEmpty) { + return failure(FailureReason.IllFormedTests(testsNok)) + } + val testsOk = testsOk0.toMap + + class EquivalenceCheckerImpl(override val trees: ts.type, override val symbols: syms.type) + extends EquivalenceChecker(ts, models, functions, norm, n, initScore, maxPerm, maxCtex, defaultMaxStepsEval)(testsOk, symbols) + val ec = new EquivalenceCheckerImpl(ts, syms) + class SuccessImpl(override val trees: ts.type, override val symbols: syms.type, override val equivChker: ec.type) extends Creation.Success + new SuccessImpl(ts, syms, ec) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + private def validateModelsSignature(ts: Trees)(syms: ts.Symbols, models: Seq[Identifier]): Option[(Identifier, Identifier)] = { + import ts._ + val unordPairs = for { + (m1, i1) <- models.zipWithIndex + m2 <- models.drop(i1 + 1) + } yield (m1, m2) + unordPairs.find { case (m1, m2) => !areSignaturesCompatible(ts)(syms.functions(m1), syms.functions(m2)) } + } + + private def areSignaturesCompatible(ts: Trees)(fd1: ts.FunDef, fd2: ts.FunDef): Boolean = + areSignaturesCompatibleModuloPerm(ts)(fd1.params, fd1.tparams, fd1.returnType, fd2.params, fd2.tparams, fd2.returnType, ArgPermutation(fd1.params.indices)) + + private def areSignaturesCompatibleModuloPerm(ts: Trees)(fd1: ts.FunDef, fd2: ts.FunDef, perm: ArgPermutation): Boolean = + areSignaturesCompatibleModuloPerm(ts)(fd1.params, fd1.tparams, fd1.returnType, fd2.params, fd2.tparams, fd2.returnType, perm) + + private def areSignaturesCompatibleModuloPerm(ts: Trees)( + params1: Seq[ts.ValDef], tparams1: Seq[ts.TypeParameterDef], retTpe1: ts.Type, + params2: Seq[ts.ValDef], tparams2: Seq[ts.TypeParameterDef], retTpe2: ts.Type, + perm: ArgPermutation + ): Boolean = { + import ts._ + // To check signature, substitute all t2 tparams by t1's + val substMap = tparams2.zip(tparams1).map { case (tpd2, tpd1) => tpd2.tp -> (tpd1.tp: Type) }.toMap + + def tpeOk(t1: Type, t2: Type): Boolean = t1 == typeOps.instantiateType(t2, substMap) + + params1.size == params2.size && tparams1.size == tparams2.size && + params1.zip(perm.m2c.map(params2)).forall { case (vd1, vd2) => tpeOk(vd1.tpe, vd2.tpe) } && + tpeOk(retTpe1, retTpe2) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + case class ArgPermutation(m2c: Seq[Int]) { + require(m2c.toSet == m2c.indices.toSet, "Permutations are hard") + + def reverse: ArgPermutation = + // zipWithIndex will give (candidate ix, model ix) + ArgPermutation(m2c.zipWithIndex.sortBy(_._1).map(_._2)) + } + + // All pairs of a matching in pairs.keySet, the "A" is extra information + case class Matching[N, A](pairs: Map[(N, N), A]) { + require(pairs.nonEmpty) + require(isMatching(pairs.keySet), "Matching is hard") + } + + // Note: tparams are not checked for permutation + private def allArgsPermutations(ts: Trees)(vdparams1: Seq[ts.ValDef], tparams1: Seq[ts.TypeParameterDef], + vdparams2: Seq[ts.ValDef], tparams2: Seq[ts.TypeParameterDef]): Set[ArgPermutation] = { + import ts._ + if (vdparams1.size != vdparams2.size || tparams1.size != tparams2.size) return Set.empty + + val params1 = vdparams1.map(_.tpe) + val params2 = vdparams2.map(_.tpe) + + def multiplicity(tps: Seq[Type]): Map[Type, Int] = tps.groupMapReduce(identity)(_ => 1)(_ + _) + + // Substitute all t2 tparams by t1's + val substMap = tparams2.zip(tparams1).map { case (tpd2, tpd1) => tpd2.tp -> (tpd1.tp: Type) }.toMap + + val params2Substed = params2.map(typeOps.instantiateType(_, substMap)) + val p1Mult = multiplicity(params1) + val p2Mult = multiplicity(params2Substed) + if (p1Mult != p2Mult) return Set.empty + + // Type -> list of their index in params2 + val ixsTpe2 = params2Substed.zipWithIndex.groupMap(_._1)(_._2) + // A graph whose nodes are the index of the arguments for params1 and params2 + // There is an edge iff the types are the same + val edges = params1.zipWithIndex.flatMap { case (tp, ix1) => + ixsTpe2(tp).map(ix2 => ix1 -> ix2) + }.toSet + // Sanity check: must of the same type + assert(edges.forall { case (ix1, ix2) => params1(ix1) == params2Substed(ix2) }, "Constructing a graph for matching types is hard, isn't it?") + allMatching(edges.map(_ -> ()).toMap) + .map(m => ArgPermutation(m.pairs.keys.toSeq.sortBy(_._1).map(_._2))) + } + + private def isMatching[T](pairs: Set[(T, T)]): Boolean = { + val l2rs = pairs.groupMap(_._1)(_._2) + val r2ls = pairs.groupMap(_._2)(_._1) + l2rs.forall(_._2.size == 1) && r2ls.forall(_._2.size == 1) + } + + // All matching for the given edges. The "A" is extra information for the given edge (unused for matching but useful for other applications) + private def allMatching[N, A](edges: Map[(N, N), A]): Set[Matching[N, A]] = { + + // Like Matching but without the data, which we will feed once we are done + case class Mtching(pairs: Set[(N, N)]) + + // Remove all edges touching `l` (on the left) + def rmLeft(edges: Set[(N, N)], l: N): Set[(N, N)] = + edges.filter { case (left, _) => left != l } + + def rmRight(edges: Set[(N, N)], r: N): Set[(N, N)] = + edges.filter { case (_, right) => right != r } + + def rec(left: Set[N], edges: Set[(N, N)]): Set[Mtching] = { + if (edges.isEmpty) Set.empty + else if (isMatching(edges)) Set(Mtching(edges)) + else { + assert(left.nonEmpty) + val l = left.head + val allR = edges.filter(_._1 == l).map(_._2) + // All possible sub matching when not taking any edge from l + val skipping = { + val skip = rec(left.tail, rmLeft(edges, l)) + skip.flatMap { m => + // r's that are not matched in the sub problem are added back to the matching solution + val unmatchedRs = allR -- m.pairs.map(_._2) + if (unmatchedRs.isEmpty) Set(m) + else unmatchedRs.map(r => Mtching(m.pairs + ((l, r)))) + } + } + val subs = allR.flatMap { r => + val subs = rec(left.tail, rmLeft(rmRight(edges, r), l)) + if (subs.isEmpty) Set(Mtching(Set((l, r)))) + else subs.map(m => Mtching(m.pairs + ((l, r)))) + } + subs ++ skipping + } + } + + rec(edges.keySet.map(_._1), edges.keySet) + .map(m => Matching(m.pairs.map(e => e -> edges(e)).toMap)) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + private def indexOfPath(paths: Option[Seq[Path]], fid: Identifier): Option[Int] = paths match { + case None => None + case Some(paths) => + // Support wildcard `_` as specified in the documentation. + // A leading wildcard is always assumed. + val path: Path = CheckFilter.fullNameToPath(CheckFilter.fixedFullName(fid)) + val ix = paths indexWhere { p => + if (p endsWith Seq("_")) path containsSlice p.init + else path endsWith p + } + if (ix >= 0) Some(ix) + else None + } + + // To be compatible: + // -norm's first n-1 params must be compatible with the model params + // -norm return type must be compatible with return type of model + // -norm last param must be the same as its return type (i.e. compatible with the return type of model) + private def checkArgsNorm(ts: Trees)(symbols: ts.Symbols, model: Identifier, norm: Identifier): Boolean = { + val m = symbols.functions(model) + val n = symbols.functions(norm) + n.params.nonEmpty && n.params.last.tpe == n.returnType && + areSignaturesCompatibleModuloPerm(ts)(m.params, m.tparams, m.returnType, n.params.init, n.tparams, n.returnType, ArgPermutation(m.params.indices)) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + private enum ExtractedTest { + case Success(ts: Trees)(val samples: Seq[ts.Expr], val instTparams: Seq[ts.Type]) + case Failure(reason: TestExtractionFailure) + } + + private def extractTest(ts: Trees)(syms: ts.Symbols, anyModelId: Identifier, fnId: Identifier): ExtractedTest = { + import ts._ + given ts.Symbols = syms + + val anyModelFd = syms.functions(anyModelId) + if (anyModelFd.params.isEmpty) { + // Why on earth would you do that?? + return ExtractedTest.Failure(TestExtractionFailure.ModelTakesNoArg) + } + + val fd = syms.functions(fnId) + val elemsTpe = fd.returnType match { + case ADTType(id: SymbolIdentifier, Seq(tp)) if id.symbol.path == Seq("stainless", "collection", "List") => tp + case _ => return ExtractedTest.Failure(TestExtractionFailure.ReturnTypeMismatch) + } + + def peel(e: Expr, acc: Seq[Expr]): Either[Expr, Seq[Expr]] = e match { + case ADT(id: SymbolIdentifier, _, Seq(head, tail)) if id.symbol.path == Seq("stainless", "collection", "Cons") => + peel(tail, acc :+ head) + case ADT(id: SymbolIdentifier, _, Seq()) if id.symbol.path == Seq("stainless", "collection", "Nil") => + Right(acc) + case _ => Left(e) + } + + val samples = peel(fd.fullBody, Seq.empty) match { + case Left(_) => return ExtractedTest.Failure(TestExtractionFailure.UnknownExpr) + case Right(Seq()) => return ExtractedTest.Failure(TestExtractionFailure.NoData) + case Right(samplesTupled) => samplesTupled + } + + val modelTpe = { + if (anyModelFd.params.size == 1) anyModelFd.params.head.tpe + else TupleType(anyModelFd.params.map(_.tpe)) + } + val instTparams = syms.unify(elemsTpe, modelTpe, anyModelFd.tparams.map(_.tp)) match { + case Some(mapping0) => + assert(mapping0.map(_._1).toSet == anyModelFd.tparams.map(_.tp).toSet) + val mapping = mapping0.toMap + anyModelFd.tparams.map(tpd => mapping(tpd.tp)) + case None => return ExtractedTest.Failure(TestExtractionFailure.UnificationFailure) + } + ExtractedTest.Success(ts)(samples, instTparams) + } +} \ No newline at end of file diff --git a/core/src/main/scala/stainless/equivchk/EquivalenceCheckingAnalysis.scala b/core/src/main/scala/stainless/equivchk/EquivalenceCheckingAnalysis.scala new file mode 100644 index 0000000000..92a467fad0 --- /dev/null +++ b/core/src/main/scala/stainless/equivchk/EquivalenceCheckingAnalysis.scala @@ -0,0 +1,17 @@ +package stainless +package equivchk + +import inox.utils.{NoPosition, Position} +import stainless.AbstractAnalysis +import stainless.extraction.ExtractionSummary +import stainless.verification.VerificationAnalysis +import EquivalenceCheckingReport._ + +class EquivalenceCheckingAnalysis(val sources: Set[Identifier], + val records: Seq[Record], + val extractionSummary: ExtractionSummary) extends AbstractAnalysis { self => + override val name: String = EquivalenceCheckingComponent.name + override type Report = EquivalenceCheckingReport + + override def toReport: EquivalenceCheckingReport = new EquivalenceCheckingReport(records, sources, extractionSummary) +} diff --git a/core/src/main/scala/stainless/equivchk/EquivalenceCheckingComponent.scala b/core/src/main/scala/stainless/equivchk/EquivalenceCheckingComponent.scala new file mode 100644 index 0000000000..d9812a9292 --- /dev/null +++ b/core/src/main/scala/stainless/equivchk/EquivalenceCheckingComponent.scala @@ -0,0 +1,352 @@ +package stainless +package equivchk + +import inox.Context +import io.circe.Json +import stainless.extraction._ +import stainless.extraction.xlang.{trees => xt} +import stainless.termination.MeasureInference +import stainless.utils.{CheckFilter, JsonUtils} +import stainless.verification._ + +import java.io.File +import scala.concurrent.Future + +object DebugSectionEquivChk extends inox.DebugSection("equivchk") + +object optCompareFuns extends inox.OptionDef[Seq[String]] { + val name = "comparefuns" + val default = Seq[String]() + val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser) + val usageRhs = "f1,f2,..." +} + +object optModels extends inox.OptionDef[Seq[String]] { + val name = "models" + val default = Seq[String]() + val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser) + val usageRhs = "f1,f2,..." +} + +object optNorm extends inox.StringOptionDef("norm", "", "f") + +object optEquivalenceOutput extends FileOptionDef { + val name = "equivchk-output" + val default = new File("equivchk-output.json") +} + +object optN extends inox.IntOptionDef("equivchk-n", EquivalenceChecker.defaultN, "") +object optInitScore extends inox.IntOptionDef("equivchk-init-score", EquivalenceChecker.defaultInitScore, "") +object optMaxPerm extends inox.IntOptionDef("equivchk-max-perm", EquivalenceChecker.defaultMaxMatchingPermutation, "") +object optMaxCtex extends inox.IntOptionDef("equivchk-max-ctex", EquivalenceChecker.defaultMaxCtex, "") + +object EquivalenceCheckingComponent extends Component { + override val name = "equivchk" + override val description = "Equivalence checking of functions" + + override type Report = EquivalenceCheckingReport + override type Analysis = EquivalenceCheckingAnalysis + + override val lowering = { + class LoweringImpl(override val s: trees.type, override val t: trees.type) + extends transformers.ConcreteTreeTransformer(s, t) + inox.transformers.SymbolTransformer(new LoweringImpl(trees, trees)) + } + + override def run(pipeline: StainlessPipeline)(using inox.Context): EquivalenceCheckingRun = { + new EquivalenceCheckingRun(pipeline) + } +} + +class EquivalenceCheckingRun private(override val component: EquivalenceCheckingComponent.type, + override val trees: stainless.trees.type, + override val pipeline: StainlessPipeline, // unused - we have our own + val underlyingRun: VerificationRun) + (using override val context: inox.Context) extends ComponentRun { self => + import EquivalenceCheckingReport._ + def this(pipeline: StainlessPipeline)(using inox.Context) = + this(EquivalenceCheckingComponent, stainless.trees, pipeline, new VerificationRun(pipeline)) + + import component.{Analysis, Report} + import extraction.given + + override def parse(json: Json): Report = ??? + + override def createPipeline = underlyingRun.createPipeline + given givenDebugSection: DebugSectionEquivChk.type = DebugSectionEquivChk + + private val tracePrePipeline: ExtractionPipeline { val s: xlang.trees.type; val t: trace.trees.type } = + xlang.extractor andThen + innerclasses.extractor andThen + methods.extractor andThen + throwing.extractor andThen + imperative.extractor andThen + oo.extractor andThen + innerfuns.extractor andThen + inlining.extractor + + private val tracePostPipeline: ExtractionPipeline { val s: trace.trees.type; val t: trees.type } = + trace.extractor andThen + termination.extractor andThen + extraction.utils.DebugPipeline("MeasureInference", MeasureInference(extraction.trees)) andThen + extraction.utils.DebugPipeline("PartialEvaluation", PartialEvaluation(extraction.trees)) andThen + extraction.completer(trees) + + override def apply(ids: Seq[Identifier], symbols: xt.Symbols): Future[Analysis] = { + val (traceSyms, traceSummary) = tracePrePipeline.extract(symbols) + val (plainSyms, plainSummary) = tracePostPipeline.extract(traceSyms) + val ec = EquivalenceChecker.tryCreate(trace.trees)(traceSyms) match { + case fail: EquivalenceChecker.Creation.Failure => explainAndSayFarewell(fail.reason) + case success: EquivalenceChecker.Creation.Success => success.equivChker + } + val toProcess = createFilter.filter(ids, plainSyms, underlyingRun.component) + for { + gen <- underlyingRun.execute(toProcess, plainSyms, ExtractionSummary.Node(traceSummary, plainSummary)) + invalidVCsCands = counterExamples(gen).flatMap { + case (vc, ctex) => ec.reportErroneous(gen.program)(gen, ctex)(vc.fid).getOrElse(Set.empty) + }.toSeq.distinct + _ = debugInvalidVCsCandidates(invalidVCsCands) + trRes <- equivCheck(ec) + } yield buildAnalysis(ec)(ids, gen, trRes) + } + + private def buildAnalysis(ec: EquivalenceChecker)(ids: Seq[Identifier], general: VerificationAnalysis, trRes: ec.Results): EquivalenceCheckingAnalysis = { + val genRecors = general.toReport.results.map { verifRecord => + Record(verifRecord.id, verifRecord.pos, verifRecord.time, Status.Verification(verifRecord.status), verifRecord.solverName, verifRecord.kind, verifRecord.derivedFrom) + } + val valid = trRes.valid.toSeq.sortBy(_._1).map { + case (v, data) => + assert(data.path.nonEmpty) + val fd = ec.symbols.getFunction(v) + val directModel = data.path.head + Record(v, fd.getPos, data.solvingInfo.time, + Status.Equivalence(EquivalenceStatus.Valid(directModel, data.solvingInfo.fromCache, data.solvingInfo.trivial)), + data.solvingInfo.solverName, "equivalence", fd.source) + } + val errns = trRes.erroneous.toSeq.sortBy(_._1).map { + case (errn, data) => + val fd = ec.symbols.getFunction(errn) + Record(errn, fd.getPos, data.solvingInfo.map(_.time).getOrElse(0L), + Status.Equivalence(EquivalenceStatus.Erroneous), + data.solvingInfo.flatMap(_.solverName), "equivalence", fd.source) + } + val wrgs = trRes.wrongs.toSeq.sorted.map { wrong => + val fd = ec.symbols.getFunction(wrong) + // No solver or time specified because it's a signature mismatch + Record(wrong, fd.getPos, 0L, Status.Equivalence(EquivalenceStatus.Wrong), None, "equivalence", fd.source) + } + val unknws = trRes.unknowns.toSeq.sortBy(_._1).map { + case (unknown, data) => + val fd = ec.symbols.getFunction(unknown) + Record(unknown, fd.getPos, data.solvingInfo.time, Status.Equivalence(EquivalenceStatus.Unknown), data.solvingInfo.solverName, "equivalence", fd.source) + } + + val allRecords = genRecors ++ valid ++ errns ++ wrgs ++ unknws + new EquivalenceCheckingAnalysis(ids.toSet, allRecords, general.extractionSummary) + } + + private def equivCheck(ec: EquivalenceChecker): Future[ec.Results] = { + class Identity(override val s: ec.trees.type, override val t: trace.trees.type) extends transformers.ConcreteTreeTransformer(s, t) + val identity = new Identity(ec.trees, trace.trees) + + def examination(): Future[ec.Results] = { + ec.pickNextExamination() match { + case ec.NextExamination.Done(pruned, res) => + debugPruned(ec)(pruned) + printResults(ec)(res) + context.options.findOption(equivchk.optEquivalenceOutput) match { + case Some(out) => dumpResultsJson(out, ec)(res) + case None => () + } + Future.successful(res) + case ec.NextExamination.NewCandidate(cand, model, strat, pruned) => + debugPruned(ec)(pruned) + debugNewCandidate(ec)(cand, model, strat) + round() + } + } + + def round(): Future[ec.Results] = { + val generated = ec.prepareRound() + val allFns = (ec.symbols.functions -- generated.map(_.id)).values.toSeq ++ generated + val syms: trace.trees.Symbols = trace.trees.NoSymbols + .withSorts(ec.symbols.sorts.values.map(identity.transform).toSeq) + .withFunctions(allFns.map(identity.transform)) + val plainSyms = tracePostPipeline.extract(syms)._1 + underlyingRun.execute(generated.map(_.id), plainSyms, ExtractionSummary.NoSummary) + .flatMap { analysis => + val concl = ec.concludeRound(analysis) + concl match { + case ec.RoundConclusion.NextRound(cand, model, strat, prunedSubFnsPairs) => + debugNewRound(ec)(cand, model, strat) + debugPrunedSubFnsPairs(prunedSubFnsPairs) + round() + case ec.RoundConclusion.CandidateClassified(cand, classification, prunedSubFnsPairs) => + debugClassified(ec)(cand, classification) + debugPrunedSubFnsPairs(prunedSubFnsPairs) + examination() + } + } + } + + examination() + } + + private def counterExamples(analysis: VerificationAnalysis) = { + analysis.vrs.collect { + case (vc, VCResult(VCStatus.Invalid(VCStatus.CounterExample(model)), _, _)) => + vc -> model + }.toMap + } + + private def debugInvalidVCsCandidates(cands: Seq[Identifier]): Unit = { + if (cands.nonEmpty) { + context.reporter.whenDebug(DebugSectionEquivChk) { debug => + debug(s"The following candidates were pruned for having invalid VCs:") + val candsStr = cands.sorted.map(_.fullName).mkString(" ", "\n ", "") + debug(candsStr) + } + } + } + + private def debugPruned(ec: EquivalenceChecker)(pruned: Map[Identifier, ec.PruningReason]): Unit = { + def pretty(fn: Identifier, reason: ec.PruningReason): String = { + val rsonStr = reason match { + case ec.PruningReason.SignatureMismatch => "signature mismatch" + case ec.PruningReason.ByTest(testId, sampleIx, ctex) => + s"test falsification by ${testId.fullName} sample n°${sampleIx + 1} with ${prettyCtex(ec)(ctex.mapping)}\n Expected: ${ctex.expected} but got: ${ctex.got}" + case ec.PruningReason.ByPreviousCtex(ctex) => s"counter-example falsification with ${prettyCtex(ec)(ctex.mapping)}\n Expected: ${ctex.expected} but got: ${ctex.got}" + } + s"${fn.fullName}: $rsonStr" + } + + if (pruned.nonEmpty) { + context.reporter.whenDebug(DebugSectionEquivChk) { debug => + debug("The following functions were pruned:") + val strs = pruned.toSeq.sortBy(_._1).map(pretty.tupled) + strs.foreach(s => debug(s" $s")) + } + } + } + + private def debugNewCandidate(ec: EquivalenceChecker)(cand: Identifier, model: Identifier, strat: ec.EquivCheckStrategy): Unit = { + context.reporter.whenDebug(DebugSectionEquivChk) { debug => + debug(s"Picking new candidate: ${cand.fullName} with model ${model.fullName} and strategy ${strat.pretty}") + } + } + + private def debugNewRound(ec: EquivalenceChecker)(cand: Identifier, model: Identifier, strat: ec.EquivCheckStrategy): Unit = { + context.reporter.whenDebug(DebugSectionEquivChk) { debug => + debug(s"Retry for ${cand.fullName} with model ${model.fullName} and strategy: ${strat.pretty}") + } + } + + private def debugClassified(ec: EquivalenceChecker)(cand: Identifier, classification: ec.Classification): Unit = { + context.reporter.whenDebug(DebugSectionEquivChk) { debug => + val msg = classification match { + case ec.Classification.Valid(model) => s"valid whose direct model is ${model.fullName}" + case ec.Classification.Invalid(ctexs) => + val ctexStr = ctexs.map(prettyCtex(ec)(_)).map(s => s" $s").mkString("\n ") + s"invalid with the following counter-examples:\n $ctexStr" + case ec.Classification.Unknown => "unknown" + } + debug(s"${cand.fullName} is $msg") + } + } + + private def debugPrunedSubFnsPairs(prunedSubFnsPairs: Set[(Identifier, Identifier, EquivalenceChecker.ArgPermutation)]): Unit = { + if (prunedSubFnsPairs.nonEmpty) { + context.reporter.whenDebug(DebugSectionEquivChk) { debug => + val str = prunedSubFnsPairs.toSeq.map { case (mod, cand, perm) => (mod.fullName, cand.fullName, perm) }.sortBy(_._1) + .map { case (mod, cand, perm) => s"$mod <-> $cand (permutation = ${perm.m2c.mkString(", ")})" } + .mkString(", ") + debug(s"(pruned all matching containing the following subfns pairs: $str)") + } + } + } + + private def printResults(ec: EquivalenceChecker)(res: ec.Results): Unit = { + import context.reporter.info + + info("Printing equivalence checking results:") + res.equiv.foreach { case (m, l) => + val lStr = l.map(_.fullName) + info(s"List of functions that are equivalent to model ${m.fullName}: ${lStr.mkString(", ")}") + } + val errns = res.erroneous.keys.toSeq.map(_.fullName).sorted.mkString(", ") + val unknowns = res.unknowns.keys.toSeq.map(_.fullName).sorted.mkString(", ") + val wrongs = res.wrongs.toSeq.map(_.fullName).sorted.mkString(", ") + info(s"List of erroneous functions: $errns") + info(s"List of timed-out functions: $unknowns") + info(s"List of wrong functions: $wrongs") + info(s"Printing the final state:") + res.valid.foreach { case (cand, data) => + val pathStr = data.path.map(_.fullName).mkString(", ") + info(s"Path for the function ${cand.fullName}: $pathStr") + } + res.erroneous.foreach { case (cand, data) => + val ctexsStr = data.ctexs.map(ctex => ctex.map { case (vd, arg) => s"${vd.id.name} -> $arg" }.mkString(", ")) + info(s"Counterexample for the function ${cand.fullName}:") + ctexsStr.foreach(s => info(s" $s")) + } + } + + private def prettyCtex(ec: EquivalenceChecker)(ctex: Map[ec.trees.ValDef, ec.trees.Expr]): String = + ctex.toSeq.map { case (vd, e) => s"${vd.id.name} -> $e" }.mkString(", ") + + private def dumpResultsJson(out: File, ec: EquivalenceChecker)(res: ec.Results): Unit = { + val equivs = res.equiv.map { case (m, l) => m.fullName -> l.map(_.fullName).toSeq.sorted } + .toSeq.sortBy(_._1) + val errns = res.erroneous.keys.toSeq.map(_.fullName).sorted + val unknowns = res.unknowns.keys.toSeq.map(_.fullName).sorted + val wrongs = res.wrongs.toSeq.map(_.fullName).sorted + + val json = Json.fromFields(Seq( + "equivalent" -> Json.fromValues(equivs.map { case (m, l) => + Json.fromFields(Seq( + "model" -> Json.fromString(m), + "functions" -> Json.fromValues(l.map(Json.fromString)) + )) + }), + "erroneous" -> Json.fromValues(errns.map(Json.fromString)), + "timeout" -> Json.fromValues(unknowns.sorted.map(Json.fromString)), + "wrong" -> Json.fromValues(wrongs.sorted.map(Json.fromString)), + )) + JsonUtils.writeFile(out, json) + } + + private def explainAndSayFarewell(reason: EquivalenceChecker.FailureReason): Nothing = { + def pretty(reason: EquivalenceChecker.TestExtractionFailure): String = reason match { + case EquivalenceChecker.TestExtractionFailure.ModelTakesNoArg => "models do not take any argument" + case EquivalenceChecker.TestExtractionFailure.ReturnTypeMismatch => "the return type is not a list" + case EquivalenceChecker.TestExtractionFailure.UnknownExpr => "use of non-literal expression" + case EquivalenceChecker.TestExtractionFailure.NoData => "no given samples" + case EquivalenceChecker.TestExtractionFailure.UnificationFailure => "could not unify the sample type with models argument types" + } + val msg = reason match { + case EquivalenceChecker.FailureReason.IllFormedTests(invalid) => + val msg = invalid.map { case (id, reason) => s" ${id.fullName}: ${pretty(reason)}" } + s"the following tests are ill-formed:\n${msg.mkString("\n")}" + case EquivalenceChecker.FailureReason.NoModels => "there are no specified models" + case EquivalenceChecker.FailureReason.NoFunctions => "there are no specified candidate functions" + case EquivalenceChecker.FailureReason.ModelsSignatureMismatch(m1, m2) => + s"models ${m1.fullName} and ${m2.fullName} signatures do not match" + case EquivalenceChecker.FailureReason.OverlappingModelsAndFunctions(overlapping) => + val verb = if (overlapping.size == 1) "is" else "are" + s"${overlapping.mkString(", ")} $verb both candidate and model" + case EquivalenceChecker.FailureReason.MultipleNormFunctions(norms) => + s"multiple norm functions are specified:\n${norms.toSeq.sorted.map(_.fullName).mkString(" ", "\n ", "")}" + case EquivalenceChecker.FailureReason.NormSignatureMismatch(norm) => + s"norm function ${norm.fullName} signature does not match with model signature" + case EquivalenceChecker.FailureReason.IllegalHyperparameterValue(msg) => msg + } + context.reporter.fatalError(s"Could not create the equivalence check component because $msg") + } + + override private[stainless] def execute(functions0: Seq[Identifier], symbols: trees.Symbols, exSummary: ExtractionSummary): Future[Analysis] = + sys.error("Unreachable because def apply was overridden") + + extension (id: Identifier) { + private def fullName: String = CheckFilter.fixedFullName(id) + } +} diff --git a/core/src/main/scala/stainless/equivchk/EquivalenceCheckingReport.scala b/core/src/main/scala/stainless/equivchk/EquivalenceCheckingReport.scala new file mode 100644 index 0000000000..bd2a5a8261 --- /dev/null +++ b/core/src/main/scala/stainless/equivchk/EquivalenceCheckingReport.scala @@ -0,0 +1,114 @@ +package stainless +package equivchk + +import io.circe._ +import io.circe.generic.semiauto._ +import io.circe.syntax._ +import stainless.extraction._ +import stainless.extraction.xlang.{trees => xt} +import stainless.termination.MeasureInference +import stainless.utils.JsonConvertions.given +import stainless.utils.{CheckFilter, JsonUtils} +import stainless.verification.VerificationReport.{Status => VerificationStatus} +import stainless.verification._ + +object EquivalenceCheckingReport { + + enum Status { + case Verification(status: VerificationStatus) + case Equivalence(status: EquivalenceStatus) + + def isValid: Boolean = this match { + case Verification(status) => status.isValid + case Equivalence(EquivalenceStatus.Valid(_, _, _)) => true + case _ => false + } + + def isValidFromCache: Boolean = this match { + case Verification(status) => status.isValidFromCache + case Equivalence(EquivalenceStatus.Valid(_, fromCache, _)) => fromCache + case _ => false + } + + def isTrivial: Boolean = this match { + case Verification(status) => status.isTrivial + case Equivalence(EquivalenceStatus.Valid(_, _, trivial)) => trivial + case _ => false + } + + def isInvalid: Boolean = this match { + case Verification(status) => status.isInvalid + case Equivalence(EquivalenceStatus.Erroneous | EquivalenceStatus.Wrong) => true + case _ => false + } + + def isInconclusive: Boolean = this match { + case Verification(status) => status.isInconclusive + case Equivalence(EquivalenceStatus.Unknown) => true + case _ => false + } + } + + enum EquivalenceStatus { + case Valid(model: Identifier, fromCache: Boolean, trivial: Boolean) + case Erroneous + case Wrong + case Unknown + } + + case class Record(id: Identifier, pos: inox.utils.Position, time: Long, + status: Status, solverName: Option[String], kind: String, + derivedFrom: Identifier) extends AbstractReportHelper.Record + + given equivStatusDecoder: Decoder[EquivalenceStatus] = deriveDecoder + given equivStatusEncoder: Encoder[EquivalenceStatus] = deriveEncoder + + given statusDecoder: Decoder[Status] = deriveDecoder + given statusEncoder: Encoder[Status] = deriveEncoder + + given recordDecoder: Decoder[Record] = deriveDecoder + given recordEncoder: Encoder[Record] = deriveEncoder +} + +class EquivalenceCheckingReport(override val results: Seq[EquivalenceCheckingReport.Record], + override val sources: Set[Identifier], + override val extractionSummary: ExtractionSummary) extends BuildableAbstractReport[EquivalenceCheckingReport.Record, EquivalenceCheckingReport] { + import EquivalenceCheckingReport.{_, given} + override protected val encoder: Encoder[Record] = recordEncoder + + override protected def build(results: Seq[Record], sources: Set[Identifier]): EquivalenceCheckingReport = + new EquivalenceCheckingReport(results, sources, ExtractionSummary.NoSummary) + + override val name: String = EquivalenceCheckingComponent.name + + override lazy val annotatedRows: Seq[RecordRow] = results map { + case Record(id, pos, time, status, solverName, kind, _) => + val statusName = status match { + case Status.Verification(stat) => stat.name + case Status.Equivalence(EquivalenceStatus.Valid(model, _, _)) => CheckFilter.fixedFullName(model) + case Status.Equivalence(EquivalenceStatus.Wrong) => "signature mismatch" + case Status.Equivalence(EquivalenceStatus.Erroneous) => "erroneous" + case Status.Equivalence(EquivalenceStatus.Unknown) => "unknown" + } + val level = levelOf(status) + val solver = solverName getOrElse "" + val extra = Seq(kind, statusName, solver) + RecordRow(id, pos, level, extra, time) + } + lazy val totalConditions: Int = results.size + lazy val totalTime: Long = results.map(_.time).sum + lazy val totalValid: Int = results.count(_.status.isValid) + lazy val totalValidFromCache: Int = results.count(_.status.isValidFromCache) + lazy val totalTrivial: Int = results.count(_.status.isTrivial) + lazy val totalInvalid: Int = results.count(_.status.isInvalid) + lazy val totalUnknown: Int = results.count(_.status.isInconclusive) + + private def levelOf(status: Status) = { + if (status.isValid) Level.Normal + else if (status.isInconclusive) Level.Warning + else Level.Error + } + + override lazy val stats: ReportStats = + ReportStats(totalConditions, totalTime, totalValid, totalValidFromCache, totalTrivial, totalInvalid, totalUnknown) +} diff --git a/core/src/main/scala/stainless/equivchk/Utils.scala b/core/src/main/scala/stainless/equivchk/Utils.scala new file mode 100644 index 0000000000..acf5315819 --- /dev/null +++ b/core/src/main/scala/stainless/equivchk/Utils.scala @@ -0,0 +1,158 @@ +package stainless +package equivchk + +import stainless.extraction.trace.Trees +import stainless.{FreshIdentifier, Identifier} + +import scala.collection.immutable.SeqMap + +trait Utils { + val trees: Trees + val context: inox.Context + + import trees._ + import exprOps.{BodyWithSpecs, Postcondition, PostconditionKind} + + class Specializer(origFd: FunDef, + newId: Identifier, + tsubst: Map[Identifier, Type], + vsubst: Map[Identifier, Expr], + replacement: Map[Identifier, (Identifier, Seq[Int])]) // replace function calls + extends ConcreteSelfTreeTransformer { + + override def transform(expr: Expr): Expr = expr match { + case v: Variable => + vsubst.getOrElse(v.id, super.transform(v)) + + case fi: FunctionInvocation if fi.id == origFd.id => + val fi1 = FunctionInvocation(newId, tps = fi.tps, args = fi.args) + super.transform(fi1.copiedFrom(fi)) + + case fi: FunctionInvocation if replacement.contains(fi.id) => + val (repl, perm) = replacement(fi.id) + val fi1 = FunctionInvocation(repl, tps = fi.tps, args = perm.map(fi.args)) + super.transform(fi1.copiedFrom(fi)) + + case _ => super.transform(expr) + } + + override def transform(tpe: Type): Type = tpe match { + case tp: TypeParameter => + tsubst.getOrElse(tp.id, super.transform(tp)) + + case _ => super.transform(tpe) + } + } + + // make a copy of the 'model' + // combine the specs of the 'lemma' + // 'suffix': only used for naming + // 'replacement': function calls that are supposed to be replaced according to given mapping and permutation + def inductPattern(symbols: Symbols, model: FunDef, lemma: FunDef, suffix: String, replacement: Map[Identifier, (Identifier, Seq[Int])]) = { + import exprOps._ + import symbols.{_, given} + + val indPattern = exprOps.freshenSignature(model).copy(id = FreshIdentifier(s"${lemma.id}$$$suffix")) + val newParamTps = indPattern.tparams.map { tparam => tparam.tp } + val newParamVars = indPattern.params.map { param => param.toVariable } + + val fi = FunctionInvocation(model.id, newParamTps, newParamVars) + + val tpairs = model.tparams zip fi.tps + val tsubst = tpairs.map { case (tparam, targ) => tparam.tp.id -> targ }.toMap + val subst = (model.params.map(_.id) zip fi.args).toMap + val specializer = new Specializer(model, indPattern.id, tsubst, subst, replacement) + + val fullBodySpecialized = specializer.transform(exprOps.withoutSpecs(model.fullBody).get) + + val specsSubst = (lemma.params.map(_.id) zip newParamVars).toMap ++ (model.params.map(_.id) zip newParamVars).toMap + val specsTsubst = ((lemma.tparams zip fi.tps) ++ (model.tparams zip fi.tps)).map { case (tparam, targ) => tparam.tp.id -> targ }.toMap + val specsSpecializer = new Specializer(indPattern, indPattern.id, specsTsubst, specsSubst, Map()) + + val specs = BodyWithSpecs(model.fullBody).specs ++ BodyWithSpecs(lemma.fullBody).specs.filterNot(_.kind == MeasureKind) + val pre = specs.filterNot(_.kind == PostconditionKind).map(spec => spec match { + case Precondition(cond) => Precondition(specsSpecializer.transform(cond)).setPos(spec) + case LetInSpec(vd, expr) => LetInSpec(vd, specsSpecializer.transform(expr)).setPos(spec) + case Measure(measure) => Measure(specsSpecializer.transform(measure)).setPos(spec) + case s => context.reporter.fatalError(s"Unsupported specs: $s") + }) + + val withPre = exprOps.reconstructSpecs(pre, Some(fullBodySpecialized), indPattern.returnType) + + val speccedLemma = BodyWithSpecs(lemma.fullBody).addPost + val speccedOrig = BodyWithSpecs(model.fullBody).addPost + val postLemma = speccedLemma.getSpec(PostconditionKind).map(post => + specsSpecializer.transform(post.expr)) + val postOrig = speccedOrig.getSpec(PostconditionKind).map(post => specsSpecializer.transform(post.expr)) + + (postLemma, postOrig) match { + case (Some(Lambda(Seq(res1), cond1)), Some(Lambda(Seq(res2), cond2))) => + val res = ValDef.fresh("res", indPattern.returnType) + val freshCond1 = exprOps.replaceFromSymbols(Map(res1 -> res.toVariable), cond1) + val freshCond2 = exprOps.replaceFromSymbols(Map(res2 -> res.toVariable), cond2) + + val cond = andJoin(Seq(freshCond1, freshCond2)) + val post = Postcondition(Lambda(Seq(res), cond)) + + indPattern.copy( + fullBody = BodyWithSpecs(withPre).withSpec(post).reconstructed, + flags = Seq(Derived(Some(lemma.id)), Derived(Some(model.id))) + ).copiedFrom(indPattern) + + case _ => indPattern + } + } + + case class ElimTraceInduct(updatedFd: FunDef, helper: Option[FunDef]) + + def elimTraceInduct(symbols: Symbols, fds: Seq[FunDef]): Map[Identifier, ElimTraceInduct] = + // SeqMap to have deterministic traversal + SeqMap.from(fds.flatMap(fd => elimTraceInduct(symbols, fd).map(res => fd.id -> res))) + + def elimTraceInduct(symbols: Symbols, fd: FunDef): Option[ElimTraceInduct] = { + def findModel(funNames: Seq[Any]): Option[FunctionInvocation] = { + var funInv: Option[FunctionInvocation] = None + + BodyWithSpecs(fd.fullBody).getSpec(PostconditionKind) match { + case Some(Postcondition(post)) => + exprOps.preTraversal { + case _ if funInv.isDefined => // do nothing + case fi@FunctionInvocation(tfd, tps, args) if fi.id != fd.id && symbols.isRecursive(tfd) && (funNames.contains(StringLiteral(tfd.name)) || funNames.contains(StringLiteral(""))) => + val paramVars = fd.params.map(_.toVariable) + val argCheck = args.forall(paramVars.contains) && args.toSet.size == args.size + if (argCheck) funInv = Some(fi) + case _ => + }(post) + case _ => + } + + funInv + } + + fd.flags.collectFirst { + case Annotation("traceInduct", funNames) => + findModel(funNames) match { + case Some(finv) => + // make a helper lemma: + val helper = inductPattern(symbols, symbols.functions(finv.id), fd, "indProof", Map.empty).setPos(fd.getPos) + + // transform the main lemma + val proof = FunctionInvocation(helper.id, finv.tps, fd.params.map(_.toVariable)) + val returnType = typeOps.instantiateType(helper.returnType, (helper.typeArgs zip fd.typeArgs).toMap) + + val body = Let(ValDef.fresh("ind$proof", returnType), proof, exprOps.withoutSpecs(fd.fullBody).get) + val withPre = exprOps.reconstructSpecs(BodyWithSpecs(fd.fullBody).specs, Some(body), fd.returnType) + + val updFn = fd.copy( + fullBody = BodyWithSpecs(withPre).reconstructed, + flags = fd.flags.filterNot(f => f.name == "traceInduct") + ) + ElimTraceInduct(updFn, Some(helper)) + + case None => // there are no recursive calls - no model function + val updFn = fd.copy(flags = fd.flags.filterNot(f => f.name == "traceInduct")) + ElimTraceInduct(updFn, None) + } + } + } +} diff --git a/core/src/main/scala/stainless/extraction/oo/TypeOps.scala b/core/src/main/scala/stainless/extraction/oo/TypeOps.scala index d67d4232d8..9720a6e68b 100644 --- a/core/src/main/scala/stainless/extraction/oo/TypeOps.scala +++ b/core/src/main/scala/stainless/extraction/oo/TypeOps.scala @@ -199,23 +199,13 @@ trait TypeOps extends innerfuns.TypeOps { self => leastUpperBound(t1 +: t2s) != Untyped } - private class Unsolvable extends Exception - protected def unsolvable = throw new Unsolvable - - /** Collects the constraints that need to be solved for [[unify]]. - * Note: this is an override point. */ - protected def unificationConstraints(t1: Type, t2: Type, free: Seq[TypeParameter]): List[(TypeParameter, Type)] = (t1, t2) match { + override protected def unificationConstraints(t1: Type, t2: Type, free: Seq[TypeParameter]): List[(TypeParameter, Type)] = (t1, t2) match { case (ct: ClassType, _) if ct.lookupClass.isEmpty => unsolvable case (_, ct: ClassType) if ct.lookupClass.isEmpty => unsolvable case (ta: TypeApply, _) if ta.lookupTypeDef.isEmpty => unsolvable case (_, ta: TypeApply) if ta.lookupTypeDef.isEmpty => unsolvable - case (adt: ADTType, _) if adt.lookupSort.isEmpty => unsolvable - case (_, adt: ADTType) if adt.lookupSort.isEmpty => unsolvable - - case _ if t1 == t2 => Nil - case (ct1: ClassType, ct2: ClassType) if ct1.tcd.cd == ct2.tcd.cd => (ct1.tps zip ct2.tps).toList flatMap (p => unificationConstraints(p._1, p._2, free)) @@ -228,61 +218,18 @@ trait TypeOps extends innerfuns.TypeOps { self => case (tp1, ta2: TypeApply) => unificationConstraints(tp1, ta2.bounds, free) - case (adt1: ADTType, adt2: ADTType) if adt1.id == adt2.id => - (adt1.tps zip adt2.tps).toList flatMap (p => unificationConstraints(p._1, p._2, free)) - - case (rt: RefinementType, _) => unificationConstraints(rt.getType, t2, free) - case (_, rt: RefinementType) => unificationConstraints(t1, rt.getType, free) - - case (pi: PiType, _) => unificationConstraints(pi.getType, t2, free) - case (_, pi: PiType) => unificationConstraints(t1, pi.getType, free) - - case (sigma: SigmaType, _) => unificationConstraints(sigma.getType, t2, free) - case (_, sigma: SigmaType) => unificationConstraints(t1, sigma.getType, free) - case (TypeBounds(lo, hi, _), tpe) if lo == hi => unificationConstraints(hi, tpe, free) case (tpe, TypeBounds(lo, hi, _)) if lo == hi => unificationConstraints(hi, tpe, free) - case (tp: TypeParameter, _) if !(typeOps.typeParamsOf(t2) contains tp) && (free contains tp) => List(tp -> t2) - case (_, tp: TypeParameter) if !(typeOps.typeParamsOf(t1) contains tp) && (free contains tp) => List(tp -> t1) - case (_: TypeParameter, _) => unsolvable - case (_, _: TypeParameter) => unsolvable - - case typeOps.Same(NAryType(ts1, _), NAryType(ts2, _)) if ts1.size == ts2.size => - (ts1 zip ts2).toList flatMap (p => unificationConstraints(p._1, p._2, free)) - case _ => unsolvable + case _ => super.unificationConstraints(t1, t2, free) } - /** Solves the constraints collected by [[unificationConstraints]]. - * Note: this is an override point. */ - protected def unificationSolution(const: List[(Type, Type)]): List[(TypeParameter, Type)] = const match { - case Nil => Nil - case (tp: TypeParameter, t) :: tl => - val replaced = tl map { case (t1, t2) => - (typeOps.instantiateType(t1, Map(tp -> t)), typeOps.instantiateType(t2, Map(tp -> t))) - } - (tp -> t) :: unificationSolution(replaced) - case (adt: ADTType, _) :: tl if adt.lookupSort.isEmpty => unsolvable - case (_, adt: ADTType) :: tl if adt.lookupSort.isEmpty => unsolvable - case (ADTType(id1, tps1), ADTType(id2, tps2)) :: tl if id1 == id2 => - unificationSolution((tps1 zip tps2).toList ++ tl) + override protected def unificationSolution(const: List[(Type, Type)]): List[(TypeParameter, Type)] = const match { case (ct: ClassType, _) :: tl if ct.lookupClass.isEmpty => unsolvable case (_, ct: ClassType) :: tl if ct.lookupClass.isEmpty => unsolvable case (ClassType(id1, tps1), ClassType(id2, tps2)) :: tl if id1 == id2 => unificationSolution((tps1 zip tps2).toList ++ tl) - case typeOps.Same(NAryType(ts1, _), NAryType(ts2, _)) :: tl if ts1.size == ts2.size => - unificationSolution((ts1 zip ts2).toList ++ tl) - case _ => - unsolvable - } - - /** Unifies two types, under a set of free variables */ - def unify(t1: Type, t2: Type, free: Seq[TypeParameter]): Option[List[(TypeParameter, Type)]] = { - try { - Some(unificationSolution(unificationConstraints(t1, t2, free))) - } catch { - case _: Unsolvable => None - } + case _ => super.unificationSolution(const) } def patternInType(pat: Pattern): Type = pat match { diff --git a/core/src/main/scala/stainless/extraction/package.scala b/core/src/main/scala/stainless/extraction/package.scala index 4cb41a2d24..db03246555 100644 --- a/core/src/main/scala/stainless/extraction/package.scala +++ b/core/src/main/scala/stainless/extraction/package.scala @@ -54,7 +54,7 @@ package object extraction { "ChooseEncoder" -> "Encodes chooses as functions", "FunctionInlining" -> "Transitively inline marked functions", "LeonInlining" -> "Transitively inline marked functions (closer to what Leon did)", - "Trace" -> "Compare --compareFuns functions for equivalence. Expand @traceInduct", + "TraceInductElimination" -> "Expand @traceInduct", "SizedADTExtraction" -> "Transform calls to 'indexedAt' to the 'SizedADT' tree", "InductElimination" -> "Replace @induct annotation by explicit recursion", "MeasureInference" -> "Infer and inject measures in recursive functions", diff --git a/core/src/main/scala/stainless/extraction/trace/Trace.scala b/core/src/main/scala/stainless/extraction/trace/Trace.scala deleted file mode 100644 index 8e4c270dd2..0000000000 --- a/core/src/main/scala/stainless/extraction/trace/Trace.scala +++ /dev/null @@ -1,962 +0,0 @@ -/* Copyright 2009-2021 EPFL, Lausanne */ - -package stainless -package extraction -package trace - -import stainless.utils.CheckFilter - -class Trace(override val s: Trees, override val t: termination.Trees) - (using override val context: inox.Context) - extends CachingPhase - with NoSummaryPhase - with IdentityFunctions - with IdentitySorts { self => - import s._ - - override protected type TransformerContext = s.Symbols - override protected def getContext(symbols: s.Symbols) = symbols - - private[this] class Identity(override val s: self.s.type, override val t: self.t.type) extends transformers.ConcreteTreeTransformer(s, t) - private[this] val identity = new Identity(self.s, self.t) - - private def evaluate(syms: s.Symbols, expr: Expr) = { - type ProgramType = inox.Program{val trees: self.s.type; val symbols: syms.type} - val prog: ProgramType = inox.Program(self.s)(syms) - val sem = new inox.Semantics { - val trees: self.s.type = self.s - val symbols: syms.type = syms - val program: prog.type = prog - def createEvaluator(ctx: inox.Context) = ??? - def createSolver(ctx: inox.Context) = ??? - } - class EvalImpl(override val program: prog.type, override val context: inox.Context) - (using override val semantics: sem.type) - extends evaluators.RecursiveEvaluator(program, context) - with inox.evaluators.HasDefaultGlobalContext - with inox.evaluators.HasDefaultRecContext - - val evaluator = new EvalImpl(prog, self.context)(using sem) - evaluator.eval(expr) - } - - override protected def extractSymbols(context: TransformerContext, symbols: s.Symbols): (t.Symbols, AllSummaries) = { - import symbols.{given, _} - import exprOps._ - - if (Trace.getModels.isEmpty) { - val models = symbols.functions.values.toList.filter(elem => !elem.flags.exists(_.name == "library") && - isModel(elem.id)).map(elem => elem.id) - Trace.setModels(models) - Trace.nextModel - } - - if (Trace.getFunctions.isEmpty) { - val functions = symbols.functions.values.toList.filter(elem => !elem.flags.exists(_.name == "library") && - shouldBeChecked(elem.id)).map(elem => elem.id) - Trace.setFunctions(functions) - Trace.nextFunction - } - - def sameSignatures(f1: FunDef, f2: FunDef) = { - f1.params.size == f2.params.size && f1.tparams.size == f2.tparams.size && - f1.params.zip(f2.params).forall(arg => arg._1.tpe == arg._2.tpe) && - f1.returnType == f2.returnType - } - - def compatibleSignatures(f1: FunDef, f2: FunDef) = { - f1.params.size == f2.params.size && f1.tparams.size == f2.tparams.size && - f1.params.map(_.tpe).toSet.forall(t => f1.params.map(_.tpe).count(_ == t) == f2.params.map(_.tpe).count(_ == t)) && - f1.returnType == f2.returnType - } - - def sameSignaturesNorm(model: Identifier, norm: Identifier) = { - val m = symbols.functions(model) - val n = symbols.functions(norm) - - n.params.size >= 1 && n.params.init.size == m.params.size && n.tparams.size == m.tparams.size && - n.params.init.zip(m.params).forall(arg => arg._1.tpe == arg._2.tpe) - } - - if (Trace.getNorm.isEmpty) { - val normOpt = symbols.functions.values.toList.find(elem => isNorm(elem.id)).map(elem => elem.id) - - (Trace.getModel, normOpt) match { - case (Some(model), Some(norm)) if sameSignaturesNorm(model, norm) => - Trace.setNorm(normOpt) - case _ => - } - } - - symbols.functions.values.toList.foreach(fd => if (fd.flags.exists(elem => elem.name == "mkTest")) Trace.setMkTest(fd.id)) - - def generateEqLemma: List[s.FunDef] = { - - // a: argument allignment - def evalAllignment(m: FunDef, f: FunDef, a: List[Int]): Boolean = { - - val counterexamples = Trace.state.values.map(elem => elem.counterexample).filter(!_.isEmpty).map(elem => elem.get).filterNot(_.counterexample.isEmpty) - val subCounterexamples = Trace.state.values.flatMap(_.subCounterexamples) - - val allCounterexamples = (counterexamples ++ subCounterexamples) - - Trace.ordering = Trace.ordering ++ Map(f.id -> a) ++ Map(m.id -> a) - - // TODO toString - val validCounterexamples = allCounterexamples.filter { elem => - elem.counterexample.values.size == f.params.size && - m.params.map(_.tpe).map(_.toString).toList == - elem.counterexample.keys.toList.map(_.tpe).map(_.toString).toList && - elem.counterexample.keys.toList.map(_.tpe).map(_.toString).toList == - Range(0, f.params.size).map(i => f.params(a(i))).map(_.tpe).map(_.toString).toList - } - - validCounterexamples.toList.distinctBy(_.counterexample.values).take(2).forall(info => { - val pair = info - - val bval = { - type ProgramType = Program{val trees: pair.prog.trees.type; val symbols: pair.prog.symbols.type} - val prog: ProgramType = pair.prog.asInstanceOf[ProgramType] - val syms: prog.symbols.type = prog.symbols - - val sem = new inox.Semantics { - val trees: prog.trees.type = prog.trees - val symbols: syms.type = prog.symbols - val program: prog.type = prog - def createEvaluator(ctx: inox.Context) = ??? - def createSolver(ctx: inox.Context) = ??? - } - class EvalImpl(override val program: prog.type, override val context: inox.Context) - (using override val semantics: sem.type) - extends evaluators.RecursiveEvaluator(program, context) - with inox.evaluators.HasDefaultGlobalContext - with inox.evaluators.HasDefaultRecContext - val evaluator = new EvalImpl(prog, self.context)(using sem) - - Trace.ordering = Trace.ordering ++ Map(f.id -> a) ++ Map(m.id -> a) - - try { - val argsM = (m.params zip pair.counterexample).map(_._2).map(_._2) - val argsF = (f.params zip pair.counterexample).map(_._2).map(_._2) - val argsFA = Range(0, argsF.size).map(i => argsF(a(i))) - - val invocationM = evaluator.program.trees.FunctionInvocation(m.id, Seq(), argsM) - val invocationF = evaluator.program.trees.FunctionInvocation(f.id, Seq(), argsFA) - - (evaluator.eval(invocationF), evaluator.eval(invocationM)) match { - case (inox.evaluators.EvaluationResults.Successful(output), inox.evaluators.EvaluationResults.Successful(expected)) => { - output == expected - } - case err => - // println(err) - true - } - } catch { - case e => - // println(e) - true - } - - } - - bval - - }) - - } - - def evalCheck(f: FunDef, m: FunDef): Boolean = { - - val counterexamples = (Trace.state.values zip Trace.state.keys).map(elem => (elem._1.counterexample, elem._2)).filter(!_._1.isEmpty).map(elem => (elem._1.get, elem._2)).filterNot(_._1.existing).filterNot(_._1.counterexample.isEmpty).filterNot(_._1.fromEval) - - def passesAllNewTests = counterexamples.forall(counterexample => { - val pair = counterexample._1 - val fun = pair.prog.symbols.functions(counterexample._2) - val mod = pair.prog.symbols.functions(Trace.state(fun.id).directModel.get) - val ref = if (pair.fromFunction) fun else mod - - val bval = { - type ProgramType = Program{val trees: pair.prog.trees.type; val symbols: pair.prog.symbols.type} - val prog: ProgramType = pair.prog.asInstanceOf[ProgramType] - val syms: prog.symbols.type = prog.symbols - - val sem = new inox.Semantics { - val trees: prog.trees.type = prog.trees - val symbols: syms.type = prog.symbols - val program: prog.type = prog - def createEvaluator(ctx: inox.Context) = ??? - def createSolver(ctx: inox.Context) = ??? - } - class EvalImpl(override val program: prog.type, override val context: inox.Context) - (using override val semantics: sem.type) - extends evaluators.RecursiveEvaluator(program, context) - with inox.evaluators.HasDefaultGlobalContext - with inox.evaluators.HasDefaultRecContext - val evaluator = new EvalImpl(prog, self.context)(using sem) - - try { - val invocationF = evaluator.program.trees.FunctionInvocation(f.id, Seq(), ref.params.map(vd => - pair.counterexample.collectFirst({ case (k, v) if(k.id.name == vd.id.name) => v }).get)) - - val invocationM = evaluator.program.trees.FunctionInvocation(m.id, Seq(), ref.params.map(vd => - pair.counterexample.collectFirst({ case (k, v) if(k.id.name == vd.id.name) => v }).get)) - - (evaluator.eval(invocationF), evaluator.eval(invocationM)) match { - case (inox.evaluators.EvaluationResults.Successful(output), inox.evaluators.EvaluationResults.Successful(expected)) => { - if(output != expected) Trace.storeCounterexample(Some(new Trace.Counterexample { - val prog = pair.prog - val counterexample = pair.counterexample.asInstanceOf[Map[this.prog.trees.ValDef,this.prog.trees.Expr]] - val existing = false - val fromEval = true - val fromFunction = pair.fromFunction - } )) - output == expected - } - case err => - // println(err) - true - } - } catch { - case e => - // println(e) - true - } - - } - bval - }) - - def passesAllTests = Trace.getMkTest match { - case Some(t) => { - val test = symbols.functions(t) - - (1 to 5).forall(i => { - val bval = { - - val getInput = s.TupleSelect(FunctionInvocation(test.id, test.tparams.map(_.tp), Seq(IntegerLiteral(i))), 1) - val getRes = s.TupleSelect(FunctionInvocation(test.id, test.tparams.map(_.tp), Seq(IntegerLiteral(i))), 2) - - (evaluate(symbols, getInput), evaluate(symbols, getRes)) match { - case (inox.evaluators.EvaluationResults.Successful(input), inox.evaluators.EvaluationResults.Successful(res)) => { - val paramVars = input match { - case Tuple(pvars) => pvars - case pvar => Seq(pvar) - } - val evalF = s.FunctionInvocation(f.id, f.tparams.map(_.tp), paramVars) - evaluate(symbols, evalF) match { - case inox.evaluators.EvaluationResults.Successful(output) => { - val counterexample = (f.params zip paramVars).toMap - if(output != res) { - val p = Trace.toCounterexample(inox.Program(self.s)(symbols))(counterexample) - Trace.storeCounterexample(p) - } - output == res - } - case _ => true - } - } - case _ => true - } - - } - bval - }) - - } - case None => { - true - } - } - - passesAllTests && passesAllNewTests - } - - // Finds all the function calls in the body of fd - def getFunCalls(fd: FunDef): List[FunDef] = { - var funs: List[Identifier] = List() - s.exprOps.preTraversal { - case fi @ s.FunctionInvocation(tfd, tps, args) if tfd != fd.id //symbols.isRecursive(tfd) - => funs = tfd::funs - case _ => - }(fd.fullBody) - funs.distinct.map(symbols.functions(_)) - } - - // f1Calls: functions that are called from f1 - // f2Calls: functions that are called from f2 - // returns a list of sublemmas for each candidate pair (same signature + name?) + replacement map - // res._1 sublemma + its sublemmas and replacement - // res._2 and res._3 map for replacement - def makeSublemmas(fd1: s.FunDef, fd2: s.FunDef): List[(List[s.FunDef], List[s.FunDef], List[s.FunDef])] = { - val f1Calls = getFunCalls(fd1).filter(!_.flags.exists(_.name == "library")) - val f2Calls = getFunCalls(fd2).filter(!_.flags.exists(_.name == "library")) - - val pairs = f1Calls zip f1Calls.map(m => f2Calls.filter(f => m != f && compatibleSignatures(m, f))) - - // maps each call from fd1 to its "best" match from fd2 - val validpairs = pairs.map(elem => (elem._1, elem._2.find(f => f.id.name == elem._1.id.name && sameSignatures(elem._1, f) && evalAllignment(elem._1, f, Range(0, elem._1.params.size).toList.permutations.toList(0))).orElse(elem._2.find(f => sameSignatures(elem._1, f) && evalAllignment(elem._1, f, Range(0, elem._1.params.size).toList.permutations.toList(0)))))).filter(elem => !elem._2.isEmpty) - - val swappairs = pairs.filter(elem => !validpairs.map(_._1).contains(elem._1)) // pairs -- validpairs - - //jump directly to the next part if it's the second time you see this fd1fd2 combo - //TODO !!! - //val validswappairs = swappairs.map(elem => (elem._1, elem._2.find(f => f.id.name == elem._1.id.name && sameSignatures(elem._1, f) && evalAllignment(elem._1, f, true)).orElse(elem._2.find(f => sameSignatures(elem._1, f) && evalAllignment(elem._1, f, true))).orElse(elem._2.find(f => evalAllignment(elem._1, f, true))))).filter(elem => !elem._2.isEmpty) - val validswappairs = swappairs.map(elem => (elem._1, elem._2.find(f => Range(0, elem._1.params.size).toList.permutations.toList.tail.exists(o => - f.params.map(_.tpe).map(_.toString).toList == - Range(0, elem._1.params.size).map(i => elem._1.params(o(i))).map(_.tpe).map(_.toString).toList && - evalAllignment(elem._1, f, o))))).filter(elem => !elem._2.isEmpty) - - validpairs.map(elem => (elem._1, elem._2) match { - case (m, Some(f)) => (equivalenceCheck(m, f, true), List(m), List(f)) - }) ++ validswappairs.map(elem => (elem._1, elem._2) match { - case (m, Some(f)) => (equivalenceCheck(m, f, true, true), List(m), List(f)) - }) - } - - def equivalenceCheck(fd1: s.FunDef, fd2: s.FunDef, sublemmaGeneration: Boolean, swapping: Boolean = false): List[s.FunDef] = { - val freshId = FreshIdentifier(CheckFilter.fixedFullName(fd1.id) + "$" + CheckFilter.fixedFullName(fd2.id)) - val eqLemma = exprOps.freshenSignature(fd1).copy(id = freshId) - - val sublemmas = if (sublemmaGeneration) makeSublemmas(fd1, fd2) else List() - - // body of fd2, with calls to subfunctions replaced - val replacement: List[FunDef] = sublemmas match { - case Nil => List() - case _ => - val sm = sublemmas.map(_._2).flatten - val sf = sublemmas.map(_._3).flatten - List(inductPattern(symbols, fd2, fd2, "replacement", (sf zip sm).toMap).setPos(fd2.getPos).copy(flags = Seq(s.Derived(Some(fd2.id))))) - } - - val newParamTps = eqLemma.tparams.map{tparam => tparam.tp} - val newParamVars = eqLemma.params.map{param => param.toVariable} - - val fdSpecs = if(Trace.funFirst) fd2 else fd1 - - val subst = (fdSpecs.params.map(_.id) zip newParamVars).toMap - val tsubst = (fdSpecs.tparams zip newParamTps).map { case (tparam, targ) => tparam.tp.id -> targ }.toMap - val specializer = new Specializer(eqLemma, eqLemma.id, tsubst, subst, Map(), symbols) - - val specs = BodyWithSpecs(fdSpecs.fullBody).specs.filter(s => s.kind == LetKind || s.kind == PreconditionKind) - val pre = specs.map(spec => spec match { - case Precondition(cond) => Precondition(specializer.transform(cond)) - case LetInSpec(vd, expr) => LetInSpec(vd, specializer.transform(expr)) - }) - - val newParamVars2 = Trace.ordering.get(fd1.id) match { - case Some(order) => - Range(0, newParamVars.size).map(i => newParamVars(order(i))) - case None => newParamVars - } - - val fun1 = s.FunctionInvocation(fd1.id, newParamTps, newParamVars) - val fun2 = replacement match { - case Nil => - s.FunctionInvocation(fd2.id, newParamTps, newParamVars2) - case h::t => - s.FunctionInvocation(h.id, newParamTps, newParamVars2) - } - - val (normFun1, normFun2) = Trace.getNorm match { - case Some(n) if (sameSignaturesNorm(fun1.id, n)) => ( - s.FunctionInvocation(n, newParamTps, newParamVars :+ fun1), - s.FunctionInvocation(n, newParamTps, newParamVars2 :+ fun2)) - case _ => (fun1, fun2) - } - - val res = s.ValDef.fresh("res", s.UnitType()) - val cond = s.Equals(normFun1, normFun2) - val post = Postcondition(Lambda(Seq(res), cond)) - val body = s.UnitLiteral() - val withPre = exprOps.reconstructSpecs(pre, Some(body), s.UnitType()) - - // return the @traceInduct annotated eqLemma - // + potential sublemmas - // + the coressponding replacement functions - (eqLemma.copy( - fullBody = BodyWithSpecs(withPre).withSpec(post).reconstructed, - flags = Seq(s.Derived(Some(fd1.id)), s.Annotation("traceInduct",List(StringLiteral(fd1.id.name)))), - returnType = s.UnitType() - ).copiedFrom(eqLemma) :: sublemmas.flatMap(_._1)) ++ replacement //++ sublemmas.map(_._2).flatten - } - - (Trace.getModel, Trace.getFunction) match { - case (Some(model), Some(function)) => { - val m = symbols.functions(model) - val f = symbols.functions(function) - - Trace.nextEqCheckState - - if (m.params.size == f.params.size && evalCheck(f, m)) { - val res: List[s.FunDef] = Trace.eqCheckState match { - case Trace.EqCheckState.ModelFirst => - equivalenceCheck(m, f, false) - case Trace.EqCheckState.FunFirst => - equivalenceCheck(f, m, false) - case Trace.EqCheckState.ModelFirstWithSublemmas => - equivalenceCheck(m, f, true) - case Trace.EqCheckState.FunFirstWithSublemmas => - equivalenceCheck(f, m, true) - } - - res match { - case t::sublemmas => - Trace.setTrace(t.id) - Trace.sublemmas = sublemmas.map(_.id) - case _ => - } - res - } - else { - Trace.resetTrace - Trace.resetEqCheckState - List() - } - } - case _ => List() - } - } - - val generatedFunctions = generateEqLemma - val functions = generatedFunctions ++ symbols.functions.values.toList - - val inductFuns = functions.toList.flatMap(fd => if (fd.flags.exists(elem => elem.name == "traceInduct")) { - // find the model for fd - var funInv: Option[s.FunctionInvocation] = None - fd.flags.filter(elem => elem.name == "traceInduct").head match { - case s.Annotation("traceInduct", fun) => { - BodyWithSpecs(fd.fullBody).getSpec(PostconditionKind) match { - case Some(Postcondition(post)) => - s.exprOps.preTraversal { - case _ if funInv.isDefined => // do nothing - case fi @ s.FunctionInvocation(tfd, tps, args) if symbols.isRecursive(tfd) && (fun.contains(StringLiteral(tfd.name)) || fun.contains(StringLiteral(""))) - => { - val paramVars = fd.params.map(_.toVariable) - val argCheck = args.forall(paramVars.contains) && args.toSet.size == args.size - if (argCheck) funInv = Some(fi) - } - case _ => - }(post) - case _ => - } - } - } - - funInv match { - case Some(finv) => { - // make a helper lemma: - val helper = inductPattern(symbols, symbols.functions(finv.id), fd, "indProof", Map()).setPos(fd.getPos) - - // transform the main lemma - val proof = FunctionInvocation(helper.id, finv.tps, fd.params.map(_.toVariable)) - val returnType = typeOps.instantiateType(helper.returnType, (helper.typeArgs zip fd.typeArgs).toMap) - - val body = Let(s.ValDef.fresh("ind$proof", returnType), proof, exprOps.withoutSpecs(fd.fullBody).get) - val withPre = exprOps.reconstructSpecs(BodyWithSpecs(fd.fullBody).specs, Some(body), fd.returnType) - - val lemma = fd.copy( - fullBody = BodyWithSpecs(withPre).reconstructed, - flags = (s.Derived(Some(fd.id)) +: s.Derived(Some(finv.id)) +: (fd.flags.filterNot(f => f.name == "traceInduct"))).distinct - ).copiedFrom(fd).setPos(fd.getPos) - - Trace.getTrace match { - case Some(t) if(t == lemma.id) => Trace.setProof(helper.id) - case _ => - } - - if (Trace.sublemmas.contains(lemma.id)) Trace.sublemmas = helper.id :: Trace.sublemmas - - List(helper, lemma) - } - - case None => { // there are no recursive calls - no model function - val lemma = fd.copy( - flags = (s.Derived(Some(fd.id)) +: (fd.flags.filterNot(f => f.name == "traceInduct"))) - ).copiedFrom(fd).setPos(fd.getPos) - List(lemma) - } - } - } else List()) - - val (extractedSymbols, summary) = super.extractSymbols(context, symbols) - - val extracted = t.NoSymbols - .withSorts(extractedSymbols.sorts.values.toSeq) - .withFunctions((generatedFunctions.map(fun => identity.transform(fun)) ++ extractedSymbols.functions.values).filterNot(fd => fd.flags.exists(elem => elem.name == "traceInduct")).toSeq) - - (registerFunctions(extracted, inductFuns.map(fun => identity.transform(fun))), summary) - } - - // make a copy of the 'model' - // combine the specs of the 'lemma' - // 'suffix': only used for naming - // 'replacement': function calls that are supposed to be replaced according to given mapping - def inductPattern(symbols: s.Symbols, model: FunDef, lemma: FunDef, suffix: String, replacement: Map[s.FunDef, s.FunDef]) = { - import symbols.{given, _} - import exprOps._ - - val indPattern = exprOps.freshenSignature(model).copy(id = FreshIdentifier(s"${lemma.id}$$$suffix")) - val newParamTps = indPattern.tparams.map{tparam => tparam.tp} - val newParamVars = indPattern.params.map{param => param.toVariable} - - val fi = FunctionInvocation(model.id, newParamTps, newParamVars) - - val tpairs = model.tparams zip fi.tps - val tsubst = tpairs.map { case (tparam, targ) => tparam.tp.id -> targ } .toMap - val subst = (model.params.map(_.id) zip fi.args).toMap - val specializer = new Specializer(model, indPattern.id, tsubst, subst, replacement, symbols) - - val fullBodySpecialized = specializer.transform(exprOps.withoutSpecs(model.fullBody).get) - - val specsSubst = (lemma.params.map(_.id) zip newParamVars).toMap ++ (model.params.map(_.id) zip newParamVars).toMap - val specsTsubst = ((lemma.tparams zip fi.tps) ++ (model.tparams zip fi.tps)).map { case (tparam, targ) => tparam.tp.id -> targ }.toMap - val specsSpecializer = new Specializer(indPattern, indPattern.id, specsTsubst, specsSubst, Map(), symbols) - - val specs = BodyWithSpecs(model.fullBody).specs ++ BodyWithSpecs(lemma.fullBody).specs.filterNot(_.kind == MeasureKind) - val pre = specs.filterNot(_.kind == PostconditionKind).map(spec => spec match { - case Precondition(cond) => Precondition(specsSpecializer.transform(cond)).setPos(spec) - case LetInSpec(vd, expr) => LetInSpec(vd, specsSpecializer.transform(expr)).setPos(spec) - case Measure(measure) => Measure(specsSpecializer.transform(measure)).setPos(spec) - case s => context.reporter.fatalError(s"Unsupported specs: $s") - }) - - val withPre = exprOps.reconstructSpecs(pre, Some(fullBodySpecialized), indPattern.returnType) - - val speccedLemma = BodyWithSpecs(lemma.fullBody).addPost - val speccedOrig = BodyWithSpecs(model.fullBody).addPost - val postLemma = speccedLemma.getSpec(PostconditionKind).map(post => - specsSpecializer.transform(post.expr)) - val postOrig = speccedOrig.getSpec(PostconditionKind).map(post => specsSpecializer.transform(post.expr)) - - (postLemma, postOrig) match { - case (Some(Lambda(Seq(res1), cond1)), Some(Lambda(Seq(res2), cond2))) => - val res = ValDef.fresh("res", indPattern.returnType) - val freshCond1 = exprOps.replaceFromSymbols(Map(res1 -> res.toVariable), cond1) - val freshCond2 = exprOps.replaceFromSymbols(Map(res2 -> res.toVariable), cond2) - - val cond = andJoin(Seq(freshCond1, freshCond2)) - val post = Postcondition(Lambda(Seq(res), cond)) - - indPattern.copy( - fullBody = BodyWithSpecs(withPre).withSpec(post).reconstructed, - flags = Seq(s.Derived(Some(lemma.id)), s.Derived(Some(model.id))) - ).copiedFrom(indPattern) - - case _ => indPattern - } - - } - - class Specializer( - origFd: FunDef, - newId: Identifier, - tsubst: Map[Identifier, Type], - vsubst: Map[Identifier, Expr], - replacement: Map[s.FunDef, s.FunDef], // replace function calls - symbols: s.Symbols - ) extends s.ConcreteSelfTreeTransformer { slf => - - def sameSignatures(f1: FunDef, f2: FunDef) = { - f1.params.size == f2.params.size && f1.tparams.size == f2.tparams.size && - f1.params.zip(f2.params).forall(arg => arg._1.tpe == arg._2.tpe) && - f1.returnType == f2.returnType - } - - override def transform(expr: slf.s.Expr): slf.t.Expr = expr match { - case v: Variable => - vsubst.getOrElse(v.id, super.transform(v)) - - case fi: FunctionInvocation if fi.id == origFd.id => - val fi1 = FunctionInvocation(newId, tps = fi.tps, args = fi.args) - super.transform(fi1.copiedFrom(fi)) - - //f1(a, b) -> f2(b, a) - // only complicate if arg signatures do not match ! - case fi @ FunctionInvocation(tfd, tps, args) if replacement.keys.exists(elem => elem.id == fi.id) => - val replacement_fd = replacement.find((k, v) => k.id == fi.id).get._2 - val replacement_id = replacement_fd.id - val fd = symbols.functions(fi.id) - val fi1 = //if (checkArgs(replacement_fd, fd)) - //FunctionInvocation(replacement_id, tps = fi.tps, args = fi.args) - val order = Trace.ordering(fi.id) - - val argsPermutations = fi.args.permutations - val a = argsPermutations.find(a => Range(0, fi.args.size).map(i => a(order(i))) == fi.args).getOrElse(fi.args) - FunctionInvocation(replacement_id, tps = fi.tps, args = a) - super.transform(fi1.copiedFrom(fi)) - - case _ => super.transform(expr) - } - - override def transform(tpe: slf.s.Type): slf.t.Type = tpe match { - case tp: TypeParameter => - tsubst.getOrElse(tp.id, super.transform(tp)) - - case _ => super.transform(tpe) - } - } - - type Path = Seq[String] - - private lazy val pathsOpt: Option[Seq[Path]] = context.options.findOption(optCompareFuns) map { functions => - functions map CheckFilter.fullNameToPath - } - - private lazy val pathsOptModels: Option[Seq[Path]] = context.options.findOption(optModels) map { functions => - functions map CheckFilter.fullNameToPath - } - - private lazy val pathsOptNorm: Option[Seq[Path]] = - Some(Seq(context.options.findOptionOrDefault(optNorm)).map(CheckFilter.fullNameToPath)) - - private def shouldBeChecked(fid: Identifier): Boolean = shouldBeChecked(pathsOpt, fid) - private def isModel(fid: Identifier): Boolean = shouldBeChecked(pathsOptModels, fid) - private def isNorm(fid: Identifier): Boolean = shouldBeChecked(pathsOptNorm, fid) - - private def shouldBeChecked(paths: Option[Seq[Path]], fid: Identifier): Boolean = paths match { - case None => false - - case Some(paths) => - // Support wildcard `_` as specified in the documentation. - // A leading wildcard is always assumed. - val path: Path = CheckFilter.fullNameToPath(CheckFilter.fixedFullName(fid)) - paths exists { p => - if (p endsWith Seq("_")) path containsSlice p.init - else path endsWith p - } - } -} - -object Trace { - var clusters: Map[Identifier, List[Identifier]] = Map() - var errors: List[Identifier] = List() // counter-example is found - var unknowns: List[Identifier] = List() // timeout - var wrong: List[Identifier] = List() // bad signature - - var allModels: Map[Identifier, Int] = Map() - var tmpModels: List[Identifier] = List() - - var allFunctions: List[Identifier] = List() - var tmpFunctions: List[Identifier] = List() - - var model: Option[Identifier] = None - var function: Option[Identifier] = None - var norm: Option[Identifier] = None - var trace: Option[Identifier] = None - var proof: Option[Identifier] = None - var sublemmas: List[Identifier] = List() - var mkTest: Option[Identifier] = None - - var sublemmaGeneration: Boolean = false - - case class State(var directModel: Option[Identifier], var counterexample: Option[Counterexample], var prevModels: List[Identifier], var subCounterexamples: List[Counterexample]) - - var state: Map[Identifier, State] = Map() - - object EqCheckState extends Enumeration { - type EqCheckState = Value - val InitState, ModelFirst, FunFirst, ModelFirstWithSublemmas, FunFirstWithSublemmas = Value - } - - var eqCheckState = EqCheckState.InitState - - def nextEqCheckState: Unit = eqCheckState = eqCheckState match { - case EqCheckState.InitState => EqCheckState.ModelFirst - case EqCheckState.ModelFirst => EqCheckState.FunFirst - case EqCheckState.FunFirst => EqCheckState.ModelFirstWithSublemmas - case EqCheckState.ModelFirstWithSublemmas => EqCheckState.FunFirstWithSublemmas - case EqCheckState.FunFirstWithSublemmas => EqCheckState.InitState - } - - def resetEqCheckState = eqCheckState = EqCheckState.InitState - def isFinalEqCheckState = eqCheckState == EqCheckState.FunFirstWithSublemmas - - def funFirst = eqCheckState == EqCheckState.FunFirst || eqCheckState == EqCheckState.FunFirstWithSublemmas - def withSublemmas = eqCheckState == EqCheckState.ModelFirstWithSublemmas || eqCheckState == EqCheckState.FunFirstWithSublemmas - - var ordering: Map[Identifier, List[Int]] = Map() - var cnt = 0 - - def apply(ts: Trees, tt: termination.Trees)(using inox.Context): ExtractionPipeline { - val s: ts.type - val t: tt.type - } = { - class Impl(override val s: ts.type, override val t: tt.type) extends Trace(s, t) - new Impl(ts, tt) - } - - def setModels(m: List[Identifier]) = { - allModels = m.map(elem => (elem, 200)).toMap - tmpModels = m - clusters = (m zip m.map(_ => Nil)).toMap - state = state ++ (m zip m.map(_ => State(None, None, List(), List()))).toMap - } - - def setFunctions(f: List[Identifier]) = { - allFunctions = f - tmpFunctions = f - cnt = f.size - state = state ++ (f zip f.map(_ => State(None, None, List(), List()))).toMap - } - - def getModels = allModels - def getFunctions = allFunctions - def getModel = model // model for the current iteration - def getFunction = function // function to check in the current iteration - def getNorm = norm - def getMkTest = mkTest - def getTrace = trace - - def setTrace(t: Identifier) = { - proof = None - trace = Some(t) - state(function.get).prevModels = model.get :: state(function.get).prevModels - } - - def resetTrace = { - trace = None - proof = None - sublemmas = List() - } - - def setProof(p: Identifier) = proof = Some(p) - def setNorm(n: Option[Identifier]) = norm = n - def setMkTest(t: Identifier) = mkTest = Some(t) - - // iterate model for the current function - private def nextModel = tmpModels match { - case x::xs => { - tmpModels = xs - model = Some(x) - } - case Nil => model = None - } - - // iterate function to check; reset model - private def nextFunction = { - trace = None - proof = None - tmpFunctions match { - case x::xs => { - val n = 3 - tmpModels = allModels.toList.sortBy(m => -m._2).map(_._1).filterNot(state(x).prevModels.contains).take(n) - - if(tmpModels.isEmpty) tmpModels = allModels.keys.take(1).toList - nextModel - tmpFunctions = xs - function = Some(x) - } - case Nil => function = None - } - } - - private def isDone = function == None - - trait Counterexample { - val prog: inox.Program - val counterexample: Map[prog.trees.ValDef, prog.trees.Expr] - val existing: Boolean - val fromFunction: Boolean - val fromEval: Boolean - } - - var tmpCounterexample: Option[Counterexample] = None - var tmpSubCounterexample: Option[Counterexample] = None - - def toCounterexample(pr: inox.Program)(counterex: Map[pr.trees.ValDef, pr.trees.Expr]): Option[Counterexample] = { - Some(new Counterexample { - val prog: pr.type = pr - val counterexample = counterex - val existing = true - val fromEval = true - val fromFunction = false - }) - } - - def reportCounterexample(pr: inox.Program)(counterex: pr.Model)(fun: Identifier): Unit = { - def isMainCounterexample(fun: Identifier) = { - !function.isEmpty && function.get == fun || - !proof.isEmpty && proof.get == fun || - !trace.isEmpty && trace.get == fun - } - - def isSubCounterexample(fun: Identifier) = { - sublemmas.contains(fun) - } - - val c = Some(new Counterexample { - val prog: pr.type = pr - val counterexample = counterex.vars - val existing = false - val fromEval = false - val fromFunction = funFirst || (function != None && function.get == fun) - }) - - if (isMainCounterexample(fun)) tmpCounterexample = c - else if (isSubCounterexample(fun)) tmpSubCounterexample = c - } - - var counter = 0 - var sublemmacounter = 0 - var flippedcounter = 0 - var valid = 0 - - def nextIteration[T <: AbstractReport[T]](report: AbstractReport[T])(implicit context: inox.Context): Boolean = { - counter = counter + 1 - if (counter % 30 == 0) printEverything - - val sublemmasAreValid = sublemmas.forall(s => !report.hasError(Some(s)) && !report.hasUnknown(Some(s))) - val sublemmasHaveErrors = sublemmas.exists(s => report.hasError(Some(s))) - - (function, trace) match { - case (Some(f), Some(t)) => { - tmpSubCounterexample match { - case Some(c) => state(function.get).subCounterexamples = c::state(function.get).subCounterexamples - case None => - } - if (report.hasError(function) || report.hasError(proof) || report.hasError(trace)) { - if (!withSublemmas) reportError(tmpCounterexample) // only if not in the sublemma state - else reportUnknown - } - else if (report.hasUnknown(function) || report.hasUnknown(proof) || report.hasUnknown(trace)) reportUnknown - else if (sublemmasAreValid) reportValid - else reportUnknown - } - case (Some(f), _) if(state(f).counterexample != None) => - reportError(state(f).counterexample) - counter = counter - 1 - case _ => reportWrong - } - - if(isDone && unknowns.size < cnt) { - cnt = unknowns.size - tmpModels = allModels.keys.toList - tmpFunctions = unknowns.reverse - unknowns = List() - nextFunction - } - // if(isDone){ - // println("COUNTER") - // println(counter) - // println("SUBLEMMA COUNTER") - // println(sublemmacounter) - // println("COUNTER valids") - // println(valid) - // } - !isDone - } - - private def storeCounterexample(counterexample: Option[Counterexample]) = { - state(function.get).counterexample = counterexample - } - - private def reportError[T](counterexample: Option[Counterexample]) = { - resetEqCheckState - errors = function.get::errors // store the counter-example - noLongerUnknown(function.get) - state(function.get).directModel = model - state(function.get).counterexample = counterexample - nextFunction - } - - // if there is a new state go there, otherwise report as unknown - private def reportUnknown = { - allModels = allModels.updated(model.get, allModels(model.get) - 1) - if (isFinalEqCheckState) { - resetEqCheckState - nextModel - if (model == None) { - unknowns = function.get::unknowns - nextFunction - } - } - else { - //nextEqCheckState - } - } - - private def reportValid = { - if(withSublemmas) sublemmacounter = sublemmacounter + 1 - if(funFirst) flippedcounter = flippedcounter + 1 - resetEqCheckState - - if (!allModels.keys.toList.contains(function.get)) { - state(function.get).directModel = model - - val inc = if (allModels(model.get) > 0) 20 else 100 - allModels = allModels.updated(model.get, allModels(model.get) + inc) - allModels = (allModels + (function.get -> 0)) - - clusters = clusters + (function.get -> List()) - } - - clusters = clusters + (model.get -> (function.get::clusters.getOrElse(model.get, List()))) - noLongerUnknown(function.get) - valid = valid + 1 - nextFunction - } - - private def reportWrong = { - resetEqCheckState - if (function != None) { - wrong = function.get::wrong - noLongerUnknown(function.get) - } - resetTrace - nextFunction - } - - private def noLongerUnknown(f: Identifier) = unknowns = unknowns.filterNot(elem => elem == f) - - def optionsError(using ctx: inox.Context): Boolean = - !ctx.options.findOptionOrDefault(frontend.optBatchedProgram) && - (!ctx.options.findOptionOrDefault(optModels).isEmpty || !ctx.options.findOptionOrDefault(optCompareFuns).isEmpty) - - def printEverything(using ctx: inox.Context) = { - import ctx.{ reporter, timers } - if(!clusters.isEmpty || !errors.isEmpty || !unknowns.isEmpty || !wrong.isEmpty) { - reporter.info(s"Printing equivalence checking results:") - allModels.keys.foreach(model => if (!clusters(model).isEmpty) { - val l = clusters(model).map(CheckFilter.fixedFullName).mkString(", ") - val m = CheckFilter.fixedFullName(model) - reporter.info(s"List of functions that are equivalent to model $m: $l") - }) - - val errorneous = errors.map(CheckFilter.fixedFullName).mkString(", ") - reporter.info(s"List of erroneous functions: $errorneous") - val timeouts = unknowns.map(CheckFilter.fixedFullName).mkString(", ") - reporter.info(s"List of timed-out functions: $timeouts") - val wrongs = wrong.map(CheckFilter.fixedFullName).mkString(", ") - reporter.info(s"List of wrong functions: $wrongs") - - reporter.info(s"Printing the final state:") - allFunctions.foreach(f => { - val l = path(f).map(CheckFilter.fixedFullName).mkString(", ") - val m = CheckFilter.fixedFullName(f) - reporter.info(s"Path for the function $m: $l") - }) - - def path(f: Identifier): List[Identifier] = { - val m = state(f).directModel - m match { - case Some(mm) => mm :: path(mm) - case None => List() - } - } - - allFunctions.foreach(f => { - state(f).counterexample match { - case None => //None - case Some(c) => - val m = CheckFilter.fixedFullName(f) - val ce = c.counterexample.map((k, v) => (k.id, v)) - val fe = c.fromEval - reporter.info(s"Counterexample for the function $m: $ce, $fe") - } - }) - - } - - } - -} \ No newline at end of file diff --git a/core/src/main/scala/stainless/extraction/trace/TraceInductElimination.scala b/core/src/main/scala/stainless/extraction/trace/TraceInductElimination.scala new file mode 100644 index 0000000000..b49771abfa --- /dev/null +++ b/core/src/main/scala/stainless/extraction/trace/TraceInductElimination.scala @@ -0,0 +1,44 @@ +package stainless +package extraction +package trace + +import stainless.equivchk.Utils + +class TraceInductElimination(override val s: Trees, override val t: termination.Trees) + (using override val context: inox.Context) + extends CachingPhase + with NoSummaryPhase + with SimplyCachedFunctions + with IdentitySorts + with Utils { self => + import s._ + override val trees: self.s.type = s + + override protected type TransformerContext = s.Symbols + override protected def getContext(symbols: s.Symbols) = symbols + + override protected def registerFunctions(symbols: t.Symbols, functions: Seq[Seq[t.FunDef]]): t.Symbols = + symbols.withFunctions(functions.flatten) + + override protected type FunctionResult = Seq[t.FunDef] + override protected def extractFunction(symbols: TransformerContext, fd: FunDef): (FunctionResult, Unit) = { + val fns = elimTraceInduct(symbols, fd) + .map(res => res.updatedFd +: res.helper.toSeq) + .getOrElse(Seq(fd)) + .map(identity.transform) + (fns, ()) + } + + private[this] class Identity(override val s: self.s.type, override val t: self.t.type) extends transformers.ConcreteTreeTransformer(s, t) + private[this] val identity = new Identity(self.s, self.t) +} + +object TraceInductElimination { + def apply(ts: Trees, tt: termination.Trees)(using inox.Context): ExtractionPipeline { + val s: ts.type + val t: tt.type + } = { + class Impl(override val s: ts.type, override val t: tt.type) extends TraceInductElimination(s, t) + new Impl(ts, tt) + } +} \ No newline at end of file diff --git a/core/src/main/scala/stainless/extraction/trace/package.scala b/core/src/main/scala/stainless/extraction/trace/package.scala index 044dfd0a05..3d36e8dfde 100644 --- a/core/src/main/scala/stainless/extraction/trace/package.scala +++ b/core/src/main/scala/stainless/extraction/trace/package.scala @@ -21,7 +21,7 @@ package object trace { } def extractor(using inox.Context) = { - utils.DebugPipeline("Trace", Trace(trees, termination.trees)) + utils.DebugPipeline("TraceInductElimination", TraceInductElimination(trees, termination.trees)) } def fullExtractor(using inox.Context) = extractor andThen nextExtractor diff --git a/core/src/main/scala/stainless/frontend/BatchedCallBack.scala b/core/src/main/scala/stainless/frontend/BatchedCallBack.scala index 88d9c5488b..c17fa93df0 100644 --- a/core/src/main/scala/stainless/frontend/BatchedCallBack.scala +++ b/core/src/main/scala/stainless/frontend/BatchedCallBack.scala @@ -6,7 +6,6 @@ package frontend import stainless.extraction.xlang.{trees => xt, TreeSanitizer} import stainless.extraction.utils.DebugSymbols import stainless.utils.LibraryFilter -import stainless.extraction.trace.Trace import scala.util.{Try, Success, Failure} import scala.concurrent.Await @@ -101,18 +100,12 @@ class BatchedCallBack(components: Seq[Component])(using val context: inox.Contex reporter.debug(e) reportError(defn.getPos, e.getMessage, symbols) } - - var rerunPipeline = true - while (rerunPipeline) { - val reports = runs map { run => - val ids = symbols.functions.keys.toSeq - val analysis = Await.result(run(ids, symbols), Duration.Inf) - RunReport(run)(analysis.toReport) - } - report = Report(reports) - rerunPipeline = Trace.nextIteration(report) - if (!rerunPipeline) Trace.printEverything + val reports = runs map { run => + val ids = symbols.functions.keys.toSeq + val analysis = Await.result(run(ids, symbols), Duration.Inf) + RunReport(run)(analysis.toReport) } + report = Report(reports) } def stop(): Unit = { diff --git a/core/src/main/scala/stainless/frontend/package.scala b/core/src/main/scala/stainless/frontend/package.scala index 86f2e89120..f3df889373 100644 --- a/core/src/main/scala/stainless/frontend/package.scala +++ b/core/src/main/scala/stainless/frontend/package.scala @@ -43,6 +43,7 @@ package object frontend { genc.GenCComponent, testgen.ScalaTestGenComponent, testgen.GenCTestGenComponent, + equivchk.EquivalenceCheckingComponent ) /** @@ -78,7 +79,7 @@ package object frontend { private def batchSymbols(activeComponents: Seq[Component])(using ctx: inox.Context): Boolean = { ctx.options.findOptionOrDefault(optBatchedProgram) || - activeComponents.exists(Set(genc.GenCComponent, testgen.ScalaTestGenComponent, testgen.GenCTestGenComponent).contains) || + activeComponents.exists(Set(genc.GenCComponent, testgen.ScalaTestGenComponent, testgen.GenCTestGenComponent, equivchk.EquivalenceCheckingComponent).contains) || ctx.options.findOptionOrDefault(optKeep).nonEmpty } diff --git a/core/src/main/scala/stainless/package.scala b/core/src/main/scala/stainless/package.scala index 68669a0399..55c7497291 100644 --- a/core/src/main/scala/stainless/package.scala +++ b/core/src/main/scala/stainless/package.scala @@ -157,10 +157,8 @@ package object stainless { lazy val useParallelism: Boolean = nParallel.isEmpty || nParallel.exists(_ > 1) - private lazy val currentThreadExecutionContext: ExecutionContext = - ExecutionContext.fromExecutor(new java.util.concurrent.Executor { - def execute(runnable: Runnable): Unit = { runnable.run() } - }) + private lazy val singleThreadExecutionContext: ExecutionContext = + ExecutionContext.fromExecutor(Executors.newFixedThreadPool(1)) private lazy val multiThreadedExecutor: java.util.concurrent.ExecutorService = nParallel.map(Executors.newFixedThreadPool(_)).getOrElse(ForkJoinTasks.defaultForkJoinPool) @@ -169,7 +167,7 @@ package object stainless { implicit def executionContext(using ctx: inox.Context): ExecutionContext = if (useParallelism && ctx.reporter.debugSections.isEmpty) multiThreadedExecutionContext - else currentThreadExecutionContext + else singleThreadExecutionContext def shutdown(): Unit = if (useParallelism) multiThreadedExecutor.shutdown() } diff --git a/core/src/main/scala/stainless/testgen/XLangTestGen.scala b/core/src/main/scala/stainless/testgen/XLangTestGen.scala index 0f6c5df021..462fec7148 100644 --- a/core/src/main/scala/stainless/testgen/XLangTestGen.scala +++ b/core/src/main/scala/stainless/testgen/XLangTestGen.scala @@ -2,13 +2,12 @@ package stainless package testgen import stainless.ast.SymbolIdentifier -import stainless.extraction.xlang.trees as xt -import stainless.verification.* +import stainless.extraction.xlang.{trees => xt} +import stainless.verification._ -object XLangTestGen { - import stainless.trees._ - - case class TyParamSubst(numericTpe: Type, numericExprCtor: BigInt => Option[Expr]) +object XLangTestGen extends utils.CtexRemapping { + override val trees: stainless.trees.type = stainless.trees + import trees._ case class TestCaseCtx(origSyms: xt.Symbols, p: inox.Program {val trees: stainless.trees.type}, @@ -64,45 +63,28 @@ object XLangTestGen { return None } - val vdMapping: Map[ValDef, (ValDef, Int)] = cex.vars.keySet.flatMap(cexVd => - faultyFd.params.zipWithIndex.find(_._1.id.name == cexVd.id.name).map(fnVd => cexVd -> fnVd)).toMap - if (vdMapping.size != cex.vars.size) { - recognizeFailure(s"due to the inability of reconciling the variables of the model and the parameters of ${vc.fid}") - return None - } - if (vdMapping.keys.size != cex.vars.size) { - recognizeFailure(s"due to name collision with the parameters of ${vc.fid}") - return None - } - - val missingMapping: Set[(ValDef, Int)] = faultyFd.params.zipWithIndex.toSet -- vdMapping.values.toSet - val (defaultArgsErrs, defaultArgs) = missingMapping.toSeq.partitionMap { - case (fnVd, pos) => - val expr = tryFindDefaultValue(substTypeParams(fnVd.tpe)) - expr.map(e => pos -> e) - } - if (defaultArgsErrs.nonEmpty) { - recognizeFailure( - s"""due to: - |${defaultArgsErrs.flatten.mkString(" - ", "\n", "")}""".stripMargin) - return None - } - - val (cexArgsErr, cexArgs) = cex.vars.toSeq.partitionMap { - case (cexVd, expr) => - val pos = vdMapping(cexVd)._2 - val rewrittenExpr = tryRewriteExpr(expr) - rewrittenExpr.map(e => pos -> e) - } - if (cexArgsErr.nonEmpty) { - recognizeFailure( - s"""due to: - |${cexArgsErr.flatten.mkString(" - ", "\n", "")}""".stripMargin) - return None - } - val args = (defaultArgs ++ cexArgs).sortBy(_._1).map(_._2) - - val testCaseBody = FunctionInvocation(vc.fid, faultyFd.tparams.map(_ => IntegerType()), args) + val (args, newTparams) = tryRemapCtex(faultyFd.params, faultyFd.tparams, tcCtx.p)(tcCtx.cex.vars) match { + case RemappedCtex.Success(args, newTparams) => (args, newTparams) + case RemappedCtex.Failure(FailureReason.NonUniqueParams(name, vds)) => + recognizeFailure(s"due to name collision with the parameters of ${faultyFd.id.name}: $name has multiple correspondances $vds") + return None + case RemappedCtex.Failure(FailureReason.NonUniqueCtexVars(name, vds)) => + recognizeFailure(s"due to name collision with the variables of the counter-examples: $name has multiple correspondances $vds") + return None + case RemappedCtex.Failure(FailureReason.NonUniqueTypeParams(name, tpds)) => + recognizeFailure(s"due to name collision with the type parameters of ${faultyFd.id.name}: $name has multiple correspondances $tpds") + return None + case RemappedCtex.Failure(FailureReason.UnmappedCtexVars(vds)) => + recognizeFailure(s"due to counter-example having extra unknown variables: $vds") + return None + case RemappedCtex.Failure(FailureReason.ExprRewrite(exprs)) => + recognizeFailure(s"the inability of rewriting ${exprs.mkString(", ")}") + return None + case RemappedCtex.Failure(FailureReason.NoSimpleValue(tps)) => + recognizeFailure(s"the inability of find a default values for ${tps.mkString(", ")}") + return None + } + val testCaseBody = FunctionInvocation(vc.fid, newTparams, args) val testCaseId = SymbolIdentifier(s"testCase$testCaseNbr") val testCaseFd = FunDef(testCaseId, Seq.empty, Seq.empty, UnitType(), testCaseBody, Seq.empty) @@ -128,28 +110,6 @@ object XLangTestGen { |$msg""".stripMargin) } - def tryFindDefaultValue(tpe: Type)(using faultyFd: FaultyFd, tcCtx: TestCaseCtx, tps: TyParamSubst, ctx: inox.Context): Either[List[String], Expr] = { - try { - tryRewriteExpr(tcCtx.p.symbols.simplestValue(tpe, allowSolver = false)(using tcCtx.p.getSemantics)) - } catch { - case _: tcCtx.p.symbols.NoSimpleValue => Left(s"the inability of finding a default value for $tpe" :: Nil) - } - } - - def substTypeParams(tpe: Type)(using faultyFd: FaultyFd, tps: TyParamSubst): Type = { - val subst = new SubstTypeParamsAndGenericValues(faultyFd.fd.tparams) - val newTpe = subst.transform(tpe) - assert(subst.failures.isEmpty, "failures are impossible when substituting for types (only for Expr)") - newTpe - } - - def tryRewriteExpr(expr: Expr)(using faultyFd: FaultyFd, tps: TyParamSubst): Either[List[String], Expr] = { - val subst = new SubstTypeParamsAndGenericValues(faultyFd.fd.tparams) - val newExpr = subst.transform(expr) - if (subst.failures.isEmpty) Right(newExpr) - else Left(subst.failures.map(e => s"the inability of rewriting $e")) - } - ///////////////////////////////////////////////////////////////////////////////////////////////// private val unlowering = new UnloweringImpl(trees, xt) @@ -189,31 +149,4 @@ object XLangTestGen { .map((clsId, _) => xt.ClassType(clsId, adtType.tps.map(transform))) } } - - private class SubstTypeParamsAndGenericValues(tparams: Seq[TypeParameterDef]) - (using tyParamSubst: TyParamSubst) extends ConcreteSelfTreeTransformer { - var failures: List[Expr] = Nil - - override def transform(tpe: Type): Type = tpe match { - case tp: TypeParameter if tparams.exists(_.tp.id.name == tp.id.name) => tyParamSubst.numericTpe - case _ => super.transform(tpe) - } - - override def transform(expr: Expr): Expr = expr match { - case GenericValue(tp, id) => - val newExpr = tparams.zipWithIndex - .find((tpDef, _) => tpDef.tp.id.name == tp.id.name) - .flatMap { (_, tpIx) => - tyParamSubst.numericExprCtor(BigInt(id) * BigInt(tparams.size) + BigInt(tpIx)) - } - newExpr match { - case Some(e) => e - case None => - failures = expr :: failures - expr - } - case _ => super.transform(expr) - } - } - } \ No newline at end of file diff --git a/core/src/main/scala/stainless/utils/CtexRemapping.scala b/core/src/main/scala/stainless/utils/CtexRemapping.scala new file mode 100644 index 0000000000..b31ed435a9 --- /dev/null +++ b/core/src/main/scala/stainless/utils/CtexRemapping.scala @@ -0,0 +1,174 @@ +package stainless +package utils + +// Utility for processing counter-examples. +// The main function is `tryRemapCtex` which attempts to map a counter-example back to the function it originated from +// by ordering the ctex variables according to the ordering of the function parameters, by instantiating generic types +// to a given numeric type (TyParamSubst) and by providing default values for function parameters missing in the ctex. +trait CtexRemapping { self => + val trees: ast.Trees + import trees._ + + // Type parameters will be replaced with the given `numericTpe` and generic values with `numericExprCtor` + // (provided it succeeds by returning `Some`). + case class TyParamSubst(numericTpe: Type, numericExprCtor: BigInt => Option[Expr]) + + enum RemappedCtex { + case Success(args: Seq[Expr], tparams: Seq[Type]) + case Failure(reason: FailureReason) + } + enum FailureReason { + case NonUniqueParams(name: String, vds: Seq[ValDef]) + case NonUniqueCtexVars(name: String, vds: Seq[ValDef]) + case NonUniqueTypeParams(name: String, vds: Seq[TypeParameterDef]) + case UnmappedCtexVars(vds: Seq[ValDef]) + case ExprRewrite(exprs: Seq[Expr]) + case NoSimpleValue(tpes: Seq[Type]) + } + + def tryRemapCtex(params: Seq[ValDef], tparams: Seq[TypeParameterDef], p: StainlessProgram) + (ctex: Map[p.trees.ValDef, p.trees.Expr]) + (using tyParamSubst: TyParamSubst, ctx: inox.Context): RemappedCtex = { + // First, ensure that the vds in `params` and `vars` are uniquely identifiable by their string name + // (note that we cannot use the Identifiers because they are generally not the same due to freshening) + val paramsName2Vds = params.groupBy(_.id.name) + paramsName2Vds.find(_._2.size > 1) match { + case Some((name, vds)) => + return RemappedCtex.Failure(FailureReason.NonUniqueParams(name, vds)) + case None => () + } + val ctexName2Vds = ctex.keySet.groupBy(_.id.name) + ctexName2Vds.find(_._2.size > 1) match { + case Some((name, vds)) => + return RemappedCtex.Failure(FailureReason.NonUniqueCtexVars(name, vds.toSeq.map(prog2self(p)(_)))) + case None => () + } + // Do the same for tparams + val tparamsNames2Tpds = tparams.groupBy(_.id.name) + tparamsNames2Tpds.find(_._2.size > 1) match { + case Some((name, tpds)) => + return RemappedCtex.Failure(FailureReason.NonUniqueTypeParams(name, tpds)) + case None => () + } + + val ctexName2Vd = ctexName2Vds.view.mapValues(_.head).toMap + val paramsName2Vd = paramsName2Vds.view.mapValues(_.head).toMap + + val extra = ctexName2Vd.keySet -- paramsName2Vd.keySet + if (extra.nonEmpty) { + // Extra variables in the counter-examples, we do not know what they correspond to + return RemappedCtex.Failure(FailureReason.UnmappedCtexVars(extra.toSeq.map(e => prog2self(p)(ctexName2Vd(e))))) + } + + // It may be the case that the ctex does not contain values for all params, in which case we will try to find a default value + val missing = paramsName2Vd.keySet -- ctexName2Vd.keySet + val (noSimpleVals, simpleVals0) = missing.partitionMap { m => + val vd = paramsName2Vd(m) + val tpe = substTypeParams(tparams, vd.tpe) + tryFindDefaultValue(p, tpe).map(m -> _).toRight(tpe) + } + if (noSimpleVals.nonEmpty) { + return RemappedCtex.Failure(FailureReason.NoSimpleValue(noSimpleVals.toSeq)) + } + val simpleVals = simpleVals0.toMap + val (notRewrittenss, rewritten0) = ctex.toSeq.partitionMap { case (vd, e) => + tryRewriteExpr(tparams, prog2self(p)(e)).map(vd.id.name -> _) + } + if (notRewrittenss.nonEmpty) { + return RemappedCtex.Failure(FailureReason.ExprRewrite(notRewrittenss.flatten)) + } + val rewritten = rewritten0.toMap + assert(rewritten.keySet.intersect(simpleVals.keySet).isEmpty) + val allUnord = rewritten ++ simpleVals + assert(allUnord.keySet == paramsName2Vd.keySet) + val allOrd = params.map(vd => allUnord(vd.id.name)) + val newTparams = tparams.map(_ => tyParamSubst.numericTpe) + RemappedCtex.Success(allOrd, newTparams) + } + + def tryRewriteExpr(tparams: Seq[TypeParameterDef], expr: Expr)(using tyParamSubst: TyParamSubst): Either[Seq[Expr], Expr] = { + if (tparams.isEmpty) Right(expr) + else { + val subst = new SubstTypeParamsAndGenericValues(tparams) + val newExpr = subst.transform(expr) + if (subst.failures.isEmpty) Right(newExpr) + else Left(subst.failures) + } + } + + def substTypeParams(tparams: Seq[TypeParameterDef], tpe: Type)(using tyParamSubst: TyParamSubst): Type = { + if (tparams.isEmpty) tpe + else { + val subst = new SubstTypeParamsAndGenericValues(tparams) + val newTpe = subst.transform(tpe) + assert(subst.failures.isEmpty, "failures are impossible when substituting for types (only for Expr)") + newTpe + } + } + + // Note: assumes that `tpe` has already been rewritten to not contain generic values + def tryFindDefaultValue(p: StainlessProgram, tpe: Type)(using inox.Context): Option[Expr] = { + try { + val dlft = p.symbols.simplestValue(self2prog(p, tpe), allowSolver = false)(using p.getSemantics) + Some(prog2self(p)(dlft)) + } catch { + case _: p.symbols.NoSimpleValue => None + } + } + + def self2prog(pr: StainlessProgram, e: Expr): pr.trees.Expr = { + class Impl(override val s: self.trees.type, override val t: pr.trees.type) extends transformers.ConcreteTreeTransformer(s, t) + new Impl(trees, pr.trees).transform(e) + } + + def self2prog(pr: StainlessProgram, tpe: Type): pr.trees.Type = { + class Impl(override val s: self.trees.type, override val t: pr.trees.type) extends transformers.ConcreteTreeTransformer(s, t) + new Impl(trees, pr.trees).transform(tpe) + } + + def prog2self(pr: StainlessProgram)(tpe: pr.trees.Type): Type = { + class Impl(override val s: pr.trees.type, override val t: self.trees.type) extends transformers.ConcreteTreeTransformer(s, t) + new Impl(pr.trees, trees).transform(tpe) + } + + def prog2self(pr: StainlessProgram)(tpd: pr.trees.TypeParameterDef): TypeParameterDef = { + class Impl(override val s: pr.trees.type, override val t: self.trees.type) extends transformers.ConcreteTreeTransformer(s, t) + new Impl(pr.trees, trees).transform(tpd) + } + + def prog2self(pr: StainlessProgram)(vd: pr.trees.ValDef): ValDef = { + class Impl(override val s: pr.trees.type, override val t: self.trees.type) extends transformers.ConcreteTreeTransformer(s, t) + new Impl(pr.trees, trees).transform(vd) + } + + def prog2self(pr: StainlessProgram)(e: pr.trees.Expr): Expr = { + class Impl(override val s: pr.trees.type, override val t: self.trees.type) extends transformers.ConcreteTreeTransformer(s, t) + new Impl(pr.trees, trees).transform(e) + } + + private class SubstTypeParamsAndGenericValues(tparams: Seq[TypeParameterDef])(using tyParamSubst: TyParamSubst) extends ConcreteSelfTreeTransformer { + var failures: List[Expr] = Nil + + override def transform(tpe: Type): Type = tpe match { + case tp: TypeParameter if tparams.exists(_.tp.id.name == tp.id.name) => tyParamSubst.numericTpe + case _ => super.transform(tpe) + } + + override def transform(expr: Expr): Expr = expr match { + case GenericValue(tp, id) => + val newExpr = tparams.zipWithIndex + .find((tpDef, _) => tpDef.tp.id.name == tp.id.name) + .flatMap { (_, tpIx) => + tyParamSubst.numericExprCtor(BigInt(id) * BigInt(tparams.size) + BigInt(tpIx)) + } + newExpr match { + case Some(e) => e + case None => + failures = expr :: failures + expr + } + case _ => super.transform(expr) + } + } + +} diff --git a/core/src/main/scala/stainless/verification/VerificationChecker.scala b/core/src/main/scala/stainless/verification/VerificationChecker.scala index cc0b1dd9ca..b5f8ab9881 100644 --- a/core/src/main/scala/stainless/verification/VerificationChecker.scala +++ b/core/src/main/scala/stainless/verification/VerificationChecker.scala @@ -12,6 +12,7 @@ import scala.util.{Failure, Success} import scala.concurrent.Future import scala.collection.mutable +object optSilent extends inox.FlagOptionDef("silent-verification", false) object optFailEarly extends inox.FlagOptionDef("fail-early", false) object optFailInvalid extends inox.FlagOptionDef("fail-invalid", false) object optVCCache extends inox.FlagOptionDef("vc-cache", true) @@ -323,7 +324,6 @@ trait VerificationChecker { self => VCResult(status, s.getResultSolver.map(_.name), Some(time)) case SatWithModel(model) if !vc.satisfiability => - extraction.trace.Trace.reportCounterexample(program)(model)(vc.fid) VCResult(VCStatus.Invalid(VCStatus.CounterExample(model)), s.getResultSolver.map(_.name), Some(time)) case Sat if vc.satisfiability => @@ -349,35 +349,37 @@ trait VerificationChecker { self => val vcResultMsg = VCResultMessage(vc, vcres) reporter.debug(vcResultMsg) - reporter.synchronized { - val descr = s" - Result for '${vc.kind}' VC for ${vc.fid.asString} @${vc.getPos}:" - - vcres.status match { - case VCStatus.Valid => - reporter.debug(descr) - reporter.debug(" => VALID") - - case VCStatus.Invalid(reason) => - reporter.warning(descr) - // avoid reprinting VC if --debug=verification is enabled - if (!reporter.isDebugEnabled(using DebugSectionVerification)) - reporter.warning(prettify(vc.condition).asString) - reporter.warning(vc.getPos, " => INVALID") - reason match { - case VCStatus.CounterExample(cex) => - reporter.warning("Found counter-example:") - reporter.warning(" " + cex.asString.replaceAll("\n", "\n ")) - - case VCStatus.Unsatisfiable => - reporter.warning("Property wasn't satisfiable") - } - - case status => - reporter.warning(descr) - // avoid reprinting VC if --debug=verification is enabled - if (!reporter.isDebugEnabled(using DebugSectionVerification)) - reporter.warning(prettify(vc.condition).asString) - reporter.warning(vc.getPos, " => " + status.name.toUpperCase) + val silent = options.findOptionOrDefault(optSilent) + if (!silent) { + reporter.synchronized { + val descr = s" - Result for '${vc.kind}' VC for ${vc.fid.asString} @${vc.getPos}:" + vcres.status match { + case VCStatus.Valid => + reporter.debug(descr) + reporter.debug(" => VALID") + + case VCStatus.Invalid(reason) => + reporter.warning(descr) + // avoid reprinting VC if --debug=verification is enabled + if (!reporter.isDebugEnabled(using DebugSectionVerification)) + reporter.warning(prettify(vc.condition).asString) + reporter.warning(vc.getPos, " => INVALID") + reason match { + case VCStatus.CounterExample(cex) => + reporter.warning("Found counter-example:") + reporter.warning(" " + cex.asString.replaceAll("\n", "\n ")) + + case VCStatus.Unsatisfiable => + reporter.warning("Property wasn't satisfiable") + } + + case status => + reporter.warning(descr) + // avoid reprinting VC if --debug=verification is enabled + if (!reporter.isDebugEnabled(using DebugSectionVerification)) + reporter.warning(prettify(vc.condition).asString) + reporter.warning(vc.getPos, " => " + status.name.toUpperCase) + } } } diff --git a/frontends/benchmarks/equivalence/addHorn.scala b/frontends/benchmarks/equivalence/addHorn/addHorn.scala similarity index 71% rename from frontends/benchmarks/equivalence/addHorn.scala rename to frontends/benchmarks/equivalence/addHorn/addHorn.scala index a5374ee5ce..e90df97886 100644 --- a/frontends/benchmarks/equivalence/addHorn.scala +++ b/frontends/benchmarks/equivalence/addHorn/addHorn.scala @@ -10,20 +10,14 @@ object AddHorn { def add_horn_1(i: BigInt, j: BigInt): BigInt = { require(i >= 0) - if (i == 0) j + if (i == 0) j else add_horn_1(i-1, j+1) } def add_horn_2(i: BigInt, j: BigInt): BigInt = { require(i >= 0) - if (i == 0) j + if (i == 0) j else if (i == 1) j + 1 else add_horn_2(i-1, j+1) } - - @traceInduct("") - def check_add_horn(i: BigInt, j: BigInt): Unit = { - require(i >= 0) - } ensuring(add_horn_1(i, j) == add_horn_2(i, j)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/addHorn/expected_outcome.json b/frontends/benchmarks/equivalence/addHorn/expected_outcome.json new file mode 100644 index 0000000000..3893c93106 --- /dev/null +++ b/frontends/benchmarks/equivalence/addHorn/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "AddHorn.add_horn_1", + "functions": [ + "AddHorn.add_horn_2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/addHorn/test_conf.json b/frontends/benchmarks/equivalence/addHorn/test_conf.json new file mode 100644 index 0000000000..73e7f72641 --- /dev/null +++ b/frontends/benchmarks/equivalence/addHorn/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "AddHorn.add_horn_1" + ], + "comparefuns": [ + "AddHorn.add_horn_2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/boardgame/Candidate1.scala b/frontends/benchmarks/equivalence/boardgame/Candidate1.scala new file mode 100644 index 0000000000..2d516bb6ae --- /dev/null +++ b/frontends/benchmarks/equivalence/boardgame/Candidate1.scala @@ -0,0 +1,153 @@ +import stainless.lang._ +import stainless.collection._ +import defs._ + +object Candidate1 { + + def adjacencyBonus(wm: WorldMap, x: BigInt, y: BigInt, districtKind: DistrictKind): BigInt = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + def adj(tile: Tile): BigInt = { + districtKind match { + case DistrictKind.Campus() => tile.base match { + case TileBase.Mountain() => BigInt(2) + case _ => tile.construction match { + case Some(Construction.City(_)) => BigInt(1) + case Some(Construction.District(_)) => BigInt(1) + case _ => BigInt(0) + } + } + case DistrictKind.IndustrialZone() => + val resAdj = tile.resource match { + case Some(Resource.Iron()) => BigInt(2) + case Some(Resource.Coal()) => BigInt(2) + case _ => BigInt(0) + } + tile.construction match { + case Some(Construction.City(_)) => resAdj + BigInt(1) + case Some(Construction.District(_)) => resAdj + BigInt(1) + case Some(Construction.Exploitation(ResourceImprovement.Mine())) => resAdj + BigInt(1) + case Some(Construction.Exploitation(ResourceImprovement.Quarry())) => resAdj + BigInt(2) + case _ => resAdj + } + } + } + def sum(ts: List[Tile], acc: BigInt): BigInt = { + decreases(ts) + ts match { + case Nil() => acc + case Cons(tile, rest) => sum(rest, acc + adj(tile)) + } + } + sum(collectTilesInRing(wm, x, y, 1), 0) + } + + /////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////// + + def validCitySettlement(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + noCitiesInHorizon(wm, x, y) && tileFreeForSettlement(wm, x, y) + } + + ///////////////////////////////////// + + def tileInWorld(wm: WorldMap, x: BigInt, y: BigInt): Tile = { + require(0 <= y && y < wm.height) + val xx = (x % wm.width + wm.width) % wm.width + val ix = y * wm.width + xx + wm.tiles(ix) + } + + def tileFreeForSettlement(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + val tile = tileInWorld(wm, x, y) + (tile.base match { + case TileBase.FlatTerrain(_) => true + case TileBase.HillTerrain(_) => true + case _ => false + }) && (tile.construction match { + case Some(Construction.City(_)) => false + case Some(Construction.District(_)) => false + case None() => true + case Some(Construction.Exploitation(_)) => true + }) + } + + def noCitiesInHorizon(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + def loop(ls: List[Tile]): Boolean = { + decreases(ls) + ls match { + case Cons(t, rest) => t.construction match { + case Some(Construction.City(_)) => false + case _ => loop(rest) + } + case Nil() => true + } + } + loop(collectTilesWithinRadius(wm, x, y, 2)) + } + + def collectTilesWithinRadius(wm: WorldMap, x: BigInt, y: BigInt, radius: BigInt): List[Tile] = { + require(0 <= y && y < wm.height) + require(radius >= 0) + require(2 * radius < wm.width) + + def allRings(currRadius: BigInt): List[Tile] = { + decreases(radius - currRadius) + require(0 <= currRadius && currRadius <= radius) + val atThisRadius = collectTilesInRing(wm, x, y, currRadius) + if (currRadius == radius) atThisRadius + else atThisRadius ++ allRings(currRadius + 1) + } + + allRings(0) + } + + def collectTilesInRing(wm: WorldMap, x: BigInt, y: BigInt, ring: BigInt): List[Tile] = { + require(0 <= y && y < wm.height) + require(ring >= 0) + require(2 * ring < wm.width) + + def loop(i: BigInt): List[Tile] = { + require(ring > 0) + require(0 <= i && i < 6 * ring) + decreases(6 * ring - i) + + val corner = i / ring + val rest = i % ring + val diffX = { + if (corner == 0) rest + else if (corner == 1) ring + else if (corner == 2) ring - rest + else if (corner == 3) -rest + else if (corner == 4) -ring + else rest - ring + } + val diffY = { + if (corner == 0) ring - rest + else if (corner == 1) -rest + else if (corner == 2) -ring + else if (corner == 3) rest - ring + else if (corner == 4) rest + else ring + } + + val xx = x + diffX + val yy = y + diffY + val includeThis = { + if (0 <= yy && yy < wm.height) List(tileInWorld(wm, xx, yy)) + else Nil() + } + if (i == 6 * ring - 1) includeThis + else includeThis ++ loop(i + 1) + } + if (ring > 0) loop(0) + else List(tileInWorld(wm, x, y)) + } + +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/boardgame/Candidate2.scala b/frontends/benchmarks/equivalence/boardgame/Candidate2.scala new file mode 100644 index 0000000000..33a1ea6e92 --- /dev/null +++ b/frontends/benchmarks/equivalence/boardgame/Candidate2.scala @@ -0,0 +1,70 @@ +import stainless.lang._ +import stainless.collection._ +import defs._ + +object Candidate2 { + + def adjacencyBonus(wm: WorldMap, x: BigInt, y: BigInt, districtKind: DistrictKind): BigInt = { + require(0 <= y && y < wm.height) + adj(wm, x, y + 1, districtKind) + + adj(wm, x + 1, y, districtKind) + + adj(wm, x + 1, y - 1, districtKind) + + adj(wm, x, y - 1, districtKind) + + adj(wm, x - 1, y, districtKind) + + adj(wm, x - 1, y + 1, districtKind) + } + + def adj(wm: WorldMap, x: BigInt, y: BigInt, districtKind: DistrictKind): BigInt = { + if (y < 0 || y >= wm.height) BigInt(0) + else { + val tile = tileInWorld(wm, x, y) + districtKind match { + case DistrictKind.Campus() => tile.base match { + case TileBase.Mountain() => BigInt(2) + case _ => tile.construction match { + case Some(Construction.City(_)) => BigInt(1) + case Some(Construction.District(_)) => BigInt(1) + case _ => BigInt(0) + } + } + case DistrictKind.IndustrialZone() => + val resAdj = tile.resource match { + case Some(Resource.Iron()) => BigInt(2) + case Some(Resource.Coal()) => BigInt(2) + case _ => BigInt(0) + } + tile.construction match { + case Some(Construction.City(_)) => resAdj + BigInt(1) + case Some(Construction.District(_)) => resAdj + BigInt(1) + case Some(Construction.Exploitation(ResourceImprovement.Mine())) => resAdj + BigInt(1) + case Some(Construction.Exploitation(ResourceImprovement.Quarry())) => resAdj + BigInt(2) + case _ => resAdj + } + } + } + } + + /////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////// + + def validCitySettlement(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + val tile = tileInWorld(wm, x, y) + tile.base match { + case TileBase.FlatTerrain(_) => true + case TileBase.HillTerrain(_) => true + case _ => false + } + } + + ///////////////////////////////////// + + def tileInWorld(wm: WorldMap, x: BigInt, y: BigInt): Tile = { + require(0 <= y && y < wm.height) + val xx = (x % wm.width + wm.width) % wm.width + val ix = y * wm.width + xx + wm.tiles(ix) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/boardgame/Candidate3.scala b/frontends/benchmarks/equivalence/boardgame/Candidate3.scala new file mode 100644 index 0000000000..9ba7c24840 --- /dev/null +++ b/frontends/benchmarks/equivalence/boardgame/Candidate3.scala @@ -0,0 +1,140 @@ +import stainless.lang._ +import stainless.collection._ +import defs._ + +object Candidate3 { + + def adjacencyBonus(wm: WorldMap, x: BigInt, y: BigInt, districtKind: DistrictKind): BigInt = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + adj(wm, x, y + 1, districtKind) + + adj(wm, x + 1, y, districtKind) + + adj(wm, x + 1, y - 1, districtKind) + + adj(wm, x, y - 1, districtKind) + + adj(wm, x - 1, y, districtKind) + + adj(wm, x - 1, y + 1, districtKind) + } + + def adj(wm: WorldMap, x: BigInt, y: BigInt, districtKind: DistrictKind): BigInt = { + // oops, forgot to check for OOB y... + val tile = tileInWorld(wm, x, y) + districtKind match { + case DistrictKind.Campus() => tile.base match { + case TileBase.Mountain() => BigInt(2) + case _ => tile.construction match { + case Some(Construction.City(_)) => BigInt(1) + case Some(Construction.District(_)) => BigInt(1) + case _ => BigInt(0) + } + } + case DistrictKind.IndustrialZone() => + val resAdj = tile.resource match { + case Some(Resource.Iron()) => BigInt(2) + case Some(Resource.Coal()) => BigInt(2) + case _ => BigInt(0) + } + tile.construction match { + case Some(Construction.City(_)) => resAdj + BigInt(1) + case Some(Construction.District(_)) => resAdj + BigInt(1) + case Some(Construction.Exploitation(ResourceImprovement.Mine())) => resAdj + BigInt(1) + case Some(Construction.Exploitation(ResourceImprovement.Quarry())) => resAdj + BigInt(2) + case _ => resAdj + } + } + } + + + /////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////// + + def validCitySettlement(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + noCitiesInHorizon(wm, x, y) // oops, forgot to check whether the tile to settle on is ok... + } + + ///////////////////////////////////// + + def tileInWorld(wm: WorldMap, x: BigInt, y: BigInt): Tile = { + require(0 <= y && y < wm.height) + val xx = (x % wm.width + wm.width) % wm.width + val ix = y * wm.width + xx + wm.tiles(ix) + } + + def noCitiesInHorizon(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + def loop(ls: List[Tile]): Boolean = { + decreases(ls) + ls match { + case Cons(t, rest) => t.construction match { + case Some(Construction.City(_)) => false + case _ => loop(rest) + } + case Nil() => true + } + } + loop(collectTilesWithinRadius(wm, x, y, 2)) + } + + def collectTilesWithinRadius(wm: WorldMap, x: BigInt, y: BigInt, radius: BigInt): List[Tile] = { + require(0 <= y && y < wm.height) + require(radius >= 0) + require(2 * radius < wm.width) + + def allRings(currRadius: BigInt): List[Tile] = { + decreases(radius - currRadius) + require(0 <= currRadius && currRadius <= radius) + val atThisRadius = collectTilesInRing(wm, x, y, currRadius) + if (currRadius == radius) atThisRadius + else atThisRadius ++ allRings(currRadius + 1) + } + + allRings(0) + } + + def collectTilesInRing(wm: WorldMap, x: BigInt, y: BigInt, ring: BigInt): List[Tile] = { + require(0 <= y && y < wm.height) + require(ring >= 0) + require(2 * ring < wm.width) + + def loop(i: BigInt): List[Tile] = { + require(ring > 0) + require(0 <= i && i < 6 * ring) + decreases(6 * ring - i) + + val corner = i / ring + val rest = i % ring + val diffX = { + if (corner == 0) rest + else if (corner == 1) ring + else if (corner == 2) ring - rest + else if (corner == 3) -rest + else if (corner == 4) -ring + else rest - ring + } + val diffY = { + if (corner == 0) ring - rest + else if (corner == 1) -rest + else if (corner == 2) -ring + else if (corner == 3) rest - ring + else if (corner == 4) rest + else ring + } + + val xx = x + diffX + val yy = y + diffY + val includeThis = { + if (0 <= yy && yy < wm.height) List(tileInWorld(wm, xx, yy)) + else Nil() + } + if (i == 6 * ring - 1) includeThis + else includeThis ++ loop(i + 1) + } + if (ring > 0) loop(0) + else List(tileInWorld(wm, x, y)) + } + +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/boardgame/Candidate4.scala b/frontends/benchmarks/equivalence/boardgame/Candidate4.scala new file mode 100644 index 0000000000..11b02ac038 --- /dev/null +++ b/frontends/benchmarks/equivalence/boardgame/Candidate4.scala @@ -0,0 +1,112 @@ +import stainless.lang._ +import stainless.collection._ +import defs._ + +object Candidate4 { + + def adjacencyBonus(wm: WorldMap, x: BigInt, y: BigInt, districtKind: DistrictKind): BigInt = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + // No, one must still compute adjacency even for tiles on the y-border of the map + if (0 < y && y < wm.height - 1) { + adj(wm, x, y + 1, districtKind) + + adj(wm, x + 1, y, districtKind) + + adj(wm, x + 1, y - 1, districtKind) + + adj(wm, x, y - 1, districtKind) + + adj(wm, x - 1, y, districtKind) + + adj(wm, x - 1, y + 1, districtKind) + } else BigInt(0) + } + + def adj(wm: WorldMap, x: BigInt, y: BigInt, districtKind: DistrictKind): BigInt = { + require(0 <= y && y < wm.height) + val tile = tileInWorld(wm, x, y) + districtKind match { + case DistrictKind.Campus() => tile.base match { + case TileBase.Mountain() => BigInt(2) + case _ => tile.construction match { + case Some(Construction.City(_)) => BigInt(1) + case Some(Construction.District(_)) => BigInt(1) + case _ => BigInt(0) + } + } + case DistrictKind.IndustrialZone() => + val resAdj = tile.resource match { + case Some(Resource.Iron()) => BigInt(2) + case Some(Resource.Coal()) => BigInt(2) + case _ => BigInt(0) + } + tile.construction match { + case Some(Construction.City(_)) => resAdj + BigInt(1) + case Some(Construction.District(_)) => resAdj + BigInt(1) + case Some(Construction.Exploitation(ResourceImprovement.Mine())) => resAdj + BigInt(1) + case Some(Construction.Exploitation(ResourceImprovement.Quarry())) => resAdj + BigInt(2) + case _ => resAdj + } + } + } + + /////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////// + + def validCitySettlement(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + // Desperately trying to do that by-hand, but forgets about the second ring... + val (p1x, p1y) = (x, y + 1) + val (p2x, p2y) = (x + 1, y) + val (p3x, p3y) = (x + 1, y - 1) + val (p4x, p4y) = (x, y - 1) + val (p5x, p5y) = (x - 1, y) + val (p6x, p6y) = (x - 1, y + 1) + tileFreeForSettlement(wm, x, y) && notACity(wm, p1x, p1y) && + notACity(wm, p2x, p2y) && + notACity(wm, p3x, p3y) && + notACity(wm, p4x, p4y) && + notACity(wm, p5x, p5y) && + notACity(wm, p6x, p6y) + } + + ///////////////////////////////////// + + def tileInWorld(wm: WorldMap, x: BigInt, y: BigInt): Tile = { + require(0 <= y && y < wm.height) + val xx = (x % wm.width + wm.width) % wm.width + val ix = y * wm.width + xx + wm.tiles(ix) + } + + def tileFreeForSettlement(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + val tile = tileInWorld(wm, x, y) + (tile.base match { + case TileBase.FlatTerrain(_) => true + case TileBase.HillTerrain(_) => true + case _ => false + }) && notACity(wm, x, y) && notADistrict(wm, x, y) + } + + def notACity(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + !(0 <= y && y < wm.height) || { + val tile = tileInWorld(wm, x, y) + tile.construction match { + case Some(Construction.City(_)) => false + case Some(Construction.District(_)) => true + case Some(Construction.Exploitation(_)) => true + case None() => true + } + } + } + def notADistrict(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + !(0 <= y && y < wm.height) || { + val tile = tileInWorld(wm, x, y) + tile.construction match { + case Some(Construction.District(_)) => false + case Some(Construction.City(_)) => true + case Some(Construction.Exploitation(_)) => true + case None() => true + } + } + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/boardgame/Model.scala b/frontends/benchmarks/equivalence/boardgame/Model.scala new file mode 100644 index 0000000000..f7e265d328 --- /dev/null +++ b/frontends/benchmarks/equivalence/boardgame/Model.scala @@ -0,0 +1,253 @@ +import stainless.lang._ +import stainless.collection._ +import stainless.annotation._ +import defs._ + +object Model { + + // Part 1. Calculating adjacency bonus + // Rules: + // -For an industrial zone: + // -Adjacent iron, coal or quarry: +1 + // -Adjacent mine, city or district: +1/2 + // -For a campus: + // -Adjacent mountain: +1 + // -Adjacent city or district: +1/2 + // Since we have half point, the result is doubled to have integers + def adjacencyBonus1(wm: WorldMap, x: BigInt, y: BigInt, districtKind: DistrictKind): BigInt = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + adj(wm, x, y + 1, districtKind) + + adj(wm, x + 1, y, districtKind) + + adj(wm, x + 1, y - 1, districtKind) + + adj(wm, x, y - 1, districtKind) + + adj(wm, x - 1, y, districtKind) + + adj(wm, x - 1, y + 1, districtKind) + } + + def adjacencyBonus2(wm: WorldMap, x: BigInt, y: BigInt, districtKind: DistrictKind): BigInt = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + + def sum(acc: BigInt, ts: List[Tile]): BigInt = { + decreases(ts) + ts match { + case Nil() => acc + case Cons(tile, rest) => sum(acc + adj(tile, districtKind), rest) + } + } + sum(0, collectTilesInRing(wm, x, y, 1)) + } + + def testsAdjacencyBonus: List[(WorldMap, BigInt, BigInt, DistrictKind)] = List( + testsAdjacencyBonus1, + ) + + def testsAdjacencyBonus1: (WorldMap, BigInt, BigInt, DistrictKind) = { + val G = Tile(TileBase.FlatTerrain(BaseTerrain.Grassland()), None(), None(), None()) + val M = Tile(TileBase.Mountain(), None(), None(), None()) + val X = G // The emplacement where we would like to compute for potential adjacency + val wm = List( + G, M, X, M, G, + G, G, G, M, G, + G, M, G, G, G, + G, G, G, G, G, + ) + // Note: the coordinates are upside down + (WorldMap(wm, 5, 4), 2, 0, DistrictKind.Campus()) + } + + ////////////////////// + + def adj(wm: WorldMap, x: BigInt, y: BigInt, districtKind: DistrictKind): BigInt = { + if (y < 0 || y >= wm.height) BigInt(0) + else adj(wm(x, y), districtKind) + } + + def adj(tile: Tile, districtKind: DistrictKind): BigInt = { + (districtKind, tile) match { + case (DistrictKind.Campus(), Tile(TileBase.Mountain(), _, _, _)) => BigInt(2) + case (DistrictKind.Campus(), Tile(_, _, _, Some(Construction.City(_)))) => BigInt(1) + case (DistrictKind.Campus(), Tile(_, _, _, Some(Construction.District(_)))) => BigInt(1) + case (DistrictKind.Campus(), _) => BigInt(0) + case (DistrictKind.IndustrialZone(), Tile(_, _, res, ctor)) => + val resAdj = res match { + case Some(Resource.Iron()) => BigInt(2) + case Some(Resource.Coal()) => BigInt(2) + case _ => BigInt(0) + } + val resCtor = ctor match { + case Some(Construction.City(_)) => BigInt(1) + case Some(Construction.District(_)) => BigInt(1) + case Some(Construction.Exploitation(ResourceImprovement.Mine())) => BigInt(1) + case Some(Construction.Exploitation(ResourceImprovement.Quarry())) => BigInt(2) + case _ => BigInt(0) + } + resAdj + resCtor + } + } + + /////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////// + + // Part 2. Determining whether a placement is suitable for settling + // -Rules: no other city in a 2-tile range + // -The tile must be adequate for settling (flat or hill terrain, and must not have another city or district on it) + def validCitySettlement(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + tileOkForCity(wm, x, y) && noOtherCitiesInRange(wm, x, y) + } + + def testsValidCitySettlement: List[(WorldMap, BigInt, BigInt)] = List( + testValidCitySettlement1, + testValidCitySettlement2, + testValidCitySettlement3, + ) + + def testValidCitySettlement1: (WorldMap, BigInt, BigInt) = { + // Ok, can be settled + val G = Tile(TileBase.FlatTerrain(BaseTerrain.Grassland()), None(), None(), None()) + val X = G // where we would like to settle + val wm = List( + G, G, G, G, G, + G, G, G, G, G, + G, X, G, G, G, + G, G, G, G, G, + ) + // Note: the coordinates are upside down + (WorldMap(wm, 5, 4), 1, 2) + } + + def testValidCitySettlement2: (WorldMap, BigInt, BigInt) = { + // A lake in the center, we can't settle there + val G = Tile(TileBase.FlatTerrain(BaseTerrain.Grassland()), None(), None(), None()) + val L = Tile(TileBase.Lake(), None(), None(), None()) + val wm = List( + G, G, G, G, G, + G, G, L, G, G, + G, G, G, G, G, + ) + // Note: the coordinates are upside down + (WorldMap(wm, 5, 3), 2, 1) + } + + def testValidCitySettlement3: (WorldMap, BigInt, BigInt) = { + // A city in the second ring of the place where we want to settle + val G = Tile(TileBase.FlatTerrain(BaseTerrain.Grassland()), None(), None(), None()) + val X = G // where we would like to settle + val Y = Tile(TileBase.FlatTerrain(BaseTerrain.Grassland()), None(), None(), Some(Construction.City(42))) // Oh no, someone's already there :( + val wm = List( + G, G, Y, G, G, + G, G, G, G, G, + G, X, G, G, G, + G, G, G, G, G, + ) + // Note: the coordinates are upside down + (WorldMap(wm, 5, 4), 1, 2) + } + + /////////////////////////////////////////////////////////////////////////// + + def tileOkForCity(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + val tile = wm(x, y) + val baseOk = tile.base match { + case TileBase.FlatTerrain(_) => true + case TileBase.HillTerrain(_) => true + case _ => false + } + val ctorOk = tile.construction match { + case None() => true + case Some(Construction.Exploitation(_)) => true // res. improvement removed on settling + case Some(Construction.District(_)) => false + case Some(Construction.City(_)) => false + } + baseOk && ctorOk + } + + def noOtherCitiesInRange(wm: WorldMap, x: BigInt, y: BigInt): Boolean = { + require(0 <= y && y < wm.height) + require(wm.width > 4) + def loop(ls: List[Tile]): Boolean = { + decreases(ls) + ls match { + case Cons(t, rest) => t.construction match { + case Some(Construction.City(_)) => false + case _ => loop(rest) + } + case Nil() => true + } + } + loop(allTilesWithinRadius(wm, x, y, 2)) + } + + // Note: includes the x,y tile as well + def allTilesWithinRadius(wm: WorldMap, x: BigInt, y: BigInt, radius: BigInt): List[Tile] = { + require(0 <= y && y < wm.height) + require(radius >= 0) + require(2 * radius < wm.width) // To avoid repetition of tiles due to wrapping + + def allRings(currRadius: BigInt): List[Tile] = { + decreases(radius - currRadius) + require(0 <= currRadius && currRadius <= radius) + val atThisRadius = collectTilesInRing(wm, x, y, currRadius) + if (currRadius == radius) atThisRadius + else atThisRadius ++ allRings(currRadius + 1) + } + + allRings(0) + } + + def collectTilesInRing(wm: WorldMap, x: BigInt, y: BigInt, radius: BigInt): List[Tile] = { + require(0 <= y && y < wm.height) + require(radius >= 0) + require(2 * radius < wm.width) + + def loop(i: BigInt): List[Tile] = { + require(radius > 0) + require(0 <= i && i < 6 * radius) + decreases(6 * radius - i) + + val corner = i / radius + val rest = i % radius + val diffX = { + if (corner == 0) rest + else if (corner == 1) radius + else if (corner == 2) radius - rest + else if (corner == 3) -rest + else if (corner == 4) -radius + else rest - radius + } + val diffY = { + if (corner == 0) radius - rest + else if (corner == 1) -rest + else if (corner == 2) -radius + else if (corner == 3) rest - radius + else if (corner == 4) rest + else radius + } + + val xx = x + diffX + val yy = y + diffY + val includeThis = { + if (0 <= yy && yy < wm.height) List(wm(xx, yy)) + else Nil() + } + if (i == 6 * radius - 1) includeThis + else includeThis ++ loop(i + 1) + } + if (radius == 0) List(wm(x, y)) + else loop(0) + } + + extension (wm: WorldMap) { + def apply(x: BigInt, y: BigInt): Tile = { + require(0 <= y && y < wm.height) + val xx = (x % wm.width + wm.width) % wm.width + val ix = y * wm.width + xx + wm.tiles(ix) + } + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/boardgame/defs.scala b/frontends/benchmarks/equivalence/boardgame/defs.scala new file mode 100644 index 0000000000..f2f6e1bb31 --- /dev/null +++ b/frontends/benchmarks/equivalence/boardgame/defs.scala @@ -0,0 +1,73 @@ +import stainless.lang._ +import stainless.collection._ + +object defs { + + case class Tile( + base: TileBase, + feature: Option[Feature], + resource: Option[Resource], + construction: Option[Construction] + ) + + // Hexagon tiles, cylinder world (i.e. wraps around x-axis) + case class WorldMap(tiles: List[Tile], width: BigInt, height: BigInt) { + require(width > 0 && height > 0) + require(tiles.length == width * height) + } + + enum TileBase { + case FlatTerrain(base: BaseTerrain) + case HillTerrain(base: BaseTerrain) + case Mountain() + case Lake() + case Coast() + case Ocean() + } + + enum BaseTerrain { + case Plains() + case Grassland() + case Desert() + case Tundra() + case Snow() + } + + enum Feature { + case Forest() + case RainForest() + case Marsh() + // etc. + } + + enum Resource { + case Iron() // can't make Stainless Steel without Iron, so this must be here + case Wheat() + case Rice() + case Stone() // weak + case Crabs() // the best + case Fish() + case Coal() // cursed + // etc. + } + + enum Construction { + case City(id: BigInt) + case District(kind: DistrictKind) + case Exploitation(kind: ResourceImprovement) + } + + enum DistrictKind { + case Campus() + case IndustrialZone() + // etc. + } + + enum ResourceImprovement { + case Farm() + case Fishery() + case Mine() + case Quarry() + // etc. + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/boardgame/expected_outcome_1.json b/frontends/benchmarks/equivalence/boardgame/expected_outcome_1.json new file mode 100644 index 0000000000..9943855a0d --- /dev/null +++ b/frontends/benchmarks/equivalence/boardgame/expected_outcome_1.json @@ -0,0 +1,17 @@ +{ + "equivalent": [ + { + "model": "Model.validCitySettlement", + "functions": [ + "Candidate1.validCitySettlement" + ] + } + ], + "erroneous": [ + "Candidate2.validCitySettlement", + "Candidate3.validCitySettlement", + "Candidate4.validCitySettlement" + ], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/boardgame/expected_outcome_2.json b/frontends/benchmarks/equivalence/boardgame/expected_outcome_2.json new file mode 100644 index 0000000000..3846828b35 --- /dev/null +++ b/frontends/benchmarks/equivalence/boardgame/expected_outcome_2.json @@ -0,0 +1,22 @@ +{ + "equivalent": [ + { + "model": "Model.adjacencyBonus2", + "functions": [ + "Candidate1.adjacencyBonus" + ] + }, + { + "model": "Model.adjacencyBonus1", + "functions": [ + "Candidate2.adjacencyBonus" + ] + } + ], + "erroneous": [ + "Candidate3.adjacencyBonus", + "Candidate4.adjacencyBonus" + ], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/boardgame/test_conf_1.json b/frontends/benchmarks/equivalence/boardgame/test_conf_1.json new file mode 100644 index 0000000000..46834a115b --- /dev/null +++ b/frontends/benchmarks/equivalence/boardgame/test_conf_1.json @@ -0,0 +1,14 @@ +{ + "models": [ + "Model.validCitySettlement" + ], + "comparefuns": [ + "Candidate1.validCitySettlement", + "Candidate2.validCitySettlement", + "Candidate3.validCitySettlement", + "Candidate4.validCitySettlement" + ], + "tests": [ + "Model.testsValidCitySettlement" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/boardgame/test_conf_2.json b/frontends/benchmarks/equivalence/boardgame/test_conf_2.json new file mode 100644 index 0000000000..6a839debd4 --- /dev/null +++ b/frontends/benchmarks/equivalence/boardgame/test_conf_2.json @@ -0,0 +1,15 @@ +{ + "models": [ + "Model.adjacencyBonus1", + "Model.adjacencyBonus2" + ], + "comparefuns": [ + "Candidate1.adjacencyBonus", + "Candidate2.adjacencyBonus", + "Candidate3.adjacencyBonus", + "Candidate4.adjacencyBonus" + ], + "tests": [ + "Model.testsAdjacencyBonus" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/cloudSculpting/Candidate.scala b/frontends/benchmarks/equivalence/cloudSculpting/Candidate.scala new file mode 100644 index 0000000000..b2be501cd7 --- /dev/null +++ b/frontends/benchmarks/equivalence/cloudSculpting/Candidate.scala @@ -0,0 +1,31 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate { + def sculpteurDeNuage[A, B, C](a: A, b: B, c: C, fuel: BigInt, i1: BigInt, i2: BigInt, i3: BigInt, a2b: A => B, b2c: B => C, c2a: C => A): (BigInt, BigInt, BigInt) = { + require(fuel >= 0) + leVraiSculpteurDeNuage(b2c, b, c, fuel, a2b, i2, a, i3, i1, c2a) + } + + def leVraiSculpteurDeNuage[A, B, C](b2c: B => C, b: B, c: C, fuel: BigInt, a2b: A => B, i2: BigInt, a: A, i3: BigInt, i1: BigInt, c2a: C => A): (BigInt, BigInt, BigInt)= { + require(fuel >= 0) + decreases(fuel) + if (fuel == 0) (i1, i2, i3) + else { + val (ii1, ii2, ii3) = mixmash(i2, i3, i1) + leVraiSculpteurDeNuage(b2c, a2b(a), b2c(b), fuel - 1, a2b, ii2, c2a(c), ii3, ii1, c2a) + } + } + + def mixmash(i2: BigInt, i3: BigInt, i1: BigInt): (BigInt, BigInt, BigInt) = { + decreases(if (i2 <= 0) -i2 else i2) + if (i2 == 0) (-1, i3, i1) + else if (i2 > 0) { + val (r1, r2, r3) = mixmash(i2 - 1, i3 + 1, i1) + (r1 + 1, r2, r3) + } else { + val (r1, r2, r3) = mixmash(i2 + 1, i3 - 1, i1) + (r1 - 1, r2, r3) + } + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/cloudSculpting/Model.scala b/frontends/benchmarks/equivalence/cloudSculpting/Model.scala new file mode 100644 index 0000000000..a3808f65a7 --- /dev/null +++ b/frontends/benchmarks/equivalence/cloudSculpting/Model.scala @@ -0,0 +1,21 @@ +import stainless.lang._ +import stainless.collection._ + +object Model { + def sculpteurDeNuage[A, B, C](a: A, b: B, c: C, fuel: BigInt, i1: BigInt, i2: BigInt, i3: BigInt, a2b: A => B, b2c: B => C, c2a: C => A): (BigInt, BigInt, BigInt) = { + require(fuel >= 0) + leVraiSculpteurDeNuage(a, b, c, fuel, i1, i2, i3, a2b, b2c, c2a) + } + + def leVraiSculpteurDeNuage[A, B, C](a: A, b: B, c: C, fuel: BigInt, i1: BigInt, i2: BigInt, i3: BigInt, a2b: A => B, b2c: B => C, c2a: C => A): (BigInt, BigInt, BigInt) = { + require(fuel >= 0) + decreases(fuel) + if (fuel == 0) (i1, i2, i3) + else { + val (ii1, ii2, ii3) = mixmash(i1, i2, i3) + leVraiSculpteurDeNuage(c2a(c), a2b(a), b2c(b), fuel - 1, ii1, ii2, ii3, a2b, b2c, c2a) + } + } + + def mixmash(i1: BigInt, i2: BigInt, i3: BigInt): (BigInt, BigInt, BigInt) = (i2 - 1, i3 + i2, i1) +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/cloudSculpting/expected_outcome.json b/frontends/benchmarks/equivalence/cloudSculpting/expected_outcome.json new file mode 100644 index 0000000000..0fdf05d7b6 --- /dev/null +++ b/frontends/benchmarks/equivalence/cloudSculpting/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Model.sculpteurDeNuage", + "functions": [ + "Candidate.sculpteurDeNuage" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/cloudSculpting/test_conf.json b/frontends/benchmarks/equivalence/cloudSculpting/test_conf.json new file mode 100644 index 0000000000..546e496e65 --- /dev/null +++ b/frontends/benchmarks/equivalence/cloudSculpting/test_conf.json @@ -0,0 +1,10 @@ +{ + "models": [ + "Model.sculpteurDeNuage" + ], + "comparefuns": [ + "Candidate.sculpteurDeNuage" + ], + "tests": [], + "max-perm": 144 +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/dup/Candidate1.scala b/frontends/benchmarks/equivalence/dup/Candidate1.scala new file mode 100644 index 0000000000..d123e1211d --- /dev/null +++ b/frontends/benchmarks/equivalence/dup/Candidate1.scala @@ -0,0 +1,13 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate1 { + + def dup[S, T](n: BigInt, s: S, t: T): List[(S, T)] = { + decreases(if (n <= 0) BigInt(0) else n) + if (n == 0) Nil() + else if (n < 0) List((s, t), (s, t), (s, t), (s, t), (s, t)) // Ok because norm will remove this + else (s, t) :: dup(n - 1, s, t) + } + +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/dup/Candidate2.scala b/frontends/benchmarks/equivalence/dup/Candidate2.scala new file mode 100644 index 0000000000..0cef1d4b54 --- /dev/null +++ b/frontends/benchmarks/equivalence/dup/Candidate2.scala @@ -0,0 +1,12 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate2 { + + def dup[S, T](n: BigInt, s: S, t: T): List[(S, T)] = { + decreases(if (n <= 0) BigInt(0) else n) + if (n <= 0) List((s, t)) + else (s, t) :: dup(n - 1, s, t) + } + +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/dup/Candidate3.scala b/frontends/benchmarks/equivalence/dup/Candidate3.scala new file mode 100644 index 0000000000..445d008998 --- /dev/null +++ b/frontends/benchmarks/equivalence/dup/Candidate3.scala @@ -0,0 +1,13 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate3 { + + // Wrong signature + def dup[S, T](n: BigInt, t: T, s: S): List[(S, T)] = { + decreases(if (n <= 0) BigInt(0) else n) + if (n <= 0) Nil() + else (s, t) :: dup(n - 1, t, s) + } + +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/dup/Candidate4.scala b/frontends/benchmarks/equivalence/dup/Candidate4.scala new file mode 100644 index 0000000000..bde7b65d26 --- /dev/null +++ b/frontends/benchmarks/equivalence/dup/Candidate4.scala @@ -0,0 +1,11 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate4 { + + def dup[S, T](n: BigInt, s: S, t: T): List[(S, T)] = { + decreases(if (n <= 0) BigInt(0) else n) + if (n <= 0) Nil() + else dup(n - 1, s, t) // duplicating nils, very useful + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/dup/Candidate5.scala b/frontends/benchmarks/equivalence/dup/Candidate5.scala new file mode 100644 index 0000000000..48e09721a9 --- /dev/null +++ b/frontends/benchmarks/equivalence/dup/Candidate5.scala @@ -0,0 +1,10 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate5 { + def dup[S, T](n: BigInt, s: S, t: T): List[(S, T)] = { + decreases(if (n <= 0) BigInt(0) else n) + if (n <= 0) Nil() + else (s, t) :: dup(n - 1, s, t) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/dup/Model.scala b/frontends/benchmarks/equivalence/dup/Model.scala new file mode 100644 index 0000000000..c1d3a1bbd1 --- /dev/null +++ b/frontends/benchmarks/equivalence/dup/Model.scala @@ -0,0 +1,16 @@ +import stainless.lang._ +import stainless.collection._ + +object Model { + + def dup[S, T](n: BigInt, s: S, t: T): List[(S, T)] = { + decreases(if (n <= 0) BigInt(0) else n) + if (n <= 0) Nil() + else (s, t) :: dup(n - 1, s, t) + } + + def norm[S, T](n: BigInt, s: S, t: T, res: List[(S, T)]): List[(S, T)] = { + if (n < 0) Nil() + else res + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/dup/expected_outcome.json b/frontends/benchmarks/equivalence/dup/expected_outcome.json new file mode 100644 index 0000000000..ccde6ffcc3 --- /dev/null +++ b/frontends/benchmarks/equivalence/dup/expected_outcome.json @@ -0,0 +1,19 @@ +{ + "equivalent": [ + { + "model": "Model.dup", + "functions": [ + "Candidate1.dup", + "Candidate5.dup" + ] + } + ], + "erroneous": [ + "Candidate2.dup", + "Candidate4.dup" + ], + "timeout": [], + "wrong": [ + "Candidate3.dup" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/dup/test_conf.json b/frontends/benchmarks/equivalence/dup/test_conf.json new file mode 100644 index 0000000000..bf1b2b0f62 --- /dev/null +++ b/frontends/benchmarks/equivalence/dup/test_conf.json @@ -0,0 +1,14 @@ +{ + "models": [ + "Model.dup" + ], + "comparefuns": [ + "Candidate1.dup", + "Candidate2.dup", + "Candidate3.dup", + "Candidate4.dup", + "Candidate5.dup" + ], + "tests": [], + "norm": "Model.norm" +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/factorial/expected_outcome.json b/frontends/benchmarks/equivalence/factorial/expected_outcome.json new file mode 100644 index 0000000000..ed93b5ef99 --- /dev/null +++ b/frontends/benchmarks/equivalence/factorial/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Factorial.fact14_1", + "functions": [ + "Factorial.fact14_2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/factorial.scala b/frontends/benchmarks/equivalence/factorial/factorial.scala similarity index 69% rename from frontends/benchmarks/equivalence/factorial.scala rename to frontends/benchmarks/equivalence/factorial/factorial.scala index 10f6632247..7b06e00522 100644 --- a/frontends/benchmarks/equivalence/factorial.scala +++ b/frontends/benchmarks/equivalence/factorial/factorial.scala @@ -10,17 +10,12 @@ object Factorial { // Fig. 14 - def fact14_1(n: BigInt): BigInt = - if (n <= 1) 1 + def fact14_1(n: BigInt): BigInt = + if (n <= 1) 1 else n * fact14_1(n-1) def fact14_2(n: BigInt): BigInt = - if (n <= 1) 1 + if (n <= 1) 1 else if (n == 10) 3628800 else n * fact14_2(n-1) - - @traceInduct("") - def check_fact14(n: BigInt): Unit = { - } ensuring(fact14_1(n) == fact14_2(n)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/factorial/test_conf.json b/frontends/benchmarks/equivalence/factorial/test_conf.json new file mode 100644 index 0000000000..761a7d7390 --- /dev/null +++ b/frontends/benchmarks/equivalence/factorial/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Factorial.fact14_1" + ], + "comparefuns": [ + "Factorial.fact14_2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/fibonacci/expected_outcome.json b/frontends/benchmarks/equivalence/fibonacci/expected_outcome.json new file mode 100644 index 0000000000..4b6a8b2ba0 --- /dev/null +++ b/frontends/benchmarks/equivalence/fibonacci/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Fibonacci.f1", + "functions": [ + "Fibonacci.f2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/fibonacci.scala b/frontends/benchmarks/equivalence/fibonacci/fibonacci.scala similarity index 85% rename from frontends/benchmarks/equivalence/fibonacci.scala rename to frontends/benchmarks/equivalence/fibonacci/fibonacci.scala index c8067f2169..2b1af9270c 100644 --- a/frontends/benchmarks/equivalence/fibonacci.scala +++ b/frontends/benchmarks/equivalence/fibonacci/fibonacci.scala @@ -21,9 +21,4 @@ object Fibonacci { else if (n <= 2) 1 else f2(n-2) + f2(n-2) + f2(n-3) } - - @traceInduct("") - def check_f(n: BigInt): Unit = { - } ensuring(f1(n) == f2(n)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/fibonacci/test_conf.json b/frontends/benchmarks/equivalence/fibonacci/test_conf.json new file mode 100644 index 0000000000..7d988fdbfc --- /dev/null +++ b/frontends/benchmarks/equivalence/fibonacci/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Fibonacci.f1" + ], + "comparefuns": [ + "Fibonacci.f2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/foolproof/Candidate.scala b/frontends/benchmarks/equivalence/foolproof/Candidate.scala new file mode 100644 index 0000000000..d0ca5f4403 --- /dev/null +++ b/frontends/benchmarks/equivalence/foolproof/Candidate.scala @@ -0,0 +1,20 @@ +import stainless.collection._ +import stainless.lang._ + +object Candidate { + def choose(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) BigInt(0) else x) + if (x <= 0) y + else if (y <= 0) x + else choose(x - 1, y - 1) + } + + def funnyZip(xs: List[BigInt], ys: List[BigInt]): List[BigInt] = { + decreases(xs) + (xs, ys) match { + case (_, Nil()) => Nil() + case (Nil(), _) => Nil() + case (x :: xs, y :: ys) => choose(x, y) :: funnyZip(xs, ys) + } + } +} diff --git a/frontends/benchmarks/equivalence/foolproof/Model.scala b/frontends/benchmarks/equivalence/foolproof/Model.scala new file mode 100644 index 0000000000..a7f7955140 --- /dev/null +++ b/frontends/benchmarks/equivalence/foolproof/Model.scala @@ -0,0 +1,22 @@ +import stainless.collection._ +import stainless.lang._ + +// Tests whether `choose` matching avoidance do not get fooled by functions named `choose`. +// See max3 for explanation on this "choose matching avoidance" +object Model { + def choose(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) BigInt(0) else x) + if (x <= 0) y + else if (y <= 0) x + else choose(x - 1, y - 1) + } + + def funnyZip(xs: List[BigInt], ys: List[BigInt]): List[BigInt] = { + decreases(xs) + (xs, ys) match { + case (_, Nil()) => Nil() + case (Nil(), _) => Nil() + case (x :: xs, y :: ys) => choose(x, y) :: funnyZip(xs, ys) + } + } +} diff --git a/frontends/benchmarks/equivalence/foolproof/expected_outcome.json b/frontends/benchmarks/equivalence/foolproof/expected_outcome.json new file mode 100644 index 0000000000..17d07c0e48 --- /dev/null +++ b/frontends/benchmarks/equivalence/foolproof/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Model.funnyZip", + "functions": [ + "Candidate.funnyZip" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/foolproof/test_conf.json b/frontends/benchmarks/equivalence/foolproof/test_conf.json new file mode 100644 index 0000000000..c1bb468cac --- /dev/null +++ b/frontends/benchmarks/equivalence/foolproof/test_conf.json @@ -0,0 +1,9 @@ +{ + "models": [ + "Model.funnyZip" + ], + "comparefuns": [ + "Candidate.funnyZip" + ], + "tests": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/fullAlternation/expected_outcome.json b/frontends/benchmarks/equivalence/fullAlternation/expected_outcome.json new file mode 100644 index 0000000000..3917c37b67 --- /dev/null +++ b/frontends/benchmarks/equivalence/fullAlternation/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "FullAlternation.m1", + "functions": [ + "FullAlternation.m2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/fullAlternation.scala b/frontends/benchmarks/equivalence/fullAlternation/fullAlternation.scala similarity index 84% rename from frontends/benchmarks/equivalence/fullAlternation.scala rename to frontends/benchmarks/equivalence/fullAlternation/fullAlternation.scala index 28e5574e60..42cec721bb 100644 --- a/frontends/benchmarks/equivalence/fullAlternation.scala +++ b/frontends/benchmarks/equivalence/fullAlternation/fullAlternation.scala @@ -15,7 +15,7 @@ object FullAlternation { else if (n == 1) 1 else m1(n - 1, !flag) + m1(n - 2, !flag) } - + def m2(n: BigInt, mode: Boolean): BigInt = { if (n < 1) 0 else if (n == 1 || n == 2) 1 @@ -23,12 +23,8 @@ object FullAlternation { var results: BigInt = 0 if (mode) results = m2(n-2, !mode) + m2(n-2, !mode) + m2(n-3, !mode) if (!mode) results = m2(n-1, !mode) + m2(n-2, !mode) - results + results } } - @traceInduct("") - def check_m(n: BigInt, flag: Boolean): Unit = { - } ensuring(m1(n, flag) == m2(n, flag)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/fullAlternation/test_conf.json b/frontends/benchmarks/equivalence/fullAlternation/test_conf.json new file mode 100644 index 0000000000..17c09be5b2 --- /dev/null +++ b/frontends/benchmarks/equivalence/fullAlternation/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "FullAlternation.m1" + ], + "comparefuns": [ + "FullAlternation.m2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith1/Candidate.scala b/frontends/benchmarks/equivalence/funnyarith1/Candidate.scala new file mode 100644 index 0000000000..14b553553e --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith1/Candidate.scala @@ -0,0 +1,33 @@ +import stainless.lang._ +import stainless.collection._ +import defs._ + +object Candidate { + // Top level + def eval(op: OpKind, x: BigInt, y: BigInt): BigInt = op match { + case OpKind.Sub() => mySub(x, y) + case OpKind.Mul() => myMul(x, y) + case OpKind.Add() => myAdd(x, y) + } + + def myAdd(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) y + else if (x > 0) myAdd(x - 1, y + 1) + else myAdd(x + 1, y - 1) + } + + def mySub(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) -y + else if (x > 0) mySub(x - 1, y - 1) + else mySub(x + 1, y + 1) + } + + def myMul(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) BigInt(0) + else if (x > 0) myAdd(myMul(x - 1, y), y) + else mySub(myMul(x + 1, y), y) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith1/Model.scala b/frontends/benchmarks/equivalence/funnyarith1/Model.scala new file mode 100644 index 0000000000..e165787cca --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith1/Model.scala @@ -0,0 +1,36 @@ +import stainless.lang._ +import stainless.collection._ +import defs._ + +// Testing subfns matching +// In Candidate.eval, the order of patmat over op is not the same here to ensure +// a different starting matching strategy from the correct one: add <-> myAdd; sub <-> mySub; mul <-> myMul +object Model { + // Top level + def eval(op: OpKind, x: BigInt, y: BigInt): BigInt = op match { + case OpKind.Add() => add(x, y) + case OpKind.Sub() => sub(x, y) + case OpKind.Mul() => mul(x, y) + } + + def add(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) y + else if (x > 0) add(x - 1, y + 1) + else add(x + 1, y - 1) + } + + def sub(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) -y + else if (x > 0) sub(x - 1, y - 1) + else sub(x + 1, y + 1) + } + + def mul(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) BigInt(0) + else if (x > 0) add(mul(x - 1, y), y) + else sub(mul(x + 1, y), y) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith1/defs.scala b/frontends/benchmarks/equivalence/funnyarith1/defs.scala new file mode 100644 index 0000000000..d2fd34d4fd --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith1/defs.scala @@ -0,0 +1,10 @@ +import stainless.lang._ +import stainless.collection._ + +object defs { + enum OpKind { + case Add() + case Sub() + case Mul() + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith1/expected_outcome.json b/frontends/benchmarks/equivalence/funnyarith1/expected_outcome.json new file mode 100644 index 0000000000..21ed1bb472 --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith1/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Model.eval", + "functions": [ + "Candidate.eval" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith1/test_conf.json b/frontends/benchmarks/equivalence/funnyarith1/test_conf.json new file mode 100644 index 0000000000..6da1f221fc --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith1/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Model.eval" + ], + "comparefuns": [ + "Candidate.eval" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith2/Candidate.scala b/frontends/benchmarks/equivalence/funnyarith2/Candidate.scala new file mode 100644 index 0000000000..bf89bb0d95 --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith2/Candidate.scala @@ -0,0 +1,34 @@ +import stainless.lang._ +import stainless.collection._ +import defs._ + +object Candidate { + + def eval(op: OpKind, x: BigInt, y: BigInt): BigInt = op match { + case OpKind.Sub() => mySub(y, x) + case OpKind.Mul() => myMul(x, y) + case OpKind.Add() => myAdd(x, y) + } + + def myAdd(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) y + else if (x > 0) myAdd(x - 1, y + 1) + else myAdd(x + 1, y - 1) + } + + // Computes y - x and not x - y + def mySub(x: BigInt, y: BigInt): BigInt = { + decreases(if (y <= 0) -y else y) + if (y == 0) -x + else if (y > 0) mySub(x - 1, y - 1) + else mySub(x + 1, y + 1) + } + + def myMul(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) BigInt(0) + else if (x > 0) myAdd(myMul(x - 1, y), y) + else mySub(y, myMul(x + 1, y)) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith2/Model.scala b/frontends/benchmarks/equivalence/funnyarith2/Model.scala new file mode 100644 index 0000000000..5bf43a0e4d --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith2/Model.scala @@ -0,0 +1,34 @@ +import stainless.lang._ +import stainless.collection._ +import defs._ + +// As funnyarith1 but the Candidate swaps the arguments of sub +object Model { + + def eval(op: OpKind, x: BigInt, y: BigInt): BigInt = op match { + case OpKind.Add() => add(x, y) + case OpKind.Sub() => sub(x, y) + case OpKind.Mul() => mul(x, y) + } + + def add(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) y + else if (x > 0) add(x - 1, y + 1) + else add(x + 1, y - 1) + } + + def sub(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) -y + else if (x > 0) sub(x - 1, y - 1) + else sub(x + 1, y + 1) + } + + def mul(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) BigInt(0) + else if (x > 0) add(mul(x - 1, y), y) + else sub(mul(x + 1, y), y) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith2/defs.scala b/frontends/benchmarks/equivalence/funnyarith2/defs.scala new file mode 100644 index 0000000000..d2fd34d4fd --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith2/defs.scala @@ -0,0 +1,10 @@ +import stainless.lang._ +import stainless.collection._ + +object defs { + enum OpKind { + case Add() + case Sub() + case Mul() + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith2/expected_outcome.json b/frontends/benchmarks/equivalence/funnyarith2/expected_outcome.json new file mode 100644 index 0000000000..21ed1bb472 --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith2/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Model.eval", + "functions": [ + "Candidate.eval" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith2/test_conf.json b/frontends/benchmarks/equivalence/funnyarith2/test_conf.json new file mode 100644 index 0000000000..ae55f50bef --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith2/test_conf.json @@ -0,0 +1,9 @@ +{ + "models": [ + "Model.eval" + ], + "comparefuns": [ + "Candidate.eval" + ], + "max-perm": 32 +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith3/Candidate.scala b/frontends/benchmarks/equivalence/funnyarith3/Candidate.scala new file mode 100644 index 0000000000..bbab0ac94a --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith3/Candidate.scala @@ -0,0 +1,30 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate { + + // Top level + def eval(x: BigInt, y: BigInt): BigInt = myMul(x, y) + + def myAdd(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) y + else if (x > 0) myAdd(x - 1, y + 1) + else myAdd(x + 1, y - 1) + } + + // Computes y - x and not x - y + def mySub(x: BigInt, y: BigInt): BigInt = { + decreases(if (y <= 0) -y else y) + if (y == 0) -x + else if (y > 0) mySub(x - 1, y - 1) + else mySub(x + 1, y + 1) + } + + def myMul(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) BigInt(0) + else if (x > 0) myAdd(myMul(x - 1, y), y) + else mySub(y, myMul(x + 1, y)) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith3/Model.scala b/frontends/benchmarks/equivalence/funnyarith3/Model.scala new file mode 100644 index 0000000000..09925d419b --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith3/Model.scala @@ -0,0 +1,32 @@ +import stainless.lang._ +import stainless.collection._ + +// Testing subfunction matching within subfunctions +// That is, we do not only try to match function appearing in top-level `eval` (mul and myMul) +// but also functions transitively appearing in mul and myMul +// Furthermore, Candidate mySub arguments are swapped +object Model { + // Top level + def eval(x: BigInt, y: BigInt): BigInt = mul(x, y) + + def mul(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) BigInt(0) + else if (x > 0) add(mul(x - 1, y), y) + else sub(mul(x + 1, y), y) + } + + def add(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) y + else if (x > 0) add(x - 1, y + 1) + else add(x + 1, y - 1) + } + + def sub(x: BigInt, y: BigInt): BigInt = { + decreases(if (x <= 0) -x else x) + if (x == 0) -y + else if (x > 0) sub(x - 1, y - 1) + else sub(x + 1, y + 1) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith3/expected_outcome.json b/frontends/benchmarks/equivalence/funnyarith3/expected_outcome.json new file mode 100644 index 0000000000..21ed1bb472 --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith3/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Model.eval", + "functions": [ + "Candidate.eval" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/funnyarith3/test_conf.json b/frontends/benchmarks/equivalence/funnyarith3/test_conf.json new file mode 100644 index 0000000000..6da1f221fc --- /dev/null +++ b/frontends/benchmarks/equivalence/funnyarith3/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Model.eval" + ], + "comparefuns": [ + "Candidate.eval" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/halfAlternation/expected_outcome.json b/frontends/benchmarks/equivalence/halfAlternation/expected_outcome.json new file mode 100644 index 0000000000..eaeabf88d2 --- /dev/null +++ b/frontends/benchmarks/equivalence/halfAlternation/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "HalfAlternation.h1", + "functions": [ + "HalfAlternation.h2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/halfAlternation.scala b/frontends/benchmarks/equivalence/halfAlternation/halfAlternation.scala similarity index 87% rename from frontends/benchmarks/equivalence/halfAlternation.scala rename to frontends/benchmarks/equivalence/halfAlternation/halfAlternation.scala index 778b3965e8..644a27e2ce 100644 --- a/frontends/benchmarks/equivalence/halfAlternation.scala +++ b/frontends/benchmarks/equivalence/halfAlternation/halfAlternation.scala @@ -22,8 +22,4 @@ object HalfAlternation { else if ((n % 2) == 0) h2(n-1) + h2(n-2) else h2(n-2) + h2(n-2) + h2(n-3) } - - @traceInduct("") - def check_h(n: BigInt): Unit = { - } ensuring(h1(n) == h2(n)) } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/halfAlternation/test_conf.json b/frontends/benchmarks/equivalence/halfAlternation/test_conf.json new file mode 100644 index 0000000000..98070c62f0 --- /dev/null +++ b/frontends/benchmarks/equivalence/halfAlternation/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "HalfAlternation.h1" + ], + "comparefuns": [ + "HalfAlternation.h2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/i1264/expected_outcome.json b/frontends/benchmarks/equivalence/i1264/expected_outcome.json new file mode 100644 index 0000000000..08b5982174 --- /dev/null +++ b/frontends/benchmarks/equivalence/i1264/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "i1264.replace", + "functions": [ + "i1264.slowReplace" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/i1264/i1264.scala b/frontends/benchmarks/equivalence/i1264/i1264.scala new file mode 100644 index 0000000000..7e8ab82d05 --- /dev/null +++ b/frontends/benchmarks/equivalence/i1264/i1264.scala @@ -0,0 +1,39 @@ +import stainless.collection._ +import stainless.lang._ +import stainless.annotation._ + +object i1264 { + + def split[T](l: List[T], x: T): List[List[T]] = { + decreases(l) + l match { + case Nil() => List[List[T]](List[T]()) + case Cons(y, ys) if x == y => + Nil[T]() :: split(ys, x) + case Cons(y, ys) => + val r = split(ys, x) + (y :: r.head) :: r.tail + } + } + + def join[T](ll: List[List[T]], l: List[T]): List[T] = { + decreases(ll) + ll match { + case Nil() => Nil[T]() + case Cons(l1, Nil()) => l1 + case Cons(l1, ls) => l1 ++ l ++ join(ls, l) + } + } + + def replace[T](l1: List[T], x: T, l2: List[T]): List[T] = { + decreases(l1) + l1 match { + case Nil() => Nil[T]() + case Cons(y, ys) if x == y => l2 ++ replace(ys, x, l2) + case Cons(y, ys) => y :: replace(ys, x, l2) + } + } + + def slowReplace[T](l1: List[T], x: T, l2: List[T]): List[T] = join(split(l1, x), l2) + +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/i1264/test_conf.json b/frontends/benchmarks/equivalence/i1264/test_conf.json new file mode 100644 index 0000000000..1516d6ab48 --- /dev/null +++ b/frontends/benchmarks/equivalence/i1264/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "i1264.replace" + ], + "comparefuns": [ + "i1264.slowReplace" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/inlining/expected_outcome.json b/frontends/benchmarks/equivalence/inlining/expected_outcome.json new file mode 100644 index 0000000000..7d25034f44 --- /dev/null +++ b/frontends/benchmarks/equivalence/inlining/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Inlining.inlining_1", + "functions": [ + "Inlining.inlining_2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/inlining.scala b/frontends/benchmarks/equivalence/inlining/inlining.scala similarity index 80% rename from frontends/benchmarks/equivalence/inlining.scala rename to frontends/benchmarks/equivalence/inlining/inlining.scala index b3523d3568..875e4aeef2 100644 --- a/frontends/benchmarks/equivalence/inlining.scala +++ b/frontends/benchmarks/equivalence/inlining/inlining.scala @@ -13,15 +13,10 @@ object Inlining { else if (x < 0) 0 else x } - + def inlining_2(x: BigInt): BigInt = { if (x > 1) inlining_2(x-2) + BigInt(2) else if (x < 0) 0 else x } - - @traceInduct("") - def check_inlining(n: BigInt): Unit = { - } ensuring(inlining_1(n) == inlining_2(n)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/inlining/test_conf.json b/frontends/benchmarks/equivalence/inlining/test_conf.json new file mode 100644 index 0000000000..49ee99fb0d --- /dev/null +++ b/frontends/benchmarks/equivalence/inlining/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Inlining.inlining_1" + ], + "comparefuns": [ + "Inlining.inlining_2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/isSorted/expected_outcome.json b/frontends/benchmarks/equivalence/isSorted/expected_outcome.json new file mode 100644 index 0000000000..1e54a2eeb5 --- /dev/null +++ b/frontends/benchmarks/equivalence/isSorted/expected_outcome.json @@ -0,0 +1,21 @@ +{ + "equivalent": [ + { + "model": "IsSorted.isSortedR", + "functions": [ + "IsSorted.isSortedB" + ] + }, + { + "model": "IsSorted.isSortedB", + "functions": [ + "IsSorted.isSortedC" + ] + } + ], + "erroneous": [ + "IsSorted.isSortedA" + ], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/isSorted.scala b/frontends/benchmarks/equivalence/isSorted/isSorted.scala similarity index 92% rename from frontends/benchmarks/equivalence/isSorted.scala rename to frontends/benchmarks/equivalence/isSorted/isSorted.scala index d8811d013d..fc15e3b1fb 100644 --- a/frontends/benchmarks/equivalence/isSorted.scala +++ b/frontends/benchmarks/equivalence/isSorted/isSorted.scala @@ -1,7 +1,7 @@ import stainless.annotation._ import stainless.lang._ import stainless.collection._ -import stainless.proof._ +import stainless.proof._ object IsSorted { @@ -20,7 +20,7 @@ object IsSorted { def iter(l: List[Int]): Boolean = if (l.isEmpty) true else if (l.tail.isEmpty) true - else leq(l.head, l.tail.head) && iter(l.tail) + else leq(l.head, l.tail.head) && iter(l.tail) if (l.size < 2) true else l.head <= l.tail.head && iter(l.tail) } @@ -30,10 +30,10 @@ object IsSorted { true else if (!l.tail.isEmpty && l.head > l.tail.head) false - else + else isSortedB(l.tail) } - + def isSortedC(l: List[Int]): Boolean = { def chk(l: List[Int], p: Int, a: Boolean): Boolean = { if (l.isEmpty) a diff --git a/frontends/benchmarks/equivalence/isSorted/test_conf.json b/frontends/benchmarks/equivalence/isSorted/test_conf.json new file mode 100644 index 0000000000..63a98c2325 --- /dev/null +++ b/frontends/benchmarks/equivalence/isSorted/test_conf.json @@ -0,0 +1,10 @@ +{ + "models": [ + "IsSorted.isSortedR" + ], + "comparefuns": [ + "IsSorted.isSortedA", + "IsSorted.isSortedB", + "IsSorted.isSortedC" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/iseven1/Candidate.scala b/frontends/benchmarks/equivalence/iseven1/Candidate.scala new file mode 100644 index 0000000000..5559dec3e6 --- /dev/null +++ b/frontends/benchmarks/equivalence/iseven1/Candidate.scala @@ -0,0 +1,20 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate { + + def isEvenTopLvl(x: BigInt): Boolean = myIsEven(x) + + def myIsOdd(x: BigInt): Boolean = { + decreases(if (x <= 0) BigInt(0) else x) + if (x <= 0) false + else if (x == 1) true + else !myIsEven(x - 1) + } + def myIsEven(x: BigInt): Boolean = { + decreases(if (x <= 0) BigInt(0) else x) + if (x < 0) false + else if (x == 0) true + else !myIsOdd(x - 1) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/iseven1/Model.scala b/frontends/benchmarks/equivalence/iseven1/Model.scala new file mode 100644 index 0000000000..5da26b3136 --- /dev/null +++ b/frontends/benchmarks/equivalence/iseven1/Model.scala @@ -0,0 +1,20 @@ +import stainless.lang._ +import stainless.collection._ + +object Model { + + def isEvenTopLvl(x: BigInt): Boolean = isEven(x) + + def isOdd(x: BigInt): Boolean = { + decreases(if (x <= 0) BigInt(0) else x) + if (x <= 0) false + else if (x == 1) true + else !isEven(x - 1) + } + def isEven(x: BigInt): Boolean = { + decreases(if (x <= 0) BigInt(0) else x) + if (x < 0) false + else if (x == 0) true + else !isOdd(x - 1) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/iseven1/expected_outcome.json b/frontends/benchmarks/equivalence/iseven1/expected_outcome.json new file mode 100644 index 0000000000..5e8542f5ca --- /dev/null +++ b/frontends/benchmarks/equivalence/iseven1/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Model.isEvenTopLvl", + "functions": [ + "Candidate.isEvenTopLvl" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/iseven1/test_conf.json b/frontends/benchmarks/equivalence/iseven1/test_conf.json new file mode 100644 index 0000000000..b7c388b38c --- /dev/null +++ b/frontends/benchmarks/equivalence/iseven1/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Model.isEvenTopLvl" + ], + "comparefuns": [ + "Candidate.isEvenTopLvl" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/iseven2/Candidate.scala b/frontends/benchmarks/equivalence/iseven2/Candidate.scala new file mode 100644 index 0000000000..1880ab627d --- /dev/null +++ b/frontends/benchmarks/equivalence/iseven2/Candidate.scala @@ -0,0 +1,20 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate { + + def isEvenTopLvl(x: BigInt): Boolean = !myIsOdd(x) && myIsEven(x) // Note: swapped order to cause "pairs" to be mismatched + + def myIsOdd(x: BigInt): Boolean = { + decreases(if (x <= 0) BigInt(0) else x) + if (x <= 0) false + else if (x == 1) true + else !myIsEven(x - 1) + } + def myIsEven(x: BigInt): Boolean = { + decreases(if (x <= 0) BigInt(0) else x) + if (x < 0) false + else if (x == 0) true + else myIsEven(x - 2) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/iseven2/Model.scala b/frontends/benchmarks/equivalence/iseven2/Model.scala new file mode 100644 index 0000000000..6ad7761ba7 --- /dev/null +++ b/frontends/benchmarks/equivalence/iseven2/Model.scala @@ -0,0 +1,20 @@ +import stainless.lang._ +import stainless.collection._ + +object Model { + + def isEvenTopLvl(x: BigInt): Boolean = isEven(x) && !isOdd(x) // calls isEven and isOdd to force matching for both of them + + def isOdd(x: BigInt): Boolean = { + decreases(if (x <= 0) BigInt(0) else x) + if (x <= 0) false + else if (x == 1) true + else !isEven(x - 1) + } + def isEven(x: BigInt): Boolean = { + decreases(if (x <= 0) BigInt(0) else x) + if (x < 0) false + else if (x == 0) true + else isEven(x - 2) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/iseven2/expected_outcome.json b/frontends/benchmarks/equivalence/iseven2/expected_outcome.json new file mode 100644 index 0000000000..5e8542f5ca --- /dev/null +++ b/frontends/benchmarks/equivalence/iseven2/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Model.isEvenTopLvl", + "functions": [ + "Candidate.isEvenTopLvl" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/iseven2/test_conf.json b/frontends/benchmarks/equivalence/iseven2/test_conf.json new file mode 100644 index 0000000000..b7c388b38c --- /dev/null +++ b/frontends/benchmarks/equivalence/iseven2/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Model.isEvenTopLvl" + ], + "comparefuns": [ + "Candidate.isEvenTopLvl" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/limit1/expected_outcome.json b/frontends/benchmarks/equivalence/limit1/expected_outcome.json new file mode 100644 index 0000000000..1aa27295e5 --- /dev/null +++ b/frontends/benchmarks/equivalence/limit1/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Limit1.limit1_1", + "functions": [ + "Limit1.limit1_2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/limit1.scala b/frontends/benchmarks/equivalence/limit1/limit1.scala similarity index 81% rename from frontends/benchmarks/equivalence/limit1.scala rename to frontends/benchmarks/equivalence/limit1/limit1.scala index 4864276ae5..a20d2d4333 100644 --- a/frontends/benchmarks/equivalence/limit1.scala +++ b/frontends/benchmarks/equivalence/limit1/limit1.scala @@ -19,9 +19,4 @@ object Limit1 { if (n <= 1) n else n + n-1 + limit1_2(n-2) } - - @traceInduct("") - def check_limit1(n: BigInt): Unit = { - } ensuring(limit1_1(n) == limit1_2(n)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/limit1/test_conf.json b/frontends/benchmarks/equivalence/limit1/test_conf.json new file mode 100644 index 0000000000..d171903c54 --- /dev/null +++ b/frontends/benchmarks/equivalence/limit1/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Limit1.limit1_1" + ], + "comparefuns": [ + "Limit1.limit1_2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/limit2/expected_outcome.json b/frontends/benchmarks/equivalence/limit2/expected_outcome.json new file mode 100644 index 0000000000..a567188ce3 --- /dev/null +++ b/frontends/benchmarks/equivalence/limit2/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Limit2.limit2_1", + "functions": [ + "Limit2.limit2_2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/limit2.scala b/frontends/benchmarks/equivalence/limit2/limit2.scala similarity index 79% rename from frontends/benchmarks/equivalence/limit2.scala rename to frontends/benchmarks/equivalence/limit2/limit2.scala index d0059add92..b37b7071f9 100644 --- a/frontends/benchmarks/equivalence/limit2.scala +++ b/frontends/benchmarks/equivalence/limit2/limit2.scala @@ -17,9 +17,4 @@ object Limit2 { if (n <= 1) n else n + limit2_2(n-1) } - - @traceInduct("") - def check_limit2(n: BigInt): Unit = { - } ensuring(limit2_1(n) == limit2_2(n)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/limit2/test_conf.json b/frontends/benchmarks/equivalence/limit2/test_conf.json new file mode 100644 index 0000000000..d4b1e2d229 --- /dev/null +++ b/frontends/benchmarks/equivalence/limit2/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Limit2.limit2_1" + ], + "comparefuns": [ + "Limit2.limit2_2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/limit3/expected_outcome.json b/frontends/benchmarks/equivalence/limit3/expected_outcome.json new file mode 100644 index 0000000000..ca001d9974 --- /dev/null +++ b/frontends/benchmarks/equivalence/limit3/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Limit3.limit3_1", + "functions": [ + "Limit3.limit3_2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/limit3.scala b/frontends/benchmarks/equivalence/limit3/limit3.scala similarity index 81% rename from frontends/benchmarks/equivalence/limit3.scala rename to frontends/benchmarks/equivalence/limit3/limit3.scala index aa8df17eff..e7751a1e36 100644 --- a/frontends/benchmarks/equivalence/limit3.scala +++ b/frontends/benchmarks/equivalence/limit3/limit3.scala @@ -21,9 +21,4 @@ object Limit3 { else r } } - - @traceInduct("") - def check_limit3(n: BigInt): Unit = { - } ensuring(limit3_1(n) == limit3_2(n)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/limit3/test_conf.json b/frontends/benchmarks/equivalence/limit3/test_conf.json new file mode 100644 index 0000000000..6f3db04488 --- /dev/null +++ b/frontends/benchmarks/equivalence/limit3/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Limit3.limit3_1" + ], + "comparefuns": [ + "Limit3.limit3_2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/max1/expected_outcome.json b/frontends/benchmarks/equivalence/max1/expected_outcome.json new file mode 100644 index 0000000000..0560430fdf --- /dev/null +++ b/frontends/benchmarks/equivalence/max1/expected_outcome.json @@ -0,0 +1,19 @@ +{ + "equivalent": [ + { + "model": "Max.maxR", + "functions": [ + "Max.maxC" + ] + }, + { + "model": "Max.maxC", + "functions": [ + "Max.maxT" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/max.scala b/frontends/benchmarks/equivalence/max1/max.scala similarity index 100% rename from frontends/benchmarks/equivalence/max.scala rename to frontends/benchmarks/equivalence/max1/max.scala diff --git a/frontends/benchmarks/equivalence/max1/test_conf.json b/frontends/benchmarks/equivalence/max1/test_conf.json new file mode 100644 index 0000000000..49a434521f --- /dev/null +++ b/frontends/benchmarks/equivalence/max1/test_conf.json @@ -0,0 +1,9 @@ +{ + "models": [ + "Max.maxR" + ], + "comparefuns": [ + "Max.maxC", + "Max.maxT" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/max2/Candidate1.scala b/frontends/benchmarks/equivalence/max2/Candidate1.scala new file mode 100644 index 0000000000..686ab23cbf --- /dev/null +++ b/frontends/benchmarks/equivalence/max2/Candidate1.scala @@ -0,0 +1,13 @@ +import stainless.collection._ +import stainless.lang._ + +object Candidate1 { + def max(l: List[Int]): Int = { + decreases(l) + l match { + case Nil() => 42 + case Cons(hd, Nil()) => hd + case Cons(hd, tl) => if (hd > max(tl)) hd else max(tl) + } + } +} diff --git a/frontends/benchmarks/equivalence/max2/Candidate2.scala b/frontends/benchmarks/equivalence/max2/Candidate2.scala new file mode 100644 index 0000000000..1453b978a2 --- /dev/null +++ b/frontends/benchmarks/equivalence/max2/Candidate2.scala @@ -0,0 +1,12 @@ +import stainless.collection._ +import stainless.lang._ + +object Candidate2 { + def max(l: List[Int]): Int = { + decreases(l) + l match { + case Nil() => Integer.MIN_VALUE + case Cons(h, t) => if (h > max(t)) h else max(t) + } + } +} diff --git a/frontends/benchmarks/equivalence/max2/Candidate3.scala b/frontends/benchmarks/equivalence/max2/Candidate3.scala new file mode 100644 index 0000000000..a50ce05a8c --- /dev/null +++ b/frontends/benchmarks/equivalence/max2/Candidate3.scala @@ -0,0 +1,18 @@ +import stainless.collection._ +import stainless.lang._ + +object Candidate3 { + def max(l: List[Int]): Int = { + decreases(l) + l match { + case Nil() => Integer.MIN_VALUE + case Cons(hd, tl) => { + tl match { + case Nil() => hd + case Cons(hd1, tl1) => + if (hd > hd1) max(hd :: tl1) else max(hd1 :: tl1) + } + } + } + } +} diff --git a/frontends/benchmarks/equivalence/max2/Candidate4.scala b/frontends/benchmarks/equivalence/max2/Candidate4.scala new file mode 100644 index 0000000000..ce4cbee337 --- /dev/null +++ b/frontends/benchmarks/equivalence/max2/Candidate4.scala @@ -0,0 +1,12 @@ +import stainless.collection._ +import stainless.lang._ + +object Candidate4 { + def max(l: List[Int]): Int = { + decreases(l) + l match { + case Nil() => 0 + case Cons(hd, tl) => if (hd > max(tl)) hd else max(tl) + } + } +} diff --git a/frontends/benchmarks/equivalence/max2/Candidate5.scala b/frontends/benchmarks/equivalence/max2/Candidate5.scala new file mode 100644 index 0000000000..fcc63899ce --- /dev/null +++ b/frontends/benchmarks/equivalence/max2/Candidate5.scala @@ -0,0 +1,6 @@ +import stainless.collection._ +import stainless.lang._ + +object Candidate5 { + def max(l: List[Int]): Int = -1 +} diff --git a/frontends/benchmarks/equivalence/max2/Model.scala b/frontends/benchmarks/equivalence/max2/Model.scala new file mode 100644 index 0000000000..b2258db3a3 --- /dev/null +++ b/frontends/benchmarks/equivalence/max2/Model.scala @@ -0,0 +1,19 @@ +import stainless.collection._ +import stainless.lang._ + +object Model { + + def max(lst: List[Int]): Int = { + decreases(lst) + lst match { + case Nil() => Integer.MIN_VALUE + case Cons(hd, Nil()) => hd + case Cons(hd, tl) => if (hd > max(tl)) hd else max(tl) + } + } + + def norm(l: List[Int], f: Int): Int = { + if (l.isEmpty) -1 + else f + } +} diff --git a/frontends/benchmarks/equivalence/max2/expected_outcome.json b/frontends/benchmarks/equivalence/max2/expected_outcome.json new file mode 100644 index 0000000000..b64cc40d27 --- /dev/null +++ b/frontends/benchmarks/equivalence/max2/expected_outcome.json @@ -0,0 +1,18 @@ +{ + "equivalent": [ + { + "model": "Model.max", + "functions": [ + "Candidate1.max", + "Candidate2.max", + "Candidate3.max" + ] + } + ], + "erroneous": [ + "Candidate4.max", + "Candidate5.max" + ], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/max2/test_conf.json b/frontends/benchmarks/equivalence/max2/test_conf.json new file mode 100644 index 0000000000..bec47d18d7 --- /dev/null +++ b/frontends/benchmarks/equivalence/max2/test_conf.json @@ -0,0 +1,14 @@ +{ + "models": [ + "Model.max" + ], + "comparefuns": [ + "Candidate1.max", + "Candidate2.max", + "Candidate3.max", + "Candidate4.max", + "Candidate5.max" + ], + "tests": [], + "norm": "Model.norm" +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/max3/Candidate.scala b/frontends/benchmarks/equivalence/max3/Candidate.scala new file mode 100644 index 0000000000..d802691770 --- /dev/null +++ b/frontends/benchmarks/equivalence/max3/Candidate.scala @@ -0,0 +1,25 @@ +import stainless.collection._ +import stainless.lang._ + +object Candidate { + def fold(f: (Int, Int) => Int, l: List[Int], a: Int): Int = { + decreases(l) + l match { + case Nil() => a + case Cons(hd, tl) => f(hd, fold(f, tl, a)) + } + } + + def max(lst: List[Int]): Int = { + lst match { + case Nil() => choose((x: Int) => true) + case Cons(hd, tl) => + fold( + (x, y) => if (x > y) x else y, + lst, + hd + ) + } + } + +} diff --git a/frontends/benchmarks/equivalence/max3/Model.scala b/frontends/benchmarks/equivalence/max3/Model.scala new file mode 100644 index 0000000000..c94c014c0d --- /dev/null +++ b/frontends/benchmarks/equivalence/max3/Model.scala @@ -0,0 +1,33 @@ +import stainless.collection._ +import stainless.lang._ + +// This is not expected to verify (it should timeout) +// but here we ensure that the `choose` functions (created from the `choose((x: Int) => true)`) +// for the Model and the Candidate do not get matched because it would make the type-checker unhappy +// (because we would create `choose` expressions when doing the replacement). +object Model { + def fold(f: (Int, Int) => Int, l: List[Int], a: Int): Int = { + decreases(l) + l match { + case Nil() => a + case Cons(hd, tl) => f(hd, fold(f, tl, a)) + } + } + + def max(lst: List[Int]): Int = { + lst match { + case Nil() => choose((x: Int) => true) + case Cons(hd, tl) => + fold( + (x, y) => if (x > y) x else y, + lst, + hd + ) + } + } + + def norm(l: List[Int], f: Int): Int = { + if (l.isEmpty) -1 + else f + } +} diff --git a/frontends/benchmarks/equivalence/max3/expected_outcome.json b/frontends/benchmarks/equivalence/max3/expected_outcome.json new file mode 100644 index 0000000000..39a2312446 --- /dev/null +++ b/frontends/benchmarks/equivalence/max3/expected_outcome.json @@ -0,0 +1,6 @@ +{ + "equivalent": [], + "erroneous": [], + "timeout": ["Candidate.max"], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/max3/test_conf.json b/frontends/benchmarks/equivalence/max3/test_conf.json new file mode 100644 index 0000000000..992aa74a13 --- /dev/null +++ b/frontends/benchmarks/equivalence/max3/test_conf.json @@ -0,0 +1,10 @@ +{ + "models": [ + "Model.max" + ], + "comparefuns": [ + "Candidate.max" + ], + "tests": [], + "norm": "Model.norm" +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/pascal/expected_outcome.json b/frontends/benchmarks/equivalence/pascal/expected_outcome.json new file mode 100644 index 0000000000..f885d7b30d --- /dev/null +++ b/frontends/benchmarks/equivalence/pascal/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Pascal.p1", + "functions": [ + "Pascal.p2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/pascal.scala b/frontends/benchmarks/equivalence/pascal/pascal.scala similarity index 80% rename from frontends/benchmarks/equivalence/pascal.scala rename to frontends/benchmarks/equivalence/pascal/pascal.scala index e5f0a980f0..700d632628 100644 --- a/frontends/benchmarks/equivalence/pascal.scala +++ b/frontends/benchmarks/equivalence/pascal/pascal.scala @@ -20,11 +20,6 @@ object Pascal { def p2(n: BigInt, m: BigInt): BigInt = { if (m < 1 || n < 1 || m > n) 0 else if (m == 1 || n == 1 || m == n) 1 - else p2(n-1, m-1) + p2 (n-2 , m-1) + p2 (n-2 , m) + else p2(n-1, m-1) + p2 (n-2 , m-1) + p2 (n-2 , m) } - - @traceInduct("") - def check_p(n: BigInt, m: BigInt): Unit = { - } ensuring(p1(n, m) == p2(n, m)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/pascal/test_conf.json b/frontends/benchmarks/equivalence/pascal/test_conf.json new file mode 100644 index 0000000000..e04f32e076 --- /dev/null +++ b/frontends/benchmarks/equivalence/pascal/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Pascal.p1" + ], + "comparefuns": [ + "Pascal.p2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/sigma/Candidate.scala b/frontends/benchmarks/equivalence/sigma/Candidate.scala new file mode 100644 index 0000000000..5d21ee2c23 --- /dev/null +++ b/frontends/benchmarks/equivalence/sigma/Candidate.scala @@ -0,0 +1,23 @@ +import stainless.collection._ +import stainless.lang._ + + +object Candidate { + + def sigma(f: BigInt => BigInt, a: BigInt, b: BigInt): BigInt = { + decreases(if (b == a) BigInt(2) else if (b > a) 2 + b - a else a - b) + + def sigma_rec( + sum: BigInt, + i: BigInt, + b: BigInt, + f: BigInt => BigInt + ): BigInt = { + decreases(if (b == i) BigInt(2) else if (b > i) 2 + b - i else i - b) + if (i < b) sigma_rec(sum + f(i), i + BigInt(1), b, f) + else if (i == b) sum + f(i) + else BigInt(0) + } + if (a > b) BigInt(0) else sigma_rec(BigInt(0), a, b, f) + } +} diff --git a/frontends/benchmarks/equivalence/sigma/Model.scala b/frontends/benchmarks/equivalence/sigma/Model.scala new file mode 100644 index 0000000000..71a3016aba --- /dev/null +++ b/frontends/benchmarks/equivalence/sigma/Model.scala @@ -0,0 +1,17 @@ +import stainless.collection._ +import stainless.lang._ + +// This one is expected to timeout, but we want to test that permutation of arguments +// for auxiliary functions does not go wrong when doing a "model first" +// and "candidate first" induction strategy +object Model { + + def sigma(f: BigInt => BigInt, a: BigInt, b: BigInt): BigInt = { + def s(a: BigInt, b: BigInt, f: BigInt => BigInt, acc: BigInt): BigInt = { + decreases(if (b == a) BigInt(2) else if (b > a) 2 + b - a else a - b) + if (a > b) acc else s(a + BigInt(1), b, f, acc + f(a)) + } + + s(a, b, f, BigInt(0)) + } +} diff --git a/frontends/benchmarks/equivalence/sigma/expected_outcome.json b/frontends/benchmarks/equivalence/sigma/expected_outcome.json new file mode 100644 index 0000000000..1019f82529 --- /dev/null +++ b/frontends/benchmarks/equivalence/sigma/expected_outcome.json @@ -0,0 +1,6 @@ +{ + "equivalent": [], + "erroneous": [], + "timeout": ["Candidate.sigma"], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/sigma/test_conf.json b/frontends/benchmarks/equivalence/sigma/test_conf.json new file mode 100644 index 0000000000..d61aa43dd3 --- /dev/null +++ b/frontends/benchmarks/equivalence/sigma/test_conf.json @@ -0,0 +1,9 @@ +{ + "models": [ + "Model.sigma" + ], + "comparefuns": [ + "Candidate.sigma" + ], + "tests": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/sum/expected_outcome.json b/frontends/benchmarks/equivalence/sum/expected_outcome.json new file mode 100644 index 0000000000..d074ef2b07 --- /dev/null +++ b/frontends/benchmarks/equivalence/sum/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Sum.sum1", + "functions": [ + "Sum.sum2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/sum.scala b/frontends/benchmarks/equivalence/sum/sum.scala similarity index 71% rename from frontends/benchmarks/equivalence/sum.scala rename to frontends/benchmarks/equivalence/sum/sum.scala index 8cc315001c..d7d26c4ddf 100644 --- a/frontends/benchmarks/equivalence/sum.scala +++ b/frontends/benchmarks/equivalence/sum/sum.scala @@ -10,16 +10,11 @@ object Sum { // Fig. 5 - two functions are not in lock-step - def sum1(n: BigInt): BigInt = + def sum1(n: BigInt): BigInt = if (n <= 1) n else n + n-1 + sum1(n-2) - def sum2(n: BigInt): BigInt = + def sum2(n: BigInt): BigInt = if (n <= 1) n else n + sum2(n-1) - - @traceInduct("") - def check_sum(n: BigInt): Unit = { - } ensuring(sum1(n) == sum2(n)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/sum/test_conf.json b/frontends/benchmarks/equivalence/sum/test_conf.json new file mode 100644 index 0000000000..d35ea3676a --- /dev/null +++ b/frontends/benchmarks/equivalence/sum/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Sum.sum1" + ], + "comparefuns": [ + "Sum.sum2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unfoldingSorted/Candidate1.scala b/frontends/benchmarks/equivalence/unfoldingSorted/Candidate1.scala new file mode 100644 index 0000000000..4b1f0536dd --- /dev/null +++ b/frontends/benchmarks/equivalence/unfoldingSorted/Candidate1.scala @@ -0,0 +1,31 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate1 { + + def unfoldingSorted[State, Elem](start: State, + next: State => Option[(State, Elem)], + leq: (Elem, Elem) => Boolean, + max: BigInt): List[Elem] = { + def insertSorted(t: Elem, xs: List[Elem]): List[Elem] = { + decreases(xs) + xs match { + case Nil() => Cons(t, Nil()) + case Cons(hd, tl) => + if (leq(t, hd)) t :: xs + else Cons(hd, insertSorted(t, tl)) + } + } + def go(s: State, xs: List[Elem], fuel: BigInt): List[Elem] = { + decreases(if (fuel <= 0) BigInt(0) else fuel) + if (fuel <= 0) xs + else next(s) match { + case Some((nxtS, t)) => + go(nxtS, insertSorted(t, xs), fuel - 1) + case None() => xs + } + } + + go(start, Nil(), max) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unfoldingSorted/Candidate2.scala b/frontends/benchmarks/equivalence/unfoldingSorted/Candidate2.scala new file mode 100644 index 0000000000..f6f6ee15e6 --- /dev/null +++ b/frontends/benchmarks/equivalence/unfoldingSorted/Candidate2.scala @@ -0,0 +1,31 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate2 { + + def unfoldingSorted[State, Elem](start: State, + next: State => Option[(Elem, State)], // oops, should be State, Elem not Elem, State + leq: (Elem, Elem) => Boolean, + max: BigInt): List[Elem] = { + def insertSorted(t: Elem, xs: List[Elem]): List[Elem] = { + decreases(xs) + xs match { + case Nil() => Cons(t, Nil()) + case Cons(hd, tl) => + if (leq(t, hd)) t :: xs + else Cons(hd, insertSorted(t, tl)) + } + } + def go(s: State, xs: List[Elem], fuel: BigInt): List[Elem] = { + decreases(if (fuel <= 0) BigInt(0) else fuel) + if (fuel <= 0) xs + else next(s) match { + case Some((t, nxtS)) => + go(nxtS, insertSorted(t, xs), fuel - 1) + case None() => xs + } + } + + go(start, Nil(), max) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unfoldingSorted/Candidate3.scala b/frontends/benchmarks/equivalence/unfoldingSorted/Candidate3.scala new file mode 100644 index 0000000000..36a9f0ef45 --- /dev/null +++ b/frontends/benchmarks/equivalence/unfoldingSorted/Candidate3.scala @@ -0,0 +1,30 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate3 { + + def unfoldingSorted[State, Elem](start: State, + next: State => Option[(State, Elem)], + leq: (Elem, Elem) => Boolean, + max: BigInt): List[Elem] = { + // Incorrect, this is an append + def insertSorted(t: Elem, xs: List[Elem]): List[Elem] = { + decreases(xs) + xs match { + case Nil() => Cons(t, Nil()) + case Cons(hd, tl) => Cons(hd, insertSorted(t, tl)) + } + } + def go(s: State, xs: List[Elem], fuel: BigInt): List[Elem] = { + decreases(if (fuel <= 0) BigInt(0) else fuel) + if (fuel <= 0) xs + else next(s) match { + case Some((nxtS, t)) => + go(nxtS, insertSorted(t, xs), fuel - 1) + case None() => xs + } + } + + go(start, Nil(), max) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unfoldingSorted/Candidate4.scala b/frontends/benchmarks/equivalence/unfoldingSorted/Candidate4.scala new file mode 100644 index 0000000000..2d10b681d8 --- /dev/null +++ b/frontends/benchmarks/equivalence/unfoldingSorted/Candidate4.scala @@ -0,0 +1,34 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate4 { + + def unfoldingSorted[State, Elem](start: State, + next: State => Option[(State, Elem)], + leq: (Elem, Elem) => Boolean, + max: BigInt): List[Elem] = { + def insertSorted(t: Elem, xs: List[Elem]): List[Elem] = { + decreases(xs) + xs match { + case Nil() => Cons(t, Nil()) + case Cons(hd, tl) => + if (leq(t, hd)) t :: xs + else Cons(hd, insertSorted(t, tl)) + } + } + def go(s: State, xs: List[Elem], fuel: BigInt): List[Elem] = { + decreases(if (fuel <= 0) BigInt(0) else fuel) + if (fuel <= 0) xs + else next(s) match { + case Some((nxtS, t)) => + go(nxtS, insertSorted(t, xs), fuel - 1) + case None() if xs.nonEmpty => + // Incorrect, should stop here + go(s, insertSorted(xs.head, xs), fuel - 1) + case None() => xs + } + } + + go(start, Nil(), max) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unfoldingSorted/Candidate5.scala b/frontends/benchmarks/equivalence/unfoldingSorted/Candidate5.scala new file mode 100644 index 0000000000..6f08aed129 --- /dev/null +++ b/frontends/benchmarks/equivalence/unfoldingSorted/Candidate5.scala @@ -0,0 +1,29 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate5 { + + def unfoldingSorted[State, Elem](start: State, + next: State => Option[(State, Elem)], + leq: (Elem, Elem) => Boolean, + max: BigInt): List[Elem] = { + def insertSorted(t: Elem, xs: List[Elem]): List[Elem] = { + decreases(xs) + xs match { + case Nil() => Cons(t, Nil()) + case Cons(hd, tl) => + if (leq(t, hd)) t :: xs + else Cons(hd, insertSorted(t, tl)) + } + } + def go(s: State, xs: List[Elem], fuel: BigInt): List[Elem] = { + decreases(if (fuel <= 0) BigInt(0) else fuel) + if (fuel <= 0) xs + else next(s).map { case (nxtS, t) => + go(nxtS, insertSorted(t, xs), fuel - 1) + }.getOrElse(xs) + } + + go(start, Nil(), max) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unfoldingSorted/Model.scala b/frontends/benchmarks/equivalence/unfoldingSorted/Model.scala new file mode 100644 index 0000000000..b4809f2ce3 --- /dev/null +++ b/frontends/benchmarks/equivalence/unfoldingSorted/Model.scala @@ -0,0 +1,31 @@ +import stainless.lang._ +import stainless.collection._ + +object Model { + + def unfoldingSorted[S, T](start: S, + next: S => Option[(S, T)], + leq: (T, T) => Boolean, + max: BigInt): List[T] = { + def insert(xs: List[T], t: T): List[T] = { + decreases(xs) + xs match { + case Nil() => Cons(t, Nil()) + case Cons(hd, tl) => + if (leq(t, hd)) t :: xs + else Cons(hd, insert(tl, t)) + } + } + def loop(s: S, fuel: BigInt, xs: List[T]): List[T] = { + decreases(if (fuel <= 0) BigInt(0) else fuel) + if (fuel <= 0) xs + else next(s) match { + case Some((nxtS, t)) => + loop(nxtS, fuel - 1, insert(xs, t)) + case None() => xs + } + } + + loop(start, max, Nil()) + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unfoldingSorted/expected_outcome.json b/frontends/benchmarks/equivalence/unfoldingSorted/expected_outcome.json new file mode 100644 index 0000000000..750d9a497e --- /dev/null +++ b/frontends/benchmarks/equivalence/unfoldingSorted/expected_outcome.json @@ -0,0 +1,19 @@ +{ + "equivalent": [ + { + "model": "Model.unfoldingSorted", + "functions": [ + "Candidate1.unfoldingSorted", + "Candidate5.unfoldingSorted" + ] + } + ], + "erroneous": [ + "Candidate3.unfoldingSorted", + "Candidate4.unfoldingSorted" + ], + "timeout": [], + "wrong": [ + "Candidate2.unfoldingSorted" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unfoldingSorted/test_conf.json b/frontends/benchmarks/equivalence/unfoldingSorted/test_conf.json new file mode 100644 index 0000000000..d6ad3658e0 --- /dev/null +++ b/frontends/benchmarks/equivalence/unfoldingSorted/test_conf.json @@ -0,0 +1,12 @@ +{ + "models": [ + "Model.unfoldingSorted" + ], + "comparefuns": [ + "Candidate1.unfoldingSorted", + "Candidate2.unfoldingSorted", + "Candidate3.unfoldingSorted", + "Candidate4.unfoldingSorted", + "Candidate5.unfoldingSorted" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/uniq1/expected_outcome.json b/frontends/benchmarks/equivalence/uniq1/expected_outcome.json new file mode 100644 index 0000000000..07ee4443cf --- /dev/null +++ b/frontends/benchmarks/equivalence/uniq1/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Uniq.uniqR", + "functions": [ + "Uniq.uniqA" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/uniq1/test_conf.json b/frontends/benchmarks/equivalence/uniq1/test_conf.json new file mode 100644 index 0000000000..cdb25e7bd8 --- /dev/null +++ b/frontends/benchmarks/equivalence/uniq1/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Uniq.uniqR" + ], + "comparefuns": [ + "Uniq.uniqA" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/uniq.scala b/frontends/benchmarks/equivalence/uniq1/uniq.scala similarity index 100% rename from frontends/benchmarks/equivalence/uniq.scala rename to frontends/benchmarks/equivalence/uniq1/uniq.scala diff --git a/frontends/benchmarks/equivalence/uniq2/Candidate1.scala b/frontends/benchmarks/equivalence/uniq2/Candidate1.scala new file mode 100644 index 0000000000..1200f436b3 --- /dev/null +++ b/frontends/benchmarks/equivalence/uniq2/Candidate1.scala @@ -0,0 +1,23 @@ +import stainless.lang._ +import stainless.collection._ + +object Candidate1 { + def check(element: Int, l: List[Int]): Boolean = { + decreases(l) + l match { + case Nil() => false + case Cons(hd, tl) => if (element == hd) true else check(element, tl) + } + } + + def app(l1: List[Int], l2: List[Int]): List[Int] = { + decreases(l1) + l1 match { + case Nil() => l2 + case Cons(hd, tl) => + if (check(hd, l2)) app(tl, l2) else app(tl, l2 ++ List(hd)) + } + } + + def uniq(lst: List[Int]): List[Int] = app(lst, Nil()) +} diff --git a/frontends/benchmarks/equivalence/uniq2/Candidate2.scala b/frontends/benchmarks/equivalence/uniq2/Candidate2.scala new file mode 100644 index 0000000000..8c71d85b5c --- /dev/null +++ b/frontends/benchmarks/equivalence/uniq2/Candidate2.scala @@ -0,0 +1,30 @@ +import stainless.collection._ +import stainless.lang._ +import stainless.annotation._ + +object Candidate2 { + def uniq(lst: List[Int]): List[Int] = { + decreases(lst.size) + lst match { + case Nil() => Nil() + case Cons(hd, tl) => + def drop(a: Int, lst_0: List[Int]): List[Int] = { + decreases(lst_0) + lst_0 match { + case Nil() => Nil() + case Cons(hd_0, tl_0) => + if (a == hd_0) drop(a, tl_0) else hd_0 :: drop(a, tl_0) + } + } + + def lem(a: Int, @induct lst: List[Int]): Unit = { + () + } ensuring(drop(a, lst).size <= lst.size) + + lem(hd, tl) + assert(drop(hd, tl).size <= tl.size) + assert(drop(hd, tl).size < lst.size) + hd :: uniq(drop(hd, tl)) + } + } +} diff --git a/frontends/benchmarks/equivalence/uniq2/Model.scala b/frontends/benchmarks/equivalence/uniq2/Model.scala new file mode 100644 index 0000000000..c3ec21a704 --- /dev/null +++ b/frontends/benchmarks/equivalence/uniq2/Model.scala @@ -0,0 +1,95 @@ +import stainless.collection._ +import stainless.annotation._ +import stainless.lang._ + +object Model { + + def remove_elem_1(e: Int, lst: List[Int]): List[Int] = { + decreases(lst) + lst match { + case Nil() => Nil() + case Cons(hd, tl) => + if (e == hd) remove_elem_1(e, tl) else hd :: remove_elem_1(e, tl) + } + } + + def solution_1(lst: List[Int]): List[Int] = { + decreases(lst) + lst match { + case Nil() => Nil() + case Cons(hd, tl) => hd :: remove_elem_1(hd, solution_1(tl)) + } + } + + def drop_2(lst: List[Int], n: Int): List[Int] = { + decreases(lst) + lst match { + case Nil() => Nil() + case Cons(hd, tl) => if (hd == n) drop_2(tl, n) else hd :: drop_2(tl, n) + } + } + + def lemma_2(n: Int, @induct lst: List[Int]): Unit = { + } ensuring(drop_2(lst, n).size <= lst.size) + + def solution_2(lst: List[Int]): List[Int] = { + decreases(lst.size) + + def lem(n: Int, @stainless.annotation.induct lst: List[Int]): Unit = { + () + } ensuring(drop_2(lst, n).size <= lst.size) + + lst match { + case Nil() => Nil() + case Cons(hd, tl) => + lem(hd, tl) + hd :: solution_2(drop_2(tl, hd)) + } + } + + def is_in_3(lst: List[Int], a: Int): Boolean = { + decreases(lst) + lst match { + case Nil() => false + case Cons(hd, tl) => if (a == hd) true else is_in_3(tl, a) + } + } + + def unique_3(lst1: List[Int], lst2: List[Int]): List[Int] = { + decreases(lst1) + lst1 match { + case Nil() => lst2 + case Cons(hd, tl) => + if (is_in_3(lst2, hd)) unique_3(tl, lst2) else unique_3(tl, lst2 ++ List[Int](hd)) + } + } + + def solution_3(lst: List[Int]): List[Int] = { unique_3(lst, Nil()) } + + def solution_4(lst: List[Int]): List[Int] = { + + def isNotIn_4(tlst: List[Int], c: Int): Boolean = { + decreases(tlst) + tlst match { + case Nil() => true + case Cons(hd, tl) => if (hd == c) false else true && isNotIn_4(tl, c) + } + } + + def uniqSave_4(l1: List[Int], l2: List[Int]): List[Int] = { + decreases(l1) + l1 match { + case Nil() => { l2 } + case Cons(hd, tl) => + if (isNotIn_4(l2, hd)) { + uniqSave_4(tl, l2 ++ List(hd)) + } else { + uniqSave_4(tl, l2) + } + } + } + uniqSave_4(lst, Nil()) + + } + +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/uniq2/expected_outcome.json b/frontends/benchmarks/equivalence/uniq2/expected_outcome.json new file mode 100644 index 0000000000..81bf461feb --- /dev/null +++ b/frontends/benchmarks/equivalence/uniq2/expected_outcome.json @@ -0,0 +1,19 @@ +{ + "equivalent": [ + { + "model": "Model.solution_2", + "functions": [ + "Candidate2.uniq" + ] + }, + { + "model": "Model.solution_3", + "functions": [ + "Candidate1.uniq" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/uniq2/test_conf.json b/frontends/benchmarks/equivalence/uniq2/test_conf.json new file mode 100644 index 0000000000..a3e7cf158a --- /dev/null +++ b/frontends/benchmarks/equivalence/uniq2/test_conf.json @@ -0,0 +1,12 @@ +{ + "models": [ + "Model.solution_1", + "Model.solution_2", + "Model.solution_3", + "Model.solution_4" + ], + "comparefuns": [ + "Candidate1.uniq", + "Candidate2.uniq" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unrolling/expected_outcome.json b/frontends/benchmarks/equivalence/unrolling/expected_outcome.json new file mode 100644 index 0000000000..f1fb7ed5a8 --- /dev/null +++ b/frontends/benchmarks/equivalence/unrolling/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "Unrolling.fact13_1", + "functions": [ + "Unrolling.fact13_2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unrolling/test_conf.json b/frontends/benchmarks/equivalence/unrolling/test_conf.json new file mode 100644 index 0000000000..9a74beef2b --- /dev/null +++ b/frontends/benchmarks/equivalence/unrolling/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "Unrolling.fact13_1" + ], + "comparefuns": [ + "Unrolling.fact13_2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unrolling.scala b/frontends/benchmarks/equivalence/unrolling/unrolling.scala similarity index 83% rename from frontends/benchmarks/equivalence/unrolling.scala rename to frontends/benchmarks/equivalence/unrolling/unrolling.scala index 0f7a1e069a..e8d69dcfbd 100644 --- a/frontends/benchmarks/equivalence/unrolling.scala +++ b/frontends/benchmarks/equivalence/unrolling/unrolling.scala @@ -11,14 +11,14 @@ object Unrolling { // Fig. 13 def fact13_1(n: BigInt): BigInt = - if (n <= 1) 1 + if (n <= 1) 1 else if (n == 2) 2 else if (n == 3) 6 else if (n == 4) 24 else n * (n-1) * (n-2) * (n-3) * fact13_1(n-4) def fact13_2(n: BigInt): BigInt = - if (n <= 1) 1 + if (n <= 1) 1 else if (n == 2) 2 else if (n == 3) 6 else if (n == 4) 24 @@ -28,8 +28,4 @@ object Unrolling { else if (n == 8) 40320 else n * (n-1) * (n-2) * (n-3) * fact13_2(n-4) - @traceInduct("") - def check_fact13(n: BigInt): Unit = { - } ensuring(fact13_1(n) == fact13_2(n)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unused/expected_outcome.json b/frontends/benchmarks/equivalence/unused/expected_outcome.json new file mode 100644 index 0000000000..2141093655 --- /dev/null +++ b/frontends/benchmarks/equivalence/unused/expected_outcome.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "FibonacciUnused.t1", + "functions": [ + "FibonacciUnused.t2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unused/test_conf.json b/frontends/benchmarks/equivalence/unused/test_conf.json new file mode 100644 index 0000000000..b6f36e0f5e --- /dev/null +++ b/frontends/benchmarks/equivalence/unused/test_conf.json @@ -0,0 +1,8 @@ +{ + "models": [ + "FibonacciUnused.t1" + ], + "comparefuns": [ + "FibonacciUnused.t2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/unused.scala b/frontends/benchmarks/equivalence/unused/unused.scala similarity index 87% rename from frontends/benchmarks/equivalence/unused.scala rename to frontends/benchmarks/equivalence/unused/unused.scala index 89f36078c0..17ae51da01 100644 --- a/frontends/benchmarks/equivalence/unused.scala +++ b/frontends/benchmarks/equivalence/unused/unused.scala @@ -26,12 +26,7 @@ object FibonacciUnused { var r3 = t2(n-3) if (n % 2 == 0) results = r2 + r2 + r3 else results = r1 + r2 - results + results } } - - @traceInduct("") - def check_t(n: BigInt): Unit = { - } ensuring(t1(n) == t2(n)) - } \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/WhacAFun.scala b/frontends/benchmarks/equivalence/whac-a-fun/WhacAFun.scala new file mode 100644 index 0000000000..73bff04de4 --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/WhacAFun.scala @@ -0,0 +1,67 @@ +import stainless.lang._ +import stainless.collection._ +import stainless.annotation._ + +object WhacAFun { + def andThen1[A, B, C](f: A => B, g: B => C): A => C = a => g(f(a)) + def andThen2[A, B, C](ff: A => B, gg: B => C): A => C = aa => gg(ff(aa)) + + def compose1[A, B, C](f: B => C, g: A => B): A => C = a => f(g(a)) + def compose2[A, B, C](ff: B => C, gg: A => B): A => C = aa => ff(gg(aa)) + + def flip1[A, B, C](f: (A, B) => C): (B, A) => C = (b, a) => f(a, b) + def flip2[A, B, C](f: (A, B) => C): (B, A) => C = (b, a) => {val res = f(a, b); res } + + def curry1[A, B, C](f: (A, B) => C): A => B => C = a => b => f(a, b) + def curry2[A, B, C](f: (A, B) => C): A => B => C = aa => bb => { val res = f(aa, bb); res } + + def uncurry1[A, B, C](f: A => B => C): (A, B) => C = (a, b) => f(a)(b) + def uncurry2[A, B, C](f: A => B => C): (A, B) => C = (a, b) => { val res = f(a)(b); res } + + /* + // Times out + def rep1[A](n: BigInt)(f: A => A)(a: A) = { + require(n >= 0) + repeat1(n)(f)(a) + } + def rep2[A](n: BigInt)(f: A => A)(a: A) = { + require(n >= 0) + repeat2(n)(f)(a) + } + // Said to be non-equivalent, even though they are :( + def repeat1[A](n: BigInt)(f: A => A): A => A = { + require(n >= 0) + decreases(n) + a => { + if (n == 0) a + else repeat1(n - 1)(f)(f(a)) + } + } + def repeat2[A](n: BigInt)(f: A => A): A => A = { + require(n >= 0) + decreases(n) + if (n == 0) a => a + else a => repeat2(n - 1)(f)(f(a)) + } + */ + + def repeat1[A](n: BigInt)(f: A => A): A => A = { + require(n >= 0) + decreases(n) + a => { + if (n == 0) a + else repeat1(n - 1)(f)(f(a)) + } + } + def repeat2[A](n: BigInt)(f: A => A): A => A = { + require(n >= 0) + decreases(n) + a => { + if (n == 0) a + else { + val fa = f(a) + repeat1(n - 1)(f)(fa) + } + } + } +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_1.json b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_1.json new file mode 100644 index 0000000000..ecdf162b83 --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_1.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "WhacAFun.andThen1", + "functions": [ + "WhacAFun.andThen2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_2.json b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_2.json new file mode 100644 index 0000000000..89b3fea13b --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_2.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "WhacAFun.compose1", + "functions": [ + "WhacAFun.compose2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_3.json b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_3.json new file mode 100644 index 0000000000..319192909d --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_3.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "WhacAFun.flip1", + "functions": [ + "WhacAFun.flip2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_4.json b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_4.json new file mode 100644 index 0000000000..d5ddae44be --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_4.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "WhacAFun.curry1", + "functions": [ + "WhacAFun.curry2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_5.json b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_5.json new file mode 100644 index 0000000000..674e85d76d --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_5.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "WhacAFun.uncurry1", + "functions": [ + "WhacAFun.uncurry2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_6.json b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_6.json new file mode 100644 index 0000000000..74d7fb4e0d --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/expected_outcome_6.json @@ -0,0 +1,13 @@ +{ + "equivalent": [ + { + "model": "WhacAFun.repeat1", + "functions": [ + "WhacAFun.repeat2" + ] + } + ], + "erroneous": [], + "timeout": [], + "wrong": [] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/test_conf_1.json b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_1.json new file mode 100644 index 0000000000..27665da91a --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_1.json @@ -0,0 +1,8 @@ +{ + "models": [ + "WhacAFun.andThen1" + ], + "comparefuns": [ + "WhacAFun.andThen2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/test_conf_2.json b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_2.json new file mode 100644 index 0000000000..fc1b3981ae --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_2.json @@ -0,0 +1,8 @@ +{ + "models": [ + "WhacAFun.compose1" + ], + "comparefuns": [ + "WhacAFun.compose2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/test_conf_3.json b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_3.json new file mode 100644 index 0000000000..5fabba9d50 --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_3.json @@ -0,0 +1,8 @@ +{ + "models": [ + "WhacAFun.flip1" + ], + "comparefuns": [ + "WhacAFun.flip2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/test_conf_4.json b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_4.json new file mode 100644 index 0000000000..41df3cbaa4 --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_4.json @@ -0,0 +1,8 @@ +{ + "models": [ + "WhacAFun.curry1" + ], + "comparefuns": [ + "WhacAFun.curry2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/test_conf_5.json b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_5.json new file mode 100644 index 0000000000..eca2608a60 --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_5.json @@ -0,0 +1,8 @@ +{ + "models": [ + "WhacAFun.uncurry1" + ], + "comparefuns": [ + "WhacAFun.uncurry2" + ] +} \ No newline at end of file diff --git a/frontends/benchmarks/equivalence/whac-a-fun/test_conf_6.json b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_6.json new file mode 100644 index 0000000000..e7566fc19e --- /dev/null +++ b/frontends/benchmarks/equivalence/whac-a-fun/test_conf_6.json @@ -0,0 +1,8 @@ +{ + "models": [ + "WhacAFun.repeat1" + ], + "comparefuns": [ + "WhacAFun.repeat2" + ] +} \ No newline at end of file diff --git a/frontends/common/src/it/scala/stainless/equivchk/EquivChkSuite.scala b/frontends/common/src/it/scala/stainless/equivchk/EquivChkSuite.scala new file mode 100644 index 0000000000..77d4ac5d22 --- /dev/null +++ b/frontends/common/src/it/scala/stainless/equivchk/EquivChkSuite.scala @@ -0,0 +1,207 @@ +package stainless +package equivchk + +import org.scalatest.funsuite.AnyFunSuite +import stainless.utils.{CheckFilter, JsonUtils, YesNoOnly} +import stainless.verification.* +import extraction.xlang.{TreeSanitizer, trees as xt} +import inox.{OptionValue, TestSilentReporter} +import stainless.equivchk.EquivalenceCheckingReport.Status +import stainless.extraction.utils.DebugSymbols + +import java.io.File +import scala.concurrent.Await +import scala.concurrent.duration.* +import scala.util.{Failure, Success, Try} +import _root_.io.circe.JsonObject + +class EquivChkSuite extends ComponentTestSuite { + override val component: EquivalenceCheckingComponent.type = EquivalenceCheckingComponent + + override def configurations = super.configurations.map { seq => + Seq( + inox.optTimeout(3.seconds), + termination.optInferMeasures(true), + termination.optCheckMeasures(YesNoOnly.Yes), + ) ++ seq + } + + private val testConfPattern = "test_conf(_(\\d+))?.json".r + private val expectedOutcomePattern = "expected_outcome(_(\\d+))?.json".r + + override protected def optionsString(options: inox.Options): String = "" + + /////////////////////////////////////////////// + + for (benchmark <- getFolders("equivalence")) { + testEquiv(s"equivalence/$benchmark") + } + + /////////////////////////////////////////////// + + private def getFolders(dir: String): Seq[String] = { + Option(getClass.getResource(s"/$dir")).toSeq.flatMap { dirUrl => + val dirFile = new File(dirUrl.getPath) + Option(dirFile.listFiles().toSeq).getOrElse(Seq.empty).filter(_.isDirectory) + .map(_.getName) + }.sorted + } + + private def testEquiv(benchmarkDir: String): Unit = { + val files = resourceFiles(benchmarkDir, f => f.endsWith(".scala") || f.endsWith(".conf") || f.endsWith(".json")).sorted + if (files.isEmpty) return // Empty folder -- skip + + val scalaFiles = files.filter(_.getName.endsWith(".scala")) + + val confs = files.flatMap(f => f.getName match { + case testConfPattern(_, num) => Some(Option(num).map(_.toInt) -> f) + case _ => None + }).toMap + assert(confs.nonEmpty, s"No test_conf.json found in $benchmarkDir") + + val expectedOutcomes = files.flatMap(f => f.getName match { + case expectedOutcomePattern(_, num) => Some(Option(num).map(_.toInt) -> f) + case _ => None + }).toMap + assert(expectedOutcomes.nonEmpty, s"No expected_outcome.json found in $benchmarkDir") + assert(confs.keySet == expectedOutcomes.keySet, "Test configuration and expected outcome files do not match") + + val runs = confs.keySet.toSeq.sorted.map { num => + (num, confs(num), expectedOutcomes(num)) + } + + ///////////////////////////////////// + + val dummyCtx: inox.Context = inox.TestContext.empty + import dummyCtx.given + val (_, program) = loadFiles(scalaFiles.map(_.getPath)) + assert(dummyCtx.reporter.errorCount == 0, "There should be no error while loading the files") + + val userFiltering = new DebugSymbols { + val name = "UserFiltering" + val context = dummyCtx + val s: xt.type = xt + val t: xt.type = xt + } + + val programSymbols = userFiltering.debugWithoutSummary(frontend.UserFiltering().transform)(program.symbols)._1 + programSymbols.ensureWellFormed + + ///////////////////////////////////// + + for ((num, confFile, expectedOutomeFile) <- runs) { + val conf = parseTestConf(confFile) + val expected = parseExpectedOutcome(expectedOutomeFile) + + val testName = s"$benchmarkDir${num.map(n => s" (variant $n)").getOrElse("")}" + test(testName, ctx => filter(ctx, benchmarkDir)) { ctx0 ?=> + val opts = Seq( + optModels(conf.models.toSeq.sorted), + optCompareFuns(conf.candidates.toSeq.sorted), + optSilent(true)) ++ + conf.norm.map(optNorm.apply).toSeq ++ + conf.n.map(optN.apply).toSeq ++ + conf.initScore.map(optInitScore.apply).toSeq ++ + conf.maxPerm.map(optMaxPerm.apply).toSeq + val ids = programSymbols.functions.keySet.toSeq + val programSymbols2 = { + if (conf.tests.isEmpty) programSymbols + else { + val annotated = programSymbols.functions.view.filter { case (fn, _) => conf.tests(fn.fullName) } + .mapValues(fd => fd.copy(flags = fd.flags :+ xt.Annotation("mkTest", Seq.empty))) + programSymbols.copy(functions = programSymbols.functions ++ annotated) + } + } + // Uncomment the `.copy(...)` to print equiv. chk. output + val ctx = ctx0.withOpts(opts: _*)//.copy(reporter = new inox.DefaultReporter(Set(DebugSectionEquivChk))) + given inox.Context = ctx + val report = Await.result(component.run(extraction.pipeline).apply(ids, programSymbols2), Duration.Inf) + val got = extractResults(conf.candidates, report) + got shouldEqual expected + } + } + } + + private case class EquivResults(equiv: Map[String, Set[String]], + erroneous: Set[String], + timeout: Set[String], + wrong: Set[String]) + private object EquivResults { + def empty: EquivResults = + EquivResults(Map.empty[String, Set[String]], Set.empty[String], Set.empty[String], Set.empty[String]) + } + + private case class TestConf(models: Set[String], + candidates: Set[String], + tests: Set[String], + norm: Option[String], + n: Option[Int], + initScore: Option[Int], + maxPerm: Option[Int]) + + private def extractResults(candidates: Set[String], analysis: component.Analysis): EquivResults = { + import EquivalenceCheckingReport._ + analysis.records.foldLeft(EquivResults.empty) { + case (acc, record) => + val fn = record.id.fullName + if (candidates(fn)) { + record.status match { + case Status.Equivalence(EquivalenceStatus.Valid(mod, _, _)) => + val currCluster = acc.equiv.getOrElse(mod.fullName, Set.empty) + acc.copy(equiv = acc.equiv + (mod.fullName -> (currCluster + fn))) + case Status.Equivalence(EquivalenceStatus.Erroneous) => acc.copy(erroneous = acc.erroneous + fn) + case Status.Equivalence(EquivalenceStatus.Unknown) => acc.copy(timeout = acc.timeout + fn) + case Status.Equivalence(EquivalenceStatus.Wrong) => acc.copy(wrong = acc.wrong + fn) + case Status.Verification(_) => acc + } + } else acc + } + } + + private def parseExpectedOutcome(f: File): EquivResults = { + val json = JsonUtils.parseFile(f) + val jsonObj = json.asObject.getOrCry("Expected top-level json to be an object") + + val equivObj = jsonObj("equivalent").getOrCry("Expected 'equivalent' field") + .asArray.getOrCry("Expected 'equivalent' to be an array") + val equiv = equivObj.map { elem => + val elemObj = elem.asObject.getOrCry("Expected elements in 'equivalent' to be objects") + val model = elemObj("model").getOrCry("Expected a 'model' field in 'equivalent'") + .asString.getOrCry("Expected 'model' to be a string") + val fns = elemObj.getStringArrayOrCry("functions").toSet + model -> fns + }.toMap + val erroneous = jsonObj.getStringArrayOrCry("erroneous").toSet + val timeout = jsonObj.getStringArrayOrCry("timeout").toSet + val wrong = jsonObj.getStringArrayOrCry("wrong").toSet + EquivResults(equiv, erroneous, timeout, wrong) + } + + private def parseTestConf(f: File): TestConf = { + val json = JsonUtils.parseFile(f) + val jsonObj = json.asObject.getOrCry("Expected top-level json to be an object") + val models = jsonObj.getStringArrayOrCry("models") + val candidates = jsonObj.getStringArrayOrCry("comparefuns") + val tests = if (jsonObj.contains("tests")) jsonObj.getStringArrayOrCry("tests") else Seq.empty + val norm = jsonObj("norm").map(_.asString.getOrCry("Expected 'norm' to be a string")) + val n = jsonObj("n").map(_.asNumber.getOrCry("Expected 'n' to be an int").toInt.getOrCry("'n' is too large")) + val initScore = jsonObj("init-score").map(_.asNumber.getOrCry("Expected 'init-score' to be an int").toInt.getOrCry("'init-score' is too large")) + val maxPerm = jsonObj("max-perm").map(_.asNumber.getOrCry("Expected 'max-perm' to be an int").toInt.getOrCry("'max-perm' is too large")) + assert(models.nonEmpty, "At least one model must be specified") + assert(candidates.nonEmpty, "At least one candidate must be specified") + TestConf(models.toSet, candidates.toSet, tests.toSet, norm, n, initScore, maxPerm) + } + + extension[T] (o: Option[T]) { + private def getOrCry(msg: String): T = o.getOrElse(throw AssertionError(msg)) + } + extension (id: Identifier) { + private def fullName: String = CheckFilter.fixedFullName(id) + } + extension (jsonObj: JsonObject) { + private def getStringArrayOrCry(field: String): Seq[String] = + jsonObj(field).getOrCry(s"Expected a '$field' field") + .asArray.getOrCry(s"Expected '$field' to be an array of strings") + .map(_.asString.getOrCry(s"Expected '$field' array elements to be strings")) + } +}