Skip to content

Commit

Permalink
Add support for access request resource to cache (#3213)
Browse files Browse the repository at this point in the history
Cache was missing support for access requests, causing
watchers to hang indefinitely without receiving events
when cache was in use.
  • Loading branch information
fspmarshall committed Dec 18, 2019
1 parent f240b71 commit ab425c1
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 22 deletions.
5 changes: 5 additions & 0 deletions lib/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func ForAuth(cfg Config) Config {
{Kind: services.KindProxy},
{Kind: services.KindReverseTunnel},
{Kind: services.KindTunnelConnection},
{Kind: services.KindAccessRequest},
}
cfg.QueueSize = defaults.AuthQueueSize
return cfg
Expand Down Expand Up @@ -118,6 +119,7 @@ type Cache struct {
provisionerCache services.Provisioner
usersCache services.UsersService
accessCache services.Access
dynamicAccessCache services.DynamicAccessExt
presenceCache services.Presence
eventsCache services.Events

Expand Down Expand Up @@ -145,6 +147,8 @@ type Config struct {
Users services.UsersService
// Access is an access service
Access services.Access
// DynamicAccess is a dynamic access service
DynamicAccess services.DynamicAccess
// Presence is a presence service
Presence services.Presence
// Backend is a backend for local cache
Expand Down Expand Up @@ -274,6 +278,7 @@ func New(config Config) (*Cache, error) {
provisionerCache: local.NewProvisioningService(wrapper),
usersCache: local.NewIdentityService(wrapper),
accessCache: local.NewAccessService(wrapper),
dynamicAccessCache: local.NewDynamicAccessService(wrapper),
presenceCache: local.NewPresenceService(wrapper),
eventsCache: local.NewEventsService(config.Backend),
Entry: log.WithFields(log.Fields{
Expand Down
11 changes: 8 additions & 3 deletions lib/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ type testPack struct {
provisionerS services.Provisioner
clusterConfigS services.ClusterConfiguration

usersS services.UsersService
accessS services.Access
presenceS services.Presence
usersS services.UsersService
accessS services.Access
dynamicAccssS services.DynamicAccess
presenceS services.Presence
}

func (t *testPack) Close() {
Expand Down Expand Up @@ -127,6 +128,7 @@ func (s *CacheSuite) newPackWithoutCache(c *check.C, setupConfig SetupConfigFn)
p.presenceS = local.NewPresenceService(p.backend)
p.usersS = local.NewIdentityService(p.backend)
p.accessS = local.NewAccessService(p.backend)
p.dynamicAccssS = local.NewDynamicAccessService(p.backend)
return p
}

Expand All @@ -143,6 +145,7 @@ func (s *CacheSuite) newPack(c *check.C, setupConfig func(c Config) Config) *tes
Trust: p.trustS,
Users: p.usersS,
Access: p.accessS,
DynamicAccess: p.dynamicAccssS,
Presence: p.presenceS,
RetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
Expand Down Expand Up @@ -206,6 +209,7 @@ func (s *CacheSuite) TestOnlyRecentInit(c *check.C) {
Trust: p.trustS,
Users: p.usersS,
Access: p.accessS,
DynamicAccess: p.dynamicAccssS,
Presence: p.presenceS,
RetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
Expand Down Expand Up @@ -355,6 +359,7 @@ func (s *CacheSuite) preferRecent(c *check.C) {
Trust: p.trustS,
Users: p.usersS,
Access: p.accessS,
DynamicAccess: p.dynamicAccssS,
Presence: p.presenceS,
RetryPeriod: 200 * time.Millisecond,
EventsC: p.eventsC,
Expand Down
70 changes: 70 additions & 0 deletions lib/cache/collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package cache

import (
"context"

"github.com/gravitational/teleport/lib/backend"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/services"
Expand Down Expand Up @@ -108,13 +110,81 @@ func setupCollections(c *Cache, watches []services.WatchKind) (map[string]collec
return nil, trace.BadParameter("missing parameter Presence")
}
collections[watch.Kind] = &tunnelConnection{watch: watch, Cache: c}
case services.KindAccessRequest:
if c.Presence == nil {
return nil, trace.BadParameter("missing parameter Presence")
}
collections[watch.Kind] = &accessRequest{watch: watch, Cache: c}
default:
return nil, trace.BadParameter("resource %q is not supported", watch.Kind)
}
}
return collections, nil
}

type accessRequest struct {
*Cache
watch services.WatchKind
}

// erase erases all data in the collection
func (r *accessRequest) erase() error {
if err := r.dynamicAccessCache.DeleteAllAccessRequests(context.TODO()); err != nil {
if !trace.IsNotFound(err) {
return trace.Wrap(err)
}
}
return nil
}

func (r *accessRequest) fetch() error {
resources, err := r.DynamicAccess.GetAccessRequests(context.TODO(), services.AccessRequestFilter{})
if err != nil {
return trace.Wrap(err)
}
if err := r.erase(); err != nil {
return trace.Wrap(err)
}
for _, resource := range resources {
if err := r.dynamicAccessCache.UpsertAccessRequest(context.TODO(), resource); err != nil {
return trace.Wrap(err)
}
}
return nil
}

func (r *accessRequest) processEvent(event services.Event) error {
switch event.Type {
case backend.OpDelete:
err := r.dynamicAccessCache.DeleteAccessRequest(context.TODO(), event.Resource.GetName())
if err != nil {
// resource could be missing in the cache
// expired or not created, if the first consumed
// event is delete
if !trace.IsNotFound(err) {
r.Warningf("Failed to delete resource %v.", err)
return trace.Wrap(err)
}
}
case backend.OpPut:
resource, ok := event.Resource.(*services.AccessRequestV3)
if !ok {
return trace.BadParameter("unexpected type %T", event.Resource)
}
r.setTTL(resource)
if err := r.dynamicAccessCache.UpsertAccessRequest(context.TODO(), resource); err != nil {
return trace.Wrap(err)
}
default:
r.Warningf("Skipping unsupported event type %v.", event.Type)
}
return nil
}

func (r *accessRequest) watchKind() services.WatchKind {
return r.watch
}

type tunnelConnection struct {
*Cache
watch services.WatchKind
Expand Down
1 change: 1 addition & 0 deletions lib/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,7 @@ func (process *TeleportProcess) newAccessCache(cfg accessCacheConfig) (*cache.Ca
Trust: cfg.services,
Users: cfg.services,
Access: cfg.services,
DynamicAccess: cfg.services,
Presence: cfg.services,
Component: teleport.Component(append(cfg.cacheName, process.id, teleport.ComponentCache)...),
MetricComponent: teleport.Component(append(cfg.cacheName, teleport.ComponentCache)...),
Expand Down
10 changes: 10 additions & 0 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ type DynamicAccess interface {
DeleteAccessRequest(ctx context.Context, reqID string) error
}

// DynamicAccessExt is an extended dynamic access interface
// used to implement some auth server internals.
type DynamicAccessExt interface {
DynamicAccess
// UpsertAccessRequest creates or updates an access request.
UpsertAccessRequest(ctx context.Context, req AccessRequest) error
// DeleteAllAccessRequests deletes all existant access requests.
DeleteAllAccessRequests(ctx context.Context) error
}

// AccessRequest is a request for temporarily granted roles
type AccessRequest interface {
Resource
Expand Down
32 changes: 25 additions & 7 deletions lib/services/local/dynamic_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ type DynamicAccessService struct {
}

// NewDynamicAccessService returns new dynamic access service instance
func NewDynamicAccessService(backend backend.Backend) *AccessService {
return &AccessService{Backend: backend}
func NewDynamicAccessService(backend backend.Backend) *DynamicAccessService {
return &DynamicAccessService{Backend: backend}
}

func (s *AccessService) CreateAccessRequest(ctx context.Context, req services.AccessRequest) error {
func (s *DynamicAccessService) CreateAccessRequest(ctx context.Context, req services.AccessRequest) error {
if err := req.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
Expand All @@ -50,7 +50,7 @@ func (s *AccessService) CreateAccessRequest(ctx context.Context, req services.Ac
return nil
}

func (s *AccessService) SetAccessRequestState(ctx context.Context, name string, state services.RequestState) error {
func (s *DynamicAccessService) SetAccessRequestState(ctx context.Context, name string, state services.RequestState) error {
item, err := s.Get(ctx, accessRequestKey(name))
if err != nil {
if trace.IsNotFound(err) {
Expand Down Expand Up @@ -80,7 +80,7 @@ func (s *AccessService) SetAccessRequestState(ctx context.Context, name string,
return nil
}

func (s *AccessService) GetAccessRequest(ctx context.Context, name string) (services.AccessRequest, error) {
func (s *DynamicAccessService) GetAccessRequest(ctx context.Context, name string) (services.AccessRequest, error) {
item, err := s.Get(ctx, accessRequestKey(name))
if err != nil {
if trace.IsNotFound(err) {
Expand All @@ -95,7 +95,7 @@ func (s *AccessService) GetAccessRequest(ctx context.Context, name string) (serv
return req, nil
}

func (s *AccessService) GetAccessRequests(ctx context.Context, filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
func (s *DynamicAccessService) GetAccessRequests(ctx context.Context, filter services.AccessRequestFilter) ([]services.AccessRequest, error) {
// Filters which specify ID are a special case since they will match exactly zero or one
// possible requests.
if filter.ID != "" {
Expand Down Expand Up @@ -129,7 +129,7 @@ func (s *AccessService) GetAccessRequests(ctx context.Context, filter services.A
return requests, nil
}

func (s *AccessService) DeleteAccessRequest(ctx context.Context, name string) error {
func (s *DynamicAccessService) DeleteAccessRequest(ctx context.Context, name string) error {
err := s.Delete(ctx, accessRequestKey(name))
if err != nil {
if trace.IsNotFound(err) {
Expand All @@ -140,6 +140,24 @@ func (s *AccessService) DeleteAccessRequest(ctx context.Context, name string) er
return nil
}

func (s *DynamicAccessService) DeleteAllAccessRequests(ctx context.Context) error {
return trace.Wrap(s.DeleteRange(ctx, backend.Key(accessRequestsPrefix), backend.RangeEnd(backend.Key(accessRequestsPrefix))))
}

func (s *DynamicAccessService) UpsertAccessRequest(ctx context.Context, req services.AccessRequest) error {
if err := req.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
item, err := itemFromAccessRequest(req)
if err != nil {
return trace.Wrap(err)
}
if _, err := s.Put(ctx, item); err != nil {
return trace.Wrap(err)
}
return nil
}

func itemFromAccessRequest(req services.AccessRequest) (backend.Item, error) {
value, err := services.GetAccessRequestMarshaler().MarshalAccessRequest(req)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions lib/services/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ type Services interface {
Events
ClusterConfiguration
Access
DynamicAccess
Presence
}
2 changes: 2 additions & 0 deletions tool/tctl/common/access_request_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ func (c *AccessRequestCommand) TryRun(cmd string, client auth.ClientI) (match bo
err = c.Deny(client)
case c.requestCreate.FullCommand():
err = c.Create(client)
case c.requestDelete.FullCommand():
err = c.Delete(client)
default:
return false, nil
}
Expand Down
36 changes: 24 additions & 12 deletions tool/tsh/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -1327,20 +1327,32 @@ Loop:
for {
select {
case event := <-watcher.Events():
if event.Type != backend.OpPut {
switch event.Type {
case backend.OpInit:
log.Infof("Access-request watcher initialized...")
continue Loop
case backend.OpPut:
r, ok := event.Resource.(*services.AccessRequestV3)
if !ok {
return trace.Errorf("unexpected resource type %T", event.Resource)
}
if r.GetName() != req.GetName() || r.GetState().IsPending() {
log.Infof("Skipping put event id=%s,state=%s.", r.GetName(), r.GetState())
continue Loop
}
if !r.GetState().IsApproved() {
return trace.Errorf("request %s has been set to %s", r.GetName(), r.GetState().String())
}
return nil
case backend.OpDelete:
if event.Resource.GetName() != req.GetName() {
log.Infof("Skipping delete event id=%s", event.Resource.GetName())
continue Loop
}
return trace.Errorf("request %s has expired or been deleted...", event.Resource.GetName())
default:
log.Warnf("Skipping unknown event type %s", event.Type)
}
r, ok := event.Resource.(*services.AccessRequestV3)
if !ok {
return trace.Errorf("unexpected resource type %T", event.Resource)
}
if r.GetName() != req.GetName() || r.GetState().IsPending() {
continue Loop
}
if !r.GetState().IsApproved() {
return trace.Errorf("request %s has been set to %s", r.GetName(), r.GetState().String())
}
return nil
case <-watcher.Done():
utils.FatalError(watcher.Error())
}
Expand Down

0 comments on commit ab425c1

Please sign in to comment.