Skip to content

Commit

Permalink
use callback instead of stream for o1
Browse files Browse the repository at this point in the history
  • Loading branch information
gptlang committed Oct 29, 2024
1 parent ce2a919 commit a6f2e75
Showing 1 changed file with 114 additions and 62 deletions.
176 changes: 114 additions & 62 deletions lua/CopilotChat/copilot.lua
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,118 @@ function Copilot:ask(prompt, opts)

local errored = false
local full_response = ''
---@type fun(err: string, line: string)?
local stream_func = function(err, line)
if not line or errored then
return
end

if err or vim.startswith(line, '{"error"') then
err = 'Failed to get response: ' .. (err and vim.inspect(err) or line)
errored = true
log.error(err)
if self.current_job and on_error then
on_error(err)
end
return
end

line = line:gsub('data: ', '')
if line == '' then
return
elseif line == '[DONE]' then
log.trace('Full response: ' .. full_response)
self.token_count = self.token_count + tiktoken.count(full_response)

if self.current_job and on_done then
on_done(full_response, self.token_count + current_count)
end

table.insert(self.history, {
content = full_response,
role = 'assistant',
})
return
end

local ok, content = pcall(vim.json.decode, line, {
luanil = {
object = true,
array = true,
},
})

if not ok then
err = 'Failed parse response: \n' .. line .. '\n' .. vim.inspect(content)
log.error(err)
return
end

if not content.choices or #content.choices == 0 then
return
end

content = content.choices[1].delta.content
if not content then
return
end

if self.current_job and on_progress then
on_progress(content)
end

-- Collect full response incrementally so we can insert it to history later
full_response = full_response .. content
end
if is_o1(model) then
stream_func = nil
end

---@type fun(response: table)?
local nonstream_callback = function(response)
if response.status ~= 200 then
local err = 'Failed to get response: ' .. tostring(response.status)
log.error(err)
if on_error then
on_error(err)
end
return
end

local ok, content = pcall(vim.json.decode, response.body, {
luanil = {
object = true,
array = true,
},
})

if not ok then
local err = 'Failed parse response: ' .. vim.inspect(content)
log.error(err)
if on_error then
on_error(err)
end
return
end

full_response = content.choices[1].message.content
if on_progress then
on_progress(full_response)
end
self.token_count = self.token_count + tiktoken.count(full_response)
if on_done then
on_done(full_response, self.token_count + current_count)
end

table.insert(self.history, {
content = full_response,
role = 'assistant',
})
end

if not is_o1(model) then
nonstream_callback = nil
end

self:with_auth(function()
local headers = generate_headers(self.token.token, self.sessionid, self.machineid)
Expand All @@ -451,75 +563,15 @@ function Copilot:ask(prompt, opts)
body = temp_file(body),
proxy = self.proxy,
insecure = self.allow_insecure,
callback = nonstream_callback,
on_error = function(err)
err = 'Failed to get response: ' .. vim.inspect(err)
log.error(err)
if self.current_job and on_error then
on_error(err)
end
end,
stream = function(err, line)
if not line or errored then
return
end

if err or vim.startswith(line, '{"error"') then
err = 'Failed to get response: ' .. (err and vim.inspect(err) or line)
errored = true
log.error(err)
if self.current_job and on_error then
on_error(err)
end
return
end

line = line:gsub('data: ', '')
if line == '' then
return
elseif line == '[DONE]' then
log.trace('Full response: ' .. full_response)
self.token_count = self.token_count + tiktoken.count(full_response)

if self.current_job and on_done then
on_done(full_response, self.token_count + current_count)
end

table.insert(self.history, {
content = full_response,
role = 'assistant',
})
return
end

local ok, content = pcall(vim.json.decode, line, {
luanil = {
object = true,
array = true,
},
})

if not ok then
err = 'Failed parse response: \n' .. line .. '\n' .. vim.inspect(content)
log.error(err)
return
end

if not content.choices or #content.choices == 0 then
return
end

content = content.choices[1].delta.content
if not content then
return
end

if self.current_job and on_progress then
on_progress(content)
end

-- Collect full response incrementally so we can insert it to history later
full_response = full_response .. content
end,
stream = stream_func,
})
:after(function()
self.current_job = nil
Expand Down

0 comments on commit a6f2e75

Please sign in to comment.