Skip to content

Commit

Permalink
pkg/transport, rafthttp: CancelableTransport
Browse files Browse the repository at this point in the history
A transport that is tied to a context. Lets the consumer cancel all requests
and dials with a single Cancel.
  • Loading branch information
Anthony Romano committed Mar 27, 2017
1 parent e6f72b4 commit ae4e6f5
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 27 deletions.
2 changes: 1 addition & 1 deletion integration/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -840,7 +840,7 @@ func mustNewTransport(t *testing.T, tlsInfo transport.TLSInfo) *http.Transport {
if err != nil {
t.Fatal(err)
}
return tr
return tr.Transport
}

type SortableMemberSliceByPeerURLs []client.Member
Expand Down
61 changes: 61 additions & 0 deletions pkg/transport/cancelable_transport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright 2017 The etcd Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package transport

import (
"fmt"
"net/http"
"strings"
"testing"
"time"
)

func TestCancelableTransportCancel(t *testing.T) {
sock := "whatever:123"
l, err := NewUnixListener(sock)
if err != nil {
t.Fatal(err)
}
defer l.Close()

tr, trerr := NewTransport(TLSInfo{}, time.Second)
if trerr != nil {
t.Fatal(trerr)
}
tr.Cancel()

errc := make(chan error, 1)
go func() {
defer close(errc)
req, reqerr := http.NewRequest("GET", "unix://"+sock, strings.NewReader("abc"))
if reqerr != nil {
errc <- reqerr
return
}
resp, rerr := tr.RoundTrip(req)
if rerr == nil {
errc <- fmt.Errorf("round trip succeeded with %+v, expected error", resp)
}
}()

select {
case err := <-errc:
if err != nil {
t.Fatal(err)
}
case <-time.After(5 * time.Second):
t.Fatalf("timed out waiting for roundtrip to cancel")
}
}
7 changes: 6 additions & 1 deletion pkg/transport/timeout_dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package transport

import (
"context"
"net"
"time"
)
Expand All @@ -26,7 +27,11 @@ type rwTimeoutDialer struct {
}

func (d *rwTimeoutDialer) Dial(network, address string) (net.Conn, error) {
conn, err := d.Dialer.Dial(network, address)
return d.DialContext(context.Background(), network, address)
}

func (d *rwTimeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := d.Dialer.DialContext(ctx, network, address)
tconn := &timeoutConn{
rdtimeoutd: d.rdtimeoutd,
wtimeoutd: d.wtimeoutd,
Expand Down
11 changes: 7 additions & 4 deletions pkg/transport/timeout_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ package transport

import (
"net"
"net/http"
"time"
)

// NewTimeoutTransport returns a transport created using the given TLS info.
// If read/write on the created connection blocks longer than its time limit,
// it will return timeout error.
// If read/write timeout is set, transport will not be able to reuse connection.
func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.Duration) (*http.Transport, error) {
func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.Duration) (*CancelableTransport, error) {
tr, err := NewTransport(info, dialtimeoutd)
if err != nil {
return nil, err
Expand All @@ -39,13 +38,17 @@ func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.
tr.MaxIdleConnsPerHost = 1024
}

tr.Dial = (&rwTimeoutDialer{
dialer := &rwTimeoutDialer{
Dialer: net.Dialer{
Timeout: dialtimeoutd,
KeepAlive: 30 * time.Second,
},
rdtimeoutd: rdtimeoutd,
wtimeoutd: wtimeoutd,
}).Dial
}
tr.Dial = func(net, addr string) (net.Conn, error) {
return dialer.DialContext(tr.Ctx(), net, addr)
}

return tr, nil
}
50 changes: 39 additions & 11 deletions pkg/transport/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,43 +15,71 @@
package transport

import (
"context"
"net"
"net/http"
"strings"
"time"
)

type CancelableTransport struct {
*http.Transport
ctx context.Context
cancel context.CancelFunc
}

func (c *CancelableTransport) Ctx() context.Context { return c.ctx }
func (c *CancelableTransport) Cancel() { c.cancel() }

func (c *CancelableTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if req.Context() != context.Background() {
// request defaults to context.Background; override
return c.Transport.RoundTrip(req)
}
return c.Transport.RoundTrip(req.WithContext(c.ctx))
}

type unixTransport struct{ *http.Transport }

func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) {
func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*CancelableTransport, error) {
cfg, err := info.ClientConfig()
if err != nil {
return nil, err
}

ctx, cancel := context.WithCancel(context.TODO())

t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: dialtimeoutd,
// value taken from http.DefaultTransport
KeepAlive: 30 * time.Second,
}).Dial,
// value taken from http.DefaultTransport
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: cfg,
}
ct := &CancelableTransport{
Transport: t,
ctx: ctx,
cancel: cancel,
}
tdialer := &net.Dialer{
Timeout: dialtimeoutd,
// value taken from http.DefaultTransport
KeepAlive: 30 * time.Second,
}
tdial := func(net, addr string) (net.Conn, error) {
return tdialer.DialContext(ctx, net, addr)
}
t.Dial = tdial

