Skip to content

improvement: Rework IndexedContext to reuse the previously calculated scopes #22898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions compiler/src/dotty/tools/dotc/interactive/Completion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ case class Completion(label: String, description: String, symbols: List[Symbol])

object Completion:

def scopeContext(pos: SourcePosition)(using Context): CompletionResult =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally this would get deduplicated with normal rawCompletions being invoked, but I will try to do it as a separate step afterwards.

val tpdPath = Interactive.pathTo(ctx.compilationUnit.tpdTree, pos.span)
val completionContext = Interactive.contextOfPath(tpdPath).withPhase(Phases.typerPhase)
inContext(completionContext):
val untpdPath = Interactive.resolveTypedOrUntypedPath(tpdPath, pos)
val mode = completionMode(untpdPath, pos, forSymbolSearch = true)
val rawPrefix = completionPrefix(untpdPath, pos)
val completer = new Completer(mode, pos, untpdPath, _ => true)
completer.scopeCompletions

/** Get possible completions from tree at `pos`
*
* @return offset and list of symbols for possible completions
Expand All @@ -60,7 +70,6 @@ object Completion:
val mode = completionMode(untpdPath, pos)
val rawPrefix = completionPrefix(untpdPath, pos)
val completions = rawCompletions(pos, mode, rawPrefix, tpdPath, untpdPath)

postProcessCompletions(untpdPath, completions, rawPrefix)

/** Get possible completions from tree at `pos`
Expand Down Expand Up @@ -89,7 +98,7 @@ object Completion:
*
* Otherwise, provide no completion suggestion.
*/
def completionMode(path: List[untpd.Tree], pos: SourcePosition): Mode = path match
def completionMode(path: List[untpd.Tree], pos: SourcePosition, forSymbolSearch: Boolean = false): Mode = path match
// Ignore `package foo@@` and `package foo.bar@@`
case ((_: tpd.Select) | (_: tpd.Ident)):: (_ : tpd.PackageDef) :: _ => Mode.None
case GenericImportSelector(sel) =>
Expand All @@ -102,11 +111,14 @@ object Completion:
case untpd.Literal(Constants.Constant(_: String)) :: _ => Mode.Term | Mode.Scope // literal completions
case (ref: untpd.RefTree) :: _ =>
val maybeSelectMembers = if ref.isInstanceOf[untpd.Select] then Mode.Member else Mode.Scope

if (ref.name.isTermName) Mode.Term | maybeSelectMembers
if (forSymbolSearch) then Mode.Term | Mode.Type | maybeSelectMembers
else if (ref.name.isTermName) Mode.Term | maybeSelectMembers
else if (ref.name.isTypeName) Mode.Type | maybeSelectMembers
else Mode.None

case (_: tpd.TypeTree | _: tpd.MemberDef) :: _ if forSymbolSearch => Mode.Type | Mode.Term
case (_: tpd.CaseDef) :: _ if forSymbolSearch => Mode.Type | Mode.Term
case Nil if forSymbolSearch => Mode.Type | Mode.Term
case _ => Mode.None

/** When dealing with <errors> in varios palces we check to see if they are
Expand Down Expand Up @@ -174,12 +186,12 @@ object Completion:
case _ => None

private object StringContextApplication:
def unapply(path: List[tpd.Tree]): Option[tpd.Apply] =
def unapply(path: List[tpd.Tree]): Option[tpd.Apply] =
path match
case tpd.Select(qual @ tpd.Apply(tpd.Select(tpd.Select(_, StdNames.nme.StringContext), _), _), _) :: _ =>
Some(qual)
case _ => None


/** Inspect `path` to determine the offset where the completion result should be inserted. */
def completionOffset(untpdPath: List[untpd.Tree]): Int =
Expand Down Expand Up @@ -230,14 +242,14 @@ object Completion:
val result = adjustedPath match
// Ignore synthetic select from `This` because in code it was `Ident`
// See example in dotty.tools.languageserver.CompletionTest.syntheticThis
case tpd.Select(qual @ tpd.This(_), _) :: _ if qual.span.isSynthetic => completer.scopeCompletions
case tpd.Select(qual @ tpd.This(_), _) :: _ if qual.span.isSynthetic => completer.scopeCompletions.names
case StringContextApplication(qual) =>
completer.scopeCompletions ++ completer.selectionCompletions(qual)
case tpd.Select(qual, _) :: _ if qual.typeOpt.hasSimpleKind =>
completer.scopeCompletions.names ++ completer.selectionCompletions(qual)
case tpd.Select(qual, _) :: _ if qual.typeOpt.hasSimpleKind =>
completer.selectionCompletions(qual)
case tpd.Select(qual, _) :: _ => Map.empty
case (tree: tpd.ImportOrExport) :: _ => completer.directMemberCompletions(tree.expr)
case _ => completer.scopeCompletions
case _ => completer.scopeCompletions.names

