Skip to content

Commit

Permalink
Refactor load balancer server list and health checking
Browse files Browse the repository at this point in the history
Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
  • Loading branch information
brandond committed Nov 16, 2024
1 parent f3047f0 commit f38a6cf
Show file tree
Hide file tree
Showing 6 changed files with 406 additions and 285 deletions.
8 changes: 4 additions & 4 deletions pkg/agent/loadbalancer/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ type lbConfig struct {

func (lb *LoadBalancer) writeConfig() error {
config := &lbConfig{
ServerURL: lb.serverURL,
ServerAddresses: lb.serverAddresses,
ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(),
ServerAddresses: lb.servers.getAddresses(),
}
configOut, err := json.MarshalIndent(config, "", " ")
if err != nil {
Expand All @@ -30,9 +30,9 @@ func (lb *LoadBalancer) updateConfig() error {
if configBytes, err := os.ReadFile(lb.configFile); err == nil {
config := &lbConfig{}
if err := json.Unmarshal(configBytes, config); err == nil {
if config.ServerURL == lb.serverURL {
if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() {
writeConfig = false
lb.setServers(config.ServerAddresses)
lb.SetServers(config.ServerAddresses)
}
}
}
Expand Down
194 changes: 81 additions & 113 deletions pkg/agent/loadbalancer/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,29 @@ import (
"errors"
"fmt"
"net"
"net/url"
"os"
"path/filepath"
"sync"
"strings"
"time"

"github.com/inetaf/tcpproxy"
"github.com/k3s-io/k3s/pkg/version"
"github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/wait"
)

// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
healthCheck func() bool
connections map[net.Conn]struct{}
}

// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
type serverConn struct {
server *server
net.Conn
}

// LoadBalancer holds data for a local listener which forwards connections to a
// pool of remote servers. It is not a proper load-balancer in that it does not
// actually balance connections, but instead fails over to a new server only
// when a connection attempt to the currently selected server fails.
type LoadBalancer struct {
// This mutex protects access to servers map and randomServers list.
// All direct access to the servers map/list should be protected by it.
mutex sync.RWMutex
proxy *tcpproxy.Proxy

serviceName string
configFile string
localAddress string
localServerURL string
defaultServerAddress string
serverURL string
serverAddresses []string
randomServers []string
servers map[string]*server
currentServerAddress string
nextServerIndex int
serviceName string
configFile string
scheme string
localAddress string
servers serverList
proxy *tcpproxy.Proxy
}

const RandomPort = 0
Expand All @@ -63,7 +40,7 @@ var (

// New contstructs a new LoadBalancer instance. The default server URL, and
// currently active servers, are stored in a file within the dataDir.
func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
config := net.ListenConfig{Control: reusePort}
var localAddress string
if isIPv6 {
Expand All @@ -84,30 +61,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
return nil, err
}

// if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with.
localAddress = listener.Addr().String()

defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress)
serverURL, err := url.Parse(defaultServerURL)
if err != nil {
return nil, err
}

if serverURL == localServerURL {
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
defaultServerAddress = ""
// Set explicit port from scheme
if serverURL.Port() == "" {
if strings.ToLower(serverURL.Scheme) == "http" {
serverURL.Host += ":80"
}
if strings.ToLower(serverURL.Scheme) == "https" {
serverURL.Host += ":443"
}
}

lb := &LoadBalancer{
serviceName: serviceName,
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
localAddress: localAddress,
localServerURL: localServerURL,
defaultServerAddress: defaultServerAddress,
servers: make(map[string]*server),
serverURL: serverURL,
serviceName: serviceName,
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
scheme: serverURL.Scheme,
localAddress: listener.Addr().String(),
}

lb.setServers([]string{lb.defaultServerAddress})
// if starting pointing at ourselves, don't set a default server address,
// which will cause all dials to fail until servers are added.
if serverURL.Host == lb.localAddress {
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
} else {
lb.servers.setDefaultAddress(serverURL.Host)
}

lb.proxy = &tcpproxy.Proxy{
ListenFunc: func(string, string) (net.Listener, error) {
Expand All @@ -126,92 +108,79 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
if err := lb.proxy.Start(); err != nil {
return nil, err
}
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress)
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress())

go lb.runHealthChecks(ctx)

return lb, nil
}

func (lb *LoadBalancer) Update(serverAddresses []string) {
if lb == nil {
return
}
if !lb.setServers(serverAddresses) {
// runHealthChecks periodically health-checks all servers.
func (lb *LoadBalancer) runHealthChecks(ctx context.Context) {
wait.Until(func() {
for _, server := range lb.servers.getServers() {
if server.healthCheck() {
lb.servers.recordSuccess(server, false)
} else {
lb.servers.recordFailure(server, false)
}
}
}, time.Second, ctx.Done())
logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName)
}

// SetServers updates the list of server addresses to contain only the listed servers.
func (lb *LoadBalancer) SetServers(serverAddresses []string) {
if !lb.servers.setAddresses(serverAddresses) {
return
}
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress)

logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress())

if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
}
}

func (lb *LoadBalancer) LoadBalancerServerURL() string {
if lb == nil {
return ""
// SetDefault sets the selected address as the default / fallback address
func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.servers.setDefaultAddress(serverAddress)
logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress)

if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
}
return lb.localServerURL
}

func (lb *LoadBalancer) ServerAddresses() []string {
if lb == nil {
return nil
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) {
if err := lb.servers.setHealthCheck(address, healthCheck); err != nil {
logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err)
} else {
logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address)
}
return lb.serverAddresses
}

