Skip to content

Commit

Permalink
feat: ask selected code block
Browse files Browse the repository at this point in the history
  • Loading branch information
yetone committed Aug 17, 2024
1 parent dea737b commit bb4056b
Show file tree
Hide file tree
Showing 8 changed files with 345 additions and 75 deletions.
82 changes: 61 additions & 21 deletions lua/avante/ai_bot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,12 @@ Replace lines: {{start_line}}-{{end_line}}
Remember: Accurate line numbers are CRITICAL. The range start_line to end_line must include ALL lines to be replaced, from the very first to the very last. Double-check every range before finalizing your response, paying special attention to the start_line to ensure it hasn't shifted down. Ensure that your line numbers perfectly match the original code structure without any overall shift.
]]

local function call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
local function call_claude_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
local api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key then
error("ANTHROPIC_API_KEY environment variable is not set")
end

local user_prompt = base_user_prompt

local tokens = Config.claude.max_tokens
local headers = {
["Content-Type"] = "application/json",
Expand All @@ -79,33 +77,56 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun
text = string.format("<code>```%s\n%s```</code>", code_lang, code_content),
}

if Tiktoken.count(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end

if selected_code_content then
code_prompt_obj.text = string.format("<code_context>```%s\n%s```</code_context>", code_lang, code_content)
end

local message_content = {
code_prompt_obj,
}

if selected_code_content then
local selected_code_obj = {
type = "text",
text = string.format("<code>```%s\n%s```</code>", code_lang, selected_code_content),
}

if Tiktoken.count(selected_code_obj.text) > 1024 then
selected_code_obj.cache_control = { type = "ephemeral" }
end

table.insert(message_content, selected_code_obj)
end

table.insert(message_content, {
type = "text",
text = string.format("<question>%s</question>", question),
})

local user_prompt = base_user_prompt

local user_prompt_obj = {
type = "text",
text = user_prompt,
}

if Tiktoken.count(code_prompt_obj.text) > 1024 then
code_prompt_obj.cache_control = { type = "ephemeral" }
end

if Tiktoken.count(user_prompt_obj.text) > 1024 then
user_prompt_obj.cache_control = { type = "ephemeral" }
end

table.insert(message_content, user_prompt_obj)

local body = {
model = Config.claude.model,
system = system_prompt,
messages = {
{
role = "user",
content = {
code_prompt_obj,
{
type = "text",
text = string.format("<question>%s</question>", question),
},
user_prompt_obj,
},
content = message_content,
},
},
stream = true,
Expand Down Expand Up @@ -154,21 +175,39 @@ local function call_claude_api_stream(question, code_lang, code_content, on_chun
})
end

local function call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
local function call_openai_api_stream(question, code_lang, code_content, selected_code_content, on_chunk, on_complete)
local api_key = os.getenv("OPENAI_API_KEY")
if not api_key and Config.provider == "openai" then
error("OPENAI_API_KEY environment variable is not set")
end

local user_prompt = base_user_prompt
.. "\n\nQUESTION:\n"
.. question
.. "\n\nCODE:\n"
.. "```"
.. code_lang
.. "\n"
.. code_content
.. "\n```"
.. "\n\nQUESTION:\n"
.. question

if selected_code_content then
user_prompt = base_user_prompt
.. "\n\nCODE CONTEXT:\n"
.. "```"
.. code_lang
.. "\n"
.. code_content
.. "\n```"
.. "\n\nCODE:\n"
.. "```"
.. code_lang
.. "\n"
.. selected_code_content
.. "\n```"
.. "\n\nQUESTION:\n"
.. question
end

local url, headers, body
if Config.provider == "azure" then
Expand Down Expand Up @@ -258,13 +297,14 @@ end
---@param question string
---@param code_lang string
---@param code_content string
---@param selected_content_content string | nil
---@param on_chunk fun(chunk: string): any
---@param on_complete fun(err: string|nil): any
function M.call_ai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
function M.call_ai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
if Config.provider == "openai" or Config.provider == "azure" then
call_openai_api_stream(question, code_lang, code_content, on_chunk, on_complete)
call_openai_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
elseif Config.provider == "claude" then
call_claude_api_stream(question, code_lang, code_content, on_chunk, on_complete)
call_claude_api_stream(question, code_lang, code_content, selected_content_content, on_chunk, on_complete)
end
end

Expand Down
7 changes: 6 additions & 1 deletion lua/avante/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ M.defaults = {
},
},
mappings = {
show_sidebar = "<leader>aa",
ask = "<leader>aa",
edit = "<leader>ae",
diff = {
ours = "co",
theirs = "ct",
Expand All @@ -53,6 +54,10 @@ M.options = {}
---@param opts? avante.Config
function M.setup(opts)
M.options = vim.tbl_deep_extend("force", M.defaults, opts or {})
if M.options.mappings.show_sidebar ~= nil then
vim.api.nvim_err_writeln("avante: mappings.show_sidebar is deprecated, use mappings.ask instead")
M.options.mappings.ask = M.options.mappings.show_sidebar
end
end

M = setmetatable(M, {
Expand Down
12 changes: 10 additions & 2 deletions lua/avante/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ local Tiktoken = require("avante.tiktoken")
local Sidebar = require("avante.sidebar")
local Config = require("avante.config")
local Diff = require("avante.diff")
local Selection = require("avante.selection")

---@class Avante
local M = {
---@type avante.Sidebar[] we use this to track chat command across tabs
sidebars = {},
---@type avante.Sidebar
current = nil,
selection = nil,
_once = false,
}

Expand All @@ -35,7 +37,7 @@ H.commands = function()
end

H.keymaps = function()
vim.keymap.set({ "n" }, Config.mappings.show_sidebar, M.toggle, { noremap = true })
vim.keymap.set({ "n", "v" }, Config.mappings.ask, M.toggle, { noremap = true })
end

H.autocmds = function()
Expand Down Expand Up @@ -76,7 +78,9 @@ H.autocmds = function()
if s then
s:destroy()
end
M.sidebars[tab] = nil
if tab ~= nil then
M.sidebars[tab] = nil
end
end,
})

Expand Down Expand Up @@ -137,6 +141,10 @@ function M.setup(opts)
highlights = Config.highlights.diff,
})

