Skip to content

Commit

Permalink
Merge pull request #587 from hashicorp/f/custom-paging
Browse files Browse the repository at this point in the history
Base Client: support custom paging
  • Loading branch information
manicminer authored Aug 9, 2023
2 parents 4efc5c2 + a85ba58 commit f0fed54
Show file tree
Hide file tree
Showing 11 changed files with 298 additions and 32 deletions.
63 changes: 41 additions & 22 deletions sdk/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@ type Request struct {
ValidStatusFunc ValidStatusFunc

Client BaseClient
Pager odata.CustomPager

// Embed *http.Request so that we can send this to an *http.Client
*http.Request
}

// Marshal serializes a payload body and adds it to the *Request
func (r *Request) Marshal(payload interface{}) error {
contentType := strings.ToLower(r.Header.Get("Content-Type"))

Expand Down Expand Up @@ -120,14 +122,17 @@ func (r *Request) Marshal(payload interface{}) error {
return fmt.Errorf("internal-error: unimplemented marshal function for content type %q", contentType)
}

// Execute invokes the Execute method for the Request's Client
func (r *Request) Execute(ctx context.Context) (*Response, error) {
return r.Client.Execute(ctx, r)
}

// ExecutePaged invokes the ExecutePaged method for the Request's Client
func (r *Request) ExecutePaged(ctx context.Context) (*Response, error) {
return r.Client.ExecutePaged(ctx, r)
}

// IsIdempotent determines whether a Request can be safely retried when encountering a connection failure
func (r *Request) IsIdempotent() bool {
switch strings.ToUpper(r.Method) {
case http.MethodGet, http.MethodHead, http.MethodOptions:
Expand All @@ -144,6 +149,7 @@ type Response struct {
*http.Response
}

// Unmarshal deserializes a response body into the provided model
func (r *Response) Unmarshal(model interface{}) error {
if model == nil {
return fmt.Errorf("model was nil")
Expand Down Expand Up @@ -302,6 +308,7 @@ func (c *Client) NewRequest(ctx context.Context, input RequestOptions) (*Request
ret := Request{
Client: c,
Request: req,
Pager: input.Pager,
ValidStatusCodes: input.ExpectedStatusCodes,
}

Expand Down Expand Up @@ -455,50 +462,62 @@ func (c *Client) ExecutePaged(ctx context.Context, req *Request) (*Response, err
return resp, fmt.Errorf("unsupported content-type %q received, only application/json is supported for paged results", contentType)
}

// Read the response body and close it
respBody, err := io.ReadAll(resp.Body)
// Unmarshal the response
firstOdata, err := odata.FromResponse(resp.Response)
if err != nil {
return resp, fmt.Errorf("could not parse response body")
return resp, err
}
resp.Body.Close()

// Unmarshal firstOdata
var firstOdata odata.OData
if err := json.Unmarshal(respBody, &firstOdata); err != nil {
return resp, err
if firstOdata == nil {
// No results, return early
return resp, nil
}

// Get results from this page
firstValue, ok := firstOdata.Value.([]interface{})
if firstOdata.NextLink == nil || firstValue == nil || !ok {
// No more pages, reassign response body and return
resp.Body = io.NopCloser(bytes.NewBuffer(respBody))
if !ok || firstValue == nil {
// No more results on this page
return resp, nil
}

// Get the next page, recursively
// TODO: may have to accommodate APIs with nonstandard paging
// Get a Link for the next results page
var nextLink *odata.Link
if req.Pager == nil {
nextLink = firstOdata.NextLink
} else {
nextLink, err = odata.NextLinkFromCustomPager(resp.Response, req.Pager)
if err != nil {
return resp, err
}
}
if nextLink == nil {
// This is the last page
return resp, nil
}

// Build request for the next page
nextReq := req
u, err := url.Parse(string(*firstOdata.NextLink))
u, err := url.Parse(string(*nextLink))
if err != nil {
return resp, err
}
nextReq.URL = u

// Retrieve the next page, descend recursively
nextResp, err := c.ExecutePaged(ctx, req)
if err != nil {
return resp, err
}

// Read the next page response body and close it
nextRespBody, err := io.ReadAll(nextResp.Body)
// Unmarshal nextOdata from the next page
nextOdata, err := odata.FromResponse(nextResp.Response)
if err != nil {
return resp, fmt.Errorf("could not parse response body")
return nextResp, err
}
nextResp.Body.Close()

// Unmarshal nextOdata from the next page
var nextOdata odata.OData
if err := json.Unmarshal(nextRespBody, &nextOdata); err != nil {
return nextResp, err
if nextOdata == nil {
// No more results, return early
return resp, nil
}

// When next page has results, append to current page
Expand Down
171 changes: 169 additions & 2 deletions sdk/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,59 @@ import (
"encoding/xml"
"fmt"
"io"
"log"
"net/http"
"net/url"
"reflect"
"testing"

"github.com/hashicorp/go-azure-helpers/lang/pointer"
"github.com/hashicorp/go-azure-sdk/sdk/internal/test"
"github.com/hashicorp/go-azure-sdk/sdk/odata"
)

var _ BaseClient = &testClient{}

type testClient struct {
*Client
}

func (c *testClient) NewRequest(ctx context.Context, input RequestOptions) (*Request, error) {
req, err := c.Client.NewRequest(ctx, input)
if err != nil {
return nil, fmt.Errorf("building %s request: %+v", input.HttpMethod, err)
}

req.Client = c
query := url.Values{}

if input.OptionsObject != nil {
if h := input.OptionsObject.ToHeaders(); h != nil {
for k, v := range h.Headers() {
req.Header[k] = v
}
}

if q := input.OptionsObject.ToQuery(); q != nil {
for k, v := range q.Values() {
// we intentionally only add one of each type
query.Del(k)
query.Add(k, v[0])
}
}

if o := input.OptionsObject.ToOData(); o != nil {
req.Header = o.AppendHeaders(req.Header)
query = o.AppendValues(query)
}
}

req.URL.RawQuery = query.Encode()
req.ValidStatusCodes = input.ExpectedStatusCodes

return req, nil
}

func TestAccClient(t *testing.T) {
test.AccTest(t)

Expand All @@ -30,7 +75,9 @@ func TestAccClient(t *testing.T) {
}
conn.Authorize(ctx, t, api)

c := NewClient(*endpoint, "example", "2020-01-01")
c := &testClient{
Client: NewClient(*endpoint, "example", "2020-01-01"),
}
c.Authorizer = conn.Authorizer

path := fmt.Sprintf("/v1.0/servicePrincipals/%s", conn.Claims.ObjectId)
Expand All @@ -48,8 +95,128 @@ func TestAccClient(t *testing.T) {
t.Fatal(err)
}

_, err = req.ExecutePaged(ctx)
resp, err := req.Execute(ctx)
if err != nil {
t.Fatalf("Execute(): %v", err)
}

fmt.Printf("%#v", resp)
}

var _ Options = &requestOptions{}

type requestOptions struct {
query *odata.Query
}

func (r *requestOptions) ToHeaders() *Headers { return nil }
func (r *requestOptions) ToOData() *odata.Query { return r.query }
func (r *requestOptions) ToQuery() *QueryParams { return nil }

func TestAccClient_Paged(t *testing.T) {
test.AccTest(t)

ctx := context.TODO()
conn := test.NewConnection(t)
api := conn.AuthConfig.Environment.MicrosoftGraph
endpoint, ok := api.Endpoint()
if !ok {
t.Fatalf("missing endpoint for microsoft graph for this environment")
}
conn.Authorize(ctx, t, api)

c := &testClient{
Client: NewClient(*endpoint, "example", "2020-01-01"),
}
c.Authorizer = conn.Authorizer

path := "/v1.0/applications"
reqOpts := RequestOptions{
ContentType: "application/json",
ExpectedStatusCodes: []int{
http.StatusOK,
},
HttpMethod: http.MethodGet,
OptionsObject: &requestOptions{
query: &odata.Query{
Filter: "startsWith(displayName,'acctest')",
Select: []string{"appId", "displayName"},
Top: 10,
},
},
Path: path,
}
req, err := c.NewRequest(ctx, reqOpts)
if err != nil {
t.Fatal(err)
}

if _, err = req.ExecutePaged(ctx); err != nil {
t.Fatalf("ExecutePaged(): %v", err)
}
}

var _ odata.CustomPager = &pager{}

type pager struct {
NextLink *odata.Link `json:"@odata.nextLink"`
}

func (p *pager) NextPageLink() *odata.Link {
if p == nil {
log.Fatalf("pager: p was nil")
}
if p.NextLink == nil {
log.Printf("[DEBUG] pager: nextLink was nil")
} else {
log.Printf("[DEBUG] pager: found custom nextLink %q", *p.NextLink)
}
defer func() {
p.NextLink = nil
}()
return p.NextLink
}

func TestAccClient_CustomPaged(t *testing.T) {
test.AccTest(t)

ctx := context.TODO()
conn := test.NewConnection(t)
api := conn.AuthConfig.Environment.MicrosoftGraph
endpoint, ok := api.Endpoint()
if !ok {
t.Fatalf("missing endpoint for microsoft graph for this environment")
}
conn.Authorize(ctx, t, api)

c := &testClient{
Client: NewClient(*endpoint, "example", "2020-01-01"),
}
c.Authorizer = conn.Authorizer

path := "/v1.0/applications"
reqOpts := RequestOptions{
ContentType: "application/json",
ExpectedStatusCodes: []int{
http.StatusOK,
},
HttpMethod: http.MethodGet,
OptionsObject: &requestOptions{
query: &odata.Query{
Filter: "startsWith(displayName,'acctest')",
Select: []string{"appId", "displayName"},
Top: 10,
},
},
Pager: &pager{},
Path: path,
}
req, err := c.NewRequest(ctx, reqOpts)
if err != nil {
t.Fatal(err)
}

if _, err = req.ExecutePaged(ctx); err != nil {
t.Fatalf("ExecutePaged(): %v", err)
}
}
Expand Down
5 changes: 5 additions & 0 deletions sdk/client/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ import (
)

type BaseClient interface {
// Execute invokes a non-paginated API request and returns a populated *Response
Execute(ctx context.Context, req *Request) (*Response, error)

// ExecutePaged invokes a paginated API request, merges the results from all pages and returns a populated *Response with all results
ExecutePaged(ctx context.Context, req *Request) (*Response, error)

// NewRequest constructs a *Request that can be passed to Execute or ExecutePaged
NewRequest(ctx context.Context, input RequestOptions) (*Request, error)
}

Expand Down
5 changes: 5 additions & 0 deletions sdk/client/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,13 @@ import (
)

type Options interface {
// ToHeaders yields a custom Headers struct to be appended to the request
ToHeaders() *Headers

// ToOData yields a custom *odata.Query struct to be appended to the request
ToOData() *odata.Query

// ToQuery yields a custom *QueryParams struct to be appended to the request
ToQuery() *QueryParams
}

Expand Down
27 changes: 22 additions & 5 deletions sdk/client/request_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,31 @@

package client

import "fmt"
import (
"fmt"

"github.com/hashicorp/go-azure-sdk/sdk/odata"
)

type RequestOptions struct {
ContentType string
// ContentType is the content type of the request and should include the charset
ContentType string

// ExpectedStatusCodes is a slice of HTTP response codes considered valid for this request
ExpectedStatusCodes []int
HttpMethod string
OptionsObject Options
Path string

// HttpMethod is the capitalized method verb for this request
HttpMethod string

// OptionsObject is used for dynamically modifying the request at runtime
OptionsObject Options

// Pager is an optional struct for handling custom pagination for this request. OData 4.0 compliant paging
// is already handled implicitly and does not require a custom pager.
Pager odata.CustomPager

// Path is the absolute URI for this request, with a leading slash.
Path string
}

func (ro RequestOptions) Validate() error {
Expand Down
Loading

0 comments on commit f0fed54

Please sign in to comment.