diff --git a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala index 6a3c3b4b6ef..2b49c671891 100644 --- a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala +++ b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala @@ -37,6 +37,7 @@ case class ProvidedSymbol(sawClass: Boolean, sawTrait: Boolean, sawObject: Boole class SourceAnalysisTraverser extends Traverser { val nameParts = ArrayBuffer[String]() var skipProvidedNames = false + var visitInitArgs = false val providedSymbolsByScope = HashMap[String, HashMap[String, ProvidedSymbol]]() val importsByScope = HashMap[String, ArrayBuffer[AnImport]]() @@ -186,7 +187,11 @@ class SourceAnalysisTraverser extends Traverser { def visitMods(mods: List[Mod]): Unit = { mods.foreach({ - case Mod.Annot(init) => apply(init) // rely on `Init` extraction in main parsing match code + case Mod.Annot(init) => + val currentVisitInitArgs = visitInitArgs + visitInitArgs = true + apply(init) // rely on `Init` extraction in main parsing match code + visitInitArgs = currentVisitInitArgs case _ => () }) } @@ -206,17 +211,19 @@ class SourceAnalysisTraverser extends Traverser { visitTemplate(templ, name) } - case Defn.Class(mods, nameNode, _tparams, _ctor, templ) => { + case Defn.Class(mods, nameNode, _tparams, ctor, templ) => { visitMods(mods) val name = extractName(nameNode) recordProvidedName(name, sawClass = true) + apply(ctor) visitTemplate(templ, name) } - case Defn.Trait(mods, nameNode, _tparams, _ctor, templ) => { + case Defn.Trait(mods, nameNode, _tparams, ctor, templ) => { visitMods(mods) val name = extractName(nameNode) recordProvidedName(name, sawTrait = true) + apply(ctor) visitTemplate(templ, name) } @@ -227,10 +234,11 @@ class SourceAnalysisTraverser extends Traverser { visitTemplate(templ, name) } - case Defn.Type(mods, nameNode, _tparams, _body) => { + case Defn.Type(mods, nameNode, _tparams, body) => { visitMods(mods) val name = extractName(nameNode) recordProvidedName(name) + extractNamesFromTypeTree(body).foreach(recordConsumedSymbol(_)) } case Defn.Val(mods, pats, decltpe, rhs) => { @@ -240,7 +248,7 @@ class SourceAnalysisTraverser extends Traverser { recordProvidedName(name) }) decltpe.foreach(tpe => { - recordConsumedSymbol(extractName(tpe)) + extractNamesFromTypeTree(tpe).foreach(recordConsumedSymbol(_)) }) super.apply(rhs) } @@ -271,6 +279,22 @@ class SourceAnalysisTraverser extends Traverser { withSuppressProvidedNames(() => apply(body)) } + case Decl.Def(mods, _nameNode, _tparams, params, decltpe) => { + visitMods(mods) + extractNamesFromTypeTree(decltpe).foreach(recordConsumedSymbol(_)) + params.foreach(param => apply(param)) + } + + case Decl.Val(mods, _pats, decltpe) => { + visitMods(mods) + extractNamesFromTypeTree(decltpe).foreach(recordConsumedSymbol(_)) + } + + case Decl.Var(mods, _pats, decltpe) => { + visitMods(mods) + extractNamesFromTypeTree(decltpe).foreach(recordConsumedSymbol(_)) + } + case Import(importers) => { importers.foreach({ case Importer(ref, importees) => val baseName = extractName(ref) @@ -296,8 +320,11 @@ class SourceAnalysisTraverser extends Traverser { }) } - case Init(tpe, _name, _argss) => { + case Init(tpe, _name, argss) => { extractNamesFromTypeTree(tpe).foreach(recordConsumedSymbol(_)) + + if (visitInitArgs) + argss.foreach(_.foreach(apply)) } case Term.Param(mods, _name, decltpe, _default) => { @@ -325,7 +352,7 @@ class SourceAnalysisTraverser extends Traverser { case node @ Term.Select(_, _) => { val name = extractName(node) recordConsumedSymbol(name) - super.apply(node.qual) + apply(node.qual) } case node @ Term.Name(_) => { diff --git a/src/python/pants/backend/scala/dependency_inference/rules.py b/src/python/pants/backend/scala/dependency_inference/rules.py index dd613e7c00f..5cfbfed54ed 100644 --- a/src/python/pants/backend/scala/dependency_inference/rules.py +++ b/src/python/pants/backend/scala/dependency_inference/rules.py @@ -78,6 +78,8 @@ async def infer_scala_dependencies_via_source_analysis( symbols.update(analysis.all_imports()) if scala_infer_subsystem.consumed_types: symbols.update(analysis.fully_qualified_consumed_symbols()) + if scala_infer_subsystem.package_objects: + symbols.update(analysis.scopes) resolve = tgt[JvmResolveField].normalized_value(jvm) diff --git a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py index 9fbe2ac9fde..4e8c362c9db 100644 --- a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py +++ b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py @@ -104,6 +104,8 @@ class NestedClass { type NestedType = Foo object NestedObject { } + + def a(a: TraitConsumedType): Integer } object OuterObject { @@ -246,6 +248,7 @@ def func4(a: Integer) = calc.calcFunc(a).toInt assert analysis.consumed_symbols_by_scope == FrozenDict( { "org.pantsbuild.example.OuterClass.NestedObject": FrozenOrderedSet(["String"]), + "org.pantsbuild.example.OuterObject": FrozenOrderedSet(["Foo"]), "org.pantsbuild.example.Functions": FrozenOrderedSet( [ "TupleTypeArg2", @@ -258,6 +261,7 @@ def func4(a: Integer) = calc.calcFunc(a).toInt "LambdaTypeArg2", "AParameterType", "LambdaTypeArg1", + "OuterObject", "bar", "OuterObject.NestedVal", ] @@ -265,11 +269,21 @@ def func4(a: Integer) = calc.calcFunc(a).toInt "org.pantsbuild.example.HasPrimaryConstructor": FrozenOrderedSet( ["bar", "SomeTypeInSecondaryConstructor"] ), + "org.pantsbuild.example.OuterClass": FrozenOrderedSet(["Foo"]), "org.pantsbuild.example.ApplyQualifier": FrozenOrderedSet( - ["Integer", "a", "toInt", "calc.calcFunc"] + ["Integer", "a", "toInt", "calc.calcFunc", "calc"] + ), + "org.pantsbuild.example.OuterTrait": FrozenOrderedSet( + ["Integer", "TraitConsumedType", "Foo"] ), "org.pantsbuild.example": FrozenOrderedSet( - ["ABaseClass", "ATrait1", "ATrait2.Nested", "BaseWithConstructor"] + [ + "ABaseClass", + "ATrait1", + "SomeTypeInPrimaryConstructor", + "ATrait2.Nested", + "BaseWithConstructor", + ] ), } ) @@ -287,8 +301,12 @@ def func4(a: Integer) = calc.calcFunc(a).toInt "java.io.ATrait1", "java.io.ATrait2.Nested", "java.io.BaseWithConstructor", + "java.io.Foo", "java.io.OuterObject.NestedVal", + "java.io.OuterObject", + "java.io.SomeTypeInPrimaryConstructor", "java.io.String", + "java.io.TraitConsumedType", "java.io.Unit", "java.io.a", "java.io.Integer", @@ -297,6 +315,7 @@ def func4(a: Integer) = calc.calcFunc(a).toInt "java.io.LambdaTypeArg2", "java.io.SomeTypeInSecondaryConstructor", "java.io.bar", + "java.io.calc", "java.io.calc.calcFunc", "java.io.foo", "java.io.toInt", @@ -307,21 +326,26 @@ def func4(a: Integer) = calc.calcFunc(a).toInt "org.pantsbuild.example.ABaseClass", "org.pantsbuild.example.AParameterType", "org.pantsbuild.example.BaseWithConstructor", + "org.pantsbuild.example.Foo", "org.pantsbuild.example.Integer", "org.pantsbuild.example.SomeTypeInSecondaryConstructor", "org.pantsbuild.example.ATrait1", "org.pantsbuild.example.ATrait2.Nested", "org.pantsbuild.example.OuterObject.NestedVal", + "org.pantsbuild.example.SomeTypeInPrimaryConstructor", "org.pantsbuild.example.String", + "org.pantsbuild.example.TraitConsumedType", "org.pantsbuild.example.Unit", "org.pantsbuild.example.a", "org.pantsbuild.example.bar", + "org.pantsbuild.example.calc", "org.pantsbuild.example.calc.calcFunc", "org.pantsbuild.example.foo", "org.pantsbuild.example.toInt", "org.pantsbuild.example.LambdaReturnType", "org.pantsbuild.example.LambdaTypeArg1", "org.pantsbuild.example.LambdaTypeArg2", + "org.pantsbuild.example.OuterObject", "org.pantsbuild.example.TupleTypeArg1", "org.pantsbuild.example.TupleTypeArg2", "org.pantsbuild.+", @@ -330,18 +354,23 @@ def func4(a: Integer) = calc.calcFunc(a).toInt "org.pantsbuild.ATrait1", "org.pantsbuild.ATrait2.Nested", "org.pantsbuild.BaseWithConstructor", + "org.pantsbuild.Foo", "org.pantsbuild.Integer", "org.pantsbuild.LambdaReturnType", "org.pantsbuild.LambdaTypeArg1", "org.pantsbuild.LambdaTypeArg2", + "org.pantsbuild.OuterObject", "org.pantsbuild.OuterObject.NestedVal", + "org.pantsbuild.SomeTypeInPrimaryConstructor", "org.pantsbuild.SomeTypeInSecondaryConstructor", "org.pantsbuild.String", + "org.pantsbuild.TraitConsumedType", "org.pantsbuild.TupleTypeArg1", "org.pantsbuild.TupleTypeArg2", "org.pantsbuild.Unit", "org.pantsbuild.a", "org.pantsbuild.bar", + "org.pantsbuild.calc", "org.pantsbuild.calc.calcFunc", "org.pantsbuild.foo", "org.pantsbuild.toInt", @@ -388,11 +417,14 @@ def test_relative_import(rule_runner: RuleRunner) -> None: ) assert set(analysis.fully_qualified_consumed_symbols()) == { + "io", "io.apply", "java.io.apply", "org.pantsbuild.io.apply", + "pio", "pio.apply", "scala.io.apply", + "sio", "sio.apply", } @@ -441,7 +473,7 @@ def test_extract_annotations(rule_runner: RuleRunner) -> None: """ package foo - @objectAnnotation("hello") + @objectAnnotation("hello", SomeType) object Object { @deprecated def foo(arg: String @argAnnotation("foo")): Unit = {} @@ -459,6 +491,7 @@ class Class { ), ) assert sorted(analysis.fully_qualified_consumed_symbols()) == [ + "foo.SomeType", "foo.String", "foo.Unit", "foo.classAnnotation", @@ -468,3 +501,26 @@ class Class { "foo.valAnnotation", "foo.varAnnotation", ] + + +def test_type_arguments(rule_runner: RuleRunner) -> None: + analysis = _analyze( + rule_runner, + textwrap.dedent( + """ + package foo + + object Object { + var a: A[SomeType] = ??? + val b: B[AnotherType] = ??? + } + """ + ), + ) + assert sorted(analysis.fully_qualified_consumed_symbols()) == [ + "foo.???", + "foo.A", + "foo.AnotherType", + "foo.B", + "foo.SomeType", + ] diff --git a/src/python/pants/backend/scala/subsystems/scala_infer.py b/src/python/pants/backend/scala/subsystems/scala_infer.py index 72e0601e08c..811492e5222 100644 --- a/src/python/pants/backend/scala/subsystems/scala_infer.py +++ b/src/python/pants/backend/scala/subsystems/scala_infer.py @@ -20,6 +20,11 @@ class ScalaInferSubsystem(Subsystem): default=True, help="Infer a target's dependencies by parsing consumed types from sources.", ) + package_objects = BoolOption( + "--package-objects", + default=True, + help="Add dependency on the package object to every target.", + ) force_add_siblings_as_dependencies = BoolOption( "--force-add-siblings-as-dependencies", default=True,