local selection = Selection:new()
selection:setup()
M.selection = selection

-- setup helpers
H.autocmds()
H.commands()
Expand Down
24 changes: 24 additions & 0 deletions lua/avante/range.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
--@class avante.Range
--@field start table Selection start point
--@field start.line number Line number of the selection start
--@field start.col number Column number of the selection start
--@field finish table Selection end point
--@field finish.line number Line number of the selection end
--@field finish.col number Column number of the selection end
local Range = {}
Range.__index = Range
-- Create a selection range
-- @param start table Selection start point
-- @param start.line number Line number of the selection start
-- @param start.col number Column number of the selection start
-- @param finish table Selection end point
-- @param finish.line number Line number of the selection end
-- @param finish.col number Column number of the selection end
function Range.new(start, finish)
local self = setmetatable({}, Range)
self.start = start
self.finish = finish
return self
end

return Range
96 changes: 96 additions & 0 deletions lua/avante/selection.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
local Config = require("avante.config")
local N = require("nui-components")

local api = vim.api
local fn = vim.fn

local NAMESPACE = api.nvim_create_namespace("avante_selection")
local PRIORITY = vim.highlight.priorities.user

local Selection = {}

function Selection:new()
return setmetatable({
hints_popup_extmark_id = nil,
edit_popup_renderer = nil,
augroup = api.nvim_create_augroup("avante_selection", { clear = true }),
}, { __index = self })
end

function Selection:get_virt_text_line()
local current_pos = fn.getpos(".")
local start_pos = fn.getpos("v")

-- Get the current and start position line numbers
local current_line = current_pos[2] - 1 -- 0-indexed
local start_line = start_pos[2] - 1

-- Ensure line numbers are not negative and don't exceed buffer range
local total_lines = api.nvim_buf_line_count(0)
if current_line < 0 then
current_line = 0
end
if start_line < 0 then
start_line = 0
end
if current_line >= total_lines then
current_line = total_lines - 1
end
if start_line >= total_lines then
start_line = total_lines - 1
end

-- Take the first line of the selection to ensure virt_text is always in the top right corner
return math.min(start_line, current_line)
end

function Selection:show_hints_popup()
self:close_hints_popup()

local hint_text = string.format(" [Ask %s] ", Config.mappings.ask)

local virt_text_line = self:get_virt_text_line()

self.hints_popup_extmark_id = vim.api.nvim_buf_set_extmark(0, NAMESPACE, virt_text_line, -1, {
virt_text = { { hint_text, "Keyword" } },
virt_text_pos = "eol",
priority = PRIORITY,
})
end

function Selection:close_hints_popup()
if self.hints_popup_extmark_id then
vim.api.nvim_buf_del_extmark(0, NAMESPACE, self.hints_popup_extmark_id)
self.hints_popup_extmark_id = nil
end
end

function Selection:setup()
vim.api.nvim_create_autocmd({ "ModeChanged" }, {
pattern = { "n:v", "n:V", "n:" }, -- Entering Visual mode from Normal mode
callback = function()
self:show_hints_popup()
end,
})

api.nvim_create_autocmd({ "CursorMoved", "CursorMovedI" }, {
group = self.augroup,
callback = function()
if vim.fn.mode() == "v" or vim.fn.mode() == "V" or vim.fn.mode() == "" then
self:show_hints_popup()
else
self:close_hints_popup()
end
end,
})

api.nvim_create_autocmd({ "ModeChanged" }, {
group = self.augroup,
pattern = { "v:n", "v:i", "v:c" }, -- Switching from visual mode back to normal, insert, or other modes
callback = function()
self:close_hints_popup()
end,
})
end

return Selection
17 changes: 17 additions & 0 deletions lua/avante/selection_result.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
--@class avante.SelectionResult
--@field content string Selected content
--@field range avante.Range Selection range
local SelectionResult = {}
SelectionResult.__index = SelectionResult

-- Create a selection content and range
--@param content string Selected content
--@param range avante.Range Selection range
function SelectionResult.new(content, range)
local self = setmetatable({}, SelectionResult)
self.content = content
self.range = range
return self
end

return SelectionResult
Loading

0 comments on commit bb4056b

Please sign in to comment.