Skip to content

Commit

Permalink
Add health-check support to loadbalancer
Browse files Browse the repository at this point in the history
* Adds support for health-checking loadbalancer servers. If a
  health-check fails when dialing, all existing connections to the
  server will be closed.
* Wires up a remotedialer tunnel connectivity check as the health check
  for supervisor/apiserver connections.
* Wires up a simple ping request to the supervisor port as the health
  check for etcd connections.

Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
  • Loading branch information
brandond committed Mar 21, 2024
1 parent 8aecc26 commit a9157f7
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 58 deletions.
35 changes: 13 additions & 22 deletions pkg/agent/loadbalancer/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ import (

// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
healthCheck func() bool
connections map[net.Conn]struct{}
}

Expand All @@ -31,7 +34,9 @@ type serverConn struct {
// actually balance connections, but instead fails over to a new server only
// when a connection attempt to the currently selected server fails.
type LoadBalancer struct {
mutex sync.Mutex
// This mutex protects access to servers map and randomServers list.
// All direct access to the servers map/list should be protected by it.
mutex sync.RWMutex
proxy *tcpproxy.Proxy

serviceName string
Expand Down Expand Up @@ -123,26 +128,9 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
}
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.ServerAddresses, lb.defaultServerAddress)

return lb, nil
}

func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.mutex.Lock()
defer lb.mutex.Unlock()

_, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer {
defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress)
}
// if the new default server doesn't have an entry in the map, add one
if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{connections: make(map[net.Conn]struct{})}
}
go lb.runHealthChecks(ctx)

lb.defaultServerAddress = serverAddress
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
return lb, nil
}

