Skip to content

Commit

Permalink
Refactor Claude response handling to use structured JSON decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
swuecho committed Sep 14, 2024
1 parent b258d8b commit 1e36946
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 9 deletions.
18 changes: 9 additions & 9 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (

"github.com/gorilla/mux"

"github.com/swuecho/chat_backend/llm/claude"
claude "github.com/swuecho/chat_backend/llm/claude"
"github.com/swuecho/chat_backend/llm/gemini"
"github.com/swuecho/chat_backend/models"
"github.com/swuecho/chat_backend/sqlc_queries"
Expand Down Expand Up @@ -949,14 +949,14 @@ func (h *ChatHandler) chatStreamClaude3(w http.ResponseWriter, chatSession sqlc_
}

if !stream {
body, err := io.ReadAll(resp.Body)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_read_response", err)
// Unmarshal directly from resp.Body
var message claude.Response
if err := json.NewDecoder(resp.Body).Decode(&message); err != nil {
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_unmarshal_response", err)
return "", "", true
}
log.Printf("%+v", string(body))
uuid := NewUUID()
answer := constructChatCompletionStreamReponse(uuid, string(body))
uuid := message.ID
answer := constructChatCompletionStreamReponse(uuid, message.Content[0].Text)
data, _ := json.Marshal(answer)
fmt.Fprint(w, string(data))
return string(data), uuid, false
Expand Down Expand Up @@ -1426,14 +1426,14 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q
// Handle non-streaming response
body, err := io.ReadAll(resp.Body)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "Error reading response body").Error(), err)
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_read_response", err)
return "", "", true
}
// body to GeminiResponse
var geminiResp gemini.ResponseBody
err = json.Unmarshal(body, &geminiResp)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "Error unmarshalling response body").Error(), err)
RespondWithError(w, http.StatusInternalServerError, "error.fail_to_unmarshal_response", err)
return "", "", true
}
answer := geminiResp.Candidates[0].Content.Parts[0].Text
Expand Down
25 changes: 25 additions & 0 deletions api/llm/claude/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,28 @@ func FormatClaudePrompt(chat_compeletion_messages []models.Message) string {
prompt := sb.String()
return prompt
}


// response (not stream)

type Response struct {
ID string `json:"id"`
Type string `json:"type"`
Role string `json:"role"`
Model string `json:"model"`
Content []Content `json:"content"`
StopReason string `json:"stop_reason"`
StopSequence interface{} `json:"stop_sequence"`
Usage Usage `json:"usage"`
}

type Content struct {
Type string `json:"type"`
Text string `json:"text"`
}

type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

3 changes: 3 additions & 0 deletions api/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ func (m *Message) SetTokenCount(tokenCount int32) *Message {
m.tokenCount = tokenCount
return m
}



0 comments on commit 1e36946

Please sign in to comment.