func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()

var allChecksFailed bool
startIndex := lb.nextServerIndex
for {
targetServer := lb.currentServerAddress

server := lb.servers[targetServer]
if server == nil || targetServer == "" {
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
} else if allChecksFailed || server.healthCheck() {
dialTime := time.Now()
conn, err := server.dialContext(ctx, network, targetServer)
if err == nil {
return conn, nil
}
logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err)
// Don't close connections to the failed server if we're retrying with health checks ignored.
// We don't want to disrupt active connections if it is unlikely they will have anywhere to go.
if !allChecksFailed {
defer server.closeAll()
}
} else {
logrus.Debugf("Dial health check failed for %s", targetServer)
}
func (lb *LoadBalancer) LocalURL() string {
return lb.scheme + "://" + lb.localAddress
}

newServer, err := lb.nextServer(targetServer)
if err != nil {
return nil, err
}
if targetServer != newServer {
logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer)
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
func (lb *LoadBalancer) ServerAddresses() []string {
return lb.servers.getAddresses()
}

maxIndex := len(lb.randomServers)
if startIndex > maxIndex {
startIndex = maxIndex
}
if lb.nextServerIndex == startIndex {
if allChecksFailed {
return nil, errors.New("all servers failed")
}
logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName)
allChecksFailed = true
func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
for _, server := range lb.servers.getServers() {
dialTime := time.Now()
conn, err := server.dialContext(ctx, network)
if err == nil {
lb.servers.recordSuccess(server, true)
return conn, nil
}
lb.servers.recordFailure(server, true)
logrus.Debugf("Dial error from load balancer %s server %s after %s: %s", lb.serviceName, server.address, time.Now().Sub(dialTime), err)
}
return nil, errors.New("all servers failed")
}

func onDialError(src net.Conn, dstDialErr error) {
Expand All @@ -220,10 +189,9 @@ func onDialError(src net.Conn, dstDialErr error) {
}

// ResetLoadBalancer will delete the local state file for the load balancer on disk
func ResetLoadBalancer(dataDir, serviceName string) error {
func ResetLoadBalancer(dataDir, serviceName string) {
stateFile := filepath.Join(dataDir, "etc", serviceName+".json")
if err := os.Remove(stateFile); err != nil {
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
logrus.Warn(err)
}
return nil
}
20 changes: 6 additions & 14 deletions pkg/agent/loadbalancer/loadbalancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ func Test_UnitFailOver(t *testing.T) {
t.Fatalf("New() failed: %v", err)
}

parsedURL, err := url.Parse(lb.LoadBalancerServerURL())
parsedURL, err := url.Parse(lb.LocalURL())
if err != nil {
t.Fatalf("url.Parse failed: %v", err)
}
localAddress := parsedURL.Host

// add the node as a new server address.
lb.Update([]string{node1Server.address()})
lb.SetServers([]string{node1Server.address()})

// make sure connections go to the node
conn1, err := net.Dial("tcp", localAddress)
Expand Down Expand Up @@ -146,17 +146,15 @@ func Test_UnitFailOver(t *testing.T) {

t.Log("conn1 closed on failure OK")

// make sure connection still goes to the first node - it is failing health checks but so
// is the default endpoint, so it should be tried first with health checks disabled,
// before failing back to the default.
// connections shoould go to the default now that node 1 is failed
conn2, err := net.Dial("tcp", localAddress)
if err != nil {
t.Fatalf("net.Dial failed: %v", err)

}
if result, err := ping(conn2); err != nil {
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
} else if result != "default:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}

Expand All @@ -168,7 +166,7 @@ func Test_UnitFailOver(t *testing.T) {

if result, err := ping(conn2); err != nil {
t.Fatalf("ping(conn2) failed: %v", err)
} else if result != "node1:ping" {
} else if result != "default:ping" {
t.Fatalf("Unexpected ping(conn2) result: %v", result)
}

Expand All @@ -191,14 +189,8 @@ func Test_UnitFailOver(t *testing.T) {

t.Log("conn3 tested OK")

if _, err := ping(conn2); err == nil {
t.Fatal("Unexpected successful ping on closed connection conn2")
}

t.Log("conn2 closed on failure OK")

// add the second node as a new server address.
lb.Update([]string{node2Server.address()})
lb.SetServers([]string{node2Server.address()})

// make sure connection now goes to the second node,
// and connections to the default are closed.
Expand Down
Loading

0 comments on commit f38a6cf

Please sign in to comment.