Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split auth.AccessPoint into variant specific interfaces #8471

Merged
merged 3 commits into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
909 changes: 728 additions & 181 deletions lib/auth/api.go

Large diffs are not rendered by default.

28 changes: 13 additions & 15 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2644,32 +2644,32 @@ func (a *Server) GetToken(ctx context.Context, token string) (types.ProvisionTok
return a.GetCache().GetToken(ctx, token)
}

// GetRoles is a part of auth.AccessPoint implementation
// GetRoles returns roles from the cache
func (a *Server) GetRoles(ctx context.Context) ([]types.Role, error) {
return a.GetCache().GetRoles(ctx)
}

// GetRole is a part of auth.AccessPoint implementation
// GetRole returns a role from the cache
func (a *Server) GetRole(ctx context.Context, name string) (types.Role, error) {
return a.GetCache().GetRole(ctx, name)
}

// GetNamespace returns namespace
// GetNamespace returns a namespace from the cache
func (a *Server) GetNamespace(name string) (*types.Namespace, error) {
return a.GetCache().GetNamespace(name)
}

// GetNamespaces is a part of auth.AccessPoint implementation
// GetNamespaces returns namespaces from the cache
func (a *Server) GetNamespaces() ([]types.Namespace, error) {
return a.GetCache().GetNamespaces()
}

// GetNodes is a part of auth.AccessPoint implementation
// GetNodes returns nodes from the cache
func (a *Server) GetNodes(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.Server, error) {
return a.GetCache().GetNodes(ctx, namespace, opts...)
}

// ListNodes is a part of auth.AccessPoint implementation
// ListNodes lists nodes from the cache
func (a *Server) ListNodes(ctx context.Context, req proto.ListNodesRequest) ([]types.Server, string, error) {
return a.GetCache().ListNodes(ctx, req)
}
Expand Down Expand Up @@ -2700,34 +2700,32 @@ func (a *Server) IterateNodePages(ctx context.Context, req proto.ListNodesReques
}
}

// GetReverseTunnels is a part of auth.AccessPoint implementation
// GetReverseTunnels returns reverse tunnels from the cache
func (a *Server) GetReverseTunnels(opts ...services.MarshalOption) ([]types.ReverseTunnel, error) {
return a.GetCache().GetReverseTunnels(opts...)
}

// GetProxies is a part of auth.AccessPoint implementation
// GetProxies returns proxies from the cache
func (a *Server) GetProxies() ([]types.Server, error) {
return a.GetCache().GetProxies()
}

// GetUser is a part of auth.AccessPoint implementation.
// GetUser returns a user from the cache
func (a *Server) GetUser(name string, withSecrets bool) (user types.User, err error) {
return a.GetCache().GetUser(name, withSecrets)
}

// GetUsers is a part of auth.AccessPoint implementation
// GetUsers returns users from the cache
func (a *Server) GetUsers(withSecrets bool) (users []types.User, err error) {
return a.GetCache().GetUsers(withSecrets)
}

// GetTunnelConnections is a part of auth.AccessPoint implementation
// GetTunnelConnections are not using recent cache as they are designed
// to be called periodically and always return fresh data
func (a *Server) GetTunnelConnections(clusterName string, opts ...services.MarshalOption) ([]types.TunnelConnection, error) {
return a.GetCache().GetTunnelConnections(clusterName, opts...)
}

