From e9d084db05a867139228f4e2e33520b5400fe5fb Mon Sep 17 00:00:00 2001 From: Amir Bilu Date: Thu, 8 Feb 2024 22:30:27 +0200 Subject: [PATCH] DEV2-4801 Mentions (#147) mentions --- lua/tabnine/chat/codelens.lua | 29 ++--------- lua/tabnine/chat/init.lua | 72 ++++++++++++++++++++++++++- lua/tabnine/chat/setup.lua | 1 - lua/tabnine/chat/user_commands.lua | 1 - lua/tabnine/lsp.lua | 78 ++++++++++++++++++++++++++++++ lua/tabnine/utils.lua | 28 +++-------- 6 files changed, 158 insertions(+), 51 deletions(-) create mode 100644 lua/tabnine/lsp.lua diff --git a/lua/tabnine/chat/codelens.lua b/lua/tabnine/chat/codelens.lua index 24e3034..bc59891 100644 --- a/lua/tabnine/chat/codelens.lua +++ b/lua/tabnine/chat/codelens.lua @@ -1,13 +1,13 @@ local chat = require("tabnine.chat") local config = require("tabnine.config") local consts = require("tabnine.consts") +local lsp = require("tabnine.lsp") local state = require("tabnine.state") local utils = require("tabnine.utils") local api = vim.api local fn = vim.fn local M = {} -local SYMBOL_KIND = { FUNCTION = 12, CLASS = 5, METHOD = 6 } local current_symbols = {} local symbol_under_cursor = nil local cancel_lsp_request = nil @@ -34,34 +34,11 @@ function M.should_display() and buf_supports_symbols end -local function flatten_symbols(symbols, result) - result = result or {} - - for _, symbol in ipairs(symbols) do - table.insert(result, symbol) - - if symbol.children then flatten_symbols(symbol.children, result) end - end - - return result -end - function M.collect_symbols(on_collect) - local params = vim.lsp.util.make_position_params() - if cancel_lsp_request then cancel_lsp_request() end - cancel_lsp_request = vim.lsp.buf_request_all(0, "textDocument/documentSymbol", params, function(responses) - current_symbols = {} - for _, response in ipairs(responses) do - if response.result then - for _, result in ipairs(flatten_symbols(response.result)) do - if result.kind == SYMBOL_KIND.FUNCTION or result.kind == SYMBOL_KIND.METHOD then - table.insert(current_symbols, result) - end - end - end - end + cancel_lsp_request = lsp.get_document_symbols("", function(symbols) + current_symbols = symbols on_collect() end) end diff --git a/lua/tabnine/chat/init.lua b/lua/tabnine/chat/init.lua index 3b36647..7a9400e 100644 --- a/lua/tabnine/chat/init.lua +++ b/lua/tabnine/chat/init.lua @@ -4,6 +4,8 @@ local tabnine_binary = require("tabnine.binary") local utils = require("tabnine.utils") local api = vim.api local config = require("tabnine.config") +local lsp = require("tabnine.lsp") +local get_symbols_request = nil local M = { enabled = false } @@ -14,6 +16,19 @@ local chat_state = nil local chat_settings = nil local initialized = false +local function to_chat_symbol_kind(kind) + if kind == lsp.SYMBOL_KIND.METHOD then + return "Method" + elseif kind == lsp.SYMBOL_KIND.FUNCTION then + return "Function" + elseif kind == lsp.SYMBOL_KIND.CLASS then + return "Class" + elseif kind == lsp.SYMBOL_KIND.FILE then + return "File" + else + return "Other" + end +end local function get_diagnostics() return vim.tbl_map(function(diagnostic) return { @@ -174,7 +189,6 @@ local function register_events(on_init) chat_binary:register_event("get_selected_code", function(_, answer) local selected_code = utils.selected_text() - answer({ code = selected_code, startLine = vim.fn.getpos("'<")[2], @@ -182,6 +196,62 @@ local function register_events(on_init) }) end) + chat_binary:register_event("get_symbols", function(request, answer) + if get_symbols_request then get_symbols_request() end + get_symbols_request = lsp.get_document_symbols(request.query, function(document_symbols) + lsp.get_workspace_symbols(request.query, function(workspace_symbols) + answer({ + workspaceSymbols = vim.tbl_map(function(symbol) + return { + name = symbol.name, + absolutePath = symbol.location.uri, + relativePath = utils.remove_matching_prefix(symbol.location.uri, fn.getcwd()), + kind = to_chat_symbol_kind(symbol.kind), + range = { + startLine = symbol.location.range.start.line, + startCharacter = symbol.location.range.start.character, + endLine = symbol.location.range["end"].line, + endCharacter = symbol.location.range["end"].character, + }, + } + end, workspace_symbols), + documentSymbols = vim.tbl_map(function(symbol) + return { + name = symbol.name, + absolutePath = api.nvim_buf_get_name(0), + relativePath = vim.fn.expand("%"), + kind = to_chat_symbol_kind(symbol.kind), + range = { + startLine = symbol.range.start.line, + startCharacter = symbol.range.start.character, + endLine = symbol.range["end"].line, + endCharacter = symbol.range["end"].character, + }, + } + end, document_symbols), + }) + end) + end) + end) + + chat_binary:register_event("get_symbols_text", function(request, answer) + answer(vim.tbl_map(function(symbol) + local buf = utils.read_file_into_buffer(symbol.absolutePath) + local text = utils.lines_to_str( + api.nvim_buf_get_text( + buf, + symbol.range.startLine, + symbol.range.startCharacter, + symbol.range.endLine, + symbol.range.endCharacter, + {} + ) + ) + api.nvim_buf_delete(buf, { force = true }) + return { id = symbol.id, snippet = text } + end, request.symbols)) + end) + chat_binary:register_event("send_event", function(event) tabnine_binary:request({ Event = { name = event.eventName, properties = event.properties }, diff --git a/lua/tabnine/chat/setup.lua b/lua/tabnine/chat/setup.lua index 7b30a1d..4c322eb 100644 --- a/lua/tabnine/chat/setup.lua +++ b/lua/tabnine/chat/setup.lua @@ -1,6 +1,5 @@ local auto_commands = require("tabnine.chat.auto_commands") local chat = require("tabnine.chat") -local codelens = require("tabnine.chat.codelens") local config = require("tabnine.config") local features = require("tabnine.features") local user_commands = require("tabnine.chat.user_commands") diff --git a/lua/tabnine/chat/user_commands.lua b/lua/tabnine/chat/user_commands.lua index bde417a..7103300 100644 --- a/lua/tabnine/chat/user_commands.lua +++ b/lua/tabnine/chat/user_commands.lua @@ -2,7 +2,6 @@ local M = {} local api = vim.api local chat = require("tabnine.chat") local codelens = require("tabnine.chat.codelens") -local config = require("tabnine.config") function M.setup() api.nvim_create_user_command("TabnineChat", function() diff --git a/lua/tabnine/lsp.lua b/lua/tabnine/lsp.lua new file mode 100644 index 0000000..63e2b47 --- /dev/null +++ b/lua/tabnine/lsp.lua @@ -0,0 +1,78 @@ +local utils = require("tabnine.utils") +local M = {} + +M.SYMBOL_KIND = { FUNCTION = 12, CLASS = 5, METHOD = 6 } + +local function is_not_source(symbol_path) + local dirs = { "node_modules", "dist", "build", "target", "out" } + for i, dir in ipairs(dirs) do + if string.sub(symbol_path, 1, string.len(dir)) == dir then return true end + end + return false +end + +local function flatten_symbols(symbols, result) + result = result or {} + + for _, symbol in ipairs(symbols) do + table.insert(result, symbol) + + if symbol.children then flatten_symbols(symbol.children, result) end + end + + return result +end + +function M.get_workspace_symbols(query, callback) + local params = { query = query } + + return vim.lsp.buf_request_all(0, "workspace/symbol", params, function(responses) + local results = {} + for _, response in ipairs(responses) do + if response.result then + for _, result in ipairs(flatten_symbols(response.result)) do + result.location.uri = utils.remove_matching_prefix(result.location.uri, "file://") + if + ( + result.kind == M.SYMBOL_KIND.CLASS + or result.kind == M.SYMBOL_KIND.METHOD + or result.kind == M.SYMBOL_KIND.FUNCTION + ) + and utils.starts_with(result.location.uri, vim.fn.getcwd()) + and not is_not_source(result.location.uri) + then + table.insert(results, result) + end + end + end + end + callback(results) + end) +end + +function M.get_document_symbols(query, callback) + local params = { + textDocument = vim.lsp.util.make_text_document_params(), + } + return vim.lsp.buf_request_all(0, "textDocument/documentSymbol", params, function(responses) + local results = {} + for _, response in ipairs(responses) do + if response.result then + for _, result in ipairs(flatten_symbols(response.result)) do + if + ( + result.kind == M.SYMBOL_KIND.CLASS + or result.kind == M.SYMBOL_KIND.METHOD + or result.kind == M.SYMBOL_KIND.FUNCTION + ) and utils.starts_with(result.name, query) + then + table.insert(results, result) + end + end + end + end + callback(results) + end) +end + +return M diff --git a/lua/tabnine/utils.lua b/lua/tabnine/utils.lua index af9292c..3acd5a2 100644 --- a/lua/tabnine/utils.lua +++ b/lua/tabnine/utils.lua @@ -30,7 +30,7 @@ end function M.remove_matching_prefix(str, prefix) if not M.starts_with(str, prefix) then return str end - return str:sub(#prefix) + return str:sub(#prefix + 1) end function M.subset(tbl, from, to) @@ -136,29 +136,13 @@ function M.select_range(range) api.nvim_win_set_cursor(0, { end_row, end_col - 1 }) end -function M.select_range(range) - local start_row, start_col, end_row, end_col = range[1][1], range[1][2], range[2][1], range[2][2] +function M.read_file_into_buffer(file_path) + local content = vim.fn.readfile(file_path) + local bufnr = vim.api.nvim_create_buf(false, true) - local v_table = { charwise = "v", linewise = "V", blockwise = "" } - selection_mode = selection_mode or "charwise" + api.nvim_buf_set_lines(bufnr, 0, -1, false, content) - -- Normalise selection_mode - if vim.tbl_contains(vim.tbl_keys(v_table), selection_mode) then selection_mode = v_table[selection_mode] end - - -- enter visual mode if normal or operator-pending (no) mode - -- Why? According to https://learnvimscriptthehardway.stevelosh.com/chapters/15.html - -- If your operator-pending mapping ends with some text visually selected, Vim will operate on that text. - -- Otherwise, Vim will operate on the text between the original cursor position and the new position. - local mode = api.nvim_get_mode() - if mode.mode ~= selection_mode then - -- Call to `nvim_replace_termcodes()` is needed for sending appropriate command to enter blockwise mode - selection_mode = vim.api.nvim_replace_termcodes(selection_mode, true, true, true) - api.nvim_cmd({ cmd = "normal", bang = true, args = { selection_mode } }, {}) - end - - api.nvim_win_set_cursor(0, { start_row, start_col - 1 }) - vim.cmd("normal! o") - api.nvim_win_set_cursor(0, { end_row, end_col - 1 }) + return bufnr end return M