Skip to content

Commit

Permalink
[v8] backport #13310 (use auth_servers when proxying) (#13402)
Browse files Browse the repository at this point in the history
* Don't GetAuthServers in transport.start

* Don't GetAuthServers in AuthProxyDialerService

* Don't GetAuthServers in localSite

* Fix lib/web tests

* Review comments

Co-authored-by: Alan Parra <alan.parra@goteleport.com>
Co-authored-by: rosstimothy <39066650+rosstimothy@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 13, 2022
1 parent ca77c7c commit c8f8d93
Show file tree
Hide file tree
Showing 13 changed files with 137 additions and 166 deletions.
4 changes: 4 additions & 0 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ type AgentConfig struct {
ReverseTunnelServer Server
// LocalClusterName is the name of the cluster this agent is running in.
LocalClusterName string
// LocalAuthAddresses is a list of auth servers to use when dialing back to
// the local cluster.
LocalAuthAddresses []string
// Component is the teleport component that this agent runs in.
// It's important for routing incoming requests for local services (like an
// IoT node or kubernetes service).
Expand Down Expand Up @@ -516,6 +519,7 @@ func (a *Agent) processRequests(conn *ssh.Client) error {
log: a.log,
closeContext: a.ctx,
authClient: a.Client,
authServers: a.LocalAuthAddresses,
kubeDialAddr: a.KubeDialAddr,
channel: ch,
requestCh: req,
Expand Down
4 changes: 4 additions & 0 deletions lib/reversetunnel/agentpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ type AgentPoolConfig struct {
HostUUID string
// LocalCluster is a cluster name this client is a member of.
LocalCluster string
// LocalAuthAddresses is a list of auth servers to use when dialing back to
// the local cluster.
LocalAuthAddresses []string
// Clock is a clock used to get time, if not set,
// system clock is used
Clock clockwork.Clock
Expand Down Expand Up @@ -295,6 +298,7 @@ func (m *AgentPool) addAgent(lease track.Lease) error {
Server: m.cfg.Server,
ReverseTunnelServer: m.cfg.ReverseTunnelServer,
LocalClusterName: m.cfg.LocalCluster,
LocalAuthAddresses: m.cfg.LocalAuthAddresses,
Component: m.cfg.Component,
Tracker: m.proxyTracker,
Lease: lease,
Expand Down
33 changes: 13 additions & 20 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import (
"golang.org/x/crypto/ssh"
)

func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSite, error) {
func newlocalSite(srv *server, domainName string, authServers []string, client auth.ClientI) (*localSite, error) {
err := utils.RegisterPrometheusCollectors(localClusterCollectors...)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -66,6 +66,7 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi
accessPoint: accessPoint,
certificateCache: certificateCache,
domainName: domainName,
authServers: authServers,
remoteConns: make(map[connKey][]*remoteConn),
clock: srv.Clock,
log: log.WithFields(log.Fields{
Expand All @@ -88,9 +89,10 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi
//
// it implements RemoteSite interface
type localSite struct {
log log.FieldLogger
domainName string
srv *server
log log.FieldLogger
domainName string
authServers []string
srv *server

// client provides access to the Auth Server API of the local cluster.
client auth.ClientI
Expand Down Expand Up @@ -159,27 +161,18 @@ func (s *localSite) GetLastConnected() time.Time {
return s.clock.Now()
}

func (s *localSite) DialAuthServer() (conn net.Conn, err error) {
// get list of local auth servers
authServers, err := s.client.GetAuthServers()
if err != nil {
return nil, trace.Wrap(err)
}

if len(authServers) < 1 {
func (s *localSite) DialAuthServer() (net.Conn, error) {
if len(s.authServers) == 0 {
return nil, trace.ConnectionProblem(nil, "no auth servers available")
}

// try and dial to one of them, as soon as we are successful, return the net.Conn
for _, authServer := range authServers {
conn, err = net.DialTimeout("tcp", authServer.GetAddr(), apidefaults.DefaultDialTimeout)
if err == nil {
return conn, nil
}
addr := utils.ChooseRandomString(s.authServers)
conn, err := net.DialTimeout("tcp", addr, apidefaults.DefaultDialTimeout)
if err != nil {
return nil, trace.ConnectionProblem(err, "unable to connect to auth server")
}

// return the last error
return nil, trace.ConnectionProblem(err, "unable to connect to auth server")
return conn, nil
}

func (s *localSite) Dial(params DialParams) (net.Conn, error) {
Expand Down
2 changes: 1 addition & 1 deletion lib/reversetunnel/localsite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestLocalSiteOverlap(t *testing.T) {
},
}

site, err := newlocalSite(srv, "clustername", &mockLocalSiteClient{})
site, err := newlocalSite(srv, "clustername", nil /* authServers */, &mockLocalSiteClient{})
require.NoError(t, err)

nodeID := uuid.NewString()
Expand Down
4 changes: 4 additions & 0 deletions lib/reversetunnel/rc_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type RemoteClusterTunnelManagerConfig struct {
HostUUID string
// LocalCluster is a cluster name this client is a member of.
LocalCluster string
// LocalAuthAddresses is a list of auth servers to use when dialing back to
// the local cluster.
LocalAuthAddresses []string
// Local ReverseTunnelServer to reach other cluster members connecting to
// this proxy over a tunnel.
ReverseTunnelServer Server
Expand Down Expand Up @@ -216,6 +219,7 @@ func realNewAgentPool(ctx context.Context, cfg RemoteClusterTunnelManagerConfig,
HostSigner: cfg.HostSigner,
HostUUID: cfg.HostUUID,
LocalCluster: cfg.LocalCluster,
LocalAuthAddresses: cfg.LocalAuthAddresses,
Clock: cfg.Clock,
KubeDialAddr: cfg.KubeDialAddr,
ReverseTunnelServer: cfg.ReverseTunnelServer,
Expand Down
28 changes: 9 additions & 19 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,6 @@ type server struct {
offlineThreshold time.Duration
}

// DirectCluster is used to access cluster directly
type DirectCluster struct {
// Name is a cluster name
Name string
// Client is a client to the cluster
Client auth.ClientI
}

// Config is a reverse tunnel server configuration
type Config struct {
// ID is the ID of this server proxy
Expand All @@ -146,11 +138,12 @@ type Config struct {
// AccessPoint caches values and can still return results during connection
// problems.
LocalAccessPoint auth.ProxyAccessPoint
// LocalAuthAddresses is a list of auth servers to use when dialing back to
// the local cluster.
LocalAuthAddresses []string
// NewCachingAccessPoint returns new caching access points
// per remote cluster
NewCachingAccessPoint auth.NewRemoteProxyCachingAccessPoint
// DirectClusters is a list of clusters accessed directly
DirectClusters []DirectCluster
// Context is a signalling context
Context context.Context
// Clock is a clock used in the server, set up to
Expand Down Expand Up @@ -303,8 +296,6 @@ func NewServer(cfg Config) (Server, error) {

srv := &server{
Config: cfg,
localSites: []*localSite{},
remoteSites: []*remoteSite{},
localAuthClient: cfg.LocalAuthClient,
localAccessPoint: cfg.LocalAccessPoint,
newAccessPoint: cfg.NewCachingAccessPoint,
Expand All @@ -317,15 +308,13 @@ func NewServer(cfg Config) (Server, error) {
offlineThreshold: offlineThreshold,
}

for _, clusterInfo := range cfg.DirectClusters {
cluster, err := newlocalSite(srv, clusterInfo.Name, clusterInfo.Client)
if err != nil {
return nil, trace.Wrap(err)
}

srv.localSites = append(srv.localSites, cluster)
localSite, err := newlocalSite(srv, cfg.ClusterName, cfg.LocalAuthAddresses, cfg.LocalAuthClient)
if err != nil {
return nil, trace.Wrap(err)
}

srv.localSites = append(srv.localSites, localSite)

s, err := sshutils.NewServer(
teleport.ComponentReverseTunnelServer,
// TODO(klizhentas): improve interface, use struct instead of parameter list
Expand Down Expand Up @@ -629,6 +618,7 @@ func (s *server) handleTransport(sconn *ssh.ServerConn, nch ssh.NewChannel) {
log: s.log,
closeContext: s.ctx,
authClient: s.LocalAccessPoint,
authServers: s.LocalAuthAddresses,
channel: channel,
requestCh: requestCh,
component: teleport.ComponentReverseTunnelServer,
Expand Down
64 changes: 32 additions & 32 deletions lib/reversetunnel/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ type transport struct {
log logrus.FieldLogger
closeContext context.Context
authClient auth.ProxyAccessPoint
authServers []string
channel ssh.Channel
requestCh <-chan *ssh.Request

Expand Down Expand Up @@ -192,8 +193,6 @@ func (p *transport) start() {
return
}

var servers []string

// Parse and extract the dial request from the client.
dreq := parseDialReq(req.Payload)
if err := dreq.CheckAndSetDefaults(); err != nil {
Expand All @@ -202,23 +201,22 @@ func (p *transport) start() {
}
p.log.Debugf("Received out-of-band proxy transport request for %v [%v].", dreq.Address, dreq.ServerID)

// directAddress will hold the address of the node to dial to, if we don't
// have a tunnel for it.
var directAddress string

// Handle special non-resolvable addresses first.
switch dreq.Address {
// Connect to an Auth Server.
case RemoteAuthServer:
authServers, err := p.authClient.GetAuthServers()
if err != nil {
p.reply(req, false, []byte("connection rejected: failed to connect to auth server"))
return
}
if len(authServers) == 0 {
p.log.Warn("No auth servers registered in the cluster.")
p.reply(req, false, []byte("connection rejected: failed to connect to auth server"))
if len(p.authServers) == 0 {
p.log.Errorf("connection rejected: no auth servers configured")
p.reply(req, false, []byte("no auth servers configured"))

return
}
for _, as := range authServers {
servers = append(servers, as.GetAddr())
}

directAddress = utils.ChooseRandomString(p.authServers)
// Connect to the Kubernetes proxy.
case LocalKubernetes:
switch p.component {
Expand Down Expand Up @@ -252,7 +250,7 @@ func (p *transport) start() {
return
}
p.log.Debugf("Forwarding connection to %q", p.kubeDialAddr.Addr)
servers = append(servers, p.kubeDialAddr.Addr)
directAddress = p.kubeDialAddr.Addr
}

// LocalNode requests are for the single server running in the agent pool.
Expand Down Expand Up @@ -283,15 +281,16 @@ func (p *transport) start() {
// If this is a proxy and not an SSH node, try finding an inbound
// tunnel from the SSH node by dreq.ServerID. We'll need to forward
// dreq.Address as well.
fallthrough
directAddress = dreq.Address
default:
servers = append(servers, dreq.Address)
// Not a special address; could be empty.
directAddress = dreq.Address
}

// Get a connection to the target address. If a tunnel exists with matching
// search names, connection over the tunnel is returned. Otherwise a direct
// net.Dial is performed.
conn, useTunnel, err := p.getConn(servers, dreq)
conn, useTunnel, err := p.getConn(directAddress, dreq)
if err != nil {
errorMessage := fmt.Sprintf("connection rejected: %v", err)
fmt.Fprint(p.channel.Stderr(), errorMessage)
Expand Down Expand Up @@ -365,7 +364,7 @@ func (p *transport) handleChannelRequests(closeContext context.Context, useTunne
// getConn checks if the local site holds a connection to the target host,
// and if it does, attempts to dial through the tunnel. Otherwise directly
// dials to host.
func (p *transport) getConn(servers []string, r *sshutils.DialReq) (net.Conn, bool, error) {
func (p *transport) getConn(addr string, r *sshutils.DialReq) (net.Conn, bool, error) {
// This function doesn't attempt to dial if a host with one of the
// search names is not registered. It's a fast check.
p.log.Debugf("Attempting to dial through tunnel with server ID %q.", r.ServerID)
Expand All @@ -385,13 +384,13 @@ func (p *transport) getConn(servers []string, r *sshutils.DialReq) (net.Conn, bo
}

errTun := err
p.log.Debugf("Attempting to dial directly %v.", servers)
conn, err = directDial(servers)
p.log.Debugf("Attempting to dial directly %q.", addr)
conn, err = p.directDial(addr)
if err != nil {
return nil, false, trace.ConnectionProblem(err, "failed dialing through tunnel (%v) or directly (%v)", errTun, err)
}

p.log.Debugf("Returning direct dialed connection to %v.", servers)
p.log.Debugf("Returning direct dialed connection to %q.", addr)
return conn, false, nil
}

Expand Down Expand Up @@ -434,18 +433,19 @@ func (p *transport) reply(req *ssh.Request, ok bool, msg []byte) {
}
}

// directDial attempst to directly dial to the target host.
func directDial(servers []string) (net.Conn, error) {
var errors []error

for _, addr := range servers {
conn, err := net.Dial("tcp", addr)
if err == nil {
return conn, nil
}
// directDial attempts to directly dial to the target host.
func (p *transport) directDial(addr string) (net.Conn, error) {
if addr == "" {
return nil, trace.BadParameter("no address to dial")
}

errors = append(errors, err)
d := net.Dialer{
Timeout: apidefaults.DefaultDialTimeout,
}
conn, err := d.DialContext(p.closeContext, "tcp", addr)
if err != nil {
return nil, trace.Wrap(err)
}

return nil, trace.NewAggregate(errors...)
return conn, nil
}
34 changes: 15 additions & 19 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2861,27 +2861,22 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
HostSigners: []ssh.Signer{conn.ServerIdentity.KeySigner},
LocalAuthClient: conn.Client,
LocalAccessPoint: accessPoint,
LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServers),
NewCachingAccessPoint: process.newLocalCacheForRemoteProxy,
NewCachingAccessPointOldProxy: process.newLocalCacheForOldRemoteProxy,
Limiter: reverseTunnelLimiter,
DirectClusters: []reversetunnel.DirectCluster{
{
Name: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority],
Client: conn.Client,
},
},
KeyGen: cfg.Keygen,
Ciphers: cfg.Ciphers,
KEXAlgorithms: cfg.KEXAlgorithms,
MACAlgorithms: cfg.MACAlgorithms,
DataDir: process.Config.DataDir,
PollingPeriod: process.Config.PollingPeriod,
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
KeyGen: cfg.Keygen,
Ciphers: cfg.Ciphers,
KEXAlgorithms: cfg.KEXAlgorithms,
MACAlgorithms: cfg.MACAlgorithms,
DataDir: process.Config.DataDir,
PollingPeriod: process.Config.PollingPeriod,
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -3042,6 +3037,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
AccessPoint: accessPoint,
HostSigner: conn.ServerIdentity.KeySigner,
LocalCluster: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority],
LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServers),
KubeDialAddr: utils.DialAddrFromListenAddr(kubeDialAddr(cfg.Proxy, clusterNetworkConfig.GetProxyListenerMode())),
ReverseTunnelServer: tsrv,
FIPS: process.Config.FIPS,
Expand Down Expand Up @@ -3228,7 +3224,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {

var alpnServer *alpnproxy.Proxy
if !cfg.Proxy.DisableTLS && !cfg.Proxy.DisableALPNSNIListener && listeners.web != nil {
authDialerService := alpnproxyauth.NewAuthProxyDialerService(tsrv, accessPoint)
authDialerService := alpnproxyauth.NewAuthProxyDialerService(tsrv, clusterName, utils.NetAddrsToStrings(process.Config.AuthServers))
alpnRouter.Add(alpnproxy.HandlerDecs{
MatchFunc: alpnproxy.MatchByALPNPrefix(string(alpncommon.ProtocolAuth)),
HandlerWithConnInfo: authDialerService.HandleConnection,
Expand Down
Loading

0 comments on commit c8f8d93

Please sign in to comment.