Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔧 #195 #196: Set router ctx in stream chunks & handle end of stream in case of some errors #203

Merged
merged 4 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 5 additions & 13 deletions pkg/api/http/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,28 +122,20 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout
wg sync.WaitGroup
)

chunkResultC := make(chan *schemas.ChatStreamResult)
chatStreamC := make(chan *schemas.ChatStreamMessage)

router, _ := routerManager.GetLangRouter(routerID)

defer close(chunkResultC)
defer close(chatStreamC)
defer c.Conn.Close()

wg.Add(1)

go func() {
defer wg.Done()

for chunkResult := range chunkResultC {
if chunkResult.Error() != nil {
if err = c.WriteJSON(chunkResult.Error()); err != nil {
break
}

continue
}

if err = c.WriteJSON(chunkResult.Chunk()); err != nil {
for chatStreamMsg := range chatStreamC {
if err = c.WriteJSON(chatStreamMsg); err != nil {
break
}
}
Expand All @@ -168,7 +160,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout
go func(chatRequest schemas.ChatStreamRequest) {
defer wg.Done()

router.ChatStream(context.Background(), &chatRequest, chunkResultC)
router.ChatStream(context.Background(), &chatRequest, chatStreamC)
}(chatRequest)
}

Expand Down
99 changes: 63 additions & 36 deletions pkg/api/schemas/chat_stream.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,34 @@
package schemas

import "time"

type (
Metadata = map[string]any
EventType = string
FinishReason = string
ErrorCode = string
)

var (
Complete FinishReason = "complete"
MaxTokens FinishReason = "max_tokens"
ContentFiltered FinishReason = "content_filtered"
Other FinishReason = "other"
ErrorReason FinishReason = "error"
OtherReason FinishReason = "other"
)

var (
NoModelConfigured ErrorCode = "no_model_configured"
ModelUnavailable ErrorCode = "model_unavailable"
AllModelsUnavailable ErrorCode = "all_models_unavailable"
UnknownError ErrorCode = "unknown_error"
)

type StreamRequestID = string

// ChatStreamRequest defines a message that requests a new streaming chat
type ChatStreamRequest struct {
ID string `json:"id" validate:"required"`
ID StreamRequestID `json:"id" validate:"required"`
Message ChatMessage `json:"message" validate:"required"`
MessageHistory []ChatMessage `json:"messageHistory" validate:"required"`
Override *OverrideChatRequest `json:"overrideMessage,omitempty"`
Expand All @@ -32,54 +46,67 @@ func NewChatStreamFromStr(message string) *ChatStreamRequest {
}

type ModelChunkResponse struct {
Metadata *Metadata `json:"metadata,omitempty"`
Message ChatMessage `json:"message"`
FinishReason *FinishReason `json:"finishReason,omitempty"`
Metadata *Metadata `json:"metadata,omitempty"`
Message ChatMessage `json:"message"`
}

type ChatStreamMessage struct {
ID StreamRequestID `json:"id"`
CreatedAt int `json:"createdAt"`
RouterID string `json:"routerId"`
Metadata *Metadata `json:"metadata,omitempty"`
Chunk *ChatStreamChunk `json:"chunk,omitempty"`
Error *ChatStreamError `json:"error,omitempty"`
}

// ChatStreamChunk defines a message for a chunk of streaming chat response
type ChatStreamChunk struct {
ID string `json:"id"`
CreatedAt int `json:"createdAt"`
Provider string `json:"providerId"`
RouterID string `json:"routerId"`
ModelID string `json:"modelId"`
Cached bool `json:"cached"`
Provider string `json:"providerName"`
ModelName string `json:"modelName"`
Metadata *Metadata `json:"metadata,omitempty"`
Cached bool `json:"cached"`
ModelResponse ModelChunkResponse `json:"modelResponse"`
FinishReason *FinishReason `json:"finishReason,omitempty"`
}

type ChatStreamError struct {
ID string `json:"id"`
ErrCode string `json:"errCode"`
Message string `json:"message"`
Metadata *Metadata `json:"metadata,omitempty"`
}

type ChatStreamResult struct {
chunk *ChatStreamChunk
err *ChatStreamError
}

func (r *ChatStreamResult) Chunk() *ChatStreamChunk {
return r.chunk
}

func (r *ChatStreamResult) Error() *ChatStreamError {
return r.err
ErrCode ErrorCode `json:"errCode"`
Message string `json:"message"`
FinishReason *FinishReason `json:"finishReason,omitempty"`
}

func NewChatStreamResult(chunk *ChatStreamChunk) *ChatStreamResult {
return &ChatStreamResult{
chunk: chunk,
err: nil,
func NewChatStreamChunk(
reqID StreamRequestID,
routerID string,
reqMetadata *Metadata,
chunk *ChatStreamChunk,
) *ChatStreamMessage {
return &ChatStreamMessage{
ID: reqID,
RouterID: routerID,
CreatedAt: int(time.Now().UTC().Unix()),
Metadata: reqMetadata,
Chunk: chunk,
}
}

func NewChatStreamErrorResult(err *ChatStreamError) *ChatStreamResult {
return &ChatStreamResult{
chunk: nil,
err: err,
func NewChatStreamError(
reqID StreamRequestID,
routerID string,
errCode ErrorCode,
errMsg string,
reqMetadata *Metadata,
finishReason *FinishReason,
) *ChatStreamMessage {
return &ChatStreamMessage{
ID: reqID,
RouterID: routerID,
CreatedAt: int(time.Now().UTC().Unix()),
Metadata: reqMetadata,
Error: &ChatStreamError{
ErrCode: errCode,
Message: errMsg,
FinishReason: finishReason,
},
}
}
14 changes: 2 additions & 12 deletions pkg/providers/azureopenai/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ type ChatStream struct {
tel *telemetry.Telemetry
client *http.Client
req *http.Request
reqID string
reqMetadata *schemas.Metadata
resp *http.Response
reader *sse.EventStreamReader
finishReasonMapper *openai.FinishReasonMapper
Expand All @@ -37,17 +35,13 @@ func NewChatStream(
tel *telemetry.Telemetry,
client *http.Client,
req *http.Request,
reqID string,
reqMetadata *schemas.Metadata,
finishReasonMapper *openai.FinishReasonMapper,
errMapper *ErrorMapper,
) *ChatStream {
return &ChatStream{
tel: tel,
client: client,
req: req,
reqID: reqID,
reqMetadata: reqMetadata,
finishReasonMapper: finishReasonMapper,
errMapper: errMapper,
}
Expand Down Expand Up @@ -129,11 +123,9 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
ID: s.reqID,
Provider: providerName,
Cached: false,
Provider: providerName,
ModelName: completionChunk.ModelName,
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"response_id": completionChunk.ID,
Expand All @@ -143,8 +135,8 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
Role: responseChunk.Delta.Role,
Content: responseChunk.Delta.Content,
},
FinishReason: s.finishReasonMapper.Map(responseChunk.FinishReason),
},
FinishReason: s.finishReasonMapper.Map(responseChunk.FinishReason),
}, nil
}
}
Expand Down Expand Up @@ -172,8 +164,6 @@ func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest)
c.tel,
c.httpClient,
httpRequest,
req.ID,
req.Metadata,
c.finishReasonMapper,
c.errMapper,
), nil
Expand Down
11 changes: 1 addition & 10 deletions pkg/providers/cohere/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
c.tel.Logger.Error(
"cohere chat request failed",
zap.Int("status_code", resp.StatusCode),
zap.String("response", string(bodyBytes)),
zap.ByteString("response", bodyBytes),
zap.Any("headers", resp.Header),
)

Expand All @@ -127,15 +127,6 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
return nil, err
}

// Parse the response JSON
var responseJSON map[string]interface{}

err = json.Unmarshal(bodyBytes, &responseJSON)
if err != nil {
c.tel.Logger.Error("failed to parse cohere chat response", zap.Error(err))
return nil, err
}

// Parse the response JSON
var cohereCompletion ChatCompletion

Expand Down
24 changes: 6 additions & 18 deletions pkg/providers/cohere/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ var (
type ChatStream struct {
client *http.Client
req *http.Request
reqID string
modelName string
reqMetadata *schemas.Metadata
resp *http.Response
generationID string
streamFinished bool
Expand All @@ -46,19 +44,15 @@ func NewChatStream(
tel *telemetry.Telemetry,
client *http.Client,
req *http.Request,
reqID string,
modelName string,
reqMetadata *schemas.Metadata,
errMapper *ErrorMapper,
finishReasonMapper *FinishReasonMapper,
) *ChatStream {
return &ChatStream{
tel: tel,
client: client,
req: req,
reqID: reqID,
modelName: modelName,
reqMetadata: reqMetadata,
errMapper: errMapper,
streamFinished: false,
finishReasonMapper: finishReasonMapper,
Expand Down Expand Up @@ -136,35 +130,31 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
ID: s.reqID,
Provider: providerName,
Cached: false,
Provider: providerName,
ModelName: s.modelName,
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"generationId": s.generationID,
"responseId": responseChunk.Response.ResponseID,
"generation_id": s.generationID,
"response_id": responseChunk.Response.ResponseID,
},
Message: schemas.ChatMessage{
Role: "model",
Content: responseChunk.Text,
},
FinishReason: s.finishReasonMapper.Map(responseChunk.FinishReason),
},
FinishReason: s.finishReasonMapper.Map(responseChunk.FinishReason),
}, nil
}

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
ID: s.reqID,
Provider: providerName,
Cached: false,
Provider: providerName,
ModelName: s.modelName,
Metadata: s.reqMetadata,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"generationId": s.generationID,
"generation_id": s.generationID,
},
Message: schemas.ChatMessage{
Role: "model",
Expand Down Expand Up @@ -198,9 +188,7 @@ func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest)
c.tel,
c.httpClient,
httpRequest,
req.ID,
c.chatRequestTemplate.Model,
req.Metadata,
c.errMapper,
c.finishReasonMapper,
), nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/cohere/finish_reason.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (m *FinishReasonMapper) Map(finishReason *string) *schemas.FinishReason {
zap.String("unknown_reason", *finishReason),
)

reason = &schemas.Other
reason = &schemas.OtherReason
}

return reason
Expand Down
18 changes: 10 additions & 8 deletions pkg/providers/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,19 @@ func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage {

func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
startedAt := time.Now()
resp, err := m.client.Chat(ctx, request)

if err == nil {
// record latency per token to normalize measurements
m.chatLatency.Add(float64(time.Since(startedAt)) / float64(resp.ModelResponse.TokenUsage.ResponseTokens))

// successful response
resp.ModelID = m.modelID
resp, err := m.client.Chat(ctx, request)
if err != nil {
m.healthTracker.TrackErr(err)

return resp, err
}

m.healthTracker.TrackErr(err)
// record latency per token to normalize measurements
m.chatLatency.Add(float64(time.Since(startedAt)) / float64(resp.ModelResponse.TokenUsage.ResponseTokens))

// successful response
resp.ModelID = m.modelID

return resp, err
}
Expand Down Expand Up @@ -151,6 +151,8 @@ func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatStreamR
return
}

chunk.ModelID = m.modelID

streamResultC <- clients.NewChatStreamResult(chunk, nil)

if chunkLatency > 1*time.Millisecond {
Expand Down
Loading
Loading