Skip to content

Commit

Permalink
[internal] scala: extract annotations as consumed types (#13810)
Browse files Browse the repository at this point in the history
Extract annotations as consumed types for dependency inference.

Closes #13751.
  • Loading branch information
Tom Dyas authored Dec 6, 2021
1 parent 3442d08 commit 4c871ed
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,43 +163,56 @@ class SourceAnalysisTraverser extends Traverser {
})
}

def visitMods(mods: List[Mod]): Unit = {
mods.foreach({
case Mod.Annot(init) => apply(init) // rely on `Init` extraction in main parsing match code
case _ => ()
})
}

override def apply(tree: Tree): Unit = tree match {
case Pkg(ref, stats) => {
val name = extractName(ref)
recordScope(name)
withNamePart(name, () => super.apply(stats))
}

case Pkg.Object(_mods, nameNode, templ) => {
case Pkg.Object(mods, nameNode, templ) => {
visitMods(mods)
val name = extractName(nameNode)
recordScope(name)
visitTemplate(templ, name)
}

case Defn.Class(_mods, nameNode, _tparams, _ctor, templ) => {
case Defn.Class(mods, nameNode, _tparams, _ctor, templ) => {
visitMods(mods)
val name = extractName(nameNode)
recordProvidedName(name, sawClass = true)
visitTemplate(templ, name)
}

case Defn.Trait(_mods, nameNode, _tparams, _ctor, templ) => {
case Defn.Trait(mods, nameNode, _tparams, _ctor, templ) => {
visitMods(mods)
val name = extractName(nameNode)
recordProvidedName(name, sawTrait = true)
visitTemplate(templ, name)
}

case Defn.Object(_mods, nameNode, templ) => {
case Defn.Object(mods, nameNode, templ) => {
visitMods(mods)
val name = extractName(nameNode)
recordProvidedName(name, sawObject = true)
visitTemplate(templ, name)
}

case Defn.Type(_mods, nameNode, _tparams, _body) => {
case Defn.Type(mods, nameNode, _tparams, _body) => {
visitMods(mods)
val name = extractName(nameNode)
recordProvidedName(name)
}

case Defn.Val(_mods, pats, decltpe, rhs) => {
case Defn.Val(mods, pats, decltpe, rhs) => {
visitMods(mods)
pats.headOption.foreach(pat => {
val name = extractName(pat)
recordProvidedName(name)
Expand All @@ -210,7 +223,8 @@ class SourceAnalysisTraverser extends Traverser {
super.apply(rhs)
}

case Defn.Var(_mods, pats, decltpe, rhs) => {
case Defn.Var(mods, pats, decltpe, rhs) => {
visitMods(mods)
pats.headOption.foreach(pat => {
val name = extractName(pat)
recordProvidedName(name)
Expand All @@ -221,7 +235,8 @@ class SourceAnalysisTraverser extends Traverser {
super.apply(rhs)
}

case Defn.Def(_mods, nameNode, _tparams, params, decltpe, body) => {
case Defn.Def(mods, nameNode, _tparams, params, decltpe, body) => {
visitMods(mods)
val name = extractName(nameNode)
recordProvidedName(name)

Expand Down Expand Up @@ -263,19 +278,22 @@ class SourceAnalysisTraverser extends Traverser {
extractNamesFromTypeTree(tpe).foreach(recordConsumedSymbol(_))
}

case Term.Param(_mods, _name, decltpe, _default) => {
case Term.Param(mods, _name, decltpe, _default) => {
visitMods(mods)
decltpe.foreach(tpe => {
extractNamesFromTypeTree(tpe).foreach(recordConsumedSymbol(_))
})
}

case Ctor.Primary(_mods, _name, params_list) => {
case Ctor.Primary(mods, _name, params_list) => {
visitMods(mods)
params_list.foreach(params => {
params.foreach(param => apply(param))
})
}

case Ctor.Secondary(_mods, _name, params_list, init, stats) => {
case Ctor.Secondary(mods, _name, params_list, init, stats) => {
visitMods(mods)
params_list.foreach(params => {
params.foreach(param => apply(param))
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,39 @@ def test_package_object(rule_runner: RuleRunner) -> None:
),
)
assert sorted(analysis.provided_symbols) == ["foo.bar.Hello"]


def test_extract_annotations(rule_runner: RuleRunner) -> None:
analysis = _analyze(
rule_runner,
textwrap.dedent(
"""
package foo
@objectAnnotation("hello")
object Object {
@deprecated
def foo(arg: String @argAnnotation("foo")): Unit = {}
}
@classAnnotation("world")
class Class {
@valAnnotation val foo = 3
@varAnnotation var bar = 4
}
@traitAnnotation
trait Trait {}
"""
),
)
assert sorted(analysis.fully_qualified_consumed_symbols()) == [
"foo.String",
"foo.Unit",
"foo.classAnnotation",
"foo.deprecated",
"foo.objectAnnotation",
"foo.traitAnnotation",
"foo.valAnnotation",
"foo.varAnnotation",
]

0 comments on commit 4c871ed

Please sign in to comment.