Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions pkg/cli/git.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"path/filepath"
"strings"

"github.com/charmbracelet/huh"
"github.com/githubnext/gh-aw/pkg/console"
"github.com/githubnext/gh-aw/pkg/logger"
)

Expand Down Expand Up @@ -646,3 +648,121 @@ func commitAndPushChanges(commitMessage string, verbose bool) error {
gitLog.Print("Commit and push workflow completed successfully")
return nil
}

// getDefaultBranch gets the default branch name for the repository
func getDefaultBranch() (string, error) {
gitLog.Print("Getting default branch name")

// Get repository slug (owner/repo)
repoSlug := getRepositorySlugFromRemote()
if repoSlug == "" {
gitLog.Print("No remote repository configured, cannot determine default branch")
return "", fmt.Errorf("no remote repository configured")
}

// Parse owner and repo from slug
parts := strings.Split(repoSlug, "/")
if len(parts) != 2 {
gitLog.Printf("Invalid repository slug format: %s", repoSlug)
return "", fmt.Errorf("invalid repository slug format: %s", repoSlug)
}

owner, repo := parts[0], parts[1]

// Use gh CLI to get default branch from GitHub API
cmd := exec.Command("gh", "api", fmt.Sprintf("/repos/%s/%s", owner, repo), "--jq", ".default_branch")
output, err := cmd.Output()
if err != nil {
gitLog.Printf("Failed to get default branch: %v", err)
return "", fmt.Errorf("failed to get default branch: %w", err)
}

defaultBranch := strings.TrimSpace(string(output))
if defaultBranch == "" {
gitLog.Print("Empty default branch returned")
return "", fmt.Errorf("could not determine default branch")
}

gitLog.Printf("Default branch: %s", defaultBranch)
return defaultBranch, nil
}

// checkOnDefaultBranch checks if the current branch is the default branch
// Returns an error if no remote is configured or if not on the default branch
func checkOnDefaultBranch(verbose bool) error {
gitLog.Print("Checking if on default branch")

// Get current branch
currentBranch, err := getCurrentBranch()
if err != nil {
return fmt.Errorf("failed to get current branch: %w", err)
}

// Get default branch
defaultBranch, err := getDefaultBranch()
if err != nil {
// If no remote is configured, fail the push operation
if strings.Contains(err.Error(), "no remote repository configured") {
gitLog.Print("No remote configured, cannot push")
return fmt.Errorf("--push requires a remote repository to be configured")
}
return fmt.Errorf("failed to get default branch: %w", err)
}

// Compare branches
if currentBranch != defaultBranch {
gitLog.Printf("Not on default branch: current=%s, default=%s", currentBranch, defaultBranch)
return fmt.Errorf("not on default branch: current branch is '%s', default branch is '%s'", currentBranch, defaultBranch)
}

gitLog.Printf("On default branch: %s", currentBranch)
if verbose {
fmt.Fprintln(os.Stderr, console.FormatSuccessMessage(fmt.Sprintf("✓ On default branch: %s", currentBranch)))
}
return nil
}

// confirmPushOperation prompts the user to confirm push operation (skips in CI)
func confirmPushOperation(verbose bool) error {
gitLog.Print("Checking if user confirmation is needed for push operation")

// Skip confirmation in CI environments
if IsRunningInCI() {
gitLog.Print("Running in CI, skipping user confirmation")
if verbose {
fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Running in CI - skipping confirmation prompt"))
}
return nil
}

// Prompt user for confirmation
gitLog.Print("Prompting user for push confirmation")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, console.FormatWarningMessage("This will commit and push changes to the remote repository."))

var confirmed bool
form := huh.NewForm(
huh.NewGroup(
huh.NewConfirm().
Title("Do you want to proceed with commit and push?").
Description("This will stage all changes, commit them, and push to the remote repository").
Value(&confirmed),
),
).WithAccessible(console.IsAccessibleMode())

if err := form.Run(); err != nil {
gitLog.Printf("Confirmation prompt failed: %v", err)
return fmt.Errorf("confirmation prompt failed: %w", err)
}

if !confirmed {
gitLog.Print("User declined push operation")
return fmt.Errorf("push operation cancelled by user")
}

gitLog.Print("User confirmed push operation")
if verbose {
fmt.Fprintln(os.Stderr, console.FormatSuccessMessage("✓ Push operation confirmed"))
}
return nil
}
84 changes: 84 additions & 0 deletions pkg/cli/git_push_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package cli

