Skip to content

Commit

Permalink
Pull ollama models during evaluation
Browse files Browse the repository at this point in the history
Closes #283
  • Loading branch information
Munsio committed Jul 25, 2024
1 parent f2a556b commit 1968501
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
17 changes: 17 additions & 0 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
} else {
for _, model := range command.Models {
p := strings.SplitN(model, provider.ProviderModelSeparator, 2)[0]

if _, ok := providersSelected[p]; ok {
continue
}

if provider, ok := provider.Providers[p]; !ok {
command.logger.Panicf("Provider %q does not exist", p)
} else {
Expand Down Expand Up @@ -361,6 +366,18 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
}()
}

// Check if a provider has the ability to pull models and do so if necessary.
if puller, ok := p.(provider.Puller); ok {
command.logger.Printf("Pulling available models for provider %q", p.ID())
for _, modelID := range command.Models {
if strings.HasPrefix(modelID, p.ID()) {
if err := puller.Pull(command.logger, modelID); err != nil {
command.logger.Panicf("ERROR: could not pull model %q: %s", modelID, err)
}
}
}
}

ms, err := p.Models()
if err != nil {
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
Expand Down
14 changes: 0 additions & 14 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,23 +491,9 @@ func TestEvaluateExecute(t *testing.T) {
}

{
var shutdown func() (err error)
defer func() { // Defer the shutdown in case there is a panic.
if shutdown != nil {
require.NoError(t, shutdown())
}
}()
validate(t, &testCase{
Name: "Pulled Model",

Before: func(t *testing.T, logger *log.Logger, resultPath string) {
var err error
shutdown, err = tools.OllamaStart(logger, tools.OllamaPath, tools.OllamaURL)
require.NoError(t, err)

require.NoError(t, tools.OllamaPull(logger, tools.OllamaPath, tools.OllamaURL, providertesting.OllamaTestModel))
},

Arguments: []string{
"--language", "golang",
"--model", "ollama/" + providertesting.OllamaTestModel,
Expand Down

0 comments on commit 1968501

Please sign in to comment.