From ae4e6f5cd96f34b63a05e8cb2511fa45ac471376 Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Mon, 6 Mar 2017 17:00:32 -0800 Subject: [PATCH] pkg/transport, rafthttp: CancelableTransport A transport that is tied to a context. Lets the consumer cancel all requests and dials with a single Cancel. --- integration/cluster.go | 2 +- pkg/transport/cancelable_transport_test.go | 61 ++++++++++++++++++++++ pkg/transport/timeout_dialer.go | 7 ++- pkg/transport/timeout_transport.go | 11 ++-- pkg/transport/transport.go | 50 ++++++++++++++---- rafthttp/transport.go | 14 +++-- rafthttp/util.go | 4 +- 7 files changed, 122 insertions(+), 27 deletions(-) create mode 100644 pkg/transport/cancelable_transport_test.go diff --git a/integration/cluster.go b/integration/cluster.go index b9bc6b964f67..24fd95425e62 100644 --- a/integration/cluster.go +++ b/integration/cluster.go @@ -840,7 +840,7 @@ func mustNewTransport(t *testing.T, tlsInfo transport.TLSInfo) *http.Transport { if err != nil { t.Fatal(err) } - return tr + return tr.Transport } type SortableMemberSliceByPeerURLs []client.Member diff --git a/pkg/transport/cancelable_transport_test.go b/pkg/transport/cancelable_transport_test.go new file mode 100644 index 000000000000..2239915e9ec5 --- /dev/null +++ b/pkg/transport/cancelable_transport_test.go @@ -0,0 +1,61 @@ +// Copyright 2017 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package transport + +import ( + "fmt" + "net/http" + "strings" + "testing" + "time" +) + +func TestCancelableTransportCancel(t *testing.T) { + sock := "whatever:123" + l, err := NewUnixListener(sock) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + tr, trerr := NewTransport(TLSInfo{}, time.Second) + if trerr != nil { + t.Fatal(trerr) + } + tr.Cancel() + + errc := make(chan error, 1) + go func() { + defer close(errc) + req, reqerr := http.NewRequest("GET", "unix://"+sock, strings.NewReader("abc")) + if reqerr != nil { + errc <- reqerr + return + } + resp, rerr := tr.RoundTrip(req) + if rerr == nil { + errc <- fmt.Errorf("round trip succeeded with %+v, expected error", resp) + } + }() + + select { + case err := <-errc: + if err != nil { + t.Fatal(err) + } + case <-time.After(5 * time.Second): + t.Fatalf("timed out waiting for roundtrip to cancel") + } +} diff --git a/pkg/transport/timeout_dialer.go b/pkg/transport/timeout_dialer.go index 6ae39ecfc9b3..5d986e6f42e6 100644 --- a/pkg/transport/timeout_dialer.go +++ b/pkg/transport/timeout_dialer.go @@ -15,6 +15,7 @@ package transport import ( + "context" "net" "time" ) @@ -26,7 +27,11 @@ type rwTimeoutDialer struct { } func (d *rwTimeoutDialer) Dial(network, address string) (net.Conn, error) { - conn, err := d.Dialer.Dial(network, address) + return d.DialContext(context.Background(), network, address) +} + +func (d *rwTimeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := d.Dialer.DialContext(ctx, network, address) tconn := &timeoutConn{ rdtimeoutd: d.rdtimeoutd, wtimeoutd: d.wtimeoutd, diff --git a/pkg/transport/timeout_transport.go b/pkg/transport/timeout_transport.go index ea16b4c0f869..3de92116f022 100644 --- a/pkg/transport/timeout_transport.go +++ b/pkg/transport/timeout_transport.go @@ -16,7 +16,6 @@ package transport import ( "net" - "net/http" "time" ) @@ -24,7 +23,7 @@ import ( // If read/write on the created connection blocks longer than its time limit, // it will return timeout error. // If read/write timeout is set, transport will not be able to reuse connection. -func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.Duration) (*http.Transport, error) { +func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.Duration) (*CancelableTransport, error) { tr, err := NewTransport(info, dialtimeoutd) if err != nil { return nil, err @@ -39,13 +38,17 @@ func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time. tr.MaxIdleConnsPerHost = 1024 } - tr.Dial = (&rwTimeoutDialer{ + dialer := &rwTimeoutDialer{ Dialer: net.Dialer{ Timeout: dialtimeoutd, KeepAlive: 30 * time.Second, }, rdtimeoutd: rdtimeoutd, wtimeoutd: wtimeoutd, - }).Dial + } + tr.Dial = func(net, addr string) (net.Conn, error) { + return dialer.DialContext(tr.Ctx(), net, addr) + } + return tr, nil } diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index 4a7fe69d2e19..8f3bfd1cb4cb 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -15,43 +15,71 @@ package transport import ( + "context" "net" "net/http" "strings" "time" ) +type CancelableTransport struct { + *http.Transport + ctx context.Context + cancel context.CancelFunc +} + +func (c *CancelableTransport) Ctx() context.Context { return c.ctx } +func (c *CancelableTransport) Cancel() { c.cancel() } + +func (c *CancelableTransport) RoundTrip(req *http.Request) (*http.Response, error) { + if req.Context() != context.Background() { + // request defaults to context.Background; override + return c.Transport.RoundTrip(req) + } + return c.Transport.RoundTrip(req.WithContext(c.ctx)) +} + type unixTransport struct{ *http.Transport } -func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) { +func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*CancelableTransport, error) { cfg, err := info.ClientConfig() if err != nil { return nil, err } + ctx, cancel := context.WithCancel(context.TODO()) + t := &http.Transport{ Proxy: http.ProxyFromEnvironment, - Dial: (&net.Dialer{ - Timeout: dialtimeoutd, - // value taken from http.DefaultTransport - KeepAlive: 30 * time.Second, - }).Dial, // value taken from http.DefaultTransport TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: cfg, } + ct := &CancelableTransport{ + Transport: t, + ctx: ctx, + cancel: cancel, + } + tdialer := &net.Dialer{ + Timeout: dialtimeoutd, + // value taken from http.DefaultTransport + KeepAlive: 30 * time.Second, + } + tdial := func(net, addr string) (net.Conn, error) { + return tdialer.DialContext(ctx, net, addr) + } + t.Dial = tdial dialer := (&net.Dialer{ Timeout: dialtimeoutd, KeepAlive: 30 * time.Second, }) - dial := func(net, addr string) (net.Conn, error) { - return dialer.Dial("unix", addr) + udial := func(net, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, "unix", addr) } - tu := &http.Transport{ Proxy: http.ProxyFromEnvironment, - Dial: dial, + Dial: udial, TLSHandshakeTimeout: 10 * time.Second, TLSClientConfig: cfg, } @@ -60,7 +88,7 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er t.RegisterProtocol("unix", ut) t.RegisterProtocol("unixs", ut) - return t, nil + return ct, nil } func (urt *unixTransport) RoundTrip(req *http.Request) (*http.Response, error) { diff --git a/rafthttp/transport.go b/rafthttp/transport.go index 1f0b46836e66..5bd87615d437 100644 --- a/rafthttp/transport.go +++ b/rafthttp/transport.go @@ -112,8 +112,8 @@ type Transport struct { // machine and thus stop the Transport. ErrorC chan error - streamRt http.RoundTripper // roundTripper used by streams - pipelineRt http.RoundTripper // roundTripper used by pipelines + streamRt *transport.CancelableTransport // roundTripper used by streams + pipelineRt *transport.CancelableTransport // roundTripper used by pipelines mu sync.RWMutex // protect the remote and peer map remotes map[types.ID]*remote // remotes map that helps newly joined member to catch up @@ -189,6 +189,8 @@ func (t *Transport) Send(msgs []raftpb.Message) { func (t *Transport) Stop() { t.mu.Lock() defer t.mu.Unlock() + t.streamRt.Cancel() + t.pipelineRt.Cancel() for _, r := range t.remotes { r.stop() } @@ -196,12 +198,8 @@ func (t *Transport) Stop() { p.stop() } t.prober.RemoveAll() - if tr, ok := t.streamRt.(*http.Transport); ok { - tr.CloseIdleConnections() - } - if tr, ok := t.pipelineRt.(*http.Transport); ok { - tr.CloseIdleConnections() - } + t.streamRt.CloseIdleConnections() + t.pipelineRt.CloseIdleConnections() t.peers = nil t.remotes = nil } diff --git a/rafthttp/util.go b/rafthttp/util.go index 61855c52a60c..d06dabe153fe 100644 --- a/rafthttp/util.go +++ b/rafthttp/util.go @@ -45,7 +45,7 @@ func NewListener(u url.URL, tlscfg *tls.Config) (net.Listener, error) { // NewRoundTripper returns a roundTripper used to send requests // to rafthttp listener of remote peers. -func NewRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (http.RoundTripper, error) { +func NewRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (*transport.CancelableTransport, error) { // It uses timeout transport to pair with remote timeout listeners. // It sets no read/write timeout, because message in requests may // take long time to write out before reading out the response. @@ -57,7 +57,7 @@ func NewRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (http // Read/write timeout is set for stream roundTripper to promptly // find out broken status, which minimizes the number of messages // sent on broken connection. -func newStreamRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (http.RoundTripper, error) { +func newStreamRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (*transport.CancelableTransport, error) { return transport.NewTimeoutTransport(tlsInfo, dialTimeout, ConnReadTimeout, ConnWriteTimeout) }