Skip to content

Commit

Permalink
simplification of select phase
Browse files Browse the repository at this point in the history
  • Loading branch information
rssh committed May 5, 2024
1 parent 36f2e3b commit 21eba36
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,58 +28,6 @@ class PhaseSelectAndGenerateShiftedMethods(selectedNodes: SelectedNodes) extends





def transformDefDefDisabled(tree: tpd.DefDef)(using Context): tpd.Tree = {

lazy val cpsTransformedAnnot = Symbols.requiredClass("cps.plugin.annotation.CpsTransformed")

if (tree.symbol.denot.is(Flags.Inline)) then
tree
else
val topTree = tree
val optKind = SelectedNodes.detectDefDefSelectKind(tree)
optKind.foreach{kind =>
tree.symbol.addAnnotation(cpsTransformedAnnot)
selectedNodes.addDefDef(tree.symbol,kind)
}
// TODO: try run this onlu on selected nodes
val childTraversor = new TreeTraverser {
override def traverse(tree: Tree)(using Context): Unit = {
tree match
case fun: DefDef if (fun.symbol != topTree.symbol) =>
selectedNodes.getDefDefRecord(tree.symbol) match
case Some(r) =>
if (!r.internal) {
selectedNodes.markAsInternal(tree.symbol)
traverseChildren(tree)
}
case None =>
traverseChildren(tree)
case Block(List(ddef:DefDef), closure:Closure) if ddef.symbol == closure.meth.symbol =>
traverseChildren(tree)
case Block(List(ddef:DefDef), Typed(closure:Closure, tp)) if ddef.symbol == closure.meth.symbol =>
traverseChildren(tree)
case Block(stats, expr) =>
// don't mark local function definitions and templates as internal
for(s <- stats) {
s match
case defDef: DefDef =>
//traverse(defDef)
case tdef: TypeDef =>
// do nothing
case other =>
traverse(other)
}
traverse(expr)
case _ =>
traverseChildren(tree)
}
}
childTraversor.traverse(tree)
tree
}

override def transformValDef(tree: tpd.ValDef)(using Context): tpd.Tree = {
tree.rhs match
case EmptyTree =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ trait RemoveScaffolding {
case Apply(fn, args) =>
fn.symbol.getAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsTransformed")) match
case Some(transformedAnnotation) =>
println(s"fn in closure has CpsTransformed annotation, ddef.rhs.tpe.widen=${ddef.rhs.tpe.widen.show}")
println(s"ddef.tpe.widen=${ddef.tpe.widen.show}, tree=${ddef.tpe.widen}")
ddef.tpe.widen match
case mt: MethodOrPoly =>
val nType = mt.derivedLambdaType(resType = ddef.rhs.tpe.widen)
Expand All @@ -58,13 +56,10 @@ trait RemoveScaffolding {
}
TransformUtil.substParamsMap(ddef.rhs, paramsMap)
})
println(s"ddef.symbol.hashCode=${ddef.symbol.hashCode()} nDdef.symbol.hashCode=${nDdef.symbol.hashCode()}")
println(s"fn=${fn}, args=${args}")
cpy.DefDef(tree)(rhs = Block(List(nDdef), Closure(env,ref(newDdefSymbol),tpe)).withSpan(treeBlock.span))
case _ =>
throw CpsTransformException("Assumed that ddef.tpe.widen is MethodOrPoly", ddef.srcPos)
case None =>
println(s"fn in closure has no annotation")
tree
case _ =>
tree
Expand Down Expand Up @@ -112,62 +107,9 @@ trait RemoveScaffolding {

override def transformApply(tree: Apply)(using ctx: Context): Tree = {

val runRetype = false

def retypeFn(fn: Tree) :Tree = {
fn match
case id: Ident =>
if (id.symbol.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsTransformed"))) then
println(s"fn has annotation: ${id.symbol.showFullName}")
else
println(s"fn has no annotation")
selectedNodes.getDefDefRecord(id.symbol) match
case Some(selectRecord) =>
println("fn in selectRecord")
case None =>
println(s"fn not in selectRecord, id=${id.show}, id.symbol=${id.symbol.showFullName}")
val retval = ref(id.symbol).withSpan(id.span) // here this will be symbol after phase CpsChangeSymbols
retval
//case sel: Select =>
// val retval = Select(sel.qualifier,sel.name).withSpan(sel.span)
// retval
case _ =>
fn
}

tree match
case Scaffolding.Cpsed(cpsedCall) =>
if (runRetype) then
val cpsedCallRetyped = cpsedCall match
case Apply(fn, args) =>
val retval =
try
val fnRetyped = retypeFn(fn)
Apply(fnRetyped, args).withSpan(cpsedCall.span)
catch
case ex: Throwable =>
println(s"RemoveScaffolding error: fn=${fn.show}, args=${args.map(_.show).mkString(",")}")
throw ex
retval
case _ =>
cpsedCall
cpsedCallRetyped
else
cpsedCall
case Apply(fn, args) =>
if (fn.symbol.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsTransformed"))) then
println(s"RemoveScaffolding::Apply, ${tree.show} fn has CpsTransformed annotation: ${fn.symbol.showFullName}")
println(s"fn.tpe.widen=${fn.tpe.widen.show}")
println(s"tree.tpe.widen=${tree.tpe.widen.show}")
tree
else
tree
// selectedNodes.getDefDefRecord(fn.symbol) match
// case Some(selectRecord) =>
// println(s"RemoveScaffolding: foudn apply with selectRecord, tree.tpe=${tree.tpe.show}, ")
// ???
// case None =>
// tree
case Scaffolding.Cpsed(cpsedCall) => cpsedCall
case _ =>
tree
}
Expand All @@ -181,17 +123,6 @@ trait RemoveScaffolding {
tree
}

override def transformSelect(tree: Select)(using Context): Tree = {
if (tree.symbol.hasAnnotation(Symbols.requiredClass("cps.plugin.annotation.CpsTransformed"))) then
println(s"RemoveScaffolding::Select, ${tree.show} has CpsTransformed annotation: ${tree.symbol.showFullName}")
println(s"sel.tpe.widen=${tree.tpe.widen.show}, sel.symbol.info.widen=${tree.symbol.info.widen}")
println(s"sel.qualifier.tpe.widen=${tree.qualifier.tpe.widen.show}, ${tree.tpe.show}")
println(s"sel.qualifier.symbol.infos=${tree.qualifier.symbol.info.show}")
tree
else
tree
}


def retrieveReturnType(ddefType: Type)(using Context): Type = {
ddefType match
Expand Down

0 comments on commit 21eba36

Please sign in to comment.