Skip to content

Commit

Permalink
🔧 #195 #196: Set router ctx in stream chunks & handle end of stream i…
Browse files Browse the repository at this point in the history
…n case of some errors (#203)

- Passed RouterID and ModelID information in the chat stream messages
- Introduced a new ChatStreamMessage type that joins both chunk and error messages. Removed unneeded context from provider chatStream structs
- defined a set of possible error codes during chat streaming
- started simplifying logging by using context-based loggers
- Introduced finish_reason on the error schema
  • Loading branch information
roma-glushko authored Apr 16, 2024
1 parent de3677e commit 4a9735c
Show file tree
Hide file tree
Showing 15 changed files with 198 additions and 191 deletions.
18 changes: 5 additions & 13 deletions pkg/api/http/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,28 +122,20 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout
wg sync.WaitGroup
)

chunkResultC := make(chan *schemas.ChatStreamResult)
chatStreamC := make(chan *schemas.ChatStreamMessage)

router, _ := routerManager.GetLangRouter(routerID)

defer close(chunkResultC)
defer close(chatStreamC)
defer c.Conn.Close()

wg.Add(1)

go func() {
defer wg.Done()

for chunkResult := range chunkResultC {
if chunkResult.Error() != nil {
if err = c.WriteJSON(chunkResult.Error()); err != nil {
break
}

continue
}

if err = c.WriteJSON(chunkResult.Chunk()); err != nil {
for chatStreamMsg := range chatStreamC {
if err = c.WriteJSON(chatStreamMsg); err != nil {
break
}
}
Expand All @@ -168,7 +160,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout
go func(chatRequest schemas.ChatStreamRequest) {
defer wg.Done()

router.ChatStream(context.Background(), &chatRequest, chunkResultC)
router.ChatStream(context.Background(), &chatRequest, chatStreamC)
}(chatRequest)
}

Expand Down
99 changes: 63 additions & 36 deletions pkg/api/schemas/chat_stream.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
package schemas

import "time"

type (
Metadata = map[string]any
EventType = string
FinishReason = string
ErrorCode = string
)

var (
Complete FinishReason = "complete"
MaxTokens FinishReason = "max_tokens"
ContentFiltered FinishReason = "content_filtered"
Other FinishReason = "other"
ErrorReason FinishReason = "error"
OtherReason FinishReason = "other"
)

var (
NoModelConfigured ErrorCode = "no_model_configured"
ModelUnavailable ErrorCode = "model_unavailable"
AllModelsUnavailable ErrorCode = "all_models_unavailable"
UnknownError ErrorCode = "unknown_error"
)

type StreamRequestID = string

// ChatStreamRequest defines a message that requests a new streaming chat
type ChatStreamRequest struct {
ID string `json:"id" validate:"required"`
ID StreamRequestID `json:"id" validate:"required"`
Message ChatMessage `json:"message" validate:"required"`
MessageHistory []ChatMessage `json:"messageHistory" validate:"required"`
Override *OverrideChatRequest `json:"overrideMessage,omitempty"`
Expand All @@ -32,54 +46,67 @@ func NewChatStreamFromStr(message string) *ChatStreamRequest {
}

type ModelChunkResponse struct {
Metadata *Metadata `json:"metadata,omitempty"`
Message ChatMessage `json:"message"`
FinishReason *FinishReason `json:"finishReason,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
Message ChatMessage `json:"message"`
}

type ChatStreamMessage struct {
ID StreamRequestID `json:"id"`
CreatedAt int `json:"createdAt"`
RouterID string `json:"routerId"`
Metadata *Metadata `json:"metadata,omitempty"`
Chunk *ChatStreamChunk `json:"chunk,omitempty"`
Error *ChatStreamError `json:"error,omitempty"`
}

// ChatStreamChunk defines a message for a chunk of streaming chat response
type ChatStreamChunk struct {
ID string `json:"id"`
CreatedAt int `json:"createdAt"`
Provider string `json:"providerId"`
RouterID string `json:"routerId"`
ModelID string `json:"modelId"`
Cached bool `json:"cached"`
Provider string `json:"providerName"`
ModelName string `json:"modelName"`
Metadata *Metadata `json:"metadata,omitempty"`
Cached bool `json:"cached"`
ModelResponse ModelChunkResponse `json:"modelResponse"`
FinishReason *FinishReason `json:"finishReason,omitempty"`
}

type ChatStreamError struct {
ID string `json:"id"`
ErrCode string `json:"errCode"`
Message string `json:"message"`
Metadata *Metadata `json:"metadata,omitempty"`
}

type ChatStreamResult struct {
chunk *ChatStreamChunk
err *ChatStreamError
}

func (r *ChatStreamResult) Chunk() *ChatStreamChunk {
return r.chunk
}

func (r *ChatStreamResult) Error() *ChatStreamError {
return r.err
ErrCode ErrorCode `json:"errCode"`
Message string `json:"message"`
FinishReason *FinishReason `json:"finishReason,omitempty"`
}

func NewChatStreamResult(chunk *ChatStreamChunk) *ChatStreamResult {
return &ChatStreamResult{
chunk: chunk,
err: nil,
func NewChatStreamChunk(
reqID StreamRequestID,
routerID string,
reqMetadata *Metadata,
chunk *ChatStreamChunk,
) *ChatStreamMessage {
return &ChatStreamMessage{
ID: reqID,
RouterID: routerID,
CreatedAt: int(time.Now().UTC().Unix()),
Metadata: reqMetadata,
Chunk: chunk,
}
}

func NewChatStreamErrorResult(err *ChatStreamError) *ChatStreamResult {
return &ChatStreamResult{
chunk: nil,
err: err,
func NewChatStreamError(
reqID StreamRequestID,
routerID string,
errCode ErrorCode,
errMsg string,
reqMetadata *Metadata,
finishReason *FinishReason,
) *ChatStreamMessage {
return &ChatStreamMessage{
ID: reqID,
RouterID: routerID,
CreatedAt: int(time.Now().UTC().Unix()),
Metadata: reqMetadata,
Error: &ChatStreamError{
ErrCode: errCode,
Message: errMsg,
FinishReason: finishReason,
},
}
}
14 changes: 2 additions & 12 deletions pkg/providers/azureopenai/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ type ChatStream struct {
tel *telemetry.Telemetry
client *http.Client
req *http.Request
reqID string
reqMetadata *schemas.Metadata
resp *http.Response
reader *sse.EventStreamReader
finishReasonMapper *openai.FinishReasonMapper
Expand All @@ -37,17 +35,13 @@ func NewChatStream(
tel *telemetry.Telemetry,
client *http.Client,
req *http.Request,
reqID string,
reqMetadata *schemas.Metadata,
finishReasonMapper *openai.FinishReasonMapper,
errMapper *ErrorMapper,
) *ChatStream {
return &ChatStream{
tel: tel,
client: client,
req: req,
reqID: reqID,
reqMetadata: reqMetadata,
finishReasonMapper: finishReasonMapper,
errMapper: errMapper,
}
Expand Down Expand Up @@ -129,11 +123,9 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
ID: s.reqID,
Provider: providerName,
Cached: false,
Provider: providerName,
ModelName: completionChunk.ModelName,
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"response_id": completionChunk.ID,
Expand All @@ -143,8 +135,8 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
Role: responseChunk.Delta.Role,
Content: responseChunk.Delta.Content,
},
FinishReason: s.finishReasonMapper.Map(responseChunk.FinishReason),
},
FinishReason: s.finishReasonMapper.Map(responseChunk.FinishReason),
}, nil
}
}
Expand Down Expand Up @@ -172,8 +164,6 @@ func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest)
c.tel,
c.httpClient,
httpRequest,
req.ID,
req.Metadata,
c.finishReasonMapper,
c.errMapper,
), nil
Expand Down
11 changes: 1 addition & 10 deletions pkg/providers/cohere/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
c.tel.Logger.Error(
"cohere chat request failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(bodyBytes)),
zap.ByteString("response", bodyBytes),
zap.Any("headers", resp.Header),
)

Expand All @@ -127,15 +127,6 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
return nil, err
}

// Parse the response JSON
var responseJSON map[string]interface{}

err = json.Unmarshal(bodyBytes, &responseJSON)
if err != nil {
c.tel.Logger.Error("failed to parse cohere chat response", zap.Error(err))
return nil, err
}

// Parse the response JSON
var cohereCompletion ChatCompletion

Expand Down
24 changes: 6 additions & 18 deletions pkg/providers/cohere/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ var (
type ChatStream struct {
client *http.Client
req *http.Request
reqID string
modelName string
reqMetadata *schemas.Metadata
resp *http.Response
generationID string
streamFinished bool
Expand All @@ -46,19 +44,15 @@ func NewChatStream(
tel *telemetry.Telemetry,
client *http.Client,
req *http.Request,
reqID string,
modelName string,
reqMetadata *schemas.Metadata,
errMapper *ErrorMapper,
finishReasonMapper *FinishReasonMapper,
) *ChatStream {
return &ChatStream{
tel: tel,
client: client,
req: req,
reqID: reqID,
modelName: modelName,
reqMetadata: reqMetadata,
errMapper: errMapper,
streamFinished: false,
finishReasonMapper: finishReasonMapper,
Expand Down Expand Up @@ -136,35 +130,31 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
ID: s.reqID,
Provider: providerName,
Cached: false,
Provider: providerName,
ModelName: s.modelName,
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"generationId": s.generationID,
"responseId": responseChunk.Response.ResponseID,
"generation_id": s.generationID,
"response_id": responseChunk.Response.ResponseID,
},
Message: schemas.ChatMessage{
Role: "model",
Content: responseChunk.Text,
},
FinishReason: s.finishReasonMapper.Map(responseChunk.FinishReason),
},
FinishReason: s.finishReasonMapper.Map(responseChunk.FinishReason),
}, nil
}

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
ID: s.reqID,
Provider: providerName,
Cached: false,
Provider: providerName,
ModelName: s.modelName,
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"generationId": s.generationID,
"generation_id": s.generationID,
},
Message: schemas.ChatMessage{
Role: "model",
Expand Down Expand Up @@ -198,9 +188,7 @@ func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest)
c.tel,
c.httpClient,
httpRequest,
req.ID,
c.chatRequestTemplate.Model,
req.Metadata,
c.errMapper,
c.finishReasonMapper,
), nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/cohere/finish_reason.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (m *FinishReasonMapper) Map(finishReason *string) *schemas.FinishReason {
zap.String("unknown_reason", *finishReason),
)

reason = &schemas.Other
reason = &schemas.OtherReason
}

return reason
Expand Down
18 changes: 10 additions & 8 deletions pkg/providers/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,19 @@ func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage {

func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
startedAt := time.Now()
resp, err := m.client.Chat(ctx, request)

if err == nil {
// record latency per token to normalize measurements
m.chatLatency.Add(float64(time.Since(startedAt)) / float64(resp.ModelResponse.TokenUsage.ResponseTokens))

// successful response
resp.ModelID = m.modelID
resp, err := m.client.Chat(ctx, request)
if err != nil {
m.healthTracker.TrackErr(err)

return resp, err
}

m.healthTracker.TrackErr(err)
// record latency per token to normalize measurements
m.chatLatency.Add(float64(time.Since(startedAt)) / float64(resp.ModelResponse.TokenUsage.ResponseTokens))

// successful response
resp.ModelID = m.modelID

return resp, err
}
Expand Down Expand Up @@ -151,6 +151,8 @@ func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatStreamR
return
}

chunk.ModelID = m.modelID

streamResultC <- clients.NewChatStreamResult(chunk, nil)

if chunkLatency > 1*time.Millisecond {
Expand Down
Loading

0 comments on commit 4a9735c

Please sign in to comment.