From c45432014bf4dc47bf5c10b23fbb5985bf4ca901 Mon Sep 17 00:00:00 2001 From: Hans Hasselberg Date: Thu, 28 May 2020 09:48:34 +0200 Subject: [PATCH 1/3] pool: remove version The version field has been used to decide which multiplexing to use. It was introduced in 2457293dceec95ecd12ef4f01442e13710ea131a. But this is 6y ago and there is no need for this differentiation anymore. --- agent/consul/auto_encrypt.go | 2 +- agent/consul/client.go | 2 +- agent/consul/client_test.go | 2 +- agent/consul/rpc.go | 4 ++-- agent/consul/server_serf.go | 2 +- agent/consul/server_test.go | 2 +- agent/consul/stats_fetcher.go | 2 +- agent/pool/pool.go | 33 ++++++++++----------------- agent/router/manager.go | 4 ++-- agent/router/manager_internal_test.go | 4 ++-- agent/router/manager_test.go | 2 +- 11 files changed, 25 insertions(+), 34 deletions(-) diff --git a/agent/consul/auto_encrypt.go b/agent/consul/auto_encrypt.go index e5dd7be16e26..3beace59840b 100644 --- a/agent/consul/auto_encrypt.go +++ b/agent/consul/auto_encrypt.go @@ -109,7 +109,7 @@ func (c *Client) RequestAutoEncryptCerts(servers []string, port int, token strin for _, ip := range ips { addr := net.TCPAddr{IP: ip, Port: port} - if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, 0, "AutoEncrypt.Sign", &args, &reply); err == nil { + if err = c.connPool.RPC(c.config.Datacenter, c.config.NodeName, &addr, "AutoEncrypt.Sign", &args, &reply); err == nil { return &reply, pkPEM, nil } else { c.logger.Warn("AutoEncrypt failed", "error", err) diff --git a/agent/consul/client.go b/agent/consul/client.go index bcaf5aac1967..c7a36293bae7 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -308,7 +308,7 @@ TRY: } // Make the request. - rpcErr := c.connPool.RPC(c.config.Datacenter, server.ShortName, server.Addr, server.Version, method, args, reply) + rpcErr := c.connPool.RPC(c.config.Datacenter, server.ShortName, server.Addr, method, args, reply) if rpcErr == nil { return nil } diff --git a/agent/consul/client_test.go b/agent/consul/client_test.go index 8b037b18b4fb..5cb6b684386d 100644 --- a/agent/consul/client_test.go +++ b/agent/consul/client_test.go @@ -425,7 +425,7 @@ func TestClient_RPC_ConsulServerPing(t *testing.T) { for range servers { time.Sleep(200 * time.Millisecond) s := c.routers.FindServer() - ok, err := c.connPool.Ping(s.Datacenter, s.ShortName, s.Addr, s.Version) + ok, err := c.connPool.Ping(s.Datacenter, s.ShortName, s.Addr) if !ok { t.Errorf("Unable to ping server %v: %s", s.String(), err) } diff --git a/agent/consul/rpc.go b/agent/consul/rpc.go index f00b14ea82aa..6edaa54a4dde 100644 --- a/agent/consul/rpc.go +++ b/agent/consul/rpc.go @@ -552,7 +552,7 @@ CHECK_LEADER: rpcErr := structs.ErrNoLeader if leader != nil { rpcErr = s.connPool.RPC(s.config.Datacenter, leader.ShortName, leader.Addr, - leader.Version, method, args, reply) + method, args, reply) if rpcErr != nil && canRetry(info, rpcErr) { goto RETRY } @@ -617,7 +617,7 @@ func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{ metrics.IncrCounterWithLabels([]string{"rpc", "cross-dc"}, 1, []metrics.Label{{Name: "datacenter", Value: dc}}) - if err := s.connPool.RPC(dc, server.ShortName, server.Addr, server.Version, method, args, reply); err != nil { + if err := s.connPool.RPC(dc, server.ShortName, server.Addr, method, args, reply); err != nil { manager.NotifyFailedServer(server) s.rpcLogger().Error("RPC failed to server in DC", "server", server.Addr, diff --git a/agent/consul/server_serf.go b/agent/consul/server_serf.go index 9c717a600827..00f88091aab9 100644 --- a/agent/consul/server_serf.go +++ b/agent/consul/server_serf.go @@ -355,7 +355,7 @@ func (s *Server) maybeBootstrap() { // Retry with exponential backoff to get peer status from this server for attempt := uint(0); attempt < maxPeerRetries; attempt++ { - if err := s.connPool.RPC(s.config.Datacenter, server.ShortName, server.Addr, server.Version, + if err := s.connPool.RPC(s.config.Datacenter, server.ShortName, server.Addr, "Status.Peers", &structs.DCSpecificRequest{Datacenter: s.config.Datacenter}, &peers); err != nil { nextRetry := (1 << attempt) * time.Second s.logger.Error("Failed to confirm peer status for server (will retry).", diff --git a/agent/consul/server_test.go b/agent/consul/server_test.go index a49d18e9265f..9cd5b4c7cec0 100644 --- a/agent/consul/server_test.go +++ b/agent/consul/server_test.go @@ -1277,7 +1277,7 @@ func testVerifyRPC(s1, s2 *Server, t *testing.T) (bool, error) { if leader == nil { t.Fatal("no leader") } - return s2.connPool.Ping(leader.Datacenter, leader.ShortName, leader.Addr, leader.Version) + return s2.connPool.Ping(leader.Datacenter, leader.ShortName, leader.Addr) } func TestServer_TLSToNoTLS(t *testing.T) { diff --git a/agent/consul/stats_fetcher.go b/agent/consul/stats_fetcher.go index 1635126d5590..bd283f9e83e5 100644 --- a/agent/consul/stats_fetcher.go +++ b/agent/consul/stats_fetcher.go @@ -43,7 +43,7 @@ func NewStatsFetcher(logger hclog.Logger, pool *pool.ConnPool, datacenter string func (f *StatsFetcher) fetch(server *metadata.Server, replyCh chan *autopilot.ServerStats) { var args struct{} var reply autopilot.ServerStats - err := f.pool.RPC(f.datacenter, server.ShortName, server.Addr, server.Version, "Status.RaftStats", &args, &reply) + err := f.pool.RPC(f.datacenter, server.ShortName, server.Addr, "Status.RaftStats", &args, &reply) if err != nil { f.logger.Warn("error getting server health from server", "server", server.Name, diff --git a/agent/pool/pool.go b/agent/pool/pool.go index 4ce7c1e460d1..16dcb7a91c20 100644 --- a/agent/pool/pool.go +++ b/agent/pool/pool.go @@ -46,7 +46,6 @@ type Conn struct { addr net.Addr session muxSession lastUsed time.Time - version int pool *ConnPool @@ -209,7 +208,7 @@ func (p *ConnPool) Shutdown() error { // wait for an existing connection attempt to finish, if one if in progress, // and will return that one if it succeeds. If all else fails, it will return a // newly-created connection and add it to the pool. -func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, version int, useTLS bool) (*Conn, error) { +func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) { if nodeName == "" { return nil, fmt.Errorf("pool: ConnPool.acquire requires a node name") } @@ -244,7 +243,7 @@ func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, version in // If we are the lead thread, make the new connection and then wake // everybody else up to see if we got it. if isLeadThread { - c, err := p.getNewConn(dc, nodeName, addr, version, useTLS) + c, err := p.getNewConn(dc, nodeName, addr, useTLS) p.Lock() delete(p.limiter, addrStr) close(wait) @@ -497,17 +496,11 @@ func DialTimeoutWithRPCTypeViaMeshGateway( } // getNewConn is used to return a new connection -func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr, version int, useTLS bool) (*Conn, error) { +func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) { if nodeName == "" { return nil, fmt.Errorf("pool: ConnPool.getNewConn requires a node name") } - // Switch the multiplexing based on version - var session muxSession - if version < 2 { - return nil, fmt.Errorf("cannot make client connection, unsupported protocol version %d", version) - } - // Get a new, raw connection and write the Consul multiplex byte to set the mode conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, useTLS, RPCMultiplexV2) if err != nil { @@ -519,7 +512,7 @@ func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr, version conf.LogOutput = p.LogOutput // Create a multiplexed session - session, _ = yamux.Client(conn, conf) + session, _ := yamux.Client(conn, conf) // Wrap the connection c := &Conn{ @@ -529,7 +522,6 @@ func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr, version session: session, clients: list.New(), lastUsed: time.Now(), - version: version, pool: p, } return c, nil @@ -567,12 +559,12 @@ func (p *ConnPool) releaseConn(conn *Conn) { } } -// getClient is used to get a usable client for an address and protocol version -func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr, version int, useTLS bool) (*Conn, *StreamClient, error) { +// getClient is used to get a usable client for an address +func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, *StreamClient, error) { retries := 0 START: // Try to get a conn first - conn, err := p.acquire(dc, nodeName, addr, version, useTLS) + conn, err := p.acquire(dc, nodeName, addr, useTLS) if err != nil { return nil, nil, fmt.Errorf("failed to get conn: %v", err) } @@ -598,7 +590,6 @@ func (p *ConnPool) RPC( dc string, nodeName string, addr net.Addr, - version int, method string, args interface{}, reply interface{}, @@ -610,7 +601,7 @@ func (p *ConnPool) RPC( if method == "AutoEncrypt.Sign" { return p.rpcInsecure(dc, nodeName, addr, method, args, reply) } else { - return p.rpc(dc, nodeName, addr, version, method, args, reply) + return p.rpc(dc, nodeName, addr, method, args, reply) } } @@ -636,12 +627,12 @@ func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method return nil } -func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error { +func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, method string, args interface{}, reply interface{}) error { p.once.Do(p.init) // Get a usable client useTLS := p.TLSConfigurator.UseTLS(dc) - conn, sc, err := p.getClient(dc, nodeName, addr, version, useTLS) + conn, sc, err := p.getClient(dc, nodeName, addr, useTLS) if err != nil { return fmt.Errorf("rpc error getting client: %v", err) } @@ -671,9 +662,9 @@ func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, version int, m // Ping sends a Status.Ping message to the specified server and // returns true if healthy, false if an error occurred -func (p *ConnPool) Ping(dc string, nodeName string, addr net.Addr, version int) (bool, error) { +func (p *ConnPool) Ping(dc string, nodeName string, addr net.Addr) (bool, error) { var out struct{} - err := p.RPC(dc, nodeName, addr, version, "Status.Ping", struct{}{}, &out) + err := p.RPC(dc, nodeName, addr, "Status.Ping", struct{}{}, &out) return err == nil, err } diff --git a/agent/router/manager.go b/agent/router/manager.go index 7944e48d264a..3715d85b4ffa 100644 --- a/agent/router/manager.go +++ b/agent/router/manager.go @@ -61,7 +61,7 @@ type ManagerSerfCluster interface { // Pinger is an interface wrapping client.ConnPool to prevent a cyclic import // dependency. type Pinger interface { - Ping(dc, nodeName string, addr net.Addr, version int) (bool, error) + Ping(dc, nodeName string, addr net.Addr) (bool, error) } // serverList is a local copy of the struct used to maintain the list of @@ -350,7 +350,7 @@ func (m *Manager) RebalanceServers() { if m.serverName != "" && srv.Name == m.serverName { continue } - ok, err := m.connPoolPinger.Ping(srv.Datacenter, srv.ShortName, srv.Addr, srv.Version) + ok, err := m.connPoolPinger.Ping(srv.Datacenter, srv.ShortName, srv.Addr) if ok { foundHealthyServer = true break diff --git a/agent/router/manager_internal_test.go b/agent/router/manager_internal_test.go index b06ccc98d5fd..10f39fbf9c5d 100644 --- a/agent/router/manager_internal_test.go +++ b/agent/router/manager_internal_test.go @@ -33,7 +33,7 @@ type fauxConnPool struct { failPct float64 } -func (cp *fauxConnPool) Ping(string, string, net.Addr, int) (bool, error) { +func (cp *fauxConnPool) Ping(string, string, net.Addr) (bool, error) { var success bool successProb := rand.Float64() if successProb > cp.failPct { @@ -179,7 +179,7 @@ func test_reconcileServerList(maxServers int) (bool, error) { // failPct of the servers for the reconcile. This // allows for the selected server to no longer be // healthy for the reconcile below. - if ok, _ := m.connPoolPinger.Ping(node.Datacenter, node.ShortName, node.Addr, node.Version); ok { + if ok, _ := m.connPoolPinger.Ping(node.Datacenter, node.ShortName, node.Addr); ok { // Will still be present healthyServers = append(healthyServers, node) } else { diff --git a/agent/router/manager_test.go b/agent/router/manager_test.go index 3b99bfe65457..6888891c099e 100644 --- a/agent/router/manager_test.go +++ b/agent/router/manager_test.go @@ -32,7 +32,7 @@ type fauxConnPool struct { failAddr net.Addr } -func (cp *fauxConnPool) Ping(dc string, nodeName string, addr net.Addr, version int) (bool, error) { +func (cp *fauxConnPool) Ping(dc string, nodeName string, addr net.Addr) (bool, error) { var success bool successProb := rand.Float64() From ad03f863ff21ab4e1d3d5d92153c54c6604024dd Mon Sep 17 00:00:00 2001 From: Hans Hasselberg Date: Thu, 28 May 2020 10:18:30 +0200 Subject: [PATCH 2/3] pool: remove useTLS and ForceTLS In the past TLS usage was enforced with these variables, but these days this decision is made by TLSConfigurator and there is no reason to keep using the variables. --- agent/consul/client.go | 3 +- agent/consul/server.go | 1 - agent/consul/snapshot_endpoint.go | 7 ++- agent/consul/snapshot_endpoint_test.go | 16 +++--- agent/consul/status_endpoint_test.go | 27 ++++++---- agent/pool/pool.go | 72 ++++++-------------------- 6 files changed, 45 insertions(+), 81 deletions(-) diff --git a/agent/consul/client.go b/agent/consul/client.go index c7a36293bae7..446ba9962363 100644 --- a/agent/consul/client.go +++ b/agent/consul/client.go @@ -137,7 +137,6 @@ func NewClientLogger(config *Config, logger hclog.InterceptLogger, tlsConfigurat MaxTime: clientRPCConnMaxIdle, MaxStreams: clientMaxStreams, TLSConfigurator: tlsConfigurator, - ForceTLS: config.VerifyOutgoing, Datacenter: config.Datacenter, } @@ -356,7 +355,7 @@ func (c *Client) SnapshotRPC(args *structs.SnapshotRequest, in io.Reader, out io // Request the operation. var reply structs.SnapshotResponse - snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, server.UseTLS, args, in, &reply) + snap, err := SnapshotRPC(c.connPool, c.config.Datacenter, server.ShortName, server.Addr, args, in, &reply) if err != nil { return err } diff --git a/agent/consul/server.go b/agent/consul/server.go index fec1efd6c10c..eb8d1cadd899 100644 --- a/agent/consul/server.go +++ b/agent/consul/server.go @@ -374,7 +374,6 @@ func NewServerLogger(config *Config, logger hclog.InterceptLogger, tokens *token MaxTime: serverRPCCache, MaxStreams: serverMaxStreams, TLSConfigurator: tlsConfigurator, - ForceTLS: config.VerifyOutgoing, Datacenter: config.Datacenter, } diff --git a/agent/consul/snapshot_endpoint.go b/agent/consul/snapshot_endpoint.go index 233354c0e7e7..5f6e3e0f8dba 100644 --- a/agent/consul/snapshot_endpoint.go +++ b/agent/consul/snapshot_endpoint.go @@ -37,7 +37,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re return nil, structs.ErrNoDCPath } - snap, err := SnapshotRPC(s.connPool, dc, server.ShortName, server.Addr, server.UseTLS, args, in, reply) + snap, err := SnapshotRPC(s.connPool, dc, server.ShortName, server.Addr, args, in, reply) if err != nil { manager.NotifyFailedServer(server) return nil, err @@ -52,7 +52,7 @@ func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Re if server == nil { return nil, structs.ErrNoLeader } - return SnapshotRPC(s.connPool, args.Datacenter, server.ShortName, server.Addr, server.UseTLS, args, in, reply) + return SnapshotRPC(s.connPool, args.Datacenter, server.ShortName, server.Addr, args, in, reply) } } @@ -194,14 +194,13 @@ func SnapshotRPC( dc string, nodeName string, addr net.Addr, - useTLS bool, args *structs.SnapshotRequest, in io.Reader, reply *structs.SnapshotResponse, ) (io.ReadCloser, error) { // Write the snapshot RPC byte to set the mode, then perform the // request. - conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, useTLS, pool.RPCSnapshot) + conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, pool.RPCSnapshot) if err != nil { return nil, err } diff --git a/agent/consul/snapshot_endpoint_test.go b/agent/consul/snapshot_endpoint_test.go index 9073fa01e62c..e0cd31a1dafd 100644 --- a/agent/consul/snapshot_endpoint_test.go +++ b/agent/consul/snapshot_endpoint_test.go @@ -46,7 +46,7 @@ func verifySnapshot(t *testing.T, s *Server, dc, token string) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - snap, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + snap, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err != nil { t.Fatalf("err: %v", err) @@ -121,7 +121,7 @@ func verifySnapshot(t *testing.T, s *Server, dc, token string) { // Restore the snapshot. args.Op = structs.SnapshotRestore - restore, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + restore, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, snap, &reply) if err != nil { t.Fatalf("err: %v", err) @@ -196,7 +196,7 @@ func TestSnapshot_LeaderState(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - snap, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + snap, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err != nil { t.Fatalf("err: %v", err) @@ -229,7 +229,7 @@ func TestSnapshot_LeaderState(t *testing.T) { // Restore the snapshot. args.Op = structs.SnapshotRestore - restore, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + restore, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, snap, &reply) if err != nil { t.Fatalf("err: %v", err) @@ -268,7 +268,7 @@ func TestSnapshot_ACLDeny(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if !acl.IsErrPermissionDenied(err) { t.Fatalf("err: %v", err) @@ -282,7 +282,7 @@ func TestSnapshot_ACLDeny(t *testing.T) { Op: structs.SnapshotRestore, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, false, + _, err := SnapshotRPC(s1.connPool, s1.config.Datacenter, s1.config.NodeName, s1.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if !acl.IsErrPermissionDenied(err) { t.Fatalf("err: %v", err) @@ -391,7 +391,7 @@ func TestSnapshot_AllowStale(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err == nil || !strings.Contains(err.Error(), structs.ErrNoLeader.Error()) { t.Fatalf("err: %v", err) @@ -408,7 +408,7 @@ func TestSnapshot_AllowStale(t *testing.T) { Op: structs.SnapshotSave, } var reply structs.SnapshotResponse - _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, false, + _, err := SnapshotRPC(s.connPool, s.config.Datacenter, s.config.NodeName, s.config.RPCAddr, &args, bytes.NewReader([]byte("")), &reply) if err == nil || !strings.Contains(err.Error(), "Raft error when taking snapshot") { t.Fatalf("err: %v", err) diff --git a/agent/consul/status_endpoint_test.go b/agent/consul/status_endpoint_test.go index a9cc158fac48..4ef010830b8b 100644 --- a/agent/consul/status_endpoint_test.go +++ b/agent/consul/status_endpoint_test.go @@ -37,20 +37,25 @@ func insecureRPCClient(s *Server, c tlsutil.Config) (rpc.ClientCodec, error) { if wrapper == nil { return nil, err } - conn, _, err := pool.DialTimeoutWithRPCTypeDirectly( - s.config.Datacenter, - s.config.NodeName, - addr, - nil, - time.Second, - true, - wrapper, - pool.RPCTLSInsecure, - pool.RPCTLSInsecure, - ) + d := &net.Dialer{Timeout: time.Second} + conn, err := d.Dial("tcp", addr.String()) if err != nil { return nil, err } + // Switch the connection into TLS mode + if _, err = conn.Write([]byte{byte(pool.RPCTLSInsecure)}); err != nil { + conn.Close() + return nil, err + } + + // Wrap the connection in a TLS client + tlsConn, err := wrapper(s.config.Datacenter, conn) + if err != nil { + conn.Close() + return nil, err + } + conn = tlsConn + return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle), nil } diff --git a/agent/pool/pool.go b/agent/pool/pool.go index 16dcb7a91c20..6255f4f4e43b 100644 --- a/agent/pool/pool.go +++ b/agent/pool/pool.go @@ -146,9 +146,6 @@ type ConnPool struct { // Datacenter is the datacenter of the current agent. Datacenter string - // ForceTLS is used to enforce outgoing TLS verification - ForceTLS bool - // Server should be set to true if this connection pool is configured in a // server instead of a client. Server bool @@ -208,7 +205,7 @@ func (p *ConnPool) Shutdown() error { // wait for an existing connection attempt to finish, if one if in progress, // and will return that one if it succeeds. If all else fails, it will return a // newly-created connection and add it to the pool. -func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) { +func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr) (*Conn, error) { if nodeName == "" { return nil, fmt.Errorf("pool: ConnPool.acquire requires a node name") } @@ -243,7 +240,7 @@ func (p *ConnPool) acquire(dc string, nodeName string, addr net.Addr, useTLS boo // If we are the lead thread, make the new connection and then wake // everybody else up to see if we got it. if isLeadThread { - c, err := p.getNewConn(dc, nodeName, addr, useTLS) + c, err := p.getNewConn(dc, nodeName, addr) p.Lock() delete(p.limiter, addrStr) close(wait) @@ -290,7 +287,6 @@ func (p *ConnPool) DialTimeout( nodeName string, addr net.Addr, timeout time.Duration, - useTLS bool, actualRPCType RPCType, ) (net.Conn, HalfCloser, error) { p.once.Do(p.init) @@ -314,64 +310,26 @@ func (p *ConnPool) DialTimeout( ) } - return DialTimeoutWithRPCTypeDirectly( + return p.dial( dc, nodeName, addr, - p.SrcAddr, timeout, - useTLS || p.ForceTLS, - p.TLSConfigurator.OutgoingRPCWrapper(), actualRPCType, RPCTLS, ) } -// DialTimeoutInsecure is used to establish a raw connection to the given -// server, with given connection timeout. It also writes RPCTLSInsecure as the -// first byte to indicate that the client cannot provide a certificate. This is -// so far only used for AutoEncrypt.Sign. -func (p *ConnPool) DialTimeoutInsecure( - dc string, - nodeName string, - addr net.Addr, - timeout time.Duration, - wrapper tlsutil.DCWrapper, -) (net.Conn, HalfCloser, error) { - p.once.Do(p.init) - - if wrapper == nil { - return nil, nil, fmt.Errorf("wrapper cannot be nil") - } else if dc != p.Datacenter { - return nil, nil, fmt.Errorf("insecure dialing prohibited between datacenters") - } - - return DialTimeoutWithRPCTypeDirectly( - dc, - nodeName, - addr, - p.SrcAddr, - timeout, - true, - wrapper, - RPCTLSInsecure, - RPCTLSInsecure, - ) -} - -func DialTimeoutWithRPCTypeDirectly( +func (p *ConnPool) dial( dc string, nodeName string, addr net.Addr, - src *net.TCPAddr, timeout time.Duration, - useTLS bool, - wrapper tlsutil.DCWrapper, actualRPCType RPCType, tlsRPCType RPCType, ) (net.Conn, HalfCloser, error) { // Try to dial the conn - d := &net.Dialer{LocalAddr: src, Timeout: timeout} + d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: timeout} conn, err := d.Dial("tcp", addr.String()) if err != nil { return nil, nil, err @@ -388,7 +346,8 @@ func DialTimeoutWithRPCTypeDirectly( } // Check if TLS is enabled - if useTLS && wrapper != nil { + if p.TLSConfigurator.UseTLS(dc) { + wrapper := p.TLSConfigurator.OutgoingRPCWrapper() // Switch the connection into TLS mode if _, err := conn.Write([]byte{byte(tlsRPCType)}); err != nil { conn.Close() @@ -496,13 +455,13 @@ func DialTimeoutWithRPCTypeViaMeshGateway( } // getNewConn is used to return a new connection -func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, error) { +func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr) (*Conn, error) { if nodeName == "" { return nil, fmt.Errorf("pool: ConnPool.getNewConn requires a node name") } // Get a new, raw connection and write the Consul multiplex byte to set the mode - conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, useTLS, RPCMultiplexV2) + conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, RPCMultiplexV2) if err != nil { return nil, err } @@ -560,11 +519,11 @@ func (p *ConnPool) releaseConn(conn *Conn) { } // getClient is used to get a usable client for an address -func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr, useTLS bool) (*Conn, *StreamClient, error) { +func (p *ConnPool) getClient(dc string, nodeName string, addr net.Addr) (*Conn, *StreamClient, error) { retries := 0 START: // Try to get a conn first - conn, err := p.acquire(dc, nodeName, addr, useTLS) + conn, err := p.acquire(dc, nodeName, addr) if err != nil { return nil, nil, fmt.Errorf("failed to get conn: %v", err) } @@ -611,8 +570,12 @@ func (p *ConnPool) RPC( // AutoEncrypt.Sign is a one-off call and it doesn't make sense to pool that // connection if it is not being reused. func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method string, args interface{}, reply interface{}) error { + if dc != p.Datacenter { + return fmt.Errorf("insecure dialing prohibited between datacenters") + } + var codec rpc.ClientCodec - conn, _, err := p.DialTimeoutInsecure(dc, nodeName, addr, 1*time.Second, p.TLSConfigurator.OutgoingRPCWrapper()) + conn, _, err := p.dial(dc, nodeName, addr, 1*time.Second, 0, RPCTLSInsecure) if err != nil { return fmt.Errorf("rpcinsecure error establishing connection: %v", err) } @@ -631,8 +594,7 @@ func (p *ConnPool) rpc(dc string, nodeName string, addr net.Addr, method string, p.once.Do(p.init) // Get a usable client - useTLS := p.TLSConfigurator.UseTLS(dc) - conn, sc, err := p.getClient(dc, nodeName, addr, useTLS) + conn, sc, err := p.getClient(dc, nodeName, addr) if err != nil { return fmt.Errorf("rpc error getting client: %v", err) } From 1fbc1d4777d26ba944e7755234db942fad276ff5 Mon Sep 17 00:00:00 2001 From: Hans Hasselberg Date: Thu, 28 May 2020 10:56:10 +0200 Subject: [PATCH 3/3] pool: remove timeout parameter Timeout was never used in a meaningful way by callers, which is why it is now entirely internal to the pool. --- agent/consul/snapshot_endpoint.go | 2 +- agent/pool/pool.go | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/agent/consul/snapshot_endpoint.go b/agent/consul/snapshot_endpoint.go index 5f6e3e0f8dba..fb2f581d801f 100644 --- a/agent/consul/snapshot_endpoint.go +++ b/agent/consul/snapshot_endpoint.go @@ -200,7 +200,7 @@ func SnapshotRPC( ) (io.ReadCloser, error) { // Write the snapshot RPC byte to set the mode, then perform the // request. - conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, 10*time.Second, pool.RPCSnapshot) + conn, hc, err := connPool.DialTimeout(dc, nodeName, addr, pool.RPCSnapshot) if err != nil { return nil, err } diff --git a/agent/pool/pool.go b/agent/pool/pool.go index 6255f4f4e43b..7f14faa38e27 100644 --- a/agent/pool/pool.go +++ b/agent/pool/pool.go @@ -286,7 +286,6 @@ func (p *ConnPool) DialTimeout( dc string, nodeName string, addr net.Addr, - timeout time.Duration, actualRPCType RPCType, ) (net.Conn, HalfCloser, error) { p.once.Do(p.init) @@ -298,7 +297,6 @@ func (p *ConnPool) DialTimeout( nodeName, addr, p.SrcAddr, - timeout, p.TLSConfigurator.OutgoingALPNRPCWrapper(), actualRPCType, RPCTLS, @@ -314,7 +312,6 @@ func (p *ConnPool) DialTimeout( dc, nodeName, addr, - timeout, actualRPCType, RPCTLS, ) @@ -324,12 +321,11 @@ func (p *ConnPool) dial( dc string, nodeName string, addr net.Addr, - timeout time.Duration, actualRPCType RPCType, tlsRPCType RPCType, ) (net.Conn, HalfCloser, error) { // Try to dial the conn - d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: timeout} + d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: defaultDialTimeout} conn, err := d.Dial("tcp", addr.String()) if err != nil { return nil, nil, err @@ -393,7 +389,6 @@ func DialTimeoutWithRPCTypeViaMeshGateway( nodeName string, addr net.Addr, src *net.TCPAddr, - timeout time.Duration, wrapper tlsutil.ALPNWrapper, actualRPCType RPCType, tlsRPCType RPCType, @@ -425,7 +420,7 @@ func DialTimeoutWithRPCTypeViaMeshGateway( return nil, nil, structs.ErrDCNotAvailable } - dialer := &net.Dialer{LocalAddr: src, Timeout: timeout} + dialer := &net.Dialer{LocalAddr: src, Timeout: defaultDialTimeout} rawConn, err := dialer.Dial("tcp", gwAddr) if err != nil { @@ -461,7 +456,7 @@ func (p *ConnPool) getNewConn(dc string, nodeName string, addr net.Addr) (*Conn, } // Get a new, raw connection and write the Consul multiplex byte to set the mode - conn, _, err := p.DialTimeout(dc, nodeName, addr, defaultDialTimeout, RPCMultiplexV2) + conn, _, err := p.DialTimeout(dc, nodeName, addr, RPCMultiplexV2) if err != nil { return nil, err } @@ -575,7 +570,7 @@ func (p *ConnPool) rpcInsecure(dc string, nodeName string, addr net.Addr, method } var codec rpc.ClientCodec - conn, _, err := p.dial(dc, nodeName, addr, 1*time.Second, 0, RPCTLSInsecure) + conn, _, err := p.dial(dc, nodeName, addr, 0, RPCTLSInsecure) if err != nil { return fmt.Errorf("rpcinsecure error establishing connection: %v", err) }