Skip to content

Commit

Permalink
Support cloud auto-join
Browse files Browse the repository at this point in the history
  • Loading branch information
ishustava committed Mar 5, 2020
1 parent 921afda commit 354e9d6
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 18 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ require (
github.com/hashicorp/consul v1.7.1
github.com/hashicorp/consul/api v1.4.0
github.com/hashicorp/consul/sdk v0.4.0
github.com/hashicorp/go-discover v0.0.0-20191202160150-7ec2cfbda7a2
github.com/hashicorp/go-hclog v0.12.0
github.com/hashicorp/go-multierror v1.0.0
github.com/hashicorp/golang-lru v0.5.3 // indirect
Expand Down
100 changes: 85 additions & 15 deletions subcommand/get-consul-client-ca/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ import (
"fmt"
"io/ioutil"
"os"
"strings"
"sync"
"time"

"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/command/flags"
"github.com/hashicorp/consul/lib"
"github.com/hashicorp/go-discover"
discoverk8s "github.com/hashicorp/go-discover/provider/k8s"
"github.com/hashicorp/go-hclog"
"github.com/mitchellh/cli"
)
Expand All @@ -20,21 +24,24 @@ type Command struct {
flags *flag.FlagSet

flagOutputFile string
flagHttpAddr string
flagServerAddr string
flagCAFile string
flagTLSServerName string
flagLogLevel string

once sync.Once
help string

providers map[string]discover.Provider
}

func (c *Command) init() {
c.flags = flag.NewFlagSet("", flag.ContinueOnError)
c.flags.StringVar(&c.flagOutputFile, "output-file", "",
"The path to the file where to put the Consul client's CA certificate.")
c.flags.StringVar(&c.flagHttpAddr, "http-addr", "",
"The HTTP address of the Consul server. This can also be provided via the CONSUL_HTTP_ADDR environment variable.")
c.flags.StringVar(&c.flagServerAddr, "server-addr", "",
"The address of the Consul server or the Cloud auto-join string. The server must be running with TLS enabled."+
"The default HTTPS port 8501 will be used if port is not provided.")
c.flags.StringVar(&c.flagCAFile, "ca-file", "",
"The path to the CA file to use when making requests to the Consul server. This can also be provided via the CONSUL_CACERT environment variable")
c.flags.StringVar(&c.flagTLSServerName, "tls-server-name", "",
Expand All @@ -61,13 +68,6 @@ func (c *Command) Run(args []string) int {
return 1
}

// create Consul client
consulClient, err := c.consulClient()
if err != nil {
c.UI.Error(fmt.Sprintf("Error initializing Consul client: %s", err))
return 1
}

// create a logger
level := hclog.LevelFromString(c.flagLogLevel)
if level == hclog.NoLevel {
Expand All @@ -79,6 +79,13 @@ func (c *Command) Run(args []string) int {
Output: os.Stderr,
})

// create Consul client
consulClient, err := c.consulClient(logger)
if err != nil {
c.UI.Error(fmt.Sprintf("Error initializing Consul client: %s", err))
return 1
}

// Get the active CA root from Consul
// Wait until it gets a successful response
var activeRoot string
Expand All @@ -90,7 +97,7 @@ func (c *Command) Run(args []string) int {
continue
}

activeRoot, err = c.getActiveRoot(caRoots)
activeRoot, err = getActiveRoot(caRoots)
if err != nil {
logger.Info("Could not get an active root", "err", err)
time.Sleep(1 * time.Second)
Expand All @@ -104,14 +111,56 @@ func (c *Command) Run(args []string) int {
return 1
}

c.UI.Info(fmt.Sprintf("Successfully written Consul client CA to: %s", c.flagOutputFile))
return 0
}

func (c *Command) consulClient() (*api.Client, error) {
func (c *Command) consulClient(logger hclog.Logger) (*api.Client, error) {
cfg := api.DefaultConfig()
if c.flagHttpAddr != "" {
cfg.Address = c.flagHttpAddr

// First, check if the server address is a cloud auto-join string.
// If it is, discover server addresses through the cloud provider.
if strings.Contains(c.flagServerAddr, "provider=") {
disco, err := c.newDiscover()
if err != nil {
return nil, err
}
logger.Debug("using cloud auto-join with", c.flagServerAddr)
servers, err := disco.Addrs(c.flagServerAddr, logger.StandardLogger(&hclog.StandardLoggerOptions{
InferLevels: true,
}))
if err != nil {
return nil, err
}

// check if we discovered any servers
if len(servers) == 0 {
return nil, fmt.Errorf("could not discover any Consul servers with %q", c.flagServerAddr)
}

logger.Debug("discovered servers", strings.Join(servers, " "))

// Pick the first server from the list,
// ignoring the port since we need to use HTTP API
firstServer := strings.SplitN(servers[0], ":", 2)[0]
cfg.Address = fmt.Sprintf("%s:8501", firstServer)
cfg.Scheme = "https"
} else {
// check if the server URL is missing a port
host := strings.TrimPrefix(c.flagServerAddr, "https://")
host = strings.TrimPrefix(c.flagServerAddr, "http://")
parts := strings.SplitN(host, ":", 2)

// Use the default HTTPS port if port is missing.
// Otherwise, use the address the user has provided.
if len(parts) == 1 {
cfg.Address = fmt.Sprintf("%s:8501", c.flagServerAddr)
cfg.Scheme = "https"
} else {
cfg.Address = c.flagServerAddr
}
}

if c.flagCAFile != "" {
cfg.TLSConfig.CAFile = c.flagCAFile
}
Expand All @@ -122,7 +171,28 @@ func (c *Command) consulClient() (*api.Client, error) {
return api.NewClient(cfg)
}

func (c *Command) getActiveRoot(roots *api.CARootList) (string, error) {
// newDiscover initializes the new Discover object
// set up with all predefined providers, as well as
// the k8s provider.
func (c *Command) newDiscover() (*discover.Discover, error) {
if c.providers == nil {
c.providers = make(map[string]discover.Provider)
}

for k, v := range discover.Providers {
c.providers[k] = v
}
c.providers["k8s"] = &discoverk8s.Provider{}

return discover.New(
discover.WithUserAgent(lib.UserAgent()),
discover.WithProviders(c.providers),
)
}

// getActiveRoot returns the currently active root
// from the roots list, otherwise returns error.
func getActiveRoot(roots *api.CARootList) (string, error) {
if roots == nil {
return "", fmt.Errorf("ca roots is nil")
}
Expand Down
155 changes: 152 additions & 3 deletions subcommand/get-consul-client-ca/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ package getconsulclientca

import (
"crypto"
"crypto/x509"
"fmt"
"github.com/hashicorp/go-discover"
"io/ioutil"
"log"
"net"
"os"
"testing"
"time"

Expand Down Expand Up @@ -54,7 +59,7 @@ func TestRun(t *testing.T) {

// run the command
exitCode := cmd.Run([]string{
"-http-addr", a.HTTPAddr,
"-server-addr", a.HTTPAddr,
"-output-file", outputFile.Name(),
})
require.Equal(t, 0, exitCode)
Expand Down Expand Up @@ -97,7 +102,7 @@ func TestRun_ConsulServerAvailableLater(t *testing.T) {
exitCode := -1
go func() {
exitCode = cmd.Run([]string{
"-http-addr", fmt.Sprintf("http://127.0.0.1:%d", randomPorts[1]),
"-server-addr", fmt.Sprintf("http://127.0.0.1:%d", randomPorts[1]),
"-output-file", outputFile.Name(),
})
require.Equal(t, 0, exitCode)
Expand Down Expand Up @@ -193,7 +198,7 @@ func TestRun_GetsOnlyActiveRoot(t *testing.T) {
})

exitCode := cmd.Run([]string{
"-http-addr", a.HTTPAddr,
"-server-addr", a.HTTPAddr,
"-output-file", outputFile.Name(),
})
require.Equal(t, 0, exitCode)
Expand All @@ -219,6 +224,80 @@ func TestRun_GetsOnlyActiveRoot(t *testing.T) {
require.Equal(t, expectedCARoot, string(actualCARoot))
}

// Test that when using cloud auto-join
// it uses the provider to get the address of the server
func TestRun_WithProvider(t *testing.T) {
t.Parallel()
outputFile, err := ioutil.TempFile("", "ca")
require.NoError(t, err)

ui := cli.NewMockUi()
provider := &fakeProvider{}
cmd := Command{
UI: ui,
providers: map[string]discover.Provider{"fake": provider},
}

caFile, certFile, keyFile, cleanup := generateServerCerts(t)
defer cleanup()

randomPorts := freeport.MustTake(5)
// start the test server
a, err := testutil.NewTestServerConfigT(t, func(c *testutil.TestServerConfig) {
c.Connect = map[string]interface{}{
"enabled": true,
}
c.CAFile = caFile
c.CertFile = certFile
c.KeyFile = keyFile
c.Ports = &testutil.TestPortConfig{
DNS: randomPorts[0],
HTTP: randomPorts[1],
HTTPS: 8501,
SerfLan: randomPorts[2],
SerfWan: randomPorts[3],
Server: randomPorts[4],
}
})
require.NoError(t, err)
defer a.Stop()

// run the command
exitCode := cmd.Run([]string{
"-server-addr", "provider=fake",
"-tls-server-name", "localhost",
"-output-file", outputFile.Name(),
"-ca-file", caFile,
})
require.Equal(t, 0, exitCode, ui.ErrorWriter.String())

// check that the provider has been called
require.Equal(t, 1, provider.addrsNumCalls, "provider's Addrs method was not called")

client, err := api.NewClient(&api.Config{
Address: a.HTTPSAddr,
Scheme: "https",
TLSConfig: api.TLSConfig{
CAFile: caFile,
},
})
require.NoError(t, err)

// get the actual root ca cert from consul
roots, _, err := client.Agent().ConnectCARoots(nil)
require.NoError(t, err)
require.NotNil(t, roots)
require.NotNil(t, roots.Roots)
require.Len(t, roots.Roots, 1)
require.True(t, roots.Roots[0].Active)
expectedCARoot := roots.Roots[0].RootCertPEM

// read the file contents
actualCARoot, err := ioutil.ReadFile(outputFile.Name())
require.NoError(t, err)
require.Equal(t, expectedCARoot, string(actualCARoot))
}

// generateCA generates Consul CA
// and returns cert and key as pem strings.
func generateCA(t *testing.T) (caPem, keyPem string) {
Expand All @@ -237,3 +316,73 @@ func generateCA(t *testing.T) (caPem, keyPem string) {

return
}

// generateServerCerts generates Consul CA
// and a server certificate and saves them to temp files.
// It returns file names in this order:
// CA certificate, server certificate, and server key.
// Note that it's the responsibility of the caller to
// remove the temporary files created by this function.
func generateServerCerts(t *testing.T) (string, string, string, func()) {
require := require.New(t)

caFile, err := ioutil.TempFile("", "ca")
require.NoError(err)

certFile, err := ioutil.TempFile("", "cert")
require.NoError(err)

certKeyFile, err := ioutil.TempFile("", "key")
require.NoError(err)

// Generate CA
sn, err := tlsutil.GenerateSerialNumber()
require.NoError(err)

s, _, err := tlsutil.GeneratePrivateKey()
require.NoError(err)

constraints := []string{"consul", "localhost"}
ca, err := tlsutil.GenerateCA(s, sn, 1, constraints)
require.NoError(err)

// Generate Server Cert
name := fmt.Sprintf("server.%s.%s", "dc1", "consul")
DNSNames := []string{name, "localhost"}
IPAddresses := []net.IP{net.ParseIP("127.0.0.1")}
extKeyUsage := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}

sn, err = tlsutil.GenerateSerialNumber()
require.NoError(err)

pub, priv, err := tlsutil.GenerateCert(s, ca, sn, name, 1, DNSNames, IPAddresses, extKeyUsage)
require.NoError(err)

// Write certs and key to files
_, err = caFile.WriteString(ca)
require.NoError(err)
_, err = certFile.WriteString(pub)
require.NoError(err)
_, err = certKeyFile.WriteString(priv)
require.NoError(err)

cleanupFunc := func() {
os.Remove(caFile.Name())
os.Remove(certFile.Name())
os.Remove(certKeyFile.Name())
}
return caFile.Name(), certFile.Name(), certKeyFile.Name(), cleanupFunc
}

type fakeProvider struct {
addrsNumCalls int
}

func (p *fakeProvider) Addrs(args map[string]string, l *log.Logger) ([]string, error) {
p.addrsNumCalls++
return []string{"127.0.0.1"}, nil
}

func (p *fakeProvider) Help() string {
return "fake-provider help"
}

0 comments on commit 354e9d6

Please sign in to comment.