From b736126e38f316110f19c667507586be905709cb Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Thu, 23 Jun 2022 16:57:51 +0200 Subject: [PATCH 1/2] Open a new remote client when the remote site has changed in a web session --- lib/web/sessions.go | 94 ++++++++++++++++++++++++++++++--------------- 1 file changed, 62 insertions(+), 32 deletions(-) diff --git a/lib/web/sessions.go b/lib/web/sessions.go index 152dfb9220721..b5d4d4a51e5e7 100644 --- a/lib/web/sessions.go +++ b/lib/web/sessions.go @@ -61,6 +61,9 @@ type SessionContext struct { // clt holds a connection to the root auth. Note that requests made using this // client are made with the identity of the user and are NOT cached. clt *auth.Client + // remoteClientCache holds the remote clients that have been used in this + // session. + remoteClientCache // unsafeCachedAuthClient holds a read-only cache to root auth. Note this access // point cache is authenticated with the identity of the node, not of the @@ -77,9 +80,6 @@ type SessionContext struct { resources *sessionResources // session refers the web session created for the user. session types.WebSession - - mu sync.Mutex - remoteClt map[string]auth.ClientI } // String returns the text representation of this context @@ -127,19 +127,6 @@ func (c *SessionContext) validateBearerToken(ctx context.Context, token string) return nil } -func (c *SessionContext) addRemoteClient(siteName string, remoteClient auth.ClientI) { - c.mu.Lock() - defer c.mu.Unlock() - c.remoteClt[siteName] = remoteClient -} - -func (c *SessionContext) getRemoteClient(siteName string) (auth.ClientI, bool) { - c.mu.Lock() - defer c.mu.Unlock() - remoteClt, ok := c.remoteClt[siteName] - return remoteClt, ok -} - // GetClient returns the client connected to the auth server func (c *SessionContext) GetClient() (auth.ClientI, error) { return c.clt, nil @@ -167,7 +154,7 @@ func (c *SessionContext) GetUserClient(site reversetunnel.RemoteSite) (auth.Clie } // check if we already have a connection to this cluster - remoteClt, ok := c.getRemoteClient(site.GetName()) + remoteClt, ok := c.getRemoteClient(site) if !ok { rClt, err := c.newRemoteClient(site) if err != nil { @@ -177,7 +164,10 @@ func (c *SessionContext) GetUserClient(site reversetunnel.RemoteSite) (auth.Clie // we'll save the remote client in our session context so we don't have to // build a new connection next time. all remote clients will be closed when // the session context is closed. - c.addRemoteClient(site.GetName(), rClt) + err = c.addRemoteClient(site, rClt) + if err != nil { + c.log.WithError(err).Info("Failed closing stale remote client for site: ", site.GetName()) + } return rClt, nil } @@ -211,7 +201,7 @@ func (c *SessionContext) tryRemoteTLSClient(cluster reversetunnel.RemoteSite) (a } _, err = clt.GetDomainName() if err != nil { - return clt, trace.Wrap(err) + return nil, trace.NewAggregate(err, clt.Close()) } return clt, nil } @@ -394,18 +384,7 @@ func (c *SessionContext) GetSessionID() string { // Close cleans up resources associated with this context and removes it // from the user context func (c *SessionContext) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - var errors []error - for _, clt := range c.remoteClt { - if err := clt.Close(); err != nil { - errors = append(errors, err) - } - } - if err := c.clt.Close(); err != nil { - errors = append(errors, err) - } - return trace.NewAggregate(errors...) + return trace.NewAggregate(c.remoteClientCache.Close(), c.clt.Close()) } // getToken returns the bearer token associated with the underlying @@ -825,7 +804,6 @@ func (s *sessionCache) newSessionContextFromSession(session types.WebSession) (* ctx := &SessionContext{ clt: userClient, unsafeCachedAuthClient: s.accessPoint, - remoteClt: make(map[string]auth.ClientI), user: session.GetUser(), session: session, parent: s, @@ -986,3 +964,55 @@ func (h *Handler) waitForWebSession(ctx context.Context, req types.GetWebSession } return trace.Wrap(err) } + +// remoteClientCache stores remote clients keyed by site name while also keeping +// track of the actual remote site associated with the client (in case the +// remote site has changed). Safe for concurrent access. Closes all clients and +// wipes the cache on Close. +type remoteClientCache struct { + sync.Mutex + clients map[string]struct { + auth.ClientI + reversetunnel.RemoteSite + } +} + +func (c *remoteClientCache) addRemoteClient(site reversetunnel.RemoteSite, remoteClient auth.ClientI) error { + c.Lock() + defer c.Unlock() + if c.clients == nil { + c.clients = make(map[string]struct { + auth.ClientI + reversetunnel.RemoteSite + }) + } + var err error + if c.clients[site.GetName()].ClientI != nil { + err = c.clients[site.GetName()].ClientI.Close() + } + c.clients[site.GetName()] = struct { + auth.ClientI + reversetunnel.RemoteSite + }{remoteClient, site} + return err +} + +func (c *remoteClientCache) getRemoteClient(site reversetunnel.RemoteSite) (auth.ClientI, bool) { + c.Lock() + defer c.Unlock() + remoteClt, ok := c.clients[site.GetName()] + return remoteClt.ClientI, ok && remoteClt.RemoteSite == site +} + +func (c *remoteClientCache) Close() error { + c.Lock() + defer c.Unlock() + + errors := make([]error, 0, len(c.clients)) + for _, clt := range c.clients { + errors = append(errors, clt.ClientI.Close()) + } + c.clients = nil + + return trace.NewAggregate(errors...) +} From c394d7b74f163eadd8ac31c60872fd0efbd30eea Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Tue, 28 Jun 2022 13:42:15 +0200 Subject: [PATCH 2/2] Test coverage for remoteClientCache --- lib/web/sessions_test.go | 85 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 lib/web/sessions_test.go diff --git a/lib/web/sessions_test.go b/lib/web/sessions_test.go new file mode 100644 index 0000000000000..419401e5eeefc --- /dev/null +++ b/lib/web/sessions_test.go @@ -0,0 +1,85 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package web + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/reversetunnel" +) + +func TestRemoteClientCache(t *testing.T) { + t.Parallel() + + openCount := 0 + cache := remoteClientCache{} + + sa1 := newMockRemoteSite("a") + sa2 := newMockRemoteSite("a") + sb := newMockRemoteSite("b") + + err1 := errors.New("c1") + err2 := errors.New("c2") + + require.NoError(t, cache.addRemoteClient(sa1, newMockClientI(&openCount, err1))) + require.Equal(t, 1, openCount) + + require.ErrorIs(t, cache.addRemoteClient(sa2, newMockClientI(&openCount, nil)), err1) + require.Equal(t, 1, openCount) + + require.NoError(t, cache.addRemoteClient(sb, newMockClientI(&openCount, err2))) + require.Equal(t, 2, openCount) + + var aggrErr trace.Aggregate + require.ErrorAs(t, cache.Close(), &aggrErr) + require.ElementsMatch(t, []error{err2}, aggrErr.Errors()) + + require.Zero(t, openCount) +} + +func newMockRemoteSite(name string) reversetunnel.RemoteSite { + return &mockRemoteSite{name: name} +} + +type mockRemoteSite struct { + reversetunnel.RemoteSite + name string +} + +func (m *mockRemoteSite) GetName() string { + return m.name +} + +func newMockClientI(openCount *int, closeErr error) auth.ClientI { + *openCount++ + return &mockClientI{openCount: openCount, closeErr: closeErr} +} + +type mockClientI struct { + auth.ClientI + openCount *int + closeErr error +} + +func (m *mockClientI) Close() error { + *m.openCount-- + return m.closeErr +}