From d93976441969a6d5939c523af134cacce7d26b1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20M=C3=BCllner?= Date: Thu, 3 Dec 2020 23:30:49 +0100 Subject: [PATCH] mr: Default to remote tracking branch We currently assume that the local branch name matches the remote one, which isn't necessarily the case. Address this by using the remote tracking name where appropriate. --- cmd/mr_create.go | 14 ++++++++++++-- cmd/util.go | 9 +++++++-- internal/git/git.go | 27 +++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/cmd/mr_create.go b/cmd/mr_create.go index 827ac064..47fba264 100644 --- a/cmd/mr_create.go +++ b/cmd/mr_create.go @@ -105,15 +105,25 @@ func runMRCreate(cmd *cobra.Command, args []string) { log.Fatal(err) } + remoteBranch, err := git.CurrentUpstreamBranch() + if remoteBranch == "" { + // Fall back to local branch + remoteBranch, err = git.CurrentBranch() + } + + if err != nil { + log.Fatal(err) + } + p, err := lab.FindProject(sourceProjectName) if err != nil { log.Fatal(err) } - if _, err := lab.GetCommit(p.ID, branch); err != nil { + if _, err := lab.GetCommit(p.ID, remoteBranch); err != nil { err = errors.Wrapf( err, "aborting MR, source branch %s not present on remote %s. did you forget to push?", - branch, sourceRemote) + remoteBranch, sourceRemote) log.Fatal(err) } diff --git a/cmd/util.go b/cmd/util.go index ab18266b..53a6105c 100644 --- a/cmd/util.go +++ b/cmd/util.go @@ -69,9 +69,14 @@ func flagConfig(fs *flag.FlagSet) { func getCurrentBranchMR(rn string) int { var num int = 0 - currentBranch, err := git.CurrentBranch() + currentBranch, err := git.CurrentUpstreamBranch() + if currentBranch == "" { + // Fall back to local branch + currentBranch, err = git.CurrentBranch() + } + if err != nil { - log.Fatal(err) + return 0 } mrs, err := lab.MRList(rn, gitlab.ListProjectMergeRequestsOptions{ diff --git a/internal/git/git.go b/internal/git/git.go index 7ac7e0c3..f8a1b569 100644 --- a/internal/git/git.go +++ b/internal/git/git.go @@ -131,6 +131,33 @@ func CurrentBranch() (string, error) { return strings.TrimSpace(string(branch)), nil } +// CurrentUpstreamBranch returns the upstream of the currently checked out branch +func CurrentUpstreamBranch() (string, error) { + localBranch, err := CurrentBranch() + if err != nil { + return "", err + } + + branch, err := UpstreamBranch(localBranch) + if err != nil { + return "", err + } + return branch, nil +} + +// UpstreamBranch returns the upstream of the specified branch +func UpstreamBranch(branch string) (string, error) { + cmd := New("rev-parse", "--abbrev-ref", branch+"@{upstream}") + cmd.Stdout = nil + cmd.Stderr = nil + ref, err := cmd.Output() + if err != nil { + return "", errors.Errorf("No upstream for branch '%s'", branch) + } + upstreamBranch := strings.SplitN(string(ref), "/", 2)[1] + return strings.TrimSpace(upstreamBranch), nil +} + // PathWithNameSpace returns the owner/repository for the current repo // Such as zaquestion/lab // Respects GitLab subgroups (https://docs.gitlab.com/ce/user/group/subgroups/)