Skip to content

Commit

Permalink
refactor pr retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
Dawid Ciepiela committed Aug 30, 2024
1 parent 6002923 commit 20223c1
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 58 deletions.
36 changes: 31 additions & 5 deletions pkg/commands/command_pr.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
color "github.com/fatih/color"
configfile "github.com/sarumaj/gh-gr/v2/pkg/configfile"
restclient "github.com/sarumaj/gh-gr/v2/pkg/restclient"
"github.com/sarumaj/gh-gr/v2/pkg/restclient/resources"
util "github.com/sarumaj/gh-gr/v2/pkg/util"
logrus "github.com/sirupsen/logrus"
cobra "github.com/spf13/cobra"
Expand Down Expand Up @@ -93,66 +94,82 @@ var prCmd = func() *cobra.Command {
type pullRequestAction func(*restclient.RESTClient) func(context.Context, string, string, int) error

// buildPullSearchQuery builds a search query for pull requests.
func buildPullSearchQuery() string {
func buildPullSearchQuery() map[string]string {
var fragments []string
var useCustomQuery bool
filter := make(map[string]string)

if prFlags.base != "" {
filter["base"] = prFlags.base
fragments = append(fragments, fmt.Sprintf("base:%s", prFlags.base))
}

if prFlags.head != "" {
filter["head"] = prFlags.head
fragments = append(fragments, fmt.Sprintf("head:%s", prFlags.head))
}

if prFlags.closedInLast > 0 {
useCustomQuery = true
fragments = append(fragments, fmt.Sprintf("closed:>=%s", time.Now().Add(-prFlags.closedInLast).Format("2006-01-02T15:04:05Z")))
}

if prFlags.closedAfterLast > 0 {
useCustomQuery = true
fragments = append(fragments, fmt.Sprintf("closed:<=%s", time.Now().Add(-prFlags.closedAfterLast).Format("2006-01-02T15:04:05Z")))
}

for _, assignee := range prFlags.assignees {
useCustomQuery = true
if util.IsGlobMatch(assignee) {
continue
}
fragments = append(fragments, fmt.Sprintf("assignee:%s", assignee))
}

for _, author := range prFlags.authors {
useCustomQuery = true
if util.IsGlobMatch(author) {
continue
}
fragments = append(fragments, fmt.Sprintf("author:%s", author))
}

for _, label := range prFlags.labels {
useCustomQuery = true
if util.IsGlobMatch(label) {
continue
}
fragments = append(fragments, fmt.Sprintf("label:%s", label))
}

for _, title := range prFlags.titles {
useCustomQuery = true
if util.IsGlobMatch(title) || util.IsRegex(title) {
continue
}
fragments = append(fragments, fmt.Sprintf("%s in:title", title))
}

if prFlags.state != "" {
filter["state"] = prFlags.state
fragments = append(fragments, fmt.Sprintf("state:%s", prFlags.state))
}

if prFlags.customQuery != "" {
useCustomQuery = true
fragments = append(fragments, prFlags.customQuery)
}

return strings.Join(fragments, " ")
if useCustomQuery {
return map[string]string{"q": strings.Join(fragments, " ")}
}

return filter
}

// listPullRequests initializes pull requests.
func listPullRequests(conf *configfile.Configuration, filter string, list *configfile.PullRequestList, flush bool) {
func listPullRequests(conf *configfile.Configuration, filter map[string]string, list *configfile.PullRequestList, flush bool) {
operationLoop[configfile.Repository](prListOperation, "PRs list", operationContextMap{
"filter": filter,
"cache": make(map[string]*restclient.RESTClient),
Expand All @@ -179,6 +196,7 @@ func listPullRequests(conf *configfile.Configuration, filter string, list *confi
}
}

// prDoOperation performs an operation on a pull request.
func prDoOperation(_ pool.WorkUnit, args operationContext) {
conf := unwrapOperationContext[*configfile.Configuration](args, "conf")
pr := unwrapOperationContext[configfile.PullRequest](args, "object")
Expand Down Expand Up @@ -239,12 +257,14 @@ func prDoOperation(_ pool.WorkUnit, args operationContext) {
status.appendRow(pr.Title, pr.Number, pr.Status(), pr.Author, pr.Assignees, pr.Labels)
}

// prListOperation lists pull requests.
// Depending on the filter, it will use either pull requests API endpoint or the search API endpoint.
func prListOperation(_ pool.WorkUnit, args operationContext) {
conf := unwrapOperationContext[*configfile.Configuration](args, "conf")
repo := unwrapOperationContext[configfile.Repository](args, "object")
status := unwrapOperationContext[*operationStatus](args, "status")
keep := unwrapOperationContext[func(configfile.PullRequest) bool](args, "keep")
filter := unwrapOperationContext[string](args, "filter")
filter := unwrapOperationContext[map[string]string](args, "filter")
cache := unwrapOperationContext[map[string]*restclient.RESTClient](args, "cache")
list := unwrapOperationContext[*configfile.PullRequestList](args, "list")

Expand Down Expand Up @@ -279,7 +299,13 @@ func prListOperation(_ pool.WorkUnit, args operationContext) {
slug := configfile.GetRepositorySlugFromURL(repo)
owner, repoName, _ := strings.Cut(slug, "/")

pulls, err := client.GetOrgRepoPulls(args.Context, owner, repoName, filter)
var pulls []resources.PullRequest
var err error
if query := filter["q"]; query != "" {
pulls, err = client.SearchOrgRepoPulls(args.Context, owner, repoName, query)
} else {
pulls, err = client.GetOrgRepoPulls(args.Context, owner, repoName, filter)
}
if err != nil {
status.appendRow("", "", err, repo.Directory, "", []string{}, []string{})
return
Expand Down
71 changes: 48 additions & 23 deletions pkg/restclient/rest_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ func (c *RESTClient) GetOrg(ctx context.Context, name string) (org *resources.Or
// Get all repositories for given organization.
func (c *RESTClient) GetOrgRepos(ctx context.Context, name string) ([]resources.Repository, error) {
c.Progressbar.Describe("Retrieving repositories for GitHub organization: %s...", name)
return getPaged[resources.Repository](c, orgReposEp.Format(map[string]any{"owner": name}), ctx)
return getPaged[resources.Repository, []resources.Repository](c, orgReposEp.Format(map[string]any{"owner": name}), ctx)
}

// Get organizations.
func (c *RESTClient) GetOrgs(ctx context.Context) ([]resources.Organization, error) {
c.Progressbar.Describe("Retrieving GitHub organizations...")
return getPaged[resources.Organization](c, orgsEp, ctx)
return getPaged[resources.Organization, []resources.Organization](c, orgsEp, ctx)
}

// Get rate limit information.
Expand All @@ -124,34 +124,31 @@ func (c *RESTClient) GetRateLimit(ctx context.Context) (rate *resources.RateLimi
}

// Get all pull requests for given organization and repository.
func (c *RESTClient) GetOrgRepoPulls(ctx context.Context, name, repo string, filter string) (out []resources.PullRequest, err error) {
func (c *RESTClient) GetOrgRepoPulls(ctx context.Context, name, repo string, filter map[string]string) (out []resources.PullRequest, err error) {
c.Describe(fmt.Sprintf("Retrieving pull requests for GitHub repository: %s/%s...", name, repo))
searchQuery := fmt.Sprintf("is:pr repo:%s/%s", name, repo)
if filter != "" {
searchQuery += " " + filter
}

resp, err := c.RequestWithContext(ctx, http.MethodGet, newRequestPath(searchIssuesEp).Set("q", searchQuery).String(), nil)
pulls, err := getPaged[resources.PullRequest, []resources.PullRequest](c, pullsEp.Format(map[string]any{"owner": name, "repo": repo}), ctx, func(params *requestPath) {
params.
Register("state", "open", "closed", "all").
Register("sort", "created", "updated", "popularity", "long-running")
for k, v := range filter {
if v == "" {
continue
}
params.Set(k, v)
}
})
if err != nil {
return nil, err
}

var searchResults resources.SearchResult[resources.PullRequest]
if err := json.NewDecoder(resp.Body).Decode(&searchResults); err != nil {
return nil, err
}

for _, item := range searchResults.Items {
var pr resources.PullRequest
if err := c.DoWithContext(ctx, http.MethodGet, item.URL, nil, &pr); err != nil {
for i, pull := range pulls {
if err := c.DoWithContext(ctx, http.MethodGet, pull.URL, nil, &pulls[i]); err != nil {
return nil, err
}

pr.Repository = name + "/" + repo
out = append(out, pr)
pulls[i].Repository = name + "/" + repo
}

return out, nil
return pulls, nil
}

// Get GitHub user.
Expand All @@ -163,13 +160,13 @@ func (c *RESTClient) GetUser(ctx context.Context) (user *resources.User, err err
// Get all repositories for given user.
func (c *RESTClient) GetUserRepos(ctx context.Context) ([]resources.Repository, error) {
c.Progressbar.Describe("Retrieving repositories for current user...")
return getPaged[resources.Repository](c, userReposEp, ctx)
return getPaged[resources.Repository, []resources.Repository](c, userReposEp, ctx)
}

// get all organizations for given user.
func (c *RESTClient) GetUserOrgs(ctx context.Context) ([]resources.Organization, error) {
c.Progressbar.Describe("Retrieving GitHub organizations for current user...")
return getPaged[resources.Organization](c, userOrgsEp, ctx)
return getPaged[resources.Organization, []resources.Organization](c, userOrgsEp, ctx)
}

// Reopen a pull request.
Expand Down Expand Up @@ -217,6 +214,34 @@ func (c *RESTClient) RequestWithContext(ctx context.Context, method, path string
return resp, nil
}

// Search for pull requests in a repository.
func (c *RESTClient) SearchOrgRepoPulls(ctx context.Context, name, repo string, filter string) (out []resources.PullRequest, err error) {
c.Describe(fmt.Sprintf("Retrieving pull requests for GitHub repository: %s/%s...", name, repo))
searchQuery := fmt.Sprintf("is:pr repo:%s/%s", name, repo)
if filter != "" {
searchQuery += " " + filter
}

searchResults, err := getPaged[resources.PullRequest, resources.SearchResult[resources.PullRequest]](c, searchIssuesEp, ctx, func(rp *requestPath) {
rp.Set("q", searchQuery)
})
if err != nil {
return nil, err
}

for _, item := range searchResults.Items {
var pr resources.PullRequest
if err := c.DoWithContext(ctx, http.MethodGet, item.URL, nil, &pr); err != nil {
return nil, err
}

pr.Repository = name + "/" + repo
out = append(out, pr)
}

return out, nil
}

// Create new REST API client.
// The rate limit of the API will be checked upfront.
func NewRESTClient(conf *configfile.Configuration, options ClientOptions, retry bool) (*RESTClient, error) {
Expand Down
Loading

0 comments on commit 20223c1

Please sign in to comment.