Skip to content

Commit

Permalink
Shut down the Ollama service if there are no Ollama models to evaluat…
Browse files Browse the repository at this point in the history
…e, to avoid unnecessary processes being opened

Closes #255
  • Loading branch information
ruiAzevedo19 committed Jul 15, 2024
1 parent 279901e commit 8e5de7d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
32 changes: 26 additions & 6 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
{
models := map[string]model.Model{}
modelsSelected := map[string]model.Model{}
providersSelected := map[provider.Provider]bool{}
evaluationContext.ProviderForModel = map[model.Model]provider.Provider{}
providersShutdown := map[provider.Provider]func() error{}
for _, p := range provider.Providers {
command.logger.Printf("Checking provider %q for models", p.ID())

Expand All @@ -325,11 +327,7 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
if err != nil {
command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err)
}
defer func() {
if err := shutdown(); err != nil {
command.logger.Panicf("ERROR: could not shutdown services of provider %q: %s", p, err)
}
}()
providersShutdown[p] = shutdown
}

ms, err := p.Models()
Expand All @@ -355,7 +353,29 @@ func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.
}
sort.Strings(command.Models)
for _, modelID := range command.Models {
modelsSelected[modelID] = models[modelID]
model := models[modelID]
modelsSelected[modelID] = model
if provider, ok := evaluationContext.ProviderForModel[model]; ok {
providersSelected[provider] = true
}
}

// Shut down services that are no longer needed.
for provider, shutdown := range providersShutdown {
if _, ok := providersSelected[provider]; !ok {
if err := shutdown(); err != nil {
command.logger.Panicf("ERROR: could not shutdown services of provider %q: %s", provider, err)
} else {
command.logger.Printf("Shutting down %s services since they are no longer needed", provider.ID())
}
} else {
currentProvider, currentShutdown := provider, shutdown
defer func() {
if err := currentShutdown(); err != nil {
command.logger.Panicf("ERROR: could not shutdown services of provider %q: %s", currentProvider, err)
}
}()
}
}

// Make the resolved selected models available in the command.
Expand Down
28 changes: 28 additions & 0 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,31 @@ func TestEvaluateExecute(t *testing.T) {
filepath.Join("result-directory", "README.md"): nil,
filepath.Join("result-directory", string(evaluatetask.IdentifierWriteTests), "ollama_"+log.CleanModelNameForFileSystem(providertesting.OllamaTestModel), "golang", "golang", "plain.log"): nil,
},
ExpectedOutputValidate: func(t *testing.T, output, resultPath string) {
assert.NotContains(t, output, "Shutting down ollama services since they are no longer needed")
},
})
}
{
validate(t, &testCase{
Name: "Shutdown Ollama if no model needs it",

Arguments: []string{
"--language", "golang",
"--model", "symflower/symbolic-execution",
"--repository", filepath.Join("golang", "plain"),
},

ExpectedResultFiles: map[string]func(t *testing.T, filePath string, data string){
filepath.Join("result-directory", "categories.svg"): nil,
filepath.Join("result-directory", "evaluation.csv"): nil,
filepath.Join("result-directory", "evaluation.log"): nil,
filepath.Join("result-directory", "README.md"): nil,
filepath.Join("result-directory", string(evaluatetask.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
},
ExpectedOutputValidate: func(t *testing.T, output, resultPath string) {
assert.Contains(t, output, "Shutting down ollama services since they are no longer needed")
},
})
}
})
Expand Down Expand Up @@ -571,6 +596,9 @@ func TestEvaluateExecute(t *testing.T) {
filepath.Join("result-directory", "README.md"): nil,
filepath.Join("result-directory", string(evaluatetask.IdentifierWriteTests), "custom-ollama_"+log.CleanModelNameForFileSystem(providertesting.OllamaTestModel), "golang", "golang", "plain.log"): nil,
},
ExpectedOutputValidate: func(t *testing.T, output, resultPath string) {
assert.NotContains(t, output, "Shutting down ollama services since they are no longer needed")
},
})
}
})
Expand Down

0 comments on commit 8e5de7d

Please sign in to comment.