Skip to content

Commit

Permalink
Fix epfl-lara#480 and add more info for missing dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
mario-bucev committed Apr 27, 2023
1 parent b6ccd2a commit e8b5971
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 47 deletions.
77 changes: 64 additions & 13 deletions core/src/main/scala/stainless/frontend/Recovery.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package stainless
package frontend

import utils.StringUtils
import extraction.xlang.{trees => xt}
import utils.{StringUtils, XLangDependenciesFinder}
import extraction.xlang.trees as xt
import inox.utils.Position
import stainless.ast.SymbolIdentifier
import stainless.extraction.oo.{DefinitionTraverser, Trees}
import stainless.utils.XLangDependenciesFinder.{DependencyInfo, IdentifierKind}

object DebugSectionRecovery extends inox.DebugSection("recovery")

Expand All @@ -14,6 +18,7 @@ object DebugSectionRecovery extends inox.DebugSection("recovery")
sealed abstract class RecoveryResult
object RecoveryResult {
final case class Success(symbols: xt.Symbols) extends RecoveryResult
// For each definition, the symbol we could not find as well as its occurrences
final case class Failure(failures: Seq[(xt.Definition, Set[Identifier])]) extends RecoveryResult
}

Expand Down Expand Up @@ -47,9 +52,9 @@ class Recovery(symbols: xt.Symbols)(using val context: inox.Context) {
)

