Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v9] backport #13310 (use auth_servers when proxying) #13399

Merged
merged 6 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ type AgentConfig struct {
ReverseTunnelServer Server
// LocalClusterName is the name of the cluster this agent is running in.
LocalClusterName string
// LocalAuthAddresses is a list of auth servers to use when dialing back to
// the local cluster.
LocalAuthAddresses []string
// Component is the teleport component that this agent runs in.
// It's important for routing incoming requests for local services (like an
// IoT node or kubernetes service).
Expand Down Expand Up @@ -516,6 +519,7 @@ func (a *Agent) processRequests(conn *ssh.Client) error {
log: a.log,
closeContext: a.ctx,
authClient: a.Client,
authServers: a.LocalAuthAddresses,
kubeDialAddr: a.KubeDialAddr,
channel: ch,
requestCh: req,
Expand Down
4 changes: 4 additions & 0 deletions lib/reversetunnel/agentpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ type AgentPoolConfig struct {
HostUUID string
// LocalCluster is a cluster name this client is a member of.
LocalCluster string
// LocalAuthAddresses is a list of auth servers to use when dialing back to
// the local cluster.
LocalAuthAddresses []string
// Clock is a clock used to get time, if not set,
// system clock is used
Clock clockwork.Clock
Expand Down Expand Up @@ -296,6 +299,7 @@ func (m *AgentPool) addAgent(lease track.Lease) error {
Server: m.cfg.Server,
ReverseTunnelServer: m.cfg.ReverseTunnelServer,
LocalClusterName: m.cfg.LocalCluster,
LocalAuthAddresses: m.cfg.LocalAuthAddresses,
Component: m.cfg.Component,
Tracker: m.proxyTracker,
Lease: lease,
Expand Down
33 changes: 13 additions & 20 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import (
"golang.org/x/crypto/ssh"
)

func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSite, error) {
func newlocalSite(srv *server, domainName string, authServers []string, client auth.ClientI) (*localSite, error) {
err := utils.RegisterPrometheusCollectors(localClusterCollectors...)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -66,6 +66,7 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi
accessPoint: accessPoint,
certificateCache: certificateCache,
domainName: domainName,
authServers: authServers,
remoteConns: make(map[connKey][]*remoteConn),
clock: srv.Clock,
log: log.WithFields(log.Fields{
Expand All @@ -88,9 +89,10 @@ func newlocalSite(srv *server, domainName string, client auth.ClientI) (*localSi
//
// it implements RemoteSite interface
type localSite struct {
log log.FieldLogger
domainName string
srv *server
log log.FieldLogger
domainName string
authServers []string
srv *server

// client provides access to the Auth Server API of the local cluster.
client auth.ClientI
Expand Down Expand Up @@ -159,27 +161,18 @@ func (s *localSite) GetLastConnected() time.Time {
return s.clock.Now()
}

func (s *localSite) DialAuthServer() (conn net.Conn, err error) {
// get list of local auth servers
authServers, err := s.client.GetAuthServers()
if err != nil {
return nil, trace.Wrap(err)
}

if len(authServers) < 1 {
func (s *localSite) DialAuthServer() (net.Conn, error) {
if len(s.authServers) == 0 {
return nil, trace.ConnectionProblem(nil, "no auth servers available")
}

// try and dial to one of them, as soon as we are successful, return the net.Conn
for _, authServer := range authServers {
conn, err = net.DialTimeout("tcp", authServer.GetAddr(), apidefaults.DefaultDialTimeout)
if err == nil {
return conn, nil
}
addr := utils.ChooseRandomString(s.authServers)
conn, err := net.DialTimeout("tcp", addr, apidefaults.DefaultDialTimeout)
if err != nil {
return nil, trace.ConnectionProblem(err, "unable to connect to auth server")
}

// return the last error
return nil, trace.ConnectionProblem(err, "unable to connect to auth server")
return conn, nil
}

func (s *localSite) Dial(params DialParams) (net.Conn, error) {
Expand Down
2 changes: 1 addition & 1 deletion lib/reversetunnel/localsite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func TestLocalSiteOverlap(t *testing.T) {
},
}

site, err := newlocalSite(srv, "clustername", &mockLocalSiteClient{})
site, err := newlocalSite(srv, "clustername", nil /* authServers */, &mockLocalSiteClient{})
require.NoError(t, err)

nodeID := uuid.NewString()
Expand Down
4 changes: 4 additions & 0 deletions lib/reversetunnel/rc_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ type RemoteClusterTunnelManagerConfig struct {
HostUUID string
// LocalCluster is a cluster name this client is a member of.
LocalCluster string
// LocalAuthAddresses is a list of auth servers to use when dialing back to
// the local cluster.
LocalAuthAddresses []string
// Local ReverseTunnelServer to reach other cluster members connecting to
// this proxy over a tunnel.
ReverseTunnelServer Server
Expand Down Expand Up @@ -216,6 +219,7 @@ func realNewAgentPool(ctx context.Context, cfg RemoteClusterTunnelManagerConfig,
HostSigner: cfg.HostSigner,
HostUUID: cfg.HostUUID,
LocalCluster: cfg.LocalCluster,
LocalAuthAddresses: cfg.LocalAuthAddresses,
Clock: cfg.Clock,
KubeDialAddr: cfg.KubeDialAddr,
ReverseTunnelServer: cfg.ReverseTunnelServer,
Expand Down
28 changes: 9 additions & 19 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,6 @@ type server struct {
offlineThreshold time.Duration
}

// DirectCluster is used to access cluster directly
type DirectCluster struct {
// Name is a cluster name
Name string
// Client is a client to the cluster
Client auth.ClientI
}

// Config is a reverse tunnel server configuration
type Config struct {
// ID is the ID of this server proxy
Expand All @@ -146,11 +138,12 @@ type Config struct {
// AccessPoint caches values and can still return results during connection
// problems.
LocalAccessPoint auth.ProxyAccessPoint
// LocalAuthAddresses is a list of auth servers to use when dialing back to
// the local cluster.
LocalAuthAddresses []string
// NewCachingAccessPoint returns new caching access points
// per remote cluster
NewCachingAccessPoint auth.NewRemoteProxyCachingAccessPoint
// DirectClusters is a list of clusters accessed directly
DirectClusters []DirectCluster
// Context is a signalling context
Context context.Context
// Clock is a clock used in the server, set up to
Expand Down Expand Up @@ -303,8 +296,6 @@ func NewServer(cfg Config) (Server, error) {

srv := &server{
Config: cfg,
localSites: []*localSite{},
remoteSites: []*remoteSite{},
localAuthClient: cfg.LocalAuthClient,
localAccessPoint: cfg.LocalAccessPoint,
newAccessPoint: cfg.NewCachingAccessPoint,
Expand All @@ -317,15 +308,13 @@ func NewServer(cfg Config) (Server, error) {
offlineThreshold: offlineThreshold,
}

for _, clusterInfo := range cfg.DirectClusters {
cluster, err := newlocalSite(srv, clusterInfo.Name, clusterInfo.Client)
if err != nil {
return nil, trace.Wrap(err)
}

srv.localSites = append(srv.localSites, cluster)
localSite, err := newlocalSite(srv, cfg.ClusterName, cfg.LocalAuthAddresses, cfg.LocalAuthClient)
if err != nil {
return nil, trace.Wrap(err)
}

srv.localSites = append(srv.localSites, localSite)

s, err := sshutils.NewServer(
teleport.ComponentReverseTunnelServer,
// TODO(klizhentas): improve interface, use struct instead of parameter list
Expand Down Expand Up @@ -629,6 +618,7 @@ func (s *server) handleTransport(sconn *ssh.ServerConn, nch ssh.NewChannel) {
log: s.log,
closeContext: s.ctx,
authClient: s.LocalAccessPoint,
authServers: s.LocalAuthAddresses,
channel: channel,
requestCh: requestCh,
component: teleport.ComponentReverseTunnelServer,
Expand Down
64 changes: 32 additions & 32 deletions lib/reversetunnel/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ type transport struct {
log logrus.FieldLogger
closeContext context.Context
authClient auth.ProxyAccessPoint
authServers []string
channel ssh.Channel
requestCh <-chan *ssh.Request

Expand Down Expand Up @@ -196,8 +197,6 @@ func (p *transport) start() {
return
}

var servers []string

// Parse and extract the dial request from the client.
dreq := parseDialReq(req.Payload)
if err := dreq.CheckAndSetDefaults(); err != nil {
Expand All @@ -206,23 +205,22 @@ func (p *transport) start() {
}
p.log.Debugf("Received out-of-band proxy transport request for %v [%v].", dreq.Address, dreq.ServerID)

// directAddress will hold the address of the node to dial to, if we don't
// have a tunnel for it.
var directAddress string

// Handle special non-resolvable addresses first.
switch dreq.Address {
// Connect to an Auth Server.
case RemoteAuthServer:
authServers, err := p.authClient.GetAuthServers()
if err != nil {
p.reply(req, false, []byte("connection rejected: failed to connect to auth server"))
return
}
if len(authServers) == 0 {
p.log.Warn("No auth servers registered in the cluster.")
p.reply(req, false, []byte("connection rejected: failed to connect to auth server"))
if len(p.authServers) == 0 {
p.log.Errorf("connection rejected: no auth servers configured")
p.reply(req, false, []byte("no auth servers configured"))

return
}
for _, as := range authServers {
servers = append(servers, as.GetAddr())
}

directAddress = utils.ChooseRandomString(p.authServers)
// Connect to the Kubernetes proxy.
case LocalKubernetes:
switch p.component {
Expand Down Expand Up @@ -256,7 +254,7 @@ func (p *transport) start() {
return
}
p.log.Debugf("Forwarding connection to %q", p.kubeDialAddr.Addr)
servers = append(servers, p.kubeDialAddr.Addr)
directAddress = p.kubeDialAddr.Addr
}

// LocalNode requests are for the single server running in the agent pool.
Expand Down Expand Up @@ -287,15 +285,16 @@ func (p *transport) start() {
// If this is a proxy and not an SSH node, try finding an inbound
// tunnel from the SSH node by dreq.ServerID. We'll need to forward
// dreq.Address as well.
fallthrough
directAddress = dreq.Address
default:
servers = append(servers, dreq.Address)
// Not a special address; could be empty.
directAddress = dreq.Address
}

// Get a connection to the target address. If a tunnel exists with matching
// search names, connection over the tunnel is returned. Otherwise a direct
// net.Dial is performed.
conn, useTunnel, err := p.getConn(servers, dreq)
conn, useTunnel, err := p.getConn(directAddress, dreq)
if err != nil {
errorMessage := fmt.Sprintf("connection rejected: %v", err)
fmt.Fprint(p.channel.Stderr(), errorMessage)
Expand Down Expand Up @@ -369,7 +368,7 @@ func (p *transport) handleChannelRequests(closeContext context.Context, useTunne
// getConn checks if the local site holds a connection to the target host,
// and if it does, attempts to dial through the tunnel. Otherwise directly
// dials to host.
func (p *transport) getConn(servers []string, r *sshutils.DialReq) (net.Conn, bool, error) {
func (p *transport) getConn(addr string, r *sshutils.DialReq) (net.Conn, bool, error) {
// This function doesn't attempt to dial if a host with one of the
// search names is not registered. It's a fast check.
p.log.Debugf("Attempting to dial through tunnel with server ID %q.", r.ServerID)
Expand All @@ -389,13 +388,13 @@ func (p *transport) getConn(servers []string, r *sshutils.DialReq) (net.Conn, bo
}

errTun := err
p.log.Debugf("Attempting to dial directly %v.", servers)
conn, err = p.directDial(servers, r.ServerID)
p.log.Debugf("Attempting to dial directly %q.", addr)
conn, err = p.directDial(addr)
if err != nil {
return nil, false, trace.ConnectionProblem(err, "failed dialing through tunnel (%v) or directly (%v)", errTun, err)
}

p.log.Debugf("Returning direct dialed connection to %v.", servers)
p.log.Debugf("Returning direct dialed connection to %q.", addr)
return conn, false, nil
}

Expand Down Expand Up @@ -438,18 +437,19 @@ func (p *transport) reply(req *ssh.Request, ok bool, msg []byte) {
}
}

// directDial attempst to directly dial to the target host.
func (p *transport) directDial(servers []string, serverID string) (net.Conn, error) {
var errors []error

for _, addr := range servers {
conn, err := net.Dial("tcp", addr)
if err == nil {
return conn, nil
}
// directDial attempts to directly dial to the target host.
func (p *transport) directDial(addr string) (net.Conn, error) {
if addr == "" {
return nil, trace.BadParameter("no address to dial")
}

errors = append(errors, err)
d := net.Dialer{
Timeout: apidefaults.DefaultDialTimeout,
}
conn, err := d.DialContext(p.closeContext, "tcp", addr)
if err != nil {
return nil, trace.Wrap(err)
}

return nil, trace.NewAggregate(errors...)
return conn, nil
}
34 changes: 15 additions & 19 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2970,27 +2970,22 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
HostSigners: []ssh.Signer{conn.ServerIdentity.KeySigner},
LocalAuthClient: conn.Client,
LocalAccessPoint: accessPoint,
LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServers),
NewCachingAccessPoint: process.newLocalCacheForRemoteProxy,
NewCachingAccessPointOldProxy: process.newLocalCacheForOldRemoteProxy,
Limiter: reverseTunnelLimiter,
DirectClusters: []reversetunnel.DirectCluster{
{
Name: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority],
Client: conn.Client,
},
},
KeyGen: cfg.Keygen,
Ciphers: cfg.Ciphers,
KEXAlgorithms: cfg.KEXAlgorithms,
MACAlgorithms: cfg.MACAlgorithms,
DataDir: process.Config.DataDir,
PollingPeriod: process.Config.PollingPeriod,
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
KeyGen: cfg.Keygen,
Ciphers: cfg.Ciphers,
KEXAlgorithms: cfg.KEXAlgorithms,
MACAlgorithms: cfg.MACAlgorithms,
DataDir: process.Config.DataDir,
PollingPeriod: process.Config.PollingPeriod,
FIPS: cfg.FIPS,
Emitter: streamEmitter,
Log: process.log,
LockWatcher: lockWatcher,
NodeWatcher: nodeWatcher,
CertAuthorityWatcher: caWatcher,
})
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -3152,6 +3147,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
AccessPoint: accessPoint,
HostSigner: conn.ServerIdentity.KeySigner,
LocalCluster: conn.ServerIdentity.Cert.Extensions[utils.CertExtensionAuthority],
LocalAuthAddresses: utils.NetAddrsToStrings(process.Config.AuthServers),
KubeDialAddr: utils.DialAddrFromListenAddr(kubeDialAddr(cfg.Proxy, clusterNetworkConfig.GetProxyListenerMode())),
ReverseTunnelServer: tsrv,
FIPS: process.Config.FIPS,
Expand Down Expand Up @@ -3350,7 +3346,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {

var alpnServer *alpnproxy.Proxy
if !cfg.Proxy.DisableTLS && !cfg.Proxy.DisableALPNSNIListener && listeners.web != nil {
authDialerService := alpnproxyauth.NewAuthProxyDialerService(tsrv, accessPoint)
authDialerService := alpnproxyauth.NewAuthProxyDialerService(tsrv, clusterName, utils.NetAddrsToStrings(process.Config.AuthServers))
alpnRouter.Add(alpnproxy.HandlerDecs{
MatchFunc: alpnproxy.MatchByALPNPrefix(string(alpncommon.ProtocolAuth)),
HandlerWithConnInfo: authDialerService.HandleConnection,
Expand Down
Loading