From bf2b89113a944236b1a6d17546f0d910ed18307d Mon Sep 17 00:00:00 2001 From: Simon Bauer Date: Wed, 2 Oct 2024 11:44:34 +0200 Subject: [PATCH 1/3] Apply "symflower fix" in any case as it will repair non-compile errors in the future Part of #350 --- evaluate/task/task-transpile.go | 56 ++++++++++----------------- evaluate/task/task-write-test.go | 54 +++++++++----------------- evaluate/task/task-write-test_test.go | 24 ------------ 3 files changed, 39 insertions(+), 95 deletions(-) diff --git a/evaluate/task/task-transpile.go b/evaluate/task/task-transpile.go index fc24385e..dd3c8a55 100644 --- a/evaluate/task/task-transpile.go +++ b/evaluate/task/task-transpile.go @@ -1,8 +1,6 @@ package task import ( - "context" - "errors" "fmt" "os" "path/filepath" @@ -75,7 +73,7 @@ func (t *TaskTranspile) Run(ctx evaltask.Context) (repositoryAssessment map[eval } for originFilePath, originLanguage := range originFilePathsWithLanguage { modelAssessmentsForFile := metrics.NewAssessments() - withSymflowerAssessmentsForFile := modelAssessmentsForFile // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until a failure actually happens. + withSymflowerAssessmentsForFile := modelAssessmentsForFile // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until we successfully applied "symflower fix". if err := ctx.Repository.Reset(ctx.Logger); err != nil { ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) @@ -110,39 +108,6 @@ func (t *TaskTranspile) Run(ctx evaltask.Context) (repositoryAssessment map[eval problems = append(problems, ps...) if err != nil { problems = append(problems, pkgerrors.WithMessage(err, originFilePath)) - - // If there is an execution timeout do not run "symflower fix" because the code itself is correct. - if errors.Is(err, context.DeadlineExceeded) { - modelAssessments.Add(modelAssessmentsForFile) - withSymflowerAssessments.Add(withSymflowerAssessmentsForFile) - - continue - } - - // Run "symflower fix" if the model response fails to execute. - if ctx.Language.ID() == "golang" { // Currently we only support Go for "symflower fix". - withSymflowerFixTestResult, processingTime, ps, err := ExecuteWithSymflowerFix(ctx, taskLogger.Logger, filepath.Join(ctx.Repository.DataPath(), packagePath)) - problems = append(problems, ps...) - if err != nil { - problems = append(problems, err) - - modelAssessments.Add(modelAssessmentsForFile) - withSymflowerAssessments.Add(withSymflowerAssessmentsForFile) - - continue - } else { - testsPassing := withSymflowerFixTestResult.TestsPass - taskLogger.Printf("with symflower repair: Executes tests with %d tests passing", testsPassing) - - // Symflower was able to fix a failure so now update the assessment with the improved results. - withSymflowerFixAssessments := metrics.NewAssessments() - withSymflowerFixAssessments[metrics.AssessmentKeyProcessingTime] = processingTime - withSymflowerFixAssessments.Award(metrics.AssessmentKeyFilesExecuted) - withSymflowerFixAssessments.AwardPoints(metrics.AssessmentKeyTestsPassing, uint64(testsPassing)) - - withSymflowerAssessmentsForFile = metrics.CombineWithSymflowerFixAssessments(modelAssessmentsForFile, withSymflowerFixAssessments) - } - } } else { testsPassing := testResult.TestsPass taskLogger.Printf("Executes tests with %d tests passing", testsPassing) @@ -150,6 +115,25 @@ func (t *TaskTranspile) Run(ctx evaltask.Context) (repositoryAssessment map[eval modelAssessmentsForFile.AwardPoints(metrics.AssessmentKeyTestsPassing, uint64(testsPassing)) } + if ctx.Language.ID() == "golang" { // Currently we only support Go for "symflower fix". + withSymflowerFixTestResult, processingTime, ps, err := ExecuteWithSymflowerFix(ctx, taskLogger.Logger, filepath.Join(ctx.Repository.DataPath(), packagePath)) + problems = append(problems, ps...) + if err != nil { + problems = append(problems, err) + } else { + testsPassing := withSymflowerFixTestResult.TestsPass + taskLogger.Printf("with symflower repair: Executes tests with %d tests passing", testsPassing) + + // Symflower was able to fix a failure so now update the assessment with the improved results. + withSymflowerFixAssessments := metrics.NewAssessments() + withSymflowerFixAssessments[metrics.AssessmentKeyProcessingTime] = processingTime + withSymflowerFixAssessments.Award(metrics.AssessmentKeyFilesExecuted) + withSymflowerFixAssessments.AwardPoints(metrics.AssessmentKeyTestsPassing, uint64(testsPassing)) + + withSymflowerAssessmentsForFile = metrics.CombineWithSymflowerFixAssessments(modelAssessmentsForFile, withSymflowerFixAssessments) + } + } + modelAssessments.Add(modelAssessmentsForFile) withSymflowerAssessments.Add(withSymflowerAssessmentsForFile) } diff --git a/evaluate/task/task-write-test.go b/evaluate/task/task-write-test.go index 5330cf44..20d83c0b 100644 --- a/evaluate/task/task-write-test.go +++ b/evaluate/task/task-write-test.go @@ -1,8 +1,6 @@ package task import ( - "context" - "errors" "fmt" "strings" @@ -55,7 +53,7 @@ func (t *TaskWriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[eva for _, filePath := range filePaths { modelAssessmentForFile := metrics.NewAssessments() - withSymflowerFixAssessmentForFile := modelAssessmentForFile // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until a failure actually happens. + withSymflowerFixAssessmentForFile := modelAssessmentForFile // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until we successfully applied "symflower fix". if err := ctx.Repository.Reset(ctx.Logger); err != nil { ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) @@ -85,44 +83,30 @@ func (t *TaskWriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[eva problems = append(problems, ps...) if err != nil { problems = append(problems, pkgerrors.WithMessage(err, filePath)) - - // If there is an execution timeout do not run "symflower fix" because the code itself is correct. - if errors.Is(err, context.DeadlineExceeded) { - modelAssessment.Add(modelAssessmentForFile) - withSymflowerFixAssessment.Add(withSymflowerFixAssessmentForFile) - - continue - } - - // Run "symflower fix" if the model response fails to execute. - if ctx.Language.ID() == "golang" { // Currently we only support Go for "symflower fix". - withSymflowerFixTestResult, processingTime, ps, err := ExecuteWithSymflowerFix(ctx, taskLogger.Logger, ctx.Repository.DataPath()) - problems = append(problems, ps...) - if err != nil { - problems = append(problems, err) - - modelAssessment.Add(modelAssessmentForFile) - withSymflowerFixAssessment.Add(withSymflowerFixAssessmentForFile) - - continue - } else { - ctx.Logger.Printf("with symflower repair: Executes tests with %d coverage objects", withSymflowerFixTestResult.Coverage) - - // Symflower was able to fix a failure so now update the assessment with the improved results. - withSymflowerFixAssessments := metrics.NewAssessments() - withSymflowerFixAssessments[metrics.AssessmentKeyProcessingTime] = processingTime - withSymflowerFixAssessments.Award(metrics.AssessmentKeyFilesExecuted) - withSymflowerFixAssessments.AwardPoints(metrics.AssessmentKeyCoverage, withSymflowerFixTestResult.Coverage) - - withSymflowerFixAssessmentForFile = metrics.CombineWithSymflowerFixAssessments(modelAssessmentForFile, withSymflowerFixAssessments) - } - } } else { taskLogger.Printf("Executes tests with %d coverage objects", testResult.Coverage) modelAssessmentForFile.Award(metrics.AssessmentKeyFilesExecuted) modelAssessmentForFile.AwardPoints(metrics.AssessmentKeyCoverage, testResult.Coverage) } + if ctx.Language.ID() == "golang" { // Currently we only support Go for "symflower fix". + withSymflowerFixTestResult, processingTime, ps, err := ExecuteWithSymflowerFix(ctx, taskLogger.Logger, ctx.Repository.DataPath()) + problems = append(problems, ps...) + if err != nil { + problems = append(problems, err) + } else { + ctx.Logger.Printf("with symflower repair: Executes tests with %d coverage objects", withSymflowerFixTestResult.Coverage) + + // Symflower was able to fix a failure so now update the assessment with the improved results. + withSymflowerFixAssessments := metrics.NewAssessments() + withSymflowerFixAssessments[metrics.AssessmentKeyProcessingTime] = processingTime + withSymflowerFixAssessments.Award(metrics.AssessmentKeyFilesExecuted) + withSymflowerFixAssessments.AwardPoints(metrics.AssessmentKeyCoverage, withSymflowerFixTestResult.Coverage) + + withSymflowerFixAssessmentForFile = metrics.CombineWithSymflowerFixAssessments(modelAssessmentForFile, withSymflowerFixAssessments) + } + } + modelAssessment.Add(modelAssessmentForFile) withSymflowerFixAssessment.Add(withSymflowerFixAssessmentForFile) } diff --git a/evaluate/task/task-write-test_test.go b/evaluate/task/task-write-test_test.go index f6a0bcf4..ebd7182d 100644 --- a/evaluate/task/task-write-test_test.go +++ b/evaluate/task/task-write-test_test.go @@ -1,14 +1,12 @@ package task import ( - "context" "fmt" "os" "path/filepath" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/symflower/eval-dev-quality/evaluate/metrics" metricstesting "github.com/symflower/eval-dev-quality/evaluate/metrics/testing" @@ -17,7 +15,6 @@ import ( "github.com/symflower/eval-dev-quality/language/golang" "github.com/symflower/eval-dev-quality/language/java" "github.com/symflower/eval-dev-quality/language/ruby" - languagetesting "github.com/symflower/eval-dev-quality/language/testing" "github.com/symflower/eval-dev-quality/log" modeltesting "github.com/symflower/eval-dev-quality/model/testing" "github.com/symflower/eval-dev-quality/task" @@ -191,27 +188,6 @@ func TestTaskWriteTestsRun(t *testing.T) { this is not valid go code `), expectedAssessments, expectedProblems, false) } - { - expectedAssessments := map[task.Identifier]metrics.Assessments{ - IdentifierWriteTests: metrics.Assessments{ - metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, - metrics.AssessmentKeyResponseNoError: 1, - }, - IdentifierWriteTestsSymflowerFix: metrics.Assessments{ - metrics.AssessmentKeyFilesExecutedMaximumReachable: 1, - metrics.AssessmentKeyResponseNoError: 1, - }, - } - expectedProblems := []string{ - "context deadline exceeded", - } - - languageMock := languagetesting.NewMockLanguageNamed(t, "golang") - languageMock.On("Files", mock.Anything, mock.Anything).Return([]string{filepath.Join("golang", "plain")}, nil).Once() - languageMock.On("ExecuteTests", mock.Anything, mock.Anything).Return(nil, nil, context.DeadlineExceeded).Once() - - validateGo(t, "Execution timeout", languageMock, "", expectedAssessments, expectedProblems, false) - } }) }) From 3c368f8edb322f3ba45738bdfdf9141ef3e8c061 Mon Sep 17 00:00:00 2001 From: Simon Bauer Date: Wed, 2 Oct 2024 15:56:00 +0200 Subject: [PATCH 2/3] refactor, Extract "write test" core task logic so it can be applied twice (with and without template) Part #350 --- evaluate/task/task-write-test.go | 107 ++++++++++++++++--------------- 1 file changed, 57 insertions(+), 50 deletions(-) diff --git a/evaluate/task/task-write-test.go b/evaluate/task/task-write-test.go index 20d83c0b..7d6c9526 100644 --- a/evaluate/task/task-write-test.go +++ b/evaluate/task/task-write-test.go @@ -52,63 +52,17 @@ func (t *TaskWriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[eva withSymflowerFixAssessment[metrics.AssessmentKeyFilesExecutedMaximumReachable] = maximumReachableFiles for _, filePath := range filePaths { - modelAssessmentForFile := metrics.NewAssessments() - withSymflowerFixAssessmentForFile := modelAssessmentForFile // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until we successfully applied "symflower fix". - if err := ctx.Repository.Reset(ctx.Logger); err != nil { ctx.Logger.Panicf("ERROR: unable to reset temporary repository path: %s", err) } - modelContext := model.Context{ - Language: ctx.Language, - - RepositoryPath: dataPath, - FilePath: filePath, - - Logger: taskLogger.Logger, - } - assessments, err := modelCapability.WriteTests(modelContext) - if err != nil { - problems = append(problems, pkgerrors.WithMessage(err, filePath)) - - continue - } - if assessments[metrics.AssessmentKeyProcessingTime] == 0 { - return nil, nil, pkgerrors.Errorf("no model response time measurement present for %q at repository %q", ctx.Model.ID(), ctx.Repository.Name()) - } - modelAssessmentForFile.Add(assessments) - modelAssessmentForFile.Award(metrics.AssessmentKeyResponseNoError) - - testResult, ps, err := ctx.Language.ExecuteTests(taskLogger.Logger, dataPath) + modelAssessmentFile, withSymflowerFixAssessmentFile, ps, err := runModelAndSymflowerFix(ctx, taskLogger, modelCapability, dataPath, filePath) problems = append(problems, ps...) if err != nil { - problems = append(problems, pkgerrors.WithMessage(err, filePath)) - } else { - taskLogger.Printf("Executes tests with %d coverage objects", testResult.Coverage) - modelAssessmentForFile.Award(metrics.AssessmentKeyFilesExecuted) - modelAssessmentForFile.AwardPoints(metrics.AssessmentKeyCoverage, testResult.Coverage) - } - - if ctx.Language.ID() == "golang" { // Currently we only support Go for "symflower fix". - withSymflowerFixTestResult, processingTime, ps, err := ExecuteWithSymflowerFix(ctx, taskLogger.Logger, ctx.Repository.DataPath()) - problems = append(problems, ps...) - if err != nil { - problems = append(problems, err) - } else { - ctx.Logger.Printf("with symflower repair: Executes tests with %d coverage objects", withSymflowerFixTestResult.Coverage) - - // Symflower was able to fix a failure so now update the assessment with the improved results. - withSymflowerFixAssessments := metrics.NewAssessments() - withSymflowerFixAssessments[metrics.AssessmentKeyProcessingTime] = processingTime - withSymflowerFixAssessments.Award(metrics.AssessmentKeyFilesExecuted) - withSymflowerFixAssessments.AwardPoints(metrics.AssessmentKeyCoverage, withSymflowerFixTestResult.Coverage) - - withSymflowerFixAssessmentForFile = metrics.CombineWithSymflowerFixAssessments(modelAssessmentForFile, withSymflowerFixAssessments) - } + return nil, problems, err } - - modelAssessment.Add(modelAssessmentForFile) - withSymflowerFixAssessment.Add(withSymflowerFixAssessmentForFile) + modelAssessment.Add(modelAssessmentFile) + withSymflowerFixAssessment.Add(withSymflowerFixAssessmentFile) } repositoryAssessment = map[evaltask.Identifier]metrics.Assessments{ @@ -119,6 +73,59 @@ func (t *TaskWriteTests) Run(ctx evaltask.Context) (repositoryAssessment map[eva return repositoryAssessment, problems, nil } +func runModelAndSymflowerFix(ctx evaltask.Context, taskLogger *taskLogger, modelCapability model.CapabilityWriteTests, dataPath string, filePath string) (modelAssessment metrics.Assessments, withSymflowerFixAssessment metrics.Assessments, problems []error, err error) { + modelAssessment = metrics.NewAssessments() + withSymflowerFixAssessment = modelAssessment // The symflower assessment tracks how the model result can be improved in case of a failure, so just link to the model assessment until we successfully applied "symflower fix". + + modelContext := model.Context{ + Language: ctx.Language, + + RepositoryPath: dataPath, + FilePath: filePath, + + Logger: taskLogger.Logger, + } + assessments, err := modelCapability.WriteTests(modelContext) + if err != nil { + return nil, nil, append(problems, pkgerrors.WithMessage(err, filePath)), nil + } + if assessments[metrics.AssessmentKeyProcessingTime] == 0 { + return nil, nil, problems, pkgerrors.Errorf("no model response time measurement present for %q at repository %q", ctx.Model.ID(), ctx.Repository.Name()) + } + modelAssessment.Add(assessments) + modelAssessment.Award(metrics.AssessmentKeyResponseNoError) + + testResult, ps, err := ctx.Language.ExecuteTests(taskLogger.Logger, dataPath) + problems = append(problems, ps...) + if err != nil { + problems = append(problems, pkgerrors.WithMessage(err, filePath)) + } else { + taskLogger.Printf("Executes tests with %d coverage objects", testResult.Coverage) + modelAssessment.Award(metrics.AssessmentKeyFilesExecuted) + modelAssessment.AwardPoints(metrics.AssessmentKeyCoverage, testResult.Coverage) + } + + if ctx.Language.ID() == "golang" { // Currently we only support Go for "symflower fix". + withSymflowerFixTestResult, processingTime, ps, err := ExecuteWithSymflowerFix(ctx, taskLogger.Logger, ctx.Repository.DataPath()) + problems = append(problems, ps...) + if err != nil { + problems = append(problems, err) + } else { + ctx.Logger.Printf("with symflower repair: Executes tests with %d coverage objects", withSymflowerFixTestResult.Coverage) + + // Symflower was able to fix a failure so now update the assessment with the improved results. + withSymflowerFix := metrics.NewAssessments() + withSymflowerFix[metrics.AssessmentKeyProcessingTime] = processingTime + withSymflowerFix.Award(metrics.AssessmentKeyFilesExecuted) + withSymflowerFix.AwardPoints(metrics.AssessmentKeyCoverage, withSymflowerFixTestResult.Coverage) + + withSymflowerFixAssessment = metrics.CombineWithSymflowerFixAssessments(modelAssessment, withSymflowerFix) + } + } + + return modelAssessment, withSymflowerFixAssessment, problems, nil +} + // validateWriteTestsRepository checks if the repository for the "write-tests" task is well-formed. func validateWriteTestsRepository(logger *log.Logger, repositoryPath string, language language.Language) (err error) { logger.Printf("validating repository %q", repositoryPath) From fcb2104ccff8f20ef38f9e58b0cfc1833f079a43 Mon Sep 17 00:00:00 2001 From: Simon Bauer Date: Thu, 3 Oct 2024 09:10:42 +0200 Subject: [PATCH 3/3] refactor, Separate LLM prompt template for "write test" task so we can add a template Part of #350 --- model/llm/llm.go | 66 +++++++++++++++++------------- model/llm/llm_test.go | 94 +++++++++++++++---------------------------- 2 files changed, 69 insertions(+), 91 deletions(-) diff --git a/model/llm/llm.go b/model/llm/llm.go index 4bd52fee..ef0a9072 100644 --- a/model/llm/llm.go +++ b/model/llm/llm.go @@ -63,7 +63,7 @@ func (m *Model) MetaInformation() (metaInformation *model.MetaInformation) { return m.metaInformation } -// llmSourceFilePromptContext is the context for template for generating an LLM test generation prompt. +// llmSourceFilePromptContext is the base template context for an LLM generation prompt. type llmSourceFilePromptContext struct { // Language holds the programming language name. Language language.Language @@ -76,8 +76,14 @@ type llmSourceFilePromptContext struct { ImportPath string } -// llmGenerateTestForFilePromptTemplate is the template for generating an LLM test generation prompt. -var llmGenerateTestForFilePromptTemplate = template.Must(template.New("model-llm-generate-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(` +// llmWriteTestSourceFilePromptContext is the template context for a write test LLM prompt. +type llmWriteTestSourceFilePromptContext struct { + // llmSourceFilePromptContext holds the context for a source file prompt. + llmSourceFilePromptContext +} + +// llmWriteTestForFilePromptTemplate is the template for generating an LLM test generation prompt. +var llmWriteTestForFilePromptTemplate = template.Must(template.New("model-llm-write-test-for-file-prompt").Parse(bytesutil.StringTrimIndentations(` Given the following {{ .Language.Name }} code file "{{ .FilePath }}" with package "{{ .ImportPath }}", provide a test file for this code{{ with $testFramework := .Language.TestFramework }} with {{ $testFramework }} as a test framework{{ end }}. The tests should produce 100 percent code coverage and must compile. The response must contain only the test code in a fenced code block and nothing else. @@ -87,14 +93,14 @@ var llmGenerateTestForFilePromptTemplate = template.Must(template.New("model-llm ` + "```" + ` `))) -// llmGenerateTestForFilePrompt returns the prompt for generating an LLM test generation. -func llmGenerateTestForFilePrompt(data *llmSourceFilePromptContext) (message string, err error) { +// Format returns the prompt for generating an LLM test generation. +func (ctx *llmWriteTestSourceFilePromptContext) Format() (message string, err error) { // Use Linux paths even when running the evaluation on Windows to ensure consistency in prompting. - data.FilePath = filepath.ToSlash(data.FilePath) - data.Code = strings.TrimSpace(data.Code) + ctx.FilePath = filepath.ToSlash(ctx.FilePath) + ctx.Code = strings.TrimSpace(ctx.Code) var b strings.Builder - if err := llmGenerateTestForFilePromptTemplate.Execute(&b, data); err != nil { + if err := llmWriteTestForFilePromptTemplate.Execute(&b, ctx); err != nil { return "", pkgerrors.WithStack(err) } @@ -123,14 +129,14 @@ var llmCodeRepairSourceFilePromptTemplate = template.Must(template.New("model-ll - {{.}}{{ end }} `))) -// llmCodeRepairSourceFilePrompt returns the prompt to code repair a source file. -func llmCodeRepairSourceFilePrompt(data *llmCodeRepairSourceFilePromptContext) (message string, err error) { +// Format returns the prompt to code repair a source file. +func (ctx *llmCodeRepairSourceFilePromptContext) Format() (message string, err error) { // Use Linux paths even when running the evaluation on Windows to ensure consistency in prompting. - data.FilePath = filepath.ToSlash(data.FilePath) - data.Code = strings.TrimSpace(data.Code) + ctx.FilePath = filepath.ToSlash(ctx.FilePath) + ctx.Code = strings.TrimSpace(ctx.Code) var b strings.Builder - if err := llmCodeRepairSourceFilePromptTemplate.Execute(&b, data); err != nil { + if err := llmCodeRepairSourceFilePromptTemplate.Execute(&b, ctx); err != nil { return "", pkgerrors.WithStack(err) } @@ -164,15 +170,15 @@ var llmTranspileSourceFilePromptTemplate = template.Must(template.New("model-llm ` + "```" + ` `))) -// llmTranspileSourceFilePrompt returns the prompt to transpile a source file. -func llmTranspileSourceFilePrompt(data *llmTranspileSourceFilePromptContext) (message string, err error) { +// Format returns the prompt to transpile a source file. +func (ctx *llmTranspileSourceFilePromptContext) Format() (message string, err error) { // Use Linux paths even when running the evaluation on Windows to ensure consistency in prompting. - data.FilePath = filepath.ToSlash(data.FilePath) - data.Code = strings.TrimSpace(data.Code) - data.OriginFileContent = strings.TrimSpace(data.OriginFileContent) + ctx.FilePath = filepath.ToSlash(ctx.FilePath) + ctx.Code = strings.TrimSpace(ctx.Code) + ctx.OriginFileContent = strings.TrimSpace(ctx.OriginFileContent) var b strings.Builder - if err := llmTranspileSourceFilePromptTemplate.Execute(&b, data); err != nil { + if err := llmTranspileSourceFilePromptTemplate.Execute(&b, ctx); err != nil { return "", pkgerrors.WithStack(err) } @@ -198,13 +204,15 @@ func (m *Model) WriteTests(ctx model.Context) (assessment metrics.Assessments, e importPath := ctx.Language.ImportPath(ctx.RepositoryPath, ctx.FilePath) - request, err := llmGenerateTestForFilePrompt(&llmSourceFilePromptContext{ - Language: ctx.Language, + request, err := (&llmWriteTestSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: ctx.Language, - Code: fileContent, - FilePath: ctx.FilePath, - ImportPath: importPath, - }) + Code: fileContent, + FilePath: ctx.FilePath, + ImportPath: importPath, + }, + }).Format() if err != nil { return nil, err } @@ -280,7 +288,7 @@ func (m *Model) RepairCode(ctx model.Context) (assessment metrics.Assessments, e importPath := ctx.Language.ImportPath(ctx.RepositoryPath, ctx.FilePath) - request, err := llmCodeRepairSourceFilePrompt(&llmCodeRepairSourceFilePromptContext{ + request, err := (&llmCodeRepairSourceFilePromptContext{ llmSourceFilePromptContext: llmSourceFilePromptContext{ Language: ctx.Language, @@ -290,7 +298,7 @@ func (m *Model) RepairCode(ctx model.Context) (assessment metrics.Assessments, e }, Mistakes: codeRepairArguments.Mistakes, - }) + }).Format() if err != nil { return nil, err } @@ -339,7 +347,7 @@ func (m *Model) Transpile(ctx model.Context) (assessment metrics.Assessments, er importPath := ctx.Language.ImportPath(ctx.RepositoryPath, ctx.FilePath) - request, err := llmTranspileSourceFilePrompt(&llmTranspileSourceFilePromptContext{ + request, err := (&llmTranspileSourceFilePromptContext{ llmSourceFilePromptContext: llmSourceFilePromptContext{ Language: ctx.Language, @@ -350,7 +358,7 @@ func (m *Model) Transpile(ctx model.Context) (assessment metrics.Assessments, er OriginLanguage: transpileArguments.OriginLanguage, OriginFileContent: originFileContent, - }) + }).Format() if err != nil { return nil, err } diff --git a/model/llm/llm_test.go b/model/llm/llm_test.go index 6963282b..2f519e5d 100644 --- a/model/llm/llm_test.go +++ b/model/llm/llm_test.go @@ -84,13 +84,15 @@ func TestModelGenerateTestsForFile(t *testing.T) { func main() {} ` sourceFilePath := "simple.go" - promptMessage, err := llmGenerateTestForFilePrompt(&llmSourceFilePromptContext{ - Language: &golang.Language{}, + promptMessage, err := (&llmWriteTestSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: &golang.Language{}, - Code: bytesutil.StringTrimIndentations(sourceFileContent), - FilePath: sourceFilePath, - ImportPath: "native", - }) + Code: bytesutil.StringTrimIndentations(sourceFileContent), + FilePath: sourceFilePath, + ImportPath: "native", + }, + }).Format() require.NoError(t, err) validate(t, &testCase{ Name: "Simple", @@ -287,18 +289,22 @@ func TestModelRepairSourceCodeFile(t *testing.T) { }) } -func TestLLMGenerateTestForFilePrompt(t *testing.T) { +type promptContext interface { + Format() (string, error) +} + +func TestFormatPromptContext(t *testing.T) { type testCase struct { Name string - Data *llmSourceFilePromptContext + Context promptContext ExpectedMessage string } validate := func(t *testing.T, tc *testCase) { t.Run(tc.Name, func(t *testing.T) { - actualMessage, actualErr := llmGenerateTestForFilePrompt(tc.Data) + actualMessage, actualErr := tc.Context.Format() require.NoError(t, actualErr) assert.Equal(t, tc.ExpectedMessage, actualMessage) @@ -306,20 +312,22 @@ func TestLLMGenerateTestForFilePrompt(t *testing.T) { } validate(t, &testCase{ - Name: "Plain", + Name: "Write Test", - Data: &llmSourceFilePromptContext{ - Language: &golang.Language{}, + Context: &llmWriteTestSourceFilePromptContext{ + llmSourceFilePromptContext: llmSourceFilePromptContext{ + Language: &golang.Language{}, - Code: bytesutil.StringTrimIndentations(` - package increment + Code: bytesutil.StringTrimIndentations(` + package increment - func increment(i int) int - return i + 1 - } - `), - FilePath: filepath.Join("path", "to", "increment.go"), - ImportPath: "increment", + func increment(i int) int + return i + 1 + } + `), + FilePath: filepath.Join("path", "to", "increment.go"), + ImportPath: "increment", + }, }, ExpectedMessage: bytesutil.StringTrimIndentations(` @@ -336,30 +344,11 @@ func TestLLMGenerateTestForFilePrompt(t *testing.T) { ` + "```" + ` `), }) -} - -func TestLLMCodeRepairSourceFilePrompt(t *testing.T) { - type testCase struct { - Name string - - Data *llmCodeRepairSourceFilePromptContext - - ExpectedMessage string - } - - validate := func(t *testing.T, tc *testCase) { - t.Run(tc.Name, func(t *testing.T) { - actualMessage, actualErr := llmCodeRepairSourceFilePrompt(tc.Data) - require.NoError(t, actualErr) - - assert.Equal(t, tc.ExpectedMessage, actualMessage) - }) - } validate(t, &testCase{ - Name: "Plain", + Name: "Code Repair", - Data: &llmCodeRepairSourceFilePromptContext{ + Context: &llmCodeRepairSourceFilePromptContext{ llmSourceFilePromptContext: llmSourceFilePromptContext{ Language: &golang.Language{}, @@ -398,30 +387,11 @@ func TestLLMCodeRepairSourceFilePrompt(t *testing.T) { - path/to/increment.go: missing return `), }) -} - -func TestLLMTranspileSourceFilePrompt(t *testing.T) { - type testCase struct { - Name string - - Data *llmTranspileSourceFilePromptContext - - ExpectedMessage string - } - - validate := func(t *testing.T, tc *testCase) { - t.Run(tc.Name, func(t *testing.T) { - actualMessage, actualErr := llmTranspileSourceFilePrompt(tc.Data) - require.NoError(t, actualErr) - - assert.Equal(t, tc.ExpectedMessage, actualMessage) - }) - } validate(t, &testCase{ - Name: "Transpile Go into Java", + Name: "Transpile", - Data: &llmTranspileSourceFilePromptContext{ + Context: &llmTranspileSourceFilePromptContext{ llmSourceFilePromptContext: llmSourceFilePromptContext{ Language: &java.Language{},