Skip to content

Commit

Permalink
Dial only application servers that serve the requested application (#…
Browse files Browse the repository at this point in the history
…12217)

* feat(reversetunnel): add a dial counter on fake remote site

* refactor(app): change `MatchHealthy` to close connections

* refactor(app): change `MatchAll` order to not dial all servers

* refactor(app): only close connection when dial succeeds
  • Loading branch information
gabrielcorado committed Apr 28, 2022
1 parent f955f36 commit ecd78ee
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 20 deletions.
9 changes: 9 additions & 0 deletions lib/reversetunnel/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package reversetunnel

import (
"net"
"sync/atomic"

"github.com/gravitational/teleport/lib/auth"

Expand Down Expand Up @@ -57,6 +58,8 @@ type FakeRemoteSite struct {
AccessPoint auth.RemoteProxyAccessPoint
// OfflineTunnels is a list of server IDs that will return connection error.
OfflineTunnels map[string]struct{}
// connCounter count how many connection requests the remote received.
connCounter int64
}

// CachingAccessPoint returns caching auth server client.
Expand All @@ -71,6 +74,8 @@ func (s *FakeRemoteSite) GetName() string {

// Dial returns the connection to the remote site.
func (s *FakeRemoteSite) Dial(params DialParams) (net.Conn, error) {
atomic.AddInt64(&s.connCounter, 1)

if _, ok := s.OfflineTunnels[params.ServerID]; ok {
return nil, trace.ConnectionProblem(nil, "server %v tunnel is offline",
params.ServerID)
Expand All @@ -83,3 +88,7 @@ func (s *FakeRemoteSite) Dial(params DialParams) (net.Conn, error) {
func (s *FakeRemoteSite) Close() error {
return nil
}

func (s *FakeRemoteSite) DialCount() int64 {
return atomic.LoadInt64(&s.connCounter)
}
238 changes: 222 additions & 16 deletions lib/web/app/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,29 @@ import (
"bytes"
"context"
"crypto/tls"
"crypto/x509/pkix"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/reversetunnel"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"

Expand All @@ -45,6 +56,7 @@ func TestAuthPOST(t *testing.T) {
cookieValue = "5588e2be54a2834b4f152c56bafcd789f53b15477129d2ab4044e9a3c1bf0f3b"
)

fakeClock := clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC))
tests := []struct {
desc string
stateInRequest string
Expand Down Expand Up @@ -77,33 +89,105 @@ func TestAuthPOST(t *testing.T) {

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
p := setup(t, test.sessionError)
p := setup(t, fakeClock, mockAuthClient{sessionError: test.sessionError}, nil)

req, err := json.Marshal(fragmentRequest{
StateValue: test.stateInRequest,
CookieValue: cookieValue,
})
require.NoError(t, err)

status := p.makeRequest(t, "POST", "/x-teleport-auth", test.stateInCookie, req)
status, _ := p.makeRequest(t, "POST", "/x-teleport-auth", AuthStateCookieName, test.stateInCookie, req)
require.Equal(t, test.outStatusCode, status)
})
}
}

type testServer struct {
serverURL *url.URL
}
func TestMatchApplicationServers(t *testing.T) {
clusterName := "test-cluster"
publicAddr := "app.example.com"

// Generate CA TLS key and cert with the cluster and application DNS.
key, cert, err := tlsca.GenerateSelfSignedCA(
pkix.Name{CommonName: clusterName},
[]string{publicAddr, apiutils.EncodeClusterName(clusterName)},
defaults.CATTL,
)
require.NoError(t, err)

func setup(t *testing.T, sessionError error) *testServer {
fakeClock := clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC))
authClient := mockAuthClient{
sessionError: sessionError,
clusterName: clusterName,
appSession: createAppSession(t, fakeClock, key, cert, clusterName, publicAddr),
// Three app servers with same public addr from our session, and three
// that won't match.
appServers: []types.AppServer{
createAppServer(t, publicAddr),
createAppServer(t, publicAddr),
createAppServer(t, publicAddr),
createAppServer(t, "random.example.com"),
createAppServer(t, "random2.example.com"),
createAppServer(t, "random3.example.com"),
},
caKey: key,
caCert: cert,
}

// Create a fake remote site and tunnel.
fakeRemoteSiteConnCh := make(chan net.Conn)
fakeRemoteSite := &reversetunnel.FakeRemoteSite{
Name: clusterName,
ConnCh: fakeRemoteSiteConnCh,
AccessPoint: authClient,
}
tunnel := &reversetunnel.FakeServer{
Sites: []reversetunnel.RemoteSite{
fakeRemoteSite,
},
}

// Create a httptest server to serve the application requests. It must serve
// TLS content with the generated certificate.
tlsCert, err := tls.X509KeyPair(cert, key)
require.NoError(t, err)
expectedContent := "Hello from application"
server := &httptest.Server{
TLS: &tls.Config{
Certificates: []tls.Certificate{tlsCert},
},
Listener: &fakeRemoteListener{fakeRemoteSite},
Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, expectedContent)
})},
}
server.StartTLS()

