diff --git a/cmd/api/main.go b/cmd/api/main.go index d9dd0ae..c40cb91 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -2,130 +2,143 @@ package main import ( "context" + "database/sql" + "flag" "fmt" "net/http" "os" "os/signal" "syscall" - "time" - - "sybil-api/internal/metrics" - auth "sybil-api/internal/middleware" - "sybil-api/internal/routes/inference" - "sybil-api/internal/routes/search" - "sybil-api/internal/routes/targon" - "sybil-api/internal/setup" + + "sybil-api/internal/middleware" + "sybil-api/internal/routers" "sybil-api/internal/shared" - "github.com/aidarkhanov/nanoid" _ "github.com/go-sql-driver/mysql" "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + emw "github.com/labstack/echo/v4/middleware" + "github.com/redis/go-redis/v9" + "go.uber.org/zap" + "github.com/manifold-inc/manifold-sdk/lib/eflag" "github.com/prometheus/client_golang/prometheus/promhttp" ) func main() { - core, errs := setup.CreateCore() - if errs != nil { - panic(fmt.Sprintf("Failed creating core: %s", errs)) - } - defer core.Shutdown() + // Flags / ENV Variables + writeDSN := flag.String("dsn", "", "Write vitess DSN") + readDSN := flag.String("read-dsn", "", "Write vitess DSN") + metricsAPIKey := flag.String("metrics-api-key", "", "Metrics api key") + redisAddr := flag.String("redis-addr", "", "Redis host:port") + debug := flag.Bool("debug", false, "Debug enabled") - server := echo.New() - e := server.Group("") - e.Use(middleware.CORS()) - e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - reqID, _ := nanoid.Generate("0123456789abcdefghijklmnopqrstuvwxyz", 28) - logger := core.Log.With( - "request_id", "req_"+reqID, - ) - logger = logger.With("externalid", c.Request().Header.Get("X-Dippy-Request-Id")) - - cc := &setup.Context{Context: c, Log: logger, Reqid: reqID} - start := time.Now() - err := next(cc) - duration := time.Since(start) - cc.Log.Infow("end_of_request", "status_code", fmt.Sprintf("%d", cc.Response().Status), "duration", duration.String()) - metrics.ResponseCodes.WithLabelValues(cc.Path(), fmt.Sprintf("%d", cc.Response().Status)).Inc() - return err - } - }) - e.Use(middleware.RecoverWithConfig(middleware.RecoverConfig{ - StackSize: 1 << 10, // 1 KB - LogErrorFunc: func(c echo.Context, err error, stack []byte) error { - defer func() { - _ = core.Log.Sync() - }() - core.Log.Errorw("Api Panic", "error", err.Error()) - return c.String(500, shared.ErrInternalServerError.Err.Error()) - }, - })) + // Leaving these here, as we will need them when we re-add gsearch + //googleSearchEngineID := flag.String("google-search-engine-id", "", "Google search engine id") + //googleAPIKey := flag.String("google-api-key", "", "Google search api key") + //googleACURL := flag.String("google-ac-url", "", "Google AC URL") - e.GET(("/ping"), func(c echo.Context) error { - return c.String(200, "") - }) - userManager := auth.NewUserManager(core.RedisClient, core.RDB, core.Log.With("manager", "user_manager")) - withUser := e.Group("", userManager.ExtractUser) - requiredUser := withUser.Group("", userManager.RequireUser) + err := eflag.SetFlagsFromEnvironment() + if err != nil { + panic(err) + } + flag.Parse() - inferenceGroup := requiredUser.Group("/v1") - inferenceManager, inferenceErr := inference.NewInferenceManager(core.WDB, core.RDB, core.RedisClient, core.Log, core.Debug) - if inferenceErr != nil { - panic(inferenceErr) + // Write DB init + writeDB, err := sql.Open("mysql", *writeDSN) + if err != nil { + panic(fmt.Sprintf("failed initializing sqlClient: %s", err)) } - defer inferenceManager.ShutDown() - - inferenceGroup.GET("/models", inferenceManager.Models) - inferenceGroup.POST("/chat/completions", inferenceManager.ChatRequest) - inferenceGroup.POST("/completions", inferenceManager.CompletionRequest) - inferenceGroup.POST("/embeddings", inferenceManager.EmbeddingRequest) - inferenceGroup.POST("/responses", inferenceManager.ResponsesRequest) - inferenceGroup.POST("/chat/history/new", inferenceManager.CompletionRequestNewHistory) - inferenceGroup.PATCH("/chat/history/:history_id", inferenceManager.UpdateHistory) - - searchGroup := requiredUser.Group("/search") - searchManager, err := search.NewSearchManager(inferenceManager.ProcessOpenaiRequest) + err = writeDB.Ping() if err != nil { - panic(err) + panic(fmt.Sprintf("failed ping to sql db: %s", err)) } - searchGroup.POST("/images", searchManager.GetImages) - searchGroup.POST("", searchManager.Search) - searchGroup.GET("/autocomplete", searchManager.GetAutocomplete) - searchGroup.POST("/sources", searchManager.GetSources) + // Read db init + readDB, err := sql.Open("mysql", *readDSN) + if err != nil { + panic(fmt.Sprintf("failed initializing readSqlClient: %s", err)) + } + err = readDB.Ping() + if err != nil { + panic(fmt.Sprintf("failed to ping read replica sql db: %s", err)) + } - requiredAdmin := requiredUser.Group("", userManager.RequireAdmin) - targonGroup := requiredAdmin.Group("/models") - targonManager, targonErr := targon.NewTargonManager(core.WDB, core.RDB, core.RedisClient, core.Log) - if targonErr != nil { - panic(targonErr) + // Load Redis connection + redisClient := redis.NewClient(&redis.Options{ + Addr: *redisAddr, + Password: "", + DB: 0, + }) + if err := redisClient.Ping(context.Background()).Err(); err != nil { + panic(fmt.Sprintf("failed ping to redis db: %s", err)) } - targonGroup.POST("", targonManager.CreateModel) - targonGroup.DELETE("/:uid", targonManager.DeleteModel) - targonGroup.PATCH("", targonManager.UpdateModel) - metricsGroup := server.Group("/metrics") - metricsGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + defer func() { + if redisClient != nil { + _ = redisClient.Close() + } + if writeDB != nil { + _ = writeDB.Close() + } + if readDB != nil { + _ = readDB.Close() + } + }() + + var logger *zap.Logger + if !*debug { + logger, err = zap.NewProduction() + if err != nil { + panic("Failed init logger") + } + } + if *debug { + logger, err = zap.NewDevelopment() + if err != nil { + panic("Failed init logger") + } + } + log := logger.Sugar() + + e := echo.New() + e.GET(("/ping"), func(c echo.Context) error { + return c.String(200, "") + }) + e.GET("/metrics", echo.WrapHandler(promhttp.Handler()), func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { apiKey, err := shared.ExtractAPIKey(c) if err != nil { return c.String(401, "Missing or invalid API key") } - if apiKey != core.Env.MetricsAPIKey { + if apiKey != *metricsAPIKey { return c.String(401, "Unauthorized API key") } return next(c) } }) - metricsGroup.GET("", echo.WrapHandler(promhttp.Handler())) + base := e.Group("") + base.Use(emw.CORS()) + base.Use(middleware.NewRecoverMiddleware(log)) + base.Use(middleware.NewTrackMiddleware(log)) + + middleware.InitUserMiddleware(redisClient, readDB, log) + + // Register routes + err = routers.RegisterAdminRoutes(base, writeDB, readDB, redisClient, log) + if err != nil { + panic(err) + } + shutdown, err := routers.RegisterInferenceRoutes(base, writeDB, readDB, redisClient, log, *debug) + if err != nil { + panic(err) + } + defer shutdown() go func() { - if err := server.Start(":80"); err != nil && err != http.ErrServerClosed { - server.Logger.Fatal("shutting down the server") + if err := e.Start(":80"); err != nil && err != http.ErrServerClosed { + e.Logger.Fatal("shutting down the server") } }() ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) @@ -135,7 +148,7 @@ func main() { ctx, cancel := context.WithTimeout(context.Background(), shared.DefaultShutdownTimeout) defer cancel() - if err := server.Shutdown(ctx); err != nil { - server.Logger.Fatal(err) + if err := e.Shutdown(ctx); err != nil { + e.Logger.Fatal(err) } } diff --git a/go.mod b/go.mod index 4eaf51b..9569d55 100644 --- a/go.mod +++ b/go.mod @@ -1,12 +1,13 @@ module sybil-api -go 1.24.2 +go 1.25.1 require ( github.com/aidarkhanov/nanoid v1.0.8 github.com/go-sql-driver/mysql v1.8.0 github.com/google/uuid v1.6.0 github.com/labstack/echo/v4 v4.11.4 + github.com/manifold-inc/manifold-sdk v0.0.2 github.com/prometheus/client_golang v1.23.2 github.com/redis/go-redis/v9 v9.5.1 go.uber.org/zap v1.27.0 diff --git a/go.sum b/go.sum index 444ac66..1b658b0 100644 --- a/go.sum +++ b/go.sum @@ -55,6 +55,8 @@ github.com/labstack/echo/v4 v4.11.4 h1:vDZmA+qNeh1pd/cCkEicDMrjtrnMGQ1QFI9gWN1zG github.com/labstack/echo/v4 v4.11.4/go.mod h1:noh7EvLwqDsmh/X/HWKPUl1AjzJrhyptRyEbQJfxen8= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= +github.com/manifold-inc/manifold-sdk v0.0.2 h1:wSm9L67ocMkoCpHsCgC597M9+VvE+lLykDDnyrqYDVs= +github.com/manifold-inc/manifold-sdk v0.0.2/go.mod h1:N8LYSdOvfOGzfAGzcoxDzGbWMd9bkkbfPDFftrFcSCU= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= diff --git a/internal/handlers/inference/history.go b/internal/handlers/inference/history.go new file mode 100644 index 0000000..18a4167 --- /dev/null +++ b/internal/handlers/inference/history.go @@ -0,0 +1,530 @@ +package inference + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "maps" + "strings" + "sybil-api/internal/shared" + "time" + + "github.com/aidarkhanov/nanoid" + "go.uber.org/zap" +) + +// NewHistoryInput contains all inputs needed to create a new chat history entry with inference +type NewHistoryInput struct { + Body []byte + User shared.UserMetadata + RequestID string + Ctx context.Context + LogFields map[string]string + StreamWriter func(token string) error // Optional callback for real-time streaming +} + +// NewHistoryOutput contains the results of creating a new history entry and running inference +type NewHistoryOutput struct { + HistoryID string + HistoryIDJSON string // SSE event for history ID + Stream bool + FinalResponse []byte + Error *HistoryError +} + +// UpdateHistoryInput contains all inputs needed to update an existing chat history +type UpdateHistoryInput struct { + HistoryID string + Messages []shared.ChatMessage + UserID uint64 + Ctx context.Context + LogFields map[string]string +} + +// UpdateHistoryOutput contains the result of updating a history entry +type UpdateHistoryOutput struct { + HistoryID string + UserID uint64 + Message string + Error *HistoryError +} + +// HistoryError represents a structured error for history operations +type HistoryError struct { + StatusCode int + Message string + Err error +} + +func (im *InferenceHandler) CompletionRequestNewHistoryLogic(input *NewHistoryInput) (*NewHistoryOutput, error) { + log := logWithFields(im.Log, input.LogFields) + + // Parse request body + var payload shared.InferenceBody + if err := json.Unmarshal(input.Body, &payload); err != nil { + log.Errorw("Failed to parse request body", "error", err.Error()) + return &NewHistoryOutput{ + Error: &HistoryError{ + StatusCode: 400, + Message: "invalid JSON format", + Err: err, + }, + }, nil + } + + if len(payload.Messages) == 0 { + return &NewHistoryOutput{ + Error: &HistoryError{ + StatusCode: 400, + Message: "messages are required", + Err: errors.New("messages are required"), + }, + }, nil + } + + messages := payload.Messages + + // Generate history ID + historyIDNano, err := nanoid.Generate("0123456789abcdefghijklmnopqrstuvwxyz", 11) + if err != nil { + log.Errorw("Failed to generate history nanoid", "error", err) + return &NewHistoryOutput{ + Error: &HistoryError{ + StatusCode: 500, + Message: "failed to generate history ID", + Err: err, + }, + }, nil + } + historyID := "chat-" + historyIDNano + + // Extract title from first user message + var title *string + for _, msg := range messages { + if msg.Role == "user" && msg.Content != "" { + titleStr := msg.Content + if len(titleStr) > 32 { + titleStr = titleStr[:32] + } + title = &titleStr + break + } + } + + // Marshal messages for DB insert + messagesJSON, err := json.Marshal(messages) + if err != nil { + log.Errorw("Failed to marshal initial messages", "error", err) + return &NewHistoryOutput{ + Error: &HistoryError{ + StatusCode: 500, + Message: "failed to prepare history", + Err: err, + }, + }, nil + } + + // Insert into database + insertQuery := ` + INSERT INTO chat_history ( + user_id, + history_id, + messages, + title, + icon + ) VALUES (?, ?, ?, ?, ?) + ` + + _, err = im.WDB.Exec(insertQuery, + input.User.UserID, + historyID, + string(messagesJSON), + title, + nil, // icon + ) + if err != nil { + log.Errorw("Failed to insert history into database", "error", err) + return &NewHistoryOutput{ + Error: &HistoryError{ + StatusCode: 500, + Message: "failed to create history", + Err: err, + }, + }, nil + } + + log.Infow("Chat history created", "history_id", historyID, "user_id", input.User.UserID) + + // Prepare history ID SSE event + historyIDEvent := map[string]any{ + "type": "history_id", + "id": historyID, + } + historyIDJSON, _ := json.Marshal(historyIDEvent) + + // Build logfields for inference + inferenceLogFields := map[string]string{} + if input.LogFields != nil { + maps.Copy(inferenceLogFields, input.LogFields) + } + inferenceLogFields["history_id"] = historyID + + // Run preprocessing + reqInfo, preErr := im.Preprocess(PreprocessInput{ + Body: input.Body, + User: input.User, + Endpoint: shared.ENDPOINTS.CHAT, + RequestID: input.RequestID, + LogFields: inferenceLogFields, + }) + + if preErr != nil { + log.Warnw("Preprocessing failed", "error", preErr.Err) + return &NewHistoryOutput{ + HistoryID: historyID, + HistoryIDJSON: string(historyIDJSON), + Error: &HistoryError{ + StatusCode: preErr.StatusCode, + Message: "inference error", + Err: preErr.Err, + }, + }, nil + } + + // Run inference with streaming callback + out, reqErr := im.DoInference(InferenceInput{ + Req: reqInfo, + User: input.User, + Ctx: input.Ctx, + LogFields: inferenceLogFields, + StreamWriter: input.StreamWriter, // Pass through the streaming callback + }) + + if reqErr != nil { + if reqErr.StatusCode >= 500 && reqErr.Err != nil { + log.Warnw("Inference error", "error", reqErr.Err.Error()) + } + return &NewHistoryOutput{ + HistoryID: historyID, + HistoryIDJSON: string(historyIDJSON), + Error: &HistoryError{ + StatusCode: reqErr.StatusCode, + Message: "inference error", + Err: reqErr.Err, + }, + }, nil + } + + if out == nil { + return &NewHistoryOutput{ + HistoryID: historyID, + HistoryIDJSON: string(historyIDJSON), + }, nil + } + + // Extract assistant message content from inference output + var assistantContent string + if out.Stream { + assistantContent = extractContentFromInferenceOutput(out) + } else { + assistantContent = extractContentFromFinalResponse(out.FinalResponse) + } + + // Update history with assistant response asynchronously + if assistantContent != "" { + var allMessages []shared.ChatMessage + allMessages = append(allMessages, messages...) + allMessages = append(allMessages, shared.ChatMessage{ + Role: "assistant", + Content: assistantContent, + }) + + allMessagesJSON, err := json.Marshal(allMessages) + if err != nil { + log.Errorw("Failed to marshal complete messages", "error", err) + } else { + go func(userID uint64, historyID string, messagesJSON []byte, log *zap.SugaredLogger) { + updateQuery := ` + UPDATE chat_history + SET messages = ?, updated_at = NOW() + WHERE history_id = ? + ` + + _, err := im.WDB.Exec(updateQuery, string(messagesJSON), historyID) + if err != nil { + log.Errorw("Failed to update history in database", "error", err, "history_id", historyID) + return + } + + log.Infow("Chat history updated with assistant response", "history_id", historyID, "user_id", userID) + + if err := im.updateUserStreak(userID, log); err != nil { + log.Errorw("Failed to update user streak", "error", err, "user_id", userID) + } + }(input.User.UserID, historyID, allMessagesJSON, log) + } + } + + return &NewHistoryOutput{ + HistoryID: historyID, + HistoryIDJSON: string(historyIDJSON), + Stream: out.Stream, + FinalResponse: out.FinalResponse, + }, nil +} + +func (im *InferenceHandler) UpdateHistoryLogic(input *UpdateHistoryInput) (*UpdateHistoryOutput, error) { + log := logWithFields(im.Log, input.LogFields) + + // Check if history exists and get owner user ID + var ownerUserID uint64 + checkQuery := `SELECT user_id FROM chat_history WHERE history_id = ?` + err := im.RDB.QueryRowContext(input.Ctx, checkQuery, input.HistoryID).Scan(&ownerUserID) + if err != nil { + if err == sql.ErrNoRows { + log.Errorw("History not found", "error", err.Error(), "history_id", input.HistoryID) + return &UpdateHistoryOutput{ + Error: &HistoryError{ + StatusCode: 404, + Message: "history not found", + Err: err, + }, + }, nil + } + log.Errorw("Failed to check history", "error", err.Error(), "history_id", input.HistoryID) + return &UpdateHistoryOutput{ + Error: &HistoryError{ + StatusCode: 500, + Message: "internal server error", + Err: err, + }, + }, nil + } + + // Check authorization + if ownerUserID != input.UserID { + log.Errorw("Unauthorized access to history", "history_id", input.HistoryID, "user_id", input.UserID, "owner_id", ownerUserID) + return &UpdateHistoryOutput{ + Error: &HistoryError{ + StatusCode: 403, + Message: "unauthorized", + Err: errors.New("unauthorized access"), + }, + }, nil + } + + log.Infow("Updating chat history", + "history_id", input.HistoryID, + "user_id", input.UserID) + + // Validate messages + if len(input.Messages) == 0 { + return &UpdateHistoryOutput{ + Error: &HistoryError{ + StatusCode: 400, + Message: "messages cannot be empty", + Err: errors.New("messages cannot be empty"), + }, + }, nil + } + + // Marshal messages + messagesJSON, err := json.Marshal(input.Messages) + if err != nil { + log.Errorw("Failed to marshal messages", "error", err) + return &UpdateHistoryOutput{ + Error: &HistoryError{ + StatusCode: 500, + Message: "internal server error", + Err: err, + }, + }, nil + } + + // Update database + updateQuery := ` + UPDATE chat_history + SET messages = ?, updated_at = NOW() + WHERE history_id = ? + ` + + _, err = im.WDB.ExecContext(input.Ctx, updateQuery, string(messagesJSON), input.HistoryID) + if err != nil { + log.Errorw("Failed to update history in database", + "error", err.Error(), + "history_id", input.HistoryID) + return &UpdateHistoryOutput{ + Error: &HistoryError{ + StatusCode: 500, + Message: "internal server error", + Err: err, + }, + }, nil + } + + log.Infow("Successfully updated chat history", + "history_id", input.HistoryID, + "user_id", input.UserID) + + // Update user streak asynchronously + go func(userID uint64, log *zap.SugaredLogger) { + if err := im.updateUserStreak(userID, log); err != nil { + log.Errorw("Failed to update user streak", "error", err, "user_id", userID) + } + }(input.UserID, log) + + return &UpdateHistoryOutput{ + HistoryID: input.HistoryID, + UserID: input.UserID, + Message: "History updated successfully", + }, nil +} + +// extractContentFromInferenceOutput extracts assistant content from inference output +func extractContentFromInferenceOutput(out *InferenceOutput) string { + if out == nil || len(out.FinalResponse) == 0 { + return "" + } + + // FinalResponse contains the marshaled array of response chunks + var chunks []json.RawMessage + if err := json.Unmarshal(out.FinalResponse, &chunks); err != nil { + return "" + } + + var fullContent strings.Builder + for _, chunkData := range chunks { + var chunk shared.Response + if err := json.Unmarshal(chunkData, &chunk); err != nil { + continue + } + + if len(chunk.Choices) == 0 { + continue + } + + choice := chunk.Choices[0] + if choice.Delta == nil { + continue + } + + if choice.Delta.Content != "" { + fullContent.WriteString(choice.Delta.Content) + } + } + + return fullContent.String() +} + +// extractContentFromFinalResponse extracts assistant content from non-streaming response +func extractContentFromFinalResponse(finalResponse []byte) string { + if len(finalResponse) == 0 { + return "" + } + + var response shared.Response + if err := json.Unmarshal(finalResponse, &response); err != nil { + return "" + } + + if len(response.Choices) == 0 { + return "" + } + + choice := response.Choices[0] + if choice.Message != nil { + return choice.Message.Content + } + + return "" +} + +func (im *InferenceHandler) updateUserStreak(userID uint64, log *zap.SugaredLogger) error { + var lastChatStr sql.NullString + var currentStreak uint64 + + err := im.RDB.QueryRow(` + SELECT last_chat, streak + FROM user + WHERE id = ? + `, userID).Scan(&lastChatStr, ¤tStreak) + if err != nil { + return fmt.Errorf("failed to get user streak data: %w", err) + } + + var lastChat sql.NullTime + if lastChatStr.Valid && lastChatStr.String != "" { + formats := []string{ + "2006-01-02 15:04:05", + time.RFC3339, + "2006-01-02T15:04:05Z07:00", + "2006-01-02 15:04:05.000000", + } + + var parsedTime time.Time + var parseErr error + for _, format := range formats { + parsedTime, parseErr = time.Parse(format, lastChatStr.String) + if parseErr == nil { + lastChat = sql.NullTime{Time: parsedTime, Valid: true} + break + } + } + + if parseErr != nil { + log.Warnw("Failed to parse last_chat timestamp", "error", parseErr, "value", lastChatStr.String) + } + } + + now := time.Now() + todayMidnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) + + var newStreak uint64 + updateStreak := false + + if lastChat.Valid { + lastChatDate := lastChat.Time + lastChatMidnight := time.Date(lastChatDate.Year(), lastChatDate.Month(), lastChatDate.Day(), 0, 0, 0, 0, lastChatDate.Location()) + + if !todayMidnight.Equal(lastChatMidnight) { + updateStreak = true + expectedDate := lastChatMidnight.AddDate(0, 0, 1) + if todayMidnight.Equal(expectedDate) { + newStreak = currentStreak + 1 + } else { + newStreak = 1 + } + } else { + newStreak = currentStreak + } + } else { + updateStreak = true + newStreak = 1 + } + + if updateStreak { + _, err = im.WDB.Exec(` + UPDATE user + SET streak = ?, last_chat = ? + WHERE id = ? + `, newStreak, now, userID) + if err != nil { + return fmt.Errorf("failed to update user streak: %w", err) + } + + log.Infow("Updated user streak", "user_id", userID, "streak", newStreak, "last_chat", now) + } else { + _, err = im.WDB.Exec(` + UPDATE user + SET last_chat = ? + WHERE id = ? + `, now, userID) + if err != nil { + return fmt.Errorf("failed to update last_chat: %w", err) + } + } + + return nil +} diff --git a/internal/handlers/inference/inference.go b/internal/handlers/inference/inference.go new file mode 100644 index 0000000..0288bc4 --- /dev/null +++ b/internal/handlers/inference/inference.go @@ -0,0 +1,188 @@ +// Package inference +package inference + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "maps" + "sync" + "time" + + "sybil-api/internal/shared" +) + +type InferenceInput struct { + Req *shared.RequestInfo + User shared.UserMetadata + Ctx context.Context + LogFields map[string]string + StreamWriter func(token string) error // callback for real-time streaming +} + +type InferenceOutput struct { + Stream bool + FinalResponse []byte + ModelID string + TimeToFirstToken time.Duration + TotalTime time.Duration + Usage *shared.Usage + Completed bool + Canceled bool +} + +func (im *InferenceHandler) DoInference(input InferenceInput) (*InferenceOutput, *shared.RequestError) { + if input.Req == nil { + return nil, &shared.RequestError{ + StatusCode: 500, + Err: errors.New("request info missing"), + } + } + reqInfo := input.Req + + im.usageCache.AddInFlightToBucket(reqInfo.UserID) + mu := sync.Mutex{} + mu.Lock() + defer func() { + im.usageCache.RemoveInFlightFromBucket(reqInfo.UserID) + mu.Unlock() + }() + + queryLogFields := map[string]string{} + if input.LogFields != nil { + maps.Copy(queryLogFields, input.LogFields) + } + queryLogFields["model"] = reqInfo.Model + queryLogFields["stream"] = fmt.Sprintf("%t", reqInfo.Stream) + + queryInput := QueryInput{ + Ctx: input.Ctx, + Req: reqInfo, + LogFields: queryLogFields, + StreamWriter: input.StreamWriter, // Pass the callback through + } + + resInfo, qerr := im.QueryModels(queryInput) + if qerr != nil { + return nil, qerr + } + + output := &InferenceOutput{ + Stream: reqInfo.Stream, + ModelID: fmt.Sprintf("%d", resInfo.ModelID), + TimeToFirstToken: resInfo.TimeToFirstToken, + TotalTime: resInfo.TotalTime, + Usage: resInfo.Usage, + Completed: resInfo.Completed, + Canceled: resInfo.Canceled, + FinalResponse: []byte(resInfo.ResponseContent), + } + + log := im.Log.With( + "endpoint", reqInfo.Endpoint, + "user_id", input.User.UserID, + "request_id", reqInfo.ID, + ) + + if resInfo.ResponseContent == "" || !resInfo.Completed { + log.Errorw("No response or incomplete response from model", + "response_content_length", len(resInfo.ResponseContent), + "completed", resInfo.Completed, + "canceled", resInfo.Canceled, + "ttft", resInfo.TimeToFirstToken, + "total_time", resInfo.TotalTime) + return nil, &shared.RequestError{ + StatusCode: 500, + Err: errors.New("no response from model"), + } + } + + go func() { + switch true { + case !resInfo.Completed: + break + case reqInfo.Stream: + var chunks []map[string]any + err := json.Unmarshal([]byte(resInfo.ResponseContent), &chunks) + if err != nil { + log.Errorw( + "Failed to unmarshal streaming ResponseContent as JSON array of chunks", + "error", + err, + "raw_response_content", + resInfo.ResponseContent, + ) + break + } + for i := len(chunks) - 1; i >= 0; i-- { + usageData, usageFieldExists := chunks[i]["usage"] + if usageFieldExists && usageData != nil { + if extractedUsage, extractErr := extractUsageData(chunks[i], reqInfo.Endpoint); extractErr == nil { + resInfo.Usage = extractedUsage + break + } + log.Warnw( + "Failed to extract usage data from a response chunk that had a non-null usage field", + "chunk_index", + i, + ) + break + } + } + case !reqInfo.Stream: + var singleResponse map[string]any + err := json.Unmarshal([]byte(resInfo.ResponseContent), &singleResponse) + if err != nil { + log.Errorw( + "Failed to unmarshal non-streaming ResponseContent as single JSON object", + "error", + err, + "raw_response_content", + resInfo.ResponseContent, + ) + break + } + usageData, usageFieldExists := singleResponse["usage"] + if usageFieldExists && usageData != nil { + if extractedUsage, extractErr := extractUsageData(singleResponse, reqInfo.Endpoint); extractErr == nil { + resInfo.Usage = extractedUsage + break + } + log.Warnw( + "Failed to extract usage data from single response object that had a non-null usage field", + ) + } + default: + break + } + + if resInfo.Usage == nil { + resInfo.Usage = &shared.Usage{IsCanceled: resInfo.Canceled} + } + + totalCredits := shared.CalculateCredits(resInfo.Usage, resInfo.Cost.InputCredits, resInfo.Cost.OutputCredits, resInfo.Cost.CanceledCredits) + + pqi := &shared.ProcessedQueryInfo{ + UserID: reqInfo.UserID, + Model: reqInfo.Model, + ModelID: resInfo.ModelID, + Endpoint: reqInfo.Endpoint, + TotalTime: resInfo.TotalTime, + TimeToFirstToken: resInfo.TimeToFirstToken, + Usage: resInfo.Usage, + Cost: resInfo.Cost, + TotalCredits: totalCredits, + ResponseContent: resInfo.ResponseContent, + RequestContent: reqInfo.Body, + CreatedAt: time.Now(), + ID: reqInfo.ID, + } + + mu.Lock() + im.usageCache.AddRequestToBucket(reqInfo.UserID, pqi, reqInfo.ID) + mu.Unlock() + }() + + return output, nil +} diff --git a/internal/routes/inference/discovery.go b/internal/handlers/inference/inference_discovery.go similarity index 98% rename from internal/routes/inference/discovery.go rename to internal/handlers/inference/inference_discovery.go index c7e48b6..2353fff 100644 --- a/internal/routes/inference/discovery.go +++ b/internal/handlers/inference/inference_discovery.go @@ -18,7 +18,7 @@ type InferenceService struct { Modality string `json:"modality"` } -func (im *InferenceManager) DiscoverModels(ctx context.Context, userID uint64, modelName string) (*InferenceService, error) { +func (im *InferenceHandler) DiscoverModels(ctx context.Context, userID uint64, modelName string) (*InferenceService, error) { cacheKey := fmt.Sprintf("sybil:v1:model:service:%d:%s", userID, modelName) cached, err := im.RedisClient.Get(ctx, cacheKey).Result() if err == nil && cached != "" { diff --git a/internal/routes/inference/inference.go b/internal/handlers/inference/inference_handler.go similarity index 77% rename from internal/routes/inference/inference.go rename to internal/handlers/inference/inference_handler.go index 3d4c4c1..5c799b4 100644 --- a/internal/routes/inference/inference.go +++ b/internal/handlers/inference/inference_handler.go @@ -1,4 +1,3 @@ -// Package inference includes all routes and functionality for Sybil Inference package inference import ( @@ -17,7 +16,7 @@ import ( "go.uber.org/zap" ) -type InferenceManager struct { +type InferenceHandler struct { WDB *sql.DB RDB *sql.DB RedisClient *redis.Client @@ -28,7 +27,7 @@ type InferenceManager struct { usageCache *buckets.UsageCache } -func NewInferenceManager(wdb *sql.DB, rdb *sql.DB, redisClient *redis.Client, log *zap.SugaredLogger, debug bool) (*InferenceManager, error) { +func NewInferenceHandler(wdb *sql.DB, rdb *sql.DB, redisClient *redis.Client, log *zap.SugaredLogger, debug bool) (*InferenceHandler, error) { // check if the databases are connected err := wdb.Ping() if err != nil { @@ -47,7 +46,7 @@ func NewInferenceManager(wdb *sql.DB, rdb *sql.DB, redisClient *redis.Client, lo usageCache := buckets.NewUsageCache(log, wdb) - return &InferenceManager{ + return &InferenceHandler{ WDB: wdb, RDB: rdb, RedisClient: redisClient, @@ -58,7 +57,7 @@ func NewInferenceManager(wdb *sql.DB, rdb *sql.DB, redisClient *redis.Client, lo }, nil } -func (im *InferenceManager) getHTTPClient(modelURL string) *http.Client { +func (im *InferenceHandler) getHTTPClient(modelURL string) *http.Client { parsedURL, err := url.Parse(modelURL) if err != nil { im.Log.Warnw("Failed to parse model URL, using full URL as key", "url", modelURL, "error", err) @@ -95,8 +94,19 @@ func (im *InferenceManager) getHTTPClient(modelURL string) *http.Client { return client } -func (im *InferenceManager) ShutDown() { +func (im *InferenceHandler) ShutDown() { if im.usageCache != nil { im.usageCache.Shutdown() } } + +func logWithFields(logger *zap.SugaredLogger, fields map[string]string) *zap.SugaredLogger { + if len(fields) == 0 { + return logger + } + args := make([]interface{}, 0, len(fields)*2) + for k, v := range fields { + args = append(args, k, v) + } + return logger.With(args...) +} diff --git a/internal/handlers/inference/inference_preprocess.go b/internal/handlers/inference/inference_preprocess.go new file mode 100644 index 0000000..175b564 --- /dev/null +++ b/internal/handlers/inference/inference_preprocess.go @@ -0,0 +1,185 @@ +package inference + +import ( + "encoding/json" + "errors" + "time" + + "sybil-api/internal/shared" +) + +type PreprocessInput struct { + Body []byte + User shared.UserMetadata + Endpoint string + RequestID string + LogFields map[string]string +} + +func (im *InferenceHandler) Preprocess(input PreprocessInput) (*shared.RequestInfo, *shared.RequestError) { + startTime := time.Now() + + newlog := logWithFields(im.Log, input.LogFields) + + // Unmarshal to generic map to set defaults + var payload map[string]any + err := json.Unmarshal(input.Body, &payload) + if err != nil { + newlog.Warnw("failed json unmarshal to payload map", "error", err.Error()) + return nil, &shared.RequestError{StatusCode: 400, Err: errors.New("malformed request")} + } + + // validate models and set defaults + model, ok := payload["model"] + if !ok { + newlog.Infow("missing model parameter", "error", "model is required") + return nil, &shared.RequestError{StatusCode: 400, Err: errors.New("model is required")} + } + + modelName := model.(string) + + newlog = newlog.With("model", modelName, "endpoint", input.Endpoint) + + if input.Endpoint == shared.ENDPOINTS.EMBEDDING { + + inputField, ok := payload["input"] + if !ok { + return nil, &shared.RequestError{ + StatusCode: 400, + Err: errors.New("input is required for embeddings"), + } + } + + switch v := inputField.(type) { + case string: + if v == "" { + return nil, &shared.RequestError{ + StatusCode: 400, + Err: errors.New("input cannot be empty"), + } + } + case []any: + if len(v) == 0 { + return nil, &shared.RequestError{ + StatusCode: 400, + Err: errors.New("input array cannot be empty"), + } + } + default: + return nil, &shared.RequestError{ + StatusCode: 400, + Err: errors.New("input must be string or array of strings"), + } + } + + if (input.User.Credits == 0 && input.User.PlanRequests == 0) && !input.User.AllowOverspend { + newlog.Infow("No credits available", "user_id", input.User.UserID) + return nil, &shared.RequestError{ + StatusCode: 402, + Err: errors.New("insufficient credits"), + } + } + + body, err := json.Marshal(payload) + if err != nil { + newlog.Errorw("Failed to marshal request body", "error", err.Error()) + return nil, &shared.RequestError{StatusCode: 500, Err: errors.New("internal server error")} + } + + return &shared.RequestInfo{ + Body: body, + UserID: input.User.UserID, + Credits: input.User.Credits, + ID: input.RequestID, + StartTime: startTime, + Endpoint: input.Endpoint, + Model: modelName, + Stream: false, + }, nil + } + + if input.Endpoint == shared.ENDPOINTS.RESPONSES { + inputField, ok := payload["input"] + if !ok { + return nil, &shared.RequestError{ + StatusCode: 400, + Err: errors.New("input is required for responses"), + } + } + + inputArray, ok := inputField.([]any) + if !ok { + return nil, &shared.RequestError{ + StatusCode: 400, + Err: errors.New("input must be an array"), + } + } + + if len(inputArray) == 0 { + return nil, &shared.RequestError{ + StatusCode: 400, + Err: errors.New("input array cannot be empty"), + } + } + } + + if (input.User.Credits == 0 && input.User.PlanRequests == 0) && !input.User.AllowOverspend { + newlog.Warnw("Insufficient credits or requests", + "credits", input.User.Credits, + "plan_requests", input.User.PlanRequests, + "allow_overspend", input.User.AllowOverspend) + return nil, &shared.RequestError{ + StatusCode: 402, + Err: errors.New("insufficient requests or credits"), + } + } + + // Set stream default if not specified + if val, ok := payload["stream"]; !ok || val == nil { + payload["stream"] = shared.DefaultStreamOption + } + + stream := payload["stream"].(bool) + + // Add stream to logger context + newlog = newlog.With("stream", stream) + + // If streaming is enabled (either by default or explicitly), include usage data + if stream { + payload["stream_options"] = map[string]any{ + "include_usage": true, + } + } + + // Log user id 3's request parameters + if input.User.UserID == 3 { + newlog.Infow("User 3 request payload", + "model", modelName, + "stream", stream, + "max_tokens", payload["max_tokens"], + "temperature", payload["temperature"], + "top_p", payload["top_p"], + "frequency_penalty", payload["frequency_penalty"], + "presence_penalty", payload["presence_penalty"]) + } + + // repackage body + body, err := json.Marshal(payload) + if err != nil { + newlog.Errorw("Failed to marshal request body", "error", err.Error()) + return nil, &shared.RequestError{StatusCode: 500, Err: errors.New("internal server error")} + } + + reqInfo := &shared.RequestInfo{ + Body: body, + UserID: input.User.UserID, + Credits: input.User.Credits, + ID: input.RequestID, + StartTime: startTime, + Endpoint: input.Endpoint, + Model: modelName, + Stream: stream, + } + + return reqInfo, nil +} diff --git a/internal/routes/inference/querymodels.go b/internal/handlers/inference/inference_query.go similarity index 57% rename from internal/routes/inference/querymodels.go rename to internal/handlers/inference/inference_query.go index 3b28743..a555960 100644 --- a/internal/routes/inference/querymodels.go +++ b/internal/handlers/inference/inference_query.go @@ -10,21 +10,28 @@ import ( "io" "net/http" "strings" - "sybil-api/internal/metrics" - "sybil-api/internal/setup" - "sybil-api/internal/shared" "sync/atomic" "time" - "github.com/labstack/echo/v4" + "sybil-api/internal/metrics" + "sybil-api/internal/shared" ) +type QueryInput struct { + Ctx context.Context + Req *shared.RequestInfo + LogFields map[string]string + StreamWriter func(token string) error // Optional callback for real-time streaming +} + // QueryModels forwards the request to the appropriate model -func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInfo) (*shared.ResponseInfo, *shared.RequestError) { +func (im *InferenceHandler) QueryModels(input QueryInput) (*shared.ResponseInfo, *shared.RequestError) { + newlog := logWithFields(im.Log, input.LogFields) + // Discover inference service - modelMetadata, err := im.DiscoverModels(c.Request().Context(), req.UserID, req.Model) + modelMetadata, err := im.DiscoverModels(input.Ctx, input.Req.UserID, input.Req.Model) if err != nil { - c.Log.Errorw("Service discovery failed", "error", err) + newlog.Errorw("Service discovery failed", "error", err) return nil, &shared.RequestError{ StatusCode: 404, Err: fmt.Errorf("service not found: %w", err), @@ -32,13 +39,13 @@ func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInf } // Add model metadata to logger context for all subsequent logs - c.Log = c.Log.With("model_id", modelMetadata.ModelID, "model_url", modelMetadata.URL) + newlog = newlog.With("model_id", modelMetadata.ModelID, "model_url", modelMetadata.URL) // Initialize http request - route := shared.ROUTES[req.Endpoint] - r, err := http.NewRequest("POST", modelMetadata.URL+route, bytes.NewBuffer(req.Body)) + route := shared.ROUTES[input.Req.Endpoint] + r, err := http.NewRequest("POST", modelMetadata.URL+route, bytes.NewBuffer(input.Req.Body)) if err != nil { - c.Log.Warnw("Failed building request", "error", err.Error()) + newlog.Warnw("Failed building request", "error", err.Error()) return nil, &shared.RequestError{ StatusCode: 400, Err: errors.New("failed building request"), @@ -59,11 +66,11 @@ func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInf var timeoutOccurred atomic.Bool ctx, cancel := context.WithTimeout(context.Background(), shared.DefaultStreamRequestTimeout) timer := time.AfterFunc(shared.DefaultStreamRequestTimeout, func() { - if req.Stream { - c.Log.Warnw("Stream request timeout triggered", + if input.Req.Stream { + newlog.Warnw("Stream request timeout triggered", "timeout_seconds", shared.DefaultStreamRequestTimeout.Seconds(), - "model", req.Model, - "user_id", req.UserID) + "model", input.Req.Model, + "user_id", input.Req.UserID) timeoutOccurred.Store(true) cancel() } @@ -74,12 +81,12 @@ func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInf }() r = r.WithContext(ctx) - preprocessingTime := time.Since(req.StartTime) + preprocessingTime := time.Since(input.Req.StartTime) httpStart := time.Now() - if c.Request().Context().Err() != nil { - c.Log.Warnw("Client already disconnected before HTTP request", - "context_error", c.Request().Context().Err()) + if input.Ctx.Err() != nil { + newlog.Warnw("Client already disconnected before HTTP request", + "context_error", input.Ctx.Err()) } httpClient := im.getHTTPClient(modelMetadata.URL) @@ -90,16 +97,16 @@ func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInf defer func() { if res != nil && res.Body != nil { if closeErr := res.Body.Close(); closeErr != nil { - c.Log.Warnw("Failed to close response body", "error", closeErr) + newlog.Warnw("Failed to close response body", "error", closeErr) } } }() - canceled := c.Request().Context().Err() == context.Canceled - modelLabel := fmt.Sprintf("%d-%s", modelMetadata.ModelID, req.Model) + canceled := input.Ctx.Err() == context.Canceled + modelLabel := fmt.Sprintf("%d-%s", modelMetadata.ModelID, input.Req.Model) if err != nil { - c.Log.Errorw("HTTP request failed", + newlog.Errorw("HTTP request failed", "http_duration_ms", httpDuration.Milliseconds(), "error", err.Error(), "canceled", canceled, @@ -107,37 +114,37 @@ func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInf } if err != nil && timeoutOccurred.Load() { - c.Log.Warnw("Request timed out - likely due to model cold start") - metrics.ErrorCount.WithLabelValues(modelLabel, req.Endpoint, fmt.Sprintf("%d", req.UserID), "cold_start").Inc() + newlog.Warnw("Request timed out - likely due to model cold start") + metrics.ErrorCount.WithLabelValues(modelLabel, input.Req.Endpoint, fmt.Sprintf("%d", input.Req.UserID), "cold_start").Inc() return nil, &shared.RequestError{StatusCode: 503, Err: errors.New("cold start detected, please try again in a few minutes")} } if canceled { - c.Log.Warnw("Request canceled by client", + newlog.Warnw("Request canceled by client", "http_duration_ms", httpDuration.Milliseconds(), - "elapsed_since_start_ms", time.Since(req.StartTime).Milliseconds(), + "elapsed_since_start_ms", time.Since(input.Req.StartTime).Milliseconds(), "had_error", err != nil, "will_continue_processing", true) - metrics.ErrorCount.WithLabelValues(modelLabel, req.Endpoint, fmt.Sprintf("%d", req.UserID), "client_canceled").Inc() + metrics.ErrorCount.WithLabelValues(modelLabel, input.Req.Endpoint, fmt.Sprintf("%d", input.Req.UserID), "client_canceled").Inc() // Don't return error, let it process gracefully } if err != nil && !canceled { - c.Log.Warnw("Failed to send request", + newlog.Warnw("Failed to send request", "error", err, "http_duration_ms", httpDuration.Milliseconds(), - "elapsed_since_start_ms", time.Since(req.StartTime).Milliseconds()) - metrics.ErrorCount.WithLabelValues(modelLabel, req.Endpoint, fmt.Sprintf("%d", req.UserID), "request_failed").Inc() + "elapsed_since_start_ms", time.Since(input.Req.StartTime).Milliseconds()) + metrics.ErrorCount.WithLabelValues(modelLabel, input.Req.Endpoint, fmt.Sprintf("%d", input.Req.UserID), "request_failed").Inc() return nil, &shared.RequestError{StatusCode: 502, Err: errors.New("request failed")} } if res != nil && res.StatusCode != http.StatusOK && !canceled { - c.Log.Warnw("Request failed with non-200 status", + newlog.Warnw("Request failed with non-200 status", "status_code", res.StatusCode, "status", res.Status, "http_duration_ms", httpDuration.Milliseconds(), - "elapsed_since_start_ms", time.Since(req.StartTime).Milliseconds(), + "elapsed_since_start_ms", time.Since(input.Req.StartTime).Milliseconds(), "returning_early", true) - metrics.ErrorCount.WithLabelValues(modelLabel, req.Endpoint, fmt.Sprintf("%d", req.UserID), "request_failed_from_error_code").Inc() + metrics.ErrorCount.WithLabelValues(modelLabel, input.Req.Endpoint, fmt.Sprintf("%d", input.Req.UserID), "request_failed_from_error_code").Inc() return nil, &shared.RequestError{StatusCode: res.StatusCode, Err: errors.New("request failed")} } @@ -148,8 +155,7 @@ func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInf var ttftRecorded bool hasDone := false - if req.Stream && !canceled { // Check if the request is streaming - c.Response().Header().Set("Content-Type", "text/event-stream") + if input.Req.Stream && !canceled { reader := bufio.NewScanner(res.Body) var currentEvent string @@ -158,11 +164,11 @@ func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInf for reader.Scan() { select { case <-ctx.Done(): - c.Log.Warnw("Inference engine request timeout during streaming") + newlog.Warnw("Inference engine request timeout during streaming") break scanner - case <-c.Request().Context().Done(): + case <-input.Ctx.Done(): if !clientDisconnected { - c.Log.Warnw("Client disconnected during streaming, continuing to read from inference engine") + newlog.Warnw("Client disconnected during streaming, continuing to read from inference engine") clientDisconnected = true } default: @@ -173,15 +179,17 @@ func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInf continue } - // Only write to client if they're still connected - if c.Request().Context().Err() == nil { - _, _ = fmt.Fprint(c.Response(), token+"\n\n") - c.Response().Flush() + // Stream token to client immediately via callback (if provided and client still connected) + if input.StreamWriter != nil && !clientDisconnected { + if err := input.StreamWriter(token); err != nil { + newlog.Warnw("Stream writer returned error, client likely disconnected", "error", err) + clientDisconnected = true + } } // Handle Responses API event format - if strings.HasPrefix(token, "event: ") { - currentEvent = strings.TrimPrefix(token, "event: ") + if ce, found := strings.CutPrefix(token, "event: "); found { + currentEvent = ce // Check for completion event if currentEvent == "response.completed" { hasDone = true @@ -194,93 +202,84 @@ func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInf } if !ttftRecorded { - ttft = time.Since(req.StartTime) + ttft = time.Since(input.Req.StartTime) ttftRecorded = true timer.Stop() // Time from HTTP completion to first token = actual model processing/queue time modelProcessingTime := time.Since(httpCompletedAt) - c.Log.Infow("First token received", + newlog.Infow("First token received", "ttft_ms", ttft.Milliseconds(), "preprocessing_ms", preprocessingTime.Milliseconds(), "http_duration_ms", httpDuration.Milliseconds(), "model_processing_ms", modelProcessingTime.Milliseconds()) } - // Handle Chat/Completions [DONE] - if token == "data: [DONE]" { + jsonData := strings.TrimPrefix(token, "data: ") + + if jsonData == "[DONE]" { hasDone = true break scanner } - // Extract the JSON part - jsonData := strings.TrimPrefix(token, "data: ") var rawMessage json.RawMessage err := json.Unmarshal([]byte(jsonData), &rawMessage) if err != nil { - c.Log.Warnw("failed unmarshaling streamed data", "error", err, "token", token) + newlog.Warnw("failed unmarshaling streamed data", "error", err, "token", token) continue } responses = append(responses, rawMessage) } } - // Always collect response content since saving decision is made in ProcessOpenaiRequest responseJSON, err := json.Marshal(responses) if err == nil { responseContent = string(responseJSON) } if !hasDone && ctx.Err() == nil { - c.Log.Errorw("encountered streaming error - no [DONE] marker", + newlog.Errorw("encountered streaming error - no [DONE] marker", "error", errors.New("[DONE] not found"), "responses_received", len(responses), "ttft_recorded", ttftRecorded, "timeout_occurred", timeoutOccurred.Load()) - metrics.ErrorCount.WithLabelValues(modelLabel, req.Endpoint, fmt.Sprintf("%d", req.UserID), "streaming_no_done").Inc() + metrics.ErrorCount.WithLabelValues(modelLabel, input.Req.Endpoint, fmt.Sprintf("%d", input.Req.UserID), "streaming_no_done").Inc() } if !hasDone && ctx.Err() != nil { - c.Log.Warnw("streaming incomplete due to context cancellation", + newlog.Warnw("streaming incomplete due to context cancellation", "context_error", ctx.Err(), "responses_received", len(responses), "ttft_recorded", ttftRecorded, "timeout_occurred", timeoutOccurred.Load(), - "total_elapsed_ms", time.Since(req.StartTime).Milliseconds(), + "total_elapsed_ms", time.Since(input.Req.StartTime).Milliseconds(), "time_spent_in_http_ms", httpDuration.Milliseconds(), "time_spent_streaming_ms", time.Since(httpCompletedAt).Milliseconds()) } if err := reader.Err(); err != nil && !errors.Is(err, context.Canceled) { - c.Log.Errorw("encountered streaming error", "error", err) - metrics.ErrorCount.WithLabelValues(modelLabel, req.Endpoint, fmt.Sprintf("%d", req.UserID), "streaming").Inc() + newlog.Errorw("encountered streaming error", "error", err) + metrics.ErrorCount.WithLabelValues(modelLabel, input.Req.Endpoint, fmt.Sprintf("%d", input.Req.UserID), "streaming").Inc() } } - if !req.Stream && !canceled { // Handle non-streaming response + if !input.Req.Stream && !canceled { // Handle non-streaming response bodyBytes, err := io.ReadAll(res.Body) hasDone = true if err != nil { hasDone = false } if err != nil && ctx.Err() == nil { - c.Log.Warnw("Failed to read non-streaming response body", "error", err) - metrics.ErrorCount.WithLabelValues(modelLabel, req.Endpoint, fmt.Sprintf("%d", req.UserID), "query_model").Inc() + newlog.Warnw("Failed to read non-streaming response body", "error", err) + metrics.ErrorCount.WithLabelValues(modelLabel, input.Req.Endpoint, fmt.Sprintf("%d", input.Req.UserID), "query_model").Inc() return nil, &shared.RequestError{StatusCode: 500, Err: errors.New("failed to read response body")} } responseContent = string(bodyBytes) - // For non-streaming, write the entire response body at once and set Content-Type. - c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSON) - if ctx.Err() == nil { - if _, err := c.Response().Write(bodyBytes); err != nil { - c.Log.Errorw("Failed to write non-streaming response to client", "error", err) - } - } // Calculate timing breakdown - ttft = time.Since(req.StartTime) + ttft = time.Since(input.Req.StartTime) } resInfo := &shared.ResponseInfo{ - Canceled: c.Request().Context().Err() == context.Canceled, + Canceled: input.Ctx.Err() == context.Canceled, Completed: hasDone, - TotalTime: time.Since(req.StartTime), + TotalTime: time.Since(input.Req.StartTime), TimeToFirstToken: ttft, ResponseContent: responseContent, ModelID: modelMetadata.ModelID, @@ -292,7 +291,7 @@ func (im *InferenceManager) QueryModels(c *setup.Context, req *shared.RequestInf } // Log final request state - c.Log.Infow("Request completed", + newlog.Infow("Request completed", "completed", resInfo.Completed, "canceled", resInfo.Canceled, "ttft_ms", ttft.Milliseconds(), diff --git a/internal/handlers/inference/inference_usage.go b/internal/handlers/inference/inference_usage.go new file mode 100644 index 0000000..ed30939 --- /dev/null +++ b/internal/handlers/inference/inference_usage.go @@ -0,0 +1,74 @@ +package inference + +import ( + "errors" + "fmt" + + "sybil-api/internal/shared" +) + +// Helper function to safely extract float64 values from a map +func getTokenCount(usageData map[string]any, field string) (uint64, error) { + value, ok := usageData[field] + if !ok { + return 0, fmt.Errorf("missing %s field", field) + } + floatVal, ok := value.(float64) + if !ok { + return 0, fmt.Errorf("invalid type for %s field", field) + } + return uint64(floatVal), nil +} + +// Helper function to safely extract usage data from response +func extractUsageData(response map[string]any, endpoint string) (*shared.Usage, error) { + usageData, ok := response["usage"].(map[string]any) + if !ok { + return nil, errors.New("missing or invalid usage data") + } + + var promptTokens, completionTokens, totalTokens uint64 + var err error + + // Handle Responses API format (input_tokens, output_tokens) + if endpoint == shared.ENDPOINTS.RESPONSES { + promptTokens, err = getTokenCount(usageData, "input_tokens") + if err != nil { + return nil, fmt.Errorf("error getting input tokens: %w", err) + } + + completionTokens, err = getTokenCount(usageData, "output_tokens") + if err != nil { + return nil, fmt.Errorf("error getting output tokens: %w", err) + } + + totalTokens = promptTokens + completionTokens + } else { + // Handle Chat/Completions format (prompt_tokens, completion_tokens) + promptTokens, err = getTokenCount(usageData, "prompt_tokens") + if err != nil { + return nil, fmt.Errorf("error getting prompt tokens: %w", err) + } + + completionTokens = uint64(0) + if endpoint != shared.ENDPOINTS.EMBEDDING { + completionTokens, err = getTokenCount(usageData, "completion_tokens") + if err != nil { + return nil, fmt.Errorf("error getting completion tokens: %w", err) + } + } + + totalTokens, err = getTokenCount(usageData, "total_tokens") + if err != nil { + return nil, fmt.Errorf("error getting total tokens: %w", err) + } + } + + return &shared.Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: totalTokens, + IsCanceled: false, + }, nil +} + diff --git a/internal/routes/inference/models.go b/internal/handlers/inference/models.go similarity index 78% rename from internal/routes/inference/models.go rename to internal/handlers/inference/models.go index 0ca5b05..19a3a3f 100644 --- a/internal/routes/inference/models.go +++ b/internal/handlers/inference/models.go @@ -5,11 +5,9 @@ import ( "database/sql" "encoding/json" "fmt" - "sybil-api/internal/setup" - "sybil-api/internal/shared" "time" - "github.com/labstack/echo/v4" + "sybil-api/internal/shared" ) type Model struct { @@ -62,70 +60,55 @@ type ModelMetadata struct { MaxInputLength *int `json:"max_input_length,omitempty"` } -func (im *InferenceManager) Models(cc echo.Context) error { - c := cc.(*setup.Context) - - ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second) - defer cancel() - - models, err := fetchModels(ctx, im.RDB, c.User, c) - if err != nil { - c.Log.Errorw("Failed to get models", "error", err.Error()) - return cc.String(500, "Failed to get models") - } - - return c.JSON(200, ModelList{ - Data: models, - }) -} +func (im *InferenceHandler) ListModels(ctx context.Context, userID *uint64, logfields map[string]string) ([]Model, error) { + log := logWithFields(im.Log, logfields) -func fetchModels(ctx context.Context, db *sql.DB, user *shared.UserMetadata, c *setup.Context) ([]Model, error) { - baseQuery := ` - ` - switch true { - case user != nil: - if models, err := queryModels(ctx, db, c, baseQuery+` - SELECT name, DATE_FORMAT(created_at, '%Y-%m-%d %H:%i:%s') as created, - icpt, ocpt, crc, metadata, modality, supported_endpoints - FROM model - WHERE enabled = true AND allowed_user_id = ? - ORDER BY name ASC`, - user.UserID); err == nil && len(models) > 0 { - return models, nil - } else if err != nil { - c.Log.Warnw("Error querying user-specific models, falling back to public", "error", err.Error()) - } - fallthrough - default: - // public models - return queryModels(ctx, db, c, ` + if userID != nil { + userModels, err := im.queryModels(ctx, logfields, ` SELECT name, DATE_FORMAT(created_at, '%Y-%m-%d %H:%i:%s') as created, icpt, ocpt, crc, metadata, modality, supported_endpoints FROM model - WHERE enabled = true AND allowed_user_id is NULL - ORDER BY name ASC`) + WHERE enabled = true AND allowed_user_id = ? + ORDER BY name ASC`, *userID) + + if err != nil { + log.Warnw("Error querying user-specific models, falling back to public", "error", err.Error()) + } else if len(userModels) > 0 { + return userModels, nil + } } + + return im.queryModels(ctx, logfields, ` + SELECT name, DATE_FORMAT(created_at, '%Y-%m-%d %H:%i:%s') as created, + icpt, ocpt, crc, metadata, modality, supported_endpoints + FROM model + WHERE enabled = true AND allowed_user_id is NULL + ORDER BY name ASC`) } -func queryModels(ctx context.Context, db *sql.DB, c *setup.Context, query string, args ...any) ([]Model, error) { - rows, err := db.QueryContext(ctx, query, args...) +func (im *InferenceHandler) queryModels(ctx context.Context, logfields map[string]string, query string, args ...any) ([]Model, error) { + log := logWithFields(im.Log, logfields) + + rows, err := im.RDB.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer rows.Close() + defer func() { + _ = rows.Close() + }() var models []Model for rows.Next() { model, err := scanModel(rows) if err != nil { - c.Log.Warnw("Failed to scan model row", "error", err.Error()) + log.Warnw("Failed to scan model row", "error", err.Error()) continue } models = append(models, model) } if err := rows.Err(); err != nil { - c.Log.Errorw("Error iterating over rows", "error", err.Error()) + log.Errorw("Error iterating over rows", "error", err.Error()) return nil, err } diff --git a/internal/routes/search/autocomplete.go b/internal/handlers/search/autocomplete.go similarity index 100% rename from internal/routes/search/autocomplete.go rename to internal/handlers/search/autocomplete.go diff --git a/internal/routes/search/google.go b/internal/handlers/search/google.go similarity index 87% rename from internal/routes/search/google.go rename to internal/handlers/search/google.go index c67d3a2..34524c0 100644 --- a/internal/routes/search/google.go +++ b/internal/handlers/search/google.go @@ -4,7 +4,6 @@ package search import ( "context" "encoding/json" - "errors" "fmt" "net/url" "strconv" @@ -28,40 +27,19 @@ type SearchManager struct { QueryInference InferenceFunc } -func NewSearchManager(queryInference InferenceFunc) (*SearchManager, []error) { - var errs []error - - googleSearchEngineID, err := shared.SafeEnv("GOOGLE_SEARCH_ENGINE_ID") - if err != nil { - errs = append(errs, err) - } - - googleAPIKey, err := shared.SafeEnv("GOOGLE_API_KEY") - if err != nil { - errs = append(errs, err) - } - - googleACURL, err := shared.SafeEnv("GOOGLE_AC_URL") - if err != nil { - errs = append(errs, err) - } - googleService, err := customsearch.NewService(context.Background(), option.WithAPIKey(googleAPIKey)) +func NewSearchManager(queryInference InferenceFunc, gseid, gapikey, gacurl string) (*SearchManager, error) { + googleService, err := customsearch.NewService(context.Background(), option.WithAPIKey(gapikey)) if err != nil { - return nil, []error{errors.New("failed to connect to google service")} - } - - if len(errs) != 0 { - return nil, errs + return nil, fmt.Errorf("failed to connect to google service: %s", err) } return &SearchManager{ - GoogleSearchEngineID: googleSearchEngineID, - GoogleAPIKey: googleAPIKey, + GoogleSearchEngineID: gseid, + GoogleAPIKey: gapikey, GoogleService: googleService, - GoogleACURL: googleACURL, + GoogleACURL: gacurl, QueryInference: queryInference, }, nil - } func QueryGoogleSearch(googleService *customsearch.Service, Log *zap.SugaredLogger, googleSearchEngineID string, query string, page int, searchType ...string) (*shared.SearchResponseBody, error) { diff --git a/internal/routes/search/images.go b/internal/handlers/search/images.go similarity index 100% rename from internal/routes/search/images.go rename to internal/handlers/search/images.go diff --git a/internal/routes/search/search.go b/internal/handlers/search/search.go similarity index 100% rename from internal/routes/search/search.go rename to internal/handlers/search/search.go diff --git a/internal/routes/search/sources.go b/internal/handlers/search/sources.go similarity index 100% rename from internal/routes/search/sources.go rename to internal/handlers/search/sources.go diff --git a/internal/routes/targon/create.go b/internal/handlers/targon/create.go similarity index 100% rename from internal/routes/targon/create.go rename to internal/handlers/targon/create.go diff --git a/internal/routes/targon/delete.go b/internal/handlers/targon/delete.go similarity index 100% rename from internal/routes/targon/delete.go rename to internal/handlers/targon/delete.go diff --git a/internal/routes/targon/targon.go b/internal/handlers/targon/targon.go similarity index 100% rename from internal/routes/targon/targon.go rename to internal/handlers/targon/targon.go diff --git a/internal/routes/targon/update.go b/internal/handlers/targon/update.go similarity index 100% rename from internal/routes/targon/update.go rename to internal/handlers/targon/update.go diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index aaec2f6..1d9668e 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -1,8 +1,10 @@ -// Package auth defines middleware route based authentication -package auth +// Package middleware defines middleware route based authentication +package middleware import ( "database/sql" + "errors" + "sync" "sybil-api/internal/setup" "sybil-api/internal/shared" @@ -12,21 +14,42 @@ import ( "go.uber.org/zap" ) -type UserManager struct { +type UserMiddleware struct { redis *redis.Client rdb *sql.DB log *zap.SugaredLogger } -func NewUserManager(r *redis.Client, rdb *sql.DB, log *zap.SugaredLogger) *UserManager { - return &UserManager{ +var ( + userManager *UserMiddleware + userManagerMutex sync.Mutex +) + +func InitUserMiddleware(r *redis.Client, rdb *sql.DB, log *zap.SugaredLogger) { + userManagerMutex.Lock() + defer userManagerMutex.Unlock() + um := NewUserMiddleware(r, rdb, log) + userManager = um +} + +func GetUserMiddleware() (*UserMiddleware, error) { + userManagerMutex.Lock() + defer userManagerMutex.Unlock() + if userManager == nil { + return nil, errors.New("user manager not initalized") + } + return userManager, nil +} + +func NewUserMiddleware(r *redis.Client, rdb *sql.DB, log *zap.SugaredLogger) *UserMiddleware { + return &UserMiddleware{ redis: r, rdb: rdb, log: log, } } -func (u *UserManager) ExtractUser(next echo.HandlerFunc) echo.HandlerFunc { +func (u *UserMiddleware) ExtractUser(next echo.HandlerFunc) echo.HandlerFunc { return func(cc echo.Context) error { c := cc.(*setup.Context) c.User = nil @@ -45,7 +68,7 @@ func (u *UserManager) ExtractUser(next echo.HandlerFunc) echo.HandlerFunc { } } -func (u *UserManager) RequireUser(next echo.HandlerFunc) echo.HandlerFunc { +func (u *UserMiddleware) RequireUser(next echo.HandlerFunc) echo.HandlerFunc { return func(cc echo.Context) error { c := cc.(*setup.Context) if c.User == nil { @@ -55,7 +78,7 @@ func (u *UserManager) RequireUser(next echo.HandlerFunc) echo.HandlerFunc { } } -func (u *UserManager) RequireAdmin(next echo.HandlerFunc) echo.HandlerFunc { +func (u *UserMiddleware) RequireAdmin(next echo.HandlerFunc) echo.HandlerFunc { return func(cc echo.Context) error { c := cc.(*setup.Context) if c.User == nil || c.User.Role != "ADMIN" { diff --git a/internal/middleware/metrics.go b/internal/middleware/metrics.go new file mode 100644 index 0000000..c4ee4bf --- /dev/null +++ b/internal/middleware/metrics.go @@ -0,0 +1,48 @@ +package middleware + +import ( + "fmt" + "time" + + "sybil-api/internal/metrics" + "sybil-api/internal/setup" + "sybil-api/internal/shared" + + "github.com/aidarkhanov/nanoid" + "github.com/labstack/echo/v4" + emw "github.com/labstack/echo/v4/middleware" + "go.uber.org/zap" +) + +func NewTrackMiddleware(log *zap.SugaredLogger) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + reqID, _ := nanoid.Generate("0123456789abcdefghijklmnopqrstuvwxyz", 28) + logger := log.With( + "request_id", "req_"+reqID, + ) + logger = logger.With("externalid", c.Request().Header.Get("X-Dippy-Request-Id")) + + cc := &setup.Context{Context: c, Log: logger, Reqid: reqID} + start := time.Now() + err := next(cc) + duration := time.Since(start) + cc.Log.Infow("end_of_request", "status_code", fmt.Sprintf("%d", cc.Response().Status), "duration", duration.String()) + metrics.ResponseCodes.WithLabelValues(cc.Path(), fmt.Sprintf("%d", cc.Response().Status)).Inc() + return err + } + } +} + +func NewRecoverMiddleware(log *zap.SugaredLogger) echo.MiddlewareFunc { + return emw.RecoverWithConfig(emw.RecoverConfig{ + StackSize: 1 << 10, // 1 KB + LogErrorFunc: func(c echo.Context, err error, stack []byte) error { + defer func() { + _ = log.Sync() + }() + log.Errorw("Api Panic", "error", err.Error()) + return c.String(500, shared.ErrInternalServerError.Err.Error()) + }, + }) +} diff --git a/internal/middleware/users.go b/internal/middleware/users.go index 2916a49..2bbf32a 100644 --- a/internal/middleware/users.go +++ b/internal/middleware/users.go @@ -1,4 +1,4 @@ -package auth +package middleware import ( "context" @@ -9,7 +9,7 @@ import ( "sybil-api/internal/shared" ) -func (u *UserManager) getUserMetadataFromKey(apiKey string, ctx context.Context) (*shared.UserMetadata, error) { +func (u *UserMiddleware) getUserMetadataFromKey(apiKey string, ctx context.Context) (*shared.UserMetadata, error) { var userMetadata shared.UserMetadata userMetadata.APIKey = apiKey diff --git a/internal/routers/admin.go b/internal/routers/admin.go new file mode 100644 index 0000000..f69af94 --- /dev/null +++ b/internal/routers/admin.go @@ -0,0 +1,31 @@ +package routers + +import ( + "database/sql" + + "sybil-api/internal/middleware" + "sybil-api/internal/handlers/targon" + + "github.com/labstack/echo/v4" + "github.com/redis/go-redis/v9" + "go.uber.org/zap" +) + +func RegisterAdminRoutes(e *echo.Group, wdb *sql.DB, rdb *sql.DB, redisClient *redis.Client, log *zap.SugaredLogger) error { + targonManager, err := targon.NewTargonManager(wdb, rdb, redisClient, log) + if err != nil { + return err + } + umw, err := middleware.GetUserMiddleware() + if err != nil { + return err + } + + requireAdmin := e.Group("", umw.ExtractUser, umw.RequireAdmin) + + requireAdmin.POST("/models", targonManager.CreateModel) + requireAdmin.DELETE("/models/:uid", targonManager.DeleteModel) + requireAdmin.PATCH("/models", targonManager.UpdateModel) + + return nil +} diff --git a/internal/routers/inference.go b/internal/routers/inference.go new file mode 100644 index 0000000..05ce15c --- /dev/null +++ b/internal/routers/inference.go @@ -0,0 +1,279 @@ +// Package routers +package routers + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "net/http" + "time" + + inferenceRoute "sybil-api/internal/handlers/inference" + "sybil-api/internal/middleware" + "sybil-api/internal/setup" + "sybil-api/internal/shared" + + "github.com/labstack/echo/v4" + "github.com/redis/go-redis/v9" + "go.uber.org/zap" +) + +type InferenceRouter struct { + ih *inferenceRoute.InferenceHandler +} + +func RegisterInferenceRoutes(e *echo.Group, wdb *sql.DB, rdb *sql.DB, redisClient *redis.Client, log *zap.SugaredLogger, debug bool) (func(), error) { + inferenceManager, inferenceErr := inferenceRoute.NewInferenceHandler(wdb, rdb, redisClient, log, debug) + if inferenceErr != nil { + return nil, inferenceErr + } + defer inferenceManager.ShutDown() + umw, err := middleware.GetUserMiddleware() + if err != nil { + return nil, err + } + + inferenceRouter := InferenceRouter{ih: inferenceManager} + + v1 := e.Group("v1") + extractUser := v1.Group("", umw.ExtractUser) + requireUser := v1.Group("", umw.ExtractUser, umw.RequireUser) + + extractUser.GET("/models", inferenceRouter.GetModels) + requireUser.POST("/chat/completions", inferenceRouter.ChatRequest) + requireUser.POST("/completions", inferenceRouter.CompletionRequest) + requireUser.POST("/embeddings", inferenceRouter.EmbeddingRequest) + requireUser.POST("/responses", inferenceRouter.ResponsesRequest) + requireUser.POST("/chat/history/new", inferenceRouter.CompletionRequestNewHistory) + requireUser.PATCH("/chat/history/:history_id", inferenceRouter.UpdateHistory) + return inferenceManager.ShutDown, nil +} + +type ModelList struct { + Data []inferenceRoute.Model `json:"data"` +} + +func (ir *InferenceRouter) GetModels(cc echo.Context) error { + c := cc.(*setup.Context) + + ctx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second) + defer cancel() + + logfields := map[string]string{ + "endpoint": "models", + } + if c.User != nil { + logfields["user_id"] = fmt.Sprintf("%d", c.User.UserID) + } + + var userID *uint64 + if c.User != nil { + userID = &c.User.UserID + } + + models, err := ir.ih.ListModels(ctx, userID, logfields) + if err != nil { + c.Log.Errorw("Failed to get models", "error", err.Error()) + return cc.String(500, "Failed to get models") + } + + return c.JSON(200, ModelList{ + Data: models, + }) +} + +func (ir *InferenceRouter) ChatRequest(cc echo.Context) error { + _, err := ir.Inference(cc, shared.ENDPOINTS.CHAT) + return err +} + +func (ir *InferenceRouter) CompletionRequest(cc echo.Context) error { + _, err := ir.Inference(cc, shared.ENDPOINTS.COMPLETION) + return err +} + +func (ir *InferenceRouter) EmbeddingRequest(cc echo.Context) error { + _, err := ir.Inference(cc, shared.ENDPOINTS.EMBEDDING) + return err +} + +func (ir *InferenceRouter) ResponsesRequest(cc echo.Context) error { + _, err := ir.Inference(cc, shared.ENDPOINTS.RESPONSES) + return err +} + +func (ir *InferenceRouter) Inference(cc echo.Context, endpoint string) (string, error) { + c := cc.(*Context) + body, err := readRequestBody(c) + if err != nil { + return "", c.JSON(http.StatusBadRequest, shared.OpenAIError{ + Message: "failed to read request body", + Object: "error", + Type: "BadRequest", + Code: http.StatusBadRequest, + }) + } + + logfields := buildLogFields(c, endpoint, nil) + + reqInfo, preErr := ir.ih.Preprocess(inferenceRoute.PreprocessInput{ + Body: body, + User: *c.User, + Endpoint: endpoint, + RequestID: c.Reqid, + LogFields: logfields, + }) + if preErr != nil { + message := "inference error" + if preErr.Err != nil { + message = preErr.Err.Error() + } + return "", c.JSON(preErr.StatusCode, shared.OpenAIError{ + Message: message, + Object: "error", + Type: "InternalError", + Code: preErr.StatusCode, + }) + } + + var streamCallback func(token string) error + if reqInfo.Stream { + setupSSEHeaders(c) + streamCallback = createStreamCallback(c) + } + + out, reqErr := ir.ih.DoInference(inferenceRoute.InferenceInput{ + Req: reqInfo, + User: *c.User, + Ctx: c.Request().Context(), + LogFields: logfields, + StreamWriter: streamCallback, // Pass the callback for real-time streaming + }) + + if reqErr != nil { + if reqErr.StatusCode >= 500 && reqErr.Err != nil { + c.Log.Warnw("Inference error", "error", reqErr.Err.Error()) + } + message := "inference error" + if reqErr.Err != nil { + message = reqErr.Err.Error() + } + + if reqInfo.Stream { + c.Log.Errorw("Error after streaming started", "error", message) + return "", nil + } + + return "", c.JSON(reqErr.StatusCode, shared.OpenAIError{ + Message: message, + Object: "error", + Type: "InternalError", + Code: reqErr.StatusCode, + }) + } + + if out == nil { + return "", nil + } + + if out.Stream { + return string(out.FinalResponse), nil + } + + c.Response().Header().Set("Content-Type", "application/json") + c.Response().WriteHeader(http.StatusOK) + if _, err := c.Response().Write(out.FinalResponse); err != nil { + c.Log.Errorw("Failed to write response", "error", err) + return "", err + } + return string(out.FinalResponse), nil +} + +func (ir *InferenceRouter) CompletionRequestNewHistory(cc echo.Context) error { + c := cc.(*Context) + + body, err := readRequestBody(c) + if err != nil { + return c.JSON(http.StatusBadRequest, map[string]string{"error": "failed to read request body"}) + } + + logfields := buildLogFields(c, shared.ENDPOINTS.CHAT, nil) + + setupSSEHeaders(c) + streamCallback := createStreamCallback(c) + + output, err := ir.ih.CompletionRequestNewHistoryLogic(&inferenceRoute.NewHistoryInput{ + Body: body, + User: *c.User, + RequestID: c.Reqid, + Ctx: c.Request().Context(), + LogFields: logfields, + StreamWriter: streamCallback, // Pass callback for real-time streaming + }) + if err != nil { + c.Log.Errorw("History creation failed", "error", err) + return nil + } + + if output.Error != nil { + c.Log.Errorw("History logic error", "error", output.Error.Message) + return nil + } + + _, _ = fmt.Fprintf(c.Response(), "data: %s\n\n", output.HistoryIDJSON) + c.Response().Flush() + + if !output.Stream && len(output.FinalResponse) > 0 { + _, _ = fmt.Fprintf(c.Response(), "data: %s\n\n", string(output.FinalResponse)) + c.Response().Flush() + } + + return nil +} + +type UpdateHistoryRequest struct { + Messages []shared.ChatMessage `json:"messages,omitempty"` +} + +// UpdateHistory is the HTTP handler wrapper for the history update logic +func (ir *InferenceRouter) UpdateHistory(cc echo.Context) error { + c := cc.(*Context) + + body, err := readRequestBody(c) + if err != nil { + return c.JSON(http.StatusInternalServerError, shared.ErrInternalServerError) + } + + var req UpdateHistoryRequest + if err := json.Unmarshal(body, &req); err != nil { + c.Log.Errorw("Failed to unmarshal request body", "error", err.Error()) + return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid JSON format"}) + } + + historyID := c.Param("history_id") + + logfields := buildLogFields(c, shared.ENDPOINTS.CHAT, map[string]string{"history_id": historyID}) + + output, err := ir.ih.UpdateHistoryLogic(&inferenceRoute.UpdateHistoryInput{ + HistoryID: historyID, + Messages: req.Messages, + UserID: c.User.UserID, + Ctx: c.Request().Context(), + LogFields: logfields, + }) + if err != nil { + c.Log.Errorw("History update failed", "error", err) + return c.JSON(http.StatusInternalServerError, shared.ErrInternalServerError) + } + + if output.Error != nil { + return c.JSON(output.Error.StatusCode, map[string]string{"error": output.Error.Message}) + } + + return c.JSON(http.StatusOK, map[string]any{ + "message": output.Message, + "id": output.HistoryID, + "user_id": output.UserID, + }) +} diff --git a/internal/routers/routers.go b/internal/routers/routers.go new file mode 100644 index 0000000..cde6169 --- /dev/null +++ b/internal/routers/routers.go @@ -0,0 +1,60 @@ +package routers + +import ( + "fmt" + "io" + "maps" + "net/http" + + "sybil-api/internal/shared" + + "github.com/labstack/echo/v4" + "go.uber.org/zap" +) + +type Context struct { + echo.Context + Log *zap.SugaredLogger + Reqid string + User *shared.UserMetadata +} + +func readRequestBody(c *Context) ([]byte, error) { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + c.Log.Errorw("Failed to read request body", "error", err.Error()) + return nil, err + } + return body, nil +} + +func buildLogFields(c *Context, endpoint string, extras map[string]string) map[string]string { + fields := map[string]string{ + "endpoint": endpoint, + "user_id": fmt.Sprintf("%d", c.User.UserID), + "request_id": c.Reqid, + } + maps.Copy(fields, extras) + return fields +} + +func setupSSEHeaders(c *Context) { + c.Response().Header().Set("Content-Type", "text/event-stream") + c.Response().Header().Set("Cache-Control", "no-cache") + c.Response().Header().Set("Connection", "keep-alive") + c.Response().WriteHeader(http.StatusOK) +} + +func createStreamCallback(c *Context) func(token string) error { + return func(token string) error { + if c.Request().Context().Err() != nil { + return c.Request().Context().Err() + } + _, err := fmt.Fprintf(c.Response(), "%s\n\n", token) + if err != nil { + return err + } + c.Response().Flush() + return nil + } +} diff --git a/internal/routes/inference/history.go b/internal/routes/inference/history.go deleted file mode 100644 index 23f6917..0000000 --- a/internal/routes/inference/history.go +++ /dev/null @@ -1,365 +0,0 @@ -package inference - -import ( - "database/sql" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "sybil-api/internal/setup" - "sybil-api/internal/shared" - "time" - - "github.com/aidarkhanov/nanoid" - "github.com/labstack/echo/v4" - "go.uber.org/zap" -) - -type CreateHistoryRequest struct { - Messages []shared.ChatMessage `json:"messages"` -} - -type UpdateHistoryRequest struct { - Messages []shared.ChatMessage `json:"messages,omitempty"` -} - -func (im *InferenceManager) CompletionRequestNewHistory(cc echo.Context) error { - c := cc.(*setup.Context) - - body, err := io.ReadAll(c.Request().Body) - if err != nil { - c.Log.Errorw("Failed to read request body", "error", err.Error()) - return c.JSON(http.StatusBadRequest, map[string]string{"error": "failed to read request body"}) - } - - var payload shared.InferenceBody - if err := json.Unmarshal(body, &payload); err != nil { - c.Log.Errorw("Failed to parse request body", "error", err.Error()) - return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid JSON format"}) - } - - if len(payload.Messages) == 0 { - return c.JSON(http.StatusBadRequest, map[string]string{"error": "messages are required"}) - } - - messages := payload.Messages - - historyIDNano, err := nanoid.Generate("0123456789abcdefghijklmnopqrstuvwxyz", 11) - if err != nil { - c.Log.Errorw("Failed to generate history nanoid", "error", err) - return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to generate history ID"}) - } - historyID := "chat-" + historyIDNano - - var title *string - for _, msg := range messages { - if msg.Role == "user" && msg.Content != "" { - titleStr := msg.Content - if len(titleStr) > 32 { - titleStr = titleStr[:32] - } - title = &titleStr - break - } - } - - messagesJSON, err := json.Marshal(messages) - if err != nil { - c.Log.Errorw("Failed to marshal initial messages", "error", err) - return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to prepare history"}) - } - - insertQuery := ` - INSERT INTO chat_history ( - user_id, - history_id, - messages, - title, - icon - ) VALUES (?, ?, ?, ?, ?) - ` - - _, err = im.WDB.Exec(insertQuery, - c.User.UserID, - historyID, - string(messagesJSON), - title, - nil, // icon - ) - if err != nil { - c.Log.Errorw("Failed to insert history into database", "error", err) - return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create history"}) - } - - c.Log.Infow("Chat history created", "history_id", historyID, "user_id", c.User.UserID) - - c.Response().Header().Set("Content-Type", "text/event-stream") - historyIDEvent := map[string]any{ - "type": "history_id", - "id": historyID, - } - historyIDJSON, _ := json.Marshal(historyIDEvent) - fmt.Fprintf(c.Response(), "data: %s\n\n", string(historyIDJSON)) - c.Response().Flush() - - c.Request().Body = io.NopCloser(strings.NewReader(string(body))) - - responseContent, err := im.CompletionRequestHistory(c) - - statusCode := c.Response().Status - if statusCode >= 400 { - c.Log.Warnw("Not updating history due to error status code", "status_code", statusCode) - if c.Response().Committed { - return nil - } - if err != nil { - return err - } - return nil - } - - if err != nil { - if c.Response().Committed { - return nil - } - return err - } - - var allMessages []shared.ChatMessage - allMessages = append(allMessages, messages...) - - if content := extractContentFromResponse(responseContent); content != "" { - allMessages = append(allMessages, shared.ChatMessage{ - Role: "assistant", - Content: content, - }) - } - - allMessagesJSON, err := json.Marshal(allMessages) - if err != nil { - c.Log.Errorw("Failed to marshal complete messages", "error", err) - return nil - } - - go func(userID uint64, historyID string, messagesJSON []byte, log *zap.SugaredLogger) { - updateQuery := ` - UPDATE chat_history - SET messages = ?, updated_at = NOW() - WHERE history_id = ? - ` - - _, err := im.WDB.Exec(updateQuery, string(messagesJSON), historyID) - if err != nil { - log.Errorw("Failed to update history in database", "error", err, "history_id", historyID) - return - } - - log.Infow("Chat history updated with assistant response", "history_id", historyID, "user_id", userID) - - if err := im.updateUserStreak(userID, log); err != nil { - log.Errorw("Failed to update user streak", "error", err, "user_id", userID) - } - }(c.User.UserID, historyID, allMessagesJSON, c.Log) - - return nil -} - -func (im *InferenceManager) UpdateHistory(cc echo.Context) error { - c := cc.(*setup.Context) - - historyIDStr := c.Param("history_id") - - var userID uint64 - checkQuery := `SELECT user_id FROM chat_history WHERE history_id = ?` - err := im.RDB.QueryRowContext(c.Request().Context(), checkQuery, historyIDStr).Scan(&userID) - if err != nil { - if err == sql.ErrNoRows { - c.Log.Errorw("History not found", "error", err.Error(), "history_id", historyIDStr) - return c.JSON(http.StatusNotFound, map[string]string{"error": "history not found"}) - } - c.Log.Errorw("Failed to check history", "error", err.Error(), "history_id", historyIDStr) - return c.JSON(http.StatusInternalServerError, shared.ErrInternalServerError) - } - - if userID != c.User.UserID { - c.Log.Errorw("Unauthorized access to history", "history_id", historyIDStr, "user_id", c.User.UserID, "owner_id", userID) - return c.JSON(http.StatusForbidden, map[string]string{"error": "unauthorized"}) - } - - var req UpdateHistoryRequest - body, err := io.ReadAll(c.Request().Body) - if err != nil { - c.Log.Errorw("Failed to read request body", "error", err.Error()) - return c.JSON(http.StatusInternalServerError, shared.ErrInternalServerError) - } - - if err := json.Unmarshal(body, &req); err != nil { - c.Log.Errorw("Failed to unmarshal request body", "error", err.Error()) - return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid JSON format"}) - } - - c.Log.Infow("Updating chat history", - "history_id", historyIDStr, - "user_id", c.User.UserID) - - if len(req.Messages) == 0 { - return c.JSON(http.StatusBadRequest, map[string]string{"error": "messages cannot be empty"}) - } - - messagesJSON, err := json.Marshal(req.Messages) - if err != nil { - c.Log.Errorw("Failed to marshal messages", "error", err) - return c.JSON(http.StatusInternalServerError, shared.ErrInternalServerError) - } - - updateQuery := ` - UPDATE chat_history - SET messages = ?, updated_at = NOW() - WHERE history_id = ? - ` - - _, err = im.WDB.ExecContext(c.Request().Context(), updateQuery, string(messagesJSON), historyIDStr) - if err != nil { - c.Log.Errorw("Failed to update history in database", - "error", err.Error(), - "history_id", historyIDStr) - return c.JSON(http.StatusInternalServerError, shared.ErrInternalServerError) - } - - c.Log.Infow("Successfully updated chat history", - "history_id", historyIDStr, - "user_id", c.User.UserID) - - go func(userID uint64, log *zap.SugaredLogger) { - if err := im.updateUserStreak(userID, log); err != nil { - log.Errorw("Failed to update user streak", "error", err, "user_id", userID) - } - }(c.User.UserID, c.Log) - - return c.JSON(http.StatusOK, map[string]any{ - "message": "History updated successfully", - "id": historyIDStr, - "user_id": c.User.UserID, - }) -} - -func extractContentFromResponse(responseContent string) string { - if responseContent == "" { - return "" - } - return extractContentFromStreamingResponse(responseContent) -} - -func extractContentFromStreamingResponse(responseContent string) string { - var chunks []shared.Response - if err := json.Unmarshal([]byte(responseContent), &chunks); err != nil { - return "" - } - - var fullContent strings.Builder - for _, chunk := range chunks { - if len(chunk.Choices) == 0 { - continue - } - - choice := chunk.Choices[0] - if choice.Delta == nil { - continue - } - - if choice.Delta.Content != "" { - fullContent.WriteString(choice.Delta.Content) - } - } - - return fullContent.String() -} - -func (im *InferenceManager) updateUserStreak(userID uint64, log *zap.SugaredLogger) error { - var lastChatStr sql.NullString - var currentStreak uint64 - - err := im.RDB.QueryRow(` - SELECT last_chat, streak - FROM user - WHERE id = ? - `, userID).Scan(&lastChatStr, ¤tStreak) - if err != nil { - return fmt.Errorf("failed to get user streak data: %w", err) - } - - var lastChat sql.NullTime - if lastChatStr.Valid && lastChatStr.String != "" { - formats := []string{ - "2006-01-02 15:04:05", - time.RFC3339, - "2006-01-02T15:04:05Z07:00", - "2006-01-02 15:04:05.000000", - } - - var parsedTime time.Time - var parseErr error - for _, format := range formats { - parsedTime, parseErr = time.Parse(format, lastChatStr.String) - if parseErr == nil { - lastChat = sql.NullTime{Time: parsedTime, Valid: true} - break - } - } - - if parseErr != nil { - log.Warnw("Failed to parse last_chat timestamp", "error", parseErr, "value", lastChatStr.String) - } - } - - now := time.Now() - todayMidnight := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) - - var newStreak uint64 - updateStreak := false - - if lastChat.Valid { - lastChatDate := lastChat.Time - lastChatMidnight := time.Date(lastChatDate.Year(), lastChatDate.Month(), lastChatDate.Day(), 0, 0, 0, 0, lastChatDate.Location()) - - if !todayMidnight.Equal(lastChatMidnight) { - updateStreak = true - expectedDate := lastChatMidnight.AddDate(0, 0, 1) - if todayMidnight.Equal(expectedDate) { - newStreak = currentStreak + 1 - } else { - newStreak = 1 - } - } else { - newStreak = currentStreak - } - } else { - updateStreak = true - newStreak = 1 - } - - if updateStreak { - _, err = im.WDB.Exec(` - UPDATE user - SET streak = ?, last_chat = ? - WHERE id = ? - `, newStreak, now, userID) - if err != nil { - return fmt.Errorf("failed to update user streak: %w", err) - } - - log.Infow("Updated user streak", "user_id", userID, "streak", newStreak, "last_chat", now) - } else { - _, err = im.WDB.Exec(` - UPDATE user - SET last_chat = ? - WHERE id = ? - `, now, userID) - if err != nil { - return fmt.Errorf("failed to update last_chat: %w", err) - } - } - - return nil -} diff --git a/internal/routes/inference/preprocess.go b/internal/routes/inference/preprocess.go deleted file mode 100644 index ffe5007..0000000 --- a/internal/routes/inference/preprocess.go +++ /dev/null @@ -1,446 +0,0 @@ -package inference - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "slices" - "sync" - "time" - - "sybil-api/internal/setup" - "sybil-api/internal/shared" - - "github.com/labstack/echo/v4" -) - -func (im *InferenceManager) ChatRequest(c echo.Context) error { - _, err := im.ProcessOpenaiRequest(c, shared.ENDPOINTS.CHAT) - return err -} - -func (im *InferenceManager) CompletionRequest(c echo.Context) error { - _, err := im.ProcessOpenaiRequest(c, shared.ENDPOINTS.COMPLETION) - return err -} - -func (im *InferenceManager) CompletionRequestHistory(c echo.Context) (string, error) { - return im.ProcessOpenaiRequest(c, shared.ENDPOINTS.CHAT) -} - -func (im *InferenceManager) EmbeddingRequest(c echo.Context) error { - _, err := im.ProcessOpenaiRequest(c, shared.ENDPOINTS.EMBEDDING) - return err -} - -func (im *InferenceManager) ResponsesRequest(c echo.Context) error { - _, err := im.ProcessOpenaiRequest(c, shared.ENDPOINTS.RESPONSES) - return err -} - -func (im *InferenceManager) preprocessOpenAIRequest( - c *setup.Context, - endpoint string, -) (*shared.RequestInfo, *shared.RequestError) { - startTime := time.Now() - - userInfo := c.User - - // Ensure properly formatted request - body, _ := io.ReadAll(c.Request().Body) - - // Unmarshal to generic map to set defaults - var payload map[string]any - err := json.Unmarshal(body, &payload) - if err != nil { - c.Log.Warnw("failed json unmarshal to payload map", "error", err.Error()) - return nil, &shared.RequestError{StatusCode: 400, Err: errors.New("malformed request")} - } - - // validate models and set defaults - model, ok := payload["model"] - if !ok { - c.Log.Infow("missing model parameter", "error", "model is required") - return nil, &shared.RequestError{StatusCode: 400, Err: errors.New("model is required")} - } - - modelName := model.(string) - - // Add model and endpoint to logger context for all subsequent logs - c.Log = c.Log.With("model", modelName, "endpoint", endpoint) - - if endpoint == shared.ENDPOINTS.EMBEDDING { - - input, ok := payload["input"] - if !ok { - return nil, &shared.RequestError{ - StatusCode: 400, - Err: errors.New("input is required for embeddings"), - } - } - - switch v := input.(type) { - case string: - if v == "" { - return nil, &shared.RequestError{ - StatusCode: 400, - Err: errors.New("input cannot be empty"), - } - } - case []any: - if len(v) == 0 { - return nil, &shared.RequestError{ - StatusCode: 400, - Err: errors.New("input array cannot be empty"), - } - } - default: - return nil, &shared.RequestError{ - StatusCode: 400, - Err: errors.New("input must be string or array of strings"), - } - } - - if (userInfo.Credits == 0 && userInfo.PlanRequests == 0) && !userInfo.AllowOverspend { - c.Log.Infow("No credits available", "user_id", userInfo.UserID) - return nil, &shared.RequestError{ - StatusCode: 402, - Err: errors.New("insufficient credits"), - } - } - - body, err = json.Marshal(payload) - if err != nil { - c.Log.Errorw("Failed to marshal request body", "error", err.Error()) - return nil, &shared.RequestError{StatusCode: 500, Err: errors.New("internal server error")} - } - - return &shared.RequestInfo{ - Body: body, - UserID: userInfo.UserID, - Credits: userInfo.Credits, - ID: c.Reqid, - StartTime: startTime, - Endpoint: endpoint, - Model: modelName, - Stream: false, - }, nil - } - - if endpoint == shared.ENDPOINTS.RESPONSES { - input, ok := payload["input"] - if !ok { - return nil, &shared.RequestError{ - StatusCode: 400, - Err: errors.New("input is required for responses"), - } - } - - inputArray, ok := input.([]any) - if !ok { - return nil, &shared.RequestError{ - StatusCode: 400, - Err: errors.New("input must be an array"), - } - } - - if len(inputArray) == 0 { - return nil, &shared.RequestError{ - StatusCode: 400, - Err: errors.New("input array cannot be empty"), - } - } - } - - if (userInfo.Credits == 0 && userInfo.PlanRequests == 0) && !userInfo.AllowOverspend { - c.Log.Warnw("Insufficient credits or requests", - "credits", userInfo.Credits, - "plan_requests", userInfo.PlanRequests, - "allow_overspend", userInfo.AllowOverspend) - return nil, &shared.RequestError{ - StatusCode: 402, - Err: errors.New("insufficient requests or credits"), - } - } - - // Set stream default if not specified - if val, ok := payload["stream"]; !ok || val == nil { - payload["stream"] = shared.DefaultStreamOption - } - - stream := payload["stream"].(bool) - - // Add stream to logger context - c.Log = c.Log.With("stream", stream) - - // If streaming is enabled (either by default or explicitly), include usage data - if stream { - payload["stream_options"] = map[string]any{ - "include_usage": true, - } - } - - // Log user id 3's request parameters - if userInfo.UserID == 3 { - c.Log.Infow("User 3 request payload", - "model", modelName, - "stream", stream, - "max_tokens", payload["max_tokens"], - "temperature", payload["temperature"], - "top_p", payload["top_p"], - "frequency_penalty", payload["frequency_penalty"], - "presence_penalty", payload["presence_penalty"]) - } - - // repackage body - body, err = json.Marshal(payload) - if err != nil { - c.Log.Errorw("Failed to marshal request body", "error", err.Error()) - return nil, &shared.RequestError{StatusCode: 500, Err: errors.New("internal server error")} - } - - reqInfo := &shared.RequestInfo{ - Body: body, - UserID: userInfo.UserID, - Credits: userInfo.Credits, - ID: c.Reqid, - StartTime: startTime, - Endpoint: endpoint, - Model: modelName, - Stream: stream, - } - - return reqInfo, nil -} - -func (im *InferenceManager) ProcessOpenaiRequest(cc echo.Context, endpoint string) (string, error) { - c := cc.(*setup.Context) - - // Add endpoint to logger context - c.Log = c.Log.With("endpoint", endpoint) - - reqInfo, preprocessError := im.preprocessOpenAIRequest(c, endpoint) - if preprocessError != nil { - if preprocessError.StatusCode >= 500 { - c.Log.Warnw("Preprocess error", "error", preprocessError.Err.Error()) - } - return "", c.String(preprocessError.StatusCode, preprocessError.Error()) - } - - im.usageCache.AddInFlightToBucket(reqInfo.UserID) - - // ensure we remove inflight BEFORE we add this to a bucket - mu := sync.Mutex{} - mu.Lock() - defer func() { - im.usageCache.RemoveInFlightFromBucket(reqInfo.UserID) - mu.Unlock() - }() - - resInfo, qerr := im.QueryModels(c, reqInfo) - if qerr != nil { - c.Log.Warnw("QueryModels failed", - "error", qerr.Error(), - "status_code", qerr.StatusCode) - - /* TODO: Revisit overload logic - if qerr.StatusCode == 502 { - overload.TrackTPS( - c.Core, - c.ModelDNS, - 1, - ) - } */ - - return "", c.JSON(qerr.StatusCode, shared.OpenAIError{ - Message: qerr.Error(), - Object: "error", - Type: "InternalError", - Code: qerr.StatusCode, - }) - } - - // Extract usage data from the response content - if resInfo.ResponseContent == "" || !resInfo.Completed { - c.Log.Errorw("No response or incomplete response from model", - "response_content_length", len(resInfo.ResponseContent), - "completed", resInfo.Completed, - "canceled", resInfo.Canceled, - "ttft", resInfo.TimeToFirstToken, - "total_time", resInfo.TotalTime) - _ = c.JSON(500, shared.OpenAIError{ - Message: "no response from model", - Object: "error", - Type: "InternalError", - Code: 500, - }) - } - - // Asynchronously process request and return to the user - log := c.Log - go func() { - switch true { - case !resInfo.Completed: - break - case reqInfo.Stream: - var chunks []map[string]any - err := json.Unmarshal([]byte(resInfo.ResponseContent), &chunks) - if err != nil { - log.Errorw( - "Failed to unmarshal streaming ResponseContent as JSON array of chunks", - "error", - err, - "raw_response_content", - resInfo.ResponseContent, - ) - break - } - slices.Reverse(chunks) - for i, chunk := range chunks { - usageData, usageFieldExists := chunk["usage"] - if usageFieldExists && usageData != nil { - if extractedUsage, extractErr := extractUsageData(chunk, endpoint); extractErr == nil { - resInfo.Usage = extractedUsage - break - } - log.Warnw( - "Failed to extract usage data from a response chunk that had a non-null usage field", - "chunk_index", - i, - ) - break - } - } - case !reqInfo.Stream: - // Not a streaming request, expect a single JSON object - var singleResponse map[string]any - err := json.Unmarshal([]byte(resInfo.ResponseContent), &singleResponse) - if err != nil { - log.Errorw( - "Failed to unmarshal non-streaming ResponseContent as single JSON object", - "error", - err, - "raw_response_content", - resInfo.ResponseContent, - ) - break - } - usageData, usageFieldExists := singleResponse["usage"] - if usageFieldExists && usageData != nil { - if extractedUsage, extractErr := extractUsageData(singleResponse, endpoint); extractErr == nil { - resInfo.Usage = extractedUsage - break - } - log.Warnw( - "Failed to extract usage data from single response object that had a non-null usage field", - ) - } - default: - break - } - - // Ensure resInfo.Usage is not nil before saving (this is a good fallback) - if resInfo.Usage == nil { - resInfo.Usage = &shared.Usage{IsCanceled: resInfo.Canceled} - } - - totalCredits := shared.CalculateCredits(resInfo.Usage, resInfo.Cost.InputCredits, resInfo.Cost.OutputCredits, resInfo.Cost.CanceledCredits) - - pqi := &shared.ProcessedQueryInfo{ - UserID: reqInfo.UserID, - Model: reqInfo.Model, - ModelID: resInfo.ModelID, - Endpoint: reqInfo.Endpoint, - TotalTime: resInfo.TotalTime, - TimeToFirstToken: resInfo.TimeToFirstToken, - Usage: resInfo.Usage, - Cost: resInfo.Cost, - TotalCredits: totalCredits, - ResponseContent: resInfo.ResponseContent, - RequestContent: reqInfo.Body, - CreatedAt: time.Now(), - ID: reqInfo.ID, - } - - /* TODO: ditto - if resInfo.Completed { - overload.TrackTPS( - core, - modelDNS, - float64(resInfo.Usage.CompletionTokens)/resInfo.TotalTime.Seconds(), - ) - } - */ - mu.Lock() - im.usageCache.AddRequestToBucket(reqInfo.UserID, pqi, reqInfo.ID) - mu.Unlock() - }() - - return resInfo.ResponseContent, nil -} - -// Helper function to safely extract float64 values from a map -func getTokenCount(usageData map[string]any, field string) (uint64, error) { - value, ok := usageData[field] - if !ok { - return 0, fmt.Errorf("missing %s field", field) - } - floatVal, ok := value.(float64) - if !ok { - return 0, fmt.Errorf("invalid type for %s field", field) - } - return uint64(floatVal), nil -} - -// Helper function to safely extract usage data from response -func extractUsageData(response map[string]any, endpoint string) (*shared.Usage, error) { - usageData, ok := response["usage"].(map[string]any) - if !ok { - return nil, errors.New("missing or invalid usage data") - } - - var promptTokens, completionTokens, totalTokens uint64 - var err error - - // Handle Responses API format (input_tokens, output_tokens) - if endpoint == shared.ENDPOINTS.RESPONSES { - promptTokens, err = getTokenCount(usageData, "input_tokens") - if err != nil { - return nil, fmt.Errorf("error getting input tokens: %w", err) - } - - completionTokens, err = getTokenCount(usageData, "output_tokens") - if err != nil { - return nil, fmt.Errorf("error getting output tokens: %w", err) - } - - totalTokens = promptTokens + completionTokens - } else { - // Handle Chat/Completions format (prompt_tokens, completion_tokens) - promptTokens, err = getTokenCount(usageData, "prompt_tokens") - if err != nil { - return nil, fmt.Errorf("error getting prompt tokens: %w", err) - } - - completionTokens = uint64(0) - if endpoint != shared.ENDPOINTS.EMBEDDING { - completionTokens, err = getTokenCount(usageData, "completion_tokens") - if err != nil { - return nil, fmt.Errorf("error getting completion tokens: %w", err) - } - } - - totalTokens, err = getTokenCount(usageData, "total_tokens") - if err != nil { - return nil, fmt.Errorf("error getting total tokens: %w", err) - } - } - - return &shared.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: totalTokens, - IsCanceled: false, - }, nil -} diff --git a/internal/setup/setup.go b/internal/setup/setup.go index 00bfa3d..93bd7a4 100644 --- a/internal/setup/setup.go +++ b/internal/setup/setup.go @@ -2,17 +2,9 @@ package setup import ( - "context" - "database/sql" - "errors" - "fmt" - "strconv" - "sybil-api/internal/shared" - "github.com/google/uuid" "github.com/labstack/echo/v4" - "github.com/redis/go-redis/v9" "go.uber.org/zap" ) @@ -22,117 +14,3 @@ type Context struct { Reqid string User *shared.UserMetadata } - -type Core struct { - Env Environment - RedisClient *redis.Client - WDB *sql.DB - RDB *sql.DB - Log *zap.SugaredLogger - Debug bool -} - -type Environment struct { - InstanceUUID string - MetricsAPIKey string -} - -func (c *Core) Shutdown() { - if c.RedisClient != nil { - _ = c.RedisClient.Close() - } - if c.WDB != nil { - _ = c.WDB.Close() - } - if c.RDB != nil { - _ = c.RDB.Close() - } -} - -func CreateCore() (*Core, []error) { - var errs []error - - DSN, err := shared.SafeEnv("DSN") - if err != nil { - errs = append(errs, err) - } - readDSN, err := shared.SafeEnv("READ_DSN") - if err != nil { - errs = append(errs, err) - } - - metricsAPIKey, err := shared.SafeEnv("METRICS_API_KEY") - if err != nil { - errs = append(errs, err) - } - - redisHost := shared.GetEnv("REDIS_HOST", "cache") - redisPort := shared.GetEnv("REDIS_PORT", "6379") - - instanceUUID := uuid.New().String() - DEBUG, err := strconv.ParseBool(shared.GetEnv("DEBUG", "false")) - if err != nil { - errs = append(errs, err) - } - - if len(errs) != 0 { - return nil, errs - } - - // Load PrimaryDB connections - sqlClient, err := sql.Open("mysql", DSN) - if err != nil { - return nil, []error{errors.New("failed initializing sqlClient"), err} - } - err = sqlClient.Ping() - if err != nil { - return nil, []error{errors.New("failed ping to sql db"), err} - } - - // Load Read Replica DB connection - readSQLClient, err := sql.Open("mysql", readDSN) - if err != nil { - return nil, []error{errors.New("failed initializing readSqlClient"), err} - } - err = readSQLClient.Ping() - if err != nil { - return nil, []error{errors.New("failed to ping read replica sql db"), err} - } - - // Load Redis connection - redisClient := redis.NewClient(&redis.Options{ - Addr: fmt.Sprintf("%s:%s", redisHost, redisPort), - Password: "", - DB: 0, - }) - if err := redisClient.Ping(context.Background()).Err(); err != nil { - return nil, []error{errors.New("failed ping to redis db"), err} - } - - var logger *zap.Logger - if !DEBUG { - logger, err = zap.NewProduction() - if err != nil { - panic("Failed init logger") - } - } - if DEBUG { - logger, err = zap.NewDevelopment() - if err != nil { - panic("Failed init logger") - } - } - log := logger.Sugar() - - return &Core{ - Debug: DEBUG, - Log: log, - Env: Environment{ - InstanceUUID: instanceUUID, - MetricsAPIKey: metricsAPIKey, - }, - RedisClient: redisClient, - RDB: readSQLClient, - WDB: sqlClient, - }, nil -}