Skip to content

Commit

Permalink
Add context.Context to all services
Browse files Browse the repository at this point in the history
Services in Elastic will now accept and honour `context.Context`. To do
that, all services now have a `Do`/`DoC` pair of methods. The latter
accepts a `context.Context`.

`Client` implements this via `PerformRequest` and `PerformRequestC`.
The latter will accept a `context.Context` as its first parameter.
If a `context.Context` is passed, Elastic uses the context-aware
`golang.org/x/net/context/ctxhttp` package to perform HTTP requests.

See e.g. #239
  • Loading branch information
olivere committed Aug 30, 2016
1 parent a6cb826 commit 73e440e
Show file tree
Hide file tree
Showing 57 changed files with 576 additions and 86 deletions.
11 changes: 10 additions & 1 deletion bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"fmt"
"net/url"

"golang.org/x/net/context"

"gopkg.in/olivere/elastic.v3/uritemplates"
)

Expand Down Expand Up @@ -145,6 +147,13 @@ func (s *BulkService) bodyAsString() (string, error) {
// you can reuse the BulkService for the next batch as the list of bulk
// requests is cleared on success.
func (s *BulkService) Do() (*BulkResponse, error) {
return s.DoC(nil)
}

// DoC sends the batched requests to Elasticsearch. Note that, when successful,
// you can reuse the BulkService for the next batch as the list of bulk
// requests is cleared on success.
func (s *BulkService) DoC(ctx context.Context) (*BulkResponse, error) {
// No actions?
if s.NumberOfActions() == 0 {
return nil, errors.New("elastic: No bulk actions to commit")
Expand Down Expand Up @@ -191,7 +200,7 @@ func (s *BulkService) Do() (*BulkResponse, error) {
}

// Get response
res, err := s.client.PerformRequest("POST", path, params, body)
res, err := s.client.PerformRequestC(ctx, "POST", path, params, body)
if err != nil {
return nil, err
}
Expand Down
9 changes: 8 additions & 1 deletion clear_scroll.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"fmt"
"net/url"
"strings"

"golang.org/x/net/context"
)

// ClearScrollService clears one or more scroll contexts by their ids.
Expand Down Expand Up @@ -68,6 +70,11 @@ func (s *ClearScrollService) Validate() error {

// Do executes the operation.
func (s *ClearScrollService) Do() (*ClearScrollResponse, error) {
return s.DoC(nil)
}

// DoC executes the operation.
func (s *ClearScrollService) DoC(ctx context.Context) (*ClearScrollResponse, error) {
// Check pre-conditions
if err := s.Validate(); err != nil {
return nil, err
Expand All @@ -83,7 +90,7 @@ func (s *ClearScrollService) Do() (*ClearScrollResponse, error) {
body := strings.Join(s.scrollId, ",")

// Get HTTP response
res, err := s.client.PerformRequest("DELETE", path, params, body)
res, err := s.client.PerformRequestC(ctx, "DELETE", path, params, body)
if err != nil {
return nil, err
}
Expand Down
23 changes: 22 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import (
"strings"
"sync"
"time"

"golang.org/x/net/context"
"golang.org/x/net/context/ctxhttp"
)

const (
Expand Down Expand Up @@ -1039,6 +1042,19 @@ func (c *Client) mustActiveConn() error {
// This is necessary for services that expect e.g. HTTP status 404 as a
// valid outcome (Exists, IndicesExists, IndicesTypeExists).
func (c *Client) PerformRequest(method, path string, params url.Values, body interface{}, ignoreErrors ...int) (*Response, error) {
return c.PerformRequestC(nil, method, path, params, body, ignoreErrors...)
}

// PerformRequestC does a HTTP request to Elasticsearch.
// It returns a response and an error on failure.
//
// Optionally, a list of HTTP error codes to ignore can be passed.
// This is necessary for services that expect e.g. HTTP status 404 as a
// valid outcome (Exists, IndicesExists, IndicesTypeExists).
//
// If ctx is not nil, it uses the ctxhttp to do the request,
// enabling both request cancelation as well as timeout.
func (c *Client) PerformRequestC(ctx context.Context, method, path string, params url.Values, body interface{}, ignoreErrors ...int) (*Response, error) {
start := time.Now().UTC()

c.mu.RLock()
Expand Down Expand Up @@ -1116,7 +1132,12 @@ func (c *Client) PerformRequest(method, path string, params url.Values, body int
c.dumpRequest((*http.Request)(req))

// Get response
res, err := c.c.Do((*http.Request)(req))
var res *http.Response
if ctx == nil {
res, err = c.c.Do((*http.Request)(req))
} else {
res, err = ctxhttp.Do(ctx, c.c, (*http.Request)(req))
}
if err != nil {
retries -= 1
if retries <= 0 {
Expand Down
102 changes: 101 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2012-2015 Oliver Eilhard. All rights reserved.
// Copyright 2012-present Oliver Eilhard. All rights reserved.
// Use of this source code is governed by a MIT-license.
// See http://olivere.mit-license.org/license.txt for details.

Expand All @@ -15,6 +15,8 @@ import (
"strings"
"testing"
"time"

"golang.org/x/net/context"
)

func findConn(s string, slice ...*conn) (int, bool) {
Expand Down Expand Up @@ -702,6 +704,28 @@ func TestPerformRequest(t *testing.T) {
}
}

func TestPerformRequestC(t *testing.T) {
client, err := NewClient()
if err != nil {
t.Fatal(err)
}
res, err := client.PerformRequestC(context.Background(), "GET", "/", nil, nil)
if err != nil {
t.Fatal(err)
}
if res == nil {
t.Fatal("expected response to be != nil")
}

ret := new(PingResult)
if err := json.Unmarshal(res.Body, ret); err != nil {
t.Fatalf("expected no error on decode; got: %v", err)
}
if ret.ClusterName == "" {
t.Errorf("expected cluster name; got: %q", ret.ClusterName)
}
}

func TestPerformRequestWithSimpleClient(t *testing.T) {
client, err := NewSimpleClient()
if err != nil {
Expand Down Expand Up @@ -942,3 +966,79 @@ func TestPerformRequestWithSetBodyError(t *testing.T) {
t.Fatal("expected no response")
}
}

// sleepingTransport will sleep before doing a request.
type sleepingTransport struct {
timeout time.Duration
}

// RoundTrip implements a "sleepy" transport.
func (tr *sleepingTransport) RoundTrip(r *http.Request) (*http.Response, error) {
time.Sleep(tr.timeout)
return http.DefaultTransport.RoundTrip(r)
}

func TestPerformRequestCWithCancel(t *testing.T) {
tr := &sleepingTransport{timeout: 3 * time.Second}
httpClient := &http.Client{Transport: tr}

client, err := NewSimpleClient(SetHttpClient(httpClient), SetMaxRetries(0))
if err != nil {
t.Fatal(err)
}

type result struct {
res *Response
err error
}
ctx, cancel := context.WithCancel(context.Background())

resc := make(chan result, 1)
go func() {
res, err := client.PerformRequestC(ctx, "GET", "/", nil, nil)
resc <- result{res: res, err: err}
}()
select {
case <-time.After(1 * time.Second):
cancel()
case res := <-resc:
t.Fatalf("expected response before cancel, got %v", res)
case <-ctx.Done():
t.Fatalf("expected no early termination, got ctx.Done(): %v", ctx.Err())
}
err = ctx.Err()
if err != context.Canceled {
t.Fatalf("expected error context.Canceled, got: %v", err)
}
}

func TestPerformRequestCWithTimeout(t *testing.T) {
tr := &sleepingTransport{timeout: 3 * time.Second}
httpClient := &http.Client{Transport: tr}

client, err := NewSimpleClient(SetHttpClient(httpClient), SetMaxRetries(0))
if err != nil {
t.Fatal(err)
}

type result struct {
res *Response
err error
}
ctx, _ := context.WithTimeout(context.Background(), 1*time.Second)

resc := make(chan result, 1)
go func() {
res, err := client.PerformRequestC(ctx, "GET", "/", nil, nil)
resc <- result{res: res, err: err}
}()
select {
case res := <-resc:
t.Fatalf("expected timeout before response, got %v", res)
case <-ctx.Done():
err := ctx.Err()
if err != context.DeadlineExceeded {
t.Fatalf("expected error context.DeadlineExceeded, got: %v", err)
}
}
}
9 changes: 8 additions & 1 deletion cluster_health.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"net/url"
"strings"

"golang.org/x/net/context"

"gopkg.in/olivere/elastic.v3/uritemplates"
)

Expand Down Expand Up @@ -166,6 +168,11 @@ func (s *ClusterHealthService) Validate() error {

// Do executes the operation.
func (s *ClusterHealthService) Do() (*ClusterHealthResponse, error) {
return s.DoC(nil)
}

// DoC executes the operation.
func (s *ClusterHealthService) DoC(ctx context.Context) (*ClusterHealthResponse, error) {
// Check pre-conditions
if err := s.Validate(); err != nil {
return nil, err
Expand All @@ -178,7 +185,7 @@ func (s *ClusterHealthService) Do() (*ClusterHealthResponse, error) {
}

// Get HTTP response
res, err := s.client.PerformRequest("GET", path, params, nil)
res, err := s.client.PerformRequestC(ctx, "GET", path, params, nil)
if err != nil {
return nil, err
}
Expand Down
9 changes: 8 additions & 1 deletion cluster_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"net/url"
"strings"

"golang.org/x/net/context"

"gopkg.in/olivere/elastic.v3/uritemplates"
)

Expand Down Expand Up @@ -152,6 +154,11 @@ func (s *ClusterStateService) Validate() error {

// Do executes the operation.
func (s *ClusterStateService) Do() (*ClusterStateResponse, error) {
return s.DoC(nil)
}

// DoC executes the operation.
func (s *ClusterStateService) DoC(ctx context.Context) (*ClusterStateResponse, error) {
// Check pre-conditions
if err := s.Validate(); err != nil {
return nil, err
Expand All @@ -164,7 +171,7 @@ func (s *ClusterStateService) Do() (*ClusterStateResponse, error) {
}

// Get HTTP response
res, err := s.client.PerformRequest("GET", path, params, nil)
res, err := s.client.PerformRequestC(ctx, "GET", path, params, nil)
if err != nil {
return nil, err
}
Expand Down
9 changes: 8 additions & 1 deletion cluster_stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"net/url"
"strings"

"golang.org/x/net/context"

"gopkg.in/olivere/elastic.v3/uritemplates"
)

Expand Down Expand Up @@ -94,6 +96,11 @@ func (s *ClusterStatsService) Validate() error {

// Do executes the operation.
func (s *ClusterStatsService) Do() (*ClusterStatsResponse, error) {
return s.DoC(nil)
}

// DoC executes the operation.
func (s *ClusterStatsService) DoC(ctx context.Context) (*ClusterStatsResponse, error) {
// Check pre-conditions
if err := s.Validate(); err != nil {
return nil, err
Expand All @@ -106,7 +113,7 @@ func (s *ClusterStatsService) Do() (*ClusterStatsResponse, error) {
}

// Get HTTP response
res, err := s.client.PerformRequest("GET", path, params, nil)
res, err := s.client.PerformRequestC(ctx, "GET", path, params, nil)
if err != nil {
return nil, err
}
Expand Down
8 changes: 7 additions & 1 deletion count.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"net/url"
"strings"

"golang.org/x/net/context"

"gopkg.in/olivere/elastic.v3/uritemplates"
)

Expand Down Expand Up @@ -257,6 +259,10 @@ func (s *CountService) Validate() error {

// Do executes the operation.
func (s *CountService) Do() (int64, error) {
return s.DoC(nil)
}

func (s *CountService) DoC(ctx context.Context) (int64, error) {
// Check pre-conditions
if err := s.Validate(); err != nil {
return 0, err
Expand Down Expand Up @@ -285,7 +291,7 @@ func (s *CountService) Do() (int64, error) {
}

// Get HTTP response
res, err := s.client.PerformRequest("POST", path, params, body)
res, err := s.client.PerformRequestC(ctx, "POST", path, params, body)
if err != nil {
return 0, err
}
Expand Down
9 changes: 8 additions & 1 deletion delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"fmt"
"net/url"

"golang.org/x/net/context"

"gopkg.in/olivere/elastic.v3/uritemplates"
)

Expand Down Expand Up @@ -175,6 +177,11 @@ func (s *DeleteService) Validate() error {

// Do executes the operation.
func (s *DeleteService) Do() (*DeleteResponse, error) {
return s.DoC(nil)
}

// DoC executes the operation.
func (s *DeleteService) DoC(ctx context.Context) (*DeleteResponse, error) {
// Check pre-conditions
if err := s.Validate(); err != nil {
return nil, err
Expand All @@ -187,7 +194,7 @@ func (s *DeleteService) Do() (*DeleteResponse, error) {
}

// Get HTTP response
res, err := s.client.PerformRequest("DELETE", path, params, nil)
res, err := s.client.PerformRequestC(ctx, "DELETE", path, params, nil)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 73e440e

Please sign in to comment.