// Teardown the remote site and the httptest server.
t.Cleanup(func() {
close(fakeRemoteSiteConnCh)
server.Close()
})

p := setup(t, fakeClock, authClient, tunnel)
status, content := p.makeRequest(t, "GET", "/", CookieName, "abc", []byte{})
require.Equal(t, http.StatusOK, status)
// Remote site should receive only 4 connection requests: 3 from the
// MatchHealthy and 1 from the transport.
require.Equal(t, int64(4), fakeRemoteSite.DialCount())
// Guarantee the request was returned by the httptest server.
require.Equal(t, expectedContent, content)
}

type testServer struct {
serverURL *url.URL
}

func setup(t *testing.T, clock clockwork.FakeClock, authClient auth.ClientI, proxyClient reversetunnel.Tunnel) *testServer {
appHandler, err := NewHandler(context.Background(), &HandlerConfig{
Clock: fakeClock,
Clock: clock,
AuthClient: authClient,
AccessPoint: authClient,
ProxyClient: proxyClient,
CipherSuites: utils.DefaultCipherSuites(),
})
require.NoError(t, err)
Expand All @@ -119,7 +203,7 @@ func setup(t *testing.T, sessionError error) *testServer {
}
}

func (p *testServer) makeRequest(t *testing.T, method, endpoint, stateInCookie string, reqBody []byte) int {
func (p *testServer) makeRequest(t *testing.T, method, endpoint, cookieName, cookieValue string, reqBody []byte) (int, string) {
u := url.URL{
Scheme: p.serverURL.Scheme,
Host: p.serverURL.Host,
Expand All @@ -128,10 +212,10 @@ func (p *testServer) makeRequest(t *testing.T, method, endpoint, stateInCookie s
req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewBuffer(reqBody))
require.NoError(t, err)

// Attach state token cookie.
// Attach the cookie.
req.AddCookie(&http.Cookie{
Name: AuthStateCookieName,
Value: stateInCookie,
Name: cookieName,
Value: cookieValue,
})

// Issue request.
Expand All @@ -145,29 +229,151 @@ func (p *testServer) makeRequest(t *testing.T, method, endpoint, stateInCookie s
return http.ErrUseLastResponse
},
}

resp, err := client.Do(req)
require.NoError(t, err)

content, err := io.ReadAll(resp.Body)
require.NoError(t, err)

require.NoError(t, resp.Body.Close())
return resp.StatusCode
return resp.StatusCode, string(content)
}

type mockAuthClient struct {
auth.ClientI
clusterName string
appSession types.WebSession
sessionError error
appServers []types.AppServer
caKey []byte
caCert []byte
}

type mockClusterName struct {
types.ClusterName
name string
}

func (c mockAuthClient) GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error) {
return mockClusterName{}, nil
func (c mockAuthClient) GetClusterName(_ ...services.MarshalOption) (types.ClusterName, error) {
return mockClusterName{name: c.clusterName}, nil
}

func (n mockClusterName) GetClusterName() string {
if n.name != "" {
return n.name
}

return "local-cluster"
}

func (c mockAuthClient) GetAppSession(context.Context, types.GetAppSessionRequest) (types.WebSession, error) {
return nil, c.sessionError
return c.appSession, c.sessionError
}

func (c mockAuthClient) GetApplicationServers(_ context.Context, _ string) ([]types.AppServer, error) {
return c.appServers, nil
}

func (c mockAuthClient) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) {
ca, err := types.NewCertAuthority(types.CertAuthoritySpecV2{
Type: types.HostCA,
ClusterName: c.clusterName,
ActiveKeys: types.CAKeySet{
TLS: []*types.TLSKeyPair{{
Cert: c.caCert,
Key: c.caKey,
}},
},
})
if err != nil {
return nil, err
}

return ca, nil
}

