Skip to content

Commit

Permalink
fix: pkg/downloader should respect basePath for file:// urls (#2481)
Browse files Browse the repository at this point in the history
* pass basePath down to pkg/downloader

Signed-off-by: Dave Lee <dave@gray101.com>

* enforce

Signed-off-by: Dave Lee <dave@gray101.com>

---------

Signed-off-by: Dave Lee <dave@gray101.com>
  • Loading branch information
dave-gray101 authored Jun 4, 2024
1 parent bdd6769 commit 2fc6fe8
Show file tree
Hide file tree
Showing 9 changed files with 23 additions and 17 deletions.
3 changes: 2 additions & 1 deletion core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ func getModelStatus(url string) (response map[string]interface{}) {
}

func getModels(url string) (response []gallery.GalleryModel) {
downloader.GetURI(url, func(url string, i []byte) error {
// TODO: No tests currently seem to exercise file:// urls. Fix?
downloader.GetURI(url, "", func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
Expand Down
2 changes: 1 addition & 1 deletion core/services/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func NewGalleryService(modelPath string) *GalleryService {

func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error {

config, err := gallery.GetGalleryConfigFromURL(req.URL)
config, err := gallery.GetGalleryConfigFromURL(req.URL, modelPath)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions embedded/embedded.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ func init() {
}
}

func GetRemoteLibraryShorteners(url string) (map[string]string, error) {
func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) {
remoteLibrary := map[string]string{}

err := downloader.GetURI(url, func(_ string, i []byte) error {
err := downloader.GetURI(url, basePath, func(_ string, i []byte) error {
return yaml.Unmarshal(i, &remoteLibrary)
})
if err != nil {
Expand Down
7 changes: 6 additions & 1 deletion pkg/downloader/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const (
GithubURI2 = "github://"
)

func GetURI(url string, f func(url string, i []byte) error) error {
func GetURI(url string, basePath string, f func(url string, i []byte) error) error {
url = ConvertURL(url)

if strings.HasPrefix(url, "file://") {
Expand All @@ -33,6 +33,11 @@ func GetURI(url string, f func(url string, i []byte) error) error {
if err != nil {
return err
}
// Check if the local file is rooted in basePath
err = utils.VerifyPath(resolvedFile, basePath)
if err != nil {
return err
}
// Read the response body
body, err := os.ReadFile(resolvedFile)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions pkg/downloader/uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ var _ = Describe("Gallery API tests", func() {
Context("URI", func() {
It("parses github with a branch", func() {
Expect(
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", func(url string, i []byte) error {
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", "", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
).ToNot(HaveOccurred())
})
It("parses github without a branch", func() {
Expect(
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", func(url string, i []byte) error {
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", "", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
).ToNot(HaveOccurred())
})
It("parses github with urls", func() {
Expect(
GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", func(url string, i []byte) error {
GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", "", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
Expand Down
10 changes: 5 additions & 5 deletions pkg/gallery/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,

if len(model.URL) > 0 {
var err error
config, err = GetGalleryConfigFromURL(model.URL)
config, err = GetGalleryConfigFromURL(model.URL, basePath)
if err != nil {
return err
}
Expand Down Expand Up @@ -142,9 +142,9 @@ func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryMod
return models, nil
}

func findGalleryURLFromReferenceURL(url string) (string, error) {
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
var refFile string
err := downloader.GetURI(url, func(url string, d []byte) error {
err := downloader.GetURI(url, basePath, func(url string, d []byte) error {
refFile = string(d)
if len(refFile) == 0 {
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
Expand All @@ -161,13 +161,13 @@ func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error)

if strings.HasSuffix(gallery.URL, ".ref") {
var err error
gallery.URL, err = findGalleryURLFromReferenceURL(gallery.URL)
gallery.URL, err = findGalleryURLFromReferenceURL(gallery.URL, basePath)
if err != nil {
return models, err
}
}

err := downloader.GetURI(gallery.URL, func(url string, d []byte) error {
err := downloader.GetURI(gallery.URL, basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/gallery/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ type PromptTemplate struct {
Content string `yaml:"content"`
}

func GetGalleryConfigFromURL(url string) (Config, error) {
func GetGalleryConfigFromURL(url string, basePath string) (Config, error) {
var config Config
err := downloader.GetURI(url, func(url string, d []byte) error {
err := downloader.GetURI(url, basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/gallery/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() {
Context("requests", func() {
It("parses github with a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
e, err := GetGalleryConfigFromURL(req.URL)
e, err := GetGalleryConfigFromURL(req.URL, "")
Expect(err).ToNot(HaveOccurred())
Expect(e.Name).To(Equal("gpt4all-j"))
})
Expand Down
2 changes: 1 addition & 1 deletion pkg/startup/model_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, model
// As a best effort, try to resolve the model from the remote library
// if it's not resolved we try with the other method below
if modelLibraryURL != "" {
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL)
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL, modelPath)
if err == nil {
if lib[url] != "" {
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
Expand Down

0 comments on commit 2fc6fe8

Please sign in to comment.