Skip to content

Commit

Permalink
improvement: add debug adapter for running main class to metals
Browse files Browse the repository at this point in the history
  • Loading branch information
kasiaMarek committed May 20, 2024
1 parent 99f098a commit a641cf6
Show file tree
Hide file tree
Showing 9 changed files with 523 additions and 19 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ lazy val metals = project
V.lsp4j,
// for DAP
V.dap4j,
"ch.epfl.scala" %% "scala-debug-adapter" % V.debugAdapter,
// for finding paths of global log/cache directories
"dev.dirs" % "directories" % "26",
// for Java formatting
Expand Down Expand Up @@ -731,7 +732,6 @@ lazy val metalsDependencies = project
"ch.epfl.scala" % "bloop-maven-plugin" % V.mavenBloop,
"ch.epfl.scala" %% "gradle-bloop" % V.gradleBloop,
"com.sourcegraph" % "semanticdb-java" % V.javaSemanticdb,
"ch.epfl.scala" %% "scala-debug-adapter" % V.debugAdapter intransitive (),
"org.foundweekends.giter8" %% "giter8" % V.gitter8Version intransitive (),
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,15 @@ object MetalsEnrichments
)
}

implicit class XtensionDebugSessionParams(params: b.DebugSessionParams) {
def asScalaMainClass(): Option[b.ScalaMainClass] =
params.getDataKind() match {
case b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS =>
decodeJson(params.getData(), classOf[b.ScalaMainClass])
case _ => None
}
}