strategy.recover(d, missings(d.id)) match {
case Left(errs) =>
case Left((d, missing)) =>
reporter.debug(" => FAIL")
Left(errs)
Left((d, missing))

case Right(result) =>
reporter.debug(" => SUCCESS")
Expand Down Expand Up @@ -86,7 +91,7 @@ object Recovery {
symbols.typeDefs.values
).toSeq

val missings = allDefs.toSeq.flatMap { defn =>
val missings = allDefs.flatMap { defn =>
val missingDeps = findMissingDeps(defn, symbols)
if (missingDeps.isEmpty) None
else Some(defn.id -> missingDeps)
Expand All @@ -96,16 +101,13 @@ object Recovery {
symbols
} else {
val recovery = new Recovery(symbols)
val recovered = recovery.recover(missings) match {
val recovered = recovery.recover(missings.view.mapValues(_.keySet).toMap) match {
case RecoveryResult.Success(recovered) =>
recovered

case RecoveryResult.Failure(errors) =>
errors foreach { case (definition, unknowns) =>
ctx.reporter.error(
s"${definition.id.uniqueName} depends on missing dependencies: " +
s"${unknowns map (_.uniqueName) mkString ", "}."
)
reportMissingDependencies(definition, unknowns, missings(definition.id))
}
ctx.reporter.fatalError(s"Cannot recover from missing dependencies")
}
Expand All @@ -114,15 +116,64 @@ object Recovery {
}
}

private def findMissingDeps(defn: xt.Definition, symbols: xt.Symbols): Set[Identifier] = {
val finder = new utils.XLangDependenciesFinder
private def findMissingDeps(defn: xt.Definition, symbols: xt.Symbols): Map[Identifier, DependencyInfo] = {
val finder = new XLangDependenciesFinder
val deps = finder(defn)
deps.filter { dep =>
deps.filter { case (dep, _) =>
!symbols.classes.contains(dep) &&
!symbols.functions.contains(dep) &&
!symbols.typeDefs.contains(dep)
}
}

private def reportMissingDependencies(definition: xt.Definition, unknowns: Set[Identifier], info: Map[Identifier, DependencyInfo])(using ctx: inox.Context): Unit = {
assert(unknowns.subsetOf(info.keySet))
ctx.reporter.error(definition.getPos, s"${definition.id.uniqueName} depends on missing dependencies:")
val missingInfo = info.filter(p => unknowns(p._1))
for ((unknId, info) <- missingInfo) {
ctx.reporter.error(s"Symbol ${unknId.uniqueName}")
for (pos <- info.positions) {
ctx.reporter.error(pos, "")
}
hintSymbolOrigin(unknId, info.kind)
}
}

private def hintSymbolOrigin(id: Identifier, kind: IdentifierKind)(using ctx: inox.Context): Unit = {
id match {
case s: SymbolIdentifier =>
val hd = s.symbol.path.headOption
val orig = {
if (hd.contains("scala")) Some("Scala")
else if (hd.contains("java")) Some("Java")
else None
}

orig match {
case Some(orig) =>
val kindStr = kind match {
case IdentifierKind.Class => "class"
case IdentifierKind.TypeDef => "type definition"
case IdentifierKind.MethodOrFunction => "method"
case IdentifierKind.TypeSelection => "type selection"
}
ctx.reporter.error(
s"""Hint: this $kindStr comes from the $orig standard library and is currently not supported.""".stripMargin)

if (kind == IdentifierKind.Class) {
ctx.reporter.error(
s"""Hint: to use this class, you may create a new class wrapping it in a field, annotated with @extern
| import stainless.annotation.extern
| class ${id.name}(@extern underlying: ${id.fullName}) {
| // ... methods
| }
|See https://epfl-lara.github.io/stainless/wrap.html#a-wrapper-for-triemap for more information.""".stripMargin)
}
case None => ()
}
case _ => ()
}
}
}

/** A strategy to recover a definition with missing dependencies */
Expand Down
57 changes: 41 additions & 16 deletions core/src/main/scala/stainless/utils/XLangDependenciesFinder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
package stainless
package utils

import extraction.xlang.{ trees => xt }
import extraction.xlang.trees as xt
import inox.utils.{NoPosition, Position}

import scala.collection.mutable.{ HashSet => MutableSet }
import scala.collection.mutable.Map as MutableMap

/**
* [[XLangDependenciesFinder]] find the set of dependencies for a function/class,
Expand All @@ -20,23 +21,29 @@ import scala.collection.mutable.{ HashSet => MutableSet }
* the class itself.
*/
class XLangDependenciesFinder {
import XLangDependenciesFinder.*

private val deps: MutableSet[Identifier] = MutableSet.empty
private val deps: MutableMap[Identifier, DependencyInfo] = MutableMap.empty

private abstract class TreeTraverser extends xt.ConcreteOOSelfTreeTraverser {
def traverse(lcd: xt.LocalClassDef): Unit
def traverse(lmd: xt.LocalMethodDef): Unit
def traverse(ltd: xt.LocalTypeDef): Unit
}

private def add(id: Identifier, kind: IdentifierKind, pos: Position): Unit = {
val curr = deps.getOrElseUpdate(id, DependencyInfo(kind, Seq.empty))
deps += id -> curr.copy(positions = curr.positions :+ pos)
}

private val finder = new TreeTraverser {
override def traverse(e: xt.Expr): Unit = e match {
case xt.FunctionInvocation(id, _, _) =>
deps += id
add(id, IdentifierKind.MethodOrFunction, e.getPos)
super.traverse(e)

case xt.MethodInvocation(_, id, _, _) =>
deps += id
add(id, IdentifierKind.MethodOrFunction, e.getPos)
super.traverse(e)

case xt.LetClass(lcds, body) =>
Expand All @@ -50,20 +57,28 @@ class XLangDependenciesFinder {
deps --= lcds.flatMap(_.typeMembers).map(_.id).toSet
}

case xt.ClassConstructor(ct, _) =>
add(ct.id, IdentifierKind.Class, e.getPos)
super.traverse(e)

case xt.LocalClassConstructor(ct, _) =>
add(ct.id, IdentifierKind.Class, e.getPos)
super.traverse(e)

case _ => super.traverse(e)
}

override def traverse(pat: xt.Pattern): Unit = pat match {
case xt.UnapplyPattern(_, _, id, _, _) =>
deps += id
add(id, IdentifierKind.MethodOrFunction, pat.getPos)
super.traverse(pat)

case _ => super.traverse(pat)
}

override def traverse(tpe: xt.Type): Unit = tpe match {
case xt.ClassType(id, _) =>
deps += id
add(id, IdentifierKind.Class, tpe.getPos)
super.traverse(tpe)

case xt.RefinementType(vd, pred) =>
Expand All @@ -72,15 +87,15 @@ class XLangDependenciesFinder {

case xt.TypeSelect(expr, id) =>
expr foreach traverse
deps += id
add(id, IdentifierKind.TypeSelection, tpe.getPos)
super.traverse(tpe)

case _ => super.traverse(tpe)
}

override def traverse(flag: xt.Flag): Unit = flag match {
case xt.IsMethodOf(id) =>
deps += id
add(id, IdentifierKind.Class, NoPosition)
super.traverse(flag)

case _ => super.traverse(flag)
Expand Down Expand Up @@ -108,35 +123,45 @@ class XLangDependenciesFinder {
traverse(ltd.toTypeDef)
}

def apply(defn: xt.Definition): Set[Identifier] = defn match {
def apply(defn: xt.Definition): Map[Identifier, DependencyInfo] = defn match {
case fd: xt.FunDef => apply(fd)
case cd: xt.ClassDef => apply(cd)
case td: xt.TypeDef => apply(td)
case _: xt.ADTSort => sys.error("There should be not sorts at this stage")
}

def apply(fd: xt.FunDef): Set[Identifier] = {
def apply(fd: xt.FunDef): Map[Identifier, DependencyInfo] = {
finder.traverse(fd)
deps -= fd.id
deps --= fd.params.map(_.id)

deps.toSet
deps.toMap
}

def apply(cd: xt.ClassDef): Set[Identifier] = {
def apply(cd: xt.ClassDef): Map[Identifier, DependencyInfo] = {
finder.traverse(cd)
deps -= cd.id
deps --= cd.fields.map(_.id)

deps.toSet
deps.toMap
}

def apply(td: xt.TypeDef): Set[Identifier] = {
def apply(td: xt.TypeDef): Map[Identifier, DependencyInfo] = {
td.tparams foreach finder.traverse
finder.traverse(td.rhs)
td.flags foreach finder.traverse
deps -= td.id

deps.toSet
deps.toMap
}
}
object XLangDependenciesFinder {
case class DependencyInfo(kind: IdentifierKind, positions: Seq[Position])

enum IdentifierKind {
case Class
case TypeDef
case MethodOrFunction
case TypeSelection
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ class DottyCompiler(ctx: inox.Context, callback: CallBack) extends Compiler {
val allOrigPhases = super.phases
val extractionPhase = new ExtractionPhase
val scheduled = Plugins.schedule(allOrigPhases, List(extractionPhase))
// We only care about the phases preceding Stainless.
// We drop the rest as we are not interested in the full compilation pipeline
// (the whole pipeline is used for StainlessPlugin).
val necessary = scheduled.takeWhile(_.forall(_.phaseName != extractionPhase.phaseName))
// We also include init.Checker (which happens in the same mini-phase as FirstTransform, therefore not contained in `necessary`)
necessary :+ List(new init.Checker) :+ List(extractionPhase)
// We only care about the phases preceding Stainless *plus* some phases that are after Stainless,
// namely RefChecker, init.Checker and ForwardDepChecks.
// Note that the Stainless phase is only about extracting the Scala tree into Stainless tree,
// the actual processing is not done as a compiler phase but is done once the compiler finishes.
takeAllPhasesIncluding(scheduled, ForwardDepChecks.name)
}

private class ExtractionPhase extends PluginPhase {
Expand All @@ -37,23 +36,24 @@ class DottyCompiler(ctx: inox.Context, callback: CallBack) extends Compiler {
override val runsBefore = Set(FirstTransform.name)
// Note: this must not be instantiated within `run`, because we need the underlying `symbolMapping` in `StainlessExtraction`
// to be shared across multiple compilation unit.
val extraction = new StainlessExtraction(ctx)

override def runOn(units: List[CompilationUnit])(using dottyCtx: DottyContext): List[CompilationUnit] = {
dottyCtx.reporter match {
case sr: SimpleReporter if sr.hasSafeInitWarnings =>
// Do not run the Stainless extraction phase by returning no compilation units
Nil
case _ =>
super.runOn(units)
}
}
private val extraction = new StainlessExtraction(ctx)

// This method id called for every compilation unit, and in the same thread.
override def run(using dottyCtx: DottyContext): Unit =
extraction.extractUnit.foreach(extracted =>
callback(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs))
}

// Pick all phases until `including` (with its group included)
private def takeAllPhasesIncluding(phases: List[List[Phase]], including: String): List[List[Phase]] = {
def rec(phases: List[List[Phase]], acc: List[List[Phase]]): List[List[Phase]] = phases match {
case Nil => acc.reverse // Should not happen, since we are interested in trimming the phases
case group :: rest =>
if (group.exists(_.phaseName == including)) (group :: acc).reverse
else rec(rest, group :: acc)
}
rec(phases, Nil)
}
}

private class DottyDriver(args: Seq[String], compiler: DottyCompiler, reporter: DottyReporter) extends Driver {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ class FragmentChecker(inoxCtx: inox.Context)(using override val dottyCtx: DottyC
if (!((sym is ParamAccessor) || (sym is CaseAccessor))) {
// Widening TermRef allows to check for method type parameter, ValDef type, etc.
// But do we not check type for ident, otherwise we will report each occurrence of the variable as erroneous
// which will rapidly overwhelm the console
// which will rapidly overwhelm the console.
tree match {
case _: tpd.Ident => return // Nothing to check further, so we return
case Typed(_ : tpd.Ident, _) =>
Expand Down Expand Up @@ -475,6 +475,7 @@ class FragmentChecker(inoxCtx: inox.Context)(using override val dottyCtx: DottyC
case Try(_, cases, finalizer) =>
if (cases.isEmpty && finalizer.isEmpty) reportError(tree.sourcePos, "try expressions are not supported in Stainless")
else if (cases.isEmpty && !finalizer.isEmpty) reportError(tree.sourcePos, "try-finally expressions are not supported in Stainless")
else if (finalizer.isEmpty) reportError(tree.sourcePos, "try-catch expressions are not supported in Stainless")
else reportError(tree.sourcePos, "try-catch-finally expressions are not supported in Stainless")
traverseChildren(tree)

Expand Down

0 comments on commit e8b5971

Please sign in to comment.