Skip to content

Commit 2746428

Browse files
authored
Code refactoring of initialization checker (#16066)
Code refactoring of initialization checker
2 parents 2844c2b + 9050560 commit 2746428

File tree

2 files changed

+100
-89
lines changed

2 files changed

+100
-89
lines changed

compiler/src/dotty/tools/dotc/transform/init/Checker.scala

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@ import StdNames._
1515
import dotty.tools.dotc.transform._
1616
import Phases._
1717

18+
import scala.collection.mutable
1819

1920
import Semantic._
2021

21-
class Checker extends Phase {
22+
class Checker extends Phase:
2223

2324
override def phaseName: String = Checker.name
2425

@@ -31,17 +32,23 @@ class Checker extends Phase {
3132

3233
override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] =
3334
val checkCtx = ctx.fresh.setPhase(this.start)
34-
Semantic.checkTasks(using checkCtx) {
35-
val traverser = new InitTreeTraverser()
36-
units.foreach { unit => traverser.traverse(unit.tpdTree) }
37-
}
35+
val traverser = new InitTreeTraverser()
36+
units.foreach { unit => traverser.traverse(unit.tpdTree) }
37+
val classes = traverser.getClasses()
38+
39+
Semantic.checkClasses(classes)(using checkCtx)
40+
3841
units
3942

40-
def run(using Context): Unit = {
43+
def run(using Context): Unit =
4144
// ignore, we already called `Semantic.check()` in `runOn`
42-
}
45+
()
46+
47+
class InitTreeTraverser extends TreeTraverser:
48+
private val classes: mutable.ArrayBuffer[ClassSymbol] = new mutable.ArrayBuffer
49+
50+
def getClasses(): List[ClassSymbol] = classes.toList
4351

44-
class InitTreeTraverser(using WorkList) extends TreeTraverser {
4552
override def traverse(tree: Tree)(using Context): Unit =
4653
traverseChildren(tree)
4754
tree match {
@@ -53,29 +60,12 @@ class Checker extends Phase {
5360
mdef match
5461
case tdef: TypeDef if tdef.isClassDef =>
5562
val cls = tdef.symbol.asClass
56-
val thisRef = ThisRef(cls)
57-
if shouldCheckClass(cls) then Semantic.addTask(thisRef)
63+
classes.append(cls)
5864
case _ =>
5965

6066
case _ =>
6167
}
62-
}
63-
64-
private def shouldCheckClass(cls: ClassSymbol)(using Context) = {
65-
val instantiable: Boolean =
66-
cls.is(Flags.Module) ||
67-
!cls.isOneOf(Flags.AbstractOrTrait) && {
68-
// see `Checking.checkInstantiable` in typer
69-
val tp = cls.appliedRef
70-
val stp = SkolemType(tp)
71-
val selfType = cls.givenSelfType.asSeenFrom(stp, cls)
72-
!selfType.exists || stp <:< selfType
73-
}
74-
75-
// A concrete class may not be instantiated if the self type is not satisfied
76-
instantiable && cls.enclosingPackageClass != defn.StdLibPatchesPackage.moduleClass
77-
}
78-
}
68+
end InitTreeTraverser
7969

8070
object Checker:
8171
val name: String = "initChecker"

compiler/src/dotty/tools/dotc/transform/init/Semantic.scala

Lines changed: 83 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,72 +1206,49 @@ object Semantic:
12061206
cls == defn.AnyValClass ||
12071207
cls == defn.ObjectClass
12081208

1209-
// ----- Work list ---------------------------------------------------
1210-
case class Task(value: ThisRef)
1211-
1212-
class WorkList private[Semantic]():
1213-
private val pendingTasks: mutable.ArrayBuffer[Task] = new mutable.ArrayBuffer
1214-
1215-
def addTask(task: Task): Unit =
1216-
if !pendingTasks.contains(task) then pendingTasks.append(task)
1217-
1218-
/** Process the worklist until done */
1219-
final def work()(using Cache, Context): Unit =
1220-
for task <- pendingTasks
1221-
do doTask(task)
1222-
1223-
/** Check an individual class
1224-
*
1225-
* This method should only be called from the work list scheduler.
1226-
*/
1227-
private def doTask(task: Task)(using Cache, Context): Unit =
1228-
val thisRef = task.value
1229-
val tpl = thisRef.klass.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]
1230-
1231-
@tailrec
1232-
def iterate(): Unit = {
1233-
given Promoted = Promoted.empty(thisRef.klass)
1234-
given Trace = Trace.empty.add(thisRef.klass.defTree)
1235-
given reporter: Reporter.BufferedReporter = new Reporter.BufferedReporter
1209+
// ----- API --------------------------------
12361210

1237-
thisRef.ensureFresh()
1211+
/** Check an individual class
1212+
*
1213+
* The class to be checked must be an instantiable concrete class.
1214+
*/
1215+
private def checkClass(classSym: ClassSymbol)(using Cache, Context): Unit =
1216+
val thisRef = ThisRef(classSym)
1217+
val tpl = classSym.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]
12381218

1239-
// set up constructor parameters
1240-
for param <- tpl.constr.termParamss.flatten do
1241-
thisRef.updateField(param.symbol, Hot)
1219+
@tailrec
1220+
def iterate(): Unit = {
1221+
given Promoted = Promoted.empty(classSym)
1222+
given Trace = Trace.empty.add(classSym.defTree)
1223+
given reporter: Reporter.BufferedReporter = new Reporter.BufferedReporter
12421224

1243-
log("checking " + task) { eval(tpl, thisRef, thisRef.klass) }
1244-
reporter.errors.foreach(_.issue)
1225+
thisRef.ensureFresh()
12451226

1246-
if cache.hasChanged && reporter.errors.isEmpty then
1247-
// code to prepare cache and heap for next iteration
1248-
cache.prepareForNextIteration()
1249-
iterate()
1250-
else
1251-
cache.prepareForNextClass()
1252-
}
1227+
// set up constructor parameters
1228+
for param <- tpl.constr.termParamss.flatten do
1229+
thisRef.updateField(param.symbol, Hot)
12531230

1254-
iterate()
1255-
end doTask
1256-
end WorkList
1257-
inline def workList(using wl: WorkList): WorkList = wl
1231+
log("checking " + classSym) { eval(tpl, thisRef, classSym) }
1232+
reporter.errors.foreach(_.issue)
12581233

1259-
// ----- API --------------------------------
1234+
if cache.hasChanged && reporter.errors.isEmpty then
1235+
// code to prepare cache and heap for next iteration
1236+
cache.prepareForNextIteration()
1237+
iterate()
1238+
else
1239+
cache.prepareForNextClass()
1240+
}
12601241

