diff --git a/builtin/plugin/backend.go b/builtin/plugin/backend.go index d33fe9c1a8eb..751588905870 100644 --- a/builtin/plugin/backend.go +++ b/builtin/plugin/backend.go @@ -7,6 +7,8 @@ import ( "reflect" "sync" + log "github.com/hashicorp/go-hclog" + uuid "github.com/hashicorp/go-uuid" "github.com/hashicorp/vault/sdk/framework" "github.com/hashicorp/vault/sdk/helper/consts" @@ -38,7 +40,7 @@ func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, // Backend returns an instance of the backend, either as a plugin if external // or as a concrete implementation if builtin, casted as logical.Backend. -func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { +func Backend(ctx context.Context, conf *logical.BackendConfig) (*PluginBackend, error) { var b PluginBackend name := conf.Config["plugin_name"] @@ -80,7 +82,7 @@ func Backend(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, // PluginBackend is a thin wrapper around plugin.BackendPluginClient type PluginBackend struct { - logical.Backend + Backend logical.Backend sync.RWMutex config *logical.BackendConfig @@ -118,12 +120,12 @@ func (b *PluginBackend) startBackend(ctx context.Context, storage logical.Storag if !b.loaded { if b.Backend.Type() != nb.Type() { nb.Cleanup(ctx) - b.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchType) + b.Backend.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchType) return ErrMismatchType } if !reflect.DeepEqual(b.Backend.SpecialPaths(), nb.SpecialPaths()) { nb.Cleanup(ctx) - b.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchPaths) + b.Backend.Logger().Warn("failed to start plugin process", "plugin", b.config.Config["plugin_name"], "error", ErrMismatchPaths) return ErrMismatchPaths } } @@ -169,7 +171,7 @@ func (b *PluginBackend) lazyLoadBackend(ctx context.Context, storage logical.Sto // Reload plugin if it's an rpc.ErrShutdown b.Lock() if b.canary == canary { - b.Logger().Debug("reloading plugin backend", "plugin", b.config.Config["plugin_name"]) + b.Backend.Logger().Debug("reloading plugin backend", "plugin", b.config.Config["plugin_name"]) err := b.startBackend(ctx, storage) if err != nil { b.Unlock() @@ -220,3 +222,52 @@ func (b *PluginBackend) HandleExistenceCheck(ctx context.Context, req *logical.R func (b *PluginBackend) Initialize(ctx context.Context, req *logical.InitializationRequest) error { return nil } + +// SpecialPaths is a thin wrapper used to ensure we grab the lock for race purposes +func (b *PluginBackend) SpecialPaths() *logical.Paths { + b.RLock() + defer b.RUnlock() + return b.Backend.SpecialPaths() +} + +// System is a thin wrapper used to ensure we grab the lock for race purposes +func (b *PluginBackend) System() logical.SystemView { + b.RLock() + defer b.RUnlock() + return b.Backend.System() +} + +// Logger is a thin wrapper used to ensure we grab the lock for race purposes +func (b *PluginBackend) Logger() log.Logger { + b.RLock() + defer b.RUnlock() + return b.Backend.Logger() +} + +// Cleanup is a thin wrapper used to ensure we grab the lock for race purposes +func (b *PluginBackend) Cleanup(ctx context.Context) { + b.RLock() + defer b.RUnlock() + b.Backend.Cleanup(ctx) +} + +// InvalidateKey is a thin wrapper used to ensure we grab the lock for race purposes +func (b *PluginBackend) InvalidateKey(ctx context.Context, key string) { + b.RLock() + defer b.RUnlock() + b.Backend.InvalidateKey(ctx, key) +} + +// Setup is a thin wrapper used to ensure we grab the lock for race purposes +func (b *PluginBackend) Setup(ctx context.Context, config *logical.BackendConfig) error { + b.RLock() + defer b.RUnlock() + return b.Backend.Setup(ctx, config) +} + +// Type is a thin wrapper used to ensure we grab the lock for race purposes +func (b *PluginBackend) Type() logical.BackendType { + b.RLock() + defer b.RUnlock() + return b.Backend.Type() +} diff --git a/helper/forwarding/types.pb.go b/helper/forwarding/types.pb.go index 3a036f4726aa..dfc3aa7ce2b7 100644 --- a/helper/forwarding/types.pb.go +++ b/helper/forwarding/types.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: helper/forwarding/types.proto diff --git a/helper/identity/mfa/types.pb.go b/helper/identity/mfa/types.pb.go index 789def20f0fe..f306ad4048be 100644 --- a/helper/identity/mfa/types.pb.go +++ b/helper/identity/mfa/types.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: helper/identity/mfa/types.proto diff --git a/helper/identity/types.pb.go b/helper/identity/types.pb.go index a392d24bc313..278f361d772d 100644 --- a/helper/identity/types.pb.go +++ b/helper/identity/types.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: helper/identity/types.proto diff --git a/helper/storagepacker/types.pb.go b/helper/storagepacker/types.pb.go index bd7b780cd5a9..235b5c011c9c 100644 --- a/helper/storagepacker/types.pb.go +++ b/helper/storagepacker/types.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: helper/storagepacker/types.proto diff --git a/http/util.go b/http/util.go index cd1d6838c591..b4c8923cc3ee 100644 --- a/http/util.go +++ b/http/util.go @@ -64,7 +64,7 @@ func rateLimitQuotaWrapping(handler http.Handler, core *vault.Core) http.Handler Type: quotas.TypeRateLimit, Path: path, MountPath: mountPath, - Role: core.DetermineRoleFromLoginRequest(mountPath, bodyBytes, r.Context()), + Role: core.DetermineRoleFromLoginRequestFromBytes(mountPath, bodyBytes, r.Context()), NamespacePath: ns.Path, ClientAddress: parseRemoteIPAddress(r), }) diff --git a/physical/raft/types.pb.go b/physical/raft/types.pb.go index 5fca8f6c3e81..470cabffcd61 100644 --- a/physical/raft/types.pb.go +++ b/physical/raft/types.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: physical/raft/types.proto diff --git a/sdk/database/dbplugin/database.pb.go b/sdk/database/dbplugin/database.pb.go index 7c9e08a9b03e..524ddc05e038 100644 --- a/sdk/database/dbplugin/database.pb.go +++ b/sdk/database/dbplugin/database.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: sdk/database/dbplugin/database.proto diff --git a/sdk/database/dbplugin/v5/proto/database.pb.go b/sdk/database/dbplugin/v5/proto/database.pb.go index a5f52dab999d..b2010276bca7 100644 --- a/sdk/database/dbplugin/v5/proto/database.pb.go +++ b/sdk/database/dbplugin/v5/proto/database.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: sdk/database/dbplugin/v5/proto/database.proto diff --git a/sdk/helper/pluginutil/multiplexing.pb.go b/sdk/helper/pluginutil/multiplexing.pb.go index d0ff51e57b24..b681bf359061 100644 --- a/sdk/helper/pluginutil/multiplexing.pb.go +++ b/sdk/helper/pluginutil/multiplexing.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: sdk/helper/pluginutil/multiplexing.proto diff --git a/sdk/logical/identity.pb.go b/sdk/logical/identity.pb.go index 4b1a36b39826..d0e2ab6227ea 100644 --- a/sdk/logical/identity.pb.go +++ b/sdk/logical/identity.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: sdk/logical/identity.proto diff --git a/sdk/logical/plugin.pb.go b/sdk/logical/plugin.pb.go index 1fb53f9a79c9..ed456ef2db90 100644 --- a/sdk/logical/plugin.pb.go +++ b/sdk/logical/plugin.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: sdk/logical/plugin.proto diff --git a/sdk/plugin/pb/backend.pb.go b/sdk/plugin/pb/backend.pb.go index dbad4da977ce..39480c82aea5 100644 --- a/sdk/plugin/pb/backend.pb.go +++ b/sdk/plugin/pb/backend.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: sdk/plugin/pb/backend.proto diff --git a/vault/activity/activity_log.pb.go b/vault/activity/activity_log.pb.go index 21c58e5675f3..7c60ec617f57 100644 --- a/vault/activity/activity_log.pb.go +++ b/vault/activity/activity_log.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: vault/activity/activity_log.proto diff --git a/vault/core.go b/vault/core.go index 8fbb5c5df3a7..14db32b7380b 100644 --- a/vault/core.go +++ b/vault/core.go @@ -175,7 +175,7 @@ func (e *ErrInvalidKey) Error() string { return fmt.Sprintf("invalid key: %v", e.Reason) } -type RegisterAuthFunc func(context.Context, time.Duration, string, *logical.Auth) error +type RegisterAuthFunc func(context.Context, time.Duration, string, *logical.Auth, string) error type activeAdvertisement struct { RedirectAddr string `json:"redirect_addr"` @@ -3324,22 +3324,30 @@ func (c *Core) CheckPluginPerms(pluginName string) (err error) { return err } +// DetermineRoleFromLoginRequestFromBytes will determine the role that should be applied to a quota for a given +// login request, accepting a byte payload +func (c *Core) DetermineRoleFromLoginRequestFromBytes(mountPoint string, payload []byte, ctx context.Context) string { + data := make(map[string]interface{}) + err := jsonutil.DecodeJSON(payload, &data) + if err != nil { + // Cannot discern a role from a request we cannot parse + return "" + } + + return c.DetermineRoleFromLoginRequest(mountPoint, data, ctx) +} + // DetermineRoleFromLoginRequest will determine the role that should be applied to a quota for a given // login request -func (c *Core) DetermineRoleFromLoginRequest(mountPoint string, payload []byte, ctx context.Context) string { +func (c *Core) DetermineRoleFromLoginRequest(mountPoint string, data map[string]interface{}, ctx context.Context) string { + c.authLock.RLock() + defer c.authLock.RUnlock() matchingBackend := c.router.MatchingBackend(ctx, mountPoint) if matchingBackend == nil || matchingBackend.Type() != logical.TypeCredential { // Role based quotas do not apply to this request return "" } - data := make(map[string]interface{}) - err := jsonutil.DecodeJSON(payload, &data) - if err != nil { - // Cannot discern a role from a request we cannot parse - return "" - } - resp, err := matchingBackend.HandleRequest(ctx, &logical.Request{ MountPoint: mountPoint, Path: "login", diff --git a/vault/core_util.go b/vault/core_util.go index 2a609acfd903..58bce82f40c8 100644 --- a/vault/core_util.go +++ b/vault/core_util.go @@ -166,7 +166,7 @@ func (c *Core) quotaLeaseWalker(ctx context.Context, callback func(request *quot return nil } -func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction, leaseIDs []string) error { +func (c *Core) quotasHandleLeases(ctx context.Context, action quotas.LeaseAction, leases []*quotas.QuotaLeaseInformation) error { return nil } diff --git a/vault/expiration.go b/vault/expiration.go index 4bda9bded750..e99ce18147b4 100644 --- a/vault/expiration.go +++ b/vault/expiration.go @@ -471,9 +471,15 @@ func (m *ExpirationManager) invalidate(key string) { m.pending.Delete(leaseID) m.leaseCount-- - if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil { - m.logger.Error("failed to update quota on lease invalidation", "error", err) - return + // Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to + // accurately update quota lease information. + // Note that cachedLeaseInfo should never be nil under normal operation. + if pending.cachedLeaseInfo != nil { + leaseInfo := "as.QuotaLeaseInformation{LeaseId: leaseID, Role: pending.cachedLeaseInfo.LoginRole} + if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil { + m.logger.Error("failed to update quota on lease invalidation", "error", err) + return + } } default: // Update the lease in memory @@ -486,14 +492,21 @@ func (m *ExpirationManager) invalidate(key string) { // other maps, and update metrics/quotas if appropriate. m.nonexpiring.Delete(leaseID) - if _, ok := m.irrevocable.Load(leaseID); ok { + if info, ok := m.irrevocable.Load(leaseID); ok { + irrevocable := info.(pendingInfo) m.irrevocable.Delete(leaseID) m.irrevocableLeaseCount-- m.leaseCount-- - if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil { - m.logger.Error("failed to update quota on lease invalidation", "error", err) - return + // Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to + // accurately update quota lease information. + // Note that cachedLeaseInfo should never be nil under normal operation. + if irrevocable.cachedLeaseInfo != nil { + leaseInfo := "as.QuotaLeaseInformation{LeaseId: leaseID, Role: irrevocable.cachedLeaseInfo.LoginRole} + if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil { + m.logger.Error("failed to update quota on lease invalidation", "error", err) + return + } } } return @@ -1389,7 +1402,7 @@ func (m *ExpirationManager) RenewToken(ctx context.Context, req *logical.Request // Register is used to take a request and response with an associated // lease. The secret gets assigned a LeaseID and the management of // of lease is assumed by the expiration manager. -func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, resp *logical.Response) (id string, retErr error) { +func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, resp *logical.Response, loginRole string) (id string, retErr error) { defer metrics.MeasureSince([]string{"expire", "register"}, time.Now()) te := req.TokenEntry() @@ -1431,6 +1444,7 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, Path: req.Path, Data: resp.Data, Secret: resp.Secret, + LoginRole: loginRole, IssueTime: time.Now(), ExpireTime: resp.Secret.ExpirationTime(), namespace: ns, @@ -1524,7 +1538,7 @@ func (m *ExpirationManager) Register(ctx context.Context, req *logical.Request, // RegisterAuth is used to take an Auth response with an associated lease. // The token does not get a LeaseID, but the lease management is handled by // the expiration manager. -func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenEntry, auth *logical.Auth) error { +func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenEntry, auth *logical.Auth, loginRole string) error { defer metrics.MeasureSince([]string{"expire", "register-auth"}, time.Now()) // Triggers failure of RegisterAuth. This should only be set and triggered @@ -1576,6 +1590,7 @@ func (m *ExpirationManager) RegisterAuth(ctx context.Context, te *logical.TokenE ClientToken: auth.ClientToken, Auth: auth, Path: te.Path, + LoginRole: loginRole, IssueTime: time.Now(), ExpireTime: authExpirationTime, namespace: tokenNS, @@ -1721,6 +1736,7 @@ func (m *ExpirationManager) inMemoryLeaseInfo(le *leaseEntry) *leaseEntry { if le.isIrrevocable() { ret.RevokeErr = le.RevokeErr } + ret.LoginRole = le.LoginRole return ret } @@ -1795,9 +1811,15 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry) { info.(pendingInfo).timer.Stop() m.pending.Delete(le.LeaseID) m.leaseCount-- - if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionDeleted, []string{le.LeaseID}); err != nil { - m.logger.Error("failed to update quota on lease deletion", "error", err) - return + // Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to + // accurately update quota lease information. + // Note that cachedLeaseInfo should never be nil under normal operation. + if pending.cachedLeaseInfo != nil { + leaseInfo := "as.QuotaLeaseInformation{LeaseId: le.LeaseID, Role: le.LoginRole} + if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil { + m.logger.Error("failed to update quota on lease deletion", "error", err) + return + } } } return @@ -1849,9 +1871,15 @@ func (m *ExpirationManager) updatePendingInternal(le *leaseEntry) { if leaseCreated { m.leaseCount++ - if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionCreated, []string{le.LeaseID}); err != nil { - m.logger.Error("failed to update quota on lease creation", "error", err) - return + // Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to + // accurately update quota lease information. + // Note that cachedLeaseInfo should never be nil under normal operation. + if pending.cachedLeaseInfo != nil { + leaseInfo := "as.QuotaLeaseInformation{LeaseId: le.LeaseID, Role: le.LoginRole} + if err := m.core.quotasHandleLeases(m.quitContext, quotas.LeaseActionCreated, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil { + m.logger.Error("failed to update quota on lease creation", "error", err) + return + } } } } @@ -2450,9 +2478,15 @@ func (m *ExpirationManager) removeFromPending(ctx context.Context, leaseID strin m.pending.Delete(leaseID) if decrementCounters { m.leaseCount-- - // Log but do not fail; unit tests (and maybe Tidy on production systems) - if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []string{leaseID}); err != nil { - m.logger.Error("failed to update quota on revocation", "error", err) + // Avoid nil pointer dereference. Without cachedLeaseInfo we do not have enough information to + // accurately update quota lease information. + // Note that cachedLeaseInfo should never be nil under normal operation. + if pending.cachedLeaseInfo != nil { + leaseInfo := "as.QuotaLeaseInformation{LeaseId: leaseID, Role: pending.cachedLeaseInfo.LoginRole} + // Log but do not fail; unit tests (and maybe Tidy on production systems) + if err := m.core.quotasHandleLeases(ctx, quotas.LeaseActionDeleted, []*quotas.QuotaLeaseInformation{leaseInfo}); err != nil { + m.logger.Error("failed to update quota on revocation", "error", err) + } } } } @@ -2663,6 +2697,11 @@ type leaseEntry struct { ExpireTime time.Time `json:"expire_time"` LastRenewalTime time.Time `json:"last_renewal_time"` + // LoginRole is used to indicate which login role (if applicable) this lease + // was created with. This is required to decrement lease count quotas + // based on login roles upon lease expiry. + LoginRole string `json:"login_role"` + // Version is used to track new different versions of leases. V0 (or // zero-value) had non-root namespaced secondary indexes live in the root // namespace, and V1 has secondary indexes live in the matching namespace. diff --git a/vault/expiration_test.go b/vault/expiration_test.go index 84a8e0076184..cf59ba3e3515 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -324,6 +324,103 @@ func TestExpiration_TotalLeaseCount(t *testing.T) { } } +func TestExpiration_TotalLeaseCount_WithRoles(t *testing.T) { + // Quotas and internal lease count tracker are coupled, so this is a proxy + // for testing the total lease count quota + c, _, _ := TestCoreUnsealed(t) + exp := c.expiration + + expectedCount := 0 + otherNS := &namespace.Namespace{ + ID: "nsid", + Path: "foo/bar", + } + for i := 0; i < 50; i++ { + le := &leaseEntry{ + LeaseID: "lease" + fmt.Sprintf("%d", i), + Path: "foo/bar/" + fmt.Sprintf("%d", i), + LoginRole: "loginRole" + fmt.Sprintf("%d", i), + namespace: namespace.RootNamespace, + IssueTime: time.Now(), + ExpireTime: time.Now().Add(time.Hour), + } + + otherNSle := &leaseEntry{ + LeaseID: "lease" + fmt.Sprintf("%d", i) + "/blah.nsid", + Path: "foo/bar/" + fmt.Sprintf("%d", i) + "/blah.nsid", + LoginRole: "loginRole" + fmt.Sprintf("%d", i), + namespace: otherNS, + IssueTime: time.Now(), + ExpireTime: time.Now().Add(time.Hour), + } + + exp.pendingLock.Lock() + if err := exp.persistEntry(namespace.RootContext(nil), le); err != nil { + exp.pendingLock.Unlock() + t.Fatalf("error persisting irrevocable entry: %v", err) + } + exp.updatePendingInternal(le) + expectedCount++ + + if err := exp.persistEntry(namespace.RootContext(nil), otherNSle); err != nil { + exp.pendingLock.Unlock() + t.Fatalf("error persisting irrevocable entry: %v", err) + } + exp.updatePendingInternal(otherNSle) + expectedCount++ + exp.pendingLock.Unlock() + } + + // add some irrevocable leases to each count to ensure they are counted too + // note: irrevocable leases almost certainly have an expire time set in the + // past, but for this exercise it should be fine to set it to whatever + for i := 50; i < 60; i++ { + le := &leaseEntry{ + LeaseID: "lease" + fmt.Sprintf("%d", i+1), + Path: "foo/bar/" + fmt.Sprintf("%d", i+1), + LoginRole: "loginRole" + fmt.Sprintf("%d", i), + namespace: namespace.RootNamespace, + IssueTime: time.Now(), + ExpireTime: time.Now(), + RevokeErr: "some err message", + } + + otherNSle := &leaseEntry{ + LeaseID: "lease" + fmt.Sprintf("%d", i+1) + "/blah.nsid", + Path: "foo/bar/" + fmt.Sprintf("%d", i+1) + "/blah.nsid", + LoginRole: "loginRole" + fmt.Sprintf("%d", i), + namespace: otherNS, + IssueTime: time.Now(), + ExpireTime: time.Now(), + RevokeErr: "some err message", + } + + exp.pendingLock.Lock() + if err := exp.persistEntry(namespace.RootContext(nil), le); err != nil { + exp.pendingLock.Unlock() + t.Fatalf("error persisting irrevocable entry: %v", err) + } + exp.updatePendingInternal(le) + expectedCount++ + + if err := exp.persistEntry(namespace.RootContext(nil), otherNSle); err != nil { + exp.pendingLock.Unlock() + t.Fatalf("error persisting irrevocable entry: %v", err) + } + exp.updatePendingInternal(otherNSle) + expectedCount++ + exp.pendingLock.Unlock() + } + + exp.pendingLock.RLock() + count := exp.leaseCount + exp.pendingLock.RUnlock() + + if count != expectedCount { + t.Errorf("bad lease count. expected %d, got %d", expectedCount, count) + } +} + func TestExpiration_Tidy(t *testing.T) { var err error @@ -477,7 +574,7 @@ func TestExpiration_Tidy(t *testing.T) { "test_key": "test_value", }, } - _, err := exp.Register(namespace.RootContext(nil), req, resp) + _, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -636,7 +733,7 @@ func benchmarkExpirationBackend(b *testing.B, physicalBackend physical.Backend, "secret_key": "abcd", }, } - _, err = exp.Register(namespace.RootContext(nil), req, resp) + _, err = exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { b.Fatalf("err: %v", err) } @@ -698,7 +795,7 @@ func BenchmarkExpiration_Create_Leases(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { req.Path = fmt.Sprintf("prod/aws/%d", i) - _, err = exp.Register(namespace.RootContext(nil), req, resp) + _, err = exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { b.Fatalf("err: %v", err) } @@ -743,7 +840,7 @@ func TestExpiration_Restore(t *testing.T) { "secret_key": "abcd", }, } - _, err := exp.Register(namespace.RootContext(nil), req, resp) + _, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -815,7 +912,7 @@ func TestExpiration_Register(t *testing.T) { }, } - id, err := exp.Register(namespace.RootContext(nil), req, resp) + id, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -829,6 +926,49 @@ func TestExpiration_Register(t *testing.T) { } } +func TestExpiration_Register_Role(t *testing.T) { + exp := mockExpiration(t) + role := "role1" + req := &logical.Request{ + Operation: logical.ReadOperation, + Path: "prod/aws/foo", + ClientToken: "foobar", + } + req.SetTokenEntry(&logical.TokenEntry{ID: "foobar", NamespaceID: "root"}) + resp := &logical.Response{ + Secret: &logical.Secret{ + LeaseOptions: logical.LeaseOptions{ + TTL: time.Hour, + }, + }, + Data: map[string]interface{}{ + "access_key": "xyz", + "secret_key": "abcd", + }, + } + + id, err := exp.Register(namespace.RootContext(nil), req, resp, role) + if err != nil { + t.Fatalf("err: %v", err) + } + + if !strings.HasPrefix(id, req.Path) { + t.Fatalf("bad: %s", id) + } + + if len(id) <= len(req.Path) { + t.Fatalf("bad: %s", id) + } + + le, err := exp.loadEntry(exp.quitContext, id) + if err != nil { + t.Fatalf("err: %v", err) + } + if le.LoginRole != role { + t.Fatalf("Login role incorrect. Expected %s, received %s", role, le.LoginRole) + } +} + func TestExpiration_Register_BatchToken(t *testing.T) { c, _, rootToken := TestCoreUnsealed(t) exp := c.expiration @@ -883,7 +1023,7 @@ func TestExpiration_Register_BatchToken(t *testing.T) { }, } - leaseID, err := exp.Register(namespace.RootContext(nil), req, resp) + leaseID, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -952,7 +1092,41 @@ func TestExpiration_RegisterAuth(t *testing.T) { Path: "auth/github/login", NamespaceID: namespace.RootNamespaceID, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") + if err != nil { + t.Fatalf("err: %v", err) + } + + te = &logical.TokenEntry{ + Path: "auth/github/../login", + NamespaceID: namespace.RootNamespaceID, + } + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") + if err == nil { + t.Fatal("expected error") + } +} + +func TestExpiration_RegisterAuth_Role(t *testing.T) { + exp := mockExpiration(t) + role := "role1" + root, err := exp.tokenStore.rootToken(context.Background()) + if err != nil { + t.Fatalf("err: %v", err) + } + + auth := &logical.Auth{ + ClientToken: root.ID, + LeaseOptions: logical.LeaseOptions{ + TTL: time.Hour, + }, + } + + te := &logical.TokenEntry{ + Path: "auth/github/login", + NamespaceID: namespace.RootNamespaceID, + } + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, role) if err != nil { t.Fatalf("err: %v", err) } @@ -961,7 +1135,7 @@ func TestExpiration_RegisterAuth(t *testing.T) { Path: "auth/github/../login", NamespaceID: namespace.RootNamespaceID, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, role) if err == nil { t.Fatal("expected error") } @@ -985,7 +1159,7 @@ func TestExpiration_RegisterAuth_NoLease(t *testing.T) { Policies: []string{"root"}, NamespaceID: namespace.RootNamespaceID, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1034,13 +1208,13 @@ func TestExpiration_RegisterAuth_NoTTL(t *testing.T) { } // First on core - err = c.RegisterAuth(ctx, 0, "auth/github/login", auth) + err = c.RegisterAuth(ctx, 0, "auth/github/login", auth, "") if err != nil { t.Fatal(err) } auth.TokenPolicies[0] = "default" - err = c.RegisterAuth(ctx, 0, "auth/github/login", auth) + err = c.RegisterAuth(ctx, 0, "auth/github/login", auth, "") if err == nil { t.Fatal("expected error") } @@ -1053,14 +1227,14 @@ func TestExpiration_RegisterAuth_NoTTL(t *testing.T) { Policies: []string{"root"}, NamespaceID: namespace.RootNamespaceID, } - err = exp.RegisterAuth(ctx, te, auth) + err = exp.RegisterAuth(ctx, te, auth, "") if err != nil { t.Fatalf("err: %v", err) } // Test non-root token with zero TTL te.Policies = []string{"default"} - err = exp.RegisterAuth(ctx, te, auth) + err = exp.RegisterAuth(ctx, te, auth, "") if err == nil { t.Fatal("expected error") } @@ -1098,7 +1272,7 @@ func TestExpiration_Revoke(t *testing.T) { }, } - id, err := exp.Register(namespace.RootContext(nil), req, resp) + id, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1145,7 +1319,7 @@ func TestExpiration_RevokeOnExpire(t *testing.T) { }, } - _, err = exp.Register(namespace.RootContext(nil), req, resp) + _, err = exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1208,7 +1382,7 @@ func TestExpiration_RevokePrefix(t *testing.T) { "secret_key": "abcd", }, } - _, err := exp.Register(namespace.RootContext(nil), req, resp) + _, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1277,7 +1451,7 @@ func TestExpiration_RevokeByToken(t *testing.T) { "secret_key": "abcd", }, } - _, err := exp.Register(namespace.RootContext(nil), req, resp) + _, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1376,7 +1550,7 @@ func TestExpiration_RevokeByToken_Blocking(t *testing.T) { "secret_key": "abcd", }, } - _, err := exp.Register(namespace.RootContext(nil), req, resp) + _, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1448,7 +1622,7 @@ func TestExpiration_RenewToken(t *testing.T) { Path: "auth/token/login", NamespaceID: namespace.RootNamespaceID, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1497,7 +1671,7 @@ func TestExpiration_RenewToken_period(t *testing.T) { Path: "auth/token/login", NamespaceID: namespace.RootNamespaceID, } - err := exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err := exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1578,7 +1752,7 @@ func TestExpiration_RenewToken_period_backend(t *testing.T) { NamespaceID: namespace.RootNamespaceID, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1635,7 +1809,7 @@ func TestExpiration_RenewToken_NotRenewable(t *testing.T) { Path: "auth/foo/login", NamespaceID: namespace.RootNamespaceID, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1688,7 +1862,7 @@ func TestExpiration_Renew(t *testing.T) { }, } - id, err := exp.Register(namespace.RootContext(nil), req, resp) + id, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1759,7 +1933,7 @@ func TestExpiration_Renew_NotRenewable(t *testing.T) { }, } - id, err := exp.Register(namespace.RootContext(nil), req, resp) + id, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1810,7 +1984,7 @@ func TestExpiration_Renew_RevokeOnExpire(t *testing.T) { }, } - id, err := exp.Register(namespace.RootContext(nil), req, resp) + id, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1887,7 +2061,7 @@ func TestExpiration_Renew_FinalSecond(t *testing.T) { } ctx := namespace.RootContext(nil) - id, err := exp.Register(ctx, req, resp) + id, err := exp.Register(ctx, req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1962,7 +2136,7 @@ func TestExpiration_Renew_FinalSecond_Lease(t *testing.T) { } ctx := namespace.RootContext(nil) - id, err := exp.Register(ctx, req, resp) + id, err := exp.Register(ctx, req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -2647,7 +2821,7 @@ func sampleToken(t *testing.T, exp *ExpirationManager, path string, expiring boo Policies: auth.Policies, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -2822,7 +2996,7 @@ func registerOneLease(t *testing.T, ctx context.Context, exp *ExpirationManager) }, } - leaseID, err := exp.Register(ctx, req, resp) + leaseID, err := exp.Register(ctx, req, resp, "") if err != nil { t.Fatal(err) } diff --git a/vault/logical_system_quotas.go b/vault/logical_system_quotas.go index e6970a62b1ff..9d8de5769742 100644 --- a/vault/logical_system_quotas.go +++ b/vault/logical_system_quotas.go @@ -211,7 +211,7 @@ func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc { } authBackend := b.Core.router.MatchingBackend(namespace.ContextWithNamespace(ctx, ns), mountPath) if authBackend == nil || authBackend.Type() != logical.TypeCredential { - return logical.ErrorResponse("Mount path '%s' is not a valid auth method and therefore unsuitable for use with role-based quotas", mountPath), nil + return logical.ErrorResponse("Mount path %q is not a valid auth method and therefore unsuitable for use with role-based quotas", mountPath), nil } // We will always error as we aren't supplying real data, but we're looking for "unsupported operation" in particular _, err := authBackend.HandleRequest(ctx, &logical.Request{ @@ -219,7 +219,7 @@ func (b *SystemBackend) handleRateLimitQuotasUpdate() framework.OperationFunc { Operation: logical.ResolveRoleOperation, }) if err != nil && (err == logical.ErrUnsupportedOperation || err == logical.ErrUnsupportedPath) { - return logical.ErrorResponse("Mount path '%s' does not support use with role-based quotas", mountPath), nil + return logical.ErrorResponse("Mount path %q does not support use with role-based quotas", mountPath), nil } } diff --git a/vault/logical_system_test.go b/vault/logical_system_test.go index 55e06d9b0d04..8c773ca6e97b 100644 --- a/vault/logical_system_test.go +++ b/vault/logical_system_test.go @@ -1708,7 +1708,7 @@ func TestSystemBackend_revokePrefixAuth_newUrl(t *testing.T) { TTL: time.Hour, }, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1772,7 +1772,7 @@ func TestSystemBackend_revokePrefixAuth_origUrl(t *testing.T) { TTL: time.Hour, }, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -3617,7 +3617,7 @@ func TestSystemBackend_PathWildcardPreflight(t *testing.T) { ClientToken: te.ID, Accessor: te.Accessor, Orphan: true, - }); err != nil { + }, ""); err != nil { t.Fatal(err) } diff --git a/vault/login_mfa.go b/vault/login_mfa.go index d92cd6304d07..2a0803683a06 100644 --- a/vault/login_mfa.go +++ b/vault/login_mfa.go @@ -716,7 +716,7 @@ func (b *LoginMFABackend) handleMFALoginValidate(ctx context.Context, req *logic } // MFA validation has passed. Let's generate the token - resp, err := b.Core.LoginMFACreateToken(ctx, cachedResponseAuth.RequestPath, cachedResponseAuth.CachedAuth) + resp, err := b.Core.LoginMFACreateToken(ctx, cachedResponseAuth.RequestPath, cachedResponseAuth.CachedAuth, req.Data) if err != nil { return nil, fmt.Errorf("failed to create a token. error: %v", err) } @@ -742,7 +742,7 @@ func (c *Core) teardownLoginMFA() error { // LoginMFACreateToken creates a token after the login MFA is validated. // It also applies the lease quotas on the original login request path. -func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAuth *logical.Auth) (*logical.Response, error) { +func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAuth *logical.Auth, loginRequestData map[string]interface{}) (*logical.Response, error) { auth := cachedAuth resp := &logical.Response{ Auth: auth, @@ -761,6 +761,7 @@ func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAu quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, "as.Request{ Path: reqPath, MountPath: strings.TrimPrefix(mountPoint, ns.Path), + Role: c.DetermineRoleFromLoginRequest(mountPoint, loginRequestData, ctx), NamespacePath: ns.Path, }) @@ -780,7 +781,7 @@ func (c *Core) LoginMFACreateToken(ctx context.Context, reqPath string, cachedAu // note that we don't need to handle the error for the following function right away. // The function takes the response as in input variable and modify it. So, the returned // arguments are resp and err. - leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, resp) + leaseGenerated, resp, err := c.LoginCreateToken(ctx, ns, reqPath, mountPoint, resp, loginRequestData) if quotaResp.Access != nil { quotaAckErr := c.ackLeaseQuota(quotaResp.Access, leaseGenerated) diff --git a/vault/quotas/quotas.go b/vault/quotas/quotas.go index f58a778ccba9..cac8ba58918d 100644 --- a/vault/quotas/quotas.go +++ b/vault/quotas/quotas.go @@ -168,6 +168,17 @@ type Manager struct { lock *sync.RWMutex } +// QuotaLeaseInformation contains all of the information lease-count quotas require +// from a lease to uniquely identify the lease-count quota to increment/decrement +type QuotaLeaseInformation struct { + // We can determine path and namespace from leaseId + LeaseId string + + // We need the role as it's not part of the leaseId, and is required + // to uniquely identify a lease count quota + Role string +} + // Quota represents the common properties of every quota type type Quota interface { // allow checks the if the request is allowed by the quota type implementation. diff --git a/vault/request_forwarding_service.pb.go b/vault/request_forwarding_service.pb.go index d7170ded484e..f894dd619320 100644 --- a/vault/request_forwarding_service.pb.go +++ b/vault/request_forwarding_service.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: vault/request_forwarding_service.proto diff --git a/vault/request_handling.go b/vault/request_handling.go index b40bde28e47a..2e32864b4b9a 100644 --- a/vault/request_handling.go +++ b/vault/request_handling.go @@ -969,9 +969,11 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp } leaseGenerated := false + loginRole := c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx) quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, "as.Request{ Path: req.Path, MountPath: strings.TrimPrefix(req.MountPoint, ns.Path), + Role: loginRole, NamespacePath: ns.Path, }) if quotaErr != nil { @@ -1111,7 +1113,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp return nil, auth, retErr } - leaseID, err := registerFunc(ctx, req, resp) + leaseID, err := registerFunc(ctx, req, resp, loginRole) if err != nil { c.logger.Error("failed to register lease", "request_path", req.Path, "error", err) retErr = multierror.Append(retErr, ErrInternalError) @@ -1191,7 +1193,7 @@ func (c *Core) handleRequest(ctx context.Context, req *logical.Request) (retResp Path: resp.Auth.CreationPath, NamespaceID: ns.ID, } - if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth); err != nil { + if err := c.expiration.RegisterAuth(ctx, registeredTokenEntry, resp.Auth, c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx)); err != nil { // Best-effort clean up on error, so we log the cleanup error as // a warning but still return as internal error. if err := c.tokenStore.revokeOrphan(ctx, resp.Auth.ClientToken); err != nil { @@ -1390,6 +1392,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re quotaResp, quotaErr := c.applyLeaseCountQuota(ctx, "as.Request{ Path: req.Path, MountPath: strings.TrimPrefix(req.MountPoint, ns.Path), + Role: c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx), NamespacePath: ns.Path, }) @@ -1576,7 +1579,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re // Attach the display name, might be used by audit backends req.DisplayName = auth.DisplayName - leaseGen, respTokenCreate, errCreateToken := c.LoginCreateToken(ctx, ns, req.Path, source, resp) + leaseGen, respTokenCreate, errCreateToken := c.LoginCreateToken(ctx, ns, req.Path, source, resp, req.Data) leaseGenerated = leaseGen if errCreateToken != nil { return respTokenCreate, nil, errCreateToken @@ -1607,7 +1610,7 @@ func (c *Core) handleLoginRequest(ctx context.Context, req *logical.Request) (re // LoginCreateToken creates a token as a result of a login request. // If MFA is enforced, mfa/validate endpoint calls this functions // after successful MFA validation to generate the token. -func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, reqPath, mountPoint string, resp *logical.Response) (bool, *logical.Response, error) { +func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, reqPath, mountPoint string, resp *logical.Response, loginRequestData map[string]interface{}) (bool, *logical.Response, error) { auth := resp.Auth source := strings.TrimPrefix(mountPoint, credentialRoutePrefix) @@ -1669,7 +1672,7 @@ func (c *Core) LoginCreateToken(ctx context.Context, ns *namespace.Namespace, re } leaseGenerated := false - err = registerFunc(ctx, tokenTTL, reqPath, auth) + err = registerFunc(ctx, tokenTTL, reqPath, auth, c.DetermineRoleFromLoginRequest(mountPoint, loginRequestData, ctx)) switch { case err == nil: if auth.TokenType != logical.TokenTypeBatch { @@ -1736,7 +1739,9 @@ func blockRequestIfErrorImpl(_ *Core, _, _ string) error { return nil } // RegisterAuth uses a logical.Auth object to create a token entry in the token // store, and registers a corresponding token lease to the expiration manager. -func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path string, auth *logical.Auth) error { +// role is the login role used as part of the creation of the token entry. If not +// relevant, can be omitted (by being provided as ""). +func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path string, auth *logical.Auth, role string) error { // We first assign token policies to what was returned from the backend // via auth.Policies. Then, we get the full set of policies into // auth.Policies from the backend + entity information -- this is not @@ -1786,7 +1791,7 @@ func (c *Core) RegisterAuth(ctx context.Context, tokenTTL time.Duration, path st auth.Renewable = false case logical.TokenTypeService: // Register with the expiration manager - if err := c.expiration.RegisterAuth(ctx, &te, auth); err != nil { + if err := c.expiration.RegisterAuth(ctx, &te, auth, role); err != nil { if err := c.tokenStore.revokeOrphan(ctx, te.ID); err != nil { c.logger.Warn("failed to clean up token lease during login request", "request_path", path, "error", err) } diff --git a/vault/request_handling_util.go b/vault/request_handling_util.go index a08709d5ca92..f8549e214f88 100644 --- a/vault/request_handling_util.go +++ b/vault/request_handling_util.go @@ -42,7 +42,7 @@ func forward(ctx context.Context, c *Core, req *logical.Request) (*logical.Respo panic("forward called in OSS Vault") } -func getLeaseRegisterFunc(c *Core) (func(context.Context, *logical.Request, *logical.Response) (string, error), error) { +func getLeaseRegisterFunc(c *Core) (func(context.Context, *logical.Request, *logical.Response, string) (string, error), error) { return c.expiration.Register, nil } diff --git a/vault/token_store_test.go b/vault/token_store_test.go index d3a7d655737e..8891fdc29ad7 100644 --- a/vault/token_store_test.go +++ b/vault/token_store_test.go @@ -330,7 +330,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { NamespaceID: namespace.RootNamespaceID, } - if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), registryEntry, auth); err != nil { + if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), registryEntry, auth, ""); err != nil { t.Fatal(err) } @@ -375,7 +375,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { }, ClientToken: ent.ID, } - if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil { + if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil { t.Fatal(err) } @@ -420,7 +420,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { }, ClientToken: ent.ID, } - if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil { + if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil { t.Fatal(err) } @@ -462,7 +462,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { }, ClientToken: ent.ID, } - if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil { + if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil { t.Fatal(err) } @@ -496,7 +496,7 @@ func TestTokenStore_TokenEntryUpgrade(t *testing.T) { }, ClientToken: ent.ID, } - if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth); err != nil { + if err := ts.expiration.RegisterAuth(namespace.RootContext(nil), ent, auth, ""); err != nil { t.Fatal(err) } @@ -572,7 +572,7 @@ func testMakeTokenViaRequestContext(t testing.TB, ctx context.Context, ts *Token } if resp.Auth.TokenType != logical.TokenTypeBatch { - if err := ts.expiration.RegisterAuth(ctx, te, resp.Auth); err != nil { + if err := ts.expiration.RegisterAuth(ctx, te, resp.Auth, ""); err != nil { t.Fatal(err) } } @@ -618,7 +618,7 @@ func testMakeTokenDirectly(t testing.TB, ts *TokenStore, te *logical.TokenEntry) CreationPath: te.Path, TokenType: te.Type, } - err := ts.expiration.RegisterAuth(namespace.RootContext(nil), te, auth) + err := ts.expiration.RegisterAuth(namespace.RootContext(nil), te, auth, "") switch err { case nil: if te.Type == logical.TokenTypeBatch { @@ -861,7 +861,7 @@ func TestTokenStore_HandleRequest_Renew_Revoke_Accessor(t *testing.T) { t.Fatal("token entry was nil") } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -1322,7 +1322,7 @@ func TestTokenStore_Revoke_Leases(t *testing.T) { "secret_key": "abcd", }, } - leaseID, err := ts.expiration.Register(namespace.RootContext(nil), req, resp) + leaseID, err := ts.expiration.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatalf("err: %v", err) } @@ -2208,7 +2208,7 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) { Renewable: true, }, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -2230,7 +2230,7 @@ func TestTokenStore_HandleRequest_Revoke(t *testing.T) { Renewable: true, }, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -2623,7 +2623,7 @@ func TestTokenStore_HandleRequest_Renew(t *testing.T) { Renewable: true, }, } - err = exp.RegisterAuth(namespace.RootContext(nil), root, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), root, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -3113,7 +3113,7 @@ func TestTokenStore_HandleRequest_RenewSelf(t *testing.T) { Renewable: true, }, } - err = exp.RegisterAuth(namespace.RootContext(nil), root, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), root, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -5787,7 +5787,7 @@ func TestTokenStore_TidyLeaseRevocation(t *testing.T) { NamespaceID: namespace.RootNamespaceID, } - err = exp.RegisterAuth(namespace.RootContext(nil), te, auth) + err = exp.RegisterAuth(namespace.RootContext(nil), te, auth, "") if err != nil { t.Fatalf("err: %v", err) } @@ -5820,7 +5820,7 @@ func TestTokenStore_TidyLeaseRevocation(t *testing.T) { leases := []string{} for i := 0; i < 10; i++ { - leaseID, err := exp.Register(namespace.RootContext(nil), req, resp) + leaseID, err := exp.Register(namespace.RootContext(nil), req, resp, "") if err != nil { t.Fatal(err) } diff --git a/vault/tokens/token.pb.go b/vault/tokens/token.pb.go index 0f4515bc9f7e..f765e3a699b5 100644 --- a/vault/tokens/token.pb.go +++ b/vault/tokens/token.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 +// protoc-gen-go v1.28.0 // protoc v3.19.4 // source: vault/tokens/token.proto diff --git a/vault/wrapping.go b/vault/wrapping.go index 5b1d32080677..0f613ba79b72 100644 --- a/vault/wrapping.go +++ b/vault/wrapping.go @@ -325,7 +325,7 @@ DONELISTHANDLING: } // Register the wrapped token with the expiration manager - if err := c.expiration.RegisterAuth(ctx, &te, wAuth); err != nil { + if err := c.expiration.RegisterAuth(ctx, &te, wAuth, c.DetermineRoleFromLoginRequest(req.MountPoint, req.Data, ctx)); err != nil { // Revoke since it's not yet being tracked for expiration c.tokenStore.revokeOrphan(ctx, te.ID) c.logger.Error("failed to register cubbyhole wrapping token lease", "request_path", req.Path, "error", err)