Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

connect: Enable renewing the intermediate cert in the primary DC #8784

Merged
merged 6 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions agent/connect/ca/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ import (
// on servers and CA provider.
var ErrRateLimited = errors.New("operation rate limited by CA provider")

// PrimaryIntermediateProviders is a list of CA providers that make use use of an
// intermediate cert in the primary datacenter as well as the secondary. This is used
// when determining whether to run the intermediate renewal routine in the primary.
var PrimaryIntermediateProviders = map[string]struct{}{
"vault": struct{}{},
}

// ProviderConfig encapsulates all the data Consul passes to `Configure` on a
// new provider instance. The provider must treat this as read-only and make
// copies of any map or slice if it might modify them internally.
Expand Down
7 changes: 3 additions & 4 deletions agent/connect/ca/provider_consul.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@ import (
"github.com/hashicorp/go-hclog"
)

const (

var (
// NotBefore will be CertificateTimeDriftBuffer in the past to account for
// time drift between different servers.
CertificateTimeDriftBuffer = time.Minute
)

var ErrNotInitialized = errors.New("provider not initialized")
ErrNotInitialized = errors.New("provider not initialized")
)

type ConsulProvider struct {
Delegate ConsulProviderStateDelegate
Expand Down
38 changes: 26 additions & 12 deletions agent/connect/ca/provider_vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io/ioutil"
"net/http"
"strings"
"time"

"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
Expand Down Expand Up @@ -92,48 +93,47 @@ func (v *VaultProvider) Configure(cfg ProviderConfig) error {

// Set up a renewer to renew the token automatically, if supported.
if token.Renewable {
lifetimeWatcher, err := client.NewLifetimeWatcher(&vaultapi.LifetimeWatcherInput{
renewer, err := client.NewRenewer(&vaultapi.RenewerInput{
Secret: &vaultapi.Secret{
Auth: &vaultapi.SecretAuth{
ClientToken: config.Token,
Renewable: token.Renewable,
LeaseDuration: secret.LeaseDuration,
},
},
Increment: token.TTL,
RenewBehavior: vaultapi.RenewBehaviorIgnoreErrors,
Increment: token.TTL,
})
if err != nil {
return fmt.Errorf("Error beginning Vault provider token renewal: %v", err)
}

ctx, cancel := context.WithCancel(context.TODO())
v.shutdown = cancel
go v.renewToken(ctx, lifetimeWatcher)
go v.renewToken(ctx, renewer)
}

return nil
}

// renewToken uses a vaultapi.Renewer to repeatedly renew our token's lease.
func (v *VaultProvider) renewToken(ctx context.Context, watcher *vaultapi.LifetimeWatcher) {
go watcher.Start()
defer watcher.Stop()
func (v *VaultProvider) renewToken(ctx context.Context, renewer *vaultapi.Renewer) {
go renewer.Renew()
defer renewer.Stop()

for {
select {
case <-ctx.Done():
return

case err := <-watcher.DoneCh():
case err := <-renewer.DoneCh():
if err != nil {
v.logger.Error("Error renewing token for Vault provider", "error", err)
}

// Watcher routine has finished, so start it again.
go watcher.Start()
// Renewer routine has finished, so start it again.
go renewer.Renew()

case <-watcher.RenewCh():
case <-renewer.RenewCh():
v.logger.Error("Successfully renewed token for Vault provider")
}
}
Expand Down Expand Up @@ -384,6 +384,7 @@ func (v *VaultProvider) GenerateIntermediate() (string, error) {
"csr": csr,
"use_csr_values": true,
"format": "pem_bundle",
"ttl": v.config.IntermediateCertTTL.String(),
})
if err != nil {
return "", err
Expand Down Expand Up @@ -456,6 +457,7 @@ func (v *VaultProvider) SignIntermediate(csr *x509.CertificateRequest) (string,
"use_csr_values": true,
"format": "pem_bundle",
"max_path_length": 0,
"ttl": v.config.IntermediateCertTTL.String(),
})
if err != nil {
return "", err
Expand All @@ -475,8 +477,20 @@ func (v *VaultProvider) SignIntermediate(csr *x509.CertificateRequest) (string,
// CrossSignCA takes a CA certificate and cross-signs it to form a trust chain
// back to our active root.
func (v *VaultProvider) CrossSignCA(cert *x509.Certificate) (string, error) {
rootPEM, err := v.ActiveRoot()
if err != nil {
return "", err
}
rootCert, err := connect.ParseCert(rootPEM)
if err != nil {
return "", fmt.Errorf("error parsing root cert: %v", err)
}
if rootCert.NotAfter.Before(time.Now()) {
return "", fmt.Errorf("root certificate is expired")
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that means that Vault would happily sign a cert with an expired root CA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it's because the sign-self-issued endpoint does minimal validation (only checks for CA cert and self-issued) so it was possible for the cert to expire and the provider would still cross-sign with it.


var pemBuf bytes.Buffer
err := pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
err = pem.Encode(&pemBuf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
if err != nil {
return "", err
}
Expand Down
131 changes: 4 additions & 127 deletions agent/connect/ca/provider_vault_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@ import (
"io/ioutil"
"os"
"os/exec"
"sync"
"testing"
"time"

"github.com/hashicorp/consul/agent/connect"
"github.com/hashicorp/consul/agent/structs"
"github.com/hashicorp/consul/sdk/freeport"
"github.com/hashicorp/consul/sdk/testutil/retry"
"github.com/hashicorp/go-hclog"
vaultapi "github.com/hashicorp/vault/api"
Expand Down Expand Up @@ -70,7 +68,7 @@ func TestVaultCAProvider_RenewToken(t *testing.T) {
require.NoError(t, err)
providerToken := secret.Auth.ClientToken

_, err = createVaultProvider(t, true, testVault.addr, providerToken, nil)
_, err = createVaultProvider(t, true, testVault.Addr, providerToken, nil)
require.NoError(t, err)

// Check the last renewal time.
Expand Down Expand Up @@ -382,19 +380,19 @@ func getIntermediateCertTTL(t *testing.T, caConf *structs.CAConfiguration) time.
return dur
}

func testVaultProvider(t *testing.T) (*VaultProvider, *testVaultServer) {
func testVaultProvider(t *testing.T) (*VaultProvider, *TestVaultServer) {
return testVaultProviderWithConfig(t, true, nil)
}

func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[string]interface{}) (*VaultProvider, *testVaultServer) {
func testVaultProviderWithConfig(t *testing.T, isPrimary bool, rawConf map[string]interface{}) (*VaultProvider, *TestVaultServer) {
testVault, err := runTestVault(t)
if err != nil {
t.Fatalf("err: %v", err)
}

testVault.WaitUntilReady(t)

provider, err := createVaultProvider(t, isPrimary, testVault.addr, testVault.rootToken, rawConf)
provider, err := createVaultProvider(t, isPrimary, testVault.Addr, testVault.RootToken, rawConf)
if err != nil {
testVault.Stop()
t.Fatalf("err: %v", err)
Expand Down Expand Up @@ -459,124 +457,3 @@ func skipIfVaultNotPresent(t *testing.T) {
t.Skipf("%q not found on $PATH - download and install to run this test", vaultBinaryName)
}
}

func runTestVault(t *testing.T) (*testVaultServer, error) {
vaultBinaryName := os.Getenv("VAULT_BINARY_NAME")
if vaultBinaryName == "" {
vaultBinaryName = "vault"
}

path, err := exec.LookPath(vaultBinaryName)
if err != nil || path == "" {
return nil, fmt.Errorf("%q not found on $PATH", vaultBinaryName)
}

ports := freeport.MustTake(2)
returnPortsFn := func() {
freeport.Return(ports)
}

var (
clientAddr = fmt.Sprintf("127.0.0.1:%d", ports[0])
clusterAddr = fmt.Sprintf("127.0.0.1:%d", ports[1])
)

const token = "root"

client, err := vaultapi.NewClient(&vaultapi.Config{
Address: "http://" + clientAddr,
})
if err != nil {
returnPortsFn()
return nil, err
}
client.SetToken(token)

args := []string{
"server",
"-dev",
"-dev-root-token-id",
token,
"-dev-listen-address",
clientAddr,
"-address",
clusterAddr,
}

cmd := exec.Command(vaultBinaryName, args...)
cmd.Stdout = ioutil.Discard
cmd.Stderr = ioutil.Discard
if err := cmd.Start(); err != nil {
returnPortsFn()
return nil, err
}

testVault := &testVaultServer{
rootToken: token,
addr: "http://" + clientAddr,
cmd: cmd,
client: client,
returnPortsFn: returnPortsFn,
}
t.Cleanup(func() {
testVault.Stop()
})
return testVault, nil
}

type testVaultServer struct {
rootToken string
addr string
cmd *exec.Cmd
client *vaultapi.Client

// returnPortsFn will put the ports claimed for the test back into the
returnPortsFn func()
}

var printedVaultVersion sync.Once

func (v *testVaultServer) WaitUntilReady(t *testing.T) {
var version string
retry.Run(t, func(r *retry.R) {
resp, err := v.client.Sys().Health()
if err != nil {
r.Fatalf("err: %v", err)
}
if !resp.Initialized {
r.Fatalf("vault server is not initialized")
}
if resp.Sealed {
r.Fatalf("vault server is sealed")
}
version = resp.Version
})
printedVaultVersion.Do(func() {
fmt.Fprintf(os.Stderr, "[INFO] agent/connect/ca: testing with vault server version: %s\n", version)
})
}

func (v *testVaultServer) Stop() error {
// There was no process
if v.cmd == nil {
return nil
}

if v.cmd.Process != nil {
if err := v.cmd.Process.Signal(os.Interrupt); err != nil {
return fmt.Errorf("failed to kill vault server: %v", err)
}
}

// wait for the process to exit to be sure that the data dir can be
// deleted on all platforms.
if err := v.cmd.Wait(); err != nil {
return err
}

if v.returnPortsFn != nil {
v.returnPortsFn()
}

return nil
}
Loading