diff --git a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala index 2b49c671891..5286b8a68ba 100644 --- a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala +++ b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala @@ -25,14 +25,17 @@ case class AnImport( ) case class Analysis( - providedSymbols: Vector[String], - providedSymbolsEncoded: Vector[String], + providedSymbols: Vector[Analysis.ProvidedSymbol], + providedSymbolsEncoded: Vector[Analysis.ProvidedSymbol], importsByScope: HashMap[String, ArrayBuffer[AnImport]], consumedSymbolsByScope: HashMap[String, HashSet[String]], scopes: Vector[String] ) +object Analysis { + case class ProvidedSymbol(name: String, recursive: Boolean) +} -case class ProvidedSymbol(sawClass: Boolean, sawTrait: Boolean, sawObject: Boolean) +case class ProvidedSymbol(sawClass: Boolean, sawTrait: Boolean, sawObject: Boolean, recursive: Boolean) class SourceAnalysisTraverser extends Traverser { val nameParts = ArrayBuffer[String]() @@ -112,7 +115,8 @@ class SourceAnalysisTraverser extends Traverser { symbolName: String, sawClass: Boolean = false, sawTrait: Boolean = false, - sawObject: Boolean = false + sawObject: Boolean = false, + recursive: Boolean = false ): Unit = { if (!skipProvidedNames) { val fullPackageName = nameParts.mkString(".") @@ -126,14 +130,16 @@ class SourceAnalysisTraverser extends Traverser { val newSymbol = ProvidedSymbol( sawClass = existingSymbol.sawClass || sawClass, sawTrait = existingSymbol.sawTrait || sawTrait, - sawObject = existingSymbol.sawObject || sawObject + sawObject = existingSymbol.sawObject || sawObject, + recursive = existingSymbol.recursive || recursive ) providedSymbols(symbolName) = newSymbol } else { providedSymbols(symbolName) = ProvidedSymbol( sawClass = sawClass, sawTrait = sawTrait, - sawObject = sawObject + sawObject = sawObject, + recursive = recursive ) } } @@ -230,8 +236,18 @@ class SourceAnalysisTraverser extends Traverser { case Defn.Object(mods, nameNode, templ) => { visitMods(mods) val name = extractName(nameNode) - recordProvidedName(name, sawObject = true) - visitTemplate(templ, name) + + // TODO: should object already be recursive? + // an object is recursive if extends another type because we cannot figure out the provided types + // in the parents, we just mark the object as recursive (which is indicated by non-empty inits) + val recursive = !templ.inits.isEmpty + recordProvidedName(name, sawObject = true, recursive = recursive) + + // If the object is recursive, no need to provide the symbols inside + if (recursive) + withSuppressProvidedNames(() => visitTemplate(templ, name)) + else + visitTemplate(templ, name) } case Defn.Type(mods, nameNode, _tparams, body) => { @@ -363,30 +379,30 @@ class SourceAnalysisTraverser extends Traverser { case node => super.apply(node) } - def gatherProvidedSymbols(): Vector[String] = { + def gatherProvidedSymbols(): Vector[Analysis.ProvidedSymbol] = { providedSymbolsByScope .flatMap({ case (scopeName, symbolsForScope) => - symbolsForScope.keys.map(symbolName => s"${scopeName}.${symbolName}").toVector + symbolsForScope.map { case(symbolName, symbol) => Analysis.ProvidedSymbol(s"${scopeName}.${symbolName}", symbol.recursive)}.toVector }) .toVector } - def gatherEncodedProvidedSymbols(): Vector[String] = { + def gatherEncodedProvidedSymbols(): Vector[Analysis.ProvidedSymbol] = { providedSymbolsByScope .flatMap({ case (scopeName, symbolsForScope) => val encodedSymbolsForScope = symbolsForScope.flatMap({ case (symbolName, symbol) => { val encodedSymbolName = NameTransformer.encode(symbolName) - val result = ArrayBuffer[String](encodedSymbolName) + val result = ArrayBuffer[Analysis.ProvidedSymbol](Analysis.ProvidedSymbol(encodedSymbolName, symbol.recursive)) if (symbol.sawObject) { - result.append(encodedSymbolName + "$") - result.append(encodedSymbolName + "$.MODULE$") + result.append(Analysis.ProvidedSymbol(encodedSymbolName + "$", symbol.recursive)) + result.append(Analysis.ProvidedSymbol(encodedSymbolName + "$.MODULE$", symbol.recursive)) } result.toVector } }) - encodedSymbolsForScope.map(symbolName => s"${scopeName}.${symbolName}") + encodedSymbolsForScope.map(symbol => symbol.copy(name = s"${scopeName}.${symbol.name}")) }) .toVector } diff --git a/src/python/pants/backend/scala/dependency_inference/rules_test.py b/src/python/pants/backend/scala/dependency_inference/rules_test.py index 66cbc462ff7..e369a02d944 100644 --- a/src/python/pants/backend/scala/dependency_inference/rules_test.py +++ b/src/python/pants/backend/scala/dependency_inference/rules_test.py @@ -354,3 +354,85 @@ def main(args: Array[String]): Unit = { assert deps == InferredDependencies( [Address("lib", relative_file_path="Library.scala", parameters={"resolve": "scala-2.12"})] ) + + +@maybe_skip_jdk_test +def test_recursive_objects(rule_runner: RuleRunner) -> None: + rule_runner.write_files( + { + "A/BUILD": dedent( + """\ + scala_sources(name = "a") + """ + ), + "A/A.scala": dedent( + """\ + package org.pantsbuild.a + + object A { + def funA(): Int = ??? + } + """ + ), + "B/BUILD": dedent( + """\ + scala_sources(name = "b") + """ + ), + "B/B.scala": dedent( + """\ + package org.pantsbuild.b + + import org.pantsbuild.a.A + + object B extends A { + def funB(): Int = ??? + } + """ + ), + "C/BUILD": dedent( + """\ + scala_sources(name = "c") + """ + ), + "C/C.scala": dedent( + """\ + package org.pantsbuild.c + + import org.pantsbuild.b.B.funA + + class C { + val x = funA() + } + """ + ), + "D/BUILD": dedent( + """\ + scala_sources(name = "d") + """ + ), + "D/D.scala": dedent( + """\ + package org.pantsbuild.d + + import org.pantsbuild.b.B.funB + + class D { + val x = funB() + } + """ + ), + } + ) + + target_b = rule_runner.get_target(Address("B", target_name="b", relative_file_path="B.scala")) + target_c = rule_runner.get_target(Address("C", target_name="c", relative_file_path="C.scala")) + target_d = rule_runner.get_target(Address("D", target_name="d", relative_file_path="D.scala")) + + assert rule_runner.request( + InferredDependencies, [InferScalaSourceDependencies(target_c[ScalaSourceField])] + ) == InferredDependencies(dependencies=[target_b.address]) + + assert rule_runner.request( + InferredDependencies, [InferScalaSourceDependencies(target_d[ScalaSourceField])] + ) == InferredDependencies(dependencies=[target_b.address]) diff --git a/src/python/pants/backend/scala/dependency_inference/scala_parser.py b/src/python/pants/backend/scala/dependency_inference/scala_parser.py index d782609a0bd..08fe72df3a0 100644 --- a/src/python/pants/backend/scala/dependency_inference/scala_parser.py +++ b/src/python/pants/backend/scala/dependency_inference/scala_parser.py @@ -69,10 +69,26 @@ def to_debug_json_dict(self) -> dict[str, Any]: } +@dataclass(frozen=True) +class ScalaProvidedSymbol: + name: str + recursive: bool + + @classmethod + def from_json_dict(cls, data: Mapping[str, Any]): + return cls(name=data["name"], recursive=data["recursive"]) + + def to_debug_json_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "recursive": self.recursive, + } + + @dataclass(frozen=True) class ScalaSourceDependencyAnalysis: - provided_symbols: FrozenOrderedSet[str] - provided_symbols_encoded: FrozenOrderedSet[str] + provided_symbols: FrozenOrderedSet[ScalaProvidedSymbol] + provided_symbols_encoded: FrozenOrderedSet[ScalaProvidedSymbol] imports_by_scope: FrozenDict[str, tuple[ScalaImport, ...]] consumed_symbols_by_scope: FrozenDict[str, FrozenOrderedSet[str]] scopes: FrozenOrderedSet[str] @@ -128,8 +144,12 @@ def scope_and_parents(scope: str) -> Iterator[str]: @classmethod def from_json_dict(cls, d: dict) -> ScalaSourceDependencyAnalysis: return cls( - provided_symbols=FrozenOrderedSet(d["providedSymbols"]), - provided_symbols_encoded=FrozenOrderedSet(d["providedSymbolsEncoded"]), + provided_symbols=FrozenOrderedSet( + ScalaProvidedSymbol.from_json_dict(v) for v in d["providedSymbols"] + ), + provided_symbols_encoded=FrozenOrderedSet( + ScalaProvidedSymbol.from_json_dict(v) for v in d["providedSymbolsEncoded"] + ), imports_by_scope=FrozenDict( { key: tuple(ScalaImport.from_json_dict(v) for v in values) @@ -147,8 +167,10 @@ def from_json_dict(cls, d: dict) -> ScalaSourceDependencyAnalysis: def to_debug_json_dict(self) -> dict[str, Any]: return { - "provided_symbols": list(self.provided_symbols), - "provided_symbols_encoded": list(self.provided_symbols_encoded), + "provided_symbols": [v.to_debug_json_dict() for v in self.provided_symbols], + "provided_symbols_encoded": [ + v.to_debug_json_dict() for v in self.provided_symbols_encoded + ], "imports_by_scope": { key: [v.to_debug_json_dict() for v in values] for key, values in self.imports_by_scope.items() diff --git a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py index 4e8c362c9db..fae68c20833 100644 --- a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py +++ b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py @@ -9,6 +9,7 @@ from pants.backend.scala.dependency_inference.scala_parser import ( AnalyzeScalaSourceRequest, ScalaImport, + ScalaProvidedSymbol, ScalaSourceDependencyAnalysis, ) from pants.backend.scala.target_types import ScalaSourceField, ScalaSourceTarget @@ -145,7 +146,7 @@ def func4(a: Integer) = calc.calcFunc(a).toInt ), ) - assert sorted(analysis.provided_symbols) == [ + assert sorted(symbol.name for symbol in analysis.provided_symbols) == [ "org.pantsbuild.example.ASubClass", "org.pantsbuild.example.ASubTrait", "org.pantsbuild.example.ApplyQualifier", @@ -179,7 +180,7 @@ def func4(a: Integer) = calc.calcFunc(a).toInt "org.pantsbuild.example.OuterTrait.NestedVar", ] - assert sorted(analysis.provided_symbols_encoded) == [ + assert sorted(symbol.name for symbol in analysis.provided_symbols_encoded) == [ "org.pantsbuild.example.ASubClass", "org.pantsbuild.example.ASubTrait", "org.pantsbuild.example.ApplyQualifier", @@ -441,7 +442,10 @@ def test_package_object(rule_runner: RuleRunner) -> None: """ ), ) - assert sorted(analysis.provided_symbols) == ["foo.bar", "foo.bar.Hello"] + assert sorted(symbol.name for symbol in analysis.provided_symbols) == [ + "foo.bar", + "foo.bar.Hello", + ] def test_source3(rule_runner: RuleRunner) -> None: @@ -524,3 +528,28 @@ def test_type_arguments(rule_runner: RuleRunner) -> None: "foo.B", "foo.SomeType", ] + + +def test_recursive_objects(rule_runner: RuleRunner) -> None: + analysis = _analyze( + rule_runner, + textwrap.dedent( + """ + package foo + + object Bar { + def a = ??? + } + + object Foo extends Bar { + def b = ??? + } + """ + ), + ) + + assert sorted(analysis.provided_symbols, key=lambda s: s.name) == [ + ScalaProvidedSymbol("foo.Bar", False), + ScalaProvidedSymbol("foo.Bar.a", False), + ScalaProvidedSymbol("foo.Foo", True), + ] diff --git a/src/python/pants/backend/scala/dependency_inference/symbol_mapper.py b/src/python/pants/backend/scala/dependency_inference/symbol_mapper.py index 61e91be3a00..ab6db09ba90 100644 --- a/src/python/pants/backend/scala/dependency_inference/symbol_mapper.py +++ b/src/python/pants/backend/scala/dependency_inference/symbol_mapper.py @@ -73,9 +73,21 @@ async def map_first_party_scala_targets_to_symbols( for (address, resolve), analysis in address_and_analysis: namespace = _symbol_namespace(address) for symbol in analysis.provided_symbols: - mapping[resolve].insert(symbol, [address], first_party=True, namespace=namespace) + mapping[resolve].insert( + symbol.name, + [address], + first_party=True, + namespace=namespace, + recursive=symbol.recursive, + ) for symbol in analysis.provided_symbols_encoded: - mapping[resolve].insert(symbol, [address], first_party=True, namespace=namespace) + mapping[resolve].insert( + symbol.name, + [address], + first_party=True, + namespace=namespace, + recursive=symbol.recursive, + ) return SymbolMap((resolve, node.frozen()) for resolve, node in mapping.items())