Skip to content

Commit

Permalink
#51: Build routers & models based on provided config
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-glushko committed Jan 1, 2024
1 parent ae3cf1b commit 65405ef
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 91 deletions.
81 changes: 81 additions & 0 deletions pkg/providers/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package providers

import (
"errors"
"fmt"
"glide/pkg/providers/openai"
"glide/pkg/telemetry"
"time"
)

var (
ErrProviderNotFound = errors.New("provider not found")
)

type LangModelConfig struct {
ID string `yaml:"id"`
Enabled bool `yaml:"enabled"`
Timeout *time.Duration `yaml:"timeout,omitempty"`
OpenAI *openai.Config
// Add other providers like
// Cohere *cohere.Config
// Anthropic *anthropic.Config
}

func DefaultLangModelConfig() *LangModelConfig {
defaultTimeout := 10 * time.Second

return &LangModelConfig{
Enabled: true,
Timeout: &defaultTimeout,
}
}

func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (LanguageModel, error) {
if c.OpenAI != nil {
client, err := openai.NewClient(c.OpenAI, tel)

if err != nil {
return nil, fmt.Errorf("error initing openai client: %v", err)
}

return client, nil
}

return nil, ErrProviderNotFound
}

func (m *LangModelConfig) validateOneProvider() error {
providersConfigured := 0

if m.OpenAI != nil {
providersConfigured++
}

// check other providers here
if providersConfigured == 0 {
return fmt.Errorf("exactly one provider must be cofigured for model \"%v\", none is configured", m.ID)
}

if providersConfigured > 1 {
return fmt.Errorf(
"exactly one provider must be cofigured for model \"%v\", %v are configured",
m.ID,
providersConfigured,
)
}

return nil
}

func (m *LangModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
*m = *DefaultLangModelConfig()

type plain LangModelConfig // to avoid recursion

if err := unmarshal((*plain)(m)); err != nil {
return err
}

return m.validateOneProvider()
}
12 changes: 0 additions & 12 deletions pkg/providers/language.go

This file was deleted.

11 changes: 10 additions & 1 deletion pkg/providers/provider.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
package providers

import "errors"
import (
"context"
"errors"
"glide/pkg/api/schemas"
)

var ErrProviderUnavailable = errors.New("provider is not available")

// ModelProvider defines an interface all model providers should support
type ModelProvider interface {
Provider() string
}

// LanguageModel defines the interface a provider should fulfill to be able to serve language chat requests
type LanguageModel interface {
Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error)
}
54 changes: 5 additions & 49 deletions pkg/routers/config.go
Original file line number Diff line number Diff line change
@@ -1,63 +1,19 @@
package routers

import (
"fmt"

"glide/pkg/providers/openai"
"glide/pkg/providers"
"glide/pkg/routers/strategy"
)

type Config struct {
LanguageRouters []LangRouterConfig `yaml:"language"`
}

type LangModel struct {
ID string `yaml:"id"`
TimeoutMs *int `yaml:"timeout_ms,omitempty"` // TODO: try to use Duration to bring more flexibility
OpenAI *openai.Config
// Add other providers like
// Cohere *cohere.Config
// Anthropic *anthropic.Config
}

func (m *LangModel) validateOneProvider() error {
providersConfigured := 0

if m.OpenAI != nil {
providersConfigured++
}

// check other providers here
if providersConfigured == 0 {
return fmt.Errorf("exactly one provider must be cofigured for model \"%v\", none is configured", m.ID)
}

if providersConfigured > 1 {
return fmt.Errorf(
"exactly one provider must be cofigured for model \"%v\", %v are configured",
m.ID,
providersConfigured,
)
}

return nil
}

func (m *LangModel) UnmarshalYAML(unmarshal func(interface{}) error) error {
type plain LangModel // to avoid recursion

if err := unmarshal((*plain)(m)); err != nil {
return err
}

return m.validateOneProvider()
}

type LangRouterConfig struct {
ID string `yaml:"id"`
Enabled bool `yaml:"enabled"`
RoutingStrategy strategy.RoutingStrategy `yaml:"strategy"`
Models []LangModel `yaml:"models"`
ID string `yaml:"id"`
Enabled bool `yaml:"enabled"`
RoutingStrategy strategy.RoutingStrategy `yaml:"strategy"`
Models []providers.LangModelConfig `yaml:"models"`
}

