Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rochala committed Aug 27, 2024
1 parent cddbc28 commit a4ad2a5
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 64 deletions.
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/NavigateAST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ object NavigateAST {
def isBetterFit(currentBest: List[Positioned], candidate: List[Positioned]): Boolean =
if currentBest.isEmpty && candidate.nonEmpty then true
else if currentBest.nonEmpty && candidate.nonEmpty then
val bestSpan= currentBest.head.span
val bestSpan = currentBest.head.span
val candidateSpan = candidate.head.span

bestSpan != candidateSpan &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ final class AutoImportsProvider(
val results = symbols.result.filter(isExactMatch(_, name))

if results.nonEmpty then
val correctedPos = CompletionPos.infer(pos, params, path, false).toSourcePosition
val correctedPos =
CompletionPos.infer(pos, params, path, wasCursorApplied = false).toSourcePosition
val mkEdit =
path match
// if we are in import section just specify full name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ import org.eclipse.lsp4j.InsertTextMode
import org.eclipse.lsp4j.Range as LspRange
import org.eclipse.lsp4j.TextEdit

object CompletionProvider:
val allKeywords =
val softKeywords = Tokens.softModifierNames + nme.as + nme.derives + nme.extension + nme.throws + nme.using
Tokens.keywords.toList.map(Tokens.tokenString) ++ softKeywords.map(_.toString)

class CompletionProvider(
search: SymbolSearch,
cachingDriver: InteractiveDriver,
Expand Down Expand Up @@ -74,7 +79,20 @@ class CompletionProvider(

val tpdPath = tpdPath0 match
case Select(qual, name) :: tail
// If for any reason we end up in param after lifting, we want to inline the synthetic val
/** If for any reason we end up in param after lifting, we want to inline the synthetic val:
* List(1).iterator.sliding@@ will be transformed into:
*
* 1| val $1$: Iterator[Int] = List.apply[Int]([1 : Int]*).iterator
* 2| {
* 3| def $anonfun(size: Int, step: Int): $1$.GroupedIterator[Int] =
* 4| $1$.sliding[Int](size, step)
* 5| closure($anonfun)
* 6| }:((Int, Int) => Iterator[Int]#GroupedIterator[Int])
*
* With completion being run at line 4 at @@:
* 4| $1$.sliding@@[Int](size, step)
*
*/
if qual.symbol.is(Flags.Synthetic) && qual.symbol.name.isInstanceOf[DerivedName] =>
qual.symbol.defTree match
case valdef: ValDef => Select(valdef.rhs, name) :: tail
Expand Down Expand Up @@ -135,10 +153,6 @@ class CompletionProvider(
)
end completions

val allKeywords =
val softKeywords = Tokens.softModifierNames + nme.as + nme.derives + nme.extension + nme.throws + nme.using
Tokens.keywords.toList.map(Tokens.tokenString) ++ softKeywords.map(_.toString)

/**
* In case if completion comes from empty line like:
* {{{
Expand All @@ -156,8 +170,8 @@ class CompletionProvider(
val offset = params.offset().nn
val query = Completion.naiveCompletionPrefix(text, offset)

if offset > 0 && text.charAt(offset - 1).isUnicodeIdentifierPart && !allKeywords.contains(query) then
false -> text
if offset > 0 && text.charAt(offset - 1).isUnicodeIdentifierPart
&& !CompletionProvider.allKeywords.contains(query) then false -> text
else
val isStartMultilineComment =

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ object NamedArgCompletions:
.zipWithIndex
.forall { case (pair, index) =>
FuzzyArgMatcher(m.tparams)
.doMatch(allArgsProvided = index != 0)
.doMatch(allArgsProvided = index != 0, ident)
.tupled(pair)
} =>
m
Expand Down Expand Up @@ -385,12 +385,13 @@ class FuzzyArgMatcher(tparams: List[Symbols.Symbol])(using Context):
* We check the args types not the result type.
*/
def doMatch(
allArgsProvided: Boolean
allArgsProvided: Boolean,
ident: Option[Ident]
)(expectedArgs: List[Symbols.Symbol], actualArgs: List[Tree]) =
(expectedArgs.length == actualArgs.length ||
(!allArgsProvided && expectedArgs.length >= actualArgs.length)) &&
actualArgs.zipWithIndex.forall {
case (Ident(name), _) => true
case (arg: Ident, _) if ident.contains(arg) => true
case (NamedArg(name, arg), _) =>
expectedArgs.exists { expected =>
expected.name == name && (!arg.hasType || arg.typeOpt.unfold
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,11 @@ object OverrideCompletions:
object OverrideExtractor:
def unapply(path: List[Tree])(using Context) =
path match
// class FooImpl extends Foo:
// def x|
// abstract class Val:
// def hello: Int = 2
//
// class Main extends Val:
// def h|
case (dd: (DefDef | ValDef)) :: (t: Template) :: (td: TypeDef) :: _
if t.parents.nonEmpty =>
val completing =
Expand All @@ -547,12 +550,13 @@ object OverrideCompletions:
)
)

// class FooImpl extends Foo:
// abstract class Val:
// def hello: Int = 2
//
// class Main extends Val:
// ov|
case (ident: Ident) :: (t: Template) :: (td: TypeDef) :: _
if t.parents.nonEmpty && "override".startsWith(
ident.name.show.replace(Cursor.value, "")
) =>
if t.parents.nonEmpty && "override".startsWith(ident.name.show.replace(Cursor.value, "")) =>
Some(
(
td,
Expand All @@ -563,15 +567,13 @@ object OverrideCompletions:
)
)

// abstract class Val:
// def hello: Int = 2
//
// class Main extends Val:
// def@@
case (id: Ident) :: (t: Template) :: (td: TypeDef) :: _
if t.parents.nonEmpty && "def".startsWith(
id.name.decoded.replace(
Cursor.value,
"",
)
) =>
if t.parents.nonEmpty && "def".startsWith(id.name.decoded.replace(Cursor.value, "")) =>
Some(
(
td,
Expand All @@ -581,6 +583,10 @@ object OverrideCompletions:
None,
)
)

// abstract class Val:
// def hello: Int = 2
//
// class Main extends Val:
// he@@
case (id: Ident) :: (t: Template) :: (td: TypeDef) :: _
Expand All @@ -595,6 +601,9 @@ object OverrideCompletions:
)
)

// abstract class Val:
// def hello: Int = 2
//
// class Main extends Val:
// hello@ // this transforms into this.hello, thus is a Select
case (sel @ Select(th: This, name)) :: (t: Template) :: (td: TypeDef) :: _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@ class CompilerCachingSuite extends BasePCSuite:

val timeout = 5.seconds

private def checkCompilationCount(params: VirtualFileParams, expected: Int): Unit =
private def checkCompilationCount(expected: Int): Unit =
presentationCompiler match
case pc: ScalaPresentationCompiler =>
val compilations= pc.compilerAccess.withNonInterruptableCompiler(Some(params))(-1, EmptyCancelToken) { driver =>
val compilations = pc.compilerAccess.withNonInterruptableCompiler(None)(-1, EmptyCancelToken) { driver =>
driver.compiler().currentCtx.runId
}.get(timeout.length, timeout.unit)
assertEquals(expected, compilations, s"Expected $expected compilations but got $compilations")
case _ => throw IllegalStateException("Presentation compiler should always be of type of ScalaPresentationCompiler")

private def getContext(params: VirtualFileParams): Context =
private def getContext(): Context =
presentationCompiler match
case pc: ScalaPresentationCompiler =>
pc.compilerAccess.withNonInterruptableCompiler(Some(params))(null, EmptyCancelToken) { driver =>
pc.compilerAccess.withNonInterruptableCompiler(None)(null, EmptyCancelToken) { driver =>
driver.compiler().currentCtx
}.get(timeout.length, timeout.unit)
case _ => throw IllegalStateException("Presentation compiler should always be of type of ScalaPresentationCompiler")
Expand All @@ -44,76 +44,73 @@ class CompilerCachingSuite extends BasePCSuite:
def beforeEach: Unit =
presentationCompiler.restart()

// We want to run art least one compilation, so runId points at 3.
// We want to run at least one compilation, so runId points at 3.
// This will ensure that we use the same driver, not recreate fresh one on each call
val dryRunParams = CompilerOffsetParams(Paths.get("Test.scala").toUri(), "dryRun", 1, EmptyCancelToken)
checkCompilationCount(dryRunParams, 2)
val freshContext = getContext(dryRunParams)
checkCompilationCount(2)
val freshContext = getContext()
presentationCompiler.complete(dryRunParams).get(timeout.length, timeout.unit)
checkCompilationCount(dryRunParams, 3)
val dryRunContext = getContext(dryRunParams)
checkCompilationCount(3)
val dryRunContext = getContext()
assert(freshContext != dryRunContext)


@Test
def `cursor-compilation-does-not-corrupt-cache`: Unit =
val contextPreCompilation = getContext()

val fakeParamsCursor = CompilerOffsetParams(Paths.get("Test.scala").toUri(), "def hello = new", 15, EmptyCancelToken)
val fakeParams = CompilerOffsetParams(Paths.get("Test.scala").toUri(), "def hello = ne", 14, EmptyCancelToken)

val contextPreCompilation = getContext(fakeParams)

presentationCompiler.complete(fakeParams).get(timeout.length, timeout.unit)
val contextPostFirst = getContext(fakeParams)
val contextPostFirst = getContext()
assert(contextPreCompilation != contextPostFirst)
checkCompilationCount(fakeParams, 4)
checkCompilationCount(4)

val fakeParamsCursor = CompilerOffsetParams(Paths.get("Test.scala").toUri(), "def hello = new", 15, EmptyCancelToken)
presentationCompiler.complete(fakeParamsCursor).get(timeout.length, timeout.unit)
val contextPostCursor = getContext(fakeParamsCursor)
val contextPostCursor = getContext()
assert(contextPreCompilation != contextPostCursor)
assert(contextPostFirst == contextPostCursor)
checkCompilationCount(fakeParamsCursor, 4)
checkCompilationCount(4)

presentationCompiler.complete(fakeParams).get(timeout.length, timeout.unit)
val contextPostSecond = getContext(fakeParams)
val contextPostSecond = getContext()
assert(contextPreCompilation != contextPostSecond)
assert(contextPostFirst == contextPostCursor)
assert(contextPostCursor == contextPostSecond)
checkCompilationCount(fakeParamsCursor, 4)
checkCompilationCount(4)

@Test
def `compilation-for-same-snippet-is-cached`: Unit =
val fakeParams = CompilerOffsetParams(Paths.get("Test.scala").toUri(), "def hello = ne", 14, EmptyCancelToken)

val contextPreCompilation = getContext(fakeParams)
val contextPreCompilation = getContext()

val fakeParams = CompilerOffsetParams(Paths.get("Test.scala").toUri(), "def hello = ne", 14, EmptyCancelToken)
presentationCompiler.complete(fakeParams).get(timeout.length, timeout.unit)
val contextPostFirst = getContext(fakeParams)
val contextPostFirst = getContext()
assert(contextPreCompilation != contextPostFirst)
checkCompilationCount(fakeParams, 4)
checkCompilationCount(4)

presentationCompiler.complete(fakeParams).get(timeout.length, timeout.unit)
val contextPostSecond = getContext(fakeParams)
val contextPostSecond = getContext()
assert(contextPreCompilation != contextPostFirst)
assert(contextPostSecond == contextPostFirst)
checkCompilationCount(fakeParams, 4)
checkCompilationCount(4)

@Test
def `compilation-for-different-snippet-is-not-cached`: Unit =

val fakeParams = CompilerOffsetParams(Paths.get("Test.scala").toUri(), "def hello = prin", 16, EmptyCancelToken)
val fakeParams2 = CompilerOffsetParams(Paths.get("Test2.scala").toUri(), "def hello = prin", 16, EmptyCancelToken)
val fakeParams3 = CompilerOffsetParams(Paths.get("Test2.scala").toUri(), "def hello = print", 17, EmptyCancelToken)

checkCompilationCount(fakeParams, 3)
checkCompilationCount(3)
val fakeParams = CompilerOffsetParams(Paths.get("Test.scala").toUri(), "def hello = prin", 16, EmptyCancelToken)
presentationCompiler.complete(fakeParams).get(timeout.length, timeout.unit)
checkCompilationCount(fakeParams, 4)
checkCompilationCount(4)

val fakeParams2 = CompilerOffsetParams(Paths.get("Test2.scala").toUri(), "def hello = prin", 16, EmptyCancelToken)
presentationCompiler.complete(fakeParams2).get(timeout.length, timeout.unit)
checkCompilationCount(fakeParams2, 5)
checkCompilationCount(5)

val fakeParams3 = CompilerOffsetParams(Paths.get("Test2.scala").toUri(), "def hello = print", 17, EmptyCancelToken)
presentationCompiler.complete(fakeParams3).get(timeout.length, timeout.unit)
checkCompilationCount(fakeParams3, 6)
checkCompilationCount(6)


private val testFunctions: List[OffsetParams => CompletableFuture[_]] = List(
Expand All @@ -137,14 +134,14 @@ class CompilerCachingSuite extends BasePCSuite:
@Test
def `different-api-calls-reuse-cache`: Unit =
val fakeParams = CompilerOffsetParams(Paths.get("Test.scala").toUri(), "def hello = ne", 13, EmptyCancelToken)

presentationCompiler.complete(fakeParams).get(timeout.length, timeout.unit)
val contextBefore = getContext(fakeParams)

val contextBefore = getContext()

val differentContexts = testFunctions.map: f =>
f(fakeParams).get(timeout.length, timeout.unit)
checkCompilationCount(fakeParams, 4)
getContext(fakeParams)
checkCompilationCount(4)
getContext()
.toSet

assert(differentContexts == Set(contextBefore))
Expand All @@ -155,12 +152,12 @@ class CompilerCachingSuite extends BasePCSuite:
import scala.concurrent.ExecutionContext.Implicits.global

val fakeParams = CompilerOffsetParams(Paths.get("Test.scala").toUri(), "def hello = ne", 13, EmptyCancelToken)

presentationCompiler.complete(fakeParams).get(timeout.length, timeout.unit)
val contextBefore = getContext(fakeParams)

val contextBefore = getContext()

val futures = testFunctions.map: f =>
f(fakeParams).asScala.map(_ => getContext(fakeParams))
f(fakeParams).asScala.map(_ => getContext())

val res = Await.result(Future.sequence(futures), timeout).toSet
assert(res == Set(contextBefore))
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,22 @@ class CompletionArgSuite extends BaseCompletionSuite:
topLines = Some(1),
)

@Test def `second-first22` =
check(
"""|object Main {
| def foo(aaa: Int, bbb: Int, ccc: Int) = aaa + bbb + ccc
| val k = foo (
| bbb = 123,
| aa@@,
| ccc = 123,
| )
|}
|""".stripMargin,
"""|aaa = : Int
|""".stripMargin,
topLines = Some(1),
)

@Test def `second-first3` =
check(
"""|object Main {
Expand Down

0 comments on commit a4ad2a5

Please sign in to comment.