diff --git a/pkg/cli/git.go b/pkg/cli/git.go index 9a9ce64860..61d039f0b3 100644 --- a/pkg/cli/git.go +++ b/pkg/cli/git.go @@ -7,6 +7,8 @@ import ( "path/filepath" "strings" + "github.com/charmbracelet/huh" + "github.com/githubnext/gh-aw/pkg/console" "github.com/githubnext/gh-aw/pkg/logger" ) @@ -646,3 +648,121 @@ func commitAndPushChanges(commitMessage string, verbose bool) error { gitLog.Print("Commit and push workflow completed successfully") return nil } + +// getDefaultBranch gets the default branch name for the repository +func getDefaultBranch() (string, error) { + gitLog.Print("Getting default branch name") + + // Get repository slug (owner/repo) + repoSlug := getRepositorySlugFromRemote() + if repoSlug == "" { + gitLog.Print("No remote repository configured, cannot determine default branch") + return "", fmt.Errorf("no remote repository configured") + } + + // Parse owner and repo from slug + parts := strings.Split(repoSlug, "/") + if len(parts) != 2 { + gitLog.Printf("Invalid repository slug format: %s", repoSlug) + return "", fmt.Errorf("invalid repository slug format: %s", repoSlug) + } + + owner, repo := parts[0], parts[1] + + // Use gh CLI to get default branch from GitHub API + cmd := exec.Command("gh", "api", fmt.Sprintf("/repos/%s/%s", owner, repo), "--jq", ".default_branch") + output, err := cmd.Output() + if err != nil { + gitLog.Printf("Failed to get default branch: %v", err) + return "", fmt.Errorf("failed to get default branch: %w", err) + } + + defaultBranch := strings.TrimSpace(string(output)) + if defaultBranch == "" { + gitLog.Print("Empty default branch returned") + return "", fmt.Errorf("could not determine default branch") + } + + gitLog.Printf("Default branch: %s", defaultBranch) + return defaultBranch, nil +} + +// checkOnDefaultBranch checks if the current branch is the default branch +// Returns an error if no remote is configured or if not on the default branch +func checkOnDefaultBranch(verbose bool) error { + gitLog.Print("Checking if on default branch") + + // Get current branch + currentBranch, err := getCurrentBranch() + if err != nil { + return fmt.Errorf("failed to get current branch: %w", err) + } + + // Get default branch + defaultBranch, err := getDefaultBranch() + if err != nil { + // If no remote is configured, fail the push operation + if strings.Contains(err.Error(), "no remote repository configured") { + gitLog.Print("No remote configured, cannot push") + return fmt.Errorf("--push requires a remote repository to be configured") + } + return fmt.Errorf("failed to get default branch: %w", err) + } + + // Compare branches + if currentBranch != defaultBranch { + gitLog.Printf("Not on default branch: current=%s, default=%s", currentBranch, defaultBranch) + return fmt.Errorf("not on default branch: current branch is '%s', default branch is '%s'", currentBranch, defaultBranch) + } + + gitLog.Printf("On default branch: %s", currentBranch) + if verbose { + fmt.Fprintln(os.Stderr, console.FormatSuccessMessage(fmt.Sprintf("✓ On default branch: %s", currentBranch))) + } + return nil +} + +// confirmPushOperation prompts the user to confirm push operation (skips in CI) +func confirmPushOperation(verbose bool) error { + gitLog.Print("Checking if user confirmation is needed for push operation") + + // Skip confirmation in CI environments + if IsRunningInCI() { + gitLog.Print("Running in CI, skipping user confirmation") + if verbose { + fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Running in CI - skipping confirmation prompt")) + } + return nil + } + + // Prompt user for confirmation + gitLog.Print("Prompting user for push confirmation") + fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(os.Stderr, console.FormatWarningMessage("This will commit and push changes to the remote repository.")) + + var confirmed bool + form := huh.NewForm( + huh.NewGroup( + huh.NewConfirm(). + Title("Do you want to proceed with commit and push?"). + Description("This will stage all changes, commit them, and push to the remote repository"). + Value(&confirmed), + ), + ).WithAccessible(console.IsAccessibleMode()) + + if err := form.Run(); err != nil { + gitLog.Printf("Confirmation prompt failed: %v", err) + return fmt.Errorf("confirmation prompt failed: %w", err) + } + + if !confirmed { + gitLog.Print("User declined push operation") + return fmt.Errorf("push operation cancelled by user") + } + + gitLog.Print("User confirmed push operation") + if verbose { + fmt.Fprintln(os.Stderr, console.FormatSuccessMessage("✓ Push operation confirmed")) + } + return nil +} diff --git a/pkg/cli/git_push_test.go b/pkg/cli/git_push_test.go new file mode 100644 index 0000000000..6182396f37 --- /dev/null +++ b/pkg/cli/git_push_test.go @@ -0,0 +1,84 @@ +package cli + +import ( + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestGetDefaultBranch tests the getDefaultBranch function +func TestGetDefaultBranch(t *testing.T) { + t.Run("no remote configured", func(t *testing.T) { + // Create a temporary directory for test + tmpDir := t.TempDir() + originalDir, _ := os.Getwd() + defer os.Chdir(originalDir) + + // Initialize git repository without remote + os.Chdir(tmpDir) + exec.Command("git", "init").Run() + exec.Command("git", "config", "user.email", "test@example.com").Run() + exec.Command("git", "config", "user.name", "Test User").Run() + + // Should fail because no remote is configured + _, err := getDefaultBranch() + require.Error(t, err, "Should fail when no remote is configured") + assert.Contains(t, err.Error(), "no remote repository configured") + }) +} + +// TestCheckOnDefaultBranch tests the checkOnDefaultBranch function +func TestCheckOnDefaultBranch(t *testing.T) { + t.Run("no remote configured - should fail", func(t *testing.T) { + // Create a temporary directory for test + tmpDir := t.TempDir() + originalDir, _ := os.Getwd() + defer os.Chdir(originalDir) + + // Initialize git repository without remote + os.Chdir(tmpDir) + exec.Command("git", "init").Run() + exec.Command("git", "config", "user.email", "test@example.com").Run() + exec.Command("git", "config", "user.name", "Test User").Run() + + // Create an initial commit + testFile := filepath.Join(tmpDir, "test.txt") + err := os.WriteFile(testFile, []byte("test"), 0644) + require.NoError(t, err) + exec.Command("git", "add", "test.txt").Run() + exec.Command("git", "commit", "-m", "initial commit").Run() + + // Should fail when no remote is configured + err = checkOnDefaultBranch(false) + require.Error(t, err, "Should fail when no remote is configured") + assert.Contains(t, err.Error(), "--push requires a remote repository to be configured") + }) +} + +// TestConfirmPushOperation tests the confirmPushOperation function +func TestConfirmPushOperation(t *testing.T) { + t.Run("skips confirmation in CI", func(t *testing.T) { + // Set CI environment variable + origCI := os.Getenv("CI") + os.Setenv("CI", "true") + defer func() { + if origCI == "" { + os.Unsetenv("CI") + } else { + os.Setenv("CI", origCI) + } + }() + + // Should succeed without prompting user + err := confirmPushOperation(false) + assert.NoError(t, err, "Should skip confirmation in CI") + }) + + // Note: Testing the interactive prompt outside CI is not feasible in automated tests + // as it requires user interaction. The function behavior in non-CI environments + // should be tested manually. +} diff --git a/pkg/cli/init.go b/pkg/cli/init.go index 59b574c9f5..c68075f4db 100644 --- a/pkg/cli/init.go +++ b/pkg/cli/init.go @@ -693,6 +693,20 @@ func InitRepository(verbose bool, mcp bool, campaign bool, tokens bool, engine s if push { initLog.Print("Push enabled - preparing to commit and push changes") fmt.Fprintln(os.Stderr, "") + + // Check if we're on the default branch + fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Checking current branch...")) + if err := checkOnDefaultBranch(verbose); err != nil { + initLog.Printf("Default branch check failed: %v", err) + return fmt.Errorf("cannot push: %w", err) + } + + // Confirm with user (skip in CI) + if err := confirmPushOperation(verbose); err != nil { + initLog.Printf("Push operation not confirmed: %v", err) + return fmt.Errorf("push operation cancelled: %w", err) + } + fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Preparing to commit and push changes...")) // Use the helper function to orchestrate the full workflow diff --git a/pkg/cli/upgrade_command.go b/pkg/cli/upgrade_command.go index 3105ac471d..9fe27358b0 100644 --- a/pkg/cli/upgrade_command.go +++ b/pkg/cli/upgrade_command.go @@ -182,6 +182,20 @@ func runUpgradeCommand(verbose bool, workflowDir string, noFix bool, noCompile b if push { upgradeLog.Print("Push enabled - preparing to commit and push changes") fmt.Fprintln(os.Stderr, "") + + // Check if we're on the default branch + fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Checking current branch...")) + if err := checkOnDefaultBranch(verbose); err != nil { + upgradeLog.Printf("Default branch check failed: %v", err) + return fmt.Errorf("cannot push: %w", err) + } + + // Confirm with user (skip in CI) + if err := confirmPushOperation(verbose); err != nil { + upgradeLog.Printf("Push operation not confirmed: %v", err) + return fmt.Errorf("push operation cancelled: %w", err) + } + fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Preparing to commit and push changes...")) // Use the helper function to orchestrate the full workflow diff --git a/pkg/cli/upgrade_command_test.go b/pkg/cli/upgrade_command_test.go index abb961397c..7046657472 100644 --- a/pkg/cli/upgrade_command_test.go +++ b/pkg/cli/upgrade_command_test.go @@ -433,7 +433,7 @@ This workflow is already up to date. exec.Command("git", "commit", "-m", "Add agent files").Run() } - // Run upgrade command with --push (should succeed but not create a new commit) + // Run upgrade command with --push (should fail because no remote is configured) config := UpgradeConfig{ Verbose: false, NoFix: true, // Skip codemods to avoid changes @@ -442,6 +442,7 @@ This workflow is already up to date. } err = RunUpgrade(config) - // Should succeed even if no changes to commit - require.NoError(t, err, "Upgrade with --push should succeed when working directory is clean") + // Should fail because no remote is configured + require.Error(t, err, "Upgrade with --push should fail when no remote is configured") + assert.Contains(t, err.Error(), "--push requires a remote repository to be configured") }