From c2cf0e30938d1d7e4d3d6512bbde91953c5a8086 Mon Sep 17 00:00:00 2001 From: Tim Gross Date: Fri, 14 Jun 2024 16:49:57 -0400 Subject: [PATCH] consul: add preflight check for created ACL tokens Nomad creates a Consul ACL token for each service for registering it in Consul or bootstrapping the Envoy proxy (for service mesh workloads). Nomad always talks to the local Consul agent and never directly to the Consul servers. But the local Consul agent talks to the Consul servers in stale consistency mode to reduce load on the servers. This can result in the Nomad client making the Envoy bootstrap request with a token that has not yet replicated to the follower that the local client is connected to. This request gets a 404 on the ACL token and that negative entry gets cached, preventing any retries from succeeding. To workaround this, we'll use a method described by our friends over on `consul-k8s` where after creating the service token we try to read the token from the local agent in stale consistency mode (which prevents a failed read from being cached). This cannot completely eliminate this source of error because it's possible that Consul cluster replication is unhealthy at the time we need it, but this should make Envoy bootstrap significantly more robust. In this changeset, we add the preflight check after we login via Workload Identity and in the function we use to derive tokens in the legacy workflow. We've added the timeouts to be configurable via node metadata rather than the usual static configuration because for most cases, users should not need to touch or even know these values are configurable; the configuration is mostly available for testing. Fixes: https://github.com/hashicorp/nomad/issues/9307 Fixes: https://github.com/hashicorp/nomad/issues/20516 Fixes: https://github.com/hashicorp/nomad/issues/10451 Ref: https://github.com/hashicorp/consul-k8s/pull/887 Ref: https://hashicorp.atlassian.net/browse/NET-10051 --- client/allocrunner/consul_hook.go | 20 ++- client/allocrunner/taskrunner/sids_hook.go | 5 +- .../allocrunner/taskrunner/sids_hook_test.go | 4 +- .../taskrunner/task_runner_test.go | 4 +- client/client.go | 42 +++++- client/consul/consul.go | 86 +++++++++++- client/consul/consul_test.go | 122 ++++++++++++++++++ client/consul/consul_testing.go | 5 + client/consul/identities.go | 6 +- client/consul/identities_test.go | 17 +-- client/consul/identities_testing.go | 5 +- 11 files changed, 283 insertions(+), 33 deletions(-) create mode 100644 client/consul/consul_test.go diff --git a/client/allocrunner/consul_hook.go b/client/allocrunner/consul_hook.go index 54a64c8f2d2..b4e087e06f8 100644 --- a/client/allocrunner/consul_hook.go +++ b/client/allocrunner/consul_hook.go @@ -4,6 +4,7 @@ package allocrunner import ( + "context" "fmt" consulapi "github.com/hashicorp/consul/api" @@ -27,7 +28,9 @@ type consulHook struct { hookResources *cstructs.AllocHookResources envBuilder *taskenv.Builder - logger log.Logger + logger log.Logger + shutdownCtx context.Context + shutdownCancelFn context.CancelFunc } type consulHookConfig struct { @@ -51,6 +54,7 @@ type consulHookConfig struct { } func newConsulHook(cfg consulHookConfig) *consulHook { + shutdownCtx, shutdownCancelFn := context.WithCancel(context.Background()) h := &consulHook{ alloc: cfg.alloc, allocdir: cfg.allocdir, @@ -59,6 +63,8 @@ func newConsulHook(cfg consulHookConfig) *consulHook { consulClientConstructor: cfg.consulClientConstructor, hookResources: cfg.hookResources, envBuilder: cfg.envBuilder(), + shutdownCtx: shutdownCtx, + shutdownCancelFn: shutdownCancelFn, } h.logger = cfg.logger.Named(h.Name()) return h @@ -225,7 +231,12 @@ func (h *consulHook) getConsulToken(cluster string, req consul.JWTLoginRequest) return nil, fmt.Errorf("failed to retrieve Consul client for cluster %s: %v", cluster, err) } - return client.DeriveTokenWithJWT(req) + t, err := client.DeriveTokenWithJWT(req) + if err == nil { + err = client.TokenPreflightCheck(h.shutdownCtx, t) + } + + return t, err } func (h *consulHook) clientForCluster(cluster string) (consul.Client, error) { @@ -248,6 +259,11 @@ func (h *consulHook) Postrun() error { return nil } +// Shutdown will get called when the client is gracefully stopping. +func (h *consulHook) Shutdown() { + h.shutdownCancelFn() +} + // Destroy cleans up any remaining Consul tokens if the alloc is GC'd or fails // to restore after a client restart. func (h *consulHook) Destroy() error { diff --git a/client/allocrunner/taskrunner/sids_hook.go b/client/allocrunner/taskrunner/sids_hook.go index 7d5e780a045..99b47a09b73 100644 --- a/client/allocrunner/taskrunner/sids_hook.go +++ b/client/allocrunner/taskrunner/sids_hook.go @@ -151,7 +151,8 @@ func (h *sidsHook) Prestart( } } - // need to ask for a new SI token & persist it to disk + // COMPAT(1.9): this code path exists only to support the legacy (non-WI) + // workflow. remove for 1.9.0. if token == "" { if token, err = h.deriveSIToken(ctx); err != nil { return err @@ -255,7 +256,7 @@ func (h *sidsHook) kill(ctx context.Context, reason error) { func (h *sidsHook) tryDerive(ctx context.Context, ch chan<- siDerivationResult) { for attempt := 0; backoff(ctx, attempt); attempt++ { - tokens, err := h.sidsClient.DeriveSITokens(h.alloc, []string{h.task.Name}) + tokens, err := h.sidsClient.DeriveSITokens(ctx, h.alloc, []string{h.task.Name}) switch { case err == nil: diff --git a/client/allocrunner/taskrunner/sids_hook_test.go b/client/allocrunner/taskrunner/sids_hook_test.go index c93a303624c..0df13f3cb5d 100644 --- a/client/allocrunner/taskrunner/sids_hook_test.go +++ b/client/allocrunner/taskrunner/sids_hook_test.go @@ -191,7 +191,7 @@ func TestSIDSHook_deriveSIToken_timeout(t *testing.T) { r := require.New(t) siClient := consulclient.NewMockServiceIdentitiesClient() - siClient.DeriveTokenFn = func(allocation *structs.Allocation, strings []string) (m map[string]string, err error) { + siClient.DeriveTokenFn = func(context.Context, *structs.Allocation, []string) (m map[string]string, err error) { select { // block forever, hopefully triggering a timeout in the caller } @@ -288,7 +288,7 @@ func TestTaskRunner_DeriveSIToken_UnWritableTokenFile(t *testing.T) { trConfig.ClientConfig.GetDefaultConsul().Token = uuid.Generate() // derive token works just fine - deriveFn := func(*structs.Allocation, []string) (map[string]string, error) { + deriveFn := func(context.Context, *structs.Allocation, []string) (map[string]string, error) { return map[string]string{task.Name: uuid.Generate()}, nil } siClient := trConfig.ConsulSI.(*consulclient.MockServiceIdentitiesClient) diff --git a/client/allocrunner/taskrunner/task_runner_test.go b/client/allocrunner/taskrunner/task_runner_test.go index 7e820d9cc41..73061acb5fb 100644 --- a/client/allocrunner/taskrunner/task_runner_test.go +++ b/client/allocrunner/taskrunner/task_runner_test.go @@ -1466,7 +1466,7 @@ func TestTaskRunner_BlockForSIDSToken(t *testing.T) { // control when we get a Consul SI token token := uuid.Generate() waitCh := make(chan struct{}) - deriveFn := func(*structs.Allocation, []string) (map[string]string, error) { + deriveFn := func(context.Context, *structs.Allocation, []string) (map[string]string, error) { <-waitCh return map[string]string{task.Name: token}, nil } @@ -1530,7 +1530,7 @@ func TestTaskRunner_DeriveSIToken_Retry(t *testing.T) { // control when we get a Consul SI token (recoverable failure on first call) token := uuid.Generate() deriveCount := 0 - deriveFn := func(*structs.Allocation, []string) (map[string]string, error) { + deriveFn := func(context.Context, *structs.Allocation, []string) (map[string]string, error) { if deriveCount > 0 { return map[string]string{task.Name: token}, nil diff --git a/client/client.go b/client/client.go index e1fdc98fcf0..8a4e8fb2af3 100644 --- a/client/client.go +++ b/client/client.go @@ -4,6 +4,7 @@ package client import ( + "context" "errors" "fmt" "maps" @@ -28,7 +29,7 @@ import ( "github.com/hashicorp/nomad/client/allocrunner/taskrunner/getter" "github.com/hashicorp/nomad/client/allocwatcher" "github.com/hashicorp/nomad/client/config" - consulApi "github.com/hashicorp/nomad/client/consul" + consulApiShim "github.com/hashicorp/nomad/client/consul" "github.com/hashicorp/nomad/client/devicemanager" "github.com/hashicorp/nomad/client/dynamicplugins" "github.com/hashicorp/nomad/client/fingerprint" @@ -232,7 +233,7 @@ type Client struct { // consulProxiesFunc gets an interface to Nomad's custom Consul client for // looking up supported envoy versions - consulProxiesFunc consulApi.SupportedProxiesAPIFunc + consulProxiesFunc consulApiShim.SupportedProxiesAPIFunc // consulCatalog is the subset of Consul's Catalog API Nomad uses for self // service discovery @@ -256,7 +257,7 @@ type Client struct { // tokensClient is Nomad Client's custom Consul client for requesting Consul // Service Identity tokens through Nomad Server. - tokensClient consulApi.ServiceIdentityAPI + tokensClient consulApiShim.ServiceIdentityAPI // vaultClients is used to interact with Vault for token and secret renewals vaultClients map[string]vaultclient.VaultClient @@ -348,7 +349,7 @@ var ( // registered via https://golang.org/pkg/net/rpc/#Server.RegisterName in place // of the client's normal RPC handlers. This allows server tests to override // the behavior of the client. -func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxiesFunc consulApi.SupportedProxiesAPIFunc, consulServices serviceregistration.Handler, rpcs map[string]interface{}) (*Client, error) { +func NewClient(cfg *config.Config, consulCatalog consul.CatalogAPI, consulProxiesFunc consulApiShim.SupportedProxiesAPIFunc, consulServices serviceregistration.Handler, rpcs map[string]interface{}) (*Client, error) { // Create the tls wrapper var tlsWrap tlsutil.RegionWrapper if cfg.TLSConfig.EnableRPC { @@ -2813,7 +2814,7 @@ func (c *Client) newAllocRunnerConfig( // identity tokens. // DEPRECATED: remove in 1.9.0 func (c *Client) setupConsulTokenClient() error { - tc := consulApi.NewIdentitiesClient(c.logger, c.deriveSIToken) + tc := consulApiShim.NewIdentitiesClient(c.logger, c.deriveSIToken) c.tokensClient = tc return nil } @@ -2960,7 +2961,7 @@ func (c *Client) deriveToken(alloc *structs.Allocation, taskNames []string, vcli // deriveSIToken takes an allocation and a set of tasks and derives Consul // Service Identity tokens for each of the tasks by requesting them from the // Nomad Server. -func (c *Client) deriveSIToken(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { +func (c *Client) deriveSIToken(ctx context.Context, alloc *structs.Allocation, taskNames []string) (map[string]string, error) { tasks, err := verifiedTasks(c.logger, alloc, taskNames) if err != nil { return nil, err @@ -3001,7 +3002,36 @@ func (c *Client) deriveSIToken(alloc *structs.Allocation, taskNames []string) (m // https://www.consul.io/api/acl/tokens.html#read-a-token // https://www.consul.io/docs/internals/security.html + consulConfigs := c.config.GetConsulConfigs(c.logger) + consulClientConstructor := consulApiShim.NewConsulClientFactory(c.Node()) + + tg := alloc.Job.LookupTaskGroup(alloc.TaskGroup) + tgNs := tg.Consul.GetNamespace() + + for task, secretID := range resp.Tokens { + t := tg.LookupTask(task) + ns := t.Consul.GetNamespace() + if ns == "" { + ns = tgNs + } + cluster := tg.LookupTask(task).GetConsulClusterName(tg) + consulConfig := consulConfigs[cluster] + consulClient, err := consulClientConstructor(consulConfig, c.logger) + if err != nil { + return nil, err + } + + err = consulClient.TokenPreflightCheck(ctx, &consulapi.ACLToken{ + Namespace: ns, + SecretID: secretID, + }) + if err != nil { + return nil, err + } + } + m := maps.Clone(resp.Tokens) + return m, nil } diff --git a/client/consul/consul.go b/client/consul/consul.go index 26096319cef..1222ee802ae 100644 --- a/client/consul/consul.go +++ b/client/consul/consul.go @@ -4,20 +4,24 @@ package consul import ( + "context" "fmt" + "time" consulapi "github.com/hashicorp/consul/api" "github.com/hashicorp/go-hclog" "github.com/hashicorp/go-multierror" + "github.com/hashicorp/nomad/helper" "github.com/hashicorp/nomad/helper/useragent" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/nomad/structs/config" ) -// TokenDeriverFunc takes an allocation and a set of tasks and derives a -// service identity token for each. Requests go through nomad server. -type TokenDeriverFunc func(*structs.Allocation, []string) (map[string]string, error) +// TokenDeriverFunc takes an allocation and a set of tasks and derives a service +// identity token for each. Requests go through nomad server and the local +// Consul agent. +type TokenDeriverFunc func(context.Context, *structs.Allocation, []string) (map[string]string, error) // ServiceIdentityAPI is the interface the Nomad Client uses to request Consul // Service Identity tokens through Nomad Server. @@ -27,7 +31,7 @@ type TokenDeriverFunc func(*structs.Allocation, []string) (map[string]string, er type ServiceIdentityAPI interface { // DeriveSITokens contacts the nomad server and requests consul service // identity tokens be generated for tasks in the allocation. - DeriveSITokens(alloc *structs.Allocation, tasks []string) (map[string]string, error) + DeriveSITokens(ctx context.Context, alloc *structs.Allocation, tasks []string) (map[string]string, error) } // SupportedProxiesAPI is the interface the Nomad Client uses to request from @@ -57,6 +61,10 @@ type Client interface { DeriveTokenWithJWT(JWTLoginRequest) (*consulapi.ACLToken, error) RevokeTokens([]*consulapi.ACLToken) error + + // TokenPreflightCheck verifies that a token has been replicated before we + // try to use it for registering services or bootstrapping Envoy + TokenPreflightCheck(context.Context, *consulapi.ACLToken) error } type consulClient struct { @@ -70,6 +78,12 @@ type consulClient struct { config *config.ConsulConfig logger hclog.Logger + + // preflightCheckTimeout/BaseInterval control how long the client will wait + // for Consul ACLs tokens to be fully replicated before giving up on the + // allocation; these are configurable via node metadata + preflightCheckTimeout time.Duration + preflightCheckBaseInterval time.Duration } // ConsulClientFunc creates a new Consul client for the specific Consul config @@ -78,7 +92,14 @@ type ConsulClientFunc func(config *config.ConsulConfig, logger hclog.Logger) (Cl // NewConsulClientFactory returns a ConsulClientFunc that closes over the // partition func NewConsulClientFactory(node *structs.Node) ConsulClientFunc { + + // these node values will be evaluated at the time we create the hooks, so + // we don't need to worry about them changing out from under us partition := node.Attributes["consul.partition"] + preflightCheckTimeout := durationFromMeta( + node, "consul.token_preflight_check.timeout", time.Second*10) + preflightCheckBaseInterval := durationFromMeta( + node, "consul.token_preflight_check.base", time.Millisecond*500) return func(config *config.ConsulConfig, logger hclog.Logger) (Client, error) { if config == nil { @@ -88,9 +109,11 @@ func NewConsulClientFactory(node *structs.Node) ConsulClientFunc { logger = logger.Named("consul").With("name", config.Name) c := &consulClient{ - config: config, - logger: logger, - partition: partition, + config: config, + logger: logger, + partition: partition, + preflightCheckTimeout: preflightCheckTimeout, + preflightCheckBaseInterval: preflightCheckBaseInterval, } // Get the Consul API configuration @@ -115,6 +138,18 @@ func NewConsulClientFactory(node *structs.Node) ConsulClientFunc { } } +func durationFromMeta(node *structs.Node, key string, defaultDur time.Duration) time.Duration { + val := node.Meta[key] + if key == "" { + return defaultDur + } + d, err := time.ParseDuration(val) + if err != nil || d == 0 { + return defaultDur + } + return d +} + // DeriveTokenWithJWT takes a JWT from request and returns a consul token. func (c *consulClient) DeriveTokenWithJWT(req JWTLoginRequest) (*consulapi.ACLToken, error) { t, _, err := c.client.ACL().Login(&consulapi.ACLLoginParams{ @@ -141,3 +176,40 @@ func (c *consulClient) RevokeTokens(tokens []*consulapi.ACLToken) error { return mErr.ErrorOrNil() } + +// TokenPreflightCheck verifies that a token has been replicated before we +// try to use it for registering services or bootstrapping Envoy +func (c *consulClient) TokenPreflightCheck(pctx context.Context, t *consulapi.ACLToken) error { + timer, timerStop := helper.NewStoppedTimer() + defer timerStop() + + var retry uint64 + var err error + ctx, cancel := context.WithTimeout(pctx, c.preflightCheckTimeout) + defer cancel() + + for { + _, _, err = c.client.ACL().TokenReadSelf(&consulapi.QueryOptions{ + Namespace: t.Namespace, + Partition: c.partition, + AllowStale: true, + Token: t.SecretID, + }) + if err == nil { + return nil + } + + retry++ + backoff := helper.Backoff( + c.preflightCheckBaseInterval, c.preflightCheckBaseInterval*2, retry) + c.logger.Trace("waiting for Consul stale query on token", + "error", err, "backoff", backoff) + timer.Reset(backoff) + select { + case <-ctx.Done(): + return err + case <-timer.C: + continue + } + } +} diff --git a/client/consul/consul_test.go b/client/consul/consul_test.go new file mode 100644 index 00000000000..479c6fdb003 --- /dev/null +++ b/client/consul/consul_test.go @@ -0,0 +1,122 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package consul + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + consulapi "github.com/hashicorp/consul/api" + "github.com/hashicorp/nomad/helper/testlog" + "github.com/hashicorp/nomad/helper/uuid" + "github.com/hashicorp/nomad/nomad/mock" + "github.com/hashicorp/nomad/nomad/structs/config" + "github.com/shoenig/test/must" +) + +type mockConsulServer struct { + httpSrv *httptest.Server + + lock sync.RWMutex + errorCodeOnTokenSelf int + countTokenSelf int +} + +func (m *mockConsulServer) resetTokenSelf(errNo int) { + m.lock.Lock() + defer m.lock.Unlock() + m.countTokenSelf = 0 + m.errorCodeOnTokenSelf = errNo +} + +func newMockConsulServer() *mockConsulServer { + + srv := &mockConsulServer{} + + mux := http.NewServeMux() + mux.HandleFunc("/v1/acl/token/self", func(w http.ResponseWriter, r *http.Request) { + + srv.lock.RLock() + defer srv.lock.RUnlock() + srv.countTokenSelf++ + + if srv.errorCodeOnTokenSelf == 0 { + secretID := r.Header.Get("X-Consul-Token") + token := &consulapi.ACLToken{ + SecretID: secretID, + } + buf, _ := json.Marshal(token) + fmt.Fprintf(w, string(buf)) + return + } + + w.WriteHeader(srv.errorCodeOnTokenSelf) + fmt.Fprintf(w, "{}") + }) + + srv.httpSrv = httptest.NewServer(mux) + return srv +} + +// TestConsul_TokenPreflightCheck verifies the retry logic for +func TestConsul_TokenPreflightCheck(t *testing.T) { + + consulSrv := newMockConsulServer() + consulSrv.resetTokenSelf(404) + + node := mock.Node() + node.Meta["consul.token_preflight_check.timeout"] = "100ms" + node.Meta["consul.token_preflight_check.base"] = "10ms" + factory := NewConsulClientFactory(node) + + cfg := &config.ConsulConfig{ + Addr: consulSrv.httpSrv.URL, + } + client, err := factory(cfg, testlog.HCLogger(t)) + must.NoError(t, err) + + token := &consulapi.ACLToken{ + SecretID: uuid.Generate(), + Namespace: "foo", + } + + preflightErrorCh := make(chan error) + + ctx1, cancel1 := context.WithTimeout(context.TODO(), time.Second*5) + defer cancel1() + + go func() { + preflightErrorCh <- client.TokenPreflightCheck(ctx1, token) + }() + + select { + case <-ctx1.Done(): + t.Fatal("test timed out before check timed out") + case err := <-preflightErrorCh: + must.EqError(t, err, "Unexpected response code: 404 ({})") + must.GreaterEq(t, 5, consulSrv.countTokenSelf) + } + + consulSrv.resetTokenSelf(0) + ctx2, cancel2 := context.WithTimeout(context.TODO(), time.Second*5) + defer cancel2() + + go func() { + preflightErrorCh <- client.TokenPreflightCheck(ctx2, token) + }() + + select { + case <-ctx2.Done(): + t.Fatal("test timed out and check should not have timed out") + case err := <-preflightErrorCh: + must.NoError(t, err, must.Sprintf("preflight should pass: %v", err)) + must.Eq(t, 1, consulSrv.countTokenSelf) + } +} diff --git a/client/consul/consul_testing.go b/client/consul/consul_testing.go index b0e281c5839..4ea9403366b 100644 --- a/client/consul/consul_testing.go +++ b/client/consul/consul_testing.go @@ -4,6 +4,7 @@ package consul import ( + "context" "crypto/md5" "encoding/hex" @@ -48,3 +49,7 @@ func (mc *MockConsulClient) RevokeTokens(tokens []*consulapi.ACLToken) error { } return nil } + +func (mc *MockConsulClient) TokenPreflightCheck(_ context.Context, _ *consulapi.ACLToken) error { + return nil +} diff --git a/client/consul/identities.go b/client/consul/identities.go index e7cf3669456..47125daae22 100644 --- a/client/consul/identities.go +++ b/client/consul/identities.go @@ -4,6 +4,8 @@ package consul import ( + "context" + "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/nomad/structs" ) @@ -25,8 +27,8 @@ func NewIdentitiesClient(logger hclog.Logger, tokenDeriver TokenDeriverFunc) *id } } -func (c *identitiesClient) DeriveSITokens(alloc *structs.Allocation, tasks []string) (map[string]string, error) { - tokens, err := c.tokenDeriver(alloc, tasks) +func (c *identitiesClient) DeriveSITokens(ctx context.Context, alloc *structs.Allocation, tasks []string) (map[string]string, error) { + tokens, err := c.tokenDeriver(ctx, alloc, tasks) if err != nil { c.logger.Error("error deriving SI token", "error", err, "alloc_id", alloc.ID, "task_names", tasks) return nil, err diff --git a/client/consul/identities_test.go b/client/consul/identities_test.go index e307efc3214..3b17087d788 100644 --- a/client/consul/identities_test.go +++ b/client/consul/identities_test.go @@ -4,36 +4,37 @@ package consul import ( + "context" "errors" "testing" "github.com/hashicorp/nomad/ci" "github.com/hashicorp/nomad/helper/testlog" "github.com/hashicorp/nomad/nomad/structs" - "github.com/stretchr/testify/require" + "github.com/shoenig/test/must" ) func TestSI_DeriveTokens(t *testing.T) { ci.Parallel(t) logger := testlog.HCLogger(t) - dFunc := func(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { + dFunc := func(context.Context, *structs.Allocation, []string) (map[string]string, error) { return map[string]string{"a": "b"}, nil } tc := NewIdentitiesClient(logger, dFunc) - tokens, err := tc.DeriveSITokens(nil, nil) - require.NoError(t, err) - require.Equal(t, map[string]string{"a": "b"}, tokens) + tokens, err := tc.DeriveSITokens(context.TODO(), nil, nil) + must.NoError(t, err) + must.Eq(t, map[string]string{"a": "b"}, tokens) } func TestSI_DeriveTokens_error(t *testing.T) { ci.Parallel(t) logger := testlog.HCLogger(t) - dFunc := func(alloc *structs.Allocation, taskNames []string) (map[string]string, error) { + dFunc := func(context.Context, *structs.Allocation, []string) (map[string]string, error) { return nil, errors.New("some failure") } tc := NewIdentitiesClient(logger, dFunc) - _, err := tc.DeriveSITokens(&structs.Allocation{ID: "a1"}, nil) - require.Error(t, err) + _, err := tc.DeriveSITokens(context.TODO(), &structs.Allocation{ID: "a1"}, nil) + must.Error(t, err) } diff --git a/client/consul/identities_testing.go b/client/consul/identities_testing.go index b2a3e51a2b4..05eabb0d8f5 100644 --- a/client/consul/identities_testing.go +++ b/client/consul/identities_testing.go @@ -4,6 +4,7 @@ package consul import ( + "context" "sync" "github.com/hashicorp/nomad/helper/uuid" @@ -35,13 +36,13 @@ func NewMockServiceIdentitiesClient() *MockServiceIdentitiesClient { } } -func (mtc *MockServiceIdentitiesClient) DeriveSITokens(alloc *structs.Allocation, tasks []string) (map[string]string, error) { +func (mtc *MockServiceIdentitiesClient) DeriveSITokens(ctx context.Context, alloc *structs.Allocation, tasks []string) (map[string]string, error) { mtc.lock.Lock() defer mtc.lock.Unlock() // if the DeriveTokenFn is explicitly set, use that if mtc.DeriveTokenFn != nil { - return mtc.DeriveTokenFn(alloc, tasks) + return mtc.DeriveTokenFn(ctx, alloc, tasks) } // generate a token for each task, unless the mock has an error ready for