From e44efcd98365897d132e18f754c79c1601c8cbce Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Thu, 28 Apr 2022 16:04:46 -0300 Subject: [PATCH] Dial only application servers that serve the requested application (#12217) (#12300) * feat(reversetunnel): add a dial counter on fake remote site * refactor(app): change `MatchHealthy` to close connections * refactor(app): change `MatchAll` order to not dial all servers * refactor(app): only close connection when dial succeeds --- lib/reversetunnel/fake.go | 9 ++ lib/web/app/handler_test.go | 238 +++++++++++++++++++++++++++++++++--- lib/web/app/match.go | 9 +- lib/web/app/match_test.go | 14 ++- lib/web/app/session.go | 7 +- 5 files changed, 257 insertions(+), 20 deletions(-) diff --git a/lib/reversetunnel/fake.go b/lib/reversetunnel/fake.go index ced1352d6f2c9..536ccaca3cf58 100644 --- a/lib/reversetunnel/fake.go +++ b/lib/reversetunnel/fake.go @@ -18,6 +18,7 @@ package reversetunnel import ( "net" + "sync/atomic" "github.com/gravitational/teleport/lib/auth" @@ -57,6 +58,8 @@ type FakeRemoteSite struct { AccessPoint auth.RemoteProxyAccessPoint // OfflineTunnels is a list of server IDs that will return connection error. OfflineTunnels map[string]struct{} + // connCounter count how many connection requests the remote received. + connCounter int64 } // CachingAccessPoint returns caching auth server client. @@ -71,6 +74,8 @@ func (s *FakeRemoteSite) GetName() string { // Dial returns the connection to the remote site. func (s *FakeRemoteSite) Dial(params DialParams) (net.Conn, error) { + atomic.AddInt64(&s.connCounter, 1) + if _, ok := s.OfflineTunnels[params.ServerID]; ok { return nil, trace.ConnectionProblem(nil, "server %v tunnel is offline", params.ServerID) @@ -83,3 +88,7 @@ func (s *FakeRemoteSite) Dial(params DialParams) (net.Conn, error) { func (s *FakeRemoteSite) Close() error { return nil } + +func (s *FakeRemoteSite) DialCount() int64 { + return atomic.LoadInt64(&s.connCounter) +} diff --git a/lib/web/app/handler_test.go b/lib/web/app/handler_test.go index d04d565c5ba30..613c8d098e684 100644 --- a/lib/web/app/handler_test.go +++ b/lib/web/app/handler_test.go @@ -20,7 +20,11 @@ import ( "bytes" "context" "crypto/tls" + "crypto/x509/pkix" "encoding/json" + "fmt" + "io" + "net" "net/http" "net/http/httptest" "net/url" @@ -28,10 +32,17 @@ import ( "time" "github.com/gravitational/teleport/api/types" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -45,6 +56,7 @@ func TestAuthPOST(t *testing.T) { cookieValue = "5588e2be54a2834b4f152c56bafcd789f53b15477129d2ab4044e9a3c1bf0f3b" ) + fakeClock := clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC)) tests := []struct { desc string stateInRequest string @@ -77,7 +89,7 @@ func TestAuthPOST(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - p := setup(t, test.sessionError) + p := setup(t, fakeClock, mockAuthClient{sessionError: test.sessionError}, nil) req, err := json.Marshal(fragmentRequest{ StateValue: test.stateInRequest, @@ -85,25 +97,97 @@ func TestAuthPOST(t *testing.T) { }) require.NoError(t, err) - status := p.makeRequest(t, "POST", "/x-teleport-auth", test.stateInCookie, req) + status, _ := p.makeRequest(t, "POST", "/x-teleport-auth", AuthStateCookieName, test.stateInCookie, req) require.Equal(t, test.outStatusCode, status) }) } } -type testServer struct { - serverURL *url.URL -} +func TestMatchApplicationServers(t *testing.T) { + clusterName := "test-cluster" + publicAddr := "app.example.com" + + // Generate CA TLS key and cert with the cluster and application DNS. + key, cert, err := tlsca.GenerateSelfSignedCA( + pkix.Name{CommonName: clusterName}, + []string{publicAddr, apiutils.EncodeClusterName(clusterName)}, + defaults.CATTL, + ) + require.NoError(t, err) -func setup(t *testing.T, sessionError error) *testServer { fakeClock := clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC)) authClient := mockAuthClient{ - sessionError: sessionError, + clusterName: clusterName, + appSession: createAppSession(t, fakeClock, key, cert, clusterName, publicAddr), + // Three app servers with same public addr from our session, and three + // that won't match. + appServers: []types.AppServer{ + createAppServer(t, publicAddr), + createAppServer(t, publicAddr), + createAppServer(t, publicAddr), + createAppServer(t, "random.example.com"), + createAppServer(t, "random2.example.com"), + createAppServer(t, "random3.example.com"), + }, + caKey: key, + caCert: cert, + } + + // Create a fake remote site and tunnel. + fakeRemoteSiteConnCh := make(chan net.Conn) + fakeRemoteSite := &reversetunnel.FakeRemoteSite{ + Name: clusterName, + ConnCh: fakeRemoteSiteConnCh, + AccessPoint: authClient, } + tunnel := &reversetunnel.FakeServer{ + Sites: []reversetunnel.RemoteSite{ + fakeRemoteSite, + }, + } + + // Create a httptest server to serve the application requests. It must serve + // TLS content with the generated certificate. + tlsCert, err := tls.X509KeyPair(cert, key) + require.NoError(t, err) + expectedContent := "Hello from application" + server := &httptest.Server{ + TLS: &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + }, + Listener: &fakeRemoteListener{fakeRemoteSite}, + Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprint(w, expectedContent) + })}, + } + server.StartTLS() + + // Teardown the remote site and the httptest server. + t.Cleanup(func() { + close(fakeRemoteSiteConnCh) + server.Close() + }) + + p := setup(t, fakeClock, authClient, tunnel) + status, content := p.makeRequest(t, "GET", "/", CookieName, "abc", []byte{}) + require.Equal(t, http.StatusOK, status) + // Remote site should receive only 4 connection requests: 3 from the + // MatchHealthy and 1 from the transport. + require.Equal(t, int64(4), fakeRemoteSite.DialCount()) + // Guarantee the request was returned by the httptest server. + require.Equal(t, expectedContent, content) +} + +type testServer struct { + serverURL *url.URL +} + +func setup(t *testing.T, clock clockwork.FakeClock, authClient auth.ClientI, proxyClient reversetunnel.Tunnel) *testServer { appHandler, err := NewHandler(context.Background(), &HandlerConfig{ - Clock: fakeClock, + Clock: clock, AuthClient: authClient, AccessPoint: authClient, + ProxyClient: proxyClient, CipherSuites: utils.DefaultCipherSuites(), }) require.NoError(t, err) @@ -119,7 +203,7 @@ func setup(t *testing.T, sessionError error) *testServer { } } -func (p *testServer) makeRequest(t *testing.T, method, endpoint, stateInCookie string, reqBody []byte) int { +func (p *testServer) makeRequest(t *testing.T, method, endpoint, cookieName, cookieValue string, reqBody []byte) (int, string) { u := url.URL{ Scheme: p.serverURL.Scheme, Host: p.serverURL.Host, @@ -128,10 +212,10 @@ func (p *testServer) makeRequest(t *testing.T, method, endpoint, stateInCookie s req, err := http.NewRequest(http.MethodPost, u.String(), bytes.NewBuffer(reqBody)) require.NoError(t, err) - // Attach state token cookie. + // Attach the cookie. req.AddCookie(&http.Cookie{ - Name: AuthStateCookieName, - Value: stateInCookie, + Name: cookieName, + Value: cookieValue, }) // Issue request. @@ -145,29 +229,151 @@ func (p *testServer) makeRequest(t *testing.T, method, endpoint, stateInCookie s return http.ErrUseLastResponse }, } + resp, err := client.Do(req) require.NoError(t, err) + + content, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) - return resp.StatusCode + return resp.StatusCode, string(content) } type mockAuthClient struct { auth.ClientI + clusterName string + appSession types.WebSession sessionError error + appServers []types.AppServer + caKey []byte + caCert []byte } type mockClusterName struct { types.ClusterName + name string } -func (c mockAuthClient) GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error) { - return mockClusterName{}, nil +func (c mockAuthClient) GetClusterName(_ ...services.MarshalOption) (types.ClusterName, error) { + return mockClusterName{name: c.clusterName}, nil } func (n mockClusterName) GetClusterName() string { + if n.name != "" { + return n.name + } + return "local-cluster" } func (c mockAuthClient) GetAppSession(context.Context, types.GetAppSessionRequest) (types.WebSession, error) { - return nil, c.sessionError + return c.appSession, c.sessionError +} + +func (c mockAuthClient) GetApplicationServers(_ context.Context, _ string) ([]types.AppServer, error) { + return c.appServers, nil +} + +func (c mockAuthClient) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error) { + ca, err := types.NewCertAuthority(types.CertAuthoritySpecV2{ + Type: types.HostCA, + ClusterName: c.clusterName, + ActiveKeys: types.CAKeySet{ + TLS: []*types.TLSKeyPair{{ + Cert: c.caCert, + Key: c.caKey, + }}, + }, + }) + if err != nil { + return nil, err + } + + return ca, nil +} + +// fakeRemoteListener Implements a `net.Listener` that return `net.Conn` from +// the `FakeRemoteSite`. +type fakeRemoteListener struct { + fakeRemote *reversetunnel.FakeRemoteSite +} + +func (r *fakeRemoteListener) Accept() (net.Conn, error) { + conn, ok := <-r.fakeRemote.ConnCh + if !ok { + return nil, fmt.Errorf("remote closed") + } + + return conn, nil + +} + +func (r *fakeRemoteListener) Close() error { + return nil +} + +func (r *fakeRemoteListener) Addr() net.Addr { + return &net.IPAddr{} +} + +// createAppSession generates a WebSession for an application. +func createAppSession(t *testing.T, clock clockwork.FakeClock, caKey, caCert []byte, clusterName, publicAddr string) types.WebSession { + tlsCA, err := tlsca.FromKeys(caCert, caKey) + require.NoError(t, err) + + // Generate the identity with a `RouteToApp` option. + subj, err := (&tlsca.Identity{ + Username: "testuser", + Groups: []string{"access"}, + RouteToApp: tlsca.RouteToApp{ + PublicAddr: publicAddr, + ClusterName: clusterName, + Name: "testapp", + }, + }).Subject() + require.NoError(t, err) + + // Generate public and private keys for the application request certificate. + priv, pub, err := testauthority.New().GetNewKeyPairFromPool() + require.NoError(t, err) + cryptoPubKey, err := sshutils.CryptoPublicKey(pub) + require.NoError(t, err) + + cert, err := tlsCA.GenerateCertificate(tlsca.CertificateRequest{ + Clock: clock, + PublicKey: cryptoPubKey, + Subject: subj, + NotAfter: clock.Now().Add(5 * time.Minute), + }) + require.NoError(t, err) + + appSession, err := types.NewWebSession(uuid.New().String(), types.KindAppSession, types.WebSessionSpecV2{ + User: "testuser", + Priv: priv, + TLSCert: cert, + Expires: clock.Now().Add(5 * time.Minute), + }) + require.NoError(t, err) + + return appSession +} + +func createAppServer(t *testing.T, publicAddr string) types.AppServer { + appName := uuid.New().String() + appServer, err := types.NewAppServerV3( + types.Metadata{Name: appName}, + types.AppServerSpecV3{ + HostID: uuid.New().String(), + App: &types.AppV3{ + Metadata: types.Metadata{Name: appName}, + Spec: types.AppSpecV3{ + URI: "localhost", + PublicAddr: publicAddr, + }, + }, + }, + ) + require.NoError(t, err) + return appServer } diff --git a/lib/web/app/match.go b/lib/web/app/match.go index 1be2446f34a90..968ee98acc0ae 100644 --- a/lib/web/app/match.go +++ b/lib/web/app/match.go @@ -85,8 +85,13 @@ func MatchName(name string) Matcher { // doesn't return any error. func MatchHealthy(proxyClient reversetunnel.Tunnel, identity *tlsca.Identity) Matcher { return func(appServer types.AppServer) bool { - _, err := dialAppServer(proxyClient, identity, appServer) - return err == nil + conn, err := dialAppServer(proxyClient, identity, appServer) + if err != nil { + return false + } + + conn.Close() + return true } } diff --git a/lib/web/app/match_test.go b/lib/web/app/match_test.go index fb2eda2fe416b..b4045863e35e3 100644 --- a/lib/web/app/match_test.go +++ b/lib/web/app/match_test.go @@ -93,5 +93,17 @@ type mockRemoteSite struct { } func (r *mockRemoteSite) Dial(_ reversetunnel.DialParams) (net.Conn, error) { - return nil, r.dialErr + if r.dialErr != nil { + return nil, r.dialErr + } + + return &mockDialConn{}, nil +} + +type mockDialConn struct { + net.Conn +} + +func (c *mockDialConn) Close() error { + return nil } diff --git a/lib/web/app/session.go b/lib/web/app/session.go index 736ce113dfd5f..49b8fe954a579 100644 --- a/lib/web/app/session.go +++ b/lib/web/app/session.go @@ -69,7 +69,12 @@ func (h *Handler) newSession(ctx context.Context, ws types.WebSession) (*session // server (in cases where there are no healthy servers). This process might // take an additional time to execute, but since it is cached, only a few // requests need to perform it. - servers, err := Match(ctx, accessPoint, MatchAll(MatchHealthy(h.c.ProxyClient, identity), MatchPublicAddr(identity.RouteToApp.PublicAddr))) + servers, err := Match(ctx, accessPoint, MatchAll( + MatchPublicAddr(identity.RouteToApp.PublicAddr), + // NOTE: Try to leave this matcher as the last one to dial only the + // application servers that match the requested application. + MatchHealthy(h.c.ProxyClient, identity), + )) if err != nil { return nil, trace.Wrap(err) }