Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added github autocompletion #84

Merged
merged 2 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cmd/close.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func CloseCmd() *cobra.Command {
}

cmd.Flags().StringP("branch", "B", "multi-gitter-branch", "The name of the branch where changes are committed.")
cmd.Flags().AddFlagSet(platformFlags())
configurePlatform(cmd)
cmd.Flags().AddFlagSet(logFlags("-"))

return cmd
Expand All @@ -30,7 +30,7 @@ func close(cmd *cobra.Command, args []string) error {

branchName, _ := flag.GetString("branch")

vc, err := getVersionController(flag)
vc, err := getVersionController(flag, true)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func MergeCmd() *cobra.Command {

cmd.Flags().StringP("branch", "B", "multi-gitter-branch", "The name of the branch where changes are committed.")
cmd.Flags().StringSliceP("merge-type", "", []string{"merge", "squash", "rebase"}, "The type of merge that should be done (GitHub). Multiple types can be used as backup strategies if the first one is not allowed.")
cmd.Flags().AddFlagSet(platformFlags())
configurePlatform(cmd)
cmd.Flags().AddFlagSet(logFlags("-"))

return cmd
Expand All @@ -31,7 +31,7 @@ func merge(cmd *cobra.Command, args []string) error {

branchName, _ := flag.GetString("branch")

vc, err := getVersionController(flag)
vc, err := getVersionController(flag, true)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func PrintCmd() *cobra.Command {
cmd.Flags().IntP("concurrent", "C", 1, "The maximum number of concurrent runs")
cmd.Flags().IntP("fetch-depth", "f", 1, "Limit fetching to the specified number of commits. Set to 0 for no limit")
cmd.Flags().StringP("error-output", "E", "-", `The file that the output of the script should be outputted to. "-" means stderr`)
cmd.Flags().AddFlagSet(platformFlags())
configurePlatform(cmd)
cmd.Flags().AddFlagSet(logFlags(""))
cmd.Flags().AddFlagSet(outputFlag())

Expand Down Expand Up @@ -76,7 +76,7 @@ func print(cmd *cobra.Command, args []string) error {
return errors.New("could not get the working directory")
}

vc, err := getVersionController(flag)
vc, err := getVersionController(flag, true)
if err != nil {
return err
}
Expand Down
97 changes: 87 additions & 10 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -41,8 +42,8 @@ func init() {
rand.Seed(time.Now().UTC().UnixNano())
}

func platformFlags() *flag.FlagSet {
flags := flag.NewFlagSet("platform", flag.ExitOnError)
func configurePlatform(cmd *cobra.Command) {
flags := cmd.Flags()

flags.StringP("base-url", "g", "", "Base URL of the (v3) GitHub API, needs to be changed if GitHub enterprise is used. Or the url to a self-hosted GitLab instance.")
flags.StringP("token", "T", "", "The GitHub/GitLab personal access token. Can also be set using the GITHUB_TOKEN/GITLAB_TOKEN environment variable.")
Expand All @@ -55,7 +56,77 @@ func platformFlags() *flag.FlagSet {

flags.StringP("platform", "p", "github", "The platform that is used. Available values: github, gitlab")

return flags
// Autocompletion for organizations
_ = cmd.RegisterFlagCompletionFunc("org", func(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) {
vc, err := getVersionController(cmd.Flags(), false)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

type getOrger interface {
GetAutocompleteOrganizations(ctx context.Context, _ string) ([]string, error)
}

g, ok := vc.(getOrger)
if !ok {
return nil, cobra.ShellCompDirectiveError
}

orgs, err := g.GetAutocompleteOrganizations(cmd.Root().Context(), toComplete)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

return orgs, cobra.ShellCompDirectiveDefault
})

// Autocompletion for users
_ = cmd.RegisterFlagCompletionFunc("user", func(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) {
vc, err := getVersionController(cmd.Flags(), false)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

type getUserser interface {
GetAutocompleteUsers(ctx context.Context, _ string) ([]string, error)
}

g, ok := vc.(getUserser)
if !ok {
return nil, cobra.ShellCompDirectiveError
}

users, err := g.GetAutocompleteUsers(cmd.Root().Context(), toComplete)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

return users, cobra.ShellCompDirectiveDefault
})

// Autocompletion for repositories
_ = cmd.RegisterFlagCompletionFunc("repo", func(cmd *cobra.Command, _ []string, toComplete string) ([]string, cobra.ShellCompDirective) {
vc, err := getVersionController(cmd.Flags(), false)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

type getRepositorieser interface {
GetAutocompleteRepositories(ctx context.Context, _ string) ([]string, error)
}

g, ok := vc.(getRepositorieser)
if !ok {
return nil, cobra.ShellCompDirectiveError
}

users, err := g.GetAutocompleteRepositories(cmd.Root().Context(), toComplete)
if err != nil {
return nil, cobra.ShellCompDirectiveError
}

return users, cobra.ShellCompDirectiveDefault
})
}

func logFlags(logFile string) *flag.FlagSet {
Expand Down Expand Up @@ -119,7 +190,9 @@ func outputFlag() *flag.FlagSet {
// This is used to override the version controller with a mock, to be used during testing
var OverrideVersionController multigitter.VersionController = nil

func getVersionController(flag *flag.FlagSet) (multigitter.VersionController, error) {
// getVersionController gets the complete version controller
// the verifyFlags parameter can be set to false if a complete vc is not required (during autocompletion)
func getVersionController(flag *flag.FlagSet, verifyFlags bool) (multigitter.VersionController, error) {
if OverrideVersionController != nil {
return OverrideVersionController, nil
}
Expand All @@ -129,21 +202,21 @@ func getVersionController(flag *flag.FlagSet) (multigitter.VersionController, er
default:
return nil, fmt.Errorf("unknown platform: %s", platform)
case "github":
return createGithubClient(flag)
return createGithubClient(flag, verifyFlags)
case "gitlab":
return createGitlabClient(flag)
return createGitlabClient(flag, verifyFlags)
}
}

func createGithubClient(flag *flag.FlagSet) (multigitter.VersionController, error) {
func createGithubClient(flag *flag.FlagSet, verifyFlags bool) (multigitter.VersionController, error) {
gitBaseURL, _ := flag.GetString("base-url")
orgs, _ := flag.GetStringSlice("org")
users, _ := flag.GetStringSlice("user")
repos, _ := flag.GetStringSlice("repo")
mergeTypeStrs, _ := flag.GetStringSlice("merge-type") // Only used for the merge command

if len(orgs) == 0 && len(users) == 0 && len(repos) == 0 {
return nil, errors.New("no organization or user set")
if verifyFlags && len(orgs) == 0 && len(users) == 0 && len(repos) == 0 {
return nil, errors.New("no organization, user or repo set")
}

token, err := getToken(flag)
Expand Down Expand Up @@ -180,12 +253,16 @@ func createGithubClient(flag *flag.FlagSet) (multigitter.VersionController, erro
return vc, nil
}

func createGitlabClient(flag *flag.FlagSet) (multigitter.VersionController, error) {
func createGitlabClient(flag *flag.FlagSet, verifyFlags bool) (multigitter.VersionController, error) {
gitBaseURL, _ := flag.GetString("base-url")
groups, _ := flag.GetStringSlice("group")
users, _ := flag.GetStringSlice("user")
projects, _ := flag.GetStringSlice("project")

if verifyFlags && len(groups) == 0 && len(users) == 0 && len(projects) == 0 {
return nil, errors.New("no group user or project set")
}

token, err := getToken(flag)
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func RunCmd() *cobra.Command {
cmd.Flags().BoolP("dry-run", "d", false, "Run without pushing changes or creating pull requests")
cmd.Flags().StringP("author-name", "", "", "Name of the committer. If not set, the global git config setting will be used.")
cmd.Flags().StringP("author-email", "", "", "Email of the committer. If not set, the global git config setting will be used.")
cmd.Flags().AddFlagSet(platformFlags())
configurePlatform(cmd)
cmd.Flags().AddFlagSet(logFlags("-"))
cmd.Flags().AddFlagSet(outputFlag())

Expand Down Expand Up @@ -121,7 +121,7 @@ func run(cmd *cobra.Command, args []string) error {
return errors.New("could not get the working directory")
}

vc, err := getVersionController(flag)
vc, err := getVersionController(flag, true)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions cmd/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func StatusCmd() *cobra.Command {
}

cmd.Flags().StringP("branch", "B", "multi-gitter-branch", "The name of the branch where changes are committed.")
cmd.Flags().AddFlagSet(platformFlags())
configurePlatform(cmd)
cmd.Flags().AddFlagSet(logFlags("-"))
cmd.Flags().AddFlagSet(outputFlag())

Expand All @@ -33,7 +33,7 @@ func status(cmd *cobra.Command, args []string) error {
branchName, _ := flag.GetString("branch")
strOutput, _ := flag.GetString("output")

vc, err := getVersionController(flag)
vc, err := getVersionController(flag, true)
if err != nil {
return err
}
Expand Down
58 changes: 58 additions & 0 deletions internal/github/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,61 @@ func (g Github) ClosePullRequest(ctx context.Context, pullReq domain.PullRequest
_, err = g.ghClient.Git.DeleteRef(ctx, pr.ownerName, pr.repoName, fmt.Sprintf("heads/%s", pr.branchName))
return err
}

// GetAutocompleteOrganizations gets organizations for autocompletion
func (g Github) GetAutocompleteOrganizations(ctx context.Context, _ string) ([]string, error) {
orgs, _, err := g.ghClient.Organizations.List(ctx, "", nil)
if err != nil {
return nil, err
}

ret := make([]string, len(orgs))
for i, org := range orgs {
ret[i] = org.GetLogin()
}

return ret, nil
}

// GetAutocompleteUsers gets users for autocompletion
func (g Github) GetAutocompleteUsers(ctx context.Context, str string) ([]string, error) {
users, _, err := g.ghClient.Search.Users(ctx, str, nil)
if err != nil {
return nil, err
}

ret := make([]string, len(users.Users))
for i, user := range users.Users {
ret[i] = user.GetLogin()
}

return ret, nil
}

// GetAutocompleteRepositories gets repositories for autocompletion
func (g Github) GetAutocompleteRepositories(ctx context.Context, str string) ([]string, error) {
var q string

// If the user has already provided a org/user, it's much more effective to search based on that
// comparared to a complete freetext search
splitted := strings.SplitN(str, "/", 2)
switch {
case len(splitted) == 2:
// Search set the user or org (user/org in the search can be used interchangeable)
q = fmt.Sprintf("user:%s %s in:name", splitted[0], splitted[1])
default:
q = fmt.Sprintf("%s in:name", str)
}

repos, _, err := g.ghClient.Search.Repositories(ctx, q, nil)
if err != nil {
return nil, err
}

ret := make([]string, len(repos.Repositories))
for i, repositories := range repos.Repositories {
ret[i] = repositories.GetFullName()
}

return ret, nil
}
52 changes: 48 additions & 4 deletions tests/table_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tests

import (
"bytes"
"fmt"
"io/ioutil"
"os"
Expand All @@ -17,6 +18,7 @@ import (
type runData struct {
out string
logOut string
cmdOut string
took time.Duration
}

Expand Down Expand Up @@ -323,6 +325,42 @@ Repositories with a successful run:
assert.Equal(t, "i like bananas", readTestFile(t, vcMock.Repositories[0].Path))
},
},

{
name: "autocomplete org",
vc: &vcmock.VersionController{},
args: []string{
"__complete", "run",
"--org", "dynamic-org",
},
verify: func(t *testing.T, vcMock *vcmock.VersionController, runData runData) {
assert.Equal(t, "static-org\ndynamic-org\n:0\nCompletion ended with directive: ShellCompDirectiveDefault\n", runData.cmdOut)
},
},

{
name: "autocomplete user",
vc: &vcmock.VersionController{},
args: []string{
"__complete", "run",
"--user", "dynamic-user",
},
verify: func(t *testing.T, vcMock *vcmock.VersionController, runData runData) {
assert.Equal(t, "static-user\ndynamic-user\n:0\nCompletion ended with directive: ShellCompDirectiveDefault\n", runData.cmdOut)
},
},

{
name: "autocomplete repo",
vc: &vcmock.VersionController{},
args: []string{
"__complete", "run",
"--repo", "dynamic-repo",
},
verify: func(t *testing.T, vcMock *vcmock.VersionController, runData runData) {
assert.Equal(t, "static-repo\ndynamic-repo\n:0\nCompletion ended with directive: ShellCompDirectiveDefault\n", runData.cmdOut)
},
},
}

for _, test := range tests {
Expand All @@ -343,12 +381,17 @@ Repositories with a successful run:
}
cmd.OverrideVersionController = vc

command := cmd.RootCmd()
command.SetArgs(append(
test.args,
cobraBuf := &bytes.Buffer{}

staticArgs := []string{
"--log-file", logFile.Name(),
"--output", outFile.Name(),
))
}

command := cmd.RootCmd()
command.SetOut(cobraBuf)
command.SetErr(cobraBuf)
command.SetArgs(append(staticArgs, test.args...))
before := time.Now()
err = command.Execute()
took := time.Since(before)
Expand All @@ -367,6 +410,7 @@ Repositories with a successful run:
test.verify(t, vc, runData{
logOut: string(logData),
out: string(outData),
cmdOut: cobraBuf.String(),
took: took,
})
})
Expand Down
Loading