Skip to content

Commit

Permalink
Merge pull request #927 from hashicorp/f-tls
Browse files Browse the repository at this point in the history
Add new `verify_server_hostname` to mitigate possibility of MITM
  • Loading branch information
armon committed May 12, 2015
2 parents d3dee0d + 90d6204 commit ebf961e
Show file tree
Hide file tree
Showing 16 changed files with 389 additions and 52 deletions.
2 changes: 2 additions & 0 deletions command/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,12 @@ func (a *Agent) consulConfig() *consul.Config {
// Copy the TLS configuration
base.VerifyIncoming = a.config.VerifyIncoming
base.VerifyOutgoing = a.config.VerifyOutgoing
base.VerifyServerHostname = a.config.VerifyServerHostname
base.CAFile = a.config.CAFile
base.CertFile = a.config.CertFile
base.KeyFile = a.config.KeyFile
base.ServerName = a.config.ServerName
base.Domain = a.config.Domain

// Setup the ServerUp callback
base.ServerUp = a.state.ConsulServerUp
Expand Down
11 changes: 11 additions & 0 deletions command/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ type Config struct {
// certificate authority. This is used to verify authenticity of server nodes.
VerifyOutgoing bool `mapstructure:"verify_outgoing"`

// VerifyServerHostname is used to enable hostname verification of servers. This
// ensures that the certificate presented is valid for server.<datacenter>.<domain>.
// This prevents a compromised client from being restarted as a server, and then
// intercepting request traffic as well as being added as a raft peer. This should be
// enabled by default with VerifyOutgoing, but for legacy reasons we cannot break
// existing clients.
VerifyServerHostname bool `mapstructure:"verify_server_hostname"`

// CAFile is a path to a certificate authority file. This is used with VerifyIncoming
// or VerifyOutgoing to verify the TLS connection.
CAFile string `mapstructure:"ca_file"`
Expand Down Expand Up @@ -838,6 +846,9 @@ func MergeConfig(a, b *Config) *Config {
if b.VerifyOutgoing {
result.VerifyOutgoing = true
}
if b.VerifyServerHostname {
result.VerifyServerHostname = true
}
if b.CAFile != "" {
result.CAFile = b.CAFile
}
Expand Down
6 changes: 5 additions & 1 deletion command/agent/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ func TestDecodeConfig(t *testing.T) {
}

// TLS
input = `{"verify_incoming": true, "verify_outgoing": true}`
input = `{"verify_incoming": true, "verify_outgoing": true, "verify_server_hostname": true}`
config, err = DecodeConfig(bytes.NewReader([]byte(input)))
if err != nil {
t.Fatalf("err: %s", err)
Expand All @@ -259,6 +259,10 @@ func TestDecodeConfig(t *testing.T) {
t.Fatalf("bad: %#v", config)
}

if config.VerifyServerHostname != true {
t.Fatalf("bad: %#v", config)
}

// TLS keys
input = `{"ca_file": "my/ca/file", "cert_file": "my.cert", "key_file": "key.pem", "server_name": "example.com"}`
config, err = DecodeConfig(bytes.NewReader([]byte(input)))
Expand Down
12 changes: 5 additions & 7 deletions consul/client.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package consul

import (
"crypto/tls"
"fmt"
"log"
"math/rand"
Expand Down Expand Up @@ -91,10 +90,9 @@ func NewClient(config *Config) (*Client, error) {
config.LogOutput = os.Stderr
}

// Create the tlsConfig
var tlsConfig *tls.Config
var err error
if tlsConfig, err = config.tlsConfig().OutgoingTLSConfig(); err != nil {
// Create the tls Wrapper
tlsWrap, err := config.tlsConfig().OutgoingTLSWrapper()
if err != nil {
return nil, err
}

Expand All @@ -104,7 +102,7 @@ func NewClient(config *Config) (*Client, error) {
// Create server
c := &Client{
config: config,
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsConfig),
connPool: NewPool(config.LogOutput, clientRPCCache, clientMaxStreams, tlsWrap),
eventCh: make(chan serf.Event, 256),
logger: logger,
shutdownCh: make(chan struct{}),
Expand Down Expand Up @@ -357,7 +355,7 @@ func (c *Client) RPC(method string, args interface{}, reply interface{}) error {

// Forward to remote Consul
TRY_RPC:
if err := c.connPool.RPC(server.Addr, server.Version, method, args, reply); err != nil {
if err := c.connPool.RPC(c.config.Datacenter, server.Addr, server.Version, method, args, reply); err != nil {
c.lastServer = nil
c.lastRPCTime = time.Time{}
return err
Expand Down
29 changes: 21 additions & 8 deletions consul/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ type Config struct {
// Node name is the name we use to advertise. Defaults to hostname.
NodeName string

// Domain is the DNS domain for the records. Defaults to "consul."
Domain string

// RaftConfig is the configuration used for Raft in the local DC
RaftConfig *raft.Config

Expand Down Expand Up @@ -100,6 +103,14 @@ type Config struct {
// server nodes.
VerifyOutgoing bool

// VerifyServerHostname is used to enable hostname verification of servers. This
// ensures that the certificate presented is valid for server.<datacenter>.<domain>.
// This prevents a compromised client from being restarted as a server, and then
// intercepting request traffic as well as being added as a raft peer. This should be
// enabled by default with VerifyOutgoing, but for legacy reasons we cannot break
// existing clients.
VerifyServerHostname bool

// CAFile is a path to a certificate authority file. This is used with VerifyIncoming
// or VerifyOutgoing to verify the TLS connection.
CAFile string
Expand Down Expand Up @@ -267,13 +278,15 @@ func DefaultConfig() *Config {

func (c *Config) tlsConfig() *tlsutil.Config {
tlsConf := &tlsutil.Config{
VerifyIncoming: c.VerifyIncoming,
VerifyOutgoing: c.VerifyOutgoing,
CAFile: c.CAFile,
CertFile: c.CertFile,
KeyFile: c.KeyFile,
NodeName: c.NodeName,
ServerName: c.ServerName}

VerifyIncoming: c.VerifyIncoming,
VerifyOutgoing: c.VerifyOutgoing,
VerifyServerHostname: c.VerifyServerHostname,
CAFile: c.CAFile,
CertFile: c.CertFile,
KeyFile: c.KeyFile,
NodeName: c.NodeName,
ServerName: c.ServerName,
Domain: c.Domain,
}
return tlsConf
}
27 changes: 13 additions & 14 deletions consul/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package consul

import (
"container/list"
"crypto/tls"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -135,8 +134,8 @@ type ConnPool struct {
// Pool maps an address to a open connection
pool map[string]*Conn

// TLS settings
tlsConfig *tls.Config
// TLS wrapper
tlsWrap tlsutil.DCWrapper

// Used to indicate the pool is shutdown
shutdown bool
Expand All @@ -148,13 +147,13 @@ type ConnPool struct {
// Set maxTime to 0 to disable reaping. maxStreams is used to control
// the number of idle streams allowed.
// If TLS settings are provided outgoing connections use TLS.
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsConfig *tls.Config) *ConnPool {
func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.DCWrapper) *ConnPool {
pool := &ConnPool{
logOutput: logOutput,
maxTime: maxTime,
maxStreams: maxStreams,
pool: make(map[string]*Conn),
tlsConfig: tlsConfig,
tlsWrap: tlsWrap,
shutdownCh: make(chan struct{}),
}
if maxTime > 0 {
Expand Down Expand Up @@ -183,14 +182,14 @@ func (p *ConnPool) Shutdown() error {

// Acquire is used to get a connection that is
// pooled or to return a new connection
func (p *ConnPool) acquire(addr net.Addr, version int) (*Conn, error) {
func (p *ConnPool) acquire(dc string, addr net.Addr, version int) (*Conn, error) {
// Check for a pooled ocnn
if conn := p.getPooled(addr, version); conn != nil {
return conn, nil
}

// Create a new connection
return p.getNewConn(addr, version)
return p.getNewConn(dc, addr, version)
}

// getPooled is used to return a pooled connection
Expand All @@ -206,7 +205,7 @@ func (p *ConnPool) getPooled(addr net.Addr, version int) *Conn {
}

// getNewConn is used to return a new connection
func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int) (*Conn, error) {
// Try to dial the conn
conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
if err != nil {
Expand All @@ -220,15 +219,15 @@ func (p *ConnPool) getNewConn(addr net.Addr, version int) (*Conn, error) {
}

// Check if TLS is enabled
if p.tlsConfig != nil {
if p.tlsWrap != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close()
return nil, err
}

// Wrap the connection in a TLS client
tlsConn, err := tlsutil.WrapTLSClient(conn, p.tlsConfig)
tlsConn, err := p.tlsWrap(dc, conn)
if err != nil {
conn.Close()
return nil, err
Expand Down Expand Up @@ -314,11 +313,11 @@ func (p *ConnPool) releaseConn(conn *Conn) {
}

// getClient is used to get a usable client for an address and protocol version
func (p *ConnPool) getClient(addr net.Addr, version int) (*Conn, *StreamClient, error) {
func (p *ConnPool) getClient(dc string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
retries := 0
START:
// Try to get a conn first
conn, err := p.acquire(addr, version)
conn, err := p.acquire(dc, addr, version)
if err != nil {
return nil, nil, fmt.Errorf("failed to get conn: %v", err)
}
Expand All @@ -340,9 +339,9 @@ START:
}

// RPC is used to make an RPC call to a remote host
func (p *ConnPool) RPC(addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
func (p *ConnPool) RPC(dc string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
// Get a usable client
conn, sc, err := p.getClient(addr, version)
conn, sc, err := p.getClient(dc, addr, version)
if err != nil {
return fmt.Errorf("rpc error: %v", err)
}
Expand Down
22 changes: 11 additions & 11 deletions consul/raft_rpc.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package consul

import (
"crypto/tls"
"fmt"
"github.com/hashicorp/consul/tlsutil"
"net"
"sync"
"time"

"github.com/hashicorp/consul/tlsutil"
)

// RaftLayer implements the raft.StreamLayer interface,
Expand All @@ -18,8 +18,8 @@ type RaftLayer struct {
// connCh is used to accept connections
connCh chan net.Conn

// TLS configuration
tlsConfig *tls.Config
// TLS wrapper
tlsWrap tlsutil.Wrapper

// Tracks if we are closed
closed bool
Expand All @@ -30,12 +30,12 @@ type RaftLayer struct {
// NewRaftLayer is used to initialize a new RaftLayer which can
// be used as a StreamLayer for Raft. If a tlsConfig is provided,
// then the connection will use TLS.
func NewRaftLayer(addr net.Addr, tlsConfig *tls.Config) *RaftLayer {
func NewRaftLayer(addr net.Addr, tlsWrap tlsutil.Wrapper) *RaftLayer {
layer := &RaftLayer{
addr: addr,
connCh: make(chan net.Conn),
tlsConfig: tlsConfig,
closeCh: make(chan struct{}),
addr: addr,
connCh: make(chan net.Conn),
tlsWrap: tlsWrap,
closeCh: make(chan struct{}),
}
return layer
}
Expand Down Expand Up @@ -87,15 +87,15 @@ func (l *RaftLayer) Dial(address string, timeout time.Duration) (net.Conn, error
}

// Check for tls mode
if l.tlsConfig != nil {
if l.tlsWrap != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
conn.Close()
return nil, err
}

// Wrap the connection in a TLS client
conn, err = tlsutil.WrapTLSClient(conn, l.tlsConfig)
conn, err = l.tlsWrap(conn)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions consul/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ func (s *Server) forwardLeader(method string, args interface{}, reply interface{
if server == nil {
return structs.ErrNoLeader
}
return s.connPool.RPC(server.Addr, server.Version, method, args, reply)
return s.connPool.RPC(s.config.Datacenter, server.Addr, server.Version, method, args, reply)
}

// forwardDC is used to forward an RPC call to a remote DC, or fail if no servers
Expand All @@ -229,7 +229,7 @@ func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{

// Forward to remote Consul
metrics.IncrCounter([]string{"consul", "rpc", "cross-dc", dc}, 1)
return s.connPool.RPC(server.Addr, server.Version, method, args, reply)
return s.connPool.RPC(dc, server.Addr, server.Version, method, args, reply)
}

// globalRPC is used to forward an RPC request to one server in each datacenter.
Expand Down
16 changes: 10 additions & 6 deletions consul/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/hashicorp/consul/acl"
"github.com/hashicorp/consul/tlsutil"
"github.com/hashicorp/golang-lru"
"github.com/hashicorp/raft"
"github.com/hashicorp/raft-boltdb"
Expand Down Expand Up @@ -182,9 +183,9 @@ func NewServer(config *Config) (*Server, error) {
config.LogOutput = os.Stderr
}

// Create the tlsConfig for outgoing connections
// Create the tls wrapper for outgoing connections
tlsConf := config.tlsConfig()
tlsConfig, err := tlsConf.OutgoingTLSConfig()
tlsWrap, err := tlsConf.OutgoingTLSWrapper()
if err != nil {
return nil, err
}
Expand All @@ -207,7 +208,7 @@ func NewServer(config *Config) (*Server, error) {
// Create server
s := &Server{
config: config,
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsConfig),
connPool: NewPool(config.LogOutput, serverRPCCache, serverMaxStreams, tlsWrap),
eventChLAN: make(chan serf.Event, 256),
eventChWAN: make(chan serf.Event, 256),
localConsuls: make(map[string]*serverParts),
Expand Down Expand Up @@ -242,7 +243,7 @@ func NewServer(config *Config) (*Server, error) {
}

// Initialize the RPC layer
if err := s.setupRPC(tlsConfig); err != nil {
if err := s.setupRPC(tlsWrap); err != nil {
s.Shutdown()
return nil, fmt.Errorf("Failed to start RPC layer: %v", err)
}
Expand Down Expand Up @@ -410,7 +411,7 @@ func (s *Server) setupRaft() error {
}

// setupRPC is used to setup the RPC listener
func (s *Server) setupRPC(tlsConfig *tls.Config) error {
func (s *Server) setupRPC(tlsWrap tlsutil.DCWrapper) error {
// Create endpoints
s.endpoints.Status = &Status{s}
s.endpoints.Catalog = &Catalog{s}
Expand Down Expand Up @@ -453,7 +454,10 @@ func (s *Server) setupRPC(tlsConfig *tls.Config) error {
return fmt.Errorf("RPC advertise address is not advertisable: %v", addr)
}

s.raftLayer = NewRaftLayer(advertise, tlsConfig)
// Provide a DC specific wrapper. Raft replication is only
// ever done in the same datacenter, so we can provide it as a constant.
wrapper := tlsutil.SpecificDC(s.config.Datacenter, tlsWrap)
s.raftLayer = NewRaftLayer(advertise, wrapper)
return nil
}

Expand Down
Loading

0 comments on commit ebf961e

Please sign in to comment.