Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for inferring relative imports from Scala #13775

Merged
merged 2 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/python/pants/backend/java/compile/javac.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def compile_java_source(
union_membership: UnionMembership,
request: CompileJavaSourceRequest,
) -> FallibleClasspathEntry:
# Request the component's direct dependency classpath, and additionally any preqrequisite.
# Request the component's direct dependency classpath, and additionally any prerequisite.
classpath_entry_requests = [
*((request.prerequisite,) if request.prerequisite else ()),
*(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@ import scala.meta.transversers.Traverser
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.reflect.NameTransformer

case class AnImport(name: String, isWildcard: Boolean)
case class AnImport(
// The partially qualified input name for the import, which must be in scope at
// the import site.
name: String,
// An optional single token alias for the import in this scope.
alias: Option[String],
// True if the import imports all symbols contained within the name.
isWildcard: Boolean,
)

case class Analysis(
providedSymbols: Vector[String],
Expand Down Expand Up @@ -126,12 +134,12 @@ class SourceAnalysisTraverser extends Traverser {
skipProvidedNames = origSkipProvidedNames
}

def recordImport(name: String, isWildcard: Boolean): Unit = {
def recordImport(name: String, alias: Option[String], isWildcard: Boolean): Unit = {
val fullPackageName = nameParts.mkString(".")
if (!importsByScope.contains(fullPackageName)) {
importsByScope(fullPackageName) = ArrayBuffer[AnImport]()
}
importsByScope(fullPackageName).append(AnImport(name, isWildcard))
importsByScope(fullPackageName).append(AnImport(name, alias, isWildcard))
}

def recordConsumedSymbol(name: String): Unit = {
Expand Down Expand Up @@ -225,9 +233,20 @@ class SourceAnalysisTraverser extends Traverser {
val baseName = extractName(ref)
importees.foreach(importee => {
importee match {
case Importee.Wildcard() => recordImport(baseName, true)
case Importee.Name(nameNode) => recordImport(s"${baseName}.${extractName(nameNode)}", false)
case Importee.Rename(nameNode, _) => recordImport(s"${baseName}.${extractName(nameNode)}", false)
case Importee.Wildcard() => recordImport(baseName, None, true)
case Importee.Name(nameNode) => {
recordImport(s"${baseName}.${extractName(nameNode)}", None, false)
}
case Importee.Rename(nameNode, aliasNode) => {
// If a type is aliased to `_`, it is not brought into scope. We still record
// the import though, since compilation will fail if an import is not present.
val alias = extractName(aliasNode)
recordImport(
s"${baseName}.${extractName(nameNode)}",
if (alias == "_") None else Some(alias),
false,
)
}
case _ =>
}
})
Expand Down
53 changes: 33 additions & 20 deletions src/python/pants/backend/scala/dependency_inference/scala_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,17 @@
@dataclass(frozen=True)
class ScalaImport:
name: str
alias: str | None
is_wildcard: bool

@classmethod
def from_json_dict(cls, data: Mapping[str, Any]):
return cls(name=data["name"], is_wildcard=data["isWildcard"])
return cls(name=data["name"], alias=data.get("alias"), is_wildcard=data["isWildcard"])

def to_debug_json_dict(self) -> dict[str, Any]:
return {
"name": self.name,
"alias": self.alias,
"is_wildcard": self.is_wildcard,
}

Expand All @@ -130,29 +132,40 @@ def fully_qualified_consumed_symbols(self) -> Iterator[str]:
have been provided by any wildcard import in scope, as well as being declared in the current
package.
"""
# Collect all wildcard imports.
wildcard_imports_by_scope = {}
for scope, imports in self.imports_by_scope.items():
wildcard_imports = tuple(imp for imp in imports if imp.is_wildcard)
if wildcard_imports:
wildcard_imports_by_scope[scope] = wildcard_imports

for scope, consumed_symbols in self.consumed_symbols_by_scope.items():
parent_scope_wildcards = {
wi
for s, wildcard_imports in wildcard_imports_by_scope.items()
for wi in wildcard_imports
if scope.startswith(s)
}

def scope_and_parents(scope: str) -> Iterator[str]:
while True:
yield scope
if scope == "":
break
scope, _, _ = scope.rpartition(".")

for consumption_scope, consumed_symbols in self.consumed_symbols_by_scope.items():
parent_scopes = tuple(scope_and_parents(consumption_scope))
for symbol in consumed_symbols:
for scope in self.scopes:
yield f"{scope}.{symbol}"
if not self.scopes or "." in symbol:
symbol_rel_prefix, dot_in_symbol, symbol_rel_suffix = symbol.partition(".")
if not self.scopes or dot_in_symbol:
# TODO: Similar to #13545: we assume that a symbol containing a dot might already
# be fully qualified.
yield symbol
for wildcard_scope in parent_scope_wildcards:
yield f"{wildcard_scope.name}.{symbol}"
for parent_scope in parent_scopes:
if parent_scope in self.scopes:
# A package declaration is a parent of this scope, and any of its symbols
# could be in scope.
yield f"{parent_scope}.{symbol}"

for imp in self.imports_by_scope.get(parent_scope, ()):
if imp.is_wildcard:
# There is a wildcard import in a parent scope.
yield f"{imp.name}.{symbol}"
if dot_in_symbol:
# If the parent scope has an import which defines the first token of the
# symbol, then it might be a relative usage of an import.
if imp.alias:
if imp.alias == symbol_rel_prefix:
yield f"{imp.name}.{symbol_rel_suffix}"
elif imp.name.endswith(f".{symbol_rel_prefix}"):
yield f"{imp.name}.{symbol_rel_suffix}"

@classmethod
def from_json_dict(cls, d: dict) -> ScalaSourceDependencyAnalysis:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
@pytest.fixture
def rule_runner() -> RuleRunner:
rule_runner = RuleRunner(
preserve_tmpdirs=True,
rules=[
*scala_parser.rules(),
*coursier_fetch_rules(),
Expand All @@ -48,19 +47,35 @@ def rule_runner() -> RuleRunner:
return rule_runner


def test_parser_simple(rule_runner: RuleRunner) -> None:
def _analyze(rule_runner: RuleRunner, source: str) -> ScalaSourceDependencyAnalysis:
rule_runner.write_files(
{
"BUILD": textwrap.dedent(
"""
scala_source(
name="simple-source",
source="SimpleSource.scala",
"BUILD": """scala_source(name="source", source="Source.scala")""",
"Source.scala": source,
}
)

target = rule_runner.get_target(address=Address("", target_name="source"))

source_files = rule_runner.request(
SourceFiles,
[
SourceFilesRequest(
(target.get(SourcesField),),
for_sources_types=(ScalaSourceField,),
enable_codegen=True,
)
],
)

return rule_runner.request(ScalaSourceDependencyAnalysis, [source_files])


def test_parser_simple(rule_runner: RuleRunner) -> None:
analysis = _analyze(
rule_runner,
textwrap.dedent(
"""
),
"SimpleSource.scala": textwrap.dedent(
"""
package org.pantsbuild
package example

Expand Down Expand Up @@ -124,26 +139,7 @@ def this(bar: SomeTypeInSecondaryConstructor) {
}
}
"""
),
}
)

target = rule_runner.get_target(address=Address("", target_name="simple-source"))

source_files = rule_runner.request(
SourceFiles,
[
SourceFilesRequest(
(target.get(SourcesField),),
for_sources_types=(ScalaSourceField,),
enable_codegen=True,
)
],
)

analysis = rule_runner.request(
ScalaSourceDependencyAnalysis,
[source_files],
),
)

assert sorted(list(analysis.provided_symbols)) == [
Expand Down Expand Up @@ -223,12 +219,18 @@ def this(bar: SomeTypeInSecondaryConstructor) {
assert analysis.imports_by_scope == FrozenDict(
{
"org.pantsbuild.example.OuterClass": (
ScalaImport(name="foo.bar.SomeItem", is_wildcard=False),
ScalaImport(name="foo.bar.SomeItem", alias=None, is_wildcard=False),
),
"org.pantsbuild.example": (
ScalaImport(name="scala.collection.mutable.ArrayBuffer", is_wildcard=False),
ScalaImport(name="scala.collection.mutable.HashMap", is_wildcard=False),
ScalaImport(name="java.io", is_wildcard=True),
ScalaImport(
name="scala.collection.mutable.ArrayBuffer", alias=None, is_wildcard=False
),
ScalaImport(
name="scala.collection.mutable.HashMap",
alias="RenamedHashMap",
is_wildcard=False,
),
ScalaImport(name="java.io", alias=None, is_wildcard=True),
),
}
)
Expand Down Expand Up @@ -325,46 +327,49 @@ def this(bar: SomeTypeInSecondaryConstructor) {


def test_extract_package_scopes(rule_runner: RuleRunner) -> None:
rule_runner.write_files(
{
"BUILD": textwrap.dedent(
"""
scala_source(
name="source",
source="Source.scala",
)
"""
),
"Source.scala": textwrap.dedent(
"""
analysis = _analyze(
rule_runner,
textwrap.dedent(
"""
package outer
package more.than.one.part.at.once
package inner
"""
),
}
)

target = rule_runner.get_target(address=Address("", target_name="source"))

source_files = rule_runner.request(
SourceFiles,
[
SourceFilesRequest(
(target.get(SourcesField),),
for_sources_types=(ScalaSourceField,),
enable_codegen=True,
)
],
)

analysis = rule_runner.request(
ScalaSourceDependencyAnalysis,
[source_files],
),
)

assert sorted(analysis.scopes) == [
"outer",
"outer.more.than.one.part.at.once",
"outer.more.than.one.part.at.once.inner",
]


def test_relative_import(rule_runner: RuleRunner) -> None:
analysis = _analyze(
rule_runner,
textwrap.dedent(
"""
import java.io
import scala.{io => sio}
import nada.{io => _}

object OuterObject {
import org.pantsbuild.{io => pio}

val i = io.apply()
val s = sio.apply()
val p = pio.apply()
}
"""
),
)

assert set(analysis.fully_qualified_consumed_symbols()) == {
"io.apply",
"java.io.apply",
"org.pantsbuild.io.apply",
"pio.apply",
"scala.io.apply",
"sio.apply",
}