Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce the concept of "tasks" to prepare for different evaluation tasks like "write tests" and "repair code" #166

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,24 @@ Multiple concrete cases (or candidates) exist for a given task that each represe

Completing a task-case rewards points depending on the quality of the result. This, of course, depends on which criteria make the solution to a task a "good" solution, but the general rule is that the more points - the better. For example, the unit tests generated by a model might actually be compiling, yielding points that set the model apart from other models that generate only non-compiling code.

### Running specific tasks

### Via repository

Each repository can contain a configuration file `repository.json` in its root directory specifying a list of tasks which are supposed to be run for this repository.

```json
{
"tasks": [
"write-tests"
]
}
```

For the evaluation of the repository only the specified tasks are executed. If no `repository.json` file exists, all tasks are executed.

## Tasks

### Task: Test Generation

Test generation is the task of generating a test suite for a given source code example.
Expand Down
2 changes: 1 addition & 1 deletion cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ func (command *Evaluate) Execute(args []string) (err error) {
}

// WriteCSVs writes the various CSV reports to disk.
func writeCSVs(resultPath string, assessments report.AssessmentPerModelPerLanguagePerRepository) (err error) {
func writeCSVs(resultPath string, assessments report.AssessmentPerModelPerLanguagePerRepositoryPerTask) (err error) {
// Write the "evaluation.csv" containing all data.
csv, err := report.GenerateCSV(assessments)
if err != nil {
Expand Down
19 changes: 10 additions & 9 deletions cmd/eval-dev-quality/cmd/evaluate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing"
"github.com/symflower/eval-dev-quality/log"
providertesting "github.com/symflower/eval-dev-quality/provider/testing"
"github.com/symflower/eval-dev-quality/task"
"github.com/symflower/eval-dev-quality/tools"
toolstesting "github.com/symflower/eval-dev-quality/tools/testing"
)
Expand Down Expand Up @@ -259,7 +260,7 @@ func TestEvaluateExecute(t *testing.T) {
"README.md": func(t *testing.T, filePath, data string) {
validateReportLinks(t, data, []string{"symflower_symbolic-execution"})
},
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
},
})
validate(t, &testCase{
Expand Down Expand Up @@ -363,8 +364,8 @@ func TestEvaluateExecute(t *testing.T) {
"README.md": func(t *testing.T, filePath, data string) {
validateReportLinks(t, data, []string{"symflower_symbolic-execution"})
},
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join("symflower_symbolic-execution", "java", "java", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "java", "java", "plain.log"): nil,
},
})
})
Expand Down Expand Up @@ -449,7 +450,7 @@ func TestEvaluateExecute(t *testing.T) {
"README.md": func(t *testing.T, filePath, data string) {
validateReportLinks(t, data, []string{"symflower_symbolic-execution"})
},
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
},
})
validate(t, &testCase{
Expand Down Expand Up @@ -517,7 +518,7 @@ func TestEvaluateExecute(t *testing.T) {
"README.md": func(t *testing.T, filePath, data string) {
validateReportLinks(t, data, []string{"symflower_symbolic-execution"})
},
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
},
})
})
Expand Down Expand Up @@ -583,7 +584,7 @@ func TestEvaluateExecute(t *testing.T) {
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
"ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil,
string(task.IdentifierWriteTests) + "/ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil,
},
})
}
Expand Down Expand Up @@ -631,7 +632,7 @@ func TestEvaluateExecute(t *testing.T) {
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
"custom-ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil,
string(task.IdentifierWriteTests) + "/custom-ollama_" + providertesting.OllamaTestModel + "/golang/golang/plain.log": nil,
},
})
}
Expand Down Expand Up @@ -689,7 +690,7 @@ func TestEvaluateExecute(t *testing.T) {
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): func(t *testing.T, filePath, data string) {
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): func(t *testing.T, filePath, data string) {
assert.Equal(t, 3, strings.Count(data, `Evaluating model "symflower/symbolic-execution"`))
},
},
Expand Down Expand Up @@ -726,7 +727,7 @@ func TestEvaluateExecute(t *testing.T) {
"golang-summed.csv": nil,
"models-summed.csv": nil,
"README.md": nil,
filepath.Join("symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
filepath.Join(string(task.IdentifierWriteTests), "symflower_symbolic-execution", "golang", "golang", "plain.log"): nil,
},
})
}
102 changes: 53 additions & 49 deletions evaluate/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,27 @@ func (ctx *Context) runsAtModelLevel() uint {
const RepositoryPlainName = "plain"

// Evaluate runs an evaluation on the given context and returns its results.
func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePerRepository, totalScore uint64) {
func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePerRepositoryPerTask, totalScore uint64) {
// Check that models and languages can be evaluated by executing the "plain" repositories.
modelSucceededBasicChecksOfLanguage := map[evalmodel.Model]map[evallanguage.Language]bool{}
ctx.Log.Printf("Checking that models and languages can be used for evaluation")
// Ensure we report metrics for every model even if they are excluded.
assessments = report.NewAssessmentPerModelPerLanguagePerRepository(ctx.Models, ctx.Languages, ctx.RepositoryPaths)
assessments = report.NewAssessmentPerModelPerLanguagePerRepositoryPerTask()
problemsPerModel := map[string][]error{}

{
// Create temporary repositories for each language so the repository is copied only once per language.
temporaryRepositories := map[string]string{}
temporaryRepositories := map[string]*Repository{}
for _, language := range ctx.Languages {
repositoryPath := filepath.Join(language.ID(), RepositoryPlainName)
temporaryRepositoryPath, cleanup, err := TemporaryRepository(ctx.Log, filepath.Join(ctx.TestdataPath, repositoryPath))
temporaryRepository, cleanup, err := TemporaryRepository(ctx.Log, ctx.TestdataPath, repositoryPath)
if err != nil {
ctx.Log.Panicf("ERROR: unable to create temporary repository path: %+v", err)
}

defer cleanup()

temporaryRepositories[repositoryPath] = temporaryRepositoryPath
temporaryRepositories[repositoryPath] = temporaryRepository
}
for rl := uint(0); rl < ctx.runsAtLanguageLevel(); rl++ {
if ctx.Runs > 1 && !ctx.RunsSequential {
Expand All @@ -93,7 +93,7 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer
for _, language := range ctx.Languages {
languageID := language.ID()
repositoryPath := filepath.Join(language.ID(), RepositoryPlainName)
temporaryRepositoryPath := temporaryRepositories[repositoryPath]
temporaryRepository := temporaryRepositories[repositoryPath]

for _, model := range ctx.Models {
modelID := model.ID()
Expand All @@ -106,29 +106,31 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer
r.SetQueryAttempts(ctx.QueryAttempts)
}

withLoadedModel(ctx.Log, model, ctx.ProviderForModel[model], func() {
for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
for _, taskIdentifier := range temporaryRepository.Tasks {
withLoadedModel(ctx.Log, model, ctx.ProviderForModel[model], func() {
for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
}

if err := temporaryRepository.Reset(ctx.Log); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

assessment, ps, err := temporaryRepository.Evaluate(ctx.Log, ctx.ResultPath, model, language, taskIdentifier)
assessments.Add(model, language, repositoryPath, taskIdentifier, assessment)
if err != nil {
ps = append(ps, err)
}
if len(ps) > 0 {
ctx.Log.Printf("Model %q was not able to solve the %q repository for language %q: %+v", modelID, repositoryPath, languageID, ps)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
} else {
modelSucceededBasicChecksOfLanguage[model][language] = true
}
}

if err := ResetTemporaryRepository(ctx.Log, temporaryRepositoryPath); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

assessment, ps, err := Repository(ctx.Log, ctx.ResultPath, model, language, temporaryRepositoryPath, repositoryPath)
assessments[model][language][repositoryPath].Add(assessment)
if err != nil {
ps = append(ps, err)
}
if len(ps) > 0 {
ctx.Log.Printf("Model %q was not able to solve the %q repository for language %q: %+v", modelID, repositoryPath, languageID, ps)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
} else {
modelSucceededBasicChecksOfLanguage[model][language] = true
}
}
})
})
}
}
}
}
Expand All @@ -142,7 +144,7 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer
// Evaluating models and languages.
ctx.Log.Printf("Evaluating models and languages")
// Create temporary repositories for each language so the repository is copied only once per language.
temporaryRepositories := map[string]string{}
temporaryRepositories := map[string]*Repository{}
for _, language := range ctx.Languages {
languagePath := filepath.Join(ctx.TestdataPath, language.ID())
repositories, err := os.ReadDir(languagePath)
Expand All @@ -151,14 +153,14 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer
}
for _, repository := range repositories {
repositoryPath := filepath.Join(language.ID(), repository.Name())
temporaryRepositoryPath, cleanup, err := TemporaryRepository(ctx.Log, filepath.Join(ctx.TestdataPath, repositoryPath))
temporaryRepository, cleanup, err := TemporaryRepository(ctx.Log, ctx.TestdataPath, repositoryPath)
if err != nil {
ctx.Log.Panicf("ERROR: unable to create temporary repository path: %s", err)
}

defer cleanup()

temporaryRepositories[repositoryPath] = temporaryRepositoryPath
temporaryRepositories[repositoryPath] = temporaryRepository
}
}
for rl := uint(0); rl < ctx.runsAtLanguageLevel(); rl++ {
Expand All @@ -177,7 +179,7 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer

for _, repository := range repositories {
repositoryPath := filepath.Join(languageID, repository.Name())
temporaryRepositoryPath := temporaryRepositories[repositoryPath]
temporaryRepository := temporaryRepositories[repositoryPath]

if !repository.IsDir() || (len(ctx.RepositoryPaths) > 0 && !repositoriesLookup[repositoryPath]) {
continue
Expand All @@ -196,24 +198,26 @@ func Evaluate(ctx *Context) (assessments report.AssessmentPerModelPerLanguagePer

continue
}
withLoadedModel(ctx.Log, model, ctx.ProviderForModel[model], func() {
for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
}

if err := ResetTemporaryRepository(ctx.Log, temporaryRepositoryPath); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

assessment, ps, err := Repository(ctx.Log, ctx.ResultPath, model, language, temporaryRepositoryPath, repositoryPath)
assessments[model][language][repositoryPath].Add(assessment)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
if err != nil {
ctx.Log.Printf("ERROR: Model %q encountered a hard error for language %q, repository %q: %+v", modelID, languageID, repositoryPath, err)
for _, taskIdentifier := range temporaryRepository.Tasks {
withLoadedModel(ctx.Log, model, ctx.ProviderForModel[model], func() {
for rm := uint(0); rm < ctx.runsAtModelLevel(); rm++ {
if ctx.Runs > 1 && ctx.RunsSequential {
ctx.Log.Printf("Run %d/%d for model %q", rm+1, ctx.Runs, modelID)
}

if err := temporaryRepository.Reset(ctx.Log); err != nil {
ctx.Log.Panicf("ERROR: unable to reset temporary repository path: %s", err)
}

assessment, ps, err := temporaryRepository.Evaluate(ctx.Log, ctx.ResultPath, model, language, taskIdentifier)
assessments.Add(model, language, repositoryPath, taskIdentifier, assessment)
problemsPerModel[modelID] = append(problemsPerModel[modelID], ps...)
if err != nil {
ctx.Log.Printf("ERROR: Model %q encountered a hard error for language %q, repository %q: %+v", modelID, languageID, repositoryPath, err)
}
}
}
})
})
}
}
}
}
Expand Down
Loading
Loading