Skip to content

Commit

Permalink
Use a ten second timeout for join requests
Browse files Browse the repository at this point in the history
This prevents k0s from hanging on idle network connections, such as
random glitches, badly behaving load balancers, and so on.

Add a context to the JoinClient methods and call them with a context
that times out after 10 seconds when joining a controller. This has
the side effect that joining will also be interruptible by SIGTERM and
Ctrl+C.

Signed-off-by: Tom Wieczorek <twieczorek@mirantis.com>
  • Loading branch information
twz123 committed Jun 13, 2024
1 parent 64ea131 commit dfb5031
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 31 deletions.
27 changes: 20 additions & 7 deletions cmd/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,16 +691,29 @@ func joinController(ctx context.Context, tokenArg string, certRootDir string) (*
return nil, fmt.Errorf("wrong token type %s, expected type: controller-bootstrap", joinClient.JoinTokenType())
}

logrus.Info("Joining existing cluster via ", joinClient.Address())

var caData v1beta1.CaResponse
err = retry.Do(func() error {
caData, err = joinClient.GetCA()
retryErr := retry.Do(
func() error {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
caData, err = joinClient.GetCA(ctx)
return err
},
retry.Context(ctx),
retry.LastErrorOnly(true),
retry.OnRetry(func(attempt uint, err error) {
logrus.WithError(err).Debug("Failed to join in attempt #", attempt+1, ", retrying after backoff")
}),
)
if retryErr != nil {
if err != nil {
return fmt.Errorf("failed to sync CA: %w", err)
retryErr = err
}
return nil
}, retry.Context(ctx))
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to join existing cluster via %s: %w", joinClient.Address(), retryErr)
}

logrus.Info("Got valid CA response, storing certificates")
return joinClient, writeCerts(caData, certRootDir)
}
33 changes: 22 additions & 11 deletions pkg/component/controller/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"strings"
"time"

"github.com/avast/retry-go"
"github.com/sirupsen/logrus"
"go.etcd.io/etcd/client/pkg/v3/tlsutil"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -93,19 +94,29 @@ func (e *Etcd) Init(_ context.Context) error {
return assets.Stage(e.K0sVars.BinDir, "etcd", constant.BinDirMode)
}

