Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API client Profile credential loader #5993

Merged
merged 1 commit into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,11 @@ func connect(ctx context.Context, cfg Config) (*Client, error) {
}

// Connect with dialer provided in creds.
if dialer, err := creds.Dialer(); err == nil {
if dialer, err := creds.Dialer(cfg.KeepAlivePeriod, cfg.DialTimeout); err != nil {
if !trace.IsNotImplemented(err) {
sendError(trace.Wrap(err))
}
} else {
syncConnect(constants.APIDomain, &Client{
c: cfg,
tlsConfig: tlsConfig,
Expand Down Expand Up @@ -255,6 +259,10 @@ func connect(ctx context.Context, cfg Config) (*Client, error) {
}
// errChan is closed, return errors.
if len(errs) == 0 {
if len(cfg.Addrs) == 0 && cfg.Dialer == nil {
// Some credentials don't require these fields. If no errors propogate, then they need to provide these fields.
return nil, trace.Errorf("all auth methods failed: try providing Addrs or Dialer in config")
}
return nil, trace.Errorf("all auth methods failed")
}
return nil, trace.Wrap(trace.NewAggregate(errs...), "all auth methods failed")
Expand Down
10 changes: 5 additions & 5 deletions api/client/contextdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ func (f ContextDialerFunc) DialContext(ctx context.Context, network, addr string
}

// NewDialer makes a new dialer.
func NewDialer(keepAliveInterval, dialTimeout time.Duration) ContextDialer {
func NewDialer(keepAlivePeriod, dialTimeout time.Duration) ContextDialer {
return &net.Dialer{
Timeout: dialTimeout,
KeepAlive: keepAliveInterval,
KeepAlive: keepAlivePeriod,
}
}

// NewTunnelDialer make a new ssh tunnel dialer
func NewTunnelDialer(ssh ssh.ClientConfig, keepAliveInterval, dialTimeout time.Duration) ContextDialer {
dialer := NewDialer(keepAliveInterval, dialTimeout)
// NewTunnelDialer make a new ssh tunnel dialer.
func NewTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration) ContextDialer {
dialer := NewDialer(keepAlivePeriod, dialTimeout)
return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
conn, err = dialer.DialContext(ctx, network, addr)
if err != nil {
Expand Down
98 changes: 93 additions & 5 deletions api/client/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ limitations under the License.
package client

import (
"context"
"crypto/tls"
"crypto/x509"
"io/ioutil"
"net"
"time"

"github.com/gravitational/teleport/api/constants"

Expand All @@ -30,8 +33,8 @@ import (

// Credentials are used to authenticate to Auth.
type Credentials interface {
// Dialer is used to dial a connection to Auth.
Dialer() (ContextDialer, error)
// Dialer is used to create a dialer used to connect to Auth.
Dialer(keepAliveInterval, dialTimeout time.Duration) (ContextDialer, error)
// TLSConfig returns TLS configuration used to connect to Auth.
TLSConfig() (*tls.Config, error)
// SSHClientConfig returns SSH configuration used to connect to Proxy through tunnel.
Expand All @@ -52,7 +55,7 @@ type TLSConfigCreds struct {
}

// Dialer is used to dial a connection to Auth.
func (c *TLSConfigCreds) Dialer() (ContextDialer, error) {
func (c *TLSConfigCreds) Dialer(keepAliveInterval, dialTimeout time.Duration) (ContextDialer, error) {
return nil, trace.NotImplemented("no dialer")
}

Expand Down Expand Up @@ -87,7 +90,7 @@ type KeyPairCreds struct {
}

// Dialer is used to dial a connection to Auth.
func (c *KeyPairCreds) Dialer() (ContextDialer, error) {
func (c *KeyPairCreds) Dialer(keepAliveInterval, dialTimeout time.Duration) (ContextDialer, error) {
return nil, trace.NotImplemented("no dialer")
}

Expand Down Expand Up @@ -134,7 +137,7 @@ type IdentityCreds struct {
}

// Dialer is used to dial a connection to Auth.
func (c *IdentityCreds) Dialer() (ContextDialer, error) {
func (c *IdentityCreds) Dialer(keepAliveInterval, dialTimeout time.Duration) (ContextDialer, error) {
return nil, trace.NotImplemented("no dialer")
}

Expand Down Expand Up @@ -166,6 +169,8 @@ func (c *IdentityCreds) SSHClientConfig() (*ssh.ClientConfig, error) {
return sshConfig, nil
}

// load is used to lazy load the identity file from persistent storage.
// This allows LoadIdentity to avoid possible errors for UX purposes.
func (c *IdentityCreds) load() error {
if c.identityFile != nil {
return nil
Expand All @@ -177,6 +182,89 @@ func (c *IdentityCreds) load() error {
return nil
}

// LoadProfile is used to load credentials from a tsh Profile.
// If dir is not specified, the default profile path will be used.
// If name is not specified, the current profile name will be used.
func LoadProfile(dir, name string) *ProfileCreds {
return &ProfileCreds{
dir: dir,
name: name,
}
}

// ProfileCreds are used to authenticate the client
// with a tsh profile with the given directory and name.
type ProfileCreds struct {
dir string
name string
profile *Profile
}

// Dialer is used to dial a connection to Auth.
func (c *ProfileCreds) Dialer(keepAliveInterval, dialTimeout time.Duration) (ContextDialer, error) {
sshConfig, err := c.SSHClientConfig()
if err != nil {
return nil, trace.Wrap(err)
}

dialer := NewTunnelDialer(*sshConfig, keepAliveInterval, dialTimeout)
return ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) {
// Ping web proxy to retrieve tunnel proxy address.
pr, err := Find(ctx, c.profile.WebProxyAddr, false, nil)
if err != nil {
return nil, trace.Wrap(err)
}

conn, err = dialer.DialContext(ctx, network, pr.Proxy.SSH.TunnelPublicAddr)
if err != nil {
// not wrapping on purpose to preserve the original error
return nil, err
}
return conn, nil
}), nil
}

// TLSConfig returns TLS configuration used to connect to Auth.
func (c *ProfileCreds) TLSConfig() (*tls.Config, error) {
if err := c.load(); err != nil {
return nil, trace.Wrap(err)
}

tlsConfig, err := c.profile.TLSConfig()
if err != nil {
return nil, trace.Wrap(err)
}

return configure(tlsConfig), nil
}

// SSHClientConfig returns SSH configuration used to connect to Proxy.
func (c *ProfileCreds) SSHClientConfig() (*ssh.ClientConfig, error) {
if err := c.load(); err != nil {
return nil, trace.Wrap(err)
}

sshConfig, err := c.profile.SSHClientConfig()
if err != nil {
return nil, trace.Wrap(err)
}

return sshConfig, nil
}

// load is used to lazy load the profile from persistent storage.
// This allows LoadProfile to avoid possible errors for UX purposes.
func (c *ProfileCreds) load() error {
if c.profile != nil {
return nil
}
var err error
if c.profile, err = ProfileFromDir(c.dir, c.name); err != nil {
return trace.BadParameter("profile could not be decoded: %v", err)
}
return nil
}

func configure(c *tls.Config) *tls.Config {
tlsConfig := c.Clone()

Expand Down
115 changes: 96 additions & 19 deletions api/client/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@ import (
"crypto/tls"
"crypto/x509"
"io/ioutil"
"os"
"path/filepath"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/utils/sshutils"

"github.com/stretchr/testify/require"
Expand All @@ -34,15 +38,12 @@ func TestLoadTLS(t *testing.T) {

// Load expected tls.Config.
expectedTLSConfig := getExpectedTLSConfig(t)

// Load TLSConfigCreds.
creds := LoadTLS(expectedTLSConfig)

// Build tls.Config and compare to expected tls.Config.
tlsConfig, err := creds.TLSConfig()
require.NoError(t, err)
require.Equal(t, expectedTLSConfig.Certificates, tlsConfig.Certificates)
require.Equal(t, expectedTLSConfig.RootCAs.Subjects(), tlsConfig.RootCAs.Subjects())
requireEqualTLSConfig(t, expectedTLSConfig, tlsConfig)

// Load invalid tls.Config.
invalidTLSCreds := LoadTLS(nil)
Expand All @@ -64,11 +65,11 @@ func TestLoadIdentityFile(t *testing.T) {
idFile := &IdentityFile{
PrivateKey: keyPEM,
Certs: Certs{
TLS: certPEM,
TLS: tlsCert,
SSH: sshCert,
},
CACerts: CACerts{
TLS: [][]byte{caCertPEM},
TLS: [][]byte{tlsCACert},
SSH: [][]byte{sshCACert},
},
}
Expand All @@ -77,17 +78,15 @@ func TestLoadIdentityFile(t *testing.T) {

// Load identity file from disk.
creds := LoadIdentityFile(path)

// Build tls.Config and compare to expected tls.Config.
tlsConfig, err := creds.TLSConfig()
require.NoError(t, err)
require.Equal(t, expectedTLSConfig.Certificates, tlsConfig.Certificates)
require.Equal(t, expectedTLSConfig.RootCAs.Subjects(), tlsConfig.RootCAs.Subjects())
requireEqualTLSConfig(t, expectedTLSConfig, tlsConfig)

// Build ssh.ClientConfig and compare to expected ssh.ClientConfig.
sshConfig, err := creds.SSHClientConfig()
require.NoError(t, err)
require.Equal(t, expectedSSHConfig.User, sshConfig.User)
requireEqualSSHConfig(t, expectedSSHConfig, sshConfig)

// Load invalid identity.
creds = LoadIdentityFile("invalid_path")
Expand All @@ -106,34 +105,99 @@ func TestLoadKeyPair(t *testing.T) {
// Write key pair and CAs files from bytes.
path := t.TempDir() + "username"
certPath, keyPath, caPath := path+".crt", path+".key", path+".cas"
err := ioutil.WriteFile(certPath, certPEM, 0600)
err := ioutil.WriteFile(certPath, tlsCert, 0600)
require.NoError(t, err)
err = ioutil.WriteFile(keyPath, keyPEM, 0600)
require.NoError(t, err)
err = ioutil.WriteFile(caPath, caCertPEM, 0600)
err = ioutil.WriteFile(caPath, tlsCACert, 0600)
require.NoError(t, err)

// Load key pair from disk.
creds := LoadKeyPair(certPath, keyPath, caPath)

// Build tls.Config and compare to expected tls.Config.
tlsConfig, err := creds.TLSConfig()
require.NoError(t, err)
require.Equal(t, expectedTLSConfig.Certificates, tlsConfig.Certificates)
require.Equal(t, expectedTLSConfig.RootCAs.Subjects(), tlsConfig.RootCAs.Subjects())
requireEqualTLSConfig(t, expectedTLSConfig, tlsConfig)

// Load invalid keypairs.
invalidIdentityCreds := LoadKeyPair("invalid_path", "invalid_path", "invalid_path")
_, err = invalidIdentityCreds.TLSConfig()
require.Error(t, err)
}

func TestLoadProfile(t *testing.T) {
t.Parallel()

// Load expected tls.Config and ssh.ClientConfig.
expectedTLSConfig := getExpectedTLSConfig(t)
expectedSSHConfig := getExpectedSSHConfig(t)

// Write identity file to disk.
dir := t.TempDir()
name := "proxy"
p := &Profile{
WebProxyAddr: "proxy:3080",
Username: "testUser",
Dir: dir,
}

// Save profile to a file.
err := p.SaveToDir(dir, true)
require.NoError(t, err)

// Write keys to disk.
keyDir := filepath.Join(dir, constants.SessionKeyDir)
err = os.MkdirAll(keyDir, 0700)
require.NoError(t, err)
userKeyDir := filepath.Join(keyDir, p.Name())
os.MkdirAll(userKeyDir, 0700)
require.NoError(t, err)
keyPath := filepath.Join(userKeyDir, p.Username)
err = ioutil.WriteFile(keyPath, []byte(keyPEM), 0600)
require.NoError(t, err)
tlsCertPath := filepath.Join(userKeyDir, p.Username+constants.FileExtTLSCert)
err = ioutil.WriteFile(tlsCertPath, []byte(tlsCert), 0600)
require.NoError(t, err)
tlsCasPath := filepath.Join(userKeyDir, constants.FileNameTLSCerts)
err = ioutil.WriteFile(tlsCasPath, []byte(tlsCACert), 0600)
require.NoError(t, err)
sshCertPath := filepath.Join(userKeyDir, p.Username+constants.FileExtSSHCert)
err = ioutil.WriteFile(sshCertPath, []byte(sshCert), 0600)
require.NoError(t, err)
Comment on lines +164 to +166
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this will need to be updated after #5938

sshCasPath := filepath.Join(dir, constants.FileNameKnownHosts)
err = ioutil.WriteFile(sshCasPath, []byte(sshCACert), 0600)
require.NoError(t, err)

// Load profile from disk.
creds := LoadProfile(dir, name)
// Build tls.Config and compare to expected tls.Config.
tlsConfig, err := creds.TLSConfig()
require.NoError(t, err)
requireEqualTLSConfig(t, expectedTLSConfig, tlsConfig)
// Build ssh.ClientConfig and compare to expected ssh.ClientConfig.
sshConfig, err := creds.SSHClientConfig()
require.NoError(t, err)
requireEqualSSHConfig(t, expectedSSHConfig, sshConfig)
// Build Dialer
_, err = creds.Dialer(0, 0)
require.NoError(t, err)

// Load invalid profile.
creds = LoadProfile("invalid_dir", "invalid_name")
_, err = creds.TLSConfig()
require.Error(t, err)
_, err = creds.SSHClientConfig()
require.Error(t, err)
_, err = creds.Dialer(0, 0)
require.Error(t, err)
}

func getExpectedTLSConfig(t *testing.T) *tls.Config {
cert, err := tls.X509KeyPair(certPEM, keyPEM)
cert, err := tls.X509KeyPair(tlsCert, keyPEM)
require.NoError(t, err)

pool := x509.NewCertPool()
require.True(t, pool.AppendCertsFromPEM(caCertPEM))
require.True(t, pool.AppendCertsFromPEM(tlsCACert))

return configure(&tls.Config{
Certificates: []tls.Certificate{cert},
Expand All @@ -148,8 +212,21 @@ func getExpectedSSHConfig(t *testing.T) *ssh.ClientConfig {
return config
}

func requireEqualTLSConfig(t *testing.T, expected *tls.Config, actual *tls.Config) {
require.Empty(t, cmp.Diff(expected, actual,
cmpopts.IgnoreFields(tls.Config{}, "GetClientCertificate"),
cmpopts.IgnoreUnexported(tls.Config{}, x509.CertPool{}),
))
}

func requireEqualSSHConfig(t *testing.T, expected *ssh.ClientConfig, actual *ssh.ClientConfig) {
require.Empty(t, cmp.Diff(expected, actual,
cmpopts.IgnoreFields(ssh.ClientConfig{}, "Auth", "HostKeyCallback"),
))
}

var (
certPEM = []byte(`-----BEGIN CERTIFICATE-----
tlsCert = []byte(`-----BEGIN CERTIFICATE-----
MIIDyzCCArOgAwIBAgIQD3MiJ2Au8PicJpCNFbvcETANBgkqhkiG9w0BAQsFADBe
MRQwEgYDVQQKEwtleGFtcGxlLmNvbTEUMBIGA1UEAxMLZXhhbXBsZS5jb20xMDAu
BgNVBAUTJzIwNTIxNzE3NzMzMTIxNzQ2ODMyNjA5NjAxODEwODc0NTAzMjg1ODAe
Expand Down Expand Up @@ -201,7 +278,7 @@ pr5VAoGBAJBhNjs9wAu+ZoPcMZcjIXT/BAj2tQYiHoRnNpvQjDYbQueUBeI0Ry8d
90Ns/9SamlBo9j8ETm9g9D3EVir9zF5XvoR13OdN9gabGy1GuubT
-----END RSA PRIVATE KEY-----`)

caCertPEM = []byte(`-----BEGIN CERTIFICATE-----
tlsCACert = []byte(`-----BEGIN CERTIFICATE-----
MIIDiTCCAnGgAwIBAgIRAJlp/39yg8U604bjsxgcoC0wDQYJKoZIhvcNAQELBQAw
XjEUMBIGA1UEChMLZXhhbXBsZS5jb20xFDASBgNVBAMTC2V4YW1wbGUuY29tMTAw
LgYDVQQFEycyMDM5MjIyNTY2MzcxMDQ0NDc3MzYxNjA0MTk0NjU2MTgzMDA5NzMw
Expand Down
Loading