Skip to content

Commit

Permalink
Scala parser improvements (#15839)
Browse files Browse the repository at this point in the history
This PR includes:
1. Multiple improvements for scala parser (visit trait and class constructor, visit the right-hand side of a type definition, visit declarations.
2. Fix bug with scala parser visiting qualifier using the parent class `apply` vs and not the `ScalaParser.apply`, caused it to miss some consumed types.
3.  Fixed a bug with the extracted name of val type.
4. Add an option to scala parser subsystem to automatically add the package "package object" as a dependency.
5. Add an option to scala subsystem to enable/disable `add_dependencies_on_all_siblings`
  • Loading branch information
somdoron authored Jun 17, 2022
1 parent 960d401 commit 2b93e45
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ case class ProvidedSymbol(sawClass: Boolean, sawTrait: Boolean, sawObject: Boole
class SourceAnalysisTraverser extends Traverser {
val nameParts = ArrayBuffer[String]()
var skipProvidedNames = false
var visitInitArgs = false

val providedSymbolsByScope = HashMap[String, HashMap[String, ProvidedSymbol]]()
val importsByScope = HashMap[String, ArrayBuffer[AnImport]]()
Expand Down Expand Up @@ -186,7 +187,11 @@ class SourceAnalysisTraverser extends Traverser {

def visitMods(mods: List[Mod]): Unit = {
mods.foreach({
case Mod.Annot(init) => apply(init) // rely on `Init` extraction in main parsing match code
case Mod.Annot(init) =>
val currentVisitInitArgs = visitInitArgs
visitInitArgs = true
apply(init) // rely on `Init` extraction in main parsing match code
visitInitArgs = currentVisitInitArgs
case _ => ()
})
}
Expand All @@ -206,17 +211,19 @@ class SourceAnalysisTraverser extends Traverser {
visitTemplate(templ, name)
}

case Defn.Class(mods, nameNode, _tparams, _ctor, templ) => {
case Defn.Class(mods, nameNode, _tparams, ctor, templ) => {
visitMods(mods)
val name = extractName(nameNode)
recordProvidedName(name, sawClass = true)
apply(ctor)
visitTemplate(templ, name)
}

case Defn.Trait(mods, nameNode, _tparams, _ctor, templ) => {
case Defn.Trait(mods, nameNode, _tparams, ctor, templ) => {
visitMods(mods)
val name = extractName(nameNode)
recordProvidedName(name, sawTrait = true)
apply(ctor)
visitTemplate(templ, name)
}

Expand All @@ -227,10 +234,11 @@ class SourceAnalysisTraverser extends Traverser {
visitTemplate(templ, name)
}

case Defn.Type(mods, nameNode, _tparams, _body) => {
case Defn.Type(mods, nameNode, _tparams, body) => {
visitMods(mods)
val name = extractName(nameNode)
recordProvidedName(name)
extractNamesFromTypeTree(body).foreach(recordConsumedSymbol(_))
}

case Defn.Val(mods, pats, decltpe, rhs) => {
Expand All @@ -240,7 +248,7 @@ class SourceAnalysisTraverser extends Traverser {
recordProvidedName(name)
})
decltpe.foreach(tpe => {
recordConsumedSymbol(extractName(tpe))
extractNamesFromTypeTree(tpe).foreach(recordConsumedSymbol(_))
})
super.apply(rhs)
}
Expand Down Expand Up @@ -271,6 +279,22 @@ class SourceAnalysisTraverser extends Traverser {
withSuppressProvidedNames(() => apply(body))
}

case Decl.Def(mods, _nameNode, _tparams, params, decltpe) => {
visitMods(mods)
extractNamesFromTypeTree(decltpe).foreach(recordConsumedSymbol(_))
params.foreach(param => apply(param))
}

case Decl.Val(mods, _pats, decltpe) => {
visitMods(mods)
extractNamesFromTypeTree(decltpe).foreach(recordConsumedSymbol(_))
}

case Decl.Var(mods, _pats, decltpe) => {
visitMods(mods)
extractNamesFromTypeTree(decltpe).foreach(recordConsumedSymbol(_))
}

case Import(importers) => {
importers.foreach({ case Importer(ref, importees) =>
val baseName = extractName(ref)
Expand All @@ -296,8 +320,11 @@ class SourceAnalysisTraverser extends Traverser {
})
}

case Init(tpe, _name, _argss) => {
case Init(tpe, _name, argss) => {
extractNamesFromTypeTree(tpe).foreach(recordConsumedSymbol(_))

if (visitInitArgs)
argss.foreach(_.foreach(apply))
}

case Term.Param(mods, _name, decltpe, _default) => {
Expand Down Expand Up @@ -325,7 +352,7 @@ class SourceAnalysisTraverser extends Traverser {
case node @ Term.Select(_, _) => {
val name = extractName(node)
recordConsumedSymbol(name)
super.apply(node.qual)
apply(node.qual)
}

case node @ Term.Name(_) => {
Expand Down
2 changes: 2 additions & 0 deletions src/python/pants/backend/scala/dependency_inference/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ async def infer_scala_dependencies_via_source_analysis(
symbols.update(analysis.all_imports())
if scala_infer_subsystem.consumed_types:
symbols.update(analysis.fully_qualified_consumed_symbols())
if scala_infer_subsystem.package_objects:
symbols.update(analysis.scopes)

resolve = tgt[JvmResolveField].normalized_value(jvm)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class NestedClass {
type NestedType = Foo
object NestedObject {
}
def a(a: TraitConsumedType): Integer
}
object OuterObject {
Expand Down Expand Up @@ -246,6 +248,7 @@ def func4(a: Integer) = calc.calcFunc(a).toInt
assert analysis.consumed_symbols_by_scope == FrozenDict(
{
"org.pantsbuild.example.OuterClass.NestedObject": FrozenOrderedSet(["String"]),
"org.pantsbuild.example.OuterObject": FrozenOrderedSet(["Foo"]),
"org.pantsbuild.example.Functions": FrozenOrderedSet(
[
"TupleTypeArg2",
Expand All @@ -258,18 +261,29 @@ def func4(a: Integer) = calc.calcFunc(a).toInt
"LambdaTypeArg2",
"AParameterType",
"LambdaTypeArg1",
"OuterObject",
"bar",
"OuterObject.NestedVal",
]
),
"org.pantsbuild.example.HasPrimaryConstructor": FrozenOrderedSet(
["bar", "SomeTypeInSecondaryConstructor"]
),
"org.pantsbuild.example.OuterClass": FrozenOrderedSet(["Foo"]),
"org.pantsbuild.example.ApplyQualifier": FrozenOrderedSet(
["Integer", "a", "toInt", "calc.calcFunc"]
["Integer", "a", "toInt", "calc.calcFunc", "calc"]
),
"org.pantsbuild.example.OuterTrait": FrozenOrderedSet(
["Integer", "TraitConsumedType", "Foo"]
),
"org.pantsbuild.example": FrozenOrderedSet(
["ABaseClass", "ATrait1", "ATrait2.Nested", "BaseWithConstructor"]
[
"ABaseClass",
"ATrait1",
"SomeTypeInPrimaryConstructor",
"ATrait2.Nested",
"BaseWithConstructor",
]
),
}
)
Expand All @@ -287,8 +301,12 @@ def func4(a: Integer) = calc.calcFunc(a).toInt
"java.io.ATrait1",
"java.io.ATrait2.Nested",
"java.io.BaseWithConstructor",
"java.io.Foo",
"java.io.OuterObject.NestedVal",
"java.io.OuterObject",
"java.io.SomeTypeInPrimaryConstructor",
"java.io.String",
"java.io.TraitConsumedType",
"java.io.Unit",
"java.io.a",
"java.io.Integer",
Expand All @@ -297,6 +315,7 @@ def func4(a: Integer) = calc.calcFunc(a).toInt
"java.io.LambdaTypeArg2",
"java.io.SomeTypeInSecondaryConstructor",
"java.io.bar",
"java.io.calc",
"java.io.calc.calcFunc",
"java.io.foo",
"java.io.toInt",
Expand All @@ -307,21 +326,26 @@ def func4(a: Integer) = calc.calcFunc(a).toInt
"org.pantsbuild.example.ABaseClass",
"org.pantsbuild.example.AParameterType",
"org.pantsbuild.example.BaseWithConstructor",
"org.pantsbuild.example.Foo",
"org.pantsbuild.example.Integer",
"org.pantsbuild.example.SomeTypeInSecondaryConstructor",
"org.pantsbuild.example.ATrait1",
"org.pantsbuild.example.ATrait2.Nested",
"org.pantsbuild.example.OuterObject.NestedVal",
"org.pantsbuild.example.SomeTypeInPrimaryConstructor",
"org.pantsbuild.example.String",
"org.pantsbuild.example.TraitConsumedType",
"org.pantsbuild.example.Unit",
"org.pantsbuild.example.a",
"org.pantsbuild.example.bar",
"org.pantsbuild.example.calc",
"org.pantsbuild.example.calc.calcFunc",
"org.pantsbuild.example.foo",
"org.pantsbuild.example.toInt",
"org.pantsbuild.example.LambdaReturnType",
"org.pantsbuild.example.LambdaTypeArg1",
"org.pantsbuild.example.LambdaTypeArg2",
"org.pantsbuild.example.OuterObject",
"org.pantsbuild.example.TupleTypeArg1",
"org.pantsbuild.example.TupleTypeArg2",
"org.pantsbuild.+",
Expand All @@ -330,18 +354,23 @@ def func4(a: Integer) = calc.calcFunc(a).toInt
"org.pantsbuild.ATrait1",
"org.pantsbuild.ATrait2.Nested",
"org.pantsbuild.BaseWithConstructor",
"org.pantsbuild.Foo",
"org.pantsbuild.Integer",
"org.pantsbuild.LambdaReturnType",
"org.pantsbuild.LambdaTypeArg1",
"org.pantsbuild.LambdaTypeArg2",
"org.pantsbuild.OuterObject",
"org.pantsbuild.OuterObject.NestedVal",
"org.pantsbuild.SomeTypeInPrimaryConstructor",
"org.pantsbuild.SomeTypeInSecondaryConstructor",
"org.pantsbuild.String",
"org.pantsbuild.TraitConsumedType",
"org.pantsbuild.TupleTypeArg1",
"org.pantsbuild.TupleTypeArg2",
"org.pantsbuild.Unit",
"org.pantsbuild.a",
"org.pantsbuild.bar",
"org.pantsbuild.calc",
"org.pantsbuild.calc.calcFunc",
"org.pantsbuild.foo",
"org.pantsbuild.toInt",
Expand Down Expand Up @@ -388,11 +417,14 @@ def test_relative_import(rule_runner: RuleRunner) -> None:
)

assert set(analysis.fully_qualified_consumed_symbols()) == {
"io",
"io.apply",
"java.io.apply",
"org.pantsbuild.io.apply",
"pio",
"pio.apply",
"scala.io.apply",
"sio",
"sio.apply",
}

Expand Down Expand Up @@ -441,7 +473,7 @@ def test_extract_annotations(rule_runner: RuleRunner) -> None:
"""
package foo
@objectAnnotation("hello")
@objectAnnotation("hello", SomeType)
object Object {
@deprecated
def foo(arg: String @argAnnotation("foo")): Unit = {}
Expand All @@ -459,6 +491,7 @@ class Class {
),
)
assert sorted(analysis.fully_qualified_consumed_symbols()) == [
"foo.SomeType",
"foo.String",
"foo.Unit",
"foo.classAnnotation",
Expand All @@ -468,3 +501,26 @@ class Class {
"foo.valAnnotation",
"foo.varAnnotation",
]


def test_type_arguments(rule_runner: RuleRunner) -> None:
analysis = _analyze(
rule_runner,
textwrap.dedent(
"""
package foo
object Object {
var a: A[SomeType] = ???
val b: B[AnotherType] = ???
}
"""
),
)
assert sorted(analysis.fully_qualified_consumed_symbols()) == [
"foo.???",
"foo.A",
"foo.AnotherType",
"foo.B",
"foo.SomeType",
]
5 changes: 5 additions & 0 deletions src/python/pants/backend/scala/subsystems/scala_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ class ScalaInferSubsystem(Subsystem):
default=True,
help="Infer a target's dependencies by parsing consumed types from sources.",
)
package_objects = BoolOption(
"--package-objects",
default=True,
help="Add dependency on the package object to every target.",
)
force_add_siblings_as_dependencies = BoolOption(
"--force-add-siblings-as-dependencies",
default=True,
Expand Down

0 comments on commit 2b93e45

Please sign in to comment.