func DefaultLangRouterConfig() LangRouterConfig {
Expand Down
64 changes: 47 additions & 17 deletions pkg/routers/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,68 @@ package routers
import (
"errors"

"go.uber.org/multierr"
"go.uber.org/zap"

"glide/pkg/telemetry"
)

var ErrRouterNotFound = errors.New("no router found with given ID")

type RouterManager struct {
config *Config
telemetry *telemetry.Telemetry
myRouter *LangRouter // TODO: replace by list of routers
config *Config
telemetry *telemetry.Telemetry
langRouters *map[string]*LangRouter
}

// NewManager creates a new instance of Router Manager that creates, holds and returns all routers
func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) {
// TODO: init routers by config
router, err := NewLangRouter(tel)
if err != nil {
return nil, err
}

return &RouterManager{
manager := RouterManager{
config: cfg,
telemetry: tel,
myRouter: router,
}, nil
}

err := manager.BuildRouters(cfg.LanguageRouters)

return &manager, err
}

func (r *RouterManager) BuildRouters(routerConfigs []LangRouterConfig) error {
routers := make(map[string]*LangRouter, len(routerConfigs))

var errs error

for _, routerConfig := range routerConfigs {
if !routerConfig.Enabled {
r.telemetry.Logger.Info("router is disabled, skipping", zap.String("routerID", routerConfig.ID))
continue
}

r.telemetry.Logger.Debug("init router", zap.String("routerID", routerConfig.ID))
router, err := NewLangRouter(&routerConfig, r.telemetry)

if err != nil {
errs = multierr.Append(errs, err)
continue
}

routers[routerConfig.ID] = router
}

if errs != nil {
return errs
}

r.langRouters = &routers

return nil
}

// Get returns a router by type and ID
// GetLangRouter returns a router by type and ID
func (r *RouterManager) GetLangRouter(routerID string) (*LangRouter, error) {
// TODO: implement actual logic
if routerID != "myrouter" {
return nil, ErrRouterNotFound
if router, found := (*r.langRouters)[routerID]; found {
return router, nil
}

return r.myRouter, nil
return nil, ErrRouterNotFound
}
67 changes: 55 additions & 12 deletions pkg/routers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,73 @@ package routers

import (
"context"
"errors"
"glide/pkg/providers"
"glide/pkg/providers/factory"
"go.uber.org/multierr"
"go.uber.org/zap"

"glide/pkg/api/schemas"
"glide/pkg/providers/openai"
"glide/pkg/telemetry"
)

var (
ErrNoModels = errors.New("no models configured for router")
)

type LangRouter struct {
openAIClient *openai.Client // TODO: replace by actual model list
telemetry *telemetry.Telemetry
config *LangRouterConfig
models []providers.LanguageModel
telemetry *telemetry.Telemetry
}

func NewLangRouter(tel *telemetry.Telemetry) (*LangRouter, error) {
openAIClient, err := openai.NewClient(openai.DefaultConfig(), tel)
if err != nil {
return nil, err
func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter, error) {
router := &LangRouter{
config: cfg,
telemetry: tel,
}

return &LangRouter{
openAIClient: openAIClient,
telemetry: tel,
}, nil
err := router.BuildModels(cfg.Models)

return router, err
}

func (r *LangRouter) BuildModels(modelConfigs []providers.LangModelConfig) error {
var errs error
models := make([]providers.LanguageModel, 0, len(modelConfigs))

for _, modelConfig := range modelConfigs {
if !modelConfig.Enabled {
r.telemetry.Logger.Info("model is disabled, skipping", zap.String("modelID", modelConfig.ID))
continue
}

r.telemetry.Logger.Debug("init lang model", zap.String("modelID", modelConfig.ID))

model, err := factory.NewModelFromConfig(modelConfig, r.telemetry)

if err != nil {
errs = multierr.Append(errs, err)
continue
}

models = append(models, model)
}

if errs != nil {
return errs
}

r.models = models

return nil
}

func (r *LangRouter) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) (*schemas.UnifiedChatResponse, error) {
if len(r.models) == 0 {
return nil, ErrNoModels
}

// TODO: implement actual routing & fallback logic
return r.openAIClient.Chat(ctx, request)
return r.models[0].Chat(ctx, request)
}

0 comments on commit 65405ef

Please sign in to comment.