/**
* Strips ANSI colors.
* As long as the color codes are valid this should correctly strip
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ final class RunTestCodeLens(
occurence: SymbolOccurrence,
textDocument: TextDocument,
target: BuildTargetIdentifier,
buildServerCanDebug: Boolean,
): Seq[l.Command] = {
if (occurence.symbol.endsWith("#main().")) {
textDocument.symbols
Expand All @@ -182,7 +181,6 @@ final class RunTestCodeLens(
Nil.asJava,
Nil.asJava,
),
buildServerCanDebug,
isJVM = true,
)
else
Expand Down Expand Up @@ -210,9 +208,9 @@ final class RunTestCodeLens(
commands = {
val main = classes.mainClasses
.get(symbol)
.map(mainCommand(target, _, buildServerCanDebug, isJVM))
.map(mainCommand(target, _, isJVM))
.getOrElse(Nil)
val tests =
lazy val tests =
// Currently tests can only be run via DAP
if (clientConfig.isDebuggingProvider() && buildServerCanDebug)
testClasses(target, classes, symbol, isJVM)
Expand All @@ -222,12 +220,12 @@ final class RunTestCodeLens(
.flatMap { symbol =>
classes.mainClasses
.get(symbol)
.map(mainCommand(target, _, buildServerCanDebug, isJVM))
.map(mainCommand(target, _, isJVM))
}
.getOrElse(Nil)
val javaMains =
if (path.isJava)
javaLenses(occurrence, textDocument, target, buildServerCanDebug)
javaLenses(occurrence, textDocument, target)
else Nil
main ++ tests ++ fromAnnot ++ javaMains
}
Expand Down Expand Up @@ -260,15 +258,15 @@ final class RunTestCodeLens(
val main =
classes.mainClasses
.get(expectedMainClass)
.map(mainCommand(target, _, buildServerCanDebug, isJVM))
.map(mainCommand(target, _, isJVM))
.getOrElse(Nil)

val fromAnnotations = textDocument.occurrences.flatMap { occ =>
for {
sym <- DebugProvider.mainFromAnnotation(occ, textDocument)
cls <- classes.mainClasses.get(sym)
range <- occurrenceRange(occ, distance)
} yield mainCommand(target, cls, buildServerCanDebug, isJVM).map { cmd =>
} yield mainCommand(target, cls, isJVM).map { cmd =>
new l.CodeLens(range, cmd, null)
}
}.flatten
Expand Down Expand Up @@ -325,7 +323,6 @@ final class RunTestCodeLens(
private def mainCommand(
target: b.BuildTargetIdentifier,
main: b.ScalaMainClass,
buildServerCanDebug: Boolean,
isJVM: Boolean,
): List[l.Command] = {
val javaBinary = buildTargets
Expand Down Expand Up @@ -353,7 +350,7 @@ final class RunTestCodeLens(
sessionParams(target, dataKind, data)
}

if (clientConfig.isDebuggingProvider() && buildServerCanDebug && isJVM)
if (clientConfig.isDebuggingProvider() && isJVM)
List(
command("run", StartRunSession, params),
command("debug", StartDebugSession, params),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import scala.collection.concurrent.TrieMap
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise
import scala.concurrent.duration.Duration
import scala.util.Failure
import scala.util.Success
import scala.util.Try
Expand Down Expand Up @@ -50,6 +51,10 @@ import scala.meta.internal.metals.clients.language.MetalsQuickPickParams
import scala.meta.internal.metals.clients.language.MetalsStatusParams
import scala.meta.internal.metals.config.RunType
import scala.meta.internal.metals.config.RunType._
import scala.meta.internal.metals.debug.server.DebugeeParamsCreator
import scala.meta.internal.metals.debug.server.MainClassDebugAdapter
import scala.meta.internal.metals.debug.server.MetalsDebugToolsResolver
import scala.meta.internal.metals.debug.server.MetalsDebuggee
import scala.meta.internal.metals.testProvider.TestSuitesProvider
import scala.meta.internal.mtags.DefinitionAlternatives.GlobalSymbol
import scala.meta.internal.mtags.OnDemandSymbolIndex
Expand All @@ -63,6 +68,7 @@ import scala.meta.io.AbsolutePath
import ch.epfl.scala.bsp4j.BuildTargetIdentifier
import ch.epfl.scala.bsp4j.DebugSessionParams
import ch.epfl.scala.bsp4j.ScalaMainClass
import ch.epfl.scala.debugadapter
import ch.epfl.scala.{bsp4j => b}
import com.google.common.net.InetAddresses
import com.google.gson.JsonElement
Expand Down Expand Up @@ -90,11 +96,14 @@ class DebugProvider(
sourceMapper: SourceMapper,
userConfig: () => UserConfiguration,
testProvider: TestSuitesProvider,
) extends Cancelable
)(implicit ec: ExecutionContext)
extends Cancelable
with LogForwarder {

import DebugProvider._

private val debugConfigCreator = new DebugeeParamsCreator(buildTargets)

private val runningLocal = new ju.concurrent.atomic.AtomicBoolean(false)

private val debugSessions = new MutableCancelable()
Expand Down Expand Up @@ -250,13 +259,13 @@ class DebugProvider(
val targets = parameters.getTargets().asScala.toSeq

compilations.compilationFinished(targets).flatMap { _ =>
val conn = buildServer
.startDebugSession(parameters, cancelPromise)
.map { uri =>
val socket = connect(uri)
connectedToServer.trySuccess(())
socket
}
val conn =
startDebugSession(buildServer, parameters, cancelPromise)
.map { uri =>
val socket = connect(uri)
connectedToServer.trySuccess(())
socket
}

conn
.withTimeout(60, TimeUnit.SECONDS)
Expand Down Expand Up @@ -311,6 +320,51 @@ class DebugProvider(
connectedToServer.future.map(_ => server)
}

private def startDebugSession(
buildServer: BuildServerConnection,
params: DebugSessionParams,
cancelPromise: Promise[Unit],
) =
if (buildServer.isDebuggingProvider) {
buildServer.startDebugSession(params, cancelPromise)
} else {
def getDebugee: Future[MetalsDebuggee] = params.getDataKind() match {
case b.DebugSessionParamsDataKind.SCALA_MAIN_CLASS =>
val optDebuggee = for {
id <- params.getTargets().asScala.headOption
projectInfo <- debugConfigCreator.create(id)
scalaMainClass <- params.asScalaMainClass()
} yield {
projectInfo.map(
new MainClassDebugAdapter(
workspace,
scalaMainClass,
_,
userConfig().javaHome,
)
)
}
optDebuggee.getOrElse(
throw new RuntimeException(s"Can't resolve debugee")
)
case _ => throw new RuntimeException(s"Can't resolve debugee")
}

for (debuggee <- getDebugee) yield {
val dapLogger =
new scala.meta.internal.metals.debug.server.DebugLogger()
val resolver = new MetalsDebugToolsResolver()
val handler =
debugadapter.DebugServer.run(
debuggee,
resolver,
dapLogger,
gracePeriod = Duration(5, TimeUnit.SECONDS),
)
handler.uri
}
}

/**
* Given a BuildTargetIdentifier either get the displayName of that build
* target or default to the full URI to display to the user.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package scala.meta.internal.metals.debug.server

import ch.epfl.scala.debugadapter.Logger

class DebugLogger extends Logger {

override def debug(msg: => String): Unit = scribe.debug(msg)

override def info(msg: => String): Unit = scribe.info(msg)

override def warn(msg: => String): Unit = scribe.warn(msg)

override def error(msg: => String): Unit = scribe.error(msg)

override def trace(t: => Throwable): Unit =
scribe.trace(s"$t: ${t.getStackTrace().mkString("\n\t")}")

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package scala.meta.internal.metals.debug.server

import java.nio.file.Path

import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import scala.concurrent.Promise

import scala.meta.internal.metals.BuildTargets
import scala.meta.internal.metals.JavaTarget
import scala.meta.internal.metals.MetalsEnrichments._
import scala.meta.internal.metals.ScalaTarget

import ch.epfl.scala.bsp4j.BuildTargetIdentifier
import ch.epfl.scala.bsp4j.MavenDependencyModule
import ch.epfl.scala.debugadapter.Library
import ch.epfl.scala.debugadapter.Module
import ch.epfl.scala.debugadapter.ScalaVersion
import ch.epfl.scala.debugadapter.SourceDirectory
import ch.epfl.scala.debugadapter.SourceJar
import ch.epfl.scala.debugadapter.UnmanagedEntry

class DebugeeParamsCreator(buildTargets: BuildTargets)(implicit
ec: ExecutionContext
) {
def create(id: BuildTargetIdentifier): Option[Future[DebugeeProject]] = {
val optScalaTarget = buildTargets.scalaTarget(id)
val optJavaTarget = buildTargets.javaTarget(id)
for {
name <- optScalaTarget
.map(_.displayName)
.orElse(optJavaTarget.map(_.displayName))
data <- buildTargets.targetData(id)
} yield {

val libraries = data.buildTargetDependencyModules
.get(id)
.filter(_.nonEmpty)
.getOrElse(Nil)
val debugLibs = libraries.flatMap(createLibrary(_))
val includedInLibs = debugLibs
.flatMap(_.sourceEntries.flatMap {
case SourceJar(jar) => Some(jar)
case _ => None
})
.toSet

val cancelPromise: Promise[Unit] = Promise()

for (
classpath <- buildTargets
.targetClasspath(id, cancelPromise)
.getOrElse(Future.successful(Nil))
.map(_.toAbsoluteClasspath.map(_.toNIO).toSeq)
) yield {

val filteredClassPath = classpath.collect {
case path if includedInLibs(path) => UnmanagedEntry(path)
}.toList

val modules = buildTargets
.allInverseDependencies(id)
.flatMap(id =>
buildTargets.scalaTarget(id).map(createModule(_)).orElse {
buildTargets.javaTarget(id).map(createModule(_))
}
)
.toSeq

new DebugeeProject(
buildTargets.scalaTarget(id).map(_.scalaVersion),
name,
modules,
libraries.flatMap(createLibrary(_)),
filteredClassPath,
classpath,
)
}
}
}

def createLibrary(lib: MavenDependencyModule): Option[Library] = {
def getWithClassifier(s: String) =
Option(lib.getArtifacts())
.flatMap(_.asScala.find(_.getClassifier() == s))
.flatMap(_.getUri().toAbsolutePathSafe)
for {
sources <- getWithClassifier("sources")
jar <- getWithClassifier(null)
} yield new Library(
lib.getName(),
lib.getVersion(),
jar.toNIO,
Seq(SourceJar(sources.toNIO)),
)
}

def createModule(target: ScalaTarget): Module = {
val scalaVersion = ScalaVersion(target.scalaVersion)
new Module(
target.displayName,
Some(scalaVersion),
target.scalac.getOptions().asScala.toSeq,
target.classDirectory.toAbsolutePath.toNIO,
sources(target.id),
)
}

def createModule(target: JavaTarget) =
new Module(
target.displayName,
None,
Nil,
target.classDirectory.toAbsolutePath.toNIO,
sources(target.id),
)

private def sources(id: BuildTargetIdentifier) =
buildTargets.sourceItemsToBuildTargets
.filter(_._2.iterator.asScala.contains(id))
.collect { case (path, _) =>
SourceDirectory(path.toNIO)
}
.toSeq
}

case class DebugeeProject(
scalaVersion: Option[String],
name: String,
modules: Seq[Module],
libraries: Seq[Library],
unmanagedEntries: Seq[UnmanagedEntry],
classpath: Seq[Path],
)
Loading

0 comments on commit a641cf6

Please sign in to comment.