diff --git a/lib/cache/cache.go b/lib/cache/cache.go index 2477b67b5d81d..9bf30b21e0820 100644 --- a/lib/cache/cache.go +++ b/lib/cache/cache.go @@ -48,6 +48,7 @@ func ForAuth(cfg Config) Config { {Kind: services.KindProxy}, {Kind: services.KindReverseTunnel}, {Kind: services.KindTunnelConnection}, + {Kind: services.KindAccessRequest}, } cfg.QueueSize = defaults.AuthQueueSize return cfg @@ -118,6 +119,7 @@ type Cache struct { provisionerCache services.Provisioner usersCache services.UsersService accessCache services.Access + dynamicAccessCache services.DynamicAccessExt presenceCache services.Presence eventsCache services.Events @@ -145,6 +147,8 @@ type Config struct { Users services.UsersService // Access is an access service Access services.Access + // DynamicAccess is a dynamic access service + DynamicAccess services.DynamicAccess // Presence is a presence service Presence services.Presence // Backend is a backend for local cache @@ -274,6 +278,7 @@ func New(config Config) (*Cache, error) { provisionerCache: local.NewProvisioningService(wrapper), usersCache: local.NewIdentityService(wrapper), accessCache: local.NewAccessService(wrapper), + dynamicAccessCache: local.NewDynamicAccessService(wrapper), presenceCache: local.NewPresenceService(wrapper), eventsCache: local.NewEventsService(config.Backend), Entry: log.WithFields(log.Fields{ diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index bcfaa44379c98..097016281e10b 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -68,9 +68,10 @@ type testPack struct { provisionerS services.Provisioner clusterConfigS services.ClusterConfiguration - usersS services.UsersService - accessS services.Access - presenceS services.Presence + usersS services.UsersService + accessS services.Access + dynamicAccessS services.DynamicAccess + presenceS services.Presence } func (t *testPack) Close() { @@ -127,6 +128,7 @@ func (s *CacheSuite) newPackWithoutCache(c *check.C, setupConfig SetupConfigFn) p.presenceS = local.NewPresenceService(p.backend) p.usersS = local.NewIdentityService(p.backend) p.accessS = local.NewAccessService(p.backend) + p.dynamicAccessS = local.NewDynamicAccessService(p.backend) return p } @@ -143,6 +145,7 @@ func (s *CacheSuite) newPack(c *check.C, setupConfig func(c Config) Config) *tes Trust: p.trustS, Users: p.usersS, Access: p.accessS, + DynamicAccess: p.dynamicAccessS, Presence: p.presenceS, RetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, @@ -206,6 +209,7 @@ func (s *CacheSuite) TestOnlyRecentInit(c *check.C) { Trust: p.trustS, Users: p.usersS, Access: p.accessS, + DynamicAccess: p.dynamicAccessS, Presence: p.presenceS, RetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, @@ -355,6 +359,7 @@ func (s *CacheSuite) preferRecent(c *check.C) { Trust: p.trustS, Users: p.usersS, Access: p.accessS, + DynamicAccess: p.dynamicAccessS, Presence: p.presenceS, RetryPeriod: 200 * time.Millisecond, EventsC: p.eventsC, diff --git a/lib/cache/collections.go b/lib/cache/collections.go index 533ef07e4f86d..207cc5eb68d27 100644 --- a/lib/cache/collections.go +++ b/lib/cache/collections.go @@ -17,6 +17,8 @@ limitations under the License. package cache import ( + "context" + "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" @@ -108,6 +110,11 @@ func setupCollections(c *Cache, watches []services.WatchKind) (map[string]collec return nil, trace.BadParameter("missing parameter Presence") } collections[watch.Kind] = &tunnelConnection{watch: watch, Cache: c} + case services.KindAccessRequest: + if c.DynamicAccess == nil { + return nil, trace.BadParameter("missing parameter DynamicAccess") + } + collections[watch.Kind] = &accessRequest{watch: watch, Cache: c} default: return nil, trace.BadParameter("resource %q is not supported", watch.Kind) } @@ -115,6 +122,69 @@ func setupCollections(c *Cache, watches []services.WatchKind) (map[string]collec return collections, nil } +type accessRequest struct { + *Cache + watch services.WatchKind +} + +// erase erases all data in the collection +func (r *accessRequest) erase() error { + if err := r.dynamicAccessCache.DeleteAllAccessRequests(context.TODO()); err != nil { + if !trace.IsNotFound(err) { + return trace.Wrap(err) + } + } + return nil +} + +func (r *accessRequest) fetch() error { + resources, err := r.DynamicAccess.GetAccessRequests(context.TODO(), services.AccessRequestFilter{}) + if err != nil { + return trace.Wrap(err) + } + if err := r.erase(); err != nil { + return trace.Wrap(err) + } + for _, resource := range resources { + if err := r.dynamicAccessCache.UpsertAccessRequest(context.TODO(), resource); err != nil { + return trace.Wrap(err) + } + } + return nil +} + +func (r *accessRequest) processEvent(event services.Event) error { + switch event.Type { + case backend.OpDelete: + err := r.dynamicAccessCache.DeleteAccessRequest(context.TODO(), event.Resource.GetName()) + if err != nil { + // resource could be missing in the cache + // expired or not created, if the first consumed + // event is delete + if !trace.IsNotFound(err) { + r.Warningf("Failed to delete resource %v.", err) + return trace.Wrap(err) + } + } + case backend.OpPut: + resource, ok := event.Resource.(*services.AccessRequestV3) + if !ok { + return trace.BadParameter("unexpected type %T", event.Resource) + } + r.setTTL(resource) + if err := r.dynamicAccessCache.UpsertAccessRequest(context.TODO(), resource); err != nil { + return trace.Wrap(err) + } + default: + r.Warningf("Skipping unsupported event type %v.", event.Type) + } + return nil +} + +func (r *accessRequest) watchKind() services.WatchKind { + return r.watch +} + type tunnelConnection struct { *Cache watch services.WatchKind diff --git a/lib/service/service.go b/lib/service/service.go index d753b0cf64341..e360c039c20ca 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -1314,6 +1314,7 @@ func (process *TeleportProcess) newAccessCache(cfg accessCacheConfig) (*cache.Ca Trust: cfg.services, Users: cfg.services, Access: cfg.services, + DynamicAccess: cfg.services, Presence: cfg.services, Component: teleport.Component(append(cfg.cacheName, process.id, teleport.ComponentCache)...), MetricComponent: teleport.Component(append(cfg.cacheName, teleport.ComponentCache)...), diff --git a/lib/services/access_request.go b/lib/services/access_request.go index 91b5b60bb9a67..c62b86155b5a2 100644 --- a/lib/services/access_request.go +++ b/lib/services/access_request.go @@ -152,6 +152,16 @@ type DynamicAccess interface { DeleteAccessRequest(ctx context.Context, reqID string) error } +// DynamicAccessExt is an extended dynamic access interface +// used to implement some auth server internals. +type DynamicAccessExt interface { + DynamicAccess + // UpsertAccessRequest creates or updates an access request. + UpsertAccessRequest(ctx context.Context, req AccessRequest) error + // DeleteAllAccessRequests deletes all existant access requests. + DeleteAllAccessRequests(ctx context.Context) error +} + // AccessRequest is a request for temporarily granted roles type AccessRequest interface { Resource diff --git a/lib/services/local/dynamic_access.go b/lib/services/local/dynamic_access.go index 6e5ac6cc9ec6f..58f597453b1ac 100644 --- a/lib/services/local/dynamic_access.go +++ b/lib/services/local/dynamic_access.go @@ -32,11 +32,11 @@ type DynamicAccessService struct { } // NewDynamicAccessService returns new dynamic access service instance -func NewDynamicAccessService(backend backend.Backend) *AccessService { - return &AccessService{Backend: backend} +func NewDynamicAccessService(backend backend.Backend) *DynamicAccessService { + return &DynamicAccessService{Backend: backend} } -func (s *AccessService) CreateAccessRequest(ctx context.Context, req services.AccessRequest) error { +func (s *DynamicAccessService) CreateAccessRequest(ctx context.Context, req services.AccessRequest) error { if err := req.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) } @@ -50,7 +50,7 @@ func (s *AccessService) CreateAccessRequest(ctx context.Context, req services.Ac return nil } -func (s *AccessService) SetAccessRequestState(ctx context.Context, name string, state services.RequestState) error { +func (s *DynamicAccessService) SetAccessRequestState(ctx context.Context, name string, state services.RequestState) error { item, err := s.Get(ctx, accessRequestKey(name)) if err != nil { if trace.IsNotFound(err) { @@ -80,7 +80,7 @@ func (s *AccessService) SetAccessRequestState(ctx context.Context, name string, return nil } -func (s *AccessService) GetAccessRequest(ctx context.Context, name string) (services.AccessRequest, error) { +func (s *DynamicAccessService) GetAccessRequest(ctx context.Context, name string) (services.AccessRequest, error) { item, err := s.Get(ctx, accessRequestKey(name)) if err != nil { if trace.IsNotFound(err) { @@ -95,7 +95,7 @@ func (s *AccessService) GetAccessRequest(ctx context.Context, name string) (serv return req, nil } -func (s *AccessService) GetAccessRequests(ctx context.Context, filter services.AccessRequestFilter) ([]services.AccessRequest, error) { +func (s *DynamicAccessService) GetAccessRequests(ctx context.Context, filter services.AccessRequestFilter) ([]services.AccessRequest, error) { // Filters which specify ID are a special case since they will match exactly zero or one // possible requests. if filter.ID != "" { @@ -129,7 +129,7 @@ func (s *AccessService) GetAccessRequests(ctx context.Context, filter services.A return requests, nil } -func (s *AccessService) DeleteAccessRequest(ctx context.Context, name string) error { +func (s *DynamicAccessService) DeleteAccessRequest(ctx context.Context, name string) error { err := s.Delete(ctx, accessRequestKey(name)) if err != nil { if trace.IsNotFound(err) { @@ -140,6 +140,24 @@ func (s *AccessService) DeleteAccessRequest(ctx context.Context, name string) er return nil } +func (s *DynamicAccessService) DeleteAllAccessRequests(ctx context.Context) error { + return trace.Wrap(s.DeleteRange(ctx, backend.Key(accessRequestsPrefix), backend.RangeEnd(backend.Key(accessRequestsPrefix)))) +} + +func (s *DynamicAccessService) UpsertAccessRequest(ctx context.Context, req services.AccessRequest) error { + if err := req.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + item, err := itemFromAccessRequest(req) + if err != nil { + return trace.Wrap(err) + } + if _, err := s.Put(ctx, item); err != nil { + return trace.Wrap(err) + } + return nil +} + func itemFromAccessRequest(req services.AccessRequest) (backend.Item, error) { value, err := services.GetAccessRequestMarshaler().MarshalAccessRequest(req) if err != nil { diff --git a/lib/services/services.go b/lib/services/services.go index 9d2dd714e7bed..076408a61ce3f 100644 --- a/lib/services/services.go +++ b/lib/services/services.go @@ -24,5 +24,6 @@ type Services interface { Events ClusterConfiguration Access + DynamicAccess Presence } diff --git a/tool/tctl/common/access_request_command.go b/tool/tctl/common/access_request_command.go index 1f0831d11e066..beb1bbc0f565d 100644 --- a/tool/tctl/common/access_request_command.go +++ b/tool/tctl/common/access_request_command.go @@ -84,6 +84,8 @@ func (c *AccessRequestCommand) TryRun(cmd string, client auth.ClientI) (match bo err = c.Deny(client) case c.requestCreate.FullCommand(): err = c.Create(client) + case c.requestDelete.FullCommand(): + err = c.Delete(client) default: return false, nil } diff --git a/tool/tsh/tsh.go b/tool/tsh/tsh.go index 467b31f2cebf1..b9cd5348ecad2 100644 --- a/tool/tsh/tsh.go +++ b/tool/tsh/tsh.go @@ -1327,20 +1327,32 @@ Loop: for { select { case event := <-watcher.Events(): - if event.Type != backend.OpPut { + switch event.Type { + case backend.OpInit: + log.Infof("Access-request watcher initialized...") continue Loop + case backend.OpPut: + r, ok := event.Resource.(*services.AccessRequestV3) + if !ok { + return trace.BadParameter("unexpected resource type %T", event.Resource) + } + if r.GetName() != req.GetName() || r.GetState().IsPending() { + log.Infof("Skipping put event id=%s,state=%s.", r.GetName(), r.GetState()) + continue Loop + } + if !r.GetState().IsApproved() { + return trace.Errorf("request %s has been set to %s", r.GetName(), r.GetState().String()) + } + return nil + case backend.OpDelete: + if event.Resource.GetName() != req.GetName() { + log.Infof("Skipping delete event id=%s", event.Resource.GetName()) + continue Loop + } + return trace.Errorf("request %s has expired or been deleted...", event.Resource.GetName()) + default: + log.Warnf("Skipping unknown event type %s", event.Type) } - r, ok := event.Resource.(*services.AccessRequestV3) - if !ok { - return trace.Errorf("unexpected resource type %T", event.Resource) - } - if r.GetName() != req.GetName() || r.GetState().IsPending() { - continue Loop - } - if !r.GetState().IsApproved() { - return trace.Errorf("request %s has been set to %s", r.GetName(), r.GetState().String()) - } - return nil case <-watcher.Done(): utils.FatalError(watcher.Error()) }