1261-
/** Add a checking task to the work list */
1262-
def addTask(thisRef: ThisRef)(using WorkList) = workList.addTask(Task(thisRef))
1242+
iterate()
1243+
end checkClass
12631244

1264-
/** Check the specified tasks
1265-
*
1266-
* Semantic.checkTasks {
1267-
* Semantic.addTask(...)
1268-
* }
1245+
/**
1246+
* Check the specified concrete classes
12691247
*/
1270-
def checkTasks(using Context)(taskBuilder: WorkList ?=> Unit): Unit =
1271-
val workList = new WorkList
1272-
val cache = new Cache
1273-
taskBuilder(using workList)
1274-
workList.work()(using cache, ctx)
1248+
def checkClasses(classes: List[ClassSymbol])(using Context): Unit =
1249+
given Cache()
1250+
for classSym <- classes if isConcreteClass(classSym) do
1251+
checkClass(classSym)
12751252

12761253
// ----- Semantic definition --------------------------------
12771254

@@ -1296,7 +1273,10 @@ object Semantic:
12961273
*
12971274
* This method only handles cache logic and delegates the work to `cases`.
12981275
*
1299-
* The parameter `cacheResult` is used to reduce the size of the cache.
1276+
* @param expr The expression to be evaluated.
1277+
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
1278+
* @param klass The enclosing class where the expression is located.
1279+
* @param cacheResult It is used to reduce the size of the cache.
13001280
*/
13011281
def eval(expr: Tree, thisV: Ref, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) {
13021282
cache.get(thisV, expr) match
@@ -1326,6 +1306,10 @@ object Semantic:
13261306
/** Handles the evaluation of different expressions
13271307
*
13281308
* Note: Recursive call should go to `eval` instead of `cases`.
1309+
*
1310+
* @param expr The expression to be evaluated.
1311+
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
1312+
* @param klass The enclosing class where the expression `expr` is located.
13291313
*/
13301314
def cases(expr: Tree, thisV: Ref, klass: ClassSymbol): Contextual[Value] =
13311315
val trace2 = trace.add(expr)
@@ -1503,7 +1487,14 @@ object Semantic:
15031487
report.error("[Internal error] unexpected tree" + Trace.show, expr)
15041488
Hot
15051489

1506-
/** Handle semantics of leaf nodes */
1490+
/** Handle semantics of leaf nodes
1491+
*
1492+
* For leaf nodes, their semantics is determined by their types.
1493+
*
1494+
* @param tp The type to be evaluated.
1495+
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
1496+
* @param klass The enclosing class where the type `tp` is located.
1497+
*/
15071498
def cases(tp: Type, thisV: Ref, klass: ClassSymbol): Contextual[Value] = log("evaluating " + tp.show, printer, (_: Value).show) {
15081499
tp match
15091500
case _: ConstantType =>
@@ -1541,7 +1532,12 @@ object Semantic:
15411532
Hot
15421533
}
15431534

1544-
/** Resolve C.this that appear in `klass` */
1535+
/** Resolve C.this that appear in `klass`
1536+
*
1537+
* @param target The class symbol for `C` for which `C.this` is to be resolved.
1538+
* @param thisV The value for `D.this` where `D` is represented by the parameter `klass`.
1539+
* @param klass The enclosing class where the type `C.this` is located.
1540+
*/
15451541
def resolveThis(target: ClassSymbol, thisV: Value, klass: ClassSymbol): Contextual[Value] = log("resolving " + target.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) {
15461542
if target == klass then thisV
15471543
else if target.is(Flags.Package) then Hot
@@ -1566,7 +1562,12 @@ object Semantic:
15661562

15671563
}
15681564

1569-
/** Compute the outer value that correspond to `tref.prefix` */
1565+
/** Compute the outer value that correspond to `tref.prefix`
1566+
*
1567+
* @param tref The type whose prefix is to be evaluated.
1568+
* @param thisV The value for `C.this` where `C` is represented by the parameter `klass`.
1569+
* @param klass The enclosing class where the type `tref` is located.
1570+
*/
15701571
def outerValue(tref: TypeRef, thisV: Ref, klass: ClassSymbol): Contextual[Value] =
15711572
val cls = tref.classSymbol.asClass
15721573
if tref.prefix == NoPrefix then
@@ -1577,7 +1578,12 @@ object Semantic:
15771578
if cls.isAllOf(Flags.JavaInterface) then Hot
15781579
else cases(tref.prefix, thisV, klass)
15791580

1580-
/** Initialize part of an abstract object in `klass` of the inheritance chain */
1581+
/** Initialize part of an abstract object in `klass` of the inheritance chain
1582+
*
1583+
* @param tpl The class body to be evaluated.
1584+
* @param thisV The value of the current object to be initialized.
1585+
* @param klass The class to which the template belongs.
1586+
*/
15811587
def init(tpl: Template, thisV: Ref, klass: ClassSymbol): Contextual[Value] = log("init " + klass.show, printer, (_: Value).show) {
15821588
val paramsMap = tpl.constr.termParamss.flatten.map { vdef =>
15831589
vdef.name -> thisV.objekt.field(vdef.symbol)
@@ -1782,3 +1788,18 @@ object Semantic:
17821788
if (sym.isEffectivelyFinal || sym.isConstructor) sym
17831789
else sym.matchingMember(cls.appliedRef)
17841790
}
1791+
1792+
private def isConcreteClass(cls: ClassSymbol)(using Context) = {
1793+
val instantiable: Boolean =
1794+
cls.is(Flags.Module) ||
1795+
!cls.isOneOf(Flags.AbstractOrTrait) && {
1796+
// see `Checking.checkInstantiable` in typer
1797+
val tp = cls.appliedRef
1798+
val stp = SkolemType(tp)
1799+
val selfType = cls.givenSelfType.asSeenFrom(stp, cls)
1800+
!selfType.exists || stp <:< selfType
1801+
}
1802+
1803+
// A concrete class may not be instantiated if the self type is not satisfied
1804+
instantiable && cls.enclosingPackageClass != defn.StdLibPatchesPackage.moduleClass
1805+
}

0 commit comments

Comments
 (0)