From a149b59314948a017239d1004f304cb57671916f Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 15 Nov 2024 22:11:47 +0000 Subject: [PATCH] Refactor load balancer server list and health checking Signed-off-by: Brad Davidson --- pkg/agent/loadbalancer/config.go | 8 +- pkg/agent/loadbalancer/loadbalancer.go | 192 ++++---- pkg/agent/loadbalancer/loadbalancer_test.go | 75 ++- pkg/agent/loadbalancer/servers.go | 482 ++++++++++++++------ pkg/agent/proxy/apiproxy.go | 4 +- pkg/etcd/etcdproxy.go | 2 +- 6 files changed, 485 insertions(+), 278 deletions(-) diff --git a/pkg/agent/loadbalancer/config.go b/pkg/agent/loadbalancer/config.go index 9a2de3214fbb..b93d05e9a20f 100644 --- a/pkg/agent/loadbalancer/config.go +++ b/pkg/agent/loadbalancer/config.go @@ -15,8 +15,8 @@ type lbConfig struct { func (lb *LoadBalancer) writeConfig() error { config := &lbConfig{ - ServerURL: lb.serverURL, - ServerAddresses: lb.serverAddresses, + ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(), + ServerAddresses: lb.servers.getAddresses(), } configOut, err := json.MarshalIndent(config, "", " ") if err != nil { @@ -30,9 +30,9 @@ func (lb *LoadBalancer) updateConfig() error { if configBytes, err := os.ReadFile(lb.configFile); err == nil { config := &lbConfig{} if err := json.Unmarshal(configBytes, config); err == nil { - if config.ServerURL == lb.serverURL { + if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() { writeConfig = false - lb.setServers(config.ServerAddresses) + lb.Update(config.ServerAddresses) } } } diff --git a/pkg/agent/loadbalancer/loadbalancer.go b/pkg/agent/loadbalancer/loadbalancer.go index db9fa6f16f72..d272ab41857e 100644 --- a/pkg/agent/loadbalancer/loadbalancer.go +++ b/pkg/agent/loadbalancer/loadbalancer.go @@ -5,52 +5,29 @@ import ( "errors" "fmt" "net" + "net/url" "os" "path/filepath" - "sync" + "strings" "time" "github.com/inetaf/tcpproxy" "github.com/k3s-io/k3s/pkg/version" "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/util/wait" ) -// server tracks the connections to a server, so that they can be closed when the server is removed. -type server struct { - // This mutex protects access to the connections map. All direct access to the map should be protected by it. - mutex sync.Mutex - address string - healthCheck func() bool - connections map[net.Conn]struct{} -} - -// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. -type serverConn struct { - server *server - net.Conn -} - // LoadBalancer holds data for a local listener which forwards connections to a // pool of remote servers. It is not a proper load-balancer in that it does not // actually balance connections, but instead fails over to a new server only // when a connection attempt to the currently selected server fails. type LoadBalancer struct { - // This mutex protects access to servers map and randomServers list. - // All direct access to the servers map/list should be protected by it. - mutex sync.RWMutex - proxy *tcpproxy.Proxy - - serviceName string - configFile string - localAddress string - localServerURL string - defaultServerAddress string - serverURL string - serverAddresses []string - randomServers []string - servers map[string]*server - currentServerAddress string - nextServerIndex int + serviceName string + configFile string + scheme string + localAddress string + servers serverList + proxy *tcpproxy.Proxy } const RandomPort = 0 @@ -63,7 +40,7 @@ var ( // New contstructs a new LoadBalancer instance. The default server URL, and // currently active servers, are stored in a file within the dataDir. -func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { +func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { config := net.ListenConfig{Control: reusePort} var localAddress string if isIPv6 { @@ -84,30 +61,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo return nil, err } - // if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with. - localAddress = listener.Addr().String() - - defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress) + serverURL, err := url.Parse(defaultServerURL) if err != nil { return nil, err } - if serverURL == localServerURL { - logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) - defaultServerAddress = "" + // Set explicit port from scheme + if serverURL.Port() == "" { + if strings.ToLower(serverURL.Scheme) == "http" { + serverURL.Host += ":80" + } + if strings.ToLower(serverURL.Scheme) == "https" { + serverURL.Host += ":443" + } } lb := &LoadBalancer{ - serviceName: serviceName, - configFile: filepath.Join(dataDir, "etc", serviceName+".json"), - localAddress: localAddress, - localServerURL: localServerURL, - defaultServerAddress: defaultServerAddress, - servers: make(map[string]*server), - serverURL: serverURL, + serviceName: serviceName, + configFile: filepath.Join(dataDir, "etc", serviceName+".json"), + scheme: serverURL.Scheme, + localAddress: listener.Addr().String(), } - lb.setServers([]string{lb.defaultServerAddress}) + // if starting pointing at ourselves, don't set a default server address, + // which will cause all dials to fail until servers are added. + if serverURL.Host == lb.localAddress { + logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) + } else { + lb.servers.setDefaultAddress(serverURL.Host) + } lb.proxy = &tcpproxy.Proxy{ ListenFunc: func(string, string) (net.Listener, error) { @@ -126,92 +108,79 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo if err := lb.proxy.Start(); err != nil { return nil, err } - logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress) + logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress()) go lb.runHealthChecks(ctx) return lb, nil } +// runHealthChecks periodically health-checks all servers. +func (lb *LoadBalancer) runHealthChecks(ctx context.Context) { + wait.Until(func() { + for _, server := range lb.servers.getServers() { + if server.healthCheck() { + lb.servers.recordSuccess(server, reason_health) + } else { + lb.servers.recordFailure(server, reason_health) + } + } + }, time.Second, ctx.Done()) + logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) +} + +// Update updates the list of server addresses to contain only the listed servers. func (lb *LoadBalancer) Update(serverAddresses []string) { - if lb == nil { - return - } - if !lb.setServers(serverAddresses) { + if !lb.servers.setAddresses(serverAddresses) { return } - logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress) + + logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress()) if err := lb.writeConfig(); err != nil { logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) } } -func (lb *LoadBalancer) LoadBalancerServerURL() string { - if lb == nil { - return "" +// SetDefault sets the selected address as the default / fallback address +func (lb *LoadBalancer) SetDefault(serverAddress string) { + lb.servers.setDefaultAddress(serverAddress) + logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress) + + if err := lb.writeConfig(); err != nil { + logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) } - return lb.localServerURL } -func (lb *LoadBalancer) ServerAddresses() []string { - if lb == nil { - return nil +// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. +func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) { + if err := lb.servers.setHealthCheck(address, healthCheck); err != nil { + logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err) + } else { + logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address) } - return lb.serverAddresses } func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) { - lb.mutex.RLock() - defer lb.mutex.RUnlock() - - var allChecksFailed bool - startIndex := lb.nextServerIndex - for { - targetServer := lb.currentServerAddress - - server := lb.servers[targetServer] - if server == nil || targetServer == "" { - logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer) - } else if allChecksFailed || server.healthCheck() { - dialTime := time.Now() - conn, err := server.dialContext(ctx, network, targetServer) - if err == nil { - return conn, nil - } - logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err) - // Don't close connections to the failed server if we're retrying with health checks ignored. - // We don't want to disrupt active connections if it is unlikely they will have anywhere to go. - if !allChecksFailed { - defer server.closeAll() - } - } else { - logrus.Debugf("Dial health check failed for %s", targetServer) + for _, server := range lb.servers.getServers() { + dialTime := time.Now() + conn, err := server.dialContext(ctx, network) + if err == nil { + lb.servers.recordSuccess(server, reason_dial) + return conn, nil } + lb.servers.recordFailure(server, reason_dial) + logrus.Debugf("Dial error from load balancer %s server %s after %s: %s", lb.serviceName, server.address, time.Now().Sub(dialTime), err) + } + return nil, errors.New("all servers failed") +} - newServer, err := lb.nextServer(targetServer) - if err != nil { - return nil, err - } - if targetServer != newServer { - logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer) - } - if ctx.Err() != nil { - return nil, ctx.Err() - } +func (lb *LoadBalancer) LocalURL() string { + return lb.scheme + "://" + lb.localAddress +} - maxIndex := len(lb.randomServers) - if startIndex > maxIndex { - startIndex = maxIndex - } - if lb.nextServerIndex == startIndex { - if allChecksFailed { - return nil, errors.New("all servers failed") - } - logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName) - allChecksFailed = true - } - } +func (lb *LoadBalancer) ServerAddresses() []string { + return lb.servers.getAddresses() } func onDialError(src net.Conn, dstDialErr error) { @@ -220,10 +189,9 @@ func onDialError(src net.Conn, dstDialErr error) { } // ResetLoadBalancer will delete the local state file for the load balancer on disk -func ResetLoadBalancer(dataDir, serviceName string) error { +func ResetLoadBalancer(dataDir, serviceName string) { stateFile := filepath.Join(dataDir, "etc", serviceName+".json") - if err := os.Remove(stateFile); err != nil { + if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) { logrus.Warn(err) } - return nil } diff --git a/pkg/agent/loadbalancer/loadbalancer_test.go b/pkg/agent/loadbalancer/loadbalancer_test.go index cbfdf982c690..a1f86e467a28 100644 --- a/pkg/agent/loadbalancer/loadbalancer_test.go +++ b/pkg/agent/loadbalancer/loadbalancer_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "slices" "strings" "testing" "time" @@ -111,15 +112,19 @@ func Test_UnitFailOver(t *testing.T) { t.Fatalf("New() failed: %v", err) } - parsedURL, err := url.Parse(lb.LoadBalancerServerURL()) + parsedURL, err := url.Parse(lb.LocalURL()) if err != nil { t.Fatalf("url.Parse failed: %v", err) } localAddress := parsedURL.Host + t.Logf("Adding node1 server: %v", lb.servers.getServers()) + // add the node as a new server address. lb.Update([]string{node1Server.address()}) + t.Logf("Added node1 server: %v", lb.servers.getServers()) + // make sure connections go to the node conn1, err := net.Dial("tcp", localAddress) if err != nil { @@ -146,9 +151,7 @@ func Test_UnitFailOver(t *testing.T) { t.Log("conn1 closed on failure OK") - // make sure connection still goes to the first node - it is failing health checks but so - // is the default endpoint, so it should be tried first with health checks disabled, - // before failing back to the default. + // connections shoould go to the default now that node 1 is failed conn2, err := net.Dial("tcp", localAddress) if err != nil { t.Fatalf("net.Dial failed: %v", err) @@ -156,7 +159,7 @@ func Test_UnitFailOver(t *testing.T) { } if result, err := ping(conn2); err != nil { t.Fatalf("ping(conn2) failed: %v", err) - } else if result != "node1:ping" { + } else if result != "default:ping" { t.Fatalf("Unexpected ping(conn2) result: %v", result) } @@ -168,7 +171,7 @@ func Test_UnitFailOver(t *testing.T) { if result, err := ping(conn2); err != nil { t.Fatalf("ping(conn2) failed: %v", err) - } else if result != "node1:ping" { + } else if result != "default:ping" { t.Fatalf("Unexpected ping(conn2) result: %v", result) } @@ -191,15 +194,13 @@ func Test_UnitFailOver(t *testing.T) { t.Log("conn3 tested OK") - if _, err := ping(conn2); err == nil { - t.Fatal("Unexpected successful ping on closed connection conn2") - } - - t.Log("conn2 closed on failure OK") + t.Logf("Adding node2 server: %v", lb.servers.getServers()) // add the second node as a new server address. lb.Update([]string{node2Server.address()}) + t.Logf("Added node2 server: %v", lb.servers.getServers()) + // make sure connection now goes to the second node, // and connections to the default are closed. conn4, err := net.Dial("tcp", localAddress) @@ -219,11 +220,63 @@ func Test_UnitFailOver(t *testing.T) { // server, connections to the default server should be closed time.Sleep(2 * time.Second) + if _, err := ping(conn2); err == nil { + t.Fatal("Unexpected successful ping on closed connection conn2") + } + + t.Log("conn2 closed on failure OK") + if _, err := ping(conn3); err == nil { t.Fatal("Unexpected successful ping on connection conn3") } t.Log("conn3 closed on failure OK") + + t.Logf("Adding default server: %v", lb.servers.getServers()) + + // add the default as a full server + lb.Update([]string{node2Server.address(), defaultServer.address()}) + + // confirm that both servers are listed in the address list + serverAddresses := lb.ServerAddresses() + if len(serverAddresses) != 2 { + t.Fatalf("Unexpected server address count") + } + + if !slices.Contains(serverAddresses, node2Server.address()) { + t.Fatalf("node2 server not in server address list") + } + + if !slices.Contains(serverAddresses, defaultServer.address()) { + t.Fatalf("default server not in server address list") + } + + // confirm that the default is still listed as default + if lb.servers.getDefaultAddress() != defaultServer.address() { + t.Fatalf("default server is not default") + } + + t.Logf("Default server added OK: %v", lb.servers.getServers()) + + // remove the default as a server + lb.Update([]string{node2Server.address()}) + + // confirm that it is not listed as a server + serverAddresses = lb.ServerAddresses() + if len(serverAddresses) != 1 { + t.Fatalf("Unexpected server address count") + } + + if slices.Contains(serverAddresses, defaultServer.address()) { + t.Fatalf("default server in server address list") + } + + // but is still listed as the default + if lb.servers.getDefaultAddress() != defaultServer.address() { + t.Fatalf("default server is not default") + } + + t.Logf("Default removed added OK: %v", lb.servers.getServers()) } // Test_UnitFailFast confirms that connnections to invalid addresses fail quickly diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index 675bee5c5c86..14f3ce95177a 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -1,118 +1,369 @@ package loadbalancer import ( + "cmp" "context" - "math/rand" + "fmt" "net" "slices" + "sync" "time" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/sets" - "k8s.io/apimachinery/pkg/util/wait" ) -func (lb *LoadBalancer) setServers(serverAddresses []string) bool { - serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress) - if len(serverAddresses) == 0 { +// serverList tracks potential backend servers for use by a loadbalancer. +type serverList struct { + // This mutex protects access to the server list. All direct access to the list should be protected by it. + mutex sync.Mutex + servers []*server +} + +// setServers updates the server list to contain only the selected addresses. +func (sl *serverList) setAddresses(addresses []string) bool { + newAddresses := sets.New(addresses...) + curAddresses := sets.New(sl.getAddresses()...) + if newAddresses.Equal(curAddresses) { return false } - lb.mutex.Lock() - defer lb.mutex.Unlock() + servers := sl.getServers() + defaultServer := sl.getDefaultServer() + closeAllFuncs := []func(){} - newAddresses := sets.NewString(serverAddresses...) - curAddresses := sets.NewString(lb.serverAddresses...) - if newAddresses.Equal(curAddresses) { - return false + // add new servers + for addedAddress := range newAddresses.Difference(curAddresses) { + if defaultServer != nil && defaultServer.address == addedAddress { + // promote the default server up to a full server, instead of adding a new entry for it + defaultServer.state = state_preferred + } else { + servers = append(servers, newServer(addedAddress, false)) + } } - for addedServer := range newAddresses.Difference(curAddresses) { - logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer) - lb.servers[addedServer] = &server{ - address: addedServer, - connections: make(map[net.Conn]struct{}), - healthCheck: func() bool { return true }, + // remove old servers + for removedAddress := range curAddresses.Difference(newAddresses) { + if defaultServer != nil && defaultServer.address == removedAddress { + // demote the default server down to standby, instead of deleting it + defaultServer.state = state_standby + closeAllFuncs = append(closeAllFuncs, defaultServer.closeAll) + } else { + servers = slices.DeleteFunc(servers, func(s *server) bool { + if s.address == removedAddress { + // set state to invalid to prevent server from making additional connections + s.state = state_invalid + closeAllFuncs = append(closeAllFuncs, s.closeAll) + return true + } + return false + }) } } - for removedServer := range curAddresses.Difference(newAddresses) { - server := lb.servers[removedServer] - if server != nil { - logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer) - // Defer closing connections until after the new server list has been put into place. - // Closing open connections ensures that anything stuck retrying on a stale server is forced - // over to a valid endpoint. - defer server.closeAll() - // Don't delete the default server from the server map, in case we need to fall back to it. - if removedServer != lb.defaultServerAddress { - delete(lb.servers, removedServer) - } + slices.SortFunc(servers, func(a, b *server) int { return cmp.Compare(b.state, a.state) }) + + // swap the new server list into place while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + sl.servers = servers + + // Close all connections to servers that were removed + for _, closeAll := range closeAllFuncs { + closeAll() + } + + return true +} + +// getAddresses returns the addresses of all servers. +// If the default server is in standby state, indicating it is only present +// because it is the default, it is not returned in this list. +func (sl *serverList) getAddresses() []string { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + addresses := make([]string, 0, len(sl.servers)) + for _, s := range sl.servers { + if s.isDefault && s.state == state_standby { + continue } + addresses = append(addresses, s.address) } + return addresses +} + +// setDefault sets the server with the provided address as the default server. +// The default flag is cleared on all other servers, and if the server was previously +// only kept in the list because it was the default, it is removed. +func (sl *serverList) setDefaultAddress(address string) { + servers := sl.getServers() - lb.serverAddresses = serverAddresses - lb.randomServers = append([]string{}, lb.serverAddresses...) - rand.Shuffle(len(lb.randomServers), func(i, j int) { - lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] - }) - // If the current server list does not contain the default server address, - // we want to include it in the random server list so that it can be tried if necessary. - // However, it should be treated as always failing health checks so that it is only - // used if all other endpoints are unavailable. - if !hasDefaultServer { - lb.randomServers = append(lb.randomServers, lb.defaultServerAddress) - if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok { - defaultServer.healthCheck = func() bool { return false } - lb.servers[lb.defaultServerAddress] = defaultServer + // deal with existing default first + if i := slices.IndexFunc(servers, func(s *server) bool { return s.isDefault }); i != -1 { + s := servers[i] + s.isDefault = false + if s.state == state_standby { + servers = slices.Delete(servers, i, i) + defer s.closeAll() } } - lb.currentServerAddress = lb.randomServers[0] - lb.nextServerIndex = 1 + // get or create server with selected address + if s := sl.getServer(address); s != nil { + s.isDefault = false + } else { + s = newServer(address, true) + servers = append(servers, s) + } - return true + slices.SortFunc(servers, func(a, b *server) int { return cmp.Compare(b.state, a.state) }) + + // swap the new server list into place while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + sl.servers = servers } -// nextServer attempts to get the next server in the loadbalancer server list. -// If another goroutine has already updated the current server address to point at -// a different address than just failed, nothing is changed. Otherwise, a new server address -// is stored to the currentServerAddress field, and returned for use. -// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex. -func (lb *LoadBalancer) nextServer(failedServer string) (string, error) { - // note: these fields are not protected by the mutex, so we clamp the index value and update - // the index/current address using local variables, to avoid time-of-check vs time-of-use - // race conditions caused by goroutine A incrementing it in between the time goroutine B - // validates its value, and uses it as a list index. - currentServerAddress := lb.currentServerAddress - nextServerIndex := lb.nextServerIndex +// getDefault returns the address of the default server. +func (sl *serverList) getDefaultAddress() string { + if s := sl.getDefaultServer(); s != nil { + return s.address + } + return "" +} + +// getDefault returns the default server. +func (sl *serverList) getDefaultServer() *server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 { + return sl.servers[i] + } + return nil +} + +// getServers returns a copy of the servers list that can be safely iterated over without holding a lock +func (sl *serverList) getServers() []*server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + return slices.Clone(sl.servers) +} + +// getServer returns the first server with the specified address +func (sl *serverList) getServer(address string) *server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 { + return sl.servers[i] + } + return nil +} + +// setHealthCheck updates the health check function for a server, replacing the +// current function. +func (sl *serverList) setHealthCheck(address string, healthCheck func() bool) error { + if s := sl.getServer(address); s != nil { + s.healthCheck = healthCheck + return nil + } + return fmt.Errorf("no server found for %s", address) +} + +// recordSuccess records a successful check of a server, either via health-check or dial. +// The server's state is adjusted accordingly. +func (sl *serverList) recordSuccess(srv *server, r reason) { + var new_active bool + var new_state state + switch srv.state { + case state_failed: + // dialed or health checked OK once, improve to recovering + new_state = state_recovering + case state_recovering: + if r == reason_health { + // only improve from recovering by also passing health check + new_state = state_preferred + } + case state_healthy: + if r == reason_dial { + // improved from healthy to active by being dialed + new_state = state_active + new_active = true + } + case state_preferred: + if r == reason_health { + if time.Now().Sub(srv.lastTransition) > time.Minute { + // has been preferred for a while without being dialed, demote to healthy + new_state = state_healthy + } + } else { + // improved from healthy to active by being dialed + new_state = state_active + new_active = true + } + } + + // no-op if state did not change + if new_state == state_invalid { + return + } + + logrus.Debugf("Server %s state to %s from successful %s", srv, new_state, r) + srv.state = new_state + srv.lastTransition = time.Now() - if len(lb.randomServers) == 0 { - return "", errors.New("No servers in load balancer proxy list") + // handle active transition and sort the server list while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + if new_active { + for _, s := range sl.servers { + switch s.state { + case state_standby: + // close connections to the default server now that we have a new active server + defer s.closeAll() + case state_active: + // warn if another server was still active, it should have been marked failed first + // before the new one went active. + if srv.address != s.address { + logrus.Warnf("Multiple active servers found: current=%s, previous=%s", srv, s) + s.state = state_healthy + } + } + } } - if len(lb.randomServers) == 1 { - return currentServerAddress, nil + slices.SortFunc(sl.servers, func(a, b *server) int { return cmp.Compare(int(b.state), int(a.state)) }) +} + +// recordSuccess records a failed check of a server, either via health-check or dial. +// The server's state is adjusted accordingly. +func (sl *serverList) recordFailure(srv *server, r reason) { + var new_state state + switch srv.state { + case state_recovering: + if r == reason_health { + // only demote from recovering if a dial fails, health checks may continue to fail despite it beig dialable. + // just leave it in recovering and don't close any connections. + new_state = state_failed + } + case state_healthy, state_preferred, state_active: + // should not have any connections when in any state other than active, but close + // them all anyway to force failover. + defer srv.closeAll() + new_state = state_failed } - if failedServer != currentServerAddress { - return currentServerAddress, nil + + // no-op if state did not change + if new_state == state_invalid { + return } - if nextServerIndex >= len(lb.randomServers) { - nextServerIndex = 0 + + logrus.Debugf("Server %s state to %s from failed %s", srv, new_state, r) + srv.state = new_state + srv.lastTransition = time.Now() + + // sort the server list while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + slices.SortFunc(sl.servers, func(a, b *server) int { return cmp.Compare(int(b.state), int(a.state)) }) +} + +// server health states, in increasing order of preference. +// The server list is kept sorted in descending order by this state value. +type state int + +const ( + state_invalid state = iota + state_failed // failed a health check or dial + state_standby // reserved for use by default server if not in server list + state_recovering // successfully health checked once, or dialed when failed + state_healthy // normal state + state_preferred // recently transitioned from recovering; should be preferred as others may go down for maintenance + state_active // currently active server +) + +func (s state) String() string { + switch s { + case state_invalid: + return "INVALID" + case state_failed: + return "FAILED" + case state_standby: + return "STANDBY" + case state_recovering: + return "RECOVERING" + case state_healthy: + return "HEALTHY" + case state_preferred: + return "PREFERRED" + case state_active: + return "ACTIVE" + default: + return "UNKNOWN" } +} - currentServerAddress = lb.randomServers[nextServerIndex] - nextServerIndex++ +type reason int - lb.currentServerAddress = currentServerAddress - lb.nextServerIndex = nextServerIndex +const ( + reason_dial reason = iota + reason_health +) - return currentServerAddress, nil +func (r reason) String() string { + switch r { + case reason_dial: + return "dial" + case reason_health: + return "health check" + default: + return "unknown reason" + } } -// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map -func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) { - conn, err := defaultDialer.Dial(network, address) +// server tracks the connections to a server, so that they can be closed when the server is removed. +type server struct { + // This mutex protects access to the connections map. All direct access to the map should be protected by it. + mutex sync.Mutex + address string + isDefault bool + state state + lastTransition time.Time + healthCheck func() bool + connections map[net.Conn]struct{} +} + +func newServer(address string, isDefault bool) *server { + state := state_preferred + if isDefault { + state = state_standby + } + return &server{ + address: address, + isDefault: isDefault, + state: state, + lastTransition: time.Now(), + healthCheck: func() bool { return true }, + connections: make(map[net.Conn]struct{}), + } +} + +func (s *server) String() string { + format := "%s:%s" + if s.isDefault { + format = "*" + format + } + return fmt.Sprintf(format, s.address, s.state) +} + +// dialContext dials a new connection to the server using the environment's proxy settings, and adds its wrapped connection to the map +func (s *server) dialContext(ctx context.Context, network string) (net.Conn, error) { + if s.state == state_invalid { + return nil, fmt.Errorf("server %s is stopping", s.address) + } + + conn, err := defaultDialer.Dial(network, s.address) if err != nil { return nil, err } @@ -132,7 +383,7 @@ func (s *server) closeAll() { defer s.mutex.Unlock() if l := len(s.connections); l > 0 { - logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address) + logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s) for conn := range s.connections { // Close the connection in a goroutine so that we don't hold the lock while doing so. go conn.Close() @@ -140,6 +391,12 @@ func (s *server) closeAll() { } } +// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. +type serverConn struct { + server *server + net.Conn +} + // Close removes the connection entry from the server's connection map, and // closes the wrapped connection. func (sc *serverConn) Close() error { @@ -149,74 +406,3 @@ func (sc *serverConn) Close() error { delete(sc.server.connections, sc) return sc.Conn.Close() } - -// SetDefault sets the selected address as the default / fallback address -func (lb *LoadBalancer) SetDefault(serverAddress string) { - lb.mutex.Lock() - defer lb.mutex.Unlock() - - hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) - // if the old default server is not currently in use, remove it from the server map - if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer { - defer server.closeAll() - delete(lb.servers, lb.defaultServerAddress) - } - // if the new default server doesn't have an entry in the map, add one - but - // with a failing health check so that it is only used as a last resort. - if _, ok := lb.servers[serverAddress]; !ok { - lb.servers[serverAddress] = &server{ - address: serverAddress, - healthCheck: func() bool { return false }, - connections: make(map[net.Conn]struct{}), - } - } - - lb.defaultServerAddress = serverAddress - logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress) -} - -// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. -func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) { - lb.mutex.Lock() - defer lb.mutex.Unlock() - - if server := lb.servers[address]; server != nil { - logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address) - server.healthCheck = healthCheck - } else { - logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address) - } -} - -// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their -// connections closed, to force clients to switch over to a healthy server. -func (lb *LoadBalancer) runHealthChecks(ctx context.Context) { - previousStatus := map[string]bool{} - wait.Until(func() { - lb.mutex.RLock() - defer lb.mutex.RUnlock() - var healthyServerExists bool - for address, server := range lb.servers { - status := server.healthCheck() - healthyServerExists = healthyServerExists || status - if status == false && previousStatus[address] == true { - // Only close connections when the server transitions from healthy to unhealthy; - // we don't want to re-close all the connections every time as we might be ignoring - // health checks due to all servers being marked unhealthy. - defer server.closeAll() - } - previousStatus[address] = status - } - - // If there is at least one healthy server, and the default server is not in the server list, - // close all the connections to the default server so that clients reconnect and switch over - // to a preferred server. - hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) - if healthyServerExists && !hasDefaultServer { - if server, ok := lb.servers[lb.defaultServerAddress]; ok { - defer server.closeAll() - } - } - }, time.Second, ctx.Done()) - logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) -} diff --git a/pkg/agent/proxy/apiproxy.go b/pkg/agent/proxy/apiproxy.go index e711623e467e..89fabf599964 100644 --- a/pkg/agent/proxy/apiproxy.go +++ b/pkg/agent/proxy/apiproxy.go @@ -52,7 +52,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor return nil, err } p.supervisorLB = lb - p.supervisorURL = lb.LoadBalancerServerURL() + p.supervisorURL = lb.LocalURL() p.apiServerURL = p.supervisorURL } @@ -155,7 +155,7 @@ func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error { return err } p.apiServerLB = lb - p.apiServerURL = lb.LoadBalancerServerURL() + p.apiServerURL = lb.LocalURL() } else { p.apiServerURL = u.String() } diff --git a/pkg/etcd/etcdproxy.go b/pkg/etcd/etcdproxy.go index 57a2e48c80c1..cc615fddbed0 100644 --- a/pkg/etcd/etcdproxy.go +++ b/pkg/etcd/etcdproxy.go @@ -52,7 +52,7 @@ func NewETCDProxy(ctx context.Context, supervisorPort int, dataDir, etcdURL stri return nil, err } e.etcdLB = lb - e.etcdLBURL = lb.LoadBalancerServerURL() + e.etcdLBURL = lb.LocalURL() e.fallbackETCDAddress = u.Host e.etcdPort = u.Port()