Skip to content

Commit

Permalink
feat(intellij): add support to collect declaration snippets. (#3394)
Browse files Browse the repository at this point in the history
* feat(intellij): add support to collect declaration snippets.

* fix(intellij): add try-catch for findTargetElement.
  • Loading branch information
icycodes authored Nov 13, 2024
1 parent 5c14645 commit 2ffe638
Show file tree
Hide file tree
Showing 10 changed files with 385 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
package com.tabbyml.intellijtabby

import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.editor.Document
import com.intellij.openapi.fileEditor.FileDocumentManager
import com.intellij.openapi.fileEditor.TextEditor
import com.intellij.openapi.fileEditor.ex.FileEditorManagerEx
import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.openapi.vfs.VirtualFileManager
import com.intellij.psi.PsiFile
import com.intellij.psi.PsiManager
import com.intellij.util.messages.Topic

fun <L : Any> Project.safeSyncPublisher(topic: Topic<L>): L? {
Expand All @@ -16,3 +26,50 @@ fun <L : Any> Project.safeSyncPublisher(topic: Topic<L>): L? {
}
}
}


fun Project.findVirtualFile(fileUri: String): VirtualFile? {
val virtualFileManager = VirtualFileManager.getInstance()
return virtualFileManager.findFileByUrl(fileUri)
}

fun Project.findDocument(fileUri: String): Document? {
return findVirtualFile(fileUri)?.let { findDocument(it) }
}

fun Project.findDocument(virtualFile: VirtualFile): Document? {
val fileDocumentManager = FileDocumentManager.getInstance()
return runReadAction { fileDocumentManager.getDocument(virtualFile) }
}

fun Project.findPsiFile(fileUri: String): PsiFile? {
return findVirtualFile(fileUri)?.let { findPsiFile(it) }
}

fun Project.findPsiFile(virtualFile: VirtualFile): PsiFile? {
val psiManager = PsiManager.getInstance(this)
return runReadAction { psiManager.findFile(virtualFile) }
}

fun Project.findEditor(fileUri: String): TextEditor? {
return findVirtualFile(fileUri)?.let { findEditor(it) }
}

fun Project.findEditor(virtualFile: VirtualFile): TextEditor? {
val fileEditorManager = FileEditorManagerEx.getInstanceEx(this)

return runInEdtAndWait {
fileEditorManager.getEditors(virtualFile)
}.firstOrNull { editor -> editor is TextEditor } as? TextEditor?
}