// GetAllTunnelConnections is a part of auth.AccessPoint implementation
// GetAllTunnelConnections are not using recent cache, as they are designed
// to be called periodically and always return fresh data
func (a *Server) GetAllTunnelConnections(opts ...services.MarshalOption) (conns []types.TunnelConnection, err error) {
Expand Down Expand Up @@ -2770,12 +2768,12 @@ func (a *Server) modeStreamer(ctx context.Context) (events.Streamer, error) {
return a.streamer, nil
}

// GetAppServers is a part of the auth.AccessPoint implementation.
// GetAppServers returns app servers from the cache
func (a *Server) GetAppServers(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.Server, error) {
return a.GetCache().GetAppServers(ctx, namespace, opts...)
}

// GetAppSession is a part of the auth.AccessPoint implementation.
// GetAppSession returns app sessions from the cache
func (a *Server) GetAppSession(ctx context.Context, req types.GetAppSessionRequest) (types.WebSession, error) {
return a.GetCache().GetAppSession(ctx, req)
}
Expand Down Expand Up @@ -3663,7 +3661,7 @@ func isHTTPS(u string) error {

// WithClusterCAs returns a TLS hello callback that returns a copy of the provided
// TLS config with client CAs pool of the specified cluster.
func WithClusterCAs(tlsConfig *tls.Config, ap AccessPoint, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
func WithClusterCAs(tlsConfig *tls.Config, ap AccessCache, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
return func(info *tls.ClientHelloInfo) (*tls.Config, error) {
var clusterName string
var err error
Expand Down
37 changes: 32 additions & 5 deletions lib/auth/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewBuiltinRoleContext(role types.SystemRole) (*Context, error) {
}

// NewAuthorizer returns new authorizer using backends
func NewAuthorizer(clusterName string, accessPoint ReadAccessPoint, lockWatcher *services.LockWatcher) (Authorizer, error) {
func NewAuthorizer(clusterName string, accessPoint AuthorizerAccessPoint, lockWatcher *services.LockWatcher) (Authorizer, error) {
if clusterName == "" {
return nil, trace.BadParameter("missing parameter clusterName")
}
Expand All @@ -68,16 +68,43 @@ type Authorizer interface {
Authorize(ctx context.Context) (*Context, error)
}

// AuthorizerAccessPoint is the access point contract required by an Authorizer
type AuthorizerAccessPoint interface {
// GetAuthPreference returns the cluster authentication configuration.
GetAuthPreference(ctx context.Context) (types.AuthPreference, error)

// GetRole returns role by name
GetRole(ctx context.Context, name string) (types.Role, error)

// GetUser returns a services.User for this cluster.
GetUser(name string, withSecrets bool) (types.User, error)

// GetCertAuthority returns cert authority by id
GetCertAuthority(id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error)

// GetCertAuthorities returns a list of cert authorities
GetCertAuthorities(caType types.CertAuthType, loadKeys bool, opts ...services.MarshalOption) ([]types.CertAuthority, error)

// GetClusterAuditConfig returns cluster audit configuration.
GetClusterAuditConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterAuditConfig, error)

// GetClusterNetworkingConfig returns cluster networking configuration.
GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error)

// GetSessionRecordingConfig returns session recording configuration.
GetSessionRecordingConfig(ctx context.Context, opts ...services.MarshalOption) (types.SessionRecordingConfig, error)
}

// authorizer creates new local authorizer
type authorizer struct {
clusterName string
accessPoint ReadAccessPoint
accessPoint AuthorizerAccessPoint
lockWatcher *services.LockWatcher
}

// Context is authorization context
type Context struct {
// User is the user name
// User is the username
User types.User
// Checker is access checker
Checker services.AccessChecker
Expand Down Expand Up @@ -669,7 +696,7 @@ func contextForBuiltinRole(r BuiltinRole, recConfig types.SessionRecordingConfig
}, nil
}

func contextForLocalUser(u LocalUser, accessPoint ReadAccessPoint) (*Context, error) {
func contextForLocalUser(u LocalUser, accessPoint AuthorizerAccessPoint) (*Context, error) {
// User has to be fetched to check if it's a blocked username
user, err := accessPoint.GetUser(u.Username, false)
if err != nil {
Expand All @@ -684,7 +711,7 @@ func contextForLocalUser(u LocalUser, accessPoint ReadAccessPoint) (*Context, er
return nil, trace.Wrap(err)
}
// Override roles and traits from the local user based on the identity roles
// and traits, this is done to prevent potential conflict. Imagine a scenairo
// and traits, this is done to prevent potential conflict. Imagine a scenario
// when SSO user has left the company, but local user entry remained with old
// privileged roles. New user with the same name has been onboarded and would
// have derived the roles from the stale user entry. This code prevents
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (s *Server) CreateAppSession(ctx context.Context, req types.CreateAppSessio

// WaitForAppSession will block until the requested application session shows up in the
// cache or a timeout occurs.
func WaitForAppSession(ctx context.Context, sessionID, user string, ap AccessPoint) error {
func WaitForAppSession(ctx context.Context, sessionID, user string, ap ReadProxyAccessPoint) error {
_, err := ap.GetAppSession(ctx, types.GetAppSessionRequest{SessionID: sessionID})
if err == nil {
return nil
Expand Down
4 changes: 2 additions & 2 deletions lib/auth/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ func (s *TLSSuite) TestWebSessionWithApprovedAccessRequestAndSwitchback(c *check
c.Assert(err, check.IsNil)

// Roles extracted from cert should contain the initial role and the role assigned with access request.
roles, _, err := services.ExtractFromCertificate(clt, sshcert)
roles, _, err := services.ExtractFromCertificate(sshcert)
c.Assert(err, check.IsNil)
c.Assert(roles, check.HasLen, 2)

Expand Down Expand Up @@ -1560,7 +1560,7 @@ func (s *TLSSuite) TestWebSessionWithApprovedAccessRequestAndSwitchback(c *check
sshcert, err = sshutils.ParseCertificate(sess2.GetPub())
c.Assert(err, check.IsNil)

roles, _, err = services.ExtractFromCertificate(clt, sshcert)
roles, _, err = services.ExtractFromCertificate(sshcert)
c.Assert(err, check.IsNil)
c.Assert(roles, check.DeepEquals, []string{initialRole})
}
Expand Down
31 changes: 13 additions & 18 deletions lib/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ func ForNode(cfg Config) Config {
{Kind: types.KindClusterNetworkingConfig},
{Kind: types.KindClusterAuthPreference},
{Kind: types.KindSessionRecordingConfig},
{Kind: types.KindUser},
{Kind: types.KindRole},
// Node only needs to "know" about default
// namespace events to avoid matching too much
Expand Down Expand Up @@ -282,7 +281,7 @@ func ForWindowsDesktop(cfg Config) Config {
// for cache
type SetupConfigFn func(c Config) Config

// Cache implements auth.AccessPoint interface and remembers
// Cache implements auth.Cache interface and remembers
// the previously returned upstream value for each API call.
//
// This which can be used if the upstream AccessPoint goes offline
Expand Down Expand Up @@ -1197,7 +1196,7 @@ func (c *Cache) GetClusterName(opts ...services.MarshalOption) (types.ClusterNam
return rg.clusterConfig.GetClusterName(opts...)
}

// GetRoles is a part of auth.AccessPoint implementation
// GetRoles is a part of auth.Cache implementation
func (c *Cache) GetRoles(ctx context.Context) ([]types.Role, error) {
rg, err := c.read()
if err != nil {
Expand All @@ -1207,7 +1206,7 @@ func (c *Cache) GetRoles(ctx context.Context) ([]types.Role, error) {
return rg.access.GetRoles(ctx)
}

// GetRole is a part of auth.AccessPoint implementation
// GetRole is a part of auth.Cache implementation
func (c *Cache) GetRole(ctx context.Context, name string) (types.Role, error) {
rg, err := c.read()
if err != nil {
Expand Down Expand Up @@ -1237,7 +1236,7 @@ func (c *Cache) GetNamespace(name string) (*types.Namespace, error) {
return rg.presence.GetNamespace(name)
}

// GetNamespaces is a part of auth.AccessPoint implementation
// GetNamespaces is a part of auth.Cache implementation
func (c *Cache) GetNamespaces() ([]types.Namespace, error) {
rg, err := c.read()
if err != nil {
Expand All @@ -1263,7 +1262,7 @@ type getNodesCacheKey struct {

var _ map[getNodesCacheKey]struct{} // compile-time hashability check

// GetNodes is a part of auth.AccessPoint implementation
// GetNodes is a part of auth.Cache implementation
func (c *Cache) GetNodes(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.Server, error) {
rg, err := c.read()
if err != nil {
Expand Down Expand Up @@ -1295,7 +1294,7 @@ func (c *Cache) GetNodes(ctx context.Context, namespace string, opts ...services
return rg.presence.GetNodes(ctx, namespace, opts...)
}

// ListNodes is a part of auth.AccessPoint implementation
// ListNodes is a part of auth.Cache implementation
func (c *Cache) ListNodes(ctx context.Context, req proto.ListNodesRequest) ([]types.Server, string, error) {
// NOTE: we "fake" the ListNodes API here in order to take advantage of TTL-based caching of
// the GetNodes endpoint, since performing TTL-based caching on a paginated endpoint is nightmarish.
Expand Down Expand Up @@ -1353,7 +1352,7 @@ func (c *Cache) GetAuthServers() ([]types.Server, error) {
return rg.presence.GetAuthServers()
}

// GetReverseTunnels is a part of auth.AccessPoint implementation
// GetReverseTunnels is a part of auth.Cache implementation
func (c *Cache) GetReverseTunnels(opts ...services.MarshalOption) ([]types.ReverseTunnel, error) {
rg, err := c.read()
if err != nil {
Expand All @@ -1363,7 +1362,7 @@ func (c *Cache) GetReverseTunnels(opts ...services.MarshalOption) ([]types.Rever
return rg.presence.GetReverseTunnels(opts...)
}

// GetProxies is a part of auth.AccessPoint implementation
// GetProxies is a part of auth.Cache implementation
func (c *Cache) GetProxies() ([]types.Server, error) {
rg, err := c.read()
if err != nil {
Expand Down Expand Up @@ -1431,7 +1430,7 @@ func (c *Cache) GetRemoteCluster(clusterName string) (types.RemoteCluster, error
return rg.presence.GetRemoteCluster(clusterName)
}

// GetUser is a part of auth.AccessPoint implementation.
// GetUser is a part of auth.Cache implementation.
func (c *Cache) GetUser(name string, withSecrets bool) (user types.User, err error) {
if withSecrets { // cache never tracks user secrets
return c.Config.Users.GetUser(name, withSecrets)
Expand All @@ -1455,7 +1454,7 @@ func (c *Cache) GetUser(name string, withSecrets bool) (user types.User, err err
return user, trace.Wrap(err)
}

// GetUsers is a part of auth.AccessPoint implementation
// GetUsers is a part of auth.Cache implementation
func (c *Cache) GetUsers(withSecrets bool) (users []types.User, err error) {
if withSecrets { // cache never tracks user secrets
return c.Users.GetUsers(withSecrets)
Expand All @@ -1468,9 +1467,7 @@ func (c *Cache) GetUsers(withSecrets bool) (users []types.User, err error) {
return rg.users.GetUsers(withSecrets)
}

// GetTunnelConnections is a part of auth.AccessPoint implementation
// GetTunnelConnections are not using recent cache as they are designed
// to be called periodically and always return fresh data
// GetTunnelConnections is a part of auth.Cache implementation
func (c *Cache) GetTunnelConnections(clusterName string, opts ...services.MarshalOption) ([]types.TunnelConnection, error) {
rg, err := c.read()
if err != nil {
Expand All @@ -1480,9 +1477,7 @@ func (c *Cache) GetTunnelConnections(clusterName string, opts ...services.Marsha
return rg.presence.GetTunnelConnections(clusterName, opts...)
}

// GetAllTunnelConnections is a part of auth.AccessPoint implementation
// GetAllTunnelConnections are not using recent cache, as they are designed
// to be called periodically and always return fresh data
// GetAllTunnelConnections is a part of auth.Cache implementation
func (c *Cache) GetAllTunnelConnections(opts ...services.MarshalOption) (conns []types.TunnelConnection, err error) {
rg, err := c.read()
if err != nil {
Expand All @@ -1492,7 +1487,7 @@ func (c *Cache) GetAllTunnelConnections(opts ...services.MarshalOption) (conns [
return rg.presence.GetAllTunnelConnections(opts...)
}

// GetKubeServices is a part of auth.AccessPoint implementation
// GetKubeServices is a part of auth.Cache implementation
func (c *Cache) GetKubeServices(ctx context.Context) ([]types.Server, error) {
rg, err := c.read()
if err != nil {
Expand Down
12 changes: 0 additions & 12 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1167,18 +1167,6 @@ func (tc *TeleportClient) LoadKeyForClusterWithReissue(ctx context.Context, clus
return nil
}

// accessPoint returns access point based on the cache policy
func (tc *TeleportClient) accessPoint(clt auth.AccessPoint, proxyHostPort string, clusterName string) (auth.AccessPoint, error) {
// If no caching policy was set or on Windows (where Teleport does not
// support file locking at the moment), return direct access to the access
// point.
if tc.CachePolicy == nil || runtime.GOOS == constants.WindowsOS {
log.Debugf("not using caching access point")
return clt, nil
}
return clt, nil
}

// LocalAgent is a getter function for the client's local agent
func (tc *TeleportClient) LocalAgent() *LocalKeyAgent {
return tc.localAgent
Expand Down
6 changes: 3 additions & 3 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ func (proxy *ProxyClient) GetDatabaseServers(ctx context.Context, namespace stri
// CurrentClusterAccessPoint returns cluster access point to the currently
// selected cluster and is used for discovery
// and could be cached based on the access policy
func (proxy *ProxyClient) CurrentClusterAccessPoint(ctx context.Context, quiet bool) (auth.AccessPoint, error) {
func (proxy *ProxyClient) CurrentClusterAccessPoint(ctx context.Context, quiet bool) (auth.ClientI, error) {
// get the current cluster:
cluster, err := proxy.currentCluster()
if err != nil {
Expand All @@ -635,15 +635,15 @@ func (proxy *ProxyClient) CurrentClusterAccessPoint(ctx context.Context, quiet b

// ClusterAccessPoint returns cluster access point used for discovery
// and could be cached based on the access policy
func (proxy *ProxyClient) ClusterAccessPoint(ctx context.Context, clusterName string, quiet bool) (auth.AccessPoint, error) {
func (proxy *ProxyClient) ClusterAccessPoint(ctx context.Context, clusterName string, quiet bool) (auth.ClientI, error) {
if clusterName == "" {
return nil, trace.BadParameter("parameter clusterName is missing")
}
clt, err := proxy.ConnectToCluster(ctx, clusterName, quiet)
if err != nil {
return nil, trace.Wrap(err)
}
return proxy.teleportClient.accessPoint(clt, proxy.proxyAddress, clusterName)
return clt, nil
}

// ConnectToCurrentCluster connects to the auth server of the currently selected
Expand Down
2 changes: 1 addition & 1 deletion lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ type ForwarderConfig struct {
// AuthClient is a auth server client.
AuthClient auth.ClientI
// CachingAuthClient is a caching auth server client for read-only access.
CachingAuthClient auth.AccessPoint
CachingAuthClient auth.ReadKubernetesAccessPoint
// StreamEmitter is used to create audit streams
// and emit audit events
StreamEmitter events.StreamEmitter
Expand Down
2 changes: 1 addition & 1 deletion lib/kube/proxy/forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ type mockRemoteSite struct {
func (s mockRemoteSite) GetName() string { return s.name }

type mockAccessPoint struct {
auth.AccessPoint
auth.KubernetesAccessPoint

netConfig types.ClusterNetworkingConfig
recordingConfig types.SessionRecordingConfig
Expand Down
Loading