Skip to content

Commit

Permalink
Allow the repository to specify a list of evaluation tasks
Browse files Browse the repository at this point in the history
Part of #165
  • Loading branch information
ahumenberger committed Jun 11, 2024
1 parent 167da7f commit f1f0536
Show file tree
Hide file tree
Showing 12 changed files with 567 additions and 283 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,24 @@ Multiple concrete cases (or candidates) exist for a given task that each represe

Completing a task-case rewards points depending on the quality of the result. This, of course, depends on which criteria make the solution to a task a "good" solution, but the general rule is that the more points - the better. For example, the unit tests generated by a model might actually be compiling, yielding points that set the model apart from other models that generate only non-compiling code.

### Running specific tasks

### Via repository

Each repository can contain a configuration file `repository.json` in its root directory specifying a list of tasks which are supposed to be run for this repository.

```json
{
"tasks": [
"write-tests"
]
}
```

For the evaluation of the repository only the specified tasks are executed. If no `repository.json` file exists, all tasks are executed.

## Tasks

### Task: Test Generation

Test generation is the task of generating a test suite for a given source code example.
Expand Down
2 changes: 1 addition & 1 deletion cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ func (command *Evaluate) Execute(args []string) (err error) {
}

// WriteCSVs writes the various CSV reports to disk.
func writeCSVs(resultPath string, assessments report.AssessmentPerModelPerLanguagePerRepository) (err error) {
func writeCSVs(resultPath string, assessments report.AssessmentPerModelPerLanguagePerRepositoryPerTask) (err error) {
// Write the "evaluation.csv" containing all data.
csv, err := report.GenerateCSV(assessments)
if err != nil {
Expand Down
19 changes: 10 additions & 9 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing"
"github.com/symflower/eval-dev-quality/log"
providertesting "github.com/symflower/eval-dev-quality/provider/testing"
"github.com/symflower/eval-dev-quality/task"
"github.com/symflower/eval-dev-quality/tools"
toolstesting "github.com/symflower/eval-dev-quality/tools/testing"
)
Expand Down Expand Up @@ -259,7 +260,7 @@ func TestEvaluateExecute(t *testing.T) {
"README.md": func(t *testing.T, filePath, data string) {
validateReportLinks(t, data, []string{"symflower_symbolic-execution"})
},
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
},
})
validate(t, &testCase{
Expand Down Expand Up @@ -363,8 +364,8 @@ func TestEvaluateExecute(t *testing.T) {
"README.md": func(t *testing.T, filePath, data string) {
validateReportLinks(t, data, []string{"symflower_symbolic-execution"})
},
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join("symflower_symbolic-execution", "java", "java", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "java", "java", "plain.log"): nil,
},
})
})
Expand Down Expand Up @@ -449,7 +450,7 @@ func TestEvaluateExecute(t *testing.T) {
"README.md": func(t *testing.T, filePath, data string) {
validateReportLinks(t, data, []string{"symflower_symbolic-execution"})
},
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
},
})
validate(t, &testCase{
Expand Down Expand Up @@ -517,7 +518,7 @@ func TestEvaluateExecute(t *testing.T) {
"README.md": func(t *testing.T, filePath, data string) {
validateReportLinks(t, data, []string{"symflower_symbolic-execution"})
},
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
},
})
})
Expand Down Expand Up @@ -583,7 +584,7 @@ func TestEvaluateExecute(t *testing.T) {
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
"ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil,
string(task.IdentifierWriteTests) + "/ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil,
},
})
}
Expand Down Expand Up @@ -631,7 +632,7 @@ func TestEvaluateExecute(t *testing.T) {
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
"custom-ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil,
string(task.IdentifierWriteTests) + "/custom-ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil,
},
})
}
Expand Down Expand Up @@ -689,7 +690,7 @@ func TestEvaluateExecute(t *testing.T) {
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): func(t *testing.T, filePath, data string) {
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): func(t *testing.T, filePath, data string) {
assert.Equal(t, 3, strings.Count(data, `Evaluating model "symflower/symbolic-execution"`))
},
},
Expand Down Expand Up @@ -726,7 +727,7 @@ func TestEvaluateExecute(t *testing.T) {
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
},
})
}
86 changes: 45 additions & 41 deletions evaluate/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ func (ctx *Context) runsAtModelLevel() uint {
const RepositoryPlainName = "plain"

// Evaluate runs an evaluation on the given context and returns its results.
func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePerRepository, totalScore uint64) {
func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePerRepositoryPerTask, totalScore uint64) {
// Check that models and languages can be evaluated by executing the "plain" repositories.
modelSucceededBasicChecksOfLanguage := map[evalmodel.Model]map[evallanguage.Language]bool{}
ctx.Log.Printf("Checking that models and languages can be used for evaluation")
// Ensure we report metrics for every model even if they are excluded.
assessments = report.NewAssessmentPerModelPerLanguagePerRepository(ctx.Models, ctx.Languages, ctx.RepositoryPaths)
assessments = report.NewAssessmentPerModelPerLanguagePerRepositoryPerTask()
problemsPerModel := map[string][]error{}

{
Expand Down Expand Up @@ -106,29 +106,31 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer
r.SetQueryAttempts(ctx.QueryAttempts)
}

withLoadedModel(ctx.Log, model, ctx.ProviderForModel[model], func() {
for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
for _, taskIdentifier := range temporaryRepository.Tasks {
withLoadedModel(ctx.Log, model, ctx.ProviderForModel[model], func() {
for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
}

if err := temporaryRepository.Reset(ctx.Log); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

assessment, ps, err := temporaryRepository.Evaluate(ctx.Log, ctx.ResultPath, model, language, taskIdentifier)
assessments.Add(model, language, repositoryPath, taskIdentifier, assessment)
if err != nil {
ps = append(ps, err)
}
if len(ps) > 0 {
ctx.Log.Printf("Model %q was not able to solve the %q repository for language %q: %+v", modelID, repositoryPath, languageID, ps)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
} else {
modelSucceededBasicChecksOfLanguage[model][language] = true
}
}

if err := temporaryRepository.Reset(ctx.Log); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

assessment, ps, err := temporaryRepository.Evaluate(ctx.Log, ctx.ResultPath, model, language)
assessments[model][language][repositoryPath].Add(assessment)
if err != nil {
ps = append(ps, err)
}
if len(ps) > 0 {
ctx.Log.Printf("Model %q was not able to solve the %q repository for language %q: %+v", modelID, repositoryPath, languageID, ps)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
} else {
modelSucceededBasicChecksOfLanguage[model][language] = true
}
}
})
})
}
}
}
}
Expand Down Expand Up @@ -196,24 +198,26 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer

continue
}
withLoadedModel(ctx.Log, model, ctx.ProviderForModel[model], func() {
for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
}

if err := temporaryRepository.Reset(ctx.Log); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

assessment, ps, err := temporaryRepository.Evaluate(ctx.Log, ctx.ResultPath, model, language)
assessments[model][language][repositoryPath].Add(assessment)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
if err != nil {
ctx.Log.Printf("ERROR: Model %q encountered a hard error for language %q, repository %q: %+v", modelID, languageID, repositoryPath, err)
for _, taskIdentifier := range temporaryRepository.Tasks {
withLoadedModel(ctx.Log, model, ctx.ProviderForModel[model], func() {
for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
}

if err := temporaryRepository.Reset(ctx.Log); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

assessment, ps, err := temporaryRepository.Evaluate(ctx.Log, ctx.ResultPath, model, language, taskIdentifier)
assessments.Add(model, language, repositoryPath, taskIdentifier, assessment)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
if err != nil {
ctx.Log.Printf("ERROR: Model %q encountered a hard error for language %q, repository %q: %+v", modelID, languageID, repositoryPath, err)
}
}
}
})
})
}
}
}
}
Expand Down
Loading

0 comments on commit f1f0536

Please sign in to comment.