Skip to content

Commit

Permalink
feat: auto load into memory on startup
Browse files Browse the repository at this point in the history
Signed-off-by: Sertac Ozercan <sozercan@gmail.com>
  • Loading branch information
sozercan committed Sep 22, 2024
1 parent 1f43678 commit 0af60a5
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 213 deletions.
2 changes: 1 addition & 1 deletion core/backend/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {
modelFile := backendConfig.Model

grpcOpts := gRPCModelOpts(backendConfig)
grpcOpts := GRPCModelOpts(backendConfig)

var inferenceModel interface{}
var err error
Expand Down
2 changes: 1 addition & 1 deletion core/backend/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
if *threads == 0 && appConfig.Threads != 0 {
threads = &appConfig.Threads
}
gRPCOpts := gRPCModelOpts(backendConfig)
gRPCOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(backendConfig.Backend),
model.WithAssetDir(appConfig.AssetsDestination),
Expand Down
2 changes: 1 addition & 1 deletion core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
if *threads == 0 && o.Threads != 0 {
threads = &o.Threads
}
grpcOpts := gRPCModelOpts(c)
grpcOpts := GRPCModelOpts(c)

var inferenceModel grpc.Backend
var err error
Expand Down
2 changes: 1 addition & 1 deletion core/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func getSeed(c config.BackendConfig) int32 {
return seed
}

func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
func GRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
Expand Down
2 changes: 1 addition & 1 deletion core/backend/rerank.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func Rerank(backend, modelFile string, request *proto.RerankRequest, loader *mod
return nil, fmt.Errorf("backend is required")
}

grpcOpts := gRPCModelOpts(backendConfig)
grpcOpts := GRPCModelOpts(backendConfig)

opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
Expand Down
2 changes: 1 addition & 1 deletion core/backend/soundgeneration.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func SoundGeneration(
return "", nil, fmt.Errorf("backend is a required parameter")
}

grpcOpts := gRPCModelOpts(backendConfig)
grpcOpts := GRPCModelOpts(backendConfig)
opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(backend),
model.WithModel(modelFile),
Expand Down
2 changes: 1 addition & 1 deletion core/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func ModelTTS(
bb = model.PiperBackend
}

grpcOpts := gRPCModelOpts(backendConfig)
grpcOpts := GRPCModelOpts(backendConfig)

opts := modelOpts(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
Expand Down
2 changes: 2 additions & 0 deletions core/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type RunCMD struct {
WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"`
Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"`
DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"`
LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"`
}

func (r *RunCMD) Run(ctx *cliContext.Context) error {
Expand Down Expand Up @@ -104,6 +105,7 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error {
config.WithDisableApiKeyRequirementForHttpGet(r.DisableApiKeyRequirementForHttpGet),
config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints),
config.WithP2PNetworkID(r.Peer2PeerNetworkID),
config.WithLoadToMemory(r.LoadToMemory),
}

token := ""
Expand Down
7 changes: 7 additions & 0 deletions core/config/application_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type ApplicationConfig struct {
DisableApiKeyRequirementForHttpGet bool
HttpGetExemptedEndpoints []*regexp.Regexp
DisableGalleryEndpoint bool
LoadToMemory []string

ModelLibraryURL string

Expand Down Expand Up @@ -331,6 +332,12 @@ func WithOpaqueErrors(opaque bool) AppOption {
}
}

func WithLoadToMemory(models []string) AppOption {
return func(o *ApplicationConfig) {
o.LoadToMemory = models
}
}

func WithSubtleKeyComparison(subtle bool) AppOption {
return func(o *ApplicationConfig) {
o.UseSubtleKeyComparison = subtle
Expand Down
Loading

0 comments on commit 0af60a5

Please sign in to comment.