private fun <T> runInEdtAndWait(runnable: () -> T): T {
val app = ApplicationManager.getApplication()
if (app.isDispatchThread) {
return runnable()
} else {
var resultRef: T? = null
app.invokeAndWait { resultRef = runnable() }
@Suppress("UNCHECKED_CAST") return resultRef as T
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package com.tabbyml.intellijtabby.chat
import com.google.gson.Gson
import com.google.gson.annotations.SerializedName
import com.google.gson.reflect.TypeToken
import com.intellij.openapi.application.ReadAction
import com.intellij.openapi.application.invokeLater
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.command.WriteCommandAction
import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.Logger
Expand Down Expand Up @@ -165,18 +165,18 @@ class ChatBrowser(private val project: Project) : JBCefBrowser(

private fun getActiveFileContext(useSelectedText: Boolean = true): FileContext? {
return FileEditorManager.getInstance(project).selectedTextEditor?.let { editor ->
ReadAction.compute<Triple<String, Int, Int>?, Throwable> {
runReadAction {
val document = editor.document
if (useSelectedText) {
val selectionModel = editor.selectionModel
val text = selectionModel.selectedText.takeUnless { it.isNullOrBlank() } ?: return@compute null
val text = selectionModel.selectedText.takeUnless { it.isNullOrBlank() } ?: return@runReadAction null
Triple(
text,
document.getLineNumber(selectionModel.selectionStart) + 1,
document.getLineNumber(selectionModel.selectionEnd) + 1,
)
} else {
val text = document.text.takeUnless { it.isBlank() } ?: return@compute null
val text = document.text.takeUnless { it.isBlank() } ?: return@runReadAction null
Triple(
text,
1,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package com.tabbyml.intellijtabby.languageSupport

import com.intellij.codeInsight.TargetElementUtil
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiRecursiveElementWalkingVisitor
import com.tabbyml.intellijtabby.findEditor
import com.tabbyml.intellijtabby.languageSupport.LanguageSupportProvider.*
import org.eclipse.lsp4j.SemanticTokenTypes

/**
* The default implementation of [LanguageSupportProvider].
* This implementation relies on [TargetElementUtil] and tries to find the navigation target at each position in the
* editor to provide semantic tokens and declarations.
* This implementation may not work effectively for all languages.
*/
open class DefaultLanguageSupportProvider : LanguageSupportProvider {
private val logger = logger<DefaultLanguageSupportProvider>()
private val targetElementUtil = TargetElementUtil.getInstance()

override fun provideSemanticTokensRange(project: Project, fileRange: FileRange): List<SemanticToken>? {
val psiFile = fileRange.file
val editor = project.findEditor(psiFile.virtualFile) ?: return null

return runReadAction {
val leafElements = mutableListOf<PsiElement>()
psiFile.accept(object : PsiRecursiveElementWalkingVisitor() {
override fun visitElement(element: PsiElement) {
if (element.children.isEmpty() &&
element.text.matches(Regex("\\w+")) &&
fileRange.range.contains(element.textRange) &&
leafElements.none { it.textRange.intersects(element.textRange) }
) {
leafElements.add(element)
}
if (element.textRange.intersects(fileRange.range)) {
super.visitElement(element)
}
}
})

leafElements.mapNotNull {
val target = try {
targetElementUtil.findTargetElement(
editor.editor,
TargetElementUtil.ELEMENT_NAME_ACCEPTED or TargetElementUtil.REFERENCED_ELEMENT_ACCEPTED,
it.textRange.startOffset
)
} catch (e: Exception) {
logger.debug("Failed to find target element when providing semantic tokens", e)
null
}
if (target == it || target == null || target.text == null) {
null
} else {
SemanticToken(
text = it.text,
range = it.textRange,
type = SemanticTokenTypes.Type, // Default to use `Type` as the token type as we don't know the actual type
)
}
}
}
}

override fun provideDeclaration(project: Project, filePosition: FilePosition): List<FileRange>? {
val psiFile = filePosition.file
val editor = project.findEditor(psiFile.virtualFile) ?: return null

return runReadAction {
val target = try {
targetElementUtil.findTargetElement(
editor.editor,
TargetElementUtil.ELEMENT_NAME_ACCEPTED or TargetElementUtil.REFERENCED_ELEMENT_ACCEPTED,
filePosition.offset
)
} catch (e: Exception) {
logger.debug("Failed to find target element at ${psiFile.virtualFile.url}:${filePosition.offset}", e)
null
}
val file = target?.containingFile ?: return@runReadAction listOf()
val range = target.textRange
listOf(FileRange(file, range))
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.tabbyml.intellijtabby.languageSupport

import com.intellij.openapi.project.Project
import com.intellij.openapi.util.TextRange
import com.intellij.psi.PsiFile

interface LanguageSupportProvider {
data class FilePosition(
val file: PsiFile,
val offset: Int,
)

data class FileRange(
val file: PsiFile,
val range: TextRange,
)

data class SemanticToken(
val text: String,
val range: TextRange,
/**
* See [org.eclipse.lsp4j.SemanticTokenTypes]
*/
val type: String,
/**
* See [org.eclipse.lsp4j.SemanticTokenModifiers]
*/
val modifiers: List<String> = emptyList(),
)

/**
* Find all semantic tokens in the given [fileRange].
* For now, this function is only used to find tokens that reference a declaration, which will be used to invoke [provideDeclaration] later.
* So it is safe to only contain these tokens in the result, like class names, function names, etc.
*
* If no tokens are found, return an empty list.
* If the provider does not support the given document, return null.
*/
fun provideSemanticTokensRange(project: Project, fileRange: FileRange): List<SemanticToken>? {
return null
}

/**
* Get the declaration location for the token at the given [filePosition].
* If no declaration is found, return an empty list.
* If the provider does not support the given document, return null.
*/
fun provideDeclaration(project: Project, filePosition: FilePosition): List<FileRange>? {
return null
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.tabbyml.intellijtabby.languageSupport

import com.intellij.openapi.components.Service
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.extensions.ExtensionPointName
import com.intellij.openapi.project.Project
import com.tabbyml.intellijtabby.languageSupport.LanguageSupportProvider.*


@Service(Service.Level.PROJECT)
class LanguageSupportService(private val project: Project) {
private val logger = logger<LanguageSupportService>()
private val languageSupportProviderExtensionPoint: ExtensionPointName<LanguageSupportProvider> =
ExtensionPointName.create("com.tabbyml.intellij-tabby.languageSupportProvider")
private val defaultLanguageSupportProvider = DefaultLanguageSupportProvider()

fun provideSemanticTokensRange(fileRange: FileRange): List<SemanticToken>? {
var semanticTokens: List<SemanticToken>? = null
for (provider in languageSupportProviderExtensionPoint.extensionList) {
semanticTokens = provider.provideSemanticTokensRange(project, fileRange)
if (semanticTokens != null) {
logger.trace("Semantic tokens provided by ${provider.javaClass.name}: $semanticTokens")
break
}
}
if (semanticTokens == null) {
semanticTokens = defaultLanguageSupportProvider.provideSemanticTokensRange(project, fileRange)
logger.trace("Semantic tokens provided by default provider: $semanticTokens")
}
return semanticTokens
}

fun provideDeclaration(position: FilePosition): List<FileRange>? {
var declaration: List<FileRange>? = null
for (provider in languageSupportProviderExtensionPoint.extensionList) {
declaration = provider.provideDeclaration(project, position)
if (declaration != null) {
logger.trace("Declaration provided by ${provider.javaClass.name}: $declaration")
break
}
}
if (declaration == null) {
declaration = defaultLanguageSupportProvider.provideDeclaration(project, position)
logger.trace("Declaration provided by default provider: $declaration")
}
return declaration
}
}
Loading

0 comments on commit 2ffe638

Please sign in to comment.