Skip to content

Commit

Permalink
DEV2-4801 Mentions (#147)
Browse files Browse the repository at this point in the history
mentions
  • Loading branch information
amirbilu authored Feb 8, 2024
1 parent 10532b6 commit e9d084d
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 51 deletions.
29 changes: 3 additions & 26 deletions lua/tabnine/chat/codelens.lua
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
local chat = require("tabnine.chat")
local config = require("tabnine.config")
local consts = require("tabnine.consts")
local lsp = require("tabnine.lsp")
local state = require("tabnine.state")
local utils = require("tabnine.utils")
local api = vim.api
local fn = vim.fn

local M = {}
local SYMBOL_KIND = { FUNCTION = 12, CLASS = 5, METHOD = 6 }
local current_symbols = {}
local symbol_under_cursor = nil
local cancel_lsp_request = nil
Expand All @@ -34,34 +34,11 @@ function M.should_display()
and buf_supports_symbols
end

local function flatten_symbols(symbols, result)
result = result or {}

for _, symbol in ipairs(symbols) do
table.insert(result, symbol)

if symbol.children then flatten_symbols(symbol.children, result) end
end

return result
end

function M.collect_symbols(on_collect)
local params = vim.lsp.util.make_position_params()

if cancel_lsp_request then cancel_lsp_request() end

cancel_lsp_request = vim.lsp.buf_request_all(0, "textDocument/documentSymbol", params, function(responses)
current_symbols = {}
for _, response in ipairs(responses) do
if response.result then
for _, result in ipairs(flatten_symbols(response.result)) do
if result.kind == SYMBOL_KIND.FUNCTION or result.kind == SYMBOL_KIND.METHOD then
table.insert(current_symbols, result)
end
end
end
end
cancel_lsp_request = lsp.get_document_symbols("", function(symbols)
current_symbols = symbols
on_collect()
end)
end
Expand Down
72 changes: 71 additions & 1 deletion lua/tabnine/chat/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ local tabnine_binary = require("tabnine.binary")
local utils = require("tabnine.utils")
local api = vim.api
local config = require("tabnine.config")
local lsp = require("tabnine.lsp")
local get_symbols_request = nil

local M = { enabled = false }

Expand All @@ -14,6 +16,19 @@ local chat_state = nil
local chat_settings = nil
local initialized = false

local function to_chat_symbol_kind(kind)
if kind == lsp.SYMBOL_KIND.METHOD then
return "Method"
elseif kind == lsp.SYMBOL_KIND.FUNCTION then
return "Function"
elseif kind == lsp.SYMBOL_KIND.CLASS then
return "Class"
elseif kind == lsp.SYMBOL_KIND.FILE then
return "File"
else
return "Other"
end
end
local function get_diagnostics()
return vim.tbl_map(function(diagnostic)
return {
Expand Down Expand Up @@ -174,14 +189,69 @@ local function register_events(on_init)

chat_binary:register_event("get_selected_code", function(_, answer)
local selected_code = utils.selected_text()

answer({
code = selected_code,
startLine = vim.fn.getpos("'<")[2],
endLine = vim.fn.getpos("'>")[2],
})
end)

chat_binary:register_event("get_symbols", function(request, answer)
if get_symbols_request then get_symbols_request() end
get_symbols_request = lsp.get_document_symbols(request.query, function(document_symbols)
lsp.get_workspace_symbols(request.query, function(workspace_symbols)
answer({
workspaceSymbols = vim.tbl_map(function(symbol)
return {
name = symbol.name,
absolutePath = symbol.location.uri,
relativePath = utils.remove_matching_prefix(symbol.location.uri, fn.getcwd()),
kind = to_chat_symbol_kind(symbol.kind),
range = {
startLine = symbol.location.range.start.line,
startCharacter = symbol.location.range.start.character,
endLine = symbol.location.range["end"].line,
endCharacter = symbol.location.range["end"].character,
},
}
end, workspace_symbols),
documentSymbols = vim.tbl_map(function(symbol)
return {
name = symbol.name,
absolutePath = api.nvim_buf_get_name(0),
relativePath = vim.fn.expand("%"),
kind = to_chat_symbol_kind(symbol.kind),
range = {
startLine = symbol.range.start.line,
startCharacter = symbol.range.start.character,
endLine = symbol.range["end"].line,
endCharacter = symbol.range["end"].character,
},
}
end, document_symbols),
})
end)
end)
end)

chat_binary:register_event("get_symbols_text", function(request, answer)
answer(vim.tbl_map(function(symbol)
local buf = utils.read_file_into_buffer(symbol.absolutePath)
local text = utils.lines_to_str(
api.nvim_buf_get_text(
buf,
symbol.range.startLine,
symbol.range.startCharacter,
symbol.range.endLine,
symbol.range.endCharacter,
{}
)
)
api.nvim_buf_delete(buf, { force = true })
return { id = symbol.id, snippet = text }
end, request.symbols))
end)

chat_binary:register_event("send_event", function(event)
tabnine_binary:request({
Event = { name = event.eventName, properties = event.properties },
Expand Down
1 change: 0 additions & 1 deletion lua/tabnine/chat/setup.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
local auto_commands = require("tabnine.chat.auto_commands")
local chat = require("tabnine.chat")
local codelens = require("tabnine.chat.codelens")
local config = require("tabnine.config")
local features = require("tabnine.features")
local user_commands = require("tabnine.chat.user_commands")
Expand Down
1 change: 0 additions & 1 deletion lua/tabnine/chat/user_commands.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ local M = {}
local api = vim.api
local chat = require("tabnine.chat")
local codelens = require("tabnine.chat.codelens")
local config = require("tabnine.config")

function M.setup()
api.nvim_create_user_command("TabnineChat", function()
Expand Down
78 changes: 78 additions & 0 deletions lua/tabnine/lsp.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
local utils = require("tabnine.utils")
local M = {}

M.SYMBOL_KIND = { FUNCTION = 12, CLASS = 5, METHOD = 6 }

local function is_not_source(symbol_path)
local dirs = { "node_modules", "dist", "build", "target", "out" }
for i, dir in ipairs(dirs) do
if string.sub(symbol_path, 1, string.len(dir)) == dir then return true end
end
return false
end

local function flatten_symbols(symbols, result)
result = result or {}

for _, symbol in ipairs(symbols) do
table.insert(result, symbol)

if symbol.children then flatten_symbols(symbol.children, result) end
end

return result
end

function M.get_workspace_symbols(query, callback)
local params = { query = query }

return vim.lsp.buf_request_all(0, "workspace/symbol", params, function(responses)
local results = {}
for _, response in ipairs(responses) do
if response.result then
for _, result in ipairs(flatten_symbols(response.result)) do
result.location.uri = utils.remove_matching_prefix(result.location.uri, "file://")
if
(
result.kind == M.SYMBOL_KIND.CLASS
or result.kind == M.SYMBOL_KIND.METHOD
or result.kind == M.SYMBOL_KIND.FUNCTION
)
and utils.starts_with(result.location.uri, vim.fn.getcwd())
and not is_not_source(result.location.uri)
then
table.insert(results, result)
end
end
end
end
callback(results)
end)
end

function M.get_document_symbols(query, callback)
local params = {
textDocument = vim.lsp.util.make_text_document_params(),
}
return vim.lsp.buf_request_all(0, "textDocument/documentSymbol", params, function(responses)
local results = {}
for _, response in ipairs(responses) do
if response.result then
for _, result in ipairs(flatten_symbols(response.result)) do
if
(
result.kind == M.SYMBOL_KIND.CLASS
or result.kind == M.SYMBOL_KIND.METHOD
or result.kind == M.SYMBOL_KIND.FUNCTION
) and utils.starts_with(result.name, query)
then
table.insert(results, result)
end
end
end
end
callback(results)
end)
end

return M
28 changes: 6 additions & 22 deletions lua/tabnine/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ end

function M.remove_matching_prefix(str, prefix)
if not M.starts_with(str, prefix) then return str end
return str:sub(#prefix)
return str:sub(#prefix + 1)
end

function M.subset(tbl, from, to)
Expand Down Expand Up @@ -136,29 +136,13 @@ function M.select_range(range)
api.nvim_win_set_cursor(0, { end_row, end_col - 1 })
end

function M.select_range(range)
local start_row, start_col, end_row, end_col = range[1][1], range[1][2], range[2][1], range[2][2]
function M.read_file_into_buffer(file_path)
local content = vim.fn.readfile(file_path)
local bufnr = vim.api.nvim_create_buf(false, true)

local v_table = { charwise = "v", linewise = "V", blockwise = "<C-v>" }
selection_mode = selection_mode or "charwise"
api.nvim_buf_set_lines(bufnr, 0, -1, false, content)

-- Normalise selection_mode
if vim.tbl_contains(vim.tbl_keys(v_table), selection_mode) then selection_mode = v_table[selection_mode] end

-- enter visual mode if normal or operator-pending (no) mode
-- Why? According to https://learnvimscriptthehardway.stevelosh.com/chapters/15.html
-- If your operator-pending mapping ends with some text visually selected, Vim will operate on that text.
-- Otherwise, Vim will operate on the text between the original cursor position and the new position.
local mode = api.nvim_get_mode()
if mode.mode ~= selection_mode then
-- Call to `nvim_replace_termcodes()` is needed for sending appropriate command to enter blockwise mode
selection_mode = vim.api.nvim_replace_termcodes(selection_mode, true, true, true)
api.nvim_cmd({ cmd = "normal", bang = true, args = { selection_mode } }, {})
end

api.nvim_win_set_cursor(0, { start_row, start_col - 1 })
vim.cmd("normal! o")
api.nvim_win_set_cursor(0, { end_row, end_col - 1 })
return bufnr
end

return M

0 comments on commit e9d084d

Please sign in to comment.