interactiv.println(i"""completion info with pos = $pos,
| term = ${completer.mode.is(Mode.Term)},
Expand Down Expand Up @@ -338,6 +350,7 @@ object Completion:
(completionMode.is(Mode.Term) && (sym.isTerm || sym.is(ModuleClass))
|| (completionMode.is(Mode.Type) && (sym.isType || sym.isStableMember)))
)
end isValidCompletionSymbol

given ScopeOrdering(using Context): Ordering[Seq[SingleDenotation]] with
val order =
Expand Down Expand Up @@ -371,7 +384,7 @@ object Completion:
* (even if the import follows it syntactically)
* - a more deeply nested import shadowing a member or a local definition causes an ambiguity
*/
def scopeCompletions(using context: Context): CompletionMap =
def scopeCompletions(using context: Context): CompletionResult =

/** Temporary data structure representing denotations with the same name introduced in a given scope
* as a member of a type, by a local definition or by an import clause
Expand All @@ -382,14 +395,19 @@ object Completion:
ScopedDenotations(denots.filter(includeFn), ctx)

val mappings = collection.mutable.Map.empty[Name, List[ScopedDenotations]].withDefaultValue(List.empty)
val renames = collection.mutable.Map.empty[Symbol, Name]
def addMapping(name: Name, denots: ScopedDenotations) =
mappings(name) = mappings(name) :+ denots

ctx.outersIterator.foreach { case ctx @ given Context =>
if ctx.isImportContext then
importedCompletions.foreach { (name, denots) =>
val imported = importedCompletions
imported.names.foreach { (name, denots) =>
addMapping(name, ScopedDenotations(denots, ctx, include(_, name)))
}
imported.renames.foreach { (name, newName) =>
renames(name) = newName
}
else if ctx.owner.isClass then
accessibleMembers(ctx.owner.thisType)
.groupByName.foreach { (name, denots) =>
Expand Down Expand Up @@ -433,7 +451,6 @@ object Completion:
// most deeply nested member or local definition if not shadowed by an import
case Some(local) if local.ctx.scope == first.ctx.scope =>
resultMappings += name -> local.denots

case None if isSingleImport || isImportedInDifferentScope || isSameSymbolImportedDouble =>
resultMappings += name -> first.denots
case None if notConflictingWithDefaults =>
Expand All @@ -443,7 +460,7 @@ object Completion:
}
}

resultMappings
CompletionResult(resultMappings, renames.toMap)
end scopeCompletions

/** Widen only those types which are applied or are exactly nothing
Expand Down Expand Up @@ -485,15 +502,20 @@ object Completion:
/** Completions introduced by imports directly in this context.
* Completions from outer contexts are not included.
*/
private def importedCompletions(using Context): CompletionMap =
private def importedCompletions(using Context): CompletionResult =
val imp = ctx.importInfo
val renames = collection.mutable.Map.empty[Symbol, Name]

if imp == null then
Map.empty
CompletionResult(Map.empty, Map.empty)
else
def fromImport(name: Name, nameInScope: Name): Seq[(Name, SingleDenotation)] =
imp.site.member(name).alternatives
.collect { case denot if include(denot, nameInScope) => nameInScope -> denot }
.collect { case denot if include(denot, nameInScope) =>
if name != nameInScope then
renames(denot.symbol) = nameInScope
nameInScope -> denot
}

val givenImports = imp.importedImplicits
.map { ref => (ref.implicitName: Name, ref.underlyingRef.denot.asSingleDenotation) }
Expand All @@ -519,7 +541,8 @@ object Completion:
fromImport(original.toTypeName, nameInScope.toTypeName)
}.toSeq.groupByName

givenImports ++ wildcardMembers ++ explicitMembers
val results = givenImports ++ wildcardMembers ++ explicitMembers
CompletionResult(results, renames.toMap)
end importedCompletions

/** Completions from implicit conversions including old style extensions using implicit classes */
Expand Down Expand Up @@ -597,7 +620,7 @@ object Completion:

// 1. The extension method is visible under a simple name, by being defined or inherited or imported in a scope enclosing the reference.
val termCompleter = new Completer(Mode.Term, pos, untpdPath, matches)
val extMethodsInScope = termCompleter.scopeCompletions.toList.flatMap:
val extMethodsInScope = termCompleter.scopeCompletions.names.toList.flatMap:
case (name, denots) => denots.collect:
case d: SymDenotation if d.isTerm && d.termRef.symbol.is(Extension) => (d.termRef, name.asTermName)

Expand Down Expand Up @@ -699,6 +722,7 @@ object Completion:

private type CompletionMap = Map[Name, Seq[SingleDenotation]]

case class CompletionResult(names: Map[Name, Seq[SingleDenotation]], renames: Map[Symbol, Name])
/**
* The completion mode: defines what kinds of symbols should be included in the completion
* results.
Expand Down
31 changes: 19 additions & 12 deletions presentation-compiler/src/main/dotty/tools/pc/AutoImports.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ object AutoImports:
case class Select(qual: SymbolIdent, name: String) extends SymbolIdent:
def value: String = s"${qual.value}.$name"

def direct(name: String): SymbolIdent = Direct(name)
def direct(name: String)(using Context): SymbolIdent = Direct(name)

def fullIdent(symbol: Symbol)(using Context): SymbolIdent =
val symbols = symbol.ownersIterator.toList
Expand Down Expand Up @@ -70,7 +70,7 @@ object AutoImports:
importSel: Option[ImportSel]
):

def name: String = ident.value
def name(using Context): String = ident.value

object SymbolImport:

Expand Down Expand Up @@ -189,10 +189,13 @@ object AutoImports:
ownerImport.importSel,
)
else
(
SymbolIdent.direct(symbol.nameBackticked),
Some(ImportSel.Direct(symbol)),
)
renames(symbol) match
case Some(rename) => (SymbolIdent.direct(rename), None)
case None =>
(
SymbolIdent.direct(symbol.nameBackticked),
Some(ImportSel.Direct(symbol)),
)
end val

SymbolImport(
Expand Down Expand Up @@ -223,9 +226,13 @@ object AutoImports:
importSel
)
case None =>
val reverse = symbol.ownersIterator.toList.reverse
val fullName = reverse.drop(1).foldLeft(SymbolIdent.direct(reverse.head.nameBackticked)){
case (acc, sym) => SymbolIdent.Select(acc, sym.nameBackticked(false))
}
SymbolImport(
symbol,
SymbolIdent.direct(symbol.fullNameBackticked),
SymbolIdent.Direct(symbol.fullNameBackticked),
None
)
end match
Expand All @@ -252,7 +259,6 @@ object AutoImports:
val topPadding =
if importPosition.padTop then "\n"
else ""

val formatted = imports
.map {
case ImportSel.Direct(sym) => importName(sym)
Expand All @@ -267,15 +273,16 @@ object AutoImports:
end renderImports

private def importName(sym: Symbol): String =
if indexedContext.importContext.toplevelClashes(sym) then
if indexedContext.toplevelClashes(sym, inImportScope = true) then
s"_root_.${sym.fullNameBackticked(false)}"
else
sym.ownersIterator.zipWithIndex.foldLeft((List.empty[String], false)) { case ((acc, isDone), (sym, idx)) =>
if(isDone || sym.isEmptyPackage || sym.isRoot) (acc, true)
else indexedContext.rename(sym) match
case Some(renamed) => (renamed :: acc, true)
case None if !sym.isPackageObject => (sym.nameBackticked(false) :: acc, false)
case None => (acc, false)
// we can't import first part
case Some(renamed) if idx != 0 => (renamed :: acc, true)
case _ if !sym.isPackageObject => (sym.nameBackticked(false) :: acc, false)
case _ => (acc, false)
}._1.mkString(".")
end AutoImportsGenerator

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ final class AutoImportsProvider(
val path =
Interactive.pathTo(newctx.compilationUnit.tpdTree, pos.span)(using newctx)

val indexedContext = IndexedContext(
Interactive.contextOfPath(path)(using newctx)
val indexedContext = IndexedContext(pos)(
using Interactive.contextOfPath(path)(using newctx)
)
import indexedContext.ctx

Expand Down Expand Up @@ -96,7 +96,7 @@ final class AutoImportsProvider(
text,
tree,
unit.comments,
indexedContext.importContext,
indexedContext,
config
)
(sym: Symbol) => generator.forSymbol(sym)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ final class ExtractMethodProvider(
given locatedCtx: Context =
val newctx = driver.currentCtx.fresh.setCompilationUnit(unit)
Interactive.contextOfPath(path)(using newctx)
val indexedCtx = IndexedContext(locatedCtx)
val indexedCtx = IndexedContext(pos)(using locatedCtx)
val printer =
ShortenedTypePrinter(search, IncludeDefaultParam.Never)(using indexedCtx)
def prettyPrint(tpe: Type) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ object HoverProvider:
val path = unit
.map(unit => Interactive.pathTo(unit.tpdTree, pos.span))
.getOrElse(Interactive.pathTo(driver.openedTrees(uri), pos))
val indexedContext = IndexedContext(ctx)
val indexedContext = IndexedContext(pos)(using ctx)

def typeFromPath(path: List[Tree]) =
if path.isEmpty then NoType else path.head.typeOpt
Expand Down Expand Up @@ -96,7 +96,7 @@ object HoverProvider:

val printerCtx = Interactive.contextOfPath(path)
val printer = ShortenedTypePrinter(search, IncludeDefaultParam.Include)(
using IndexedContext(printerCtx)
using IndexedContext(pos)(using printerCtx)
)
MetalsInteractive.enclosingSymbolsWithExpressionType(
enclosing,
Expand Down Expand Up @@ -134,7 +134,7 @@ object HoverProvider:
.map(_.docstring())
.mkString("\n")

val expresionTypeOpt =
val expresionTypeOpt =
if symbol.name == StdNames.nme.??? then
InferExpectedType(search, driver, params).infer()
else printer.expressionType(exprTpw)
Expand Down
Loading
Loading