func (lb *LoadBalancer) Update(serverAddresses []string) {
Expand All @@ -166,15 +154,18 @@ func (lb *LoadBalancer) LoadBalancerServerURL() string {
return lb.localServerURL
}

func (lb *LoadBalancer) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()

startIndex := lb.nextServerIndex
for {
targetServer := lb.currentServerAddress

server := lb.servers[targetServer]
if server == nil || targetServer == "" {
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
} else {
} else if server.healthCheck() {
conn, err := server.dialContext(ctx, network, targetServer)
if err == nil {
return conn, nil
Expand Down
73 changes: 66 additions & 7 deletions pkg/agent/loadbalancer/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/url"
"os"
"strconv"
"time"

"github.com/k3s-io/k3s/pkg/version"
http_dialer "github.com/mwitkow/go-http-dialer"
Expand All @@ -17,6 +18,7 @@ import (

"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apimachinery/pkg/util/wait"
)

var defaultDialer proxy.Dialer = &net.Dialer{}
Expand Down Expand Up @@ -73,7 +75,11 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool {

for addedServer := range newAddresses.Difference(curAddresses) {
logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer)
lb.servers[addedServer] = &server{connections: make(map[net.Conn]struct{})}
lb.servers[addedServer] = &server{
address: addedServer,
connections: make(map[net.Conn]struct{}),
healthCheck: func() bool { return true },
}
}

for removedServer := range curAddresses.Difference(newAddresses) {
Expand Down Expand Up @@ -106,8 +112,8 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool {
}

func (lb *LoadBalancer) nextServer(failedServer string) (string, error) {
lb.mutex.Lock()
defer lb.mutex.Unlock()
lb.mutex.RLock()
defer lb.mutex.RUnlock()

if len(lb.randomServers) == 0 {
return "", errors.New("No servers in load balancer proxy list")
Expand Down Expand Up @@ -162,10 +168,12 @@ func (s *server) closeAll() {
s.mutex.Lock()
defer s.mutex.Unlock()

logrus.Debugf("Closing %d connections to load balancer server", len(s.connections))
for conn := range s.connections {
// Close the connection in a goroutine so that we don't hold the lock while doing so.
go conn.Close()
if l := len(s.connections); l > 0 {
logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address)
for conn := range s.connections {
// Close the connection in a goroutine so that we don't hold the lock while doing so.
go conn.Close()
}
}
}

Expand All @@ -178,3 +186,54 @@ func (sc *serverConn) Close() error {
delete(sc.server.connections, sc)
return sc.Conn.Close()
}

// SetDefault sets the selected address as the default / fallback address
func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.mutex.Lock()
defer lb.mutex.Unlock()

_, hasOriginalServer := sortServers(lb.ServerAddresses, lb.defaultServerAddress)
// if the old default server is not currently in use, remove it from the server map
if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasOriginalServer {
defer server.closeAll()
delete(lb.servers, lb.defaultServerAddress)
}
// if the new default server doesn't have an entry in the map, add one
if _, ok := lb.servers[serverAddress]; !ok {
lb.servers[serverAddress] = &server{
address: serverAddress,
healthCheck: func() bool { return true },
connections: make(map[net.Conn]struct{}),
}
}

lb.defaultServerAddress = serverAddress
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)
}

// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) {
lb.mutex.Lock()
defer lb.mutex.Unlock()

if server := lb.servers[address]; server != nil {
logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address)
server.healthCheck = healthCheck
} else {
logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address)
}
}

// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their
// connections closed, to force clients to switch over to a healthy server.
func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
wait.Until(func() {
lb.mutex.RLock()
defer lb.mutex.RUnlock()
for _, server := range lb.servers {
if !server.healthCheck() {
defer server.closeAll()
}
}
}, time.Second, ctx.Done())
}
18 changes: 17 additions & 1 deletion pkg/agent/proxy/apiproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxy

import (
"context"
"net"
sysnet "net"
"net/url"
"strconv"
Expand All @@ -21,6 +22,7 @@ type Proxy interface {
SupervisorAddresses() []string
APIServerURL() string
IsAPIServerLBEnabled() bool
SetHealthCheck(address string, healthCheck func() bool)
}

// NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If
Expand Down Expand Up @@ -70,6 +72,7 @@ type proxy struct {
apiServerEnabled bool

apiServerURL string
apiServerPort string
supervisorURL string
supervisorPort string
initialSupervisorURL string
Expand All @@ -96,6 +99,18 @@ func (p *proxy) Update(addresses []string) {
p.supervisorAddresses = supervisorAddresses
}

func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) {
if p.supervisorLB != nil {
p.supervisorLB.SetHealthCheck(address, healthCheck)
}

if p.apiServerLB != nil {
host, _, _ := net.SplitHostPort(address)
address = net.JoinHostPort(host, p.apiServerPort)
p.apiServerLB.SetHealthCheck(address, healthCheck)
}
}

func (p *proxy) setSupervisorPort(addresses []string) []string {
var newAddresses []string
for _, address := range addresses {
Expand All @@ -119,7 +134,8 @@ func (p *proxy) SetAPIServerPort(ctx context.Context, port int, isIPv6 bool) err
if err != nil {
return errors.Wrapf(err, "failed to parse server URL %s", p.initialSupervisorURL)
}
u.Host = sysnet.JoinHostPort(u.Hostname(), strconv.Itoa(port))
p.apiServerPort = strconv.Itoa(port)
u.Host = sysnet.JoinHostPort(u.Hostname(), p.apiServerPort)

p.apiServerURL = u.String()
p.apiServerEnabled = true
Expand Down
56 changes: 43 additions & 13 deletions pkg/agent/tunnel/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tunnel
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"os"
Expand Down Expand Up @@ -289,7 +290,9 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
disconnect := map[string]context.CancelFunc{}
for _, address := range proxy.SupervisorAddresses() {
if _, ok := disconnect[address]; !ok {
disconnect[address] = a.connect(ctx, wg, address, tlsConfig)
conn := a.connect(ctx, wg, address, tlsConfig)
disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected)
}
}

Expand Down Expand Up @@ -361,7 +364,9 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan
for _, address := range proxy.SupervisorAddresses() {
validEndpoint[address] = true
if _, ok := disconnect[address]; !ok {
disconnect[address] = a.connect(ctx, nil, address, tlsConfig)
conn := a.connect(ctx, nil, address, tlsConfig)
disconnect[address] = conn.cancel
proxy.SetHealthCheck(address, conn.connected)
}
}

Expand Down Expand Up @@ -403,32 +408,54 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo
return false
}

type agentConnection struct {
cancel context.CancelFunc
connected func() bool
}

// connect initiates a connection to the remotedialer server. Incoming dial requests from
// the server will be checked by the authorizer function prior to being fulfilled.
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) context.CancelFunc {
func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection {
wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address)
ws := &websocket.Dialer{
TLSClientConfig: tlsConfig,
}

// Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect.
// If we cannot connect, connected will be set to false when the initial connection attempt fails.
connected := true

once := sync.Once{}
if waitGroup != nil {
waitGroup.Add(1)
}

ctx, cancel := context.WithCancel(rootCtx)
auth := func(proto, address string) bool {
return a.authorized(rootCtx, proto, address)
}

onConnect := func(_ context.Context, _ *remotedialer.Session) error {
connected = true
logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy")
if waitGroup != nil {
once.Do(waitGroup.Done)
}
return nil
}

// Start remotedialer connect loop in a goroutine to ensure a connection to the target server
go func() {
for {
remotedialer.ClientConnect(ctx, wsURL, nil, ws, func(proto, address string) bool {
return a.authorized(rootCtx, proto, address)
}, func(_ context.Context, _ *remotedialer.Session) error {
if waitGroup != nil {
once.Do(waitGroup.Done)
}
return nil
})

// ConnectToProxy blocks until error or context cancellation
err := remotedialer.ConnectToProxy(ctx, wsURL, nil, auth, ws, onConnect)
connected = false
if err != nil && !errors.Is(err, context.Canceled) {
logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconecting...")
// wait between reconnection attempts to avoid hammering the server
time.Sleep(endpointDebounceDelay)
}
// If the context has been cancelled, exit the goroutine instead of retrying
if ctx.Err() != nil {
if waitGroup != nil {
once.Do(waitGroup.Done)
Expand All @@ -438,7 +465,10 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup
}
}()

return cancel
return agentConnection{
cancel: cancel,
connected: func() bool { return connected },
}
}

// isKubeletPort returns true if the connection is to a reserved TCP port on a loopback address.
Expand Down
2 changes: 1 addition & 1 deletion pkg/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (c *Cluster) Start(ctx context.Context) (<-chan struct{}, error) {
clientURL.Host = clientURL.Hostname() + ":2379"
clientURLs = append(clientURLs, clientURL.String())
}
etcdProxy, err := etcd.NewETCDProxy(ctx, true, c.config.DataDir, clientURLs[0], utilsnet.IsIPv6CIDR(c.config.ServiceIPRanges[0]))
etcdProxy, err := etcd.NewETCDProxy(ctx, c.config.SupervisorPort, c.config.DataDir, clientURLs[0], utilsnet.IsIPv6CIDR(c.config.ServiceIPRanges[0]))
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/cluster/managed.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func (c *Cluster) assignManagedDriver(ctx context.Context) error {
return nil
}

// setupEtcdProxy periodically updates the etcd proxy with the current list of
// setupEtcdProxy starts a goroutine to periodically update the etcd proxy with the current list of
// cluster client URLs, as retrieved from etcd.
func (c *Cluster) setupEtcdProxy(ctx context.Context, etcdProxy etcd.Proxy) {
if c.managedDB == nil {
Expand All @@ -138,15 +138,15 @@ func (c *Cluster) setupEtcdProxy(ctx context.Context, etcdProxy etcd.Proxy) {
for range t.C {
newAddresses, err := c.managedDB.GetMembersClientURLs(ctx)
if err != nil {
logrus.Warnf("failed to get etcd client URLs: %v", err)
logrus.Warnf("Failed to get etcd client URLs: %v", err)
continue
}
// client URLs are a full URI, but the proxy only wants host:port
var hosts []string
for _, address := range newAddresses {
u, err := url.Parse(address)
if err != nil {
logrus.Warnf("failed to parse etcd client URL: %v", err)
logrus.Warnf("Failed to parse etcd client URL: %v", err)
continue
}
hosts = append(hosts, u.Host)
Expand All @@ -162,7 +162,7 @@ func (c *Cluster) deleteNodePasswdSecret(ctx context.Context) {
secretsClient := c.config.Runtime.Core.Core().V1().Secret()
if err := nodepassword.Delete(secretsClient, nodeName); err != nil {
if apierrors.IsNotFound(err) {
logrus.Debugf("node password secret is not found for node %s", nodeName)
logrus.Debugf("Node password secret is not found for node %s", nodeName)
return
}
logrus.Warnf("failed to delete old node password secret: %v", err)
Expand Down
Loading

0 comments on commit a9157f7

Please sign in to comment.