Skip to content

Commit

Permalink
Merge pull request #182 from symflower/163-refactor-fix-command-initi…
Browse files Browse the repository at this point in the history
…alization

fix, Default to all repositories if none are selected in CLI
  • Loading branch information
ahumenberger authored Jun 18, 2024
2 parents 631b0b2 + 5d354f6 commit 670b709
Show file tree
Hide file tree
Showing 4 changed files with 364 additions and 98 deletions.
232 changes: 141 additions & 91 deletions cmd/eval-dev-quality/cmd/evaluate.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ type Evaluate struct {

// logger holds the logger of the command.
logger *log.Logger
// timestamp holds the timestamp of the command execution.
timestamp time.Time
}

var _ SetLogger = (*Evaluate)(nil)
Expand All @@ -79,29 +81,50 @@ func (command *Evaluate) SetLogger(logger *log.Logger) {
command.logger = logger
}

// Execute executes the command.
func (command *Evaluate) Execute(args []string) (err error) {
evaluationTimestamp := time.Now()
command.ResultPath = strings.ReplaceAll(command.ResultPath, "%datetime%", evaluationTimestamp.Format("2006-01-02-15:04:05")) // REMARK Use a datetime format with a dash, so directories can be easily marked because they are only one group.
command.ResultPath, err = util.UniqueDirectory(command.ResultPath)
if err != nil {
return err
// Initialize initializes the command according to the arguments.
func (command *Evaluate) Initialize(args []string) (evaluationContext *evaluate.Context, cleanup func()) {
// Ensure the cleanup always runs in case there is a panic.
defer func() {
if r := recover(); r != nil {
if cleanup != nil {
cleanup()
}
panic(r)
}
}()
evaluationContext = &evaluate.Context{}

// Setup evaluation result directory.
{
command.ResultPath = strings.ReplaceAll(command.ResultPath, "%datetime%", command.timestamp.Format("2006-01-02-15:04:05")) // REMARK Use a datetime format with a dash, so directories can be easily marked because they are only one group.
uniqueResultPath, err := util.UniqueDirectory(command.ResultPath)
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
command.ResultPath = uniqueResultPath
evaluationContext.ResultPath = uniqueResultPath
command.logger.Printf("Writing results to %s", command.ResultPath)
}
command.logger.Printf("Writing results to %s", command.ResultPath)

log, logClose, err := log.WithFile(command.logger, filepath.Join(command.ResultPath, "evaluation.log"))
if err != nil {
return err
// Initialize logging within result directory.
{
log, logClose, err := log.WithFile(command.logger, filepath.Join(command.ResultPath, "evaluation.log"))
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
cleanup = logClose
command.logger = log
evaluationContext.Log = log
}
defer logClose()

// Check common options.
// Check and validate common options.
{
if command.InstallToolsPath == "" {
command.InstallToolsPath, err = tools.InstallPathDefault()
installToolsPath, err := tools.InstallPathDefault()
if err != nil {
log.Panicf("ERROR: %s", err)
command.logger.Panicf("ERROR: %s", err)
}
command.InstallToolsPath = installToolsPath
}

if command.OllamaBinaryPath != "" {
Expand All @@ -116,18 +139,35 @@ func (command *Evaluate) Execute(args []string) (err error) {
}

if command.QueryAttempts == 0 {
log.Panicf("number of configured query attempts must be greater than zero")
command.logger.Panicf("number of configured query attempts must be greater than zero")
}
evaluationContext.QueryAttempts = command.QueryAttempts

if command.ExecutionTimeout == 0 {
log.Panicf("execution timeout for compilation and tests must be greater than zero")
command.logger.Panicf("execution timeout for compilation and tests must be greater than zero")
} else {
language.DefaultExecutionTimeout = time.Duration(command.ExecutionTimeout) * time.Minute
}

if command.Runs == 0 {
log.Panicf("number of configured runs must be greater than zero")
command.logger.Panicf("number of configured runs must be greater than zero")
}
evaluationContext.Runs = command.Runs

evaluationContext.NoDisqualification = command.NoDisqualification
}

// Ensure the "testdata" path exists and make it absolute.
{
if err := osutil.DirExists(command.TestdataPath); err != nil {
command.logger.Panicf("ERROR: testdata path %q cannot be accessed: %s", command.TestdataPath, err)
}
testdataPath, err := filepath.Abs(command.TestdataPath)
if err != nil {
command.logger.Panicf("ERROR: could not resolve testdata path %q to an absolute path: %s", command.TestdataPath, err)
}
command.TestdataPath = testdataPath
evaluationContext.TestdataPath = testdataPath
}

// Register custom OpenAI API providers and models.
Expand All @@ -149,11 +189,11 @@ func (command *Evaluate) Execute(args []string) (err error) {

providerID, _, ok := strings.Cut(model, provider.ProviderModelSeparator)
if !ok {
log.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator)
command.logger.Panicf("ERROR: cannot split %q into provider and model name by %q", model, provider.ProviderModelSeparator)
}
modelProvider, ok := customProviders[providerID]
if !ok {
log.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
command.logger.Panicf("ERROR: unknown custom provider %q for model %q", providerID, model)
}

modelProvider.AddModel(llm.NewModel(modelProvider, model))
Expand All @@ -174,86 +214,109 @@ func (command *Evaluate) Execute(args []string) (err error) {
ls := maps.Keys(language.Languages)
sort.Strings(ls)

log.Panicf("ERROR: language %s does not exist. Valid languages are: %s", languageID, strings.Join(ls, ", "))
command.logger.Panicf("ERROR: language %s does not exist. Valid languages are: %s", languageID, strings.Join(ls, ", "))
}

languages[languageID] = l
}
}

sort.Strings(command.Languages)
for _, languageID := range command.Languages {
languagesSelected[languageID] = languages[languageID]
}
}

commandRepositories := map[string]bool{}
commandRepositoriesLanguages := map[string]bool{}
for _, r := range command.Repositories {
languageIDOfRepository := strings.SplitN(r, string(os.PathSeparator), 2)[0]
commandRepositoriesLanguages[languageIDOfRepository] = true
}

if _, ok := languagesSelected[languageIDOfRepository]; ok {
commandRepositories[r] = true
// Gather repositories and update language selection accordingly.
{
if len(command.Repositories) == 0 {
for _, l := range command.Languages {
repositories, err := language.RepositoriesForLanguage(language.Languages[l], command.TestdataPath)
if err != nil {
command.logger.Panicf("ERROR: %s", err)
}
command.Repositories = append(command.Repositories, repositories...)
}
sort.Strings(command.Repositories)
} else {
log.Printf("Excluded repository %s because its language %q is not enabled for this evaluation", r, languageIDOfRepository)
commandRepositories := map[string]bool{}
commandRepositoriesLanguages := map[string]bool{}
for _, r := range command.Repositories {
languageIDOfRepository := strings.SplitN(r, string(os.PathSeparator), 2)[0]
if _, ok := languagesSelected[languageIDOfRepository]; ok {
commandRepositories[r] = true
commandRepositoriesLanguages[languageIDOfRepository] = true
} else {
command.logger.Printf("Excluded repository %s because its language %q is not enabled for this evaluation", r, languageIDOfRepository)
}
}
for languageID := range languagesSelected {
if commandRepositoriesLanguages[languageID] { // Also add the plain repository in case we already have repositories for this language.
commandRepositories[filepath.Join(languageID, evaluate.RepositoryPlainName)] = true
} else {
command.Languages = slices.DeleteFunc(command.Languages, func(l string) bool {
return l == languageID
})
delete(languagesSelected, languageID)
command.logger.Printf("Excluded language %q because it is not part of the selected repositories", languageID)
}
}

command.Repositories = maps.Keys(commandRepositories)
sort.Strings(command.Repositories)
}
evaluationContext.RepositoryPaths = command.Repositories
}
for languageID := range languagesSelected {
if len(command.Repositories) == 0 || commandRepositoriesLanguages[languageID] {
commandRepositories[filepath.Join(languageID, evaluate.RepositoryPlainName)] = true
} else {
command.Languages = slices.DeleteFunc(command.Languages, func(l string) bool {
return l == languageID
})
delete(languagesSelected, languageID)
log.Printf("Excluded language %q because it is not part of the selected repositories", languageID)
}

// Make the resolved selected languages available in the command.
evaluationContext.Languages = make([]language.Language, len(command.Languages))
for i, languageID := range command.Languages {
evaluationContext.Languages[i] = languagesSelected[languageID]
}
command.Repositories = maps.Keys(commandRepositories)
sort.Strings(command.Repositories)

// Gather models.
modelsSelected := map[string]model.Model{}
providerForModel := map[model.Model]provider.Provider{}
{
models := map[string]model.Model{}
modelsSelected := map[string]model.Model{}
evaluationContext.ProviderForModel = map[model.Model]provider.Provider{}
for _, p := range provider.Providers {
log.Printf("Checking provider %q for models", p.ID())
command.logger.Printf("Checking provider %q for models", p.ID())

if t, ok := p.(provider.InjectToken); ok {
token, ok := command.ProviderTokens[p.ID()]
if ok {
t.SetToken(token)
}
}
if err := p.Available(log); err != nil {
log.Printf("Skipping unavailable provider %q cause: %s", p.ID(), err)
if err := p.Available(command.logger); err != nil {
command.logger.Printf("Skipping unavailable provider %q cause: %s", p.ID(), err)

continue
}

// Start services of providers.
if service, ok := p.(provider.Service); ok {
log.Printf("Starting services for provider %q", p.ID())
shutdown, err := service.Start(log)
command.logger.Printf("Starting services for provider %q", p.ID())
shutdown, err := service.Start(command.logger)
if err != nil {
log.Panicf("ERROR: could not start services for provider %q: %s", p, err)
command.logger.Panicf("ERROR: could not start services for provider %q: %s", p, err)
}
defer func() {
if err := shutdown(); err != nil {
log.Panicf("ERROR: could not shutdown services of provider %q: %s", p, err)
command.logger.Panicf("ERROR: could not shutdown services of provider %q: %s", p, err)
}
}()
}

ms, err := p.Models()
if err != nil {
log.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
command.logger.Panicf("ERROR: could not query models for provider %q: %s", p.ID(), err)
}

for _, m := range ms {
models[m.ID()] = m
providerForModel[m] = p
evaluationContext.ProviderForModel[m] = p
}
}
modelIDs := maps.Keys(models)
Expand All @@ -263,58 +326,45 @@ func (command *Evaluate) Execute(args []string) (err error) {
} else {
for _, modelID := range command.Models {
if _, ok := models[modelID]; !ok {
log.Panicf("ERROR: model %s does not exist. Valid models are: %s", modelID, strings.Join(modelIDs, ", "))
command.logger.Panicf("ERROR: model %s does not exist. Valid models are: %s", modelID, strings.Join(modelIDs, ", "))
}
}
}
sort.Strings(command.Models)
for _, modelID := range command.Models {
modelsSelected[modelID] = models[modelID]
}
}

if err := osutil.DirExists(command.TestdataPath); err != nil {
log.Panicf("ERROR: testdata path %q cannot be accessed: %s", command.TestdataPath, err)
}
command.TestdataPath, err = filepath.Abs(command.TestdataPath)
if err != nil {
log.Panicf("ERROR: could not resolve testdata path %q to an absolute path: %s", command.TestdataPath, err)
}

// Install required tools for the basic evaluation.
if err := tools.InstallEvaluation(log, command.InstallToolsPath); err != nil {
log.Panicf("ERROR: %s", err)
// Make the resolved selected models available in the command.
evaluationContext.Models = make([]model.Model, len(command.Models))
for i, modelID := range command.Models {
evaluationContext.Models[i] = modelsSelected[modelID]
}
}

ls := make([]language.Language, len(command.Languages))
for i, languageID := range command.Languages {
ls[i] = languagesSelected[languageID]
}
ms := make([]model.Model, len(command.Models))
for i, modelID := range command.Models {
ms[i] = modelsSelected[modelID]
}
assessments, totalScore := evaluate.Evaluate(&evaluate.Context{
Log: log,
return evaluationContext, cleanup
}

Languages: ls,
// Execute executes the command.
func (command *Evaluate) Execute(args []string) (err error) {
command.timestamp = time.Now()

Models: ms,
ProviderForModel: providerForModel,
QueryAttempts: command.QueryAttempts,
evaluationContext, cleanup := command.Initialize(args)
defer cleanup()
if evaluationContext == nil {
command.logger.Panic("ERROR: empty evaluation context")
}

RepositoryPaths: command.Repositories,
ResultPath: command.ResultPath,
TestdataPath: command.TestdataPath,
// Install required tools for the basic evaluation.
if err := tools.InstallEvaluation(command.logger, command.InstallToolsPath); err != nil {
command.logger.Panicf("ERROR: %s", err)
}

Runs: command.Runs,
RunsSequential: command.RunsSequential,
NoDisqualification: command.NoDisqualification,
})
assessments, totalScore := evaluate.Evaluate(evaluationContext)

assessmentsPerModel := assessments.CollapseByModel()
if err := (report.Markdown{
DateTime: evaluationTimestamp,
DateTime: command.timestamp,
Version: evaluate.Version,

CSVPath: "./evaluation.csv",
Expand All @@ -325,17 +375,17 @@ func (command *Evaluate) Execute(args []string) (err error) {
AssessmentPerModel: assessmentsPerModel,
TotalScore: totalScore,
}).WriteToFile(filepath.Join(command.ResultPath, "README.md")); err != nil {
return err
command.logger.Panicf("ERROR: %s", err)
}

_ = assessmentsPerModel.WalkByScore(func(model model.Model, assessment metrics.Assessments, score uint64) (err error) {
log.Printf("Evaluation score for %q (%q): %s", model.ID(), assessment.Category(totalScore).ID, assessment)
command.logger.Printf("Evaluation score for %q (%q): %s", model.ID(), assessment.Category(totalScore).ID, assessment)

return nil
})

if err := writeCSVs(command.ResultPath, assessments); err != nil {
log.Panicf("ERROR: %s", err)
command.logger.Panicf("ERROR: %s", err)
}

return nil
Expand Down
Loading

0 comments on commit 670b709

Please sign in to comment.