diff --git a/commands/cmdutils/factory.go b/commands/cmdutils/factory.go index cfb7a085..b17f8ea5 100644 --- a/commands/cmdutils/factory.go +++ b/commands/cmdutils/factory.go @@ -3,6 +3,7 @@ package cmdutils import ( "fmt" + "strings" "github.com/profclems/glab/internal/config" "github.com/profclems/glab/internal/git" @@ -55,11 +56,15 @@ func LabClientFunc(repoHost string, cfg config.Config, isGraphQL bool) (*gitlab. } func remotesFunc() (glrepo.Remotes, error) { + hostOverride := "" + if !strings.EqualFold(glinstance.Default(), glinstance.OverridableDefault()) { + hostOverride = glinstance.OverridableDefault() + } rr := &remoteResolver{ readRemotes: git.Remotes, getConfig: configFunc, } - fn := rr.Resolver() + fn := rr.Resolver(hostOverride) return fn() } diff --git a/commands/cmdutils/remote_resolver.go b/commands/cmdutils/remote_resolver.go index 5df0c1cb..8870da0d 100644 --- a/commands/cmdutils/remote_resolver.go +++ b/commands/cmdutils/remote_resolver.go @@ -4,6 +4,7 @@ import ( "errors" "net/url" "sort" + "strings" "github.com/profclems/glab/internal/config" "github.com/profclems/glab/internal/git" @@ -17,7 +18,7 @@ type remoteResolver struct { urlTranslator func(*url.URL) *url.URL } -func (rr *remoteResolver) Resolver() func() (glrepo.Remotes, error) { +func (rr *remoteResolver) Resolver(hostOverride string) func() (glrepo.Remotes, error) { var cachedRemotes glrepo.Remotes var remotesError error @@ -59,6 +60,22 @@ func (rr *remoteResolver) Resolver() func() (glrepo.Remotes, error) { var hostname string cachedRemotes = glrepo.Remotes{} sort.Sort(resolvedRemotes) + + if hostOverride != "" { + for _, r := range resolvedRemotes { + if strings.EqualFold(r.RepoHost(), hostOverride) { + cachedRemotes = append(cachedRemotes, r) + } + } + + if len(cachedRemotes) == 0 { + remotesError = errors.New("none of the git remotes configured for this repository correspond to the GITLAB_HOST environment variable. Try adding a matching remote or unsetting the variable") + return nil, remotesError + } + + return cachedRemotes, nil + } + for _, r := range resolvedRemotes { if hostname == "" { if !knownHosts[r.RepoHost()] { diff --git a/commands/cmdutils/remote_resolver_test.go b/commands/cmdutils/remote_resolver_test.go index c0868d95..421778c2 100644 --- a/commands/cmdutils/remote_resolver_test.go +++ b/commands/cmdutils/remote_resolver_test.go @@ -33,7 +33,7 @@ func Test_remoteResolver(t *testing.T) { }, } - resolver := rr.Resolver() + resolver := rr.Resolver("") remotes, err := resolver() require.NoError(t, err) require.Equal(t, 2, len(remotes)) @@ -41,3 +41,32 @@ func Test_remoteResolver(t *testing.T) { assert.Equal(t, "upstream", remotes[0].Name) assert.Equal(t, "fork", remotes[1].Name) } + +func Test_remoteResolverOverride(t *testing.T) { + rr := &remoteResolver{ + readRemotes: func() (git.RemoteSet, error) { + return git.RemoteSet{ + git.NewRemote("fork", "https://example.org/ghe-owner/ghe-fork.git"), + git.NewRemote("origin", "https://gitlab.com/owner/repo.git"), + git.NewRemote("upstream", "https://example.org/ghe-owner/ghe-repo.git"), + }, nil + }, + getConfig: func() (config.Config, error) { + return config.NewFromString(heredoc.Doc(` + hosts: + example.org: + oauth_token: GHETOKEN + `)), nil + }, + urlTranslator: func(u *url.URL) *url.URL { + return u + }, + } + + resolver := rr.Resolver("gitlab.com") + remotes, err := resolver() + require.NoError(t, err) + require.Equal(t, 1, len(remotes)) + + assert.Equal(t, "origin", remotes[0].Name) +}