diff --git a/pkg/cli/fix_command.go b/pkg/cli/fix_command.go index 225a328ed9..27829ad0db 100644 --- a/pkg/cli/fix_command.go +++ b/pkg/cli/fix_command.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "io" "os" "path/filepath" "strings" @@ -28,6 +29,12 @@ func RunFix(config FixConfig) error { return runFixCommand(config.WorkflowIDs, config.Write, config.Verbose, config.WorkflowDir) } +// RunFixWithWriter runs the fix command with the given configuration, +// writing output to the provided writer +func RunFixWithWriter(w io.Writer, config FixConfig) error { + return runFixCommandWithWriter(w, config.WorkflowIDs, config.Write, config.Verbose, config.WorkflowDir) +} + // NewFixCommand creates the fix command func NewFixCommand() *cobra.Command { cmd := &cobra.Command{ @@ -116,6 +123,12 @@ func listAvailableCodemods() error { // runFixCommand runs the fix command on specified or all workflows func runFixCommand(workflowIDs []string, write bool, verbose bool, workflowDir string) error { + return runFixCommandWithWriter(os.Stderr, workflowIDs, write, verbose, workflowDir) +} + +// runFixCommandWithWriter runs the fix command on specified or all workflows, +// writing output to the provided writer +func runFixCommandWithWriter(w io.Writer, workflowIDs []string, write bool, verbose bool, workflowDir string) error { fixLog.Printf("Running fix command: workflowIDs=%v, write=%v, verbose=%v, workflowDir=%s", workflowIDs, write, verbose, workflowDir) // Set up workflow directory (using default if not specified) @@ -149,7 +162,7 @@ func runFixCommand(workflowIDs []string, write bool, verbose bool, workflowDir s } if len(files) == 0 { - fmt.Fprintln(os.Stderr, console.FormatInfoMessage("No workflow files found.")) + fmt.Fprintln(w, console.FormatInfoMessage("No workflow files found.")) return nil } @@ -165,9 +178,9 @@ func runFixCommand(workflowIDs []string, write bool, verbose bool, workflowDir s for _, file := range files { fixLog.Printf("Processing file: %s", file) - fixed, appliedFixes, err := processWorkflowFileWithInfo(file, codemods, write, verbose) + fixed, appliedFixes, err := processWorkflowFileWithInfo(w, file, codemods, write, verbose) if err != nil { - fmt.Fprintf(os.Stderr, "%s\n", console.FormatErrorMessage(fmt.Sprintf("Error processing %s: %v", filepath.Base(file), err))) + fmt.Fprintf(w, "%s\n", console.FormatErrorMessage(fmt.Sprintf("Error processing %s: %v", filepath.Base(file), err))) continue } @@ -190,49 +203,49 @@ func runFixCommand(workflowIDs []string, write bool, verbose bool, workflowDir s // Update copilot instructions if err := ensureCopilotInstructions(verbose, false); err != nil { fixLog.Printf("Failed to update copilot instructions: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update copilot instructions: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update copilot instructions: %v", err))) } // Update dispatcher agent if err := ensureAgenticWorkflowsDispatcher(verbose, false); err != nil { fixLog.Printf("Failed to update dispatcher agent: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update dispatcher agent: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update dispatcher agent: %v", err))) } // Update create workflow prompt if err := ensureCreateWorkflowPrompt(verbose, false); err != nil { fixLog.Printf("Failed to update create workflow prompt: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update workflow creation prompt: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update workflow creation prompt: %v", err))) } // Update update workflow prompt if err := ensureUpdateWorkflowPrompt(verbose, false); err != nil { fixLog.Printf("Failed to update update workflow prompt: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update workflow update prompt: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update workflow update prompt: %v", err))) } // Update create shared agentic workflow prompt if err := ensureCreateSharedAgenticWorkflowPrompt(verbose, false); err != nil { fixLog.Printf("Failed to update create shared workflow prompt: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update shared workflow creation prompt: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update shared workflow creation prompt: %v", err))) } // Update debug workflow prompt if err := ensureDebugWorkflowPrompt(verbose, false); err != nil { fixLog.Printf("Failed to update debug workflow prompt: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update debug workflow prompt: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update debug workflow prompt: %v", err))) } // Update upgrade agentic workflows prompt if err := ensureUpgradeAgenticWorkflowsPrompt(verbose, false); err != nil { fixLog.Printf("Failed to update upgrade workflows prompt: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update upgrade workflow prompt: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update upgrade workflow prompt: %v", err))) } // Update Serena tool documentation if err := ensureSerenaTool(verbose, false); err != nil { fixLog.Printf("Failed to update Serena tool documentation: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update Serena tool documentation: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to update Serena tool documentation: %v", err))) } // Delete old template files from pkg/cli/templates/ (only with --write) @@ -240,7 +253,7 @@ func runFixCommand(workflowIDs []string, write bool, verbose bool, workflowDir s fixLog.Print("Cleaning up old template files") if err := deleteOldTemplateFiles(verbose); err != nil { fixLog.Printf("Failed to delete old template files: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to delete old template files: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to delete old template files: %v", err))) } } @@ -249,7 +262,7 @@ func runFixCommand(workflowIDs []string, write bool, verbose bool, workflowDir s fixLog.Print("Deleting old agent files") if err := deleteOldAgentFiles(verbose); err != nil { fixLog.Printf("Failed to delete old agent files: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to delete old agent files: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to delete old agent files: %v", err))) } } @@ -260,43 +273,43 @@ func runFixCommand(workflowIDs []string, write bool, verbose bool, workflowDir s if write { if err := os.Remove(schemaPath); err != nil { fixLog.Printf("Failed to delete schema file: %v", err) - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to delete deprecated schema file: %v", err))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("Warning: Failed to delete deprecated schema file: %v", err))) } else { fixLog.Print("Deleted deprecated schema file") if verbose { - fmt.Fprintf(os.Stderr, "%s\n", console.FormatSuccessMessage("Deleted deprecated .github/aw/schemas/agentic-workflow.json")) + fmt.Fprintf(w, "%s\n", console.FormatSuccessMessage("Deleted deprecated .github/aw/schemas/agentic-workflow.json")) } } } else { - fmt.Fprintf(os.Stderr, "%s\n", console.FormatInfoMessage("Would delete deprecated .github/aw/schemas/agentic-workflow.json")) + fmt.Fprintf(w, "%s\n", console.FormatInfoMessage("Would delete deprecated .github/aw/schemas/agentic-workflow.json")) } } // Print summary - fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(w, "") if write { if totalFixed > 0 { - fmt.Fprintf(os.Stderr, "%s\n", console.FormatSuccessMessage(fmt.Sprintf("✓ Fixed %d of %d workflow files", totalFixed, totalFiles))) + fmt.Fprintf(w, "%s\n", console.FormatSuccessMessage(fmt.Sprintf("✓ Fixed %d of %d workflow files", totalFixed, totalFiles))) } else { - fmt.Fprintf(os.Stderr, "%s\n", console.FormatInfoMessage("✓ No fixes needed")) + fmt.Fprintf(w, "%s\n", console.FormatInfoMessage("✓ No fixes needed")) } } else { if totalFixed > 0 { - fmt.Fprintf(os.Stderr, "%s\n", console.FormatInfoMessage(fmt.Sprintf("Would fix %d of %d workflow files", totalFixed, totalFiles))) - fmt.Fprintln(os.Stderr, "") + fmt.Fprintf(w, "%s\n", console.FormatInfoMessage(fmt.Sprintf("Would fix %d of %d workflow files", totalFixed, totalFiles))) + fmt.Fprintln(w, "") // Output as agent prompt - fmt.Fprintln(os.Stderr, console.FormatInfoMessage("To fix these issues, run:")) - fmt.Fprintln(os.Stderr, "") - fmt.Fprintln(os.Stderr, " gh aw fix --write") - fmt.Fprintln(os.Stderr, "") - fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Or fix them individually:")) - fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(w, console.FormatInfoMessage("To fix these issues, run:")) + fmt.Fprintln(w, "") + fmt.Fprintln(w, " gh aw fix --write") + fmt.Fprintln(w, "") + fmt.Fprintln(w, console.FormatInfoMessage("Or fix them individually:")) + fmt.Fprintln(w, "") for _, wf := range workflowsNeedingFixes { - fmt.Fprintf(os.Stderr, " gh aw fix %s --write\n", strings.TrimSuffix(wf.File, ".md")) + fmt.Fprintf(w, " gh aw fix %s --write\n", strings.TrimSuffix(wf.File, ".md")) } } else { - fmt.Fprintf(os.Stderr, "%s\n", console.FormatInfoMessage("✓ No fixes needed")) + fmt.Fprintf(w, "%s\n", console.FormatInfoMessage("✓ No fixes needed")) } } @@ -311,12 +324,12 @@ type workflowFixInfo struct { // processWorkflowFile processes a single workflow file with all codemods func processWorkflowFile(filePath string, codemods []Codemod, write bool, verbose bool) (bool, error) { - fixed, _, err := processWorkflowFileWithInfo(filePath, codemods, write, verbose) + fixed, _, err := processWorkflowFileWithInfo(os.Stderr, filePath, codemods, write, verbose) return fixed, err } // processWorkflowFileWithInfo processes a single workflow file and returns detailed fix information -func processWorkflowFileWithInfo(filePath string, codemods []Codemod, write bool, verbose bool) (bool, []string, error) { +func processWorkflowFileWithInfo(w io.Writer, filePath string, codemods []Codemod, write bool, verbose bool) (bool, []string, error) { fixLog.Printf("Processing workflow file: %s", filePath) // Read the file @@ -360,7 +373,7 @@ func processWorkflowFileWithInfo(filePath string, codemods []Codemod, write bool // If no changes, report and return if !hasChanges { if verbose { - fmt.Fprintf(os.Stderr, "%s\n", console.FormatInfoMessage(fmt.Sprintf(" %s - no fixes needed", filepath.Base(filePath)))) + fmt.Fprintf(w, "%s\n", console.FormatInfoMessage(fmt.Sprintf(" %s - no fixes needed", filepath.Base(filePath)))) } return false, nil, nil } @@ -373,14 +386,14 @@ func processWorkflowFileWithInfo(filePath string, codemods []Codemod, write bool return false, nil, fmt.Errorf("failed to write file: %w", err) } - fmt.Fprintf(os.Stderr, "%s\n", console.FormatSuccessMessage(fmt.Sprintf("✓ %s", fileName))) + fmt.Fprintf(w, "%s\n", console.FormatSuccessMessage(fmt.Sprintf("✓ %s", fileName))) for _, codemodName := range appliedCodemods { - fmt.Fprintf(os.Stderr, " • %s\n", codemodName) + fmt.Fprintf(w, " • %s\n", codemodName) } } else { - fmt.Fprintf(os.Stderr, "%s\n", console.FormatWarningMessage(fmt.Sprintf("⚠ %s", fileName))) + fmt.Fprintf(w, "%s\n", console.FormatWarningMessage(fmt.Sprintf("⚠ %s", fileName))) for _, codemodName := range appliedCodemods { - fmt.Fprintf(os.Stderr, " • %s\n", codemodName) + fmt.Fprintf(w, " • %s\n", codemodName) } } diff --git a/pkg/cli/mcp_server.go b/pkg/cli/mcp_server.go index fc954e7a46..4bebaaf20c 100644 --- a/pkg/cli/mcp_server.go +++ b/pkg/cli/mcp_server.go @@ -1,6 +1,7 @@ package cli import ( + "bytes" "context" "encoding/json" "fmt" @@ -870,35 +871,48 @@ Returns formatted text output showing: default: } - // Build command arguments - cmdArgs := []string{"update"} - - // Add workflow IDs if specified - cmdArgs = append(cmdArgs, args.Workflows...) - - // Add optional flags - if args.Major { - cmdArgs = append(cmdArgs, "--major") - } - if args.Force { - cmdArgs = append(cmdArgs, "--force") - } - - // Execute the CLI command - cmd := execCmd(ctx, cmdArgs...) - output, err := cmd.CombinedOutput() + mcpLog.Printf("Executing update tool: workflows=%v, major=%v, force=%v", args.Workflows, args.Major, args.Force) + + // Call the update function directly instead of spawning subprocess + // Use a bytes.Buffer to capture output for the MCP response + var outputBuf bytes.Buffer + + // Note: The update function requires full parameters, so we set defaults for unspecified ones: + // - verbose: false (MCP server doesn't need verbose output) + // - engineOverride: "" (no engine override) + // - createPR: false (don't create PR from MCP server) + // - workflowsDir: "" (use default .github/workflows) + // - noStopAfter: false (keep stop-after fields) + // - stopAfter: "" (don't override stop-after) + // - merge: false (don't merge, just update) + // - noActions: false (update actions by default) + outputStr, err := UpdateWorkflowsWithExtensionCheckContext( + ctx, + &outputBuf, + args.Workflows, + args.Major, // allowMajor + args.Force, // force + false, // verbose + "", // engineOverride + false, // createPR + "", // workflowsDir + false, // noStopAfter + "", // stopAfter + false, // merge + false, // noActions + ) if err != nil { return nil, nil, &jsonrpc.Error{ Code: jsonrpc.CodeInternalError, Message: "failed to update workflows", - Data: mcpErrorData(map[string]any{"error": err.Error(), "output": string(output)}), + Data: mcpErrorData(map[string]any{"error": err.Error(), "output": outputStr}), } } return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: string(output)}, + &mcp.TextContent{Text: outputStr}, }, }, nil, nil }) diff --git a/pkg/cli/update_actions.go b/pkg/cli/update_actions.go index 226cf831de..e27c156d25 100644 --- a/pkg/cli/update_actions.go +++ b/pkg/cli/update_actions.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -152,6 +153,131 @@ func UpdateActions(allowMajor, verbose bool) error { return nil } +// UpdateActionsWithWriter updates GitHub Actions versions in .github/aw/actions-lock.json, +// writing output to the provided writer +func UpdateActionsWithWriter(w io.Writer, allowMajor, verbose bool) error { + updateLog.Print("Starting action updates with writer") + + if verbose { + fmt.Fprintln(w, console.FormatInfoMessage("Checking for GitHub Actions updates...")) + } + + // Get the path to actions-lock.json + actionsLockPath := filepath.Join(".github", "aw", "actions-lock.json") + + // Check if the file exists + if _, err := os.Stat(actionsLockPath); os.IsNotExist(err) { + if verbose { + fmt.Fprintln(w, console.FormatVerboseMessage(fmt.Sprintf("Actions lock file not found: %s", actionsLockPath))) + } + return nil // Not an error, just skip + } + + // Load the current actions lock file + data, err := os.ReadFile(actionsLockPath) + if err != nil { + return fmt.Errorf("failed to read actions lock file: %w", err) + } + + var actionsLock actionsLockFile + if err := json.Unmarshal(data, &actionsLock); err != nil { + return fmt.Errorf("failed to parse actions lock file: %w", err) + } + + updateLog.Printf("Loaded %d action entries from actions-lock.json", len(actionsLock.Entries)) + + // Track updates + var updatedActions []string + var failedActions []string + var skippedActions []string + + // Update each action + for key, entry := range actionsLock.Entries { + updateLog.Printf("Checking action: %s@%s", entry.Repo, entry.Version) + + // Check for latest release + latestVersion, latestSHA, err := getLatestActionRelease(entry.Repo, entry.Version, allowMajor, verbose) + if err != nil { + if verbose { + fmt.Fprintln(w, console.FormatWarningMessage(fmt.Sprintf("Failed to check %s: %v", entry.Repo, err))) + } + failedActions = append(failedActions, entry.Repo) + continue + } + + // Check if update is available + if latestVersion == entry.Version && latestSHA == entry.SHA { + if verbose { + fmt.Fprintln(w, console.FormatVerboseMessage(fmt.Sprintf("%s@%s is up to date", entry.Repo, entry.Version))) + } + skippedActions = append(skippedActions, entry.Repo) + continue + } + + // Update the entry + updateLog.Printf("Updating %s from %s (%s) to %s (%s)", entry.Repo, entry.Version, entry.SHA[:7], latestVersion, latestSHA[:7]) + fmt.Fprintln(w, console.FormatSuccessMessage(fmt.Sprintf("Updated %s from %s to %s", entry.Repo, entry.Version, latestVersion))) + + // Delete the old key (which has the old version) + delete(actionsLock.Entries, key) + + // Create a new key with the new version + newKey := entry.Repo + "@" + latestVersion + actionsLock.Entries[newKey] = actionsLockEntry{ + Repo: entry.Repo, + Version: latestVersion, + SHA: latestSHA, + } + + updatedActions = append(updatedActions, entry.Repo) + } + + // Show summary + fmt.Fprintln(w, "") + + if len(updatedActions) > 0 { + fmt.Fprintln(w, console.FormatSuccessMessage(fmt.Sprintf("Updated %d action(s):", len(updatedActions)))) + for _, action := range updatedActions { + fmt.Fprintln(w, console.FormatListItem(action)) + } + fmt.Fprintln(w, "") + } + + if len(skippedActions) > 0 && verbose { + fmt.Fprintln(w, console.FormatInfoMessage(fmt.Sprintf("%d action(s) already up to date", len(skippedActions)))) + fmt.Fprintln(w, "") + } + + if len(failedActions) > 0 { + fmt.Fprintln(w, console.FormatWarningMessage(fmt.Sprintf("Failed to check %d action(s):", len(failedActions)))) + for _, action := range failedActions { + fmt.Fprintf(w, " %s\n", action) + } + fmt.Fprintln(w, "") + } + + // Save the updated actions lock file if there were any updates + if len(updatedActions) > 0 { + // Marshal with sorted keys and pretty printing + updatedData, err := marshalActionsLockSorted(&actionsLock) + if err != nil { + return fmt.Errorf("failed to marshal updated actions lock: %w", err) + } + + // Add trailing newline for prettier compliance + updatedData = append(updatedData, '\n') + + if err := os.WriteFile(actionsLockPath, updatedData, 0644); err != nil { + return fmt.Errorf("failed to write updated actions lock file: %w", err) + } + + updateLog.Printf("Successfully wrote updated actions-lock.json with %d updates", len(updatedActions)) + fmt.Fprintln(w, console.FormatInfoMessage("Updated actions-lock.json file")) + } + + return nil +} + // getLatestActionRelease gets the latest release for an action repository // It respects semantic versioning and the allowMajor flag func getLatestActionRelease(repo, currentVersion string, allowMajor, verbose bool) (string, string, error) { diff --git a/pkg/cli/update_command.go b/pkg/cli/update_command.go index a46d338233..d44c9439ea 100644 --- a/pkg/cli/update_command.go +++ b/pkg/cli/update_command.go @@ -1,7 +1,10 @@ package cli import ( + "bytes" + "context" "fmt" + "io" "os" "github.com/github/gh-aw/pkg/console" @@ -185,3 +188,122 @@ func UpdateWorkflowsWithExtensionCheck(workflowNames []string, allowMajor, force return nil } + +// UpdateWorkflowsWithExtensionCheckContext performs the complete update process with context support. +// It accepts a context for cancellation and an io.Writer for output capture. +// This function is designed for concurrent usage and does not modify global state. +// +// Parameters: +// - ctx: Context for cancellation support +// - output: Writer for capturing console output +// - workflowNames: List of workflow IDs to update (empty for all) +// - allowMajor: Allow major version updates +// - force: Force update even if no changes detected +// - verbose: Enable verbose output +// - engineOverride: Override the AI engine +// - createPR: Create a pull request with changes +// - workflowsDir: Workflow directory (empty for default) +// - noStopAfter: Remove stop-after field +// - stopAfter: Override stop-after value +// - merge: Merge local changes with upstream +// - noActions: Skip GitHub Actions updates +// +// Returns: +// - Output string containing formatted results +// - Error if the update process fails +func UpdateWorkflowsWithExtensionCheckContext( + ctx context.Context, + output io.Writer, + workflowNames []string, + allowMajor, force, verbose bool, + engineOverride string, + createPR bool, + workflowsDir string, + noStopAfter bool, + stopAfter string, + merge bool, + noActions bool, +) (string, error) { + updateLog.Printf("Starting update process with context: workflows=%v, allowMajor=%v, force=%v, createPR=%v, merge=%v, noActions=%v", workflowNames, allowMajor, force, createPR, merge, noActions) + + // Create a buffer to capture output + var buf bytes.Buffer + // Use a multi-writer to write to both the provided output and our buffer + w := io.MultiWriter(output, &buf) + + // Check for context cancellation + select { + case <-ctx.Done(): + return "", ctx.Err() + default: + } + + // Step 1: Check for gh-aw extension updates + if err := checkExtensionUpdateWithWriter(w, verbose); err != nil { + return buf.String(), fmt.Errorf("extension update check failed: %w", err) + } + + // Check for context cancellation + select { + case <-ctx.Done(): + return buf.String(), ctx.Err() + default: + } + + // Step 2: Update GitHub Actions versions (unless disabled) + if !noActions { + if err := UpdateActionsWithWriter(w, allowMajor, verbose); err != nil { + return buf.String(), fmt.Errorf("action update failed: %w", err) + } + } + + // Check for context cancellation + select { + case <-ctx.Done(): + return buf.String(), ctx.Err() + default: + } + + // Step 3: Update workflows from source repositories + // Note: Each workflow is compiled immediately after update + if err := UpdateWorkflowsWithWriter(w, workflowNames, allowMajor, force, verbose, engineOverride, workflowsDir, noStopAfter, stopAfter, merge); err != nil { + return buf.String(), fmt.Errorf("workflow update failed: %w", err) + } + + // Check for context cancellation + select { + case <-ctx.Done(): + return buf.String(), ctx.Err() + default: + } + + // Step 4: Apply automatic fixes to updated workflows + fixConfig := FixConfig{ + WorkflowIDs: workflowNames, + Write: true, + Verbose: verbose, + } + if err := RunFixWithWriter(w, fixConfig); err != nil { + updateLog.Printf("Fix command failed (non-fatal): %v", err) + // Don't fail the update if fix fails - this is non-critical + if verbose { + fmt.Fprintln(w, console.FormatWarningMessage(fmt.Sprintf("Warning: automatic fixes failed: %v", err))) + } + } + + // Check for context cancellation + select { + case <-ctx.Done(): + return buf.String(), ctx.Err() + default: + } + + // Step 5: Optionally create PR if flag is set + if createPR { + if err := createUpdatePRWithWriter(w, verbose); err != nil { + return buf.String(), fmt.Errorf("failed to create PR: %w", err) + } + } + + return buf.String(), nil +} diff --git a/pkg/cli/update_command_context_test.go b/pkg/cli/update_command_context_test.go new file mode 100644 index 0000000000..0a4e9cb6a1 --- /dev/null +++ b/pkg/cli/update_command_context_test.go @@ -0,0 +1,185 @@ +//go:build !integration + +package cli + +import ( + "bytes" + "context" + "os" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestUpdateWorkflowsWithExtensionCheckContext_Cancellation tests that the function respects context cancellation +func TestUpdateWorkflowsWithExtensionCheckContext_Cancellation(t *testing.T) { + // Create a context that is already cancelled + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + var buf bytes.Buffer + + // Call the function with the cancelled context + output, err := UpdateWorkflowsWithExtensionCheckContext( + ctx, + &buf, + nil, // workflowNames + false, // allowMajor + false, // force + false, // verbose + "", // engineOverride + false, // createPR + "", // workflowsDir + false, // noStopAfter + "", // stopAfter + false, // merge + false, // noActions + ) + + // Should return context.Canceled error + require.Error(t, err, "Expected error when context is cancelled") + assert.Equal(t, context.Canceled, err, "Expected context.Canceled error") + assert.Empty(t, output, "Expected empty output when cancelled immediately") +} + +// TestUpdateWorkflowsWithExtensionCheckContext_Timeout tests that the function respects context timeout +func TestUpdateWorkflowsWithExtensionCheckContext_Timeout(t *testing.T) { + // Skip if running in CI without sufficient time + if testing.Short() { + t.Skip("Skipping timeout test in short mode") + } + + // Create a context with a very short timeout + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + + // Wait a moment to ensure context times out + time.Sleep(10 * time.Millisecond) + + var buf bytes.Buffer + + // Call the function with the timed-out context + _, err := UpdateWorkflowsWithExtensionCheckContext( + ctx, + &buf, + nil, // workflowNames + false, // allowMajor + false, // force + false, // verbose + "", // engineOverride + false, // createPR + "", // workflowsDir + false, // noStopAfter + "", // stopAfter + false, // merge + false, // noActions + ) + + // Should return context.DeadlineExceeded error + require.Error(t, err, "Expected error when context times out") + assert.Contains(t, []error{context.DeadlineExceeded, context.Canceled}, err, "Expected context deadline/canceled error") +} + +// TestUpdateWorkflowsWithExtensionCheckContext_OutputCapture tests that output is captured correctly +func TestUpdateWorkflowsWithExtensionCheckContext_OutputCapture(t *testing.T) { + // Skip if we don't have gh CLI or not in a git repo + if !isGHCLIAvailable() { + t.Skip("Skipping test: gh CLI not available") + } + + ctx := context.Background() + var buf bytes.Buffer + + // Create a temporary directory to avoid "no such file" errors + tmpDir := t.TempDir() + originalDir, _ := os.Getwd() + defer os.Chdir(originalDir) + + // Change to temp dir but don't create workflows directory (test will fail gracefully) + os.Chdir(tmpDir) + + // Call with minimal parameters - this should at least try to check for updates + _, err := UpdateWorkflowsWithExtensionCheckContext( + ctx, + &buf, + []string{}, // workflowNames (empty means all workflows with source field) + false, // allowMajor + false, // force + false, // verbose + "", // engineOverride + false, // createPR + "", // workflowsDir + false, // noStopAfter + "", // stopAfter + false, // merge + true, // noActions (skip action updates to make test faster) + ) + + // The function may fail if there are no workflows directory/files, which is expected + // The important thing is that we tested the function can be called + if err != nil { + // Expected - likely no workflows directory or no workflows with source field + assert.True(t, strings.Contains(err.Error(), "no workflows found") || strings.Contains(err.Error(), "no such file"), + "Expected 'no workflows found' or 'no such file' error, got: %v", err) + } + + // Note: Buffer may be empty if error occurs early (before any output) + // That's OK - we're testing that the function works with a buffer +} + +// TestUpdateWorkflowsWithExtensionCheckContext_BufferWriting tests that the function writes to the provided buffer +func TestUpdateWorkflowsWithExtensionCheckContext_BufferWriting(t *testing.T) { + ctx := context.Background() + var buf bytes.Buffer + + // Create a custom writer that tracks writes + writeCount := 0 + trackingWriter := &trackingWriter{ + Writer: &buf, + onWrite: func(p []byte) { + writeCount++ + }, + } + + // Call with parameters that will trigger some output + _, err := UpdateWorkflowsWithExtensionCheckContext( + ctx, + trackingWriter, + []string{}, // Empty workflow names + false, // allowMajor + false, // force + true, // verbose (should generate more output) + "", // engineOverride + false, // createPR + "", // workflowsDir + false, // noStopAfter + "", // stopAfter + false, // merge + true, // noActions (skip action updates) + ) + + // We expect an error (no workflows with source) but also some output + if err == nil || !strings.Contains(err.Error(), "no workflows found") { + t.Logf("Unexpected error: %v", err) + } + + // Verify that writes occurred to our tracking writer + assert.Positive(t, writeCount, "Expected at least one write to the output buffer") + assert.NotEmpty(t, buf.String(), "Expected buffer to contain output") +} + +// trackingWriter wraps an io.Writer and tracks write operations +type trackingWriter struct { + Writer *bytes.Buffer + onWrite func([]byte) +} + +func (tw *trackingWriter) Write(p []byte) (n int, err error) { + if tw.onWrite != nil { + tw.onWrite(p) + } + return tw.Writer.Write(p) +} diff --git a/pkg/cli/update_display.go b/pkg/cli/update_display.go index b16833bb5e..aca28c7dce 100644 --- a/pkg/cli/update_display.go +++ b/pkg/cli/update_display.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "io" "os" "github.com/github/gh-aw/pkg/console" @@ -9,23 +10,29 @@ import ( // showUpdateSummary displays a summary of workflow updates using console helpers func showUpdateSummary(successfulUpdates []string, failedUpdates []updateFailure) { - fmt.Fprintln(os.Stderr, "") + showUpdateSummaryWithWriter(os.Stderr, successfulUpdates, failedUpdates) +} + +// showUpdateSummaryWithWriter displays a summary of workflow updates using console helpers, +// writing output to the provided writer +func showUpdateSummaryWithWriter(w io.Writer, successfulUpdates []string, failedUpdates []updateFailure) { + fmt.Fprintln(w, "") // Show successful updates if len(successfulUpdates) > 0 { - fmt.Fprintln(os.Stderr, console.FormatSuccessMessage(fmt.Sprintf("Successfully updated and compiled %d workflow(s):", len(successfulUpdates)))) + fmt.Fprintln(w, console.FormatSuccessMessage(fmt.Sprintf("Successfully updated and compiled %d workflow(s):", len(successfulUpdates)))) for _, name := range successfulUpdates { - fmt.Fprintln(os.Stderr, console.FormatListItem(name)) + fmt.Fprintln(w, console.FormatListItem(name)) } - fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(w, "") } // Show failed updates if len(failedUpdates) > 0 { - fmt.Fprintln(os.Stderr, console.FormatErrorMessage(fmt.Sprintf("Failed to update %d workflow(s):", len(failedUpdates)))) + fmt.Fprintln(w, console.FormatErrorMessage(fmt.Sprintf("Failed to update %d workflow(s):", len(failedUpdates)))) for _, failure := range failedUpdates { - fmt.Fprintf(os.Stderr, " %s: %s\n", failure.Name, failure.Error) + fmt.Fprintf(w, " %s: %s\n", failure.Name, failure.Error) } - fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(w, "") } } diff --git a/pkg/cli/update_extension_check.go b/pkg/cli/update_extension_check.go index 9e779287dc..d5a9cf8cb4 100644 --- a/pkg/cli/update_extension_check.go +++ b/pkg/cli/update_extension_check.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "io" "os" "strings" @@ -11,22 +12,28 @@ import ( // checkExtensionUpdate checks if a newer version of gh-aw is available func checkExtensionUpdate(verbose bool) error { + return checkExtensionUpdateWithWriter(os.Stderr, verbose) +} + +// checkExtensionUpdateWithWriter checks if a newer version of gh-aw is available, +// writing output to the provided writer +func checkExtensionUpdateWithWriter(w io.Writer, verbose bool) error { if verbose { - fmt.Fprintln(os.Stderr, console.FormatVerboseMessage("Checking for gh-aw extension updates...")) + fmt.Fprintln(w, console.FormatVerboseMessage("Checking for gh-aw extension updates...")) } // Run gh extension upgrade --dry-run to check for updates output, err := workflow.RunGHCombined("Checking for extension updates...", "extension", "upgrade", "github/gh-aw", "--dry-run") if err != nil { if verbose { - fmt.Fprintln(os.Stderr, console.FormatWarningMessage(fmt.Sprintf("Failed to check for extension updates: %v", err))) + fmt.Fprintln(w, console.FormatWarningMessage(fmt.Sprintf("Failed to check for extension updates: %v", err))) } return nil // Don't fail the whole command if update check fails } outputStr := strings.TrimSpace(string(output)) if verbose { - fmt.Fprintln(os.Stderr, console.FormatVerboseMessage(fmt.Sprintf("Extension update check output: %s", outputStr))) + fmt.Fprintln(w, console.FormatVerboseMessage(fmt.Sprintf("Extension update check output: %s", outputStr))) } // Parse the output to see if an update is available @@ -34,15 +41,15 @@ func checkExtensionUpdate(verbose bool) error { lines := strings.Split(outputStr, "\n") for _, line := range lines { if strings.Contains(line, "[agentics]: would have upgraded from") { - fmt.Fprintln(os.Stderr, console.FormatInfoMessage(line)) - fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Run 'gh extension upgrade github/gh-aw' to update")) + fmt.Fprintln(w, console.FormatInfoMessage(line)) + fmt.Fprintln(w, console.FormatInfoMessage("Run 'gh extension upgrade github/gh-aw' to update")) return nil } } if strings.Contains(outputStr, "✓ Successfully checked extension upgrades") { if verbose { - fmt.Fprintln(os.Stderr, console.FormatSuccessMessage("gh-aw extension is up to date")) + fmt.Fprintln(w, console.FormatSuccessMessage("gh-aw extension is up to date")) } } diff --git a/pkg/cli/update_git.go b/pkg/cli/update_git.go index 042de8050f..8eb21f5933 100644 --- a/pkg/cli/update_git.go +++ b/pkg/cli/update_git.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "io" "math/rand" "os" "os/exec" @@ -33,6 +34,12 @@ func runGitCommand(args ...string) error { // createUpdatePR creates a pull request with the workflow changes func createUpdatePR(verbose bool) error { + return createUpdatePRWithWriter(os.Stderr, verbose) +} + +// createUpdatePRWithWriter creates a pull request with the workflow changes, +// writing output to the provided writer +func createUpdatePRWithWriter(w io.Writer, verbose bool) error { // Check if GitHub CLI is available if !isGHCLIAvailable() { return fmt.Errorf("GitHub CLI (gh) is required for PR creation but not found in PATH") @@ -45,12 +52,12 @@ func createUpdatePR(verbose bool) error { } if !hasChanges { - fmt.Fprintln(os.Stderr, console.FormatInfoMessage("No changes to create PR for")) + fmt.Fprintln(w, console.FormatInfoMessage("No changes to create PR for")) return nil } if verbose { - fmt.Fprintln(os.Stderr, console.FormatVerboseMessage("Creating pull request with workflow updates...")) + fmt.Fprintln(w, console.FormatVerboseMessage("Creating pull request with workflow updates...")) } // Create a branch name with timestamp @@ -86,8 +93,8 @@ func createUpdatePR(verbose bool) error { return fmt.Errorf("failed to create PR: %w\nOutput: %s", err, string(output)) } - fmt.Fprintln(os.Stderr, console.FormatSuccessMessage("Successfully created pull request")) - fmt.Fprintln(os.Stderr, console.FormatInfoMessage(strings.TrimSpace(string(output)))) + fmt.Fprintln(w, console.FormatSuccessMessage("Successfully created pull request")) + fmt.Fprintln(w, console.FormatInfoMessage(strings.TrimSpace(string(output)))) return nil } diff --git a/pkg/cli/update_workflows.go b/pkg/cli/update_workflows.go index 26bafbde08..a50d6c8fa5 100644 --- a/pkg/cli/update_workflows.go +++ b/pkg/cli/update_workflows.go @@ -2,6 +2,7 @@ package cli import ( "fmt" + "io" "os" "path/filepath" "strings" @@ -13,6 +14,12 @@ import ( // UpdateWorkflows updates workflows from their source repositories func UpdateWorkflows(workflowNames []string, allowMajor, force, verbose bool, engineOverride string, workflowsDir string, noStopAfter bool, stopAfter string, merge bool) error { + return UpdateWorkflowsWithWriter(os.Stderr, workflowNames, allowMajor, force, verbose, engineOverride, workflowsDir, noStopAfter, stopAfter, merge) +} + +// UpdateWorkflowsWithWriter updates workflows from their source repositories, +// writing output to the provided writer +func UpdateWorkflowsWithWriter(w io.Writer, workflowNames []string, allowMajor, force, verbose bool, engineOverride string, workflowsDir string, noStopAfter bool, stopAfter string, merge bool) error { updateLog.Printf("Scanning for workflows with source field: dir=%s, filter=%v, merge=%v", workflowsDir, workflowNames, merge) // Use provided workflows directory or default @@ -35,7 +42,7 @@ func UpdateWorkflows(workflowNames []string, allowMajor, force, verbose bool, en return fmt.Errorf("no workflows found with source field") } - fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("Found %d workflow(s) to update", len(workflows)))) + fmt.Fprintln(w, console.FormatInfoMessage(fmt.Sprintf("Found %d workflow(s) to update", len(workflows)))) // Track update results var successfulUpdates []string @@ -54,7 +61,7 @@ func UpdateWorkflows(workflowNames []string, allowMajor, force, verbose bool, en } // Show summary - showUpdateSummary(successfulUpdates, failedUpdates) + showUpdateSummaryWithWriter(w, successfulUpdates, failedUpdates) if len(successfulUpdates) == 0 { return fmt.Errorf("no workflows were successfully updated")