Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[branch/v9] Create remote site cache based on remote auth version (#12130) #12251

Merged
merged 3 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,16 +354,25 @@ func (a *Agent) handleGlobalRequests(ctx context.Context, requestCh <-chan *ssh.

switch r.Type {
case versionRequest:
err := r.Reply(true, []byte(teleport.Version))
// reply with the auth server version
pong, err := a.Client.Ping(ctx)
if err != nil {
log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.log.WithError(err).Warnf("Failed to ping auth server in response to %v request.", r.Type)
if err := r.Reply(false, []byte("Failed to retrieve auth version")); err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
continue
}
}

if err := r.Reply(true, []byte(pong.ServerVersion)); err != nil {
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
continue
}
default:
// This handles keep-alive messages and matches the behaviour of OpenSSH.
err := r.Reply(false, nil)
if err != nil {
log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
a.log.Debugf("Failed to reply to %v request: %v.", r.Type, err)
continue
}
}
Expand Down
70 changes: 30 additions & 40 deletions lib/reversetunnel/srv.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ import (
"github.com/gravitational/teleport/lib/sshca"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"

"github.com/coreos/go-semver/semver"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -639,7 +637,7 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N
// nodes it's a node dialing back.
val, ok := sconn.Permissions.Extensions[extCertRole]
if !ok {
log.Errorf("Failed to accept connection, missing %q extension", extCertRole)
s.log.Errorf("Failed to accept connection, missing %q extension", extCertRole)
s.rejectRequest(nch, ssh.ConnectionFailed, "unknown role")
return
}
Expand All @@ -665,22 +663,22 @@ func (s *server) handleHeartbeat(conn net.Conn, sconn *ssh.ServerConn, nch ssh.N
s.handleNewService(role, conn, sconn, nch, types.WindowsDesktopTunnel)
// Unknown role.
default:
log.Errorf("Unsupported role attempting to connect: %v", val)
s.log.Errorf("Unsupported role attempting to connect: %v", val)
s.rejectRequest(nch, ssh.ConnectionFailed, fmt.Sprintf("unsupported role %v", val))
}
}

func (s *server) handleNewService(role types.SystemRole, conn net.Conn, sconn *ssh.ServerConn, nch ssh.NewChannel, connType types.TunnelType) {
cluster, rconn, err := s.upsertServiceConn(conn, sconn, connType)
if err != nil {
log.Errorf("Failed to upsert %s: %v.", role, err)
s.log.Errorf("Failed to upsert %s: %v.", role, err)
sconn.Close()
return
}

ch, req, err := nch.Accept()
if err != nil {
log.Errorf("Failed to accept on channel: %v.", err)
s.log.Errorf("Failed to accept on channel: %v.", err)
sconn.Close()
return
}
Expand All @@ -692,14 +690,14 @@ func (s *server) handleNewCluster(conn net.Conn, sshConn *ssh.ServerConn, nch ss
// add the incoming site (cluster) to the list of active connections:
site, remoteConn, err := s.upsertRemoteCluster(conn, sshConn)
if err != nil {
log.Error(trace.Wrap(err))
s.log.Error(trace.Wrap(err))
s.rejectRequest(nch, ssh.ConnectionFailed, "failed to accept incoming cluster connection")
return
}
// accept the request and start the heartbeat on it:
ch, req, err := nch.Accept()
if err != nil {
log.Error(trace.Wrap(err))
s.log.Error(trace.Wrap(err))
sshConn.Close()
return
}
Expand Down Expand Up @@ -1062,25 +1060,12 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
}
remoteSite.remoteClient = clt

// Check if the cluster that is connecting is a pre-v8 cluster. If it is,
// don't assume the newer organization of cluster configuration resources
// (RFD 28) because older proxy servers will reject that causing the cache
// to go into a re-sync loop.
var accessPointFunc auth.NewRemoteProxyCachingAccessPoint
ok, err := isPreV8Cluster(closeContext, sconn)
remoteVersion, err := getRemoteAuthVersion(closeContext, sconn)
if err != nil {
return nil, trace.Wrap(err)
}
if ok {
log.Debugf("Pre-v8 cluster connecting, loading old cache policy.")
accessPointFunc = srv.Config.NewCachingAccessPointOldProxy
} else {
accessPointFunc = srv.newAccessPoint
}

// Configure access to the cached subset of the Auth Server API of the remote
// cluster this remote site provides access to.
accessPoint, err := accessPointFunc(clt, []string{"reverse", domainName})
accessPoint, err := createRemoteAccessPoint(srv, clt, remoteVersion, domainName)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -1125,31 +1110,35 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite,
return remoteSite, nil
}

// isPreV8Cluster checks if the cluster is older than 8.0.0.
func isPreV8Cluster(ctx context.Context, conn ssh.Conn) (bool, error) {
version, err := sendVersionRequest(ctx, conn)
// createRemoteAccessPoint creates a new access point for the remote cluster.
// Checks if the cluster that is connecting is a pre-v8 cluster. If it is,
// don't assume the newer organization of cluster configuration resources
// (RFD 28) because older proxy servers will reject that causing the cache
// to go into a re-sync loop.
func createRemoteAccessPoint(srv *server, clt auth.ClientI, version, domainName string) (auth.RemoteProxyAccessPoint, error) {
ok, err := utils.MinVerWithoutPreRelease(version, utils.VersionBeforeAlpha("8.0.0"))
if err != nil {
return false, trace.Wrap(err)
return nil, trace.Wrap(err)
}

remoteClusterVersion, err := semver.NewVersion(version)
if err != nil {
return false, trace.Wrap(err)
accessPointFunc := srv.Config.NewCachingAccessPoint
if !ok {
srv.log.Debugf("cluster %q running %q is connecting, loading old cache policy.", domainName, version)
accessPointFunc = srv.Config.NewCachingAccessPointOldProxy
}
minClusterVersion, err := semver.NewVersion(utils.VersionBeforeAlpha("8.0.0"))

// Configure access to the cached subset of the Auth Server API of the remote
// cluster this remote site provides access to.
accessPoint, err := accessPointFunc(clt, []string{"reverse", domainName})
if err != nil {
return false, trace.Wrap(err)
}
// Return true if the version is older than 8.0.0
if remoteClusterVersion.LessThan(*minClusterVersion) {
return true, nil
return nil, trace.Wrap(err)
}

return false, nil
return accessPoint, nil
}

// sendVersionRequest sends a request for the version remote Teleport cluster.
func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) {
// getRemoteAuthVersion sends a version request to the remote agent.
func getRemoteAuthVersion(ctx context.Context, sconn ssh.Conn) (string, error) {
errorCh := make(chan error, 1)
versionCh := make(chan string, 1)

Expand All @@ -1163,6 +1152,7 @@ func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) {
errorCh <- trace.BadParameter("no response to %v request", versionRequest)
return
}

versionCh <- string(payload)
}()

Expand All @@ -1174,7 +1164,7 @@ func sendVersionRequest(ctx context.Context, sconn ssh.Conn) (string, error) {
case <-time.After(defaults.WaitCopyTimeout):
return "", trace.BadParameter("timeout waiting for version")
case <-ctx.Done():
return "", ctx.Err()
return "", trace.Wrap(ctx.Err())
}
}

Expand Down
68 changes: 68 additions & 0 deletions lib/reversetunnel/srv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package reversetunnel

import (
"context"
"errors"
"net"
"testing"
"time"
Expand Down Expand Up @@ -159,3 +160,70 @@ type mockAccessPoint struct {
func (ap mockAccessPoint) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) {
return ap.ca, nil
}

func TestCreateRemoteAccessPoint(t *testing.T) {
cases := []struct {
name string
version string
assertion require.ErrorAssertionFunc
oldRemoteProxy bool
}{
{
name: "invalid version",
assertion: require.Error,
},
{
name: "remote running 9.0.0",
assertion: require.NoError,
version: "9.0.0",
},
{
name: "remote running 8.0.0",
assertion: require.NoError,
version: "8.0.0",
},
{
name: "remote running 7.0.0",
assertion: require.NoError,
version: "7.0.0",
oldRemoteProxy: true,
},
{
name: "remote running 6.0.0",
assertion: require.NoError,
version: "6.0.0",
oldRemoteProxy: true,
},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
newProxyFn := func(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) {
if tt.oldRemoteProxy {
return nil, errors.New("expected to create an old remote proxy")
}

return nil, nil
}

oldProxyFn := func(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) {
if !tt.oldRemoteProxy {
return nil, errors.New("expected to create an new remote proxy")
}

return nil, nil
}

clt := &mockAuthClient{}
srv := &server{
log: utils.NewLoggerForTests(),
Config: Config{
NewCachingAccessPoint: newProxyFn,
NewCachingAccessPointOldProxy: oldProxyFn,
},
}
_, err := createRemoteAccessPoint(srv, clt, tt.version, "test")
tt.assertion(t, err)
})
}
}
29 changes: 29 additions & 0 deletions lib/utils/ver.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,32 @@ func CheckVersion(currentVersion, minVersion string) error {
func VersionBeforeAlpha(version string) string {
return version + "-aa"
}

// MinVerWithoutPreRelease compares semver strings, but skips prerelease. This allows to compare
// two versions and ignore dev,alpha,beta, etc. strings.
func MinVerWithoutPreRelease(currentVersion, minVersion string) (bool, error) {
currentSemver, minSemver, err := versionStringToSemver(currentVersion, minVersion)
if err != nil {
return false, trace.Wrap(err)
}

// Erase pre-release string, so only version is compared.
currentSemver.PreRelease = ""
minSemver.PreRelease = ""

return !currentSemver.LessThan(*minSemver), nil
}

func versionStringToSemver(ver1, ver2 string) (*semver.Version, *semver.Version, error) {
v1Semver, err := semver.NewVersion(ver1)
if err != nil {
return nil, nil, trace.Wrap(err, "unsupported version format, need semver format: %q, e.g 1.0.0", v1Semver)
}

v2Semver, err := semver.NewVersion(ver2)
if err != nil {
return nil, nil, trace.Wrap(err, "unsupported version format, need semver format: %q, e.g 1.0.0", v2Semver)
}

return v1Semver, v2Semver, nil
}