dialer := (&net.Dialer{
Timeout: dialtimeoutd,
KeepAlive: 30 * time.Second,
})
dial := func(net, addr string) (net.Conn, error) {
return dialer.Dial("unix", addr)
udial := func(net, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "unix", addr)
}

tu := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: dial,
Dial: udial,
TLSHandshakeTimeout: 10 * time.Second,
TLSClientConfig: cfg,
}
Expand All @@ -60,7 +88,7 @@ func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, er
t.RegisterProtocol("unix", ut)
t.RegisterProtocol("unixs", ut)

return t, nil
return ct, nil
}

func (urt *unixTransport) RoundTrip(req *http.Request) (*http.Response, error) {
Expand Down
14 changes: 6 additions & 8 deletions rafthttp/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ type Transport struct {
// machine and thus stop the Transport.
ErrorC chan error

streamRt http.RoundTripper // roundTripper used by streams
pipelineRt http.RoundTripper // roundTripper used by pipelines
streamRt *transport.CancelableTransport // roundTripper used by streams
pipelineRt *transport.CancelableTransport // roundTripper used by pipelines

mu sync.RWMutex // protect the remote and peer map
remotes map[types.ID]*remote // remotes map that helps newly joined member to catch up
Expand Down Expand Up @@ -189,19 +189,17 @@ func (t *Transport) Send(msgs []raftpb.Message) {
func (t *Transport) Stop() {
t.mu.Lock()
defer t.mu.Unlock()
t.streamRt.Cancel()
t.pipelineRt.Cancel()
for _, r := range t.remotes {
r.stop()
}
for _, p := range t.peers {
p.stop()
}
t.prober.RemoveAll()
if tr, ok := t.streamRt.(*http.Transport); ok {
tr.CloseIdleConnections()
}
if tr, ok := t.pipelineRt.(*http.Transport); ok {
tr.CloseIdleConnections()
}
t.streamRt.CloseIdleConnections()
t.pipelineRt.CloseIdleConnections()
t.peers = nil
t.remotes = nil
}
Expand Down
4 changes: 2 additions & 2 deletions rafthttp/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func NewListener(u url.URL, tlscfg *tls.Config) (net.Listener, error) {

// NewRoundTripper returns a roundTripper used to send requests
// to rafthttp listener of remote peers.
func NewRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (http.RoundTripper, error) {
func NewRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (*transport.CancelableTransport, error) {
// It uses timeout transport to pair with remote timeout listeners.
// It sets no read/write timeout, because message in requests may
// take long time to write out before reading out the response.
Expand All @@ -57,7 +57,7 @@ func NewRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (http
// Read/write timeout is set for stream roundTripper to promptly
// find out broken status, which minimizes the number of messages
// sent on broken connection.
func newStreamRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (http.RoundTripper, error) {
func newStreamRoundTripper(tlsInfo transport.TLSInfo, dialTimeout time.Duration) (*transport.CancelableTransport, error) {
return transport.NewTimeoutTransport(tlsInfo, dialTimeout, ConnReadTimeout, ConnWriteTimeout)
}

Expand Down

0 comments on commit ae4e6f5

Please sign in to comment.