Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scala: mark an object that extends another type as recursive #15865

Merged
merged 4 commits into from
Jun 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ case class AnImport(
)

case class Analysis(
providedSymbols: Vector[String],
providedSymbolsEncoded: Vector[String],
providedSymbols: Vector[Analysis.ProvidedSymbol],
providedSymbolsEncoded: Vector[Analysis.ProvidedSymbol],
importsByScope: HashMap[String, ArrayBuffer[AnImport]],
consumedSymbolsByScope: HashMap[String, HashSet[String]],
scopes: Vector[String]
)
object Analysis {
case class ProvidedSymbol(name: String, recursive: Boolean)
}

case class ProvidedSymbol(sawClass: Boolean, sawTrait: Boolean, sawObject: Boolean)
case class ProvidedSymbol(sawClass: Boolean, sawTrait: Boolean, sawObject: Boolean, recursive: Boolean)

class SourceAnalysisTraverser extends Traverser {
val nameParts = ArrayBuffer[String]()
Expand Down Expand Up @@ -112,7 +115,8 @@ class SourceAnalysisTraverser extends Traverser {
symbolName: String,
sawClass: Boolean = false,
sawTrait: Boolean = false,
sawObject: Boolean = false
sawObject: Boolean = false,
recursive: Boolean = false
): Unit = {
if (!skipProvidedNames) {
val fullPackageName = nameParts.mkString(".")
Expand All @@ -126,14 +130,16 @@ class SourceAnalysisTraverser extends Traverser {
val newSymbol = ProvidedSymbol(
sawClass = existingSymbol.sawClass || sawClass,
sawTrait = existingSymbol.sawTrait || sawTrait,
sawObject = existingSymbol.sawObject || sawObject
sawObject = existingSymbol.sawObject || sawObject,
recursive = existingSymbol.recursive || recursive
)
providedSymbols(symbolName) = newSymbol
} else {
providedSymbols(symbolName) = ProvidedSymbol(
sawClass = sawClass,
sawTrait = sawTrait,
sawObject = sawObject
sawObject = sawObject,
recursive = recursive
)
}
}
Expand Down Expand Up @@ -230,8 +236,18 @@ class SourceAnalysisTraverser extends Traverser {
case Defn.Object(mods, nameNode, templ) => {
visitMods(mods)
val name = extractName(nameNode)
recordProvidedName(name, sawObject = true)
visitTemplate(templ, name)

// TODO: should object already be recursive?
// an object is recursive if extends another type because we cannot figure out the provided types
// in the parents, we just mark the object as recursive (which is indicated by non-empty inits)
val recursive = !templ.inits.isEmpty
recordProvidedName(name, sawObject = true, recursive = recursive)

// If the object is recursive, no need to provide the symbols inside
tdyas marked this conversation as resolved.
Show resolved Hide resolved
if (recursive)
withSuppressProvidedNames(() => visitTemplate(templ, name))
else
visitTemplate(templ, name)
}

case Defn.Type(mods, nameNode, _tparams, body) => {
Expand Down Expand Up @@ -363,30 +379,30 @@ class SourceAnalysisTraverser extends Traverser {
case node => super.apply(node)
}

def gatherProvidedSymbols(): Vector[String] = {
def gatherProvidedSymbols(): Vector[Analysis.ProvidedSymbol] = {
providedSymbolsByScope
.flatMap({ case (scopeName, symbolsForScope) =>
symbolsForScope.keys.map(symbolName => s"${scopeName}.${symbolName}").toVector
symbolsForScope.map { case(symbolName, symbol) => Analysis.ProvidedSymbol(s"${scopeName}.${symbolName}", symbol.recursive)}.toVector
})
.toVector
}

def gatherEncodedProvidedSymbols(): Vector[String] = {
def gatherEncodedProvidedSymbols(): Vector[Analysis.ProvidedSymbol] = {
providedSymbolsByScope
.flatMap({ case (scopeName, symbolsForScope) =>
val encodedSymbolsForScope = symbolsForScope.flatMap({
case (symbolName, symbol) => {
val encodedSymbolName = NameTransformer.encode(symbolName)
val result = ArrayBuffer[String](encodedSymbolName)
val result = ArrayBuffer[Analysis.ProvidedSymbol](Analysis.ProvidedSymbol(encodedSymbolName, symbol.recursive))
if (symbol.sawObject) {
result.append(encodedSymbolName + "$")
result.append(encodedSymbolName + "$.MODULE$")
result.append(Analysis.ProvidedSymbol(encodedSymbolName + "$", symbol.recursive))
result.append(Analysis.ProvidedSymbol(encodedSymbolName + "$.MODULE$", symbol.recursive))
}
result.toVector
}
})

encodedSymbolsForScope.map(symbolName => s"${scopeName}.${symbolName}")
encodedSymbolsForScope.map(symbol => symbol.copy(name = s"${scopeName}.${symbol.name}"))
})
.toVector
}
Expand Down
82 changes: 82 additions & 0 deletions src/python/pants/backend/scala/dependency_inference/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,85 @@ def main(args: Array[String]): Unit = {
assert deps == InferredDependencies(
[Address("lib", relative_file_path="Library.scala", parameters={"resolve": "scala-2.12"})]
)


@maybe_skip_jdk_test
def test_recursive_objects(rule_runner: RuleRunner) -> None:
rule_runner.write_files(
{
"A/BUILD": dedent(
"""\
scala_sources(name = "a")
"""
),
"A/A.scala": dedent(
"""\
package org.pantsbuild.a

object A {
def funA(): Int = ???
}
"""
),
"B/BUILD": dedent(
"""\
scala_sources(name = "b")
"""
),
"B/B.scala": dedent(
"""\
package org.pantsbuild.b

import org.pantsbuild.a.A

object B extends A {
def funB(): Int = ???
}
"""
),
"C/BUILD": dedent(
"""\
scala_sources(name = "c")
"""
),
"C/C.scala": dedent(
"""\
package org.pantsbuild.c

import org.pantsbuild.b.B.funA

class C {
val x = funA()
}
"""
),
"D/BUILD": dedent(
"""\
scala_sources(name = "d")
"""
),
"D/D.scala": dedent(
"""\
package org.pantsbuild.d

import org.pantsbuild.b.B.funB

class D {
val x = funB()
}
"""
),
}
)

target_b = rule_runner.get_target(Address("B", target_name="b", relative_file_path="B.scala"))
target_c = rule_runner.get_target(Address("C", target_name="c", relative_file_path="C.scala"))
target_d = rule_runner.get_target(Address("D", target_name="d", relative_file_path="D.scala"))

assert rule_runner.request(
InferredDependencies, [InferScalaSourceDependencies(target_c[ScalaSourceField])]
) == InferredDependencies(dependencies=[target_b.address])

assert rule_runner.request(
InferredDependencies, [InferScalaSourceDependencies(target_d[ScalaSourceField])]
) == InferredDependencies(dependencies=[target_b.address])
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,26 @@ def to_debug_json_dict(self) -> dict[str, Any]:
}


@dataclass(frozen=True)
class ScalaProvidedSymbol:
name: str
recursive: bool

@classmethod
def from_json_dict(cls, data: Mapping[str, Any]):
return cls(name=data["name"], recursive=data["recursive"])

def to_debug_json_dict(self) -> dict[str, Any]:
return {
"name": self.name,
"recursive": self.recursive,
}


@dataclass(frozen=True)
class ScalaSourceDependencyAnalysis:
provided_symbols: FrozenOrderedSet[str]
provided_symbols_encoded: FrozenOrderedSet[str]
provided_symbols: FrozenOrderedSet[ScalaProvidedSymbol]
provided_symbols_encoded: FrozenOrderedSet[ScalaProvidedSymbol]
imports_by_scope: FrozenDict[str, tuple[ScalaImport, ...]]
consumed_symbols_by_scope: FrozenDict[str, FrozenOrderedSet[str]]
scopes: FrozenOrderedSet[str]
Expand Down Expand Up @@ -128,8 +144,12 @@ def scope_and_parents(scope: str) -> Iterator[str]:
@classmethod
def from_json_dict(cls, d: dict) -> ScalaSourceDependencyAnalysis:
return cls(
provided_symbols=FrozenOrderedSet(d["providedSymbols"]),
provided_symbols_encoded=FrozenOrderedSet(d["providedSymbolsEncoded"]),
provided_symbols=FrozenOrderedSet(
ScalaProvidedSymbol.from_json_dict(v) for v in d["providedSymbols"]
),
provided_symbols_encoded=FrozenOrderedSet(
ScalaProvidedSymbol.from_json_dict(v) for v in d["providedSymbolsEncoded"]
),
imports_by_scope=FrozenDict(
{
key: tuple(ScalaImport.from_json_dict(v) for v in values)
Expand All @@ -147,8 +167,10 @@ def from_json_dict(cls, d: dict) -> ScalaSourceDependencyAnalysis:

def to_debug_json_dict(self) -> dict[str, Any]:
return {
"provided_symbols": list(self.provided_symbols),
"provided_symbols_encoded": list(self.provided_symbols_encoded),
"provided_symbols": [v.to_debug_json_dict() for v in self.provided_symbols],
"provided_symbols_encoded": [
v.to_debug_json_dict() for v in self.provided_symbols_encoded
],
"imports_by_scope": {
key: [v.to_debug_json_dict() for v in values]
for key, values in self.imports_by_scope.items()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pants.backend.scala.dependency_inference.scala_parser import (
AnalyzeScalaSourceRequest,
ScalaImport,
ScalaProvidedSymbol,
ScalaSourceDependencyAnalysis,
)
from pants.backend.scala.target_types import ScalaSourceField, ScalaSourceTarget
Expand Down Expand Up @@ -145,7 +146,7 @@ def func4(a: Integer) = calc.calcFunc(a).toInt
),
)

assert sorted(analysis.provided_symbols) == [
assert sorted(symbol.name for symbol in analysis.provided_symbols) == [
"org.pantsbuild.example.ASubClass",
"org.pantsbuild.example.ASubTrait",
"org.pantsbuild.example.ApplyQualifier",
Expand Down Expand Up @@ -179,7 +180,7 @@ def func4(a: Integer) = calc.calcFunc(a).toInt
"org.pantsbuild.example.OuterTrait.NestedVar",
]

assert sorted(analysis.provided_symbols_encoded) == [
assert sorted(symbol.name for symbol in analysis.provided_symbols_encoded) == [
"org.pantsbuild.example.ASubClass",
"org.pantsbuild.example.ASubTrait",
"org.pantsbuild.example.ApplyQualifier",
Expand Down Expand Up @@ -441,7 +442,10 @@ def test_package_object(rule_runner: RuleRunner) -> None:
"""
),
)
assert sorted(analysis.provided_symbols) == ["foo.bar", "foo.bar.Hello"]
assert sorted(symbol.name for symbol in analysis.provided_symbols) == [
"foo.bar",
"foo.bar.Hello",
]


def test_source3(rule_runner: RuleRunner) -> None:
Expand Down Expand Up @@ -524,3 +528,28 @@ def test_type_arguments(rule_runner: RuleRunner) -> None:
"foo.B",
"foo.SomeType",
]


def test_recursive_objects(rule_runner: RuleRunner) -> None:
somdoron marked this conversation as resolved.
Show resolved Hide resolved
analysis = _analyze(
rule_runner,
textwrap.dedent(
"""
package foo

object Bar {
def a = ???
}

object Foo extends Bar {
def b = ???
}
"""
),
)

assert sorted(analysis.provided_symbols, key=lambda s: s.name) == [
ScalaProvidedSymbol("foo.Bar", False),
ScalaProvidedSymbol("foo.Bar.a", False),
ScalaProvidedSymbol("foo.Foo", True),
]
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,21 @@ async def map_first_party_scala_targets_to_symbols(
for (address, resolve), analysis in address_and_analysis:
namespace = _symbol_namespace(address)
for symbol in analysis.provided_symbols:
mapping[resolve].insert(symbol, [address], first_party=True, namespace=namespace)
mapping[resolve].insert(
symbol.name,
[address],
first_party=True,
namespace=namespace,
recursive=symbol.recursive,
)
for symbol in analysis.provided_symbols_encoded:
mapping[resolve].insert(symbol, [address], first_party=True, namespace=namespace)
mapping[resolve].insert(
symbol.name,
[address],
first_party=True,
namespace=namespace,
recursive=symbol.recursive,
)

return SymbolMap((resolve, node.frozen()) for resolve, node in mapping.items())

Expand Down