// fakeRemoteListener Implements a `net.Listener` that return `net.Conn` from
// the `FakeRemoteSite`.
type fakeRemoteListener struct {
fakeRemote *reversetunnel.FakeRemoteSite
}

func (r *fakeRemoteListener) Accept() (net.Conn, error) {
conn, ok := <-r.fakeRemote.ConnCh
if !ok {
return nil, fmt.Errorf("remote closed")
}

return conn, nil

}

func (r *fakeRemoteListener) Close() error {
return nil
}

func (r *fakeRemoteListener) Addr() net.Addr {
return &net.IPAddr{}
}

// createAppSession generates a WebSession for an application.
func createAppSession(t *testing.T, clock clockwork.FakeClock, caKey, caCert []byte, clusterName, publicAddr string) types.WebSession {
tlsCA, err := tlsca.FromKeys(caCert, caKey)
require.NoError(t, err)

// Generate the identity with a `RouteToApp` option.
subj, err := (&tlsca.Identity{
Username: "testuser",
Groups: []string{"access"},
RouteToApp: tlsca.RouteToApp{
PublicAddr: publicAddr,
ClusterName: clusterName,
Name: "testapp",
},
}).Subject()
require.NoError(t, err)

// Generate public and private keys for the application request certificate.
priv, pub, err := testauthority.New().GetNewKeyPairFromPool()
require.NoError(t, err)
cryptoPubKey, err := sshutils.CryptoPublicKey(pub)
require.NoError(t, err)

cert, err := tlsCA.GenerateCertificate(tlsca.CertificateRequest{
Clock: clock,
PublicKey: cryptoPubKey,
Subject: subj,
NotAfter: clock.Now().Add(5 * time.Minute),
})
require.NoError(t, err)

appSession, err := types.NewWebSession(uuid.New().String(), types.KindAppSession, types.WebSessionSpecV2{
User: "testuser",
Priv: priv,
TLSCert: cert,
Expires: clock.Now().Add(5 * time.Minute),
})
require.NoError(t, err)

return appSession
}

func createAppServer(t *testing.T, publicAddr string) types.AppServer {
appName := uuid.New().String()
appServer, err := types.NewAppServerV3(
types.Metadata{Name: appName},
types.AppServerSpecV3{
HostID: uuid.New().String(),
App: &types.AppV3{
Metadata: types.Metadata{Name: appName},
Spec: types.AppSpecV3{
URI: "localhost",
PublicAddr: publicAddr,
},
},
},
)
require.NoError(t, err)
return appServer
}
9 changes: 7 additions & 2 deletions lib/web/app/match.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ func MatchName(name string) Matcher {
// doesn't return any error.
func MatchHealthy(proxyClient reversetunnel.Tunnel, identity *tlsca.Identity) Matcher {
return func(appServer types.AppServer) bool {
_, err := dialAppServer(proxyClient, identity, appServer)
return err == nil
conn, err := dialAppServer(proxyClient, identity, appServer)
if err != nil {
return false
}

conn.Close()
return true
}
}

Expand Down
14 changes: 13 additions & 1 deletion lib/web/app/match_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,17 @@ type mockRemoteSite struct {
}

func (r *mockRemoteSite) Dial(_ reversetunnel.DialParams) (net.Conn, error) {
return nil, r.dialErr
if r.dialErr != nil {
return nil, r.dialErr
}

return &mockDialConn{}, nil
}

type mockDialConn struct {
net.Conn
}

func (c *mockDialConn) Close() error {
return nil
}
Loading

0 comments on commit ecd78ee

Please sign in to comment.