Skip to content

Commit

Permalink
#196: Fixed tests after major refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed Apr 15, 2024
1 parent 5e1adc3 commit 81a9169
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 53 deletions.
2 changes: 1 addition & 1 deletion pkg/providers/azureopenai/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
Provider: providerName,
Cached: false,
Provider: providerName,
ModelName: completionChunk.ModelName,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
Expand Down
20 changes: 8 additions & 12 deletions pkg/providers/cohere/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ import (
"context"
"encoding/json"
"fmt"
"glide/pkg/providers/clients"
"glide/pkg/telemetry"
"io"
"net/http"

"glide/pkg/providers/clients"
"glide/pkg/telemetry"

"go.uber.org/zap"

"glide/pkg/api/schemas"
Expand All @@ -29,9 +30,7 @@ var (
type ChatStream struct {
client *http.Client
req *http.Request
reqID string
modelName string
reqMetadata *schemas.Metadata
resp *http.Response
generationID string
streamFinished bool
Expand All @@ -46,7 +45,6 @@ func NewChatStream(
client *http.Client,
req *http.Request,
modelName string,
reqMetadata *schemas.Metadata,
errMapper *ErrorMapper,
finishReasonMapper *FinishReasonMapper,
) *ChatStream {
Expand All @@ -55,7 +53,6 @@ func NewChatStream(
client: client,
req: req,
modelName: modelName,
reqMetadata: reqMetadata,
errMapper: errMapper,
streamFinished: false,
finishReasonMapper: finishReasonMapper,
Expand Down Expand Up @@ -133,13 +130,13 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
Provider: providerName,
Cached: false,
Provider: providerName,
ModelName: s.modelName,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"generationId": s.generationID,
"responseId": responseChunk.Response.ResponseID,
"generation_id": s.generationID,
"response_id": responseChunk.Response.ResponseID,
},
Message: schemas.ChatMessage{
Role: "model",
Expand All @@ -152,12 +149,12 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
Provider: providerName,
Cached: false,
Provider: providerName,
ModelName: s.modelName,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
"generationId": s.generationID,
"generation_id": s.generationID,
},
Message: schemas.ChatMessage{
Role: "model",
Expand Down Expand Up @@ -192,7 +189,6 @@ func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest)
c.httpClient,
httpRequest,
c.chatRequestTemplate.Model,
req.Metadata,
c.errMapper,
c.finishReasonMapper,
), nil
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage {

func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
startedAt := time.Now()
resp, err := m.client.Chat(ctx, request)

resp, err := m.client.Chat(ctx, request)
if err != nil {
m.healthTracker.TrackErr(err)

Expand Down
10 changes: 4 additions & 6 deletions pkg/providers/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", string(c.config.APIKey)))

// TODO: this could leak information from messages which may not be a desired thing to have
c.tel.Logger.Debug(
c.logger.Debug(
"Chat Request",
zap.String("provider", c.Provider()),
zap.String("chatURL", c.chatURL),
zap.Any("payload", payload),
)
Expand All @@ -105,9 +104,9 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
// Read the response body into a byte slice
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
c.tel.Logger.Error(
c.logger.Error(
"Failed to read chat response",
zap.String("provider", c.Provider()), zap.Error(err),
zap.Error(err),
zap.ByteString("rawResponse", bodyBytes),
)

Expand All @@ -119,9 +118,8 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche

err = json.Unmarshal(bodyBytes, &chatCompletion)
if err != nil {
c.tel.Logger.Error(
c.logger.Error(
"Failed to unmarshal chat response",
zap.String("provider", c.Provider()),
zap.ByteString("rawResponse", bodyBytes),
zap.Error(err),
)
Expand Down
25 changes: 9 additions & 16 deletions pkg/providers/openai/chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import (

"github.com/r3labs/sse/v2"
"glide/pkg/providers/clients"
"glide/pkg/telemetry"

"go.uber.org/zap"

"glide/pkg/api/schemas"
Expand All @@ -21,30 +19,28 @@ var StreamDoneMarker = []byte("[DONE]")

// ChatStream represents OpenAI chat stream for a specific request
type ChatStream struct {
tel *telemetry.Telemetry
client *http.Client
req *http.Request
reqID schemas.StreamRequestID
reqMetadata *schemas.Metadata
resp *http.Response
reader *sse.EventStreamReader
finishReasonMapper *FinishReasonMapper
errMapper *ErrorMapper
logger *zap.Logger
}

func NewChatStream(
tel *telemetry.Telemetry,
client *http.Client,
req *http.Request,
finishReasonMapper *FinishReasonMapper,
errMapper *ErrorMapper,
logger *zap.Logger,
) *ChatStream {
return &ChatStream{
tel: tel,
client: client,
req: req,
finishReasonMapper: finishReasonMapper,
errMapper: errMapper,
logger: logger,
}
}

Expand All @@ -70,9 +66,8 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
for {
rawEvent, err := s.reader.ReadEvent()
if err != nil {
s.tel.L().Warn(
s.logger.Warn(
"Chat stream is unexpectedly disconnected",
zap.String("provider", providerName),
zap.Error(err),
)

Expand All @@ -82,9 +77,8 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
return nil, clients.ErrProviderUnavailable
}

s.tel.L().Debug(
s.logger.Debug(
"Raw chat stream chunk",
zap.String("provider", providerName),
zap.ByteString("rawChunk", rawEvent),
)

Expand All @@ -99,9 +93,8 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {
}

if !event.HasContent() {
s.tel.L().Debug(
s.logger.Debug(
"Received an empty message in chat stream, skipping it",
zap.String("provider", providerName),
zap.Any("msg", event),
)

Expand All @@ -117,8 +110,8 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) {

// TODO: use objectpool here
return &schemas.ChatStreamChunk{
Provider: providerName,
Cached: false,
Provider: providerName,
ModelName: completionChunk.ModelName,
ModelResponse: schemas.ModelChunkResponse{
Metadata: &schemas.Metadata{
Expand Down Expand Up @@ -156,11 +149,11 @@ func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest)
}

return NewChatStream(
c.tel,
c.httpClient,
httpRequest,
c.finishReasonMapper,
c.errMapper,
c.logger,
), nil
}

Expand Down Expand Up @@ -202,7 +195,7 @@ func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamReque
request.Header.Set("Connection", "keep-alive")

// TODO: this could leak information from messages which may not be a desired thing to have
c.tel.L().Debug(
c.logger.Debug(
"Stream chat request",
zap.String("chatURL", c.chatURL),
zap.Any("payload", chatRequest),
Expand Down
10 changes: 9 additions & 1 deletion pkg/providers/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"net/http"
"net/url"

"go.uber.org/zap"

"glide/pkg/providers/clients"
"glide/pkg/telemetry"
)
Expand All @@ -28,6 +30,7 @@ type Client struct {
config *Config
httpClient *http.Client
tel *telemetry.Telemetry
logger *zap.Logger
}

// NewClient creates a new OpenAI client for the OpenAI API.
Expand All @@ -37,6 +40,10 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
return nil, err
}

logger := tel.L().With(
zap.String("provider", providerName),
)

c := &Client{
baseURL: providerConfig.BaseURL,
chatURL: chatURL,
Expand All @@ -52,7 +59,8 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
MaxIdleConnsPerHost: 2,
},
},
tel: tel,
tel: tel,
logger: logger,
}

return c, nil
Expand Down
1 change: 0 additions & 1 deletion pkg/providers/testing/lang.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ func (m *RespMock) Resp() *schemas.ChatResponse {

func (m *RespMock) RespChunk() *schemas.ChatStreamChunk {
return &schemas.ChatStreamChunk{
ID: "rsp0001",
ModelResponse: schemas.ModelChunkResponse{
Message: schemas.ChatMessage{
Content: m.Msg,
Expand Down
Loading

0 comments on commit 81a9169

Please sign in to comment.