diff --git a/src/python/pants/backend/java/compile/javac.py b/src/python/pants/backend/java/compile/javac.py index 3a27687ec2e..6dc47c5d330 100644 --- a/src/python/pants/backend/java/compile/javac.py +++ b/src/python/pants/backend/java/compile/javac.py @@ -52,7 +52,7 @@ async def compile_java_source( union_membership: UnionMembership, request: CompileJavaSourceRequest, ) -> FallibleClasspathEntry: - # Request the component's direct dependency classpath, and additionally any preqrequisite. + # Request the component's direct dependency classpath, and additionally any prerequisite. classpath_entry_requests = [ *((request.prerequisite,) if request.prerequisite else ()), *( diff --git a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala index 156e0a48b3d..9c0425f7b02 100644 --- a/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala +++ b/src/python/pants/backend/scala/dependency_inference/ScalaParser.scala @@ -16,7 +16,15 @@ import scala.meta.transversers.Traverser import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.reflect.NameTransformer -case class AnImport(name: String, isWildcard: Boolean) +case class AnImport( + // The partially qualified input name for the import, which must be in scope at + // the import site. + name: String, + // An optional single token alias for the import in this scope. + alias: Option[String], + // True if the import imports all symbols contained within the name. + isWildcard: Boolean, +) case class Analysis( providedSymbols: Vector[String], @@ -126,12 +134,12 @@ class SourceAnalysisTraverser extends Traverser { skipProvidedNames = origSkipProvidedNames } - def recordImport(name: String, isWildcard: Boolean): Unit = { + def recordImport(name: String, alias: Option[String], isWildcard: Boolean): Unit = { val fullPackageName = nameParts.mkString(".") if (!importsByScope.contains(fullPackageName)) { importsByScope(fullPackageName) = ArrayBuffer[AnImport]() } - importsByScope(fullPackageName).append(AnImport(name, isWildcard)) + importsByScope(fullPackageName).append(AnImport(name, alias, isWildcard)) } def recordConsumedSymbol(name: String): Unit = { @@ -225,9 +233,20 @@ class SourceAnalysisTraverser extends Traverser { val baseName = extractName(ref) importees.foreach(importee => { importee match { - case Importee.Wildcard() => recordImport(baseName, true) - case Importee.Name(nameNode) => recordImport(s"${baseName}.${extractName(nameNode)}", false) - case Importee.Rename(nameNode, _) => recordImport(s"${baseName}.${extractName(nameNode)}", false) + case Importee.Wildcard() => recordImport(baseName, None, true) + case Importee.Name(nameNode) => { + recordImport(s"${baseName}.${extractName(nameNode)}", None, false) + } + case Importee.Rename(nameNode, aliasNode) => { + // If a type is aliased to `_`, it is not brought into scope. We still record + // the import though, since compilation will fail if an import is not present. + val alias = extractName(aliasNode) + recordImport( + s"${baseName}.${extractName(nameNode)}", + if (alias == "_") None else Some(alias), + false, + ) + } case _ => } }) diff --git a/src/python/pants/backend/scala/dependency_inference/scala_parser.py b/src/python/pants/backend/scala/dependency_inference/scala_parser.py index 22c9bf4f595..549bff35226 100644 --- a/src/python/pants/backend/scala/dependency_inference/scala_parser.py +++ b/src/python/pants/backend/scala/dependency_inference/scala_parser.py @@ -96,15 +96,17 @@ @dataclass(frozen=True) class ScalaImport: name: str + alias: str | None is_wildcard: bool @classmethod def from_json_dict(cls, data: Mapping[str, Any]): - return cls(name=data["name"], is_wildcard=data["isWildcard"]) + return cls(name=data["name"], alias=data.get("alias"), is_wildcard=data["isWildcard"]) def to_debug_json_dict(self) -> dict[str, Any]: return { "name": self.name, + "alias": self.alias, "is_wildcard": self.is_wildcard, } @@ -130,29 +132,40 @@ def fully_qualified_consumed_symbols(self) -> Iterator[str]: have been provided by any wildcard import in scope, as well as being declared in the current package. """ - # Collect all wildcard imports. - wildcard_imports_by_scope = {} - for scope, imports in self.imports_by_scope.items(): - wildcard_imports = tuple(imp for imp in imports if imp.is_wildcard) - if wildcard_imports: - wildcard_imports_by_scope[scope] = wildcard_imports - - for scope, consumed_symbols in self.consumed_symbols_by_scope.items(): - parent_scope_wildcards = { - wi - for s, wildcard_imports in wildcard_imports_by_scope.items() - for wi in wildcard_imports - if scope.startswith(s) - } + + def scope_and_parents(scope: str) -> Iterator[str]: + while True: + yield scope + if scope == "": + break + scope, _, _ = scope.rpartition(".") + + for consumption_scope, consumed_symbols in self.consumed_symbols_by_scope.items(): + parent_scopes = tuple(scope_and_parents(consumption_scope)) for symbol in consumed_symbols: - for scope in self.scopes: - yield f"{scope}.{symbol}" - if not self.scopes or "." in symbol: + symbol_rel_prefix, dot_in_symbol, symbol_rel_suffix = symbol.partition(".") + if not self.scopes or dot_in_symbol: # TODO: Similar to #13545: we assume that a symbol containing a dot might already # be fully qualified. yield symbol - for wildcard_scope in parent_scope_wildcards: - yield f"{wildcard_scope.name}.{symbol}" + for parent_scope in parent_scopes: + if parent_scope in self.scopes: + # A package declaration is a parent of this scope, and any of its symbols + # could be in scope. + yield f"{parent_scope}.{symbol}" + + for imp in self.imports_by_scope.get(parent_scope, ()): + if imp.is_wildcard: + # There is a wildcard import in a parent scope. + yield f"{imp.name}.{symbol}" + if dot_in_symbol: + # If the parent scope has an import which defines the first token of the + # symbol, then it might be a relative usage of an import. + if imp.alias: + if imp.alias == symbol_rel_prefix: + yield f"{imp.name}.{symbol_rel_suffix}" + elif imp.name.endswith(f".{symbol_rel_prefix}"): + yield f"{imp.name}.{symbol_rel_suffix}" @classmethod def from_json_dict(cls, d: dict) -> ScalaSourceDependencyAnalysis: diff --git a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py index 41b2e79291b..bddb0879d94 100644 --- a/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py +++ b/src/python/pants/backend/scala/dependency_inference/scala_parser_test.py @@ -29,7 +29,6 @@ @pytest.fixture def rule_runner() -> RuleRunner: rule_runner = RuleRunner( - preserve_tmpdirs=True, rules=[ *scala_parser.rules(), *coursier_fetch_rules(), @@ -48,19 +47,35 @@ def rule_runner() -> RuleRunner: return rule_runner -def test_parser_simple(rule_runner: RuleRunner) -> None: +def _analyze(rule_runner: RuleRunner, source: str) -> ScalaSourceDependencyAnalysis: rule_runner.write_files( { - "BUILD": textwrap.dedent( - """ - scala_source( - name="simple-source", - source="SimpleSource.scala", + "BUILD": """scala_source(name="source", source="Source.scala")""", + "Source.scala": source, + } + ) + + target = rule_runner.get_target(address=Address("", target_name="source")) + + source_files = rule_runner.request( + SourceFiles, + [ + SourceFilesRequest( + (target.get(SourcesField),), + for_sources_types=(ScalaSourceField,), + enable_codegen=True, ) + ], + ) + + return rule_runner.request(ScalaSourceDependencyAnalysis, [source_files]) + + +def test_parser_simple(rule_runner: RuleRunner) -> None: + analysis = _analyze( + rule_runner, + textwrap.dedent( """ - ), - "SimpleSource.scala": textwrap.dedent( - """ package org.pantsbuild package example @@ -124,26 +139,7 @@ def this(bar: SomeTypeInSecondaryConstructor) { } } """ - ), - } - ) - - target = rule_runner.get_target(address=Address("", target_name="simple-source")) - - source_files = rule_runner.request( - SourceFiles, - [ - SourceFilesRequest( - (target.get(SourcesField),), - for_sources_types=(ScalaSourceField,), - enable_codegen=True, - ) - ], - ) - - analysis = rule_runner.request( - ScalaSourceDependencyAnalysis, - [source_files], + ), ) assert sorted(list(analysis.provided_symbols)) == [ @@ -223,12 +219,18 @@ def this(bar: SomeTypeInSecondaryConstructor) { assert analysis.imports_by_scope == FrozenDict( { "org.pantsbuild.example.OuterClass": ( - ScalaImport(name="foo.bar.SomeItem", is_wildcard=False), + ScalaImport(name="foo.bar.SomeItem", alias=None, is_wildcard=False), ), "org.pantsbuild.example": ( - ScalaImport(name="scala.collection.mutable.ArrayBuffer", is_wildcard=False), - ScalaImport(name="scala.collection.mutable.HashMap", is_wildcard=False), - ScalaImport(name="java.io", is_wildcard=True), + ScalaImport( + name="scala.collection.mutable.ArrayBuffer", alias=None, is_wildcard=False + ), + ScalaImport( + name="scala.collection.mutable.HashMap", + alias="RenamedHashMap", + is_wildcard=False, + ), + ScalaImport(name="java.io", alias=None, is_wildcard=True), ), } ) @@ -325,42 +327,15 @@ def this(bar: SomeTypeInSecondaryConstructor) { def test_extract_package_scopes(rule_runner: RuleRunner) -> None: - rule_runner.write_files( - { - "BUILD": textwrap.dedent( - """ - scala_source( - name="source", - source="Source.scala", - ) - """ - ), - "Source.scala": textwrap.dedent( - """ + analysis = _analyze( + rule_runner, + textwrap.dedent( + """ package outer package more.than.one.part.at.once package inner """ - ), - } - ) - - target = rule_runner.get_target(address=Address("", target_name="source")) - - source_files = rule_runner.request( - SourceFiles, - [ - SourceFilesRequest( - (target.get(SourcesField),), - for_sources_types=(ScalaSourceField,), - enable_codegen=True, - ) - ], - ) - - analysis = rule_runner.request( - ScalaSourceDependencyAnalysis, - [source_files], + ), ) assert sorted(analysis.scopes) == [ @@ -368,3 +343,33 @@ def test_extract_package_scopes(rule_runner: RuleRunner) -> None: "outer.more.than.one.part.at.once", "outer.more.than.one.part.at.once.inner", ] + + +def test_relative_import(rule_runner: RuleRunner) -> None: + analysis = _analyze( + rule_runner, + textwrap.dedent( + """ + import java.io + import scala.{io => sio} + import nada.{io => _} + + object OuterObject { + import org.pantsbuild.{io => pio} + + val i = io.apply() + val s = sio.apply() + val p = pio.apply() + } + """ + ), + ) + + assert set(analysis.fully_qualified_consumed_symbols()) == { + "io.apply", + "java.io.apply", + "org.pantsbuild.io.apply", + "pio.apply", + "scala.io.apply", + "sio.apply", + }