Skip to content

Commit

Permalink
Periodically resync proxies to agents
Browse files Browse the repository at this point in the history
Prior to #14262, resource watchers would periodically close their watcher,
create a new one and refetch the current set of resources. It turns out
that the reverse tunnel subsytem relied on this behavior to periodically
broadcast the list of proxies to agents during steady state. Now that
watchers are persistent and no longer perform a refetch, agents that are
unable to connect to a proxy expire them after a period of time, and
since they never receive the periodic refresh, they never attempt to
connect to said proxy again.

To remedy this, a new ticker is added to the `localsite` that grabs
the current set of proxies from its proxy watcher and sends a discovery
request to the agent. The frequency of the ticker is set to fire
prior to the tracker would expire the proxy so that if a proxy exists
in the cluster, then the agent will continually try to connect to it.
  • Loading branch information
rosstimothy authored and github-actions committed Nov 4, 2022
1 parent ce0c5b1 commit 4597860
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 32 deletions.
30 changes: 19 additions & 11 deletions lib/reversetunnel/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,18 @@ type discoveryRequestRaw struct {
}

func marshalDiscoveryRequest(req discoveryRequest) ([]byte, error) {
var out discoveryRequestRaw
out := discoveryRequestRaw{
Proxies: make([]json.RawMessage, 0, len(req.Proxies)),
}
for _, p := range req.Proxies {
// Clone the server value to avoid a potential race
// since the proxies are shared.
// Marshaling attempts to enforce defaults which modifies
// the original value.
p = p.DeepCopy()
data, err := services.MarshalServer(p)
// create a new server that clones only the id and kind as that's all we need
// to propagate
srv, err := types.NewServer(p.GetName(), p.GetKind(), types.ServerSpecV2{})
if err != nil {
return nil, trace.Wrap(err)
}

data, err := utils.FastMarshal(srv)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -81,17 +85,21 @@ func unmarshalDiscoveryRequest(data []byte) (*discoveryRequest, error) {
if len(data) == 0 {
return nil, trace.BadParameter("missing payload in discovery request")
}

var raw discoveryRequestRaw
err := utils.FastUnmarshal(data, &raw)
if err != nil {
if err := utils.FastUnmarshal(data, &raw); err != nil {
return nil, trace.Wrap(err)
}
var out discoveryRequest

out := discoveryRequest{
Proxies: make([]types.Server, 0, len(raw.Proxies)),
}
for _, bytes := range raw.Proxies {
proxy, err := services.UnmarshalServer([]byte(bytes), types.KindProxy)
proxy, err := services.UnmarshalServer(bytes, types.KindProxy)
if err != nil {
return nil, trace.Wrap(err)
}

out.Proxies = append(out.Proxies, proxy)
}
out.ClusterName = raw.ClusterName
Expand Down
76 changes: 62 additions & 14 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/proxy"
"github.com/gravitational/teleport/lib/reversetunnel/track"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/forward"
"github.com/gravitational/teleport/lib/utils"
Expand All @@ -43,10 +44,31 @@ import (
"golang.org/x/exp/slices"
)

// periodicFunctionInterval is the interval at which periodic stats are calculated.
var periodicFunctionInterval = 3 * time.Minute
const (
// periodicFunctionInterval is the interval at which periodic stats are calculated.
periodicFunctionInterval = 3 * time.Minute

func newlocalSite(srv *server, domainName string, authServers []string) (*localSite, error) {
// proxySyncInterval is the interval at which the current proxies are synchronized to
// connected agents via a discovery request. It is a function of track.DefaultProxyExpiry
// to ensure that the proxies are always synced before the tracker expiry.
proxySyncInterval = track.DefaultProxyExpiry * 2 / 3
)

// withPeriodicFunctionInterval adjusts the periodic function interval
func withPeriodicFunctionInterval(interval time.Duration) func(site *localSite) {
return func(site *localSite) {
site.periodicFunctionInterval = interval
}
}

// withProxySyncInterval adjusts the proxy sync interval
func withProxySyncInterval(interval time.Duration) func(site *localSite) {
return func(site *localSite) {
site.proxySyncInterval = interval
}
}

func newlocalSite(srv *server, domainName string, authServers []string, opts ...func(*localSite)) (*localSite, error) {
err := metrics.RegisterPrometheusCollectors(localClusterCollectors...)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -76,8 +98,14 @@ func newlocalSite(srv *server, domainName string, authServers []string) (*localS
"cluster": domainName,
},
}),
offlineThreshold: srv.offlineThreshold,
peerClient: srv.PeerClient,
offlineThreshold: srv.offlineThreshold,
peerClient: srv.PeerClient,
periodicFunctionInterval: periodicFunctionInterval,
proxySyncInterval: proxySyncInterval,
}

for _, opt := range opts {
opt(s)
}

// Start periodic functions for the local cluster in the background.
Expand Down Expand Up @@ -118,7 +146,15 @@ type localSite struct {
// marking a reverse tunnel connection as invalid.
offlineThreshold time.Duration

// peerClient is the proxy peering client
peerClient *proxy.Client

// periodicFunctionInterval defines the interval period functions run at
periodicFunctionInterval time.Duration

// proxySyncInterval defines the interval at which discovery requests are
// sent to keep agents in sync
proxySyncInterval time.Duration
}

// GetTunnelsCount always the number of tunnel connections to this cluster.
Expand Down Expand Up @@ -494,13 +530,16 @@ func (s *localSite) fanOutProxies(proxies []types.Server) {
}
}

// handleHearbeat receives heartbeat messages from the connected agent
// handleHeartbeat receives heartbeat messages from the connected agent
// if the agent has missed several heartbeats in a row, Proxy marks
// the connection as invalid.
func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-chan *ssh.Request) {
proxyResyncTicker := s.clock.NewTicker(s.proxySyncInterval)

defer func() {
s.log.Debugf("Cluster connection closed.")
rconn.Close()
proxyResyncTicker.Stop()
}()

firstHeartbeat := true
Expand All @@ -509,14 +548,23 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
case <-s.srv.ctx.Done():
s.log.Infof("closing")
return
case <-proxyResyncTicker.Chan():
req := discoveryRequest{
Proxies: s.srv.proxyWatcher.GetCurrent(),
}

if err := rconn.sendDiscoveryRequest(req); err != nil {
s.log.WithError(err).Debugf("Marking connection invalid on error")
rconn.markInvalid(err)
return
}
case proxies := <-rconn.newProxiesC:
req := discoveryRequest{
ClusterName: s.srv.ClusterName,
Type: rconn.tunnelType,
Proxies: proxies,
Proxies: proxies,
}

if err := rconn.sendDiscoveryRequest(req); err != nil {
s.log.Debugf("Marking connection invalid on error: %v.", err)
s.log.WithError(err).Debugf("Marking connection invalid on error")
rconn.markInvalid(err)
return
}
Expand Down Expand Up @@ -549,10 +597,10 @@ func (s *localSite) handleHeartbeat(rconn *remoteConn, ch ssh.Channel, reqC <-ch
} else {
s.log.WithFields(log.Fields{"nodeID": rconn.nodeID}).Debugf("Ping <- %v", rconn.conn.RemoteAddr())
}
tm := time.Now().UTC()
tm := s.clock.Now().UTC()
rconn.setLastHeartbeat(tm)
// Note that time.After is re-created everytime a request is processed.
case <-time.After(s.offlineThreshold):
case <-s.clock.After(s.offlineThreshold):
rconn.markInvalid(trace.ConnectionProblem(nil, "no heartbeats for %v", s.offlineThreshold))
}
}
Expand Down Expand Up @@ -611,14 +659,14 @@ func (s *localSite) chanTransportConn(rconn *remoteConn, dreq *sshutils.DialReq)

// periodicFunctions runs functions periodic functions for the local cluster.
func (s *localSite) periodicFunctions() {
ticker := time.NewTicker(periodicFunctionInterval)
ticker := s.clock.NewTicker(s.periodicFunctionInterval)
defer ticker.Stop()

for {
select {
case <-s.srv.ctx.Done():
return
case <-ticker.C:
case <-ticker.Chan():
if err := s.sshTunnelStats(); err != nil {
s.log.Warningf("Failed to report SSH tunnel statistics for: %v: %v.", s.domainName, err)
}
Expand Down
Loading

0 comments on commit 4597860

Please sign in to comment.