Skip to content

Commit

Permalink
k8s vault auth
Browse files Browse the repository at this point in the history
  • Loading branch information
Kryvchun committed May 12, 2022
1 parent 7ecad12 commit aa8c510
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 60 deletions.
56 changes: 56 additions & 0 deletions dependency/client_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
rootcerts "github.com/hashicorp/go-rootcerts"
nomadapi "github.com/hashicorp/nomad/api"
vaultapi "github.com/hashicorp/vault/api"
vaultkubernetesauth "github.com/hashicorp/vault/api/auth/kubernetes"
)

// ClientSet is a collection of clients that dependencies use to communicate
Expand Down Expand Up @@ -94,6 +95,11 @@ type CreateVaultClientInput struct {
SSLCAPath string
ServerName string

K8SAuthRoleName string
K8SServiceAccountMountPath string
K8SServiceAccountToken string
K8SServiceMountPath string

TransportCustomDialer TransportDialer
TransportDialKeepAlive time.Duration
TransportDialTimeout time.Duration
Expand Down Expand Up @@ -333,6 +339,14 @@ func (c *ClientSet) CreateVaultClient(i *CreateVaultClientInput) error {
client.SetNamespace(i.Namespace)
}

// Set token using k8s auth method.
if i.K8SAuthRoleName != "" && i.Token == "" {
err = prepareK8SServiceTokenAuth(i, client)
if err != nil {
return fmt.Errorf("client set: vault: %w", err)
}
}

// Set the token if given
if i.Token != "" {
client.SetToken(i.Token)
Expand Down Expand Up @@ -521,3 +535,45 @@ func (c *ClientSet) Stop() {
c.nomad.httpClient.Transport.(*http.Transport).CloseIdleConnections()
}
}

func prepareK8SServiceTokenAuth(
i *CreateVaultClientInput,
client *vaultapi.Client,
) (err error) {
opts := make([]vaultkubernetesauth.LoginOption, 0, 1)

switch {
case i.K8SServiceAccountToken != "":
opts = append(opts, vaultkubernetesauth.WithServiceAccountToken(
i.K8SServiceAccountToken,
))
case i.K8SServiceAccountMountPath != "":
opts = append(opts, vaultkubernetesauth.WithServiceAccountTokenPath(
i.K8SServiceAccountMountPath,
))
default:
// The Kubernetes service account token JWT will be retrieved
// from /var/run/secrets/kubernetes.io/serviceaccount/token.
}

if i.K8SServiceMountPath != "" {
opts = append(opts, vaultkubernetesauth.WithMountPath(
i.K8SServiceMountPath,
))
}

k8sAuth, err := vaultkubernetesauth.NewKubernetesAuth(i.K8SAuthRoleName, opts...)
if err != nil {
return fmt.Errorf("k8s auth: new kubernetes auth: %w", err)
}

ctx := context.TODO()
sec, err := client.Auth().Login(ctx, k8sAuth)
if err != nil {
return fmt.Errorf("k8s auth: login: %w", err)
}

i.Token = sec.Auth.ClientToken

return nil
}
228 changes: 228 additions & 0 deletions dependency/client_set_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
package dependency

import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"testing"

"github.com/hashicorp/consul-template/test"
"github.com/hashicorp/vault/api"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestClientSet_unwrapVaultToken(t *testing.T) {
Expand Down Expand Up @@ -36,3 +44,223 @@ func TestClientSet_unwrapVaultToken(t *testing.T) {
t.Fatal(err)
}
}

func TestClientSet_K8SServiceTokenAuth(t *testing.T) {
t.Parallel()

validSecret := &api.Secret{Auth: &api.SecretAuth{ClientToken: vaultToken}}
invalidSecret := &api.Secret{Auth: &api.SecretAuth{ClientToken: "invalid"}}
require.NotEqual(t, validSecret, invalidSecret)

k8sLoginPathCond := func(mountPath string) func(r *http.Request) bool {
return func(r *http.Request) bool {
return r.URL.Path == "/v1/auth/"+mountPath+"/login"
}
}

t.Run("service_token", func(t *testing.T) {
t.Parallel()

testServerAddr := newVaultMockReversedProxy(t, vaultMock{
HandleCond: k8sLoginPathCond("kubernetes"),
HandleJSON: func(_ *http.Request, data map[string]interface{}) interface{} {
assert.Equal(t, data["jwt"], "service_token", data)
assert.Equal(t, data["role"], "default", data)

return validSecret
},
})

clientSet := NewClientSet()
err := clientSet.CreateVaultClient(&CreateVaultClientInput{
Address: testServerAddr,
K8SAuthRoleName: "default",
K8SServiceAccountToken: "service_token",
})
if err != nil {
t.Fatal(err)
}

_, err = clientSet.Vault().Logical().List("/entities")
require.NoError(t, err)
})

t.Run("service_token_from_file", func(t *testing.T) {
t.Parallel()

testServerAddr := newVaultMockReversedProxy(t, vaultMock{
HandleCond: k8sLoginPathCond("kubernetes"),
HandleJSON: func(_ *http.Request, data map[string]interface{}) interface{} {
assert.Equal(t, data["jwt"], "service_token", data)
assert.Equal(t, data["role"], "default_file", data)

return validSecret
},
})

f := test.CreateTempfile(t, []byte("service_token"))

clientSet := NewClientSet()
err := clientSet.CreateVaultClient(&CreateVaultClientInput{
Address: testServerAddr,
K8SAuthRoleName: "default_file",
K8SServiceAccountMountPath: f.Name(),
})
if err != nil {
t.Fatal(err)
}

_, err = clientSet.Vault().Logical().List("/entities")
require.NoError(t, err)
})

t.Run("service_token_file_value_priority", func(t *testing.T) {
t.Parallel()

testServerAddr := newVaultMockReversedProxy(t, vaultMock{
HandleCond: k8sLoginPathCond("kubernetes"),
HandleJSON: func(_ *http.Request, data map[string]interface{}) interface{} {
assert.Equal(t, data["jwt"], "service_token_value", data)
assert.Equal(t, data["role"], "default", data)

return validSecret
},
})

f := test.CreateTempfile(t, []byte("service_token_file"))

clientSet := NewClientSet()
err := clientSet.CreateVaultClient(&CreateVaultClientInput{
Address: testServerAddr,
K8SAuthRoleName: "default",
K8SServiceAccountMountPath: f.Name(),
K8SServiceAccountToken: "service_token_value",
})
if err != nil {
t.Fatal(err)
}

_, err = clientSet.Vault().Logical().List("/entities")
require.NoError(t, err)
})

t.Run("mount_path", func(t *testing.T) {
t.Parallel()

testServerAddr := newVaultMockReversedProxy(t, vaultMock{
HandleCond: k8sLoginPathCond("mount_path"),
HandleJSON: func(r *http.Request, data map[string]interface{}) interface{} {
return validSecret
},
})

clientSet := NewClientSet()
err := clientSet.CreateVaultClient(&CreateVaultClientInput{
Address: testServerAddr,
K8SAuthRoleName: "default",
K8SServiceAccountToken: "service_token",
K8SServiceMountPath: "mount_path",
})
if err != nil {
t.Fatal(err)
}

_, err = clientSet.Vault().Logical().List("/entities")
require.NoError(t, err)
})

t.Run("token_already_set", func(t *testing.T) {
t.Parallel()

testServerAddr := newVaultMockReversedProxy(t)

clientSet := NewClientSet()
err := clientSet.CreateVaultClient(&CreateVaultClientInput{
Address: testServerAddr,
Token: vaultToken,
K8SAuthRoleName: "default",
K8SServiceAccountToken: "service_token",
})
require.NoError(t, err)

_, err = clientSet.Vault().Logical().List("/entities")
require.NoError(t, err)
})

t.Run("auth_failed", func(t *testing.T) {
t.Parallel()

testServerAddr := newVaultMockReversedProxy(t, vaultMock{
HandleCond: k8sLoginPathCond("kubernetes"),
HandleJSON: func(*http.Request, map[string]interface{}) interface{} {
return invalidSecret
},
})

clientSet := NewClientSet()
err := clientSet.CreateVaultClient(&CreateVaultClientInput{
Address: testServerAddr,
K8SAuthRoleName: "default",
K8SServiceAccountToken: "service_token",
})
require.NoError(t, err)

_, err = clientSet.Vault().Logical().List("/entities")
require.Error(t, err)
})
}

type vaultMock struct {
HandleCond func(r *http.Request) bool
HandleJSON func(r *http.Request, data map[string]interface{}) interface{}
}

func (m vaultMock) processReq(tb testing.TB, w http.ResponseWriter, r *http.Request) {
if m.HandleJSON == nil {
return
}

var data map[string]interface{}
err := json.NewDecoder(r.Body).Decode(&data)
if !assert.NoError(tb, err) {
http.Error(w, err.Error(), http.StatusInternalServerError)

return
}

tb.Logf("%s: %s: %+v", r.Method, r.URL, data)

w.Header().Set("Content-Type", "application/json")

err = json.NewEncoder(w).Encode(m.HandleJSON(r, data))
assert.NoError(tb, err)
}

// newVaultMockReversedProxy mocks some calls and proxies others to Vault.
func newVaultMockReversedProxy(tb testing.TB, mocks ...vaultMock) string {
tb.Helper()

vaultURL, err := url.Parse(vaultAddr)
require.NoError(tb, err)

vaultReverseProxy := httputil.NewSingleHostReverseProxy(vaultURL)

testServer := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
for _, m := range mocks {
if !m.HandleCond(r) {
continue
}

m.processReq(tb, w, r)

return
}

vaultReverseProxy.ServeHTTP(w, r)
}),
)
tb.Cleanup(testServer.Close)

