From dfb5031f758df1f1173374301b06e71a60d7283a Mon Sep 17 00:00:00 2001 From: Tom Wieczorek Date: Thu, 13 Jun 2024 12:47:41 +0200 Subject: [PATCH] Use a ten second timeout for join requests This prevents k0s from hanging on idle network connections, such as random glitches, badly behaving load balancers, and so on. Add a context to the JoinClient methods and call them with a context that times out after 10 seconds when joining a controller. This has the side effect that joining will also be interruptible by SIGTERM and Ctrl+C. Signed-off-by: Tom Wieczorek --- cmd/controller/controller.go | 27 ++++++--- pkg/component/controller/etcd.go | 33 ++++++---- pkg/token/joinclient.go | 28 +++++---- pkg/token/joinclient_test.go | 100 +++++++++++++++++++++++++++++++ 4 files changed, 157 insertions(+), 31 deletions(-) create mode 100644 pkg/token/joinclient_test.go diff --git a/cmd/controller/controller.go b/cmd/controller/controller.go index 88fa1eb2e261..af1834d9dffc 100644 --- a/cmd/controller/controller.go +++ b/cmd/controller/controller.go @@ -691,16 +691,29 @@ func joinController(ctx context.Context, tokenArg string, certRootDir string) (* return nil, fmt.Errorf("wrong token type %s, expected type: controller-bootstrap", joinClient.JoinTokenType()) } + logrus.Info("Joining existing cluster via ", joinClient.Address()) + var caData v1beta1.CaResponse - err = retry.Do(func() error { - caData, err = joinClient.GetCA() + retryErr := retry.Do( + func() error { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + caData, err = joinClient.GetCA(ctx) + return err + }, + retry.Context(ctx), + retry.LastErrorOnly(true), + retry.OnRetry(func(attempt uint, err error) { + logrus.WithError(err).Debug("Failed to join in attempt #", attempt+1, ", retrying after backoff") + }), + ) + if retryErr != nil { if err != nil { - return fmt.Errorf("failed to sync CA: %w", err) + retryErr = err } - return nil - }, retry.Context(ctx)) - if err != nil { - return nil, err + return nil, fmt.Errorf("failed to join existing cluster via %s: %w", joinClient.Address(), retryErr) } + + logrus.Info("Got valid CA response, storing certificates") return joinClient, writeCerts(caData, certRootDir) } diff --git a/pkg/component/controller/etcd.go b/pkg/component/controller/etcd.go index ce98196784d9..60b4eb2ab2b6 100644 --- a/pkg/component/controller/etcd.go +++ b/pkg/component/controller/etcd.go @@ -25,6 +25,7 @@ import ( "strings" "time" + "github.com/avast/retry-go" "github.com/sirupsen/logrus" "go.etcd.io/etcd/client/pkg/v3/tlsutil" "golang.org/x/sync/errgroup" @@ -93,19 +94,29 @@ func (e *Etcd) Init(_ context.Context) error { return assets.Stage(e.K0sVars.BinDir, "etcd", constant.BinDirMode) } -func (e *Etcd) syncEtcdConfig(peerURL, etcdCaCert, etcdCaCertKey string) ([]string, error) { +func (e *Etcd) syncEtcdConfig(ctx context.Context, peerURL, etcdCaCert, etcdCaCertKey string) ([]string, error) { + logrus.Info("Synchronizing etcd config with existing cluster via ", e.JoinClient.Address()) + var etcdResponse v1beta1.EtcdResponse var err error - for i := 0; i < 20; i++ { - logrus.Debugf("trying to sync etcd config") - etcdResponse, err = e.JoinClient.JoinEtcd(peerURL) - if err == nil { - break + retryErr := retry.Do( + func() error { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + etcdResponse, err = e.JoinClient.JoinEtcd(ctx, peerURL) + return err + }, + retry.Context(ctx), + retry.LastErrorOnly(true), + retry.OnRetry(func(attempt uint, err error) { + logrus.WithError(err).Debug("Failed to synchronize etcd config in attempt #", attempt+1, ", retrying after backoff") + }), + ) + if retryErr != nil { + if err != nil { + retryErr = err } - time.Sleep(500 * time.Millisecond) - } - if err != nil { - return nil, err + return nil, fmt.Errorf("failed to synchronize etcd config with existing cluster via %s: %w", e.JoinClient.Address(), retryErr) } logrus.Debugf("got cluster info: %v", etcdResponse.InitialCluster) @@ -179,7 +190,7 @@ func (e *Etcd) Start(ctx context.Context) error { if file.Exists(filepath.Join(e.K0sVars.EtcdDataDir, "member", "snap", "db")) { logrus.Warnf("etcd db file(s) already exist, not gonna run join process") } else if e.JoinClient != nil { - initialCluster, err := e.syncEtcdConfig(peerURL, etcdCaCert, etcdCaCertKey) + initialCluster, err := e.syncEtcdConfig(ctx, peerURL, etcdCaCert, etcdCaCertKey) if err != nil { return fmt.Errorf("failed to sync etcd config: %w", err) } diff --git a/pkg/token/joinclient.go b/pkg/token/joinclient.go index 9883a24fbf99..b8e4f3dc03cf 100644 --- a/pkg/token/joinclient.go +++ b/pkg/token/joinclient.go @@ -18,6 +18,7 @@ package token import ( "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/json" @@ -27,7 +28,6 @@ import ( "os" "github.com/k0sproject/k0s/pkg/apis/k0s/v1beta1" - "github.com/sirupsen/logrus" "k8s.io/client-go/tools/clientcmd" ) @@ -74,16 +74,23 @@ func JoinClientFromToken(encodedToken string) (*JoinClient, error) { c.joinAddress = config.Host c.joinTokenType = GetTokenType(&raw) - logrus.Info("initialized join client successfully") return c, nil } +func (j *JoinClient) Address() string { + return j.joinAddress +} + +func (j *JoinClient) JoinTokenType() string { + return j.joinTokenType +} + // GetCA calls the CA sync API -func (j *JoinClient) GetCA() (v1beta1.CaResponse, error) { +func (j *JoinClient) GetCA(ctx context.Context) (v1beta1.CaResponse, error) { var caData v1beta1.CaResponse - req, err := http.NewRequest(http.MethodGet, j.joinAddress+"/v1beta1/ca", nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, j.joinAddress+"/v1beta1/ca", nil) if err != nil { - return caData, err + return caData, fmt.Errorf("failed to create join request: %w", err) } req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", j.bearerToken)) @@ -96,7 +103,6 @@ func (j *JoinClient) GetCA() (v1beta1.CaResponse, error) { if resp.StatusCode != http.StatusOK { return caData, fmt.Errorf("unexpected response status: %s", resp.Status) } - logrus.Info("got valid CA response") b, err := io.ReadAll(resp.Body) if err != nil { return caData, err @@ -109,7 +115,7 @@ func (j *JoinClient) GetCA() (v1beta1.CaResponse, error) { } // JoinEtcd calls the etcd join API -func (j *JoinClient) JoinEtcd(peerAddress string) (v1beta1.EtcdResponse, error) { +func (j *JoinClient) JoinEtcd(ctx context.Context, peerAddress string) (v1beta1.EtcdResponse, error) { var etcdResponse v1beta1.EtcdResponse etcdRequest := v1beta1.EtcdRequest{ PeerAddress: peerAddress, @@ -125,9 +131,9 @@ func (j *JoinClient) JoinEtcd(peerAddress string) (v1beta1.EtcdResponse, error) return etcdResponse, err } - req, err := http.NewRequest(http.MethodPost, j.joinAddress+"/v1beta1/etcd/members", buf) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, j.joinAddress+"/v1beta1/etcd/members", buf) if err != nil { - return etcdResponse, err + return etcdResponse, fmt.Errorf("failed to create join request: %w", err) } req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", j.bearerToken)) resp, err := j.httpClient.Do(req) @@ -150,7 +156,3 @@ func (j *JoinClient) JoinEtcd(peerAddress string) (v1beta1.EtcdResponse, error) return etcdResponse, nil } - -func (j *JoinClient) JoinTokenType() string { - return j.joinTokenType -} diff --git a/pkg/token/joinclient_test.go b/pkg/token/joinclient_test.go new file mode 100644 index 000000000000..19e542c71e6b --- /dev/null +++ b/pkg/token/joinclient_test.go @@ -0,0 +1,100 @@ +/* +Copyright 2024 k0s authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package token_test + +import ( + "bytes" + "context" + "net" + "net/http" + "net/url" + "testing" + + "github.com/k0sproject/k0s/pkg/token" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJoinClient_Cancellation(t *testing.T) { + t.Parallel() + + for _, test := range []struct { + name string + funcUnderTest func(context.Context, *token.JoinClient) error + }{ + {"GetCA", func(ctx context.Context, c *token.JoinClient) error { + _, err := c.GetCA(ctx) + return err + }}, + {"JoinEtcd", func(ctx context.Context, c *token.JoinClient) error { + _, err := c.JoinEtcd(ctx, "") + return err + }}, + } { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + clientContext, cancelClientContext := context.WithCancelCause(context.Background()) + joinURL := startFakeJoinServer(t, func(_ http.ResponseWriter, req *http.Request) { + cancelClientContext(assert.AnError) // cancel the client's context + <-req.Context().Done() // block forever + }) + + kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), nil, "", "") + require.NoError(t, err) + tok, err := token.JoinEncode(bytes.NewReader(kubeconfig)) + require.NoError(t, err) + + underTest, err := token.JoinClientFromToken(tok) + require.NoError(t, err) + + err = test.funcUnderTest(clientContext, underTest) + assert.ErrorIs(t, err, context.Canceled, "Expected the call to be cancelled") + assert.Same(t, context.Cause(clientContext), assert.AnError, "Didn't receive an HTTP request") + }) + } +} + +func startFakeJoinServer(t *testing.T, handler func(http.ResponseWriter, *http.Request)) *url.URL { + requestCtx, cancelRequests := context.WithCancel(context.Background()) + + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + require.NoError(t, err) + } + + server := &http.Server{ + Addr: listener.Addr().String(), + Handler: http.HandlerFunc(handler), + BaseContext: func(net.Listener) context.Context { return requestCtx }, + } + + serverError := make(chan error) + go func() { defer close(serverError); serverError <- server.Serve(listener) }() + + t.Cleanup(func() { + cancelRequests() + if !assert.NoError(t, server.Shutdown(context.Background()), "Couldn't shutdown HTTP server") { + return + } + assert.ErrorIs(t, <-serverError, http.ErrServerClosed, "HTTP server terminated unexpectedly") + }) + + return &url.URL{Scheme: "http", Host: server.Addr} +}