func (e *Etcd) syncEtcdConfig(peerURL, etcdCaCert, etcdCaCertKey string) ([]string, error) {
func (e *Etcd) syncEtcdConfig(ctx context.Context, peerURL, etcdCaCert, etcdCaCertKey string) ([]string, error) {
logrus.Info("Synchronizing etcd config with existing cluster via ", e.JoinClient.Address())

var etcdResponse v1beta1.EtcdResponse
var err error
for i := 0; i < 20; i++ {
logrus.Debugf("trying to sync etcd config")
etcdResponse, err = e.JoinClient.JoinEtcd(peerURL)
if err == nil {
break
retryErr := retry.Do(
func() error {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
etcdResponse, err = e.JoinClient.JoinEtcd(ctx, peerURL)
return err
},
retry.Context(ctx),
retry.LastErrorOnly(true),
retry.OnRetry(func(attempt uint, err error) {
logrus.WithError(err).Debug("Failed to synchronize etcd config in attempt #", attempt+1, ", retrying after backoff")
}),
)
if retryErr != nil {
if err != nil {
retryErr = err
}
time.Sleep(500 * time.Millisecond)
}
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to synchronize etcd config with existing cluster via %s: %w", e.JoinClient.Address(), retryErr)
}

logrus.Debugf("got cluster info: %v", etcdResponse.InitialCluster)
Expand Down Expand Up @@ -179,7 +190,7 @@ func (e *Etcd) Start(ctx context.Context) error {
if file.Exists(filepath.Join(e.K0sVars.EtcdDataDir, "member", "snap", "db")) {
logrus.Warnf("etcd db file(s) already exist, not gonna run join process")
} else if e.JoinClient != nil {
initialCluster, err := e.syncEtcdConfig(peerURL, etcdCaCert, etcdCaCertKey)
initialCluster, err := e.syncEtcdConfig(ctx, peerURL, etcdCaCert, etcdCaCertKey)
if err != nil {
return fmt.Errorf("failed to sync etcd config: %w", err)
}
Expand Down
28 changes: 15 additions & 13 deletions pkg/token/joinclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package token

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
Expand All @@ -27,7 +28,6 @@ import (
"os"

"github.com/k0sproject/k0s/pkg/apis/k0s/v1beta1"
"github.com/sirupsen/logrus"
"k8s.io/client-go/tools/clientcmd"
)

Expand Down Expand Up @@ -74,16 +74,23 @@ func JoinClientFromToken(encodedToken string) (*JoinClient, error) {
c.joinAddress = config.Host
c.joinTokenType = GetTokenType(&raw)

logrus.Info("initialized join client successfully")
return c, nil
}

func (j *JoinClient) Address() string {
return j.joinAddress
}

func (j *JoinClient) JoinTokenType() string {
return j.joinTokenType
}

// GetCA calls the CA sync API
func (j *JoinClient) GetCA() (v1beta1.CaResponse, error) {
func (j *JoinClient) GetCA(ctx context.Context) (v1beta1.CaResponse, error) {
var caData v1beta1.CaResponse
req, err := http.NewRequest(http.MethodGet, j.joinAddress+"/v1beta1/ca", nil)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, j.joinAddress+"/v1beta1/ca", nil)
if err != nil {
return caData, err
return caData, fmt.Errorf("failed to create join request: %w", err)
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", j.bearerToken))

Expand All @@ -96,7 +103,6 @@ func (j *JoinClient) GetCA() (v1beta1.CaResponse, error) {
if resp.StatusCode != http.StatusOK {
return caData, fmt.Errorf("unexpected response status: %s", resp.Status)
}
logrus.Info("got valid CA response")
b, err := io.ReadAll(resp.Body)
if err != nil {
return caData, err
Expand All @@ -109,7 +115,7 @@ func (j *JoinClient) GetCA() (v1beta1.CaResponse, error) {
}

// JoinEtcd calls the etcd join API
func (j *JoinClient) JoinEtcd(peerAddress string) (v1beta1.EtcdResponse, error) {
func (j *JoinClient) JoinEtcd(ctx context.Context, peerAddress string) (v1beta1.EtcdResponse, error) {
var etcdResponse v1beta1.EtcdResponse
etcdRequest := v1beta1.EtcdRequest{
PeerAddress: peerAddress,
Expand All @@ -125,9 +131,9 @@ func (j *JoinClient) JoinEtcd(peerAddress string) (v1beta1.EtcdResponse, error)
return etcdResponse, err
}

req, err := http.NewRequest(http.MethodPost, j.joinAddress+"/v1beta1/etcd/members", buf)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, j.joinAddress+"/v1beta1/etcd/members", buf)
if err != nil {
return etcdResponse, err
return etcdResponse, fmt.Errorf("failed to create join request: %w", err)
}
req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", j.bearerToken))
resp, err := j.httpClient.Do(req)
Expand All @@ -150,7 +156,3 @@ func (j *JoinClient) JoinEtcd(peerAddress string) (v1beta1.EtcdResponse, error)

return etcdResponse, nil
}

func (j *JoinClient) JoinTokenType() string {
return j.joinTokenType
}
100 changes: 100 additions & 0 deletions pkg/token/joinclient_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
Copyright 2024 k0s authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package token_test

import (
"bytes"
"context"
"net"
"net/http"
"net/url"
"testing"

"github.com/k0sproject/k0s/pkg/token"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestJoinClient_Cancellation(t *testing.T) {
t.Parallel()

for _, test := range []struct {
name string
funcUnderTest func(context.Context, *token.JoinClient) error
}{
{"GetCA", func(ctx context.Context, c *token.JoinClient) error {
_, err := c.GetCA(ctx)
return err
}},
{"JoinEtcd", func(ctx context.Context, c *token.JoinClient) error {
_, err := c.JoinEtcd(ctx, "")
return err
}},
} {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()

clientContext, cancelClientContext := context.WithCancelCause(context.Background())
joinURL := startFakeJoinServer(t, func(_ http.ResponseWriter, req *http.Request) {
cancelClientContext(assert.AnError) // cancel the client's context
<-req.Context().Done() // block forever
})

kubeconfig, err := token.GenerateKubeconfig(joinURL.String(), nil, "", "")
require.NoError(t, err)
tok, err := token.JoinEncode(bytes.NewReader(kubeconfig))
require.NoError(t, err)

underTest, err := token.JoinClientFromToken(tok)
require.NoError(t, err)

err = test.funcUnderTest(clientContext, underTest)
assert.ErrorIs(t, err, context.Canceled, "Expected the call to be cancelled")
assert.Same(t, context.Cause(clientContext), assert.AnError, "Didn't receive an HTTP request")
})
}
}

func startFakeJoinServer(t *testing.T, handler func(http.ResponseWriter, *http.Request)) *url.URL {
requestCtx, cancelRequests := context.WithCancel(context.Background())

listener, err := net.Listen("tcp", "localhost:0")
if err != nil {
require.NoError(t, err)
}

server := &http.Server{
Addr: listener.Addr().String(),
Handler: http.HandlerFunc(handler),
BaseContext: func(net.Listener) context.Context { return requestCtx },
}

serverError := make(chan error)
go func() { defer close(serverError); serverError <- server.Serve(listener) }()

t.Cleanup(func() {
cancelRequests()
if !assert.NoError(t, server.Shutdown(context.Background()), "Couldn't shutdown HTTP server") {
return
}
assert.ErrorIs(t, <-serverError, http.ErrServerClosed, "HTTP server terminated unexpectedly")
})

return &url.URL{Scheme: "http", Host: server.Addr}
}

0 comments on commit dfb5031

Please sign in to comment.