return testServer.URL
}
3 changes: 1 addition & 2 deletions dependency/connect_ca_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ import (
)

func TestConnectCAQuery_Fetch(t *testing.T) {

d := NewConnectCAQuery()
raw, _, err := d.Fetch(testClients, nil)
assert.NoError(t, err)
act := raw.([]*api.CARoot)
if assert.Len(t, act, 1) {
root := act[0]
assert.Equal(t, root.Name, "Consul CA Root Cert")
assert.Contains(t, root.Name, "Consul CA")
assert.True(t, root.Active)
assert.NotEmpty(t, root.RootCertPEM)
}
Expand Down
6 changes: 4 additions & 2 deletions dependency/vault_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -535,9 +535,11 @@ func TestVaultReadQuery_Fetch_PKI_Anonymous(t *testing.T) {
Token: "",
})
_, err = anonClient.vault.client.Auth().Token().LookupSelf()
if err == nil || !strings.Contains(err.Error(), "missing client token") {
if err == nil ||
!(strings.Contains(err.Error(), "missing client token") ||
strings.Contains(err.Error(), "permission denied")) {
// check environment for VAULT_TOKEN
t.Fatalf("expected a missing client token error but found: %v", err)
t.Fatalf("expected a missing client token error but found: %q", err)
}

d, err := NewVaultReadQuery("pki/cert/ca")
Expand Down
Loading

0 comments on commit aa8c510

Please sign in to comment.