Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: pkg/downloader should respect basePath for file:// urls #2481

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ func getModelStatus(url string) (response map[string]interface{}) {
}

func getModels(url string) (response []gallery.GalleryModel) {
downloader.GetURI(url, func(url string, i []byte) error {
// TODO: No tests currently seem to exercise file:// urls. Fix?
downloader.GetURI(url, "", func(url string, i []byte) error {
// Unmarshal YAML data into a struct
return json.Unmarshal(i, &response)
})
Expand Down
2 changes: 1 addition & 1 deletion core/services/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func NewGalleryService(modelPath string) *GalleryService {

func prepareModel(modelPath string, req gallery.GalleryModel, cl *config.BackendConfigLoader, downloadStatus func(string, string, string, float64)) error {

config, err := gallery.GetGalleryConfigFromURL(req.URL)
config, err := gallery.GetGalleryConfigFromURL(req.URL, modelPath)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions embedded/embedded.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ func init() {
}
}

func GetRemoteLibraryShorteners(url string) (map[string]string, error) {
func GetRemoteLibraryShorteners(url string, basePath string) (map[string]string, error) {
remoteLibrary := map[string]string{}

err := downloader.GetURI(url, func(_ string, i []byte) error {
err := downloader.GetURI(url, basePath, func(_ string, i []byte) error {
return yaml.Unmarshal(i, &remoteLibrary)
})
if err != nil {
Expand Down
7 changes: 6 additions & 1 deletion pkg/downloader/uri.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const (
GithubURI2 = "github://"
)

func GetURI(url string, f func(url string, i []byte) error) error {
func GetURI(url string, basePath string, f func(url string, i []byte) error) error {
url = ConvertURL(url)

if strings.HasPrefix(url, "file://") {
Expand All @@ -33,6 +33,11 @@ func GetURI(url string, f func(url string, i []byte) error) error {
if err != nil {
return err
}
// Check if the local file is rooted in basePath
err = utils.VerifyPath(resolvedFile, basePath)
if err != nil {
return err
}
// Read the response body
body, err := os.ReadFile(resolvedFile)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions pkg/downloader/uri_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ var _ = Describe("Gallery API tests", func() {
Context("URI", func() {
It("parses github with a branch", func() {
Expect(
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", func(url string, i []byte) error {
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", "", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
).ToNot(HaveOccurred())
})
It("parses github without a branch", func() {
Expect(
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", func(url string, i []byte) error {
GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", "", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
).ToNot(HaveOccurred())
})
It("parses github with urls", func() {
Expect(
GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", func(url string, i []byte) error {
GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", "", func(url string, i []byte) error {
Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"))
return nil
}),
Expand Down
10 changes: 5 additions & 5 deletions pkg/gallery/gallery.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,

if len(model.URL) > 0 {
var err error
config, err = GetGalleryConfigFromURL(model.URL)
config, err = GetGalleryConfigFromURL(model.URL, basePath)
if err != nil {
return err
}
Expand Down Expand Up @@ -142,9 +142,9 @@ func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryMod
return models, nil
}

func findGalleryURLFromReferenceURL(url string) (string, error) {
func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) {
var refFile string
err := downloader.GetURI(url, func(url string, d []byte) error {
err := downloader.GetURI(url, basePath, func(url string, d []byte) error {
refFile = string(d)
if len(refFile) == 0 {
return fmt.Errorf("invalid reference file at url %s: %s", url, d)
Expand All @@ -161,13 +161,13 @@ func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error)

if strings.HasSuffix(gallery.URL, ".ref") {
var err error
gallery.URL, err = findGalleryURLFromReferenceURL(gallery.URL)
gallery.URL, err = findGalleryURLFromReferenceURL(gallery.URL, basePath)
if err != nil {
return models, err
}
}

err := downloader.GetURI(gallery.URL, func(url string, d []byte) error {
err := downloader.GetURI(gallery.URL, basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &models)
})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions pkg/gallery/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ type PromptTemplate struct {
Content string `yaml:"content"`
}

func GetGalleryConfigFromURL(url string) (Config, error) {
func GetGalleryConfigFromURL(url string, basePath string) (Config, error) {
var config Config
err := downloader.GetURI(url, func(url string, d []byte) error {
err := downloader.GetURI(url, basePath, func(url string, d []byte) error {
return yaml.Unmarshal(d, &config)
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/gallery/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ var _ = Describe("Gallery API tests", func() {
Context("requests", func() {
It("parses github with a branch", func() {
req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"}
e, err := GetGalleryConfigFromURL(req.URL)
e, err := GetGalleryConfigFromURL(req.URL, "")
Expect(err).ToNot(HaveOccurred())
Expect(e.Name).To(Equal("gpt4all-j"))
})
Expand Down
2 changes: 1 addition & 1 deletion pkg/startup/model_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, model
// As a best effort, try to resolve the model from the remote library
// if it's not resolved we try with the other method below
if modelLibraryURL != "" {
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL)
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL, modelPath)
if err == nil {
if lib[url] != "" {
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
Expand Down