import (
"os"
"os/exec"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// TestGetDefaultBranch tests the getDefaultBranch function
func TestGetDefaultBranch(t *testing.T) {
t.Run("no remote configured", func(t *testing.T) {
// Create a temporary directory for test
tmpDir := t.TempDir()
originalDir, _ := os.Getwd()
defer os.Chdir(originalDir)

// Initialize git repository without remote
os.Chdir(tmpDir)
exec.Command("git", "init").Run()
exec.Command("git", "config", "user.email", "test@example.com").Run()
exec.Command("git", "config", "user.name", "Test User").Run()

// Should fail because no remote is configured
_, err := getDefaultBranch()
require.Error(t, err, "Should fail when no remote is configured")
assert.Contains(t, err.Error(), "no remote repository configured")
})
}

// TestCheckOnDefaultBranch tests the checkOnDefaultBranch function
func TestCheckOnDefaultBranch(t *testing.T) {
t.Run("no remote configured - should fail", func(t *testing.T) {
// Create a temporary directory for test
tmpDir := t.TempDir()
originalDir, _ := os.Getwd()
defer os.Chdir(originalDir)

// Initialize git repository without remote
os.Chdir(tmpDir)
exec.Command("git", "init").Run()
exec.Command("git", "config", "user.email", "test@example.com").Run()
exec.Command("git", "config", "user.name", "Test User").Run()

// Create an initial commit
testFile := filepath.Join(tmpDir, "test.txt")
err := os.WriteFile(testFile, []byte("test"), 0644)
require.NoError(t, err)
exec.Command("git", "add", "test.txt").Run()
exec.Command("git", "commit", "-m", "initial commit").Run()

// Should fail when no remote is configured
err = checkOnDefaultBranch(false)
require.Error(t, err, "Should fail when no remote is configured")
assert.Contains(t, err.Error(), "--push requires a remote repository to be configured")
})
}

// TestConfirmPushOperation tests the confirmPushOperation function
func TestConfirmPushOperation(t *testing.T) {
t.Run("skips confirmation in CI", func(t *testing.T) {
// Set CI environment variable
origCI := os.Getenv("CI")
os.Setenv("CI", "true")
defer func() {
if origCI == "" {
os.Unsetenv("CI")
} else {
os.Setenv("CI", origCI)
}
}()

// Should succeed without prompting user
err := confirmPushOperation(false)
assert.NoError(t, err, "Should skip confirmation in CI")
})

// Note: Testing the interactive prompt outside CI is not feasible in automated tests
// as it requires user interaction. The function behavior in non-CI environments
// should be tested manually.
}
14 changes: 14 additions & 0 deletions pkg/cli/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,20 @@ func InitRepository(verbose bool, mcp bool, campaign bool, tokens bool, engine s
if push {
initLog.Print("Push enabled - preparing to commit and push changes")
fmt.Fprintln(os.Stderr, "")

// Check if we're on the default branch
fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Checking current branch..."))
if err := checkOnDefaultBranch(verbose); err != nil {
initLog.Printf("Default branch check failed: %v", err)
return fmt.Errorf("cannot push: %w", err)
}

// Confirm with user (skip in CI)
if err := confirmPushOperation(verbose); err != nil {
initLog.Printf("Push operation not confirmed: %v", err)
return fmt.Errorf("push operation cancelled: %w", err)
}

fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Preparing to commit and push changes..."))

// Use the helper function to orchestrate the full workflow
Expand Down
14 changes: 14 additions & 0 deletions pkg/cli/upgrade_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,20 @@ func runUpgradeCommand(verbose bool, workflowDir string, noFix bool, noCompile b
if push {
upgradeLog.Print("Push enabled - preparing to commit and push changes")
fmt.Fprintln(os.Stderr, "")

// Check if we're on the default branch
fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Checking current branch..."))
if err := checkOnDefaultBranch(verbose); err != nil {
upgradeLog.Printf("Default branch check failed: %v", err)
return fmt.Errorf("cannot push: %w", err)
}

// Confirm with user (skip in CI)
if err := confirmPushOperation(verbose); err != nil {
upgradeLog.Printf("Push operation not confirmed: %v", err)
return fmt.Errorf("push operation cancelled: %w", err)
}

fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Preparing to commit and push changes..."))

// Use the helper function to orchestrate the full workflow
Expand Down
7 changes: 4 additions & 3 deletions pkg/cli/upgrade_command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ This workflow is already up to date.
exec.Command("git", "commit", "-m", "Add agent files").Run()
}

// Run upgrade command with --push (should succeed but not create a new commit)
// Run upgrade command with --push (should fail because no remote is configured)
config := UpgradeConfig{
Verbose: false,
NoFix: true, // Skip codemods to avoid changes
Expand All @@ -442,6 +442,7 @@ This workflow is already up to date.
}

err = RunUpgrade(config)
// Should succeed even if no changes to commit
require.NoError(t, err, "Upgrade with --push should succeed when working directory is clean")
// Should fail because no remote is configured
require.Error(t, err, "Upgrade with --push should fail when no remote is configured")
assert.Contains(t, err.Error(), "--push requires a remote repository to be configured")
}