Skip to content

Commit

Permalink
Refactor load balancer server list and health checking
Browse files Browse the repository at this point in the history
Signed-off-by: Brad Davidson <brad.davidson@rancher.com>
  • Loading branch information
brandond committed Dec 6, 2024
1 parent 95797c4 commit 911ee19
Show file tree
Hide file tree
Showing 9 changed files with 864 additions and 516 deletions.
21 changes: 9 additions & 12 deletions pkg/agent/loadbalancer/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ type lbConfig struct {

func (lb *LoadBalancer) writeConfig() error {
config := &lbConfig{
ServerURL: lb.serverURL,
ServerAddresses: lb.serverAddresses,
ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(),
ServerAddresses: lb.servers.getAddresses(),
}
configOut, err := json.MarshalIndent(config, "", " ")
if err != nil {
Expand All @@ -26,20 +26,17 @@ func (lb *LoadBalancer) writeConfig() error {
}

func (lb *LoadBalancer) updateConfig() error {
writeConfig := true
if configBytes, err := os.ReadFile(lb.configFile); err == nil {
config := &lbConfig{}
if err := json.Unmarshal(configBytes, config); err == nil {
if config.ServerURL == lb.serverURL {
writeConfig = false
lb.setServers(config.ServerAddresses)
// if the default server from the config matches our current default,
// load the rest of the addresses as well.
if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() {
lb.Update(config.ServerAddresses)
return nil
}
}
}
if writeConfig {
if err := lb.writeConfig(); err != nil {
return err
}
}
return nil
// config didn't exist or used a different default server, write the current config to disk.
return lb.writeConfig()
}
2 changes: 1 addition & 1 deletion pkg/agent/loadbalancer/httpproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func SetHTTPProxy(address string) error {
func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) {
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Create a new HTTP proxy dialer
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer)))
httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithConnectionTimeout(10*time.Second), http_dialer.WithDialer(forward.(*net.Dialer)))
return httpProxyDialer, nil
} else if proxyURL.Scheme == "socks5" {
// For SOCKS5 proxies, use the proxy package's FromURL
Expand Down
6 changes: 4 additions & 2 deletions pkg/agent/loadbalancer/httpproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@ package loadbalancer

import (
"fmt"
"net"
"os"
"strings"
"testing"

"github.com/k3s-io/k3s/pkg/version"
"github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
)

var originalDialer proxy.Dialer
var defaultEnv map[string]string
var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"}

Expand All @@ -19,7 +20,7 @@ func init() {
}

func prepareEnv(env ...string) {
defaultDialer = &net.Dialer{}
originalDialer = defaultDialer
defaultEnv = map[string]string{}
for _, e := range proxyEnvs {
if v, ok := os.LookupEnv(e); ok {
Expand All @@ -34,6 +35,7 @@ func prepareEnv(env ...string) {
}

func restoreEnv() {
defaultDialer = originalDialer
for _, e := range proxyEnvs {
if v, ok := defaultEnv[e]; ok {
os.Setenv(e, v)
Expand Down
174 changes: 55 additions & 119 deletions pkg/agent/loadbalancer/loadbalancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,29 @@ package loadbalancer

import (
"context"
"errors"
"fmt"
"net"
"net/url"
"os"
"path/filepath"
"sync"
"time"
"strings"

"github.com/inetaf/tcpproxy"
"github.com/k3s-io/k3s/pkg/version"
"github.com/sirupsen/logrus"
)

// server tracks the connections to a server, so that they can be closed when the server is removed.
type server struct {
// This mutex protects access to the connections map. All direct access to the map should be protected by it.
mutex sync.Mutex
address string
healthCheck func() bool
connections map[net.Conn]struct{}
}

// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed.
type serverConn struct {
server *server
net.Conn
}

// LoadBalancer holds data for a local listener which forwards connections to a
// pool of remote servers. It is not a proper load-balancer in that it does not
// actually balance connections, but instead fails over to a new server only
// when a connection attempt to the currently selected server fails.
type LoadBalancer struct {
// This mutex protects access to servers map and randomServers list.
// All direct access to the servers map/list should be protected by it.
mutex sync.RWMutex
proxy *tcpproxy.Proxy

serviceName string
configFile string
localAddress string
localServerURL string
defaultServerAddress string
serverURL string
serverAddresses []string
randomServers []string
servers map[string]*server
currentServerAddress string
nextServerIndex int
serviceName string
configFile string
scheme string
localAddress string
servers serverList
proxy *tcpproxy.Proxy
}

const RandomPort = 0
Expand All @@ -63,7 +37,7 @@ var (

// New contstructs a new LoadBalancer instance. The default server URL, and
// currently active servers, are stored in a file within the dataDir.
func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) {
config := net.ListenConfig{Control: reusePort}
var localAddress string
if isIPv6 {
Expand All @@ -84,30 +58,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
return nil, err
}

// if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with.
localAddress = listener.Addr().String()

defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress)
serverURL, err := url.Parse(defaultServerURL)
if err != nil {
return nil, err
}

if serverURL == localServerURL {
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
defaultServerAddress = ""
// Set explicit port from scheme
if serverURL.Port() == "" {
if strings.ToLower(serverURL.Scheme) == "http" {
serverURL.Host += ":80"
}
if strings.ToLower(serverURL.Scheme) == "https" {
serverURL.Host += ":443"
}
}

lb := &LoadBalancer{
serviceName: serviceName,
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
localAddress: localAddress,
localServerURL: localServerURL,
defaultServerAddress: defaultServerAddress,
servers: make(map[string]*server),
serverURL: serverURL,
serviceName: serviceName,
configFile: filepath.Join(dataDir, "etc", serviceName+".json"),
scheme: serverURL.Scheme,
localAddress: listener.Addr().String(),
}

lb.setServers([]string{lb.defaultServerAddress})
// if starting pointing at ourselves, don't set a default server address,
// which will cause all dials to fail until servers are added.
if serverURL.Host == lb.localAddress {
logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName)
} else {
lb.servers.setDefaultAddress(lb.serviceName, serverURL.Host)
}

lb.proxy = &tcpproxy.Proxy{
ListenFunc: func(string, string) (net.Listener, error) {
Expand All @@ -116,7 +95,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
}
lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{
Addr: serviceName,
DialContext: lb.dialContext,
DialContext: lb.servers.dialContext,
OnDialError: onDialError,
})

Expand All @@ -126,92 +105,50 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo
if err := lb.proxy.Start(); err != nil {
return nil, err
}
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress)
logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress())

go lb.runHealthChecks(ctx)
go lb.servers.runHealthChecks(ctx, lb.serviceName)

return lb, nil
}

// Update updates the list of server addresses to contain only the listed servers.
func (lb *LoadBalancer) Update(serverAddresses []string) {
if lb == nil {
return
}
if !lb.setServers(serverAddresses) {
if !lb.servers.setAddresses(lb.serviceName, serverAddresses) {
return
}
logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress)

logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress())

if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
}
}

func (lb *LoadBalancer) LoadBalancerServerURL() string {
if lb == nil {
return ""
// SetDefault sets the selected address as the default / fallback address
func (lb *LoadBalancer) SetDefault(serverAddress string) {
lb.servers.setDefaultAddress(lb.serviceName, serverAddress)

if err := lb.writeConfig(); err != nil {
logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err)
}
return lb.localServerURL
}

func (lb *LoadBalancer) ServerAddresses() []string {
if lb == nil {
return nil
// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function.
func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck HealthCheckFunc) {
if err := lb.servers.setHealthCheck(address, healthCheck); err != nil {
logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err)
} else {
logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address)
}
return lb.serverAddresses
}

func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) {
lb.mutex.RLock()
defer lb.mutex.RUnlock()

var allChecksFailed bool
startIndex := lb.nextServerIndex
for {
targetServer := lb.currentServerAddress

server := lb.servers[targetServer]
if server == nil || targetServer == "" {
logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer)
} else if allChecksFailed || server.healthCheck() {
dialTime := time.Now()
conn, err := server.dialContext(ctx, network, targetServer)
if err == nil {
return conn, nil
}
logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err)
// Don't close connections to the failed server if we're retrying with health checks ignored.
// We don't want to disrupt active connections if it is unlikely they will have anywhere to go.
if !allChecksFailed {
defer server.closeAll()
}
} else {
logrus.Debugf("Dial health check failed for %s", targetServer)
}

newServer, err := lb.nextServer(targetServer)
if err != nil {
return nil, err
}
if targetServer != newServer {
logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer)
}
if ctx.Err() != nil {
return nil, ctx.Err()
}
func (lb *LoadBalancer) LocalURL() string {
return lb.scheme + "://" + lb.localAddress
}

maxIndex := len(lb.randomServers)
if startIndex > maxIndex {
startIndex = maxIndex
}
if lb.nextServerIndex == startIndex {
if allChecksFailed {
return nil, errors.New("all servers failed")
}
logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName)
allChecksFailed = true
}
}
func (lb *LoadBalancer) ServerAddresses() []string {
return lb.servers.getAddresses()
}

func onDialError(src net.Conn, dstDialErr error) {
Expand All @@ -220,10 +157,9 @@ func onDialError(src net.Conn, dstDialErr error) {
}

// ResetLoadBalancer will delete the local state file for the load balancer on disk
func ResetLoadBalancer(dataDir, serviceName string) error {
func ResetLoadBalancer(dataDir, serviceName string) {
stateFile := filepath.Join(dataDir, "etc", serviceName+".json")
if err := os.Remove(stateFile); err != nil {
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
logrus.Warn(err)
}
return nil
}
Loading

0 comments on commit 911ee19

Please sign in to comment.