From 2656b594bbf1bd10b3b2a0c2328cbd12dddcb4b9 Mon Sep 17 00:00:00 2001 From: sharat Date: Wed, 1 Feb 2017 17:08:40 +0530 Subject: [PATCH] rafthttp: use http.Request.WithContext instead of Cancel --- rafthttp/fake_roundtripper_test.go | 3 +++ rafthttp/pipeline.go | 5 +++-- rafthttp/snapshot_sender.go | 6 ++++-- rafthttp/stream.go | 6 +++++- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/rafthttp/fake_roundtripper_test.go b/rafthttp/fake_roundtripper_test.go index 498570ec84b..4e17dee818b 100644 --- a/rafthttp/fake_roundtripper_test.go +++ b/rafthttp/fake_roundtripper_test.go @@ -24,11 +24,14 @@ func (t *roundTripperBlocker) RoundTrip(req *http.Request) (*http.Response, erro t.mu.Lock() t.cancel[req] = c t.mu.Unlock() + ctx := req.Context() select { case <-t.unblockc: return &http.Response{StatusCode: http.StatusNoContent, Body: &nopReadCloser{}}, nil case <-req.Cancel: return nil, errors.New("request canceled") + case <-ctx.Done(): + return nil, errors.New("request canceled") case <-c: return nil, errors.New("request canceled") } diff --git a/rafthttp/pipeline.go b/rafthttp/pipeline.go index ccd9eb78698..d9f07c3479d 100644 --- a/rafthttp/pipeline.go +++ b/rafthttp/pipeline.go @@ -16,13 +16,13 @@ package rafthttp import ( "bytes" + "context" "errors" "io/ioutil" "sync" "time" "github.com/coreos/etcd/etcdserver/stats" - "github.com/coreos/etcd/pkg/httputil" "github.com/coreos/etcd/pkg/pbutil" "github.com/coreos/etcd/pkg/types" "github.com/coreos/etcd/raft" @@ -118,7 +118,8 @@ func (p *pipeline) post(data []byte) (err error) { req := createPostRequest(u, RaftPrefix, bytes.NewBuffer(data), "application/protobuf", p.tr.URLs, p.tr.ID, p.tr.ClusterID) done := make(chan struct{}, 1) - cancel := httputil.RequestCanceler(req) + ctx, cancel := context.WithCancel(context.Background()) + req = req.WithContext(ctx) go func() { select { case <-done: diff --git a/rafthttp/snapshot_sender.go b/rafthttp/snapshot_sender.go index 105b330728e..52273c9d195 100644 --- a/rafthttp/snapshot_sender.go +++ b/rafthttp/snapshot_sender.go @@ -16,6 +16,7 @@ package rafthttp import ( "bytes" + "context" "io" "io/ioutil" "net/http" @@ -104,7 +105,9 @@ func (s *snapshotSender) send(merged snap.Message) { // post posts the given request. // It returns nil when request is sent out and processed successfully. func (s *snapshotSender) post(req *http.Request) (err error) { - cancel := httputil.RequestCanceler(req) + ctx, cancel := context.WithCancel(context.Background()) + req = req.WithContext(ctx) + defer cancel() type responseAndError struct { resp *http.Response @@ -130,7 +133,6 @@ func (s *snapshotSender) post(req *http.Request) (err error) { select { case <-s.stopc: - cancel() return errStopped case r := <-result: if r.err != nil { diff --git a/rafthttp/stream.go b/rafthttp/stream.go index e10056731f5..40d66687387 100644 --- a/rafthttp/stream.go +++ b/rafthttp/stream.go @@ -15,6 +15,7 @@ package rafthttp import ( + "context" "fmt" "io" "io/ioutil" @@ -427,14 +428,17 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) { setPeerURLsHeader(req, cr.tr.URLs) + ctx, cancel := context.WithCancel(context.Background()) + req = req.WithContext(ctx) + cr.mu.Lock() + cr.cancel = cancel select { case <-cr.stopc: cr.mu.Unlock() return nil, fmt.Errorf("stream reader is stopped") default: } - cr.cancel = httputil.RequestCanceler(req) cr.mu.Unlock() resp, err := cr.tr.streamRt.RoundTrip(req)