diff --git a/core/application/startup.go b/core/application/startup.go
index 8e2387b9226f..eb387d06debd 100644
--- a/core/application/startup.go
+++ b/core/application/startup.go
@@ -62,12 +62,12 @@ func New(opts ...config.AppOption) (*Application, error) {
}
}
- if err := coreStartup.InstallModels(application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
+ if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil {
log.Error().Err(err).Msg("error installing models")
}
for _, backend := range options.ExternalBackends {
- if err := coreStartup.InstallExternalBackends(options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
+ if err := coreStartup.InstallExternalBackends(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil {
log.Error().Err(err).Msg("error installing external backend")
}
}
diff --git a/core/backend/llm.go b/core/backend/llm.go
index d6c7bc736e93..3cd74d9a4953 100644
--- a/core/backend/llm.go
+++ b/core/backend/llm.go
@@ -45,7 +45,7 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
if !slices.Contains(modelNames, c.Name) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
- err := gallery.InstallModelFromGallery(o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
+ err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries)
if err != nil {
log.Error().Err(err).Msgf("failed to install model %q from gallery", modelFile)
//return nil, err
diff --git a/core/cli/backends.go b/core/cli/backends.go
index 666f1eb0be11..0528d76678d8 100644
--- a/core/cli/backends.go
+++ b/core/cli/backends.go
@@ -1,6 +1,7 @@
package cli
import (
+ "context"
"encoding/json"
"fmt"
@@ -102,7 +103,7 @@ func (bi *BackendsInstall) Run(ctx *cliContext.Context) error {
}
modelLoader := model.NewModelLoader(systemState, true)
- err = startup.InstallExternalBackends(galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
+ err = startup.InstallExternalBackends(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias)
if err != nil {
return err
}
diff --git a/core/cli/models.go b/core/cli/models.go
index dd5273317bfc..bcbb60d48828 100644
--- a/core/cli/models.go
+++ b/core/cli/models.go
@@ -135,7 +135,7 @@ func (mi *ModelsInstall) Run(ctx *cliContext.Context) error {
}
modelLoader := model.NewModelLoader(systemState, true)
- err = startup.InstallModels(galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
+ err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName)
if err != nil {
return err
}
diff --git a/core/cli/worker/worker_llamacpp.go b/core/cli/worker/worker_llamacpp.go
index 8a55f2345eee..1b4be6736637 100644
--- a/core/cli/worker/worker_llamacpp.go
+++ b/core/cli/worker/worker_llamacpp.go
@@ -1,6 +1,7 @@
package worker
import (
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -42,7 +43,7 @@ func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (str
log.Error().Err(err).Msg("failed loading galleries")
return "", err
}
- err := gallery.InstallBackendFromGallery(gals, systemState, ml, llamaCPPGalleryName, nil, true)
+ err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true)
if err != nil {
log.Error().Err(err).Msg("llama-cpp backend not found, failed to install it")
return "", err
diff --git a/core/gallery/backends.go b/core/gallery/backends.go
index 34b175aa78c9..aee4b2d93928 100644
--- a/core/gallery/backends.go
+++ b/core/gallery/backends.go
@@ -3,6 +3,7 @@
package gallery
import (
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -69,7 +70,7 @@ func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error {
}
// InstallBackendFromGallery installs a backend from the gallery.
-func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
+func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error {
if !force {
// check if we already have the backend installed
backends, err := ListSystemBackends(systemState)
@@ -109,7 +110,7 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
log.Debug().Str("name", name).Str("bestBackend", bestBackend.Name).Msg("Installing backend from meta backend")
// Then, let's install the best backend
- if err := InstallBackend(systemState, modelLoader, bestBackend, downloadStatus); err != nil {
+ if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil {
return err
}
@@ -134,10 +135,10 @@ func InstallBackendFromGallery(galleries []config.Gallery, systemState *system.S
return nil
}
- return InstallBackend(systemState, modelLoader, backend, downloadStatus)
+ return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus)
}
-func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
+func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error {
// Create base path if it doesn't exist
err := os.MkdirAll(systemState.Backend.BackendsPath, 0750)
if err != nil {
@@ -164,11 +165,17 @@ func InstallBackend(systemState *system.SystemState, modelLoader *model.ModelLoa
}
} else {
uri := downloader.URI(config.URI)
- if err := uri.DownloadFile(backendPath, "", 1, 1, downloadStatus); err != nil {
+ if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil {
success := false
// Try to download from mirrors
for _, mirror := range config.Mirrors {
- if err := downloader.URI(mirror).DownloadFile(backendPath, "", 1, 1, downloadStatus); err == nil {
+ // Check for cancellation before trying next mirror
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+ if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil {
success = true
break
}
diff --git a/core/gallery/backends_test.go b/core/gallery/backends_test.go
index 26652caadd05..15900d25018b 100644
--- a/core/gallery/backends_test.go
+++ b/core/gallery/backends_test.go
@@ -1,6 +1,7 @@
package gallery
import (
+ "context"
"encoding/json"
"os"
"path/filepath"
@@ -55,7 +56,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
)
must(err)
sysDefault.GPUVendor = "" // force default selection
- backs, err := ListSystemBackends(sysDefault)
+ backs, err := ListSystemBackends(sysDefault)
must(err)
aliasBack, ok := backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
@@ -77,7 +78,7 @@ var _ = Describe("Runtime capability-based backend selection", func() {
must(err)
sysNvidia.GPUVendor = "nvidia"
sysNvidia.VRAM = 8 * 1024 * 1024 * 1024
- backs, err = ListSystemBackends(sysNvidia)
+ backs, err = ListSystemBackends(sysNvidia)
must(err)
aliasBack, ok = backs.Get("llama-cpp")
Expect(ok).To(BeTrue())
@@ -116,13 +117,13 @@ var _ = Describe("Gallery Backends", func() {
Describe("InstallBackendFromGallery", func() {
It("should return error when backend is not found", func() {
- err := InstallBackendFromGallery(galleries, systemState, ml, "non-existent", nil, true)
+ err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true)
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\""))
})
It("should install backend from gallery", func() {
- err := InstallBackendFromGallery(galleries, systemState, ml, "test-backend", nil, true)
+ err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile())
})
@@ -298,7 +299,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
- err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
+ err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -378,7 +379,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
- err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
+ err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -462,7 +463,7 @@ var _ = Describe("Gallery Backends", func() {
VRAM: 1000000000000,
Backend: system.Backend{BackendsPath: tempDir},
}
- err = InstallBackendFromGallery([]config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
+ err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true)
Expect(err).NotTo(HaveOccurred())
metaBackendPath := filepath.Join(tempDir, "meta-backend")
@@ -561,7 +562,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(newPath),
)
Expect(err).NotTo(HaveOccurred())
- err = InstallBackend(systemState, ml, &backend, nil)
+ err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created
Expect(newPath).To(BeADirectory())
})
@@ -593,7 +594,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
- err = InstallBackend(systemState, ml, &backend, nil)
+ err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json"))
@@ -626,7 +627,7 @@ var _ = Describe("Gallery Backends", func() {
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile())
- err = InstallBackend(systemState, ml, &backend, nil)
+ err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
})
@@ -647,7 +648,7 @@ var _ = Describe("Gallery Backends", func() {
system.WithBackendPath(tempDir),
)
Expect(err).NotTo(HaveOccurred())
- err = InstallBackend(systemState, ml, &backend, nil)
+ err = InstallBackend(context.TODO(), systemState, ml, &backend, nil)
Expect(err).ToNot(HaveOccurred())
Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile())
diff --git a/core/gallery/gallery.go b/core/gallery/gallery.go
index d8dc3100f1a9..62362148ecef 100644
--- a/core/gallery/gallery.go
+++ b/core/gallery/gallery.go
@@ -1,6 +1,7 @@
package gallery
import (
+ "context"
"fmt"
"os"
"path/filepath"
@@ -28,6 +29,19 @@ func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) {
return config, nil
}
+func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) {
+ var config T
+ uri := downloader.URI(url)
+ err := uri.DownloadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error {
+ return yaml.Unmarshal(d, &config)
+ })
+ if err != nil {
+ log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url")
+ return config, err
+ }
+ return config, nil
+}
+
func ReadConfigFile[T any](filePath string) (*T, error) {
// Read the YAML file
yamlFile, err := os.ReadFile(filePath)
diff --git a/core/gallery/models.go b/core/gallery/models.go
index a1abe0ee1182..7205886b633c 100644
--- a/core/gallery/models.go
+++ b/core/gallery/models.go
@@ -1,6 +1,7 @@
package gallery
import (
+ "context"
"errors"
"fmt"
"os"
@@ -72,6 +73,7 @@ type PromptTemplate struct {
// Installs a model from the gallery
func InstallModelFromGallery(
+ ctx context.Context,
modelGalleries, backendGalleries []config.Gallery,
systemState *system.SystemState,
modelLoader *model.ModelLoader,
@@ -84,7 +86,7 @@ func InstallModelFromGallery(
if len(model.URL) > 0 {
var err error
- config, err = GetGalleryConfigFromURL[ModelConfig](model.URL, systemState.Model.ModelsPath)
+ config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath)
if err != nil {
return err
}
@@ -125,7 +127,7 @@ func InstallModelFromGallery(
return err
}
- installedModel, err := InstallModel(systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
+ installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan)
if err != nil {
return err
}
@@ -133,7 +135,7 @@ func InstallModelFromGallery(
if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
- if err := InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
+ if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}
@@ -154,7 +156,7 @@ func InstallModelFromGallery(
return applyModel(model)
}
-func InstallModel(systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
+func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) {
basePath := systemState.Model.ModelsPath
// Create base path if it doesn't exist
err := os.MkdirAll(basePath, 0750)
@@ -168,6 +170,13 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
// Download files and verify their SHA
for i, file := range config.Files {
+ // Check for cancellation before each file
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename)
if err := utils.VerifyPath(file.Filename, basePath); err != nil {
@@ -185,7 +194,7 @@ func InstallModel(systemState *system.SystemState, nameOverride string, config *
}
}
uri := downloader.URI(file.URI)
- if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
+ if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil {
return nil, err
}
}
diff --git a/core/gallery/models_test.go b/core/gallery/models_test.go
index 3ae76c203c32..df0bee06ce8e 100644
--- a/core/gallery/models_test.go
+++ b/core/gallery/models_test.go
@@ -1,6 +1,7 @@
package gallery_test
import (
+ "context"
"errors"
"os"
"path/filepath"
@@ -34,7 +35,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
- _, err = InstallModel(systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
+ _, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} {
@@ -88,7 +89,7 @@ var _ = Describe("Model test", func() {
Expect(models[0].URL).To(Equal(bertEmbeddingsURL))
Expect(models[0].Installed).To(BeFalse())
- err = InstallModelFromGallery(galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
+ err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true)
Expect(err).ToNot(HaveOccurred())
dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml"))
@@ -129,7 +130,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
- _, err = InstallModel(systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
+ _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -149,7 +150,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
- _, err = InstallModel(systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
+ _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true)
Expect(err).ToNot(HaveOccurred())
for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} {
@@ -179,7 +180,7 @@ var _ = Describe("Model test", func() {
system.WithModelPath(tempdir),
)
Expect(err).ToNot(HaveOccurred())
- _, err = InstallModel(systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
+ _, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true)
Expect(err).To(HaveOccurred())
})
})
diff --git a/core/http/app_test.go b/core/http/app_test.go
index c9c752df578b..2d4ff6d06571 100644
--- a/core/http/app_test.go
+++ b/core/http/app_test.go
@@ -85,7 +85,7 @@ func getModels(url string) ([]gallery.GalleryModel, error) {
response := []gallery.GalleryModel{}
uri := downloader.URI(url)
// TODO: No tests currently seem to exercise file:// urls. Fix?
- err := uri.DownloadWithAuthorizationAndCallback("", bearerKey, func(url string, i []byte) error {
+ err := uri.DownloadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
diff --git a/core/http/routes/ui_api.go b/core/http/routes/ui_api.go
index 15a4769d372c..3ea4852e08dd 100644
--- a/core/http/routes/ui_api.go
+++ b/core/http/routes/ui_api.go
@@ -1,6 +1,7 @@
package routes
import (
+ "context"
"fmt"
"math"
"net/url"
@@ -35,23 +36,35 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
progress := 0
isDeletion := false
isQueued := false
+ isCancelled := false
+ isCancellable := false
message := ""
if status != nil {
- // Skip completed operations
- if status.Processed {
+ // Skip completed operations (unless cancelled and not yet cleaned up)
+ if status.Processed && !status.Cancelled {
+ continue
+ }
+ // Skip cancelled operations that are processed (they're done, no need to show)
+ if status.Processed && status.Cancelled {
continue
}
progress = int(status.Progress)
isDeletion = status.Deletion
+ isCancelled = status.Cancelled
+ isCancellable = status.Cancellable
message = status.Message
if isDeletion {
taskType = "deletion"
}
+ if isCancelled {
+ taskType = "cancelled"
+ }
} else {
// Job is queued but hasn't started
isQueued = true
+ isCancellable = true
message = "Operation queued"
}
@@ -76,16 +89,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
}
operations = append(operations, fiber.Map{
- "id": galleryID,
- "name": displayName,
- "fullName": galleryID,
- "jobID": jobID,
- "progress": progress,
- "taskType": taskType,
- "isDeletion": isDeletion,
- "isBackend": isBackend,
- "isQueued": isQueued,
- "message": message,
+ "id": galleryID,
+ "name": displayName,
+ "fullName": galleryID,
+ "jobID": jobID,
+ "progress": progress,
+ "taskType": taskType,
+ "isDeletion": isDeletion,
+ "isBackend": isBackend,
+ "isQueued": isQueued,
+ "isCancelled": isCancelled,
+ "cancellable": isCancellable,
+ "message": message,
})
}
@@ -108,6 +123,28 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
})
+ // Cancel operation endpoint
+ app.Post("/api/operations/:jobID/cancel", func(c *fiber.Ctx) error {
+ jobID := strings.Clone(c.Params("jobID"))
+ log.Debug().Msgf("API request to cancel operation: %s", jobID)
+
+ err := galleryService.CancelOperation(jobID)
+ if err != nil {
+ log.Error().Err(err).Msgf("Failed to cancel operation: %s", jobID)
+ return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
+ "error": err.Error(),
+ })
+ }
+
+ // Clean up opcache for cancelled operation
+ opcache.DeleteUUID(jobID)
+
+ return c.JSON(fiber.Map{
+ "success": true,
+ "message": "Operation cancelled",
+ })
+ })
+
// Model Gallery APIs
app.Get("/api/models", func(c *fiber.Ctx) error {
term := c.Query("term")
@@ -248,12 +285,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
uid := id.String()
opcache.Set(galleryID, uid)
+ ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
ID: uid,
GalleryElementName: galleryID,
Galleries: appConfig.Galleries,
BackendGalleries: appConfig.BackendGalleries,
+ Context: ctx,
+ CancelFunc: cancelFunc,
}
+ // Store cancellation function immediately so queued operations can be cancelled
+ galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.ModelGalleryChannel <- op
}()
@@ -291,13 +333,18 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
opcache.Set(galleryID, uid)
+ ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{
ID: uid,
Delete: true,
GalleryElementName: galleryName,
Galleries: appConfig.Galleries,
BackendGalleries: appConfig.BackendGalleries,
+ Context: ctx,
+ CancelFunc: cancelFunc,
}
+ // Store cancellation function immediately so queued operations can be cancelled
+ galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.ModelGalleryChannel <- op
cl.RemoveModelConfig(galleryName)
@@ -341,7 +388,7 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
})
}
- _, err = gallery.InstallModel(appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
+ _, err = gallery.InstallModel(context.Background(), appConfig.SystemState, model.Name, &config, model.Overrides, nil, false)
if err != nil {
return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{
"error": err.Error(),
@@ -526,11 +573,16 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
uid := id.String()
opcache.Set(backendID, uid)
+ ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryBackend, any]{
ID: uid,
GalleryElementName: backendID,
Galleries: appConfig.BackendGalleries,
+ Context: ctx,
+ CancelFunc: cancelFunc,
}
+ // Store cancellation function immediately so queued operations can be cancelled
+ galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.BackendGalleryChannel <- op
}()
@@ -568,12 +620,17 @@ func RegisterUIAPIRoutes(app *fiber.App, cl *config.ModelConfigLoader, appConfig
opcache.Set(backendID, uid)
+ ctx, cancelFunc := context.WithCancel(context.Background())
op := services.GalleryOp[gallery.GalleryBackend, any]{
ID: uid,
Delete: true,
GalleryElementName: backendName,
Galleries: appConfig.BackendGalleries,
+ Context: ctx,
+ CancelFunc: cancelFunc,
}
+ // Store cancellation function immediately so queued operations can be cancelled
+ galleryService.StoreCancellation(uid, cancelFunc)
go func() {
galleryService.BackendGalleryChannel <- op
}()
diff --git a/core/http/views/partials/inprogress.html b/core/http/views/partials/inprogress.html
index 7ebe04927d1b..8c2dce3baebd 100644
--- a/core/http/views/partials/inprogress.html
+++ b/core/http/views/partials/inprogress.html
@@ -71,15 +71,34 @@
Queued
-
+
+
+
+ Cancelling...
+
+
+
-
-
+
+
+
+
+
+
+
+
+ Cancelled
+
+
@@ -88,8 +107,8 @@
@@ -141,6 +160,57 @@
}
},
+ async cancelOperation(jobID, operationID) {
+ // Check if operation is already cancelled
+ const operation = this.operations.find(op => op.jobID === jobID);
+ if (operation && operation.isCancelled) {
+ // Already cancelled, no need to do anything
+ return;
+ }
+
+ try {
+ const response = await fetch(`/api/operations/${jobID}/cancel`, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ },
+ });
+
+ if (!response.ok) {
+ const error = await response.json();
+ const errorMessage = error.error || 'Failed to cancel operation';
+
+ // Don't show alert for "already cancelled" - just update UI silently
+ if (errorMessage.includes('already cancelled')) {
+ if (operation) {
+ operation.isCancelled = true;
+ operation.cancellable = false;
+ }
+ this.fetchOperations();
+ return;
+ }
+
+ throw new Error(errorMessage);
+ }
+
+ // Update the operation status immediately
+ if (operation) {
+ operation.isCancelled = true;
+ operation.cancellable = false;
+ operation.message = 'Cancelling...';
+ }
+
+ // Refresh operations to get updated status
+ this.fetchOperations();
+ } catch (error) {
+ console.error('Error cancelling operation:', error);
+ // Only show alert if it's not an "already cancelled" error
+ if (!error.message.includes('already cancelled')) {
+ alert('Failed to cancel operation: ' + error.message);
+ }
+ }
+ },
+
destroy() {
if (this.pollInterval) {
clearInterval(this.pollInterval);
diff --git a/core/services/backends.go b/core/services/backends.go
index 7ffca2ac1030..6eb69bbc1346 100644
--- a/core/services/backends.go
+++ b/core/services/backends.go
@@ -1,6 +1,10 @@
package services
import (
+ "context"
+ "errors"
+ "fmt"
+
"github.com/mudler/LocalAI/core/gallery"
"github.com/mudler/LocalAI/pkg/system"
@@ -10,14 +14,43 @@ import (
func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend, any], systemState *system.SystemState) error {
utils.ResetDownloadTimers()
- g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", Progress: 0})
+
+ // Check if already cancelled
+ if op.Context != nil {
+ select {
+ case <-op.Context.Done():
+ g.UpdateStatus(op.ID, &GalleryOpStatus{
+ Cancelled: true,
+ Processed: true,
+ Message: "cancelled",
+ GalleryElementName: op.GalleryElementName,
+ })
+ return op.Context.Err()
+ default:
+ }
+ }
+
+ g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf("processing backend: %s", op.GalleryElementName), Progress: 0, Cancellable: true})
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
- g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
+ // Check for cancellation during progress updates
+ if op.Context != nil {
+ select {
+ case <-op.Context.Done():
+ return
+ default:
+ }
+ }
+ g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf(processingMessage, fileName, total, current), FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current, Cancellable: true})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
+ ctx := op.Context
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
var err error
if op.Delete {
err = gallery.DeleteBackendFromSystem(g.appConfig.SystemState, op.GalleryElementName)
@@ -25,9 +58,19 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend, an
} else {
log.Warn().Msgf("installing backend %s", op.GalleryElementName)
log.Debug().Msgf("backend galleries: %v", g.appConfig.BackendGalleries)
- err = gallery.InstallBackendFromGallery(g.appConfig.BackendGalleries, systemState, g.modelLoader, op.GalleryElementName, progressCallback, true)
+ err = gallery.InstallBackendFromGallery(ctx, g.appConfig.BackendGalleries, systemState, g.modelLoader, op.GalleryElementName, progressCallback, true)
}
if err != nil {
+ // Check if error is due to cancellation
+ if op.Context != nil && errors.Is(err, op.Context.Err()) {
+ g.UpdateStatus(op.ID, &GalleryOpStatus{
+ Cancelled: true,
+ Processed: true,
+ Message: "cancelled",
+ GalleryElementName: op.GalleryElementName,
+ })
+ return err
+ }
log.Error().Err(err).Msgf("error installing backend %s", op.GalleryElementName)
if !op.Delete {
// If we didn't install the backend, we need to make sure we don't have a leftover directory
@@ -42,6 +85,7 @@ func (g *GalleryService) backendHandler(op *GalleryOp[gallery.GalleryBackend, an
Processed: true,
GalleryElementName: op.GalleryElementName,
Message: "completed",
- Progress: 100})
+ Progress: 100,
+ Cancellable: false})
return nil
}
diff --git a/core/services/gallery.go b/core/services/gallery.go
index 2290c450d9ed..8b24be00c9e6 100644
--- a/core/services/gallery.go
+++ b/core/services/gallery.go
@@ -1,88 +1,166 @@
-package services
-
-import (
- "context"
- "fmt"
- "sync"
-
- "github.com/mudler/LocalAI/core/config"
- "github.com/mudler/LocalAI/core/gallery"
- "github.com/mudler/LocalAI/pkg/model"
- "github.com/mudler/LocalAI/pkg/system"
-)
-
-type GalleryService struct {
- appConfig *config.ApplicationConfig
- sync.Mutex
- ModelGalleryChannel chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]
- BackendGalleryChannel chan GalleryOp[gallery.GalleryBackend, any]
-
- modelLoader *model.ModelLoader
- statuses map[string]*GalleryOpStatus
-}
-
-func NewGalleryService(appConfig *config.ApplicationConfig, ml *model.ModelLoader) *GalleryService {
- return &GalleryService{
- appConfig: appConfig,
- ModelGalleryChannel: make(chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]),
- BackendGalleryChannel: make(chan GalleryOp[gallery.GalleryBackend, any]),
- modelLoader: ml,
- statuses: make(map[string]*GalleryOpStatus),
- }
-}
-
-func (g *GalleryService) UpdateStatus(s string, op *GalleryOpStatus) {
- g.Lock()
- defer g.Unlock()
- g.statuses[s] = op
-}
-
-func (g *GalleryService) GetStatus(s string) *GalleryOpStatus {
- g.Lock()
- defer g.Unlock()
-
- return g.statuses[s]
-}
-
-func (g *GalleryService) GetAllStatus() map[string]*GalleryOpStatus {
- g.Lock()
- defer g.Unlock()
-
- return g.statuses
-}
-
-func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, systemState *system.SystemState) error {
- // updates the status with an error
- var updateError func(id string, e error)
- if !g.appConfig.OpaqueErrors {
- updateError = func(id string, e error) {
- g.UpdateStatus(id, &GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
- }
- } else {
- updateError = func(id string, _ error) {
- g.UpdateStatus(id, &GalleryOpStatus{Error: fmt.Errorf("an error occurred"), Processed: true})
- }
- }
-
- go func() {
- for {
- select {
- case <-c.Done():
- return
- case op := <-g.BackendGalleryChannel:
- err := g.backendHandler(&op, systemState)
- if err != nil {
- updateError(op.ID, err)
- }
-
- case op := <-g.ModelGalleryChannel:
- err := g.modelHandler(&op, cl, systemState)
- if err != nil {
- updateError(op.ID, err)
- }
- }
- }
- }()
-
- return nil
-}
+package services
+
+import (
+ "context"
+ "fmt"
+ "sync"
+
+ "github.com/mudler/LocalAI/core/config"
+ "github.com/mudler/LocalAI/core/gallery"
+ "github.com/mudler/LocalAI/pkg/model"
+ "github.com/mudler/LocalAI/pkg/system"
+)
+
+type GalleryService struct {
+ appConfig *config.ApplicationConfig
+ sync.Mutex
+ ModelGalleryChannel chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]
+ BackendGalleryChannel chan GalleryOp[gallery.GalleryBackend, any]
+
+ modelLoader *model.ModelLoader
+ statuses map[string]*GalleryOpStatus
+ cancellations map[string]context.CancelFunc
+}
+
+func NewGalleryService(appConfig *config.ApplicationConfig, ml *model.ModelLoader) *GalleryService {
+ return &GalleryService{
+ appConfig: appConfig,
+ ModelGalleryChannel: make(chan GalleryOp[gallery.GalleryModel, gallery.ModelConfig]),
+ BackendGalleryChannel: make(chan GalleryOp[gallery.GalleryBackend, any]),
+ modelLoader: ml,
+ statuses: make(map[string]*GalleryOpStatus),
+ cancellations: make(map[string]context.CancelFunc),
+ }
+}
+
+func (g *GalleryService) UpdateStatus(s string, op *GalleryOpStatus) {
+ g.Lock()
+ defer g.Unlock()
+ g.statuses[s] = op
+}
+
+func (g *GalleryService) GetStatus(s string) *GalleryOpStatus {
+ g.Lock()
+ defer g.Unlock()
+
+ return g.statuses[s]
+}
+
+func (g *GalleryService) GetAllStatus() map[string]*GalleryOpStatus {
+ g.Lock()
+ defer g.Unlock()
+
+ return g.statuses
+}
+
+// CancelOperation cancels an in-progress operation by its ID
+func (g *GalleryService) CancelOperation(id string) error {
+ g.Lock()
+ defer g.Unlock()
+
+ // Check if operation is already cancelled
+ if status, ok := g.statuses[id]; ok && status.Cancelled {
+ return fmt.Errorf("operation %q is already cancelled", id)
+ }
+
+ cancelFunc, exists := g.cancellations[id]
+ if !exists {
+ return fmt.Errorf("operation %q not found or already completed", id)
+ }
+
+ // Cancel the operation
+ cancelFunc()
+
+ // Update status to reflect cancellation
+ if status, ok := g.statuses[id]; ok {
+ status.Cancelled = true
+ status.Processed = true
+ status.Message = "cancelled"
+ } else {
+ // Create status for queued operations that haven't started yet
+ g.statuses[id] = &GalleryOpStatus{
+ Cancelled: true,
+ Processed: true,
+ Message: "cancelled",
+ Cancellable: false,
+ }
+ }
+
+ // Clean up cancellation function
+ delete(g.cancellations, id)
+
+ return nil
+}
+
+// storeCancellation stores a cancellation function for an operation
+func (g *GalleryService) storeCancellation(id string, cancelFunc context.CancelFunc) {
+ g.Lock()
+ defer g.Unlock()
+ g.cancellations[id] = cancelFunc
+}
+
+// StoreCancellation is a public method to store a cancellation function for an operation
+// This allows cancellation functions to be stored immediately when operations are created,
+// enabling cancellation of queued operations that haven't started processing yet.
+func (g *GalleryService) StoreCancellation(id string, cancelFunc context.CancelFunc) {
+ g.storeCancellation(id, cancelFunc)
+}
+
+// removeCancellation removes a cancellation function when operation completes
+func (g *GalleryService) removeCancellation(id string) {
+ g.Lock()
+ defer g.Unlock()
+ delete(g.cancellations, id)
+}
+
+func (g *GalleryService) Start(c context.Context, cl *config.ModelConfigLoader, systemState *system.SystemState) error {
+ // updates the status with an error
+ var updateError func(id string, e error)
+ if !g.appConfig.OpaqueErrors {
+ updateError = func(id string, e error) {
+ g.UpdateStatus(id, &GalleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()})
+ }
+ } else {
+ updateError = func(id string, _ error) {
+ g.UpdateStatus(id, &GalleryOpStatus{Error: fmt.Errorf("an error occurred"), Processed: true})
+ }
+ }
+
+ go func() {
+ for {
+ select {
+ case <-c.Done():
+ return
+ case op := <-g.BackendGalleryChannel:
+ // Create context if not provided
+ if op.Context == nil {
+ op.Context, op.CancelFunc = context.WithCancel(c)
+ g.storeCancellation(op.ID, op.CancelFunc)
+ } else if op.CancelFunc != nil {
+ g.storeCancellation(op.ID, op.CancelFunc)
+ }
+ err := g.backendHandler(&op, systemState)
+ if err != nil {
+ updateError(op.ID, err)
+ }
+ g.removeCancellation(op.ID)
+
+ case op := <-g.ModelGalleryChannel:
+ // Create context if not provided
+ if op.Context == nil {
+ op.Context, op.CancelFunc = context.WithCancel(c)
+ g.storeCancellation(op.ID, op.CancelFunc)
+ } else if op.CancelFunc != nil {
+ g.storeCancellation(op.ID, op.CancelFunc)
+ }
+ err := g.modelHandler(&op, cl, systemState)
+ if err != nil {
+ updateError(op.ID, err)
+ }
+ g.removeCancellation(op.ID)
+ }
+ }
+ }()
+
+ return nil
+}
diff --git a/core/services/models.go b/core/services/models.go
index b22999f6b977..40ebbc98ee63 100644
--- a/core/services/models.go
+++ b/core/services/models.go
@@ -1,7 +1,10 @@
package services
import (
+ "context"
"encoding/json"
+ "errors"
+ "fmt"
"os"
"github.com/mudler/LocalAI/core/config"
@@ -13,22 +16,74 @@ import (
"gopkg.in/yaml.v2"
)
+const (
+ processingMessage = "processing file: %s. Total: %s. Current: %s"
+)
+
func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel, gallery.ModelConfig], cl *config.ModelConfigLoader, systemState *system.SystemState) error {
utils.ResetDownloadTimers()
- g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", Progress: 0})
+ // Check if already cancelled
+ if op.Context != nil {
+ select {
+ case <-op.Context.Done():
+ g.UpdateStatus(op.ID, &GalleryOpStatus{
+ Cancelled: true,
+ Processed: true,
+ Message: "cancelled",
+ GalleryElementName: op.GalleryElementName,
+ })
+ return op.Context.Err()
+ default:
+ }
+ }
+
+ g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf("processing model: %s", op.GalleryElementName), Progress: 0, Cancellable: true})
// displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) {
- g.UpdateStatus(op.ID, &GalleryOpStatus{Message: "processing", FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
+ // Check for cancellation during progress updates
+ if op.Context != nil {
+ select {
+ case <-op.Context.Done():
+ return
+ default:
+ }
+ }
+ g.UpdateStatus(op.ID, &GalleryOpStatus{Message: fmt.Sprintf(processingMessage, fileName, total, current), FileName: fileName, Progress: percentage, TotalFileSize: total, DownloadedFileSize: current, Cancellable: true})
utils.DisplayDownloadFunction(fileName, current, total, percentage)
}
err := processModelOperation(op, systemState, g.modelLoader, g.appConfig.EnforcePredownloadScans, g.appConfig.AutoloadBackendGalleries, progressCallback)
if err != nil {
+ // Check if error is due to cancellation
+ if op.Context != nil && errors.Is(err, op.Context.Err()) {
+ g.UpdateStatus(op.ID, &GalleryOpStatus{
+ Cancelled: true,
+ Processed: true,
+ Message: "cancelled",
+ GalleryElementName: op.GalleryElementName,
+ })
+ return err
+ }
return err
}
+ // Check for cancellation before final steps
+ if op.Context != nil {
+ select {
+ case <-op.Context.Done():
+ g.UpdateStatus(op.ID, &GalleryOpStatus{
+ Cancelled: true,
+ Processed: true,
+ Message: "cancelled",
+ GalleryElementName: op.GalleryElementName,
+ })
+ return op.Context.Err()
+ default:
+ }
+ }
+
// Reload models
err = cl.LoadModelConfigsFromPath(systemState.Model.ModelsPath)
if err != nil {
@@ -46,26 +101,27 @@ func (g *GalleryService) modelHandler(op *GalleryOp[gallery.GalleryModel, galler
Processed: true,
GalleryElementName: op.GalleryElementName,
Message: "completed",
- Progress: 100})
+ Progress: 100,
+ Cancellable: false})
return nil
}
-func installModelFromRemoteConfig(systemState *system.SystemState, modelLoader *model.ModelLoader, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery) error {
- config, err := gallery.GetGalleryConfigFromURL[gallery.ModelConfig](req.URL, systemState.Model.ModelsPath)
+func installModelFromRemoteConfig(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, req gallery.GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool, backendGalleries []config.Gallery) error {
+ config, err := gallery.GetGalleryConfigFromURLWithContext[gallery.ModelConfig](ctx, req.URL, systemState.Model.ModelsPath)
if err != nil {
return err
}
config.Files = append(config.Files, req.AdditionalFiles...)
- installedModel, err := gallery.InstallModel(systemState, req.Name, &config, req.Overrides, downloadStatus, enforceScan)
+ installedModel, err := gallery.InstallModel(ctx, systemState, req.Name, &config, req.Overrides, downloadStatus, enforceScan)
if err != nil {
return err
}
if automaticallyInstallBackend && installedModel.Backend != "" {
- if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
+ if err := gallery.InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil {
return err
}
}
@@ -79,15 +135,16 @@ type galleryModel struct {
}
func processRequests(systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, automaticallyInstallBackend bool, galleries []config.Gallery, backendGalleries []config.Gallery, requests []galleryModel) error {
+ ctx := context.Background()
var err error
for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" {
- err = installModelFromRemoteConfig(systemState, modelLoader, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries)
+ err = installModelFromRemoteConfig(ctx, systemState, modelLoader, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend, backendGalleries)
} else {
err = gallery.InstallModelFromGallery(
- galleries, backendGalleries, systemState, modelLoader, r.ID, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend)
+ ctx, galleries, backendGalleries, systemState, modelLoader, r.ID, r.GalleryModel, utils.DisplayDownloadFunction, enforceScan, automaticallyInstallBackend)
}
}
return err
@@ -126,25 +183,40 @@ func processModelOperation(
automaticallyInstallBackend bool,
progressCallback func(string, string, string, float64),
) error {
+ ctx := op.Context
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ // Check for cancellation before starting
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
switch {
case op.Delete:
return gallery.DeleteModelFromSystem(systemState, op.GalleryElementName)
case op.GalleryElement != nil:
installedModel, err := gallery.InstallModel(
- systemState, op.GalleryElement.Name,
+ ctx, systemState, op.GalleryElement.Name,
op.GalleryElement,
op.Req.Overrides,
progressCallback, enforcePredownloadScans)
+ if err != nil {
+ return err
+ }
if automaticallyInstallBackend && installedModel.Backend != "" {
log.Debug().Msgf("Installing backend %q", installedModel.Backend)
- if err := gallery.InstallBackendFromGallery(op.BackendGalleries, systemState, modelLoader, installedModel.Backend, progressCallback, false); err != nil {
+ if err := gallery.InstallBackendFromGallery(ctx, op.BackendGalleries, systemState, modelLoader, installedModel.Backend, progressCallback, false); err != nil {
return err
}
}
- return err
+ return nil
case op.GalleryElementName != "":
- return gallery.InstallModelFromGallery(op.Galleries, op.BackendGalleries, systemState, modelLoader, op.GalleryElementName, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend)
+ return gallery.InstallModelFromGallery(ctx, op.Galleries, op.BackendGalleries, systemState, modelLoader, op.GalleryElementName, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend)
default:
- return installModelFromRemoteConfig(systemState, modelLoader, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries)
+ return installModelFromRemoteConfig(ctx, systemState, modelLoader, op.Req, progressCallback, enforcePredownloadScans, automaticallyInstallBackend, op.BackendGalleries)
}
}
diff --git a/core/services/operation.go b/core/services/operation.go
index d0ba76ceb1a9..0b79f0dcbd83 100644
--- a/core/services/operation.go
+++ b/core/services/operation.go
@@ -1,6 +1,8 @@
package services
import (
+ "context"
+
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/xsync"
)
@@ -17,6 +19,10 @@ type GalleryOp[T any, E any] struct {
Galleries []config.Gallery
BackendGalleries []config.Gallery
+
+ // Context for cancellation support
+ Context context.Context
+ CancelFunc context.CancelFunc
}
type GalleryOpStatus struct {
@@ -29,6 +35,8 @@ type GalleryOpStatus struct {
TotalFileSize string `json:"file_size"`
DownloadedFileSize string `json:"downloaded_size"`
GalleryElementName string `json:"gallery_element_name"`
+ Cancelled bool `json:"cancelled"` // Cancelled is true if the operation was cancelled
+ Cancellable bool `json:"cancellable"` // Cancellable is true if the operation can be cancelled
}
type OpCache struct {
diff --git a/core/startup/backend_preload.go b/core/startup/backend_preload.go
index bf9c4ec90764..835a0bc005f6 100644
--- a/core/startup/backend_preload.go
+++ b/core/startup/backend_preload.go
@@ -1,6 +1,7 @@
package startup
import (
+ "context"
"fmt"
"path/filepath"
"strings"
@@ -13,7 +14,7 @@ import (
"github.com/rs/zerolog/log"
)
-func InstallExternalBackends(galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
+func InstallExternalBackends(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), backend, name, alias string) error {
uri := downloader.URI(backend)
switch {
case uri.LooksLikeDir():
@@ -21,7 +22,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
name = filepath.Base(backend)
}
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from path")
- if err := gallery.InstallBackend(systemState, modelLoader, &gallery.GalleryBackend{
+ if err := gallery.InstallBackend(ctx, systemState, modelLoader, &gallery.GalleryBackend{
Metadata: gallery.Metadata{
Name: name,
},
@@ -35,7 +36,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
return fmt.Errorf("specifying a name is required for OCI images")
}
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
- if err := gallery.InstallBackend(systemState, modelLoader, &gallery.GalleryBackend{
+ if err := gallery.InstallBackend(ctx, systemState, modelLoader, &gallery.GalleryBackend{
Metadata: gallery.Metadata{
Name: name,
},
@@ -53,7 +54,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
name = strings.TrimSuffix(name, filepath.Ext(name))
log.Info().Str("backend", backend).Str("name", name).Msg("Installing backend from OCI image")
- if err := gallery.InstallBackend(systemState, modelLoader, &gallery.GalleryBackend{
+ if err := gallery.InstallBackend(ctx, systemState, modelLoader, &gallery.GalleryBackend{
Metadata: gallery.Metadata{
Name: name,
},
@@ -66,7 +67,7 @@ func InstallExternalBackends(galleries []config.Gallery, systemState *system.Sys
if name != "" || alias != "" {
return fmt.Errorf("specifying a name or alias is not supported for this backend")
}
- err := gallery.InstallBackendFromGallery(galleries, systemState, modelLoader, backend, downloadStatus, true)
+ err := gallery.InstallBackendFromGallery(ctx, galleries, systemState, modelLoader, backend, downloadStatus, true)
if err != nil {
return fmt.Errorf("error installing backend %s: %w", backend, err)
}
diff --git a/core/startup/model_preload.go b/core/startup/model_preload.go
index 193aad6a2e52..9377830a4df1 100644
--- a/core/startup/model_preload.go
+++ b/core/startup/model_preload.go
@@ -1,6 +1,7 @@
package startup
import (
+ "context"
"encoding/json"
"errors"
"fmt"
@@ -30,7 +31,7 @@ const (
// InstallModels will preload models from the given list of URLs and galleries
// It will download the model if it is not already present in the model path
// It will also try to resolve if the model is an embedded model YAML configuration
-func InstallModels(galleryService *services.GalleryService, galleries, backendGalleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
+func InstallModels(ctx context.Context, galleryService *services.GalleryService, galleries, backendGalleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, enforceScan, autoloadBackendGalleries bool, downloadStatus func(string, string, string, float64), models ...string) error {
// create an error that groups all errors
var err error
@@ -53,7 +54,7 @@ func InstallModels(galleryService *services.GalleryService, galleries, backendGa
return nil
}
- if err := gallery.InstallBackendFromGallery(backendGalleries, systemState, modelLoader, model.Backend, downloadStatus, false); err != nil {
+ if err := gallery.InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, model.Backend, downloadStatus, false); err != nil {
log.Error().Err(err).Str("backend", model.Backend).Msg("error installing backend")
return err
}
@@ -153,7 +154,7 @@ func InstallModels(galleryService *services.GalleryService, galleries, backendGa
}
} else {
// Check if it's a model gallery, or print a warning
- e, found := installModel(galleries, backendGalleries, url, systemState, modelLoader, downloadStatus, enforceScan, autoloadBackendGalleries)
+ e, found := installModel(ctx, galleries, backendGalleries, url, systemState, modelLoader, downloadStatus, enforceScan, autoloadBackendGalleries)
if e != nil && found {
log.Error().Err(err).Msgf("[startup] failed installing model '%s'", url)
err = errors.Join(err, e)
@@ -210,7 +211,7 @@ func InstallModels(galleryService *services.GalleryService, galleries, backendGa
return err
}
-func installModel(galleries, backendGalleries []config.Gallery, modelName string, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
+func installModel(ctx context.Context, galleries, backendGalleries []config.Gallery, modelName string, systemState *system.SystemState, modelLoader *model.ModelLoader, downloadStatus func(string, string, string, float64), enforceScan, autoloadBackendGalleries bool) (error, bool) {
models, err := gallery.AvailableGalleryModels(galleries, systemState)
if err != nil {
return err, false
@@ -226,7 +227,7 @@ func installModel(galleries, backendGalleries []config.Gallery, modelName string
}
log.Info().Str("model", modelName).Str("license", model.License).Msg("installing model")
- err = gallery.InstallModelFromGallery(galleries, backendGalleries, systemState, modelLoader, modelName, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
+ err = gallery.InstallModelFromGallery(ctx, galleries, backendGalleries, systemState, modelLoader, modelName, gallery.GalleryModel{}, downloadStatus, enforceScan, autoloadBackendGalleries)
if err != nil {
return err, true
}
diff --git a/core/startup/model_preload_test.go b/core/startup/model_preload_test.go
index 255c3dc4e65a..54dc5507392f 100644
--- a/core/startup/model_preload_test.go
+++ b/core/startup/model_preload_test.go
@@ -1,6 +1,7 @@
package startup_test
import (
+ "context"
"fmt"
"os"
"path/filepath"
@@ -33,7 +34,7 @@ var _ = Describe("Preload test", func() {
url := "https://raw.githubusercontent.com/mudler/LocalAI-examples/main/configurations/phi-2.yaml"
fileName := fmt.Sprintf("%s.yaml", "phi-2")
- InstallModels(nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
+ InstallModels(context.TODO(), nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
resultFile := filepath.Join(tmpdir, fileName)
@@ -46,7 +47,7 @@ var _ = Describe("Preload test", func() {
url := "huggingface://TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"
fileName := fmt.Sprintf("%s.gguf", "tinyllama-1.1b-chat-v0.3.Q2_K")
- err := InstallModels(nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
+ err := InstallModels(context.TODO(), nil, []config.Gallery{}, []config.Gallery{}, systemState, ml, true, true, nil, url)
Expect(err).ToNot(HaveOccurred())
resultFile := filepath.Join(tmpdir, fileName)
diff --git a/pkg/downloader/progress.go b/pkg/downloader/progress.go
index 6cd6132b26bc..05cd07a9867b 100644
--- a/pkg/downloader/progress.go
+++ b/pkg/downloader/progress.go
@@ -1,6 +1,9 @@
package downloader
-import "hash"
+import (
+ "context"
+ "hash"
+)
type progressWriter struct {
fileName string
@@ -10,23 +13,45 @@ type progressWriter struct {
written int64
downloadStatus func(string, string, string, float64)
hash hash.Hash
+ ctx context.Context
}
func (pw *progressWriter) Write(p []byte) (n int, err error) {
+ // Check for cancellation before writing
+ if pw.ctx != nil {
+ select {
+ case <-pw.ctx.Done():
+ return 0, pw.ctx.Err()
+ default:
+ }
+ }
+
n, err = pw.hash.Write(p)
+ if err != nil {
+ return n, err
+ }
pw.written += int64(n)
+ // Check for cancellation after writing chunk
+ if pw.ctx != nil {
+ select {
+ case <-pw.ctx.Done():
+ return n, pw.ctx.Err()
+ default:
+ }
+ }
+
if pw.total > 0 {
percentage := float64(pw.written) / float64(pw.total) * 100
if pw.totalFiles > 1 {
// This is a multi-file download
// so we need to adjust the percentage
// to reflect the progress of the whole download
- // This is the file pw.fileNo of pw.totalFiles files. We assume that
+ // This is the file pw.fileNo (0-indexed) of pw.totalFiles files. We assume that
// the files before successfully downloaded.
percentage = percentage / float64(pw.totalFiles)
- if pw.fileNo > 1 {
- percentage += float64(pw.fileNo-1) * 100 / float64(pw.totalFiles)
+ if pw.fileNo > 0 {
+ percentage += float64(pw.fileNo) * 100 / float64(pw.totalFiles)
}
}
//log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%)", pw.fileName, formatBytes(pw.written), formatBytes(pw.total), percentage)
diff --git a/pkg/downloader/uri.go b/pkg/downloader/uri.go
index b6739498da9c..ea1631f4fffa 100644
--- a/pkg/downloader/uri.go
+++ b/pkg/downloader/uri.go
@@ -1,6 +1,7 @@
package downloader
import (
+ "context"
"crypto/sha256"
"errors"
"fmt"
@@ -49,10 +50,10 @@ func loadConfig() string {
}
func (uri URI) DownloadWithCallback(basePath string, f func(url string, i []byte) error) error {
- return uri.DownloadWithAuthorizationAndCallback(basePath, "", f)
+ return uri.DownloadWithAuthorizationAndCallback(context.Background(), basePath, "", f)
}
-func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorization string, f func(url string, i []byte) error) error {
+func (uri URI) DownloadWithAuthorizationAndCallback(ctx context.Context, basePath string, authorization string, f func(url string, i []byte) error) error {
url := uri.ResolveURL()
if strings.HasPrefix(url, LocalPrefix) {
@@ -83,8 +84,7 @@ func (uri URI) DownloadWithAuthorizationAndCallback(basePath string, authorizati
}
// Send a GET request to the URL
-
- req, err := http.NewRequest("GET", url, nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return err
}
@@ -264,6 +264,10 @@ func (uri URI) checkSeverSupportsRangeHeader() (bool, error) {
}
func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
+ return uri.DownloadFileWithContext(context.Background(), filePath, sha, fileN, total, downloadStatus)
+}
+
+func (uri URI) DownloadFileWithContext(ctx context.Context, filePath, sha string, fileN, total int, downloadStatus func(string, string, string, float64)) error {
url := uri.ResolveURL()
if uri.LooksLikeOCI() {
@@ -285,7 +289,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
}
if url, ok := strings.CutPrefix(url, OllamaPrefix); ok {
- return oci.OllamaFetchModel(url, filePath, progressStatus)
+ return oci.OllamaFetchModel(ctx, url, filePath, progressStatus)
}
if url, ok := strings.CutPrefix(url, OCIFilePrefix); ok {
@@ -295,7 +299,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
return fmt.Errorf("failed to open tarball: %s", err.Error())
}
- return oci.ExtractOCIImage(img, url, filePath, downloadStatus)
+ return oci.ExtractOCIImage(ctx, img, url, filePath, downloadStatus)
}
url = strings.TrimPrefix(url, OCIPrefix)
@@ -304,7 +308,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
return fmt.Errorf("failed to get image %q: %v", url, err)
}
- return oci.ExtractOCIImage(img, url, filePath, downloadStatus)
+ return oci.ExtractOCIImage(ctx, img, url, filePath, downloadStatus)
}
// We need to check if url looks like an URL or bail out
@@ -312,6 +316,13 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
return fmt.Errorf("url %q does not look like an HTTP URL", url)
}
+ // Check for cancellation before starting
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+
// Check if the file already exists
_, err := os.Stat(filePath)
if err == nil {
@@ -346,7 +357,7 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
log.Info().Msgf("Downloading %q", url)
- req, err := http.NewRequest("GET", url, nil)
+ req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request for %q: %v", filePath, err)
}
@@ -375,6 +386,12 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
// Start the request
resp, err := http.DefaultClient.Do(req)
if err != nil {
+ // Check if error is due to context cancellation
+ if errors.Is(err, context.Canceled) {
+ // Clean up partial file on cancellation
+ removePartialFile(tmpFilePath)
+ return err
+ }
return fmt.Errorf("failed to download file %q: %v", filePath, err)
}
defer resp.Body.Close()
@@ -406,12 +423,27 @@ func (uri URI) DownloadFile(filePath, sha string, fileN, total int, downloadStat
fileNo: fileN,
totalFiles: total,
downloadStatus: downloadStatus,
+ ctx: ctx,
}
_, err = io.Copy(io.MultiWriter(outFile, progress), resp.Body)
if err != nil {
+ // Check if error is due to context cancellation
+ if errors.Is(err, context.Canceled) {
+ // Clean up partial file on cancellation
+ removePartialFile(tmpFilePath)
+ return err
+ }
return fmt.Errorf("failed to write file %q: %v", filePath, err)
}
+ // Check for cancellation before finalizing
+ select {
+ case <-ctx.Done():
+ removePartialFile(tmpFilePath)
+ return ctx.Err()
+ default:
+ }
+
err = os.Rename(tmpFilePath, filePath)
if err != nil {
return fmt.Errorf("failed to rename temporary file %s -> %s: %v", tmpFilePath, filePath, err)
diff --git a/pkg/oci/blob.go b/pkg/oci/blob.go
index f0df27309993..0f5a2cf66be0 100644
--- a/pkg/oci/blob.go
+++ b/pkg/oci/blob.go
@@ -6,13 +6,14 @@ import (
"io"
"os"
+ "github.com/mudler/LocalAI/pkg/xio"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
oras "oras.land/oras-go/v2"
"oras.land/oras-go/v2/registry/remote"
)
-func FetchImageBlob(r, reference, dst string, statusReader func(ocispec.Descriptor) io.Writer) error {
+func FetchImageBlob(ctx context.Context, r, reference, dst string, statusReader func(ocispec.Descriptor) io.Writer) error {
// 0. Create a file store for the output
fs, err := os.Create(dst)
if err != nil {
@@ -21,7 +22,6 @@ func FetchImageBlob(r, reference, dst string, statusReader func(ocispec.Descript
defer fs.Close()
// 1. Connect to a remote repository
- ctx := context.Background()
repo, err := remote.NewRepository(r)
if err != nil {
return fmt.Errorf("failed to create repository: %v", err)
@@ -37,12 +37,12 @@ func FetchImageBlob(r, reference, dst string, statusReader func(ocispec.Descript
if statusReader != nil {
// 3. Write the file to the file store
- _, err = io.Copy(io.MultiWriter(fs, statusReader(desc)), reader)
+ _, err = xio.Copy(ctx, io.MultiWriter(fs, statusReader(desc)), reader)
if err != nil {
return err
}
} else {
- _, err = io.Copy(fs, reader)
+ _, err = xio.Copy(ctx, fs, reader)
if err != nil {
return err
}
diff --git a/pkg/oci/blob_test.go b/pkg/oci/blob_test.go
index 7be12ac15d5b..cef29a972228 100644
--- a/pkg/oci/blob_test.go
+++ b/pkg/oci/blob_test.go
@@ -1,6 +1,7 @@
package oci_test
import (
+ "context"
"os"
. "github.com/mudler/LocalAI/pkg/oci" // Update with your module path
@@ -14,7 +15,7 @@ var _ = Describe("OCI", func() {
f, err := os.CreateTemp("", "ollama")
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(f.Name())
- err = FetchImageBlob("registry.ollama.ai/library/gemma", "sha256:c1864a5eb19305c40519da12cc543519e48a0697ecd30e15d5ac228644957d12", f.Name(), nil)
+ err = FetchImageBlob(context.TODO(), "registry.ollama.ai/library/gemma", "sha256:c1864a5eb19305c40519da12cc543519e48a0697ecd30e15d5ac228644957d12", f.Name(), nil)
Expect(err).NotTo(HaveOccurred())
})
})
diff --git a/pkg/oci/image.go b/pkg/oci/image.go
index 0b85ec8eb90f..90d433a05b0f 100644
--- a/pkg/oci/image.go
+++ b/pkg/oci/image.go
@@ -23,6 +23,7 @@ import (
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/remote/transport"
"github.com/google/go-containerregistry/pkg/v1/tarball"
+ "github.com/mudler/LocalAI/pkg/xio"
)
// ref: https://github.com/mudler/luet/blob/master/pkg/helpers/docker/docker.go#L117
@@ -97,7 +98,7 @@ func (pw *progressWriter) Write(p []byte) (int, error) {
}
// ExtractOCIImage will extract a given targetImage into a given targetDestination
-func ExtractOCIImage(img v1.Image, imageRef string, targetDestination string, downloadStatus func(string, string, string, float64)) error {
+func ExtractOCIImage(ctx context.Context, img v1.Image, imageRef string, targetDestination string, downloadStatus func(string, string, string, float64)) error {
// Create a temporary tar file
tmpTarFile, err := os.CreateTemp("", "localai-oci-*.tar")
if err != nil {
@@ -107,13 +108,13 @@ func ExtractOCIImage(img v1.Image, imageRef string, targetDestination string, do
defer tmpTarFile.Close()
// Download the image as tar with progress tracking
- err = DownloadOCIImageTar(img, imageRef, tmpTarFile.Name(), downloadStatus)
+ err = DownloadOCIImageTar(ctx, img, imageRef, tmpTarFile.Name(), downloadStatus)
if err != nil {
return fmt.Errorf("failed to download image tar: %v", err)
}
// Extract the tar file to the target destination
- err = ExtractOCIImageFromTar(tmpTarFile.Name(), imageRef, targetDestination, downloadStatus)
+ err = ExtractOCIImageFromTar(ctx, tmpTarFile.Name(), imageRef, targetDestination, downloadStatus)
if err != nil {
return fmt.Errorf("failed to extract image tar: %v", err)
}
@@ -207,7 +208,7 @@ func GetOCIImageSize(targetImage, targetPlatform string, auth *registrytypes.Aut
// DownloadOCIImageTar downloads the compressed layers of an image and then creates an uncompressed tar
// This provides accurate size estimation and allows for later extraction
-func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, downloadStatus func(string, string, string, float64)) error {
+func DownloadOCIImageTar(ctx context.Context, img v1.Image, imageRef string, tarFilePath string, downloadStatus func(string, string, string, float64)) error {
// Get layers to calculate total compressed size for estimation
layers, err := img.Layers()
if err != nil {
@@ -267,7 +268,7 @@ func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, down
return fmt.Errorf("failed to get compressed layer: %v", err)
}
- _, err = io.Copy(writer, layerReader)
+ _, err = xio.Copy(ctx, writer, layerReader)
file.Close()
if err != nil {
return fmt.Errorf("failed to download layer %d: %v", i, err)
@@ -298,7 +299,7 @@ func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, down
// Extract uncompressed tar from local image
extractReader := mutate.Extract(localImg)
- _, err = io.Copy(tarFile, extractReader)
+ _, err = xio.Copy(ctx, tarFile, extractReader)
if err != nil {
return fmt.Errorf("failed to extract uncompressed tar: %v", err)
}
@@ -307,7 +308,7 @@ func DownloadOCIImageTar(img v1.Image, imageRef string, tarFilePath string, down
}
// ExtractOCIImageFromTar extracts an image from a previously downloaded tar file
-func ExtractOCIImageFromTar(tarFilePath, imageRef, targetDestination string, downloadStatus func(string, string, string, float64)) error {
+func ExtractOCIImageFromTar(ctx context.Context, tarFilePath, imageRef, targetDestination string, downloadStatus func(string, string, string, float64)) error {
// Open the tar file
tarFile, err := os.Open(tarFilePath)
if err != nil {
@@ -331,7 +332,7 @@ func ExtractOCIImageFromTar(tarFilePath, imageRef, targetDestination string, dow
}
// Extract the tar file
- _, err = archive.Apply(context.Background(),
+ _, err = archive.Apply(ctx,
targetDestination, reader,
archive.WithNoSameOwner())
diff --git a/pkg/oci/image_test.go b/pkg/oci/image_test.go
index 1e59d762f2f8..8b26c2b87655 100644
--- a/pkg/oci/image_test.go
+++ b/pkg/oci/image_test.go
@@ -1,6 +1,7 @@
package oci_test
import (
+ "context"
"os"
"runtime"
@@ -30,7 +31,7 @@ var _ = Describe("OCI", func() {
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(dir)
- err = ExtractOCIImage(img, imageName, dir, nil)
+ err = ExtractOCIImage(context.TODO(), img, imageName, dir, nil)
Expect(err).NotTo(HaveOccurred())
})
})
diff --git a/pkg/oci/ollama.go b/pkg/oci/ollama.go
index 79b152918c2b..b9092c18cb5a 100644
--- a/pkg/oci/ollama.go
+++ b/pkg/oci/ollama.go
@@ -1,6 +1,7 @@
package oci
import (
+ "context"
"encoding/json"
"fmt"
"io"
@@ -76,7 +77,7 @@ func OllamaModelBlob(image string) (string, error) {
return "", nil
}
-func OllamaFetchModel(image string, output string, statusWriter func(ocispec.Descriptor) io.Writer) error {
+func OllamaFetchModel(ctx context.Context, image string, output string, statusWriter func(ocispec.Descriptor) io.Writer) error {
_, repository, imageNoTag := ParseImageParts(image)
blobID, err := OllamaModelBlob(image)
@@ -84,5 +85,5 @@ func OllamaFetchModel(image string, output string, statusWriter func(ocispec.Des
return err
}
- return FetchImageBlob(fmt.Sprintf("registry.ollama.ai/%s/%s", repository, imageNoTag), blobID, output, statusWriter)
+ return FetchImageBlob(ctx, fmt.Sprintf("registry.ollama.ai/%s/%s", repository, imageNoTag), blobID, output, statusWriter)
}
diff --git a/pkg/oci/ollama_test.go b/pkg/oci/ollama_test.go
index e2144aa74546..fbda69e6b40e 100644
--- a/pkg/oci/ollama_test.go
+++ b/pkg/oci/ollama_test.go
@@ -1,6 +1,7 @@
package oci_test
import (
+ "context"
"os"
. "github.com/mudler/LocalAI/pkg/oci" // Update with your module path
@@ -14,7 +15,7 @@ var _ = Describe("OCI", func() {
f, err := os.CreateTemp("", "ollama")
Expect(err).NotTo(HaveOccurred())
defer os.RemoveAll(f.Name())
- err = OllamaFetchModel("gemma:2b", f.Name(), nil)
+ err = OllamaFetchModel(context.TODO(), "gemma:2b", f.Name(), nil)
Expect(err).NotTo(HaveOccurred())
})
})
diff --git a/pkg/xio/copy.go b/pkg/xio/copy.go
new file mode 100644
index 000000000000..93aaee38a751
--- /dev/null
+++ b/pkg/xio/copy.go
@@ -0,0 +1,21 @@
+package xio
+
+import (
+ "context"
+ "io"
+)
+
+type readerFunc func(p []byte) (n int, err error)
+
+func (rf readerFunc) Read(p []byte) (n int, err error) { return rf(p) }
+
+func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
+ return io.Copy(dst, readerFunc(func(p []byte) (int, error) {
+ select {
+ case <-ctx.Done():
+ return 0, ctx.Err()
+ default:
+ return src.Read(p)
+ }
+ }))
+}