diff --git a/examples/go-chat/main.go b/examples/go-chat/main.go index 7663fb8f4ee..bdbd2ae645b 100644 --- a/examples/go-chat/main.go +++ b/examples/go-chat/main.go @@ -15,19 +15,19 @@ func main() { } messages := []api.Message{ - api.Message{ + { Role: "system", Content: "Provide very brief, concise responses", }, - api.Message{ + { Role: "user", Content: "Name some unusual animals", }, - api.Message{ + { Role: "assistant", Content: "Monotreme, platypus, echidna", }, - api.Message{ + { Role: "user", Content: "which of these is the most dangerous?", }, diff --git a/openai/openai.go b/openai/openai.go index bda42b4da3d..1c1910791c1 100644 --- a/openai/openai.go +++ b/openai/openai.go @@ -32,8 +32,8 @@ type ErrorResponse struct { } type Message struct { - Role string `json:"role"` - Content any `json:"content"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` ToolCalls []ToolCall `json:"tool_calls,omitempty"` } @@ -45,7 +45,7 @@ type Choice struct { type ChunkChoice struct { Index int `json:"index"` - Delta Message `json:"delta"` + Delta Message `json:"delta,omitempty"` FinishReason *string `json:"finish_reason"` } @@ -139,6 +139,8 @@ type CompletionChunk struct { } type ToolCall struct { + // Index is not nil only in chat completion chunk object + Index *int `json:"index,omitempty"` ID string `json:"id"` Type string `json:"type"` Function struct { @@ -244,6 +246,28 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { } func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { + toolCalls := make([]ToolCall, len(r.Message.ToolCalls)) + for i, tc := range r.Message.ToolCalls { + idx := i + toolCalls[i].Index = &idx + toolCalls[i].ID = toolCallId() + toolCalls[i].Type = "function" + toolCalls[i].Function.Name = tc.Function.Name + + args, err := json.Marshal(tc.Function.Arguments) + if err != nil { + slog.Error("could not marshall function arguments to json", "error", err) + continue + } + + toolCalls[i].Function.Arguments = string(args) + } + + message := Message{Role: "assistant", Content: r.Message.Content} + hasToolCalls := len(toolCalls) > 0 + if hasToolCalls { + message = Message{ToolCalls: toolCalls} + } return ChatCompletionChunk{ Id: id, Object: "chat.completion.chunk", @@ -252,8 +276,12 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { SystemFingerprint: "fp_ollama", Choices: []ChunkChoice{{ Index: 0, - Delta: Message{Role: "assistant", Content: r.Message.Content}, + // Delta: Message{Role: "assistant", Content: r.Message.Content}, + Delta: message, FinishReason: func(reason string) *string { + // if hasToolCalls { + // reason = "tool_calls" + // } if len(reason) > 0 { return &reason } @@ -610,6 +638,7 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) { if chatResponse.Done { _, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) if err != nil { + slog.Error("writeResponse done", "err", err) return 0, err } } diff --git a/server/routes.go b/server/routes.go index 6c470c174c1..5230378b187 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1362,9 +1362,15 @@ func (s *Server) ChatHandler(c *gin.Context) { slog.Debug("chat request", "images", len(images), "prompt", prompt) + toolCallsCh := make(chan []api.ToolCall, 1) + contentCh := make(chan string, 1) + ch := make(chan any) go func() { + var sb strings.Builder defer close(ch) + defer close(toolCallsCh) + defer close(contentCh) if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -1384,8 +1390,17 @@ func (s *Server) ChatHandler(c *gin.Context) { EvalDuration: r.EvalDuration, }, } + sb.WriteString(r.Content) if r.Done { + content := sb.String() + contentCh <- content + if len(req.Tools) > 0 { + if toolCalls, ok := m.parseToolCalls(content); ok { + toolCallsCh <- toolCalls + } + } + res.TotalDuration = time.Since(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart) } @@ -1396,13 +1411,13 @@ func (s *Server) ChatHandler(c *gin.Context) { } }() + toolsRequired := len(req.Tools) > 0 + // no stream response if req.Stream != nil && !*req.Stream { var resp api.ChatResponse - var sb strings.Builder for rr := range ch { switch t := rr.(type) { case api.ChatResponse: - sb.WriteString(t.Message.Content) resp = t case gin.H: msg, ok := t["error"].(string) @@ -1418,10 +1433,11 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - resp.Message.Content = sb.String() - - if len(req.Tools) > 0 { - if toolCalls, ok := m.parseToolCalls(sb.String()); ok { + content := <-contentCh + resp.Message.Content = content + if toolsRequired { + toolCalls := <-toolCallsCh + if len(toolCalls) > 0 { resp.Message.ToolCalls = toolCalls resp.Message.Content = "" } @@ -1431,7 +1447,59 @@ func (s *Server) ChatHandler(c *gin.Context) { return } - streamResponse(c, ch) + // stream response + streamCh := make(chan any) + for rr := range ch { + switch t := rr.(type) { + case api.ChatResponse: + go func() { + // slog.Warn("reassign chat response", "content", t.Message.Content) + streamCh <- t + if t.Done { + // slog.Warn("close stream channel") + close(streamCh) + } + }() + case gin.H: + msg, ok := t["error"].(string) + if !ok { + msg = "unexpected error format in response" + } + c.JSON(http.StatusInternalServerError, gin.H{"error": msg}) + return + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) + return + } + } + + // if tools are required + if toolsRequired { + toolCalls := <-toolCallsCh + // if tool calls are present, use different channel respose + hasToolCalls := len(toolCalls) > 0 + if hasToolCalls { + // reset the channel + toolCallsCh := make(chan any, 1) + res := api.ChatResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant", ToolCalls: toolCalls}, + Done: true, + DoneReason: "tool_calls", + } + toolCallsCh <- res + close(toolCallsCh) + slog.Info("[tools] stream response") + streamResponse(c, toolCallsCh) + return + } else { + slog.Info("[tools] no call") + } + } + + slog.Info("stream response") + streamResponse(c, streamCh) } func handleScheduleError(c *gin.Context, name string, err error) {