Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/application/startup.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ func New(opts ...config.AppOption) (*Application, error) {
}
}

if err := coreStartup.InstallModels(application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
log.Error().Err(err).Msg("error installing models")
}

for _, backend := range options.ExternalBackends {
if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
if err := coreStartup.InstallExternalBackends(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
log.Error().Err(err).Msg("error installing external backend")
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
if !slices.Contains(modelNames, c.Name) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil {
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
//return nil, err
Expand Down
3 changes: 2 additions & 1 deletion core/cli/backends.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cli

import (
"context"
"encoding/json"
"fmt"

Expand Down Expand Up @@ -102,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
}

modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
err = startup.InstallExternalBackends(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion core/cli/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
}

modelLoader := model.NewModelLoader(systemState, true)
err = startup.InstallModels(galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
if err != nil {
return err
}
Expand Down
3 changes: 2 additions & 1 deletion core/cli/worker/worker_llamacpp.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package worker

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -42,7 +43,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
log.Error().Err(err).Msg("failed loading galleries")
return "", err
}
err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
if err != nil {
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
return "", err
Expand Down
19 changes: 13 additions & 6 deletions core/gallery/backends.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package gallery

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -69,7 +70,7 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
}

// InstallBackendFromGallery installs a backend from the gallery.
func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
if !force {
// check if we already have the backend installed
backends, err := ListSystemBackends(systemState)
Expand Down Expand Up @@ -109,7 +110,7 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")

// Then, let's install the best backend
if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
return err
}

Expand All @@ -134,10 +135,10 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
return nil
}

return InstallBackend(systemState, modelLoader, backend, downloadStatus)
return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
}

func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
if err != nil {
Expand All @@ -164,11 +165,17 @@ func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoa
}
} else {
uri := downloader.URI(config.URI)
if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
success := false
// Try to download from mirrors
for _, mirror := range config.Mirrors {
if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
// Check for cancellation before trying next mirror
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
success = true
break
}
Expand Down
23 changes: 12 additions & 11 deletions core/gallery/backends_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gallery

import (
"context"
"encoding/json"
"os"
"path/filepath"
Expand Down Expand Up @@ -55,7 +56,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
)
must(err)
sysDefault.GPUVendor = "" // force default selection
backs, err := ListSystemBackends(sysDefault)
backs, err := ListSystemBackends(sysDefault)
must(err)
aliasBack, ok := backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
Expand All @@ -77,7 +78,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
must(err)
sysNvidia.GPUVendor = "nvidia"
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
backs, err = ListSystemBackends(sysNvidia)
backs, err = ListSystemBackends(sysNvidia)
must(err)
aliasBack, ok = backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
Expand Down Expand Up @@ -116,13 +117,13 @@ var _ = Describe("Gallery Backends", func() {

Describe("InstallBackendFromGallery", func() {
It("should return error when backend is not found", func() {
err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
})

It("should install backend from gallery", func() {
err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
})
Expand Down Expand Up @@ -298,7 +299,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())

metaBackendPath := filepath.Join(tempDir, "meta-backend")
Expand Down Expand Up @@ -378,7 +379,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())

metaBackendPath := filepath.Join(tempDir, "meta-backend")
Expand Down Expand Up @@ -462,7 +463,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())

metaBackendPath := filepath.Join(tempDir, "meta-backend")
Expand Down Expand Up @@ -561,7 +562,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(newPath),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
Expect(newPath).To(BeADirectory())
})
Expand Down Expand Up @@ -593,7 +594,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
Expand Down Expand Up @@ -626,7 +627,7 @@ var _ = Describe("Gallery Backends", func() {

Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())

err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
})
Expand All @@ -647,7 +648,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
err = InstallBackend(systemState, ml, &backend, nil)
err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())

Expand Down
14 changes: 14 additions & 0 deletions core/gallery/gallery.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gallery

import (
"context"
"fmt"
"os"
"path/filepath"
Expand Down Expand Up @@ -28,6 +29,19 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
return config, nil
}

func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
var config T
uri := downloader.URI(url)
err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
return config, err
}
return config, nil
}

func ReadConfigFile[T any](filePath string) (*T, error) {
// Read the YAML file
yamlFile, err := os.ReadFile(filePath)
Expand Down
19 changes: 14 additions & 5 deletions core/gallery/models.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gallery

import (
"context"
"errors"
"fmt"
"os"
Expand Down Expand Up @@ -72,6 +73,7 @@ type PromptTemplate struct {

// Installs a model from the gallery
func InstallModelFromGallery(
ctx context.Context,
modelGalleries, backendGalleries []config.Gallery,
systemState *system.SystemState,
modelLoader *model.ModelLoader,
Expand All @@ -84,7 +86,7 @@ func InstallModelFromGallery(

if len(model.URL) > 0 {
var err error
config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath)
if err != nil {
return err
}
Expand Down Expand Up @@ -125,15 +127,15 @@ func InstallModelFromGallery(
return err
}

installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
if err != nil {
return err
}
log.Debug().Msgf("Installed model %q", installedModel.Name)
if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend)

if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}
Expand All @@ -154,7 +156,7 @@ func InstallModelFromGallery(
return applyModel(model)
}

func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
basePath := systemState.Model.ModelsPath
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0750)
Expand All @@ -168,6 +170,13 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *

// Download files and verify their SHA
for i, file := range config.Files {
// Check for cancellation before each file
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}

log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)

if err := utils.VerifyPath(file.Filename, basePath); err != nil {
Expand All @@ -185,7 +194,7 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
}
}
uri := downloader.URI(file.URI)
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
return nil, err
}
}
Expand Down
11 changes: 6 additions & 5 deletions core/gallery/models_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gallery_test

import (
"context"
"errors"
"os"
"path/filepath"
Expand Down Expand Up @@ -34,7 +35,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())

for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
Expand Down Expand Up @@ -88,7 +89,7 @@ var _ = Describe("Model test", func() {
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
Expect(models[0].Installed).To(BeFalse())

err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
Expect(err).ToNot(HaveOccurred())

dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
Expand Down Expand Up @@ -129,7 +130,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())

for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
Expand All @@ -149,7 +150,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())

for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
Expand Down Expand Up @@ -179,7 +180,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
_, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
_, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).To(HaveOccurred())
})
})
Expand Down
2 changes: 1 addition & 1 deletion core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{}
uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix?
err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
Expand Down
Loading
Loading