Skip to content

Commit

Permalink
Periodically resync proxies to agents (#18050)
Browse files Browse the repository at this point in the history
* Periodically resync proxies to agents

Prior to #14262, resource watchers would periodically close their watcher,
create a new one and refetch the current set of resources. It turns out
that the reverse tunnel subsystem relied on this behavior to periodically
broadcast the list of proxies to agents during steady state. Now that
watchers are persistent and no longer perform a refetch, agents that are
unable to connect to a proxy expire them after a period of time, and
since they never receive the periodic refresh, they never attempt to
connect to said proxy again.

To remedy this, a new ticker is added to the `localsite` that grabs
the current set of proxies from its proxy watcher and sends a discovery
request to the agent. The frequency of the ticker is set to fire
prior to the tracker would expire the proxy so that if a proxy exists
in the cluster, then the agent will continually try to connect to it.
  • Loading branch information
rosstimothy authored Nov 4, 2022
1 parent 92c9429 commit 3b4c144
Show file tree
Hide file tree
Showing 9 changed files with 506 additions and 91 deletions.
14 changes: 6 additions & 8 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package reversetunnel

import (
"context"
"encoding/json"
"fmt"
"io"
"strings"
Expand Down Expand Up @@ -637,17 +638,14 @@ func (a *agent) handleDiscovery(ch ssh.Channel, reqC <-chan *ssh.Request) {
a.log.Infof("Connection closed, returning")
return
}
r, err := unmarshalDiscoveryRequest(req.Payload)
if err != nil {
a.log.Warningf("Bad payload: %v.", err)
return
}

var proxies []string
for _, proxy := range r.Proxies {
proxies = append(proxies, proxy.GetName())
var r discoveryRequest
if err := json.Unmarshal(req.Payload, &r); err != nil {
a.log.WithError(err).Warningf("Bad payload")
return
}

proxies := r.ProxyNames()
a.log.Debugf("Received discovery request: %v", proxies)
a.tracker.TrackExpected(proxies...)
}
Expand Down
13 changes: 4 additions & 9 deletions lib/reversetunnel/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package reversetunnel

import (
"encoding/json"
"fmt"
"net"
"sync"
Expand Down Expand Up @@ -225,22 +226,16 @@ func (c *remoteConn) sendDiscoveryRequest(req discoveryRequest) error {

// Marshal and send the request. If the connection failed, mark the
// connection as invalid so it will be removed later.
payload, err := marshalDiscoveryRequest(req)
payload, err := json.Marshal(req)
if err != nil {
return trace.Wrap(err)
}

// Log the discovery request being sent. Useful for debugging to know what
// proxies the tunnel server thinks exist.
names := make([]string, 0, len(req.Proxies))
for _, proxy := range req.Proxies {
names = append(names, proxy.GetName())
}
c.log.Debugf("Sending %v discovery request with proxies %q to %v.",
req.Type, names, c.sconn.RemoteAddr())
c.log.Debugf("Sending discovery request with proxies %q to %v.", req.ProxyNames(), c.sconn.RemoteAddr())

_, err = discoveryCh.SendRequest(chanDiscoveryReq, false, payload)
if err != nil {
if _, err := discoveryCh.SendRequest(chanDiscoveryReq, false, payload); err != nil {
c.markInvalid(err)
return trace.Wrap(err)
}
Expand Down
121 changes: 71 additions & 50 deletions lib/reversetunnel/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ package reversetunnel

import (
"encoding/json"
"fmt"
"strings"

"github.com/gravitational/trace"

Expand All @@ -30,71 +28,94 @@ import (

// discoveryRequest is a request sent from a connected proxy with the missing proxies.
type discoveryRequest struct {
// ClusterName is the name of the cluster that sends the discovery request.
ClusterName string `json:"cluster_name"`

// Type is the type of tunnel, is either node or proxy.
Type string `json:"type"`

// ClusterAddr is the address of the cluster.
ClusterAddr utils.NetAddr `json:"-"`

// Proxies is a list of proxies in the cluster sending the discovery request.
Proxies []types.Server `json:"proxies"`
}

func (r discoveryRequest) String() string {
proxyNames := make([]string, 0, len(r.Proxies))
// ProxyNames returns the names of all proxies carried in the request
func (r *discoveryRequest) ProxyNames() []string {
names := make([]string, 0, len(r.Proxies))
for _, p := range r.Proxies {
proxyNames = append(proxyNames, p.GetName())
names = append(names, p.GetName())
}
return fmt.Sprintf("discovery request, cluster name: %v, address: %v, proxies: %v",
r.ClusterName, r.ClusterAddr, strings.Join(proxyNames, ","))
}

type discoveryRequestRaw struct {
ClusterName string `json:"cluster_name"`
Type string `json:"type"`
Proxies []json.RawMessage `json:"proxies"`
return names
}

func marshalDiscoveryRequest(req discoveryRequest) ([]byte, error) {
var out discoveryRequestRaw
for _, p := range req.Proxies {
// Clone the server value to avoid a potential race
// since the proxies are shared.
// Marshaling attempts to enforce defaults which modifies
// the original value.
p = p.DeepCopy()
data, err := services.MarshalServer(p)
if err != nil {
return nil, trace.Wrap(err)
}
out.Proxies = append(out.Proxies, data)
// MarshalJSON creates a minimal JSON representation of a discoveryRequest
// by converting the Proxies from types.Server to discoveryProxy.
// The minification is useful since only the Proxy ID is to be consumed
// by the agents. This is needed to maintain backward compatibility
// but should be replaced in the future by a message which
// only contains the Proxy IDs.
func (r *discoveryRequest) MarshalJSON() ([]byte, error) {
var out struct {
Proxies []discoveryProxy `json:"proxies"`
}
out.ClusterName = req.ClusterName
out.Type = req.Type

out.Proxies = make([]discoveryProxy, 0, len(r.Proxies))

for _, p := range r.Proxies {
out.Proxies = append(out.Proxies, discoveryProxy(p.GetName()))
}

return json.Marshal(out)
}

func unmarshalDiscoveryRequest(data []byte) (*discoveryRequest, error) {
func (r *discoveryRequest) UnmarshalJSON(data []byte) error {
if len(data) == 0 {
return nil, trace.BadParameter("missing payload in discovery request")
return trace.BadParameter("missing payload in discovery request")
}
var raw discoveryRequestRaw
err := utils.FastUnmarshal(data, &raw)
if err != nil {
return nil, trace.Wrap(err)

var in struct {
Proxies []json.RawMessage `json:"proxies"`
}
var out discoveryRequest
for _, bytes := range raw.Proxies {
proxy, err := services.UnmarshalServer([]byte(bytes), types.KindProxy)

if err := utils.FastUnmarshal(data, &in); err != nil {
return trace.Wrap(err)
}

d := discoveryRequest{
Proxies: make([]types.Server, 0, len(in.Proxies)),
}

for _, bytes := range in.Proxies {
proxy, err := services.UnmarshalServer(bytes, types.KindProxy)
if err != nil {
return nil, trace.Wrap(err)
return trace.Wrap(err)
}
out.Proxies = append(out.Proxies, proxy)

d.Proxies = append(d.Proxies, proxy)
}

*r = d
return nil
}

// discoveryProxy is a wrapper around a Proxy ID that
// can be marshaled to json in the minimal representation
// of a types.Server that will still be correctly unmarshalled
// as a types.Server. Backwards compatibility requires a types.Server
// to be included in a discoveryRequest when in reality only
// the Proxy ID needs to be communicated to agents.
//
// This should eventually be replaced by a newer version of
// messages used by agents to indicate they can support discovery
// requests which only contain Proxy IDs.
type discoveryProxy string

// MarshalJSON creates a minimum representation of types.Server
// such that (*discoveryRequest) UnmarshalJSON will successfully
// unmarshal this as a types.Server. This allows the discoveryRequest
// to be four and a half times smaller when marshaled.
func (s discoveryProxy) MarshalJSON() ([]byte, error) {
var p struct {
Version string `json:"version"`
Metadata struct {
Name string `json:"name"`
} `json:"metadata"`
}
out.ClusterName = raw.ClusterName
out.Type = raw.Type
return &out, nil
p.Version = types.V2
p.Metadata.Name = string(s)
return json.Marshal(p)
}
147 changes: 147 additions & 0 deletions lib/reversetunnel/discovery_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright 2022 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package reversetunnel

import (
"encoding/json"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/utils"
)

// discoveryRequestRaw is the legacy type that was used
// as the payload for discoveryRequests. It exists
// here for the sake of ensuring backward compatibility.
type discoveryRequestRaw struct {
ClusterName string `json:"cluster_name"`
Type string `json:"type"`
Proxies []json.RawMessage `json:"proxies"`
}

// marshalDiscoveryRequest is the legacy method of marshaling a discoveryRequest
func marshalDiscoveryRequest(req discoveryRequest) ([]byte, error) {
out := discoveryRequestRaw{
Proxies: make([]json.RawMessage, 0, len(req.Proxies)),
}
for _, p := range req.Proxies {
// Clone the server value to avoid a potential race
// since the proxies are shared.
// Marshaling attempts to enforce defaults which modifies
// the original value.
p = p.DeepCopy()
data, err := services.MarshalServer(p)
if err != nil {
return nil, trace.Wrap(err)
}
out.Proxies = append(out.Proxies, data)
}

return json.Marshal(out)
}

// unmarshalDiscoveryRequest is the legacy method of unmarshaling a discoveryRequest
func unmarshalDiscoveryRequest(data []byte) (*discoveryRequest, error) {
if len(data) == 0 {
return nil, trace.BadParameter("missing payload in discovery request")
}

var raw discoveryRequestRaw
if err := utils.FastUnmarshal(data, &raw); err != nil {
return nil, trace.Wrap(err)
}

out := discoveryRequest{
Proxies: make([]types.Server, 0, len(raw.Proxies)),
}
for _, bytes := range raw.Proxies {
proxy, err := services.UnmarshalServer(bytes, types.KindProxy)
if err != nil {
return nil, trace.Wrap(err)
}

out.Proxies = append(out.Proxies, proxy)
}

return &out, nil
}

func TestDiscoveryRequestMarshalling(t *testing.T) {
const proxyCount = 10

// create a discovery request
req := discoveryRequest{
Proxies: make([]types.Server, 0, proxyCount),
}

// populate the proxies
for i := 0; i < proxyCount; i++ {
p, err := types.NewServer(uuid.New().String(), types.KindProxy, types.ServerSpecV2{})
require.NoError(t, err)
req.Proxies = append(req.Proxies, p)
}

// test marshaling the request with the legacy mechanism and unmarshaling
// with the new mechanism
t.Run("marshal=legacy unmarshal=new", func(t *testing.T) {
payload, err := marshalDiscoveryRequest(req)
require.NoError(t, err)

var got discoveryRequest
require.NoError(t, json.Unmarshal(payload, &got))

require.Empty(t, cmp.Diff(req.ProxyNames(), got.ProxyNames()))
})

// test marshaling the request with the new mechanism and unmarshaling
// with the legacy mechanism
t.Run("marshal=new unmarshal=legacy", func(t *testing.T) {
payload, err := json.Marshal(req)
require.NoError(t, err)

got, err := unmarshalDiscoveryRequest(payload)
require.NoError(t, err)

require.Empty(t, cmp.Diff(req.ProxyNames(), got.ProxyNames()))
})

// test marshaling and unmarshaling the request with the new mechanism
t.Run("marshal=new unmarshal=new", func(t *testing.T) {
payload, err := json.Marshal(req)
require.NoError(t, err)

var got discoveryRequest
require.NoError(t, json.Unmarshal(payload, &got))

require.Empty(t, cmp.Diff(req.ProxyNames(), got.ProxyNames()))
})

// test marshaling and unmarshaling the request with the legacy mechanism
t.Run("marshal=legacy unmarshal=legacy", func(t *testing.T) {
payload, err := marshalDiscoveryRequest(req)
require.NoError(t, err)

got, err := unmarshalDiscoveryRequest(payload)
require.NoError(t, err)

require.Empty(t, cmp.Diff(req.ProxyNames(), got.ProxyNames()))
})
}
Loading

0 comments on commit 3b4c144

Please sign in to comment.