From b7eb2c1c7d876668311f8d18810972dc8f7f0a3e Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Wed, 20 Nov 2024 18:53:09 +0000 Subject: [PATCH 01/14] Remove unused code from etcdproxy None of these fields or functions are used in k3s or rke2 Signed-off-by: Brad Davidson (cherry picked from commit f2f57b4a4b00ff80fdc96676f462a204150a0007) Signed-off-by: Brad Davidson --- pkg/etcd/etcdproxy.go | 58 ++++++------------------------------------- 1 file changed, 8 insertions(+), 50 deletions(-) diff --git a/pkg/etcd/etcdproxy.go b/pkg/etcd/etcdproxy.go index 55918850b3ff..141ba679b580 100644 --- a/pkg/etcd/etcdproxy.go +++ b/pkg/etcd/etcdproxy.go @@ -6,21 +6,16 @@ import ( "fmt" "net" "net/http" - "net/url" "strconv" "time" "github.com/k3s-io/k3s/pkg/agent/loadbalancer" - "github.com/pkg/errors" "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/wait" ) type Proxy interface { Update(addresses []string) - ETCDURL() string - ETCDAddresses() []string - ETCDServerURL() string } var httpClient = &http.Client{ @@ -34,44 +29,22 @@ var httpClient = &http.Client{ // NewETCDProxy initializes a new proxy structure that contain a load balancer // which listens on port 2379 and proxy between etcd cluster members func NewETCDProxy(ctx context.Context, supervisorPort int, dataDir, etcdURL string, isIPv6 bool) (Proxy, error) { - u, err := url.Parse(etcdURL) - if err != nil { - return nil, errors.Wrap(err, "failed to parse etcd client URL") - } - - e := &etcdproxy{ - dataDir: dataDir, - initialETCDURL: etcdURL, - etcdURL: etcdURL, - supervisorPort: supervisorPort, - disconnect: map[string]context.CancelFunc{}, - } - lb, err := loadbalancer.New(ctx, dataDir, loadbalancer.ETCDServerServiceName, etcdURL, 2379, isIPv6) if err != nil { return nil, err } - e.etcdLB = lb - e.etcdLBURL = lb.LoadBalancerServerURL() - e.fallbackETCDAddress = u.Host - e.etcdPort = u.Port() - - return e, nil + return &etcdproxy{ + supervisorPort: supervisorPort, + etcdLB: lb, + disconnect: map[string]context.CancelFunc{}, + }, nil } type etcdproxy struct { - dataDir string - etcdLBURL string - - supervisorPort int - initialETCDURL string - etcdURL string - etcdPort string - fallbackETCDAddress string - etcdAddresses []string - etcdLB *loadbalancer.LoadBalancer - disconnect map[string]context.CancelFunc + supervisorPort int + etcdLB *loadbalancer.LoadBalancer + disconnect map[string]context.CancelFunc } func (e *etcdproxy) Update(addresses []string) { @@ -95,21 +68,6 @@ func (e *etcdproxy) Update(addresses []string) { } } -func (e *etcdproxy) ETCDURL() string { - return e.etcdURL -} - -func (e *etcdproxy) ETCDAddresses() []string { - if len(e.etcdAddresses) > 0 { - return e.etcdAddresses - } - return []string{e.fallbackETCDAddress} -} - -func (e *etcdproxy) ETCDServerURL() string { - return e.etcdURL -} - // start a polling routine that makes periodic requests to the etcd node's supervisor port. // If the request fails, the node is marked unhealthy. func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() bool { From 01bdb070847818e52936e178f0cb0c24a20d8c46 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 15 Nov 2024 00:50:31 +0000 Subject: [PATCH 02/14] Move http/socks proxy stuff to separate file Signed-off-by: Brad Davidson (cherry picked from commit 13e911378764cafb98030ebe80832739ae5ce87e) Signed-off-by: Brad Davidson --- pkg/agent/loadbalancer/httpproxy.go | 70 +++++++++++++++++++ .../{servers_test.go => httpproxy_test.go} | 0 pkg/agent/loadbalancer/servers.go | 61 ---------------- 3 files changed, 70 insertions(+), 61 deletions(-) create mode 100644 pkg/agent/loadbalancer/httpproxy.go rename pkg/agent/loadbalancer/{servers_test.go => httpproxy_test.go} (100%) diff --git a/pkg/agent/loadbalancer/httpproxy.go b/pkg/agent/loadbalancer/httpproxy.go new file mode 100644 index 000000000000..f14859bfe71c --- /dev/null +++ b/pkg/agent/loadbalancer/httpproxy.go @@ -0,0 +1,70 @@ +package loadbalancer + +import ( + "fmt" + "net" + "net/url" + "os" + "strconv" + "time" + + "github.com/k3s-io/k3s/pkg/version" + http_dialer "github.com/mwitkow/go-http-dialer" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "golang.org/x/net/http/httpproxy" + "golang.org/x/net/proxy" +) + +var defaultDialer proxy.Dialer = &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, +} + +// SetHTTPProxy configures a proxy-enabled dialer to be used for all loadbalancer connections, +// if the agent has been configured to allow use of a HTTP proxy, and the environment has been configured +// to indicate use of a HTTP proxy for the server URL. +func SetHTTPProxy(address string) error { + // Check if env variable for proxy is set + if useProxy, _ := strconv.ParseBool(os.Getenv(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED")); !useProxy || address == "" { + return nil + } + + serverURL, err := url.Parse(address) + if err != nil { + return errors.Wrapf(err, "failed to parse address %s", address) + } + + // Call this directly instead of using the cached environment used by http.ProxyFromEnvironment to allow for testing + proxyFromEnvironment := httpproxy.FromEnvironment().ProxyFunc() + proxyURL, err := proxyFromEnvironment(serverURL) + if err != nil { + return errors.Wrapf(err, "failed to get proxy for address %s", address) + } + if proxyURL == nil { + logrus.Debug(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED is true but no proxy is configured for URL " + serverURL.String()) + return nil + } + + dialer, err := proxyDialer(proxyURL, defaultDialer) + if err != nil { + return errors.Wrapf(err, "failed to create proxy dialer for %s", proxyURL) + } + + defaultDialer = dialer + logrus.Debugf("Using proxy %s for agent connection to %s", proxyURL, serverURL) + return nil +} + +// proxyDialer creates a new proxy.Dialer that routes connections through the specified proxy. +func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { + if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { + // Create a new HTTP proxy dialer + httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer))) + return httpProxyDialer, nil + } else if proxyURL.Scheme == "socks5" { + // For SOCKS5 proxies, use the proxy package's FromURL + return proxy.FromURL(proxyURL, forward) + } + return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) +} diff --git a/pkg/agent/loadbalancer/servers_test.go b/pkg/agent/loadbalancer/httpproxy_test.go similarity index 100% rename from pkg/agent/loadbalancer/servers_test.go rename to pkg/agent/loadbalancer/httpproxy_test.go diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index 660810525470..a0bfa3550cf2 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -2,66 +2,18 @@ package loadbalancer import ( "context" - "fmt" "math/rand" "net" - "net/url" - "os" "slices" - "strconv" "time" - "github.com/k3s-io/k3s/pkg/version" - http_dialer "github.com/mwitkow/go-http-dialer" "github.com/pkg/errors" - "golang.org/x/net/http/httpproxy" - "golang.org/x/net/proxy" "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" ) -var defaultDialer proxy.Dialer = &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, -} - -// SetHTTPProxy configures a proxy-enabled dialer to be used for all loadbalancer connections, -// if the agent has been configured to allow use of a HTTP proxy, and the environment has been configured -// to indicate use of a HTTP proxy for the server URL. -func SetHTTPProxy(address string) error { - // Check if env variable for proxy is set - if useProxy, _ := strconv.ParseBool(os.Getenv(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED")); !useProxy || address == "" { - return nil - } - - serverURL, err := url.Parse(address) - if err != nil { - return errors.Wrapf(err, "failed to parse address %s", address) - } - - // Call this directly instead of using the cached environment used by http.ProxyFromEnvironment to allow for testing - proxyFromEnvironment := httpproxy.FromEnvironment().ProxyFunc() - proxyURL, err := proxyFromEnvironment(serverURL) - if err != nil { - return errors.Wrapf(err, "failed to get proxy for address %s", address) - } - if proxyURL == nil { - logrus.Debug(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED is true but no proxy is configured for URL " + serverURL.String()) - return nil - } - - dialer, err := proxyDialer(proxyURL, defaultDialer) - if err != nil { - return errors.Wrapf(err, "failed to create proxy dialer for %s", proxyURL) - } - - defaultDialer = dialer - logrus.Debugf("Using proxy %s for agent connection to %s", proxyURL, serverURL) - return nil -} - func (lb *LoadBalancer) setServers(serverAddresses []string) bool { serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress) if len(serverAddresses) == 0 { @@ -174,19 +126,6 @@ func (s *server) dialContext(ctx context.Context, network, address string) (net. return wrappedConn, nil } -// proxyDialer creates a new proxy.Dialer that routes connections through the specified proxy. -func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { - if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Create a new HTTP proxy dialer - httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer))) - return httpProxyDialer, nil - } else if proxyURL.Scheme == "socks5" { - // For SOCKS5 proxies, use the proxy package's FromURL - return proxy.FromURL(proxyURL, forward) - } - return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) -} - // closeAll closes all connections to the server, and removes their entries from the map func (s *server) closeAll() { s.mutex.Lock() From c0a44a1b51c8b93ca0e94e4d39698c89dc5c9283 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 15 Nov 2024 01:32:12 +0000 Subject: [PATCH 03/14] Separate persistent config struct from LoadBalancer and make fields private Signed-off-by: Brad Davidson (cherry picked from commit 67fd5fa9e5319e0880dd24d4dd188184e923b954) Signed-off-by: Brad Davidson --- pkg/agent/loadbalancer/config.go | 16 +++++++++++++--- pkg/agent/loadbalancer/loadbalancer.go | 18 ++++++++++++------ pkg/agent/loadbalancer/servers.go | 10 +++++----- pkg/etcd/etcdproxy.go | 2 +- 4 files changed, 31 insertions(+), 15 deletions(-) diff --git a/pkg/agent/loadbalancer/config.go b/pkg/agent/loadbalancer/config.go index 1620c8ab6bbc..9a2de3214fbb 100644 --- a/pkg/agent/loadbalancer/config.go +++ b/pkg/agent/loadbalancer/config.go @@ -7,8 +7,18 @@ import ( "github.com/k3s-io/k3s/pkg/agent/util" ) +// lbConfig stores loadbalancer state that should be persisted across restarts. +type lbConfig struct { + ServerURL string `json:"ServerURL"` + ServerAddresses []string `json:"ServerAddresses"` +} + func (lb *LoadBalancer) writeConfig() error { - configOut, err := json.MarshalIndent(lb, "", " ") + config := &lbConfig{ + ServerURL: lb.serverURL, + ServerAddresses: lb.serverAddresses, + } + configOut, err := json.MarshalIndent(config, "", " ") if err != nil { return err } @@ -18,9 +28,9 @@ func (lb *LoadBalancer) writeConfig() error { func (lb *LoadBalancer) updateConfig() error { writeConfig := true if configBytes, err := os.ReadFile(lb.configFile); err == nil { - config := &LoadBalancer{} + config := &lbConfig{} if err := json.Unmarshal(configBytes, config); err == nil { - if config.ServerURL == lb.ServerURL { + if config.ServerURL == lb.serverURL { writeConfig = false lb.setServers(config.ServerAddresses) } diff --git a/pkg/agent/loadbalancer/loadbalancer.go b/pkg/agent/loadbalancer/loadbalancer.go index 6689a9e7ca39..db9fa6f16f72 100644 --- a/pkg/agent/loadbalancer/loadbalancer.go +++ b/pkg/agent/loadbalancer/loadbalancer.go @@ -45,13 +45,12 @@ type LoadBalancer struct { localAddress string localServerURL string defaultServerAddress string - ServerURL string - ServerAddresses []string + serverURL string + serverAddresses []string randomServers []string servers map[string]*server currentServerAddress string nextServerIndex int - Listener net.Listener } const RandomPort = 0 @@ -105,7 +104,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo localServerURL: localServerURL, defaultServerAddress: defaultServerAddress, servers: make(map[string]*server), - ServerURL: serverURL, + serverURL: serverURL, } lb.setServers([]string{lb.defaultServerAddress}) @@ -127,7 +126,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo if err := lb.proxy.Start(); err != nil { return nil, err } - logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.ServerAddresses, lb.defaultServerAddress) + logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress) go lb.runHealthChecks(ctx) @@ -141,7 +140,7 @@ func (lb *LoadBalancer) Update(serverAddresses []string) { if !lb.setServers(serverAddresses) { return } - logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.ServerAddresses, lb.defaultServerAddress) + logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress) if err := lb.writeConfig(); err != nil { logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) @@ -155,6 +154,13 @@ func (lb *LoadBalancer) LoadBalancerServerURL() string { return lb.localServerURL } +func (lb *LoadBalancer) ServerAddresses() []string { + if lb == nil { + return nil + } + return lb.serverAddresses +} + func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) { lb.mutex.RLock() defer lb.mutex.RUnlock() diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index a0bfa3550cf2..675bee5c5c86 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -24,7 +24,7 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool { defer lb.mutex.Unlock() newAddresses := sets.NewString(serverAddresses...) - curAddresses := sets.NewString(lb.ServerAddresses...) + curAddresses := sets.NewString(lb.serverAddresses...) if newAddresses.Equal(curAddresses) { return false } @@ -53,8 +53,8 @@ func (lb *LoadBalancer) setServers(serverAddresses []string) bool { } } - lb.ServerAddresses = serverAddresses - lb.randomServers = append([]string{}, lb.ServerAddresses...) + lb.serverAddresses = serverAddresses + lb.randomServers = append([]string{}, lb.serverAddresses...) rand.Shuffle(len(lb.randomServers), func(i, j int) { lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] }) @@ -155,7 +155,7 @@ func (lb *LoadBalancer) SetDefault(serverAddress string) { lb.mutex.Lock() defer lb.mutex.Unlock() - hasDefaultServer := slices.Contains(lb.ServerAddresses, lb.defaultServerAddress) + hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) // if the old default server is not currently in use, remove it from the server map if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer { defer server.closeAll() @@ -211,7 +211,7 @@ func (lb *LoadBalancer) runHealthChecks(ctx context.Context) { // If there is at least one healthy server, and the default server is not in the server list, // close all the connections to the default server so that clients reconnect and switch over // to a preferred server. - hasDefaultServer := slices.Contains(lb.ServerAddresses, lb.defaultServerAddress) + hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) if healthyServerExists && !hasDefaultServer { if server, ok := lb.servers[lb.defaultServerAddress]; ok { defer server.closeAll() diff --git a/pkg/etcd/etcdproxy.go b/pkg/etcd/etcdproxy.go index 141ba679b580..ec781e11a3ae 100644 --- a/pkg/etcd/etcdproxy.go +++ b/pkg/etcd/etcdproxy.go @@ -51,7 +51,7 @@ func (e *etcdproxy) Update(addresses []string) { e.etcdLB.Update(addresses) validEndpoint := map[string]bool{} - for _, address := range e.etcdLB.ServerAddresses { + for _, address := range e.etcdLB.ServerAddresses() { validEndpoint[address] = true if _, ok := e.disconnect[address]; !ok { ctx, cancel := context.WithCancel(context.Background()) From fc3a2447a2dd9bde82b87f3a15c28577af4f27b4 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Sun, 17 Nov 2024 23:49:57 +0000 Subject: [PATCH 04/14] Refactor filterCN to use a Set instead of map[string]bool Signed-off-by: Brad Davidson (cherry picked from commit 95797c4a79de4ee712d9d17a62f0446471823a71) Signed-off-by: Brad Davidson --- pkg/cluster/address_controller.go | 35 ++++++++++--------------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/pkg/cluster/address_controller.go b/pkg/cluster/address_controller.go index 780942d0d3ae..bb73a20deac4 100644 --- a/pkg/cluster/address_controller.go +++ b/pkg/cluster/address_controller.go @@ -8,20 +8,17 @@ import ( controllerv1 "github.com/rancher/wrangler/v3/pkg/generated/controllers/core/v1" "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/sets" ) func registerAddressHandlers(ctx context.Context, c *Cluster) { nodes := c.config.Runtime.Core.Core().V1().Node() a := &addressesHandler{ nodeController: nodes, - allowed: map[string]bool{}, + allowed: sets.New(c.config.SANs...), } - for _, cn := range c.config.SANs { - a.allowed[cn] = true - } - - logrus.Infof("Starting dynamiclistener CN filter node controller") + logrus.Infof("Starting dynamiclistener CN filter node controller with SANs: %v", c.config.SANs) nodes.OnChange(ctx, "server-cn-filter", a.sync) c.cnFilterFunc = a.filterCN } @@ -30,40 +27,30 @@ type addressesHandler struct { sync.RWMutex nodeController controllerv1.NodeController - allowed map[string]bool + allowed sets.Set[string] } // filterCN filters a list of potential server CNs (hostnames or IPs), removing any which do not correspond to // valid cluster servers (control-plane or etcd), or an address explicitly added via the tls-san option. func (a *addressesHandler) filterCN(cns ...string) []string { - if !a.nodeController.Informer().HasSynced() { + if len(cns) == 0 || !a.nodeController.Informer().HasSynced() { return cns } a.RLock() defer a.RUnlock() - filteredCNs := make([]string, 0, len(cns)) - for _, cn := range cns { - if a.allowed[cn] { - filteredCNs = append(filteredCNs, cn) - } else { - logrus.Debugf("CN filter controller rejecting certificate CN: %s", cn) - } - } - return filteredCNs + return a.allowed.Intersection(sets.New(cns...)).UnsortedList() } // sync updates the allowed address list to include addresses for the node func (a *addressesHandler) sync(key string, node *v1.Node) (*v1.Node, error) { - if node != nil { - if node.Labels[util.ControlPlaneRoleLabelKey] != "" || node.Labels[util.ETCDRoleLabelKey] != "" { - a.Lock() - defer a.Unlock() + if node != nil && (node.Labels[util.ControlPlaneRoleLabelKey] != "" || node.Labels[util.ETCDRoleLabelKey] != "") { + a.Lock() + defer a.Unlock() - for _, address := range node.Status.Addresses { - a.allowed[address.String()] = true - } + for _, address := range node.Status.Addresses { + a.allowed.Insert(address.String()) } } return node, nil From 90662bc1f349c64463b01ffcc61ef162e26f5009 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 15 Nov 2024 22:11:47 +0000 Subject: [PATCH 05/14] Refactor load balancer server list and health checking Signed-off-by: Brad Davidson (cherry picked from commit 911ee19a93a43ed467c1b20070ae213846747887) Signed-off-by: Brad Davidson --- pkg/agent/loadbalancer/config.go | 21 +- pkg/agent/loadbalancer/httpproxy.go | 2 +- pkg/agent/loadbalancer/httpproxy_test.go | 6 +- pkg/agent/loadbalancer/loadbalancer.go | 174 ++---- pkg/agent/loadbalancer/loadbalancer_test.go | 567 ++++++++++++-------- pkg/agent/loadbalancer/servers.go | 555 ++++++++++++++----- pkg/agent/proxy/apiproxy.go | 8 +- pkg/agent/tunnel/tunnel.go | 25 +- pkg/etcd/etcdproxy.go | 22 +- 9 files changed, 864 insertions(+), 516 deletions(-) diff --git a/pkg/agent/loadbalancer/config.go b/pkg/agent/loadbalancer/config.go index 9a2de3214fbb..b7d8f63f9d10 100644 --- a/pkg/agent/loadbalancer/config.go +++ b/pkg/agent/loadbalancer/config.go @@ -15,8 +15,8 @@ type lbConfig struct { func (lb *LoadBalancer) writeConfig() error { config := &lbConfig{ - ServerURL: lb.serverURL, - ServerAddresses: lb.serverAddresses, + ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(), + ServerAddresses: lb.servers.getAddresses(), } configOut, err := json.MarshalIndent(config, "", " ") if err != nil { @@ -26,20 +26,17 @@ func (lb *LoadBalancer) writeConfig() error { } func (lb *LoadBalancer) updateConfig() error { - writeConfig := true if configBytes, err := os.ReadFile(lb.configFile); err == nil { config := &lbConfig{} if err := json.Unmarshal(configBytes, config); err == nil { - if config.ServerURL == lb.serverURL { - writeConfig = false - lb.setServers(config.ServerAddresses) + // if the default server from the config matches our current default, + // load the rest of the addresses as well. + if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() { + lb.Update(config.ServerAddresses) + return nil } } } - if writeConfig { - if err := lb.writeConfig(); err != nil { - return err - } - } - return nil + // config didn't exist or used a different default server, write the current config to disk. + return lb.writeConfig() } diff --git a/pkg/agent/loadbalancer/httpproxy.go b/pkg/agent/loadbalancer/httpproxy.go index f14859bfe71c..ea9711824975 100644 --- a/pkg/agent/loadbalancer/httpproxy.go +++ b/pkg/agent/loadbalancer/httpproxy.go @@ -60,7 +60,7 @@ func SetHTTPProxy(address string) error { func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { // Create a new HTTP proxy dialer - httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer))) + httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithConnectionTimeout(10*time.Second), http_dialer.WithDialer(forward.(*net.Dialer))) return httpProxyDialer, nil } else if proxyURL.Scheme == "socks5" { // For SOCKS5 proxies, use the proxy package's FromURL diff --git a/pkg/agent/loadbalancer/httpproxy_test.go b/pkg/agent/loadbalancer/httpproxy_test.go index c8b8b5b924bb..07f72e927e77 100644 --- a/pkg/agent/loadbalancer/httpproxy_test.go +++ b/pkg/agent/loadbalancer/httpproxy_test.go @@ -2,15 +2,16 @@ package loadbalancer import ( "fmt" - "net" "os" "strings" "testing" "github.com/k3s-io/k3s/pkg/version" "github.com/sirupsen/logrus" + "golang.org/x/net/proxy" ) +var originalDialer proxy.Dialer var defaultEnv map[string]string var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"} @@ -19,7 +20,7 @@ func init() { } func prepareEnv(env ...string) { - defaultDialer = &net.Dialer{} + originalDialer = defaultDialer defaultEnv = map[string]string{} for _, e := range proxyEnvs { if v, ok := os.LookupEnv(e); ok { @@ -34,6 +35,7 @@ func prepareEnv(env ...string) { } func restoreEnv() { + defaultDialer = originalDialer for _, e := range proxyEnvs { if v, ok := defaultEnv[e]; ok { os.Setenv(e, v) diff --git a/pkg/agent/loadbalancer/loadbalancer.go b/pkg/agent/loadbalancer/loadbalancer.go index db9fa6f16f72..2f6d33fbf4c2 100644 --- a/pkg/agent/loadbalancer/loadbalancer.go +++ b/pkg/agent/loadbalancer/loadbalancer.go @@ -2,55 +2,29 @@ package loadbalancer import ( "context" - "errors" "fmt" "net" + "net/url" "os" "path/filepath" - "sync" - "time" + "strings" "github.com/inetaf/tcpproxy" "github.com/k3s-io/k3s/pkg/version" "github.com/sirupsen/logrus" ) -// server tracks the connections to a server, so that they can be closed when the server is removed. -type server struct { - // This mutex protects access to the connections map. All direct access to the map should be protected by it. - mutex sync.Mutex - address string - healthCheck func() bool - connections map[net.Conn]struct{} -} - -// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. -type serverConn struct { - server *server - net.Conn -} - // LoadBalancer holds data for a local listener which forwards connections to a // pool of remote servers. It is not a proper load-balancer in that it does not // actually balance connections, but instead fails over to a new server only // when a connection attempt to the currently selected server fails. type LoadBalancer struct { - // This mutex protects access to servers map and randomServers list. - // All direct access to the servers map/list should be protected by it. - mutex sync.RWMutex - proxy *tcpproxy.Proxy - - serviceName string - configFile string - localAddress string - localServerURL string - defaultServerAddress string - serverURL string - serverAddresses []string - randomServers []string - servers map[string]*server - currentServerAddress string - nextServerIndex int + serviceName string + configFile string + scheme string + localAddress string + servers serverList + proxy *tcpproxy.Proxy } const RandomPort = 0 @@ -63,7 +37,7 @@ var ( // New contstructs a new LoadBalancer instance. The default server URL, and // currently active servers, are stored in a file within the dataDir. -func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { +func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { config := net.ListenConfig{Control: reusePort} var localAddress string if isIPv6 { @@ -84,30 +58,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo return nil, err } - // if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with. - localAddress = listener.Addr().String() - - defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress) + serverURL, err := url.Parse(defaultServerURL) if err != nil { return nil, err } - if serverURL == localServerURL { - logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) - defaultServerAddress = "" + // Set explicit port from scheme + if serverURL.Port() == "" { + if strings.ToLower(serverURL.Scheme) == "http" { + serverURL.Host += ":80" + } + if strings.ToLower(serverURL.Scheme) == "https" { + serverURL.Host += ":443" + } } lb := &LoadBalancer{ - serviceName: serviceName, - configFile: filepath.Join(dataDir, "etc", serviceName+".json"), - localAddress: localAddress, - localServerURL: localServerURL, - defaultServerAddress: defaultServerAddress, - servers: make(map[string]*server), - serverURL: serverURL, + serviceName: serviceName, + configFile: filepath.Join(dataDir, "etc", serviceName+".json"), + scheme: serverURL.Scheme, + localAddress: listener.Addr().String(), } - lb.setServers([]string{lb.defaultServerAddress}) + // if starting pointing at ourselves, don't set a default server address, + // which will cause all dials to fail until servers are added. + if serverURL.Host == lb.localAddress { + logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) + } else { + lb.servers.setDefaultAddress(lb.serviceName, serverURL.Host) + } lb.proxy = &tcpproxy.Proxy{ ListenFunc: func(string, string) (net.Listener, error) { @@ -116,7 +95,7 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo } lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{ Addr: serviceName, - DialContext: lb.dialContext, + DialContext: lb.servers.dialContext, OnDialError: onDialError, }) @@ -126,92 +105,50 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo if err := lb.proxy.Start(); err != nil { return nil, err } - logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.serverAddresses, lb.defaultServerAddress) + logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress()) - go lb.runHealthChecks(ctx) + go lb.servers.runHealthChecks(ctx, lb.serviceName) return lb, nil } +// Update updates the list of server addresses to contain only the listed servers. func (lb *LoadBalancer) Update(serverAddresses []string) { - if lb == nil { - return - } - if !lb.setServers(serverAddresses) { + if !lb.servers.setAddresses(lb.serviceName, serverAddresses) { return } - logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.serverAddresses, lb.defaultServerAddress) + + logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress()) if err := lb.writeConfig(); err != nil { logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) } } -func (lb *LoadBalancer) LoadBalancerServerURL() string { - if lb == nil { - return "" +// SetDefault sets the selected address as the default / fallback address +func (lb *LoadBalancer) SetDefault(serverAddress string) { + lb.servers.setDefaultAddress(lb.serviceName, serverAddress) + + if err := lb.writeConfig(); err != nil { + logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) } - return lb.localServerURL } -func (lb *LoadBalancer) ServerAddresses() []string { - if lb == nil { - return nil +// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. +func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck HealthCheckFunc) { + if err := lb.servers.setHealthCheck(address, healthCheck); err != nil { + logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err) + } else { + logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address) } - return lb.serverAddresses } -func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) { - lb.mutex.RLock() - defer lb.mutex.RUnlock() - - var allChecksFailed bool - startIndex := lb.nextServerIndex - for { - targetServer := lb.currentServerAddress - - server := lb.servers[targetServer] - if server == nil || targetServer == "" { - logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer) - } else if allChecksFailed || server.healthCheck() { - dialTime := time.Now() - conn, err := server.dialContext(ctx, network, targetServer) - if err == nil { - return conn, nil - } - logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err) - // Don't close connections to the failed server if we're retrying with health checks ignored. - // We don't want to disrupt active connections if it is unlikely they will have anywhere to go. - if !allChecksFailed { - defer server.closeAll() - } - } else { - logrus.Debugf("Dial health check failed for %s", targetServer) - } - - newServer, err := lb.nextServer(targetServer) - if err != nil { - return nil, err - } - if targetServer != newServer { - logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer) - } - if ctx.Err() != nil { - return nil, ctx.Err() - } +func (lb *LoadBalancer) LocalURL() string { + return lb.scheme + "://" + lb.localAddress +} - maxIndex := len(lb.randomServers) - if startIndex > maxIndex { - startIndex = maxIndex - } - if lb.nextServerIndex == startIndex { - if allChecksFailed { - return nil, errors.New("all servers failed") - } - logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName) - allChecksFailed = true - } - } +func (lb *LoadBalancer) ServerAddresses() []string { + return lb.servers.getAddresses() } func onDialError(src net.Conn, dstDialErr error) { @@ -220,10 +157,9 @@ func onDialError(src net.Conn, dstDialErr error) { } // ResetLoadBalancer will delete the local state file for the load balancer on disk -func ResetLoadBalancer(dataDir, serviceName string) error { +func ResetLoadBalancer(dataDir, serviceName string) { stateFile := filepath.Join(dataDir, "etc", serviceName+".json") - if err := os.Remove(stateFile); err != nil { + if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) { logrus.Warn(err) } - return nil } diff --git a/pkg/agent/loadbalancer/loadbalancer_test.go b/pkg/agent/loadbalancer/loadbalancer_test.go index cbfdf982c690..69b4fca10cab 100644 --- a/pkg/agent/loadbalancer/loadbalancer_test.go +++ b/pkg/agent/loadbalancer/loadbalancer_test.go @@ -5,19 +5,29 @@ import ( "context" "fmt" "net" - "net/url" + "strconv" "strings" "testing" "time" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" "github.com/sirupsen/logrus" ) +func Test_UnitLoadBalancer(t *testing.T) { + _, reporterConfig := GinkgoConfiguration() + reporterConfig.Verbose = testing.Verbose() + RegisterFailHandler(Fail) + RunSpecs(t, "LoadBalancer Suite", reporterConfig) +} + func init() { logrus.SetLevel(logrus.DebugLevel) } type testServer struct { + address string listener net.Listener conns []net.Conn prefix string @@ -31,6 +41,7 @@ func createServer(ctx context.Context, prefix string) (*testServer, error) { s := &testServer{ prefix: prefix, listener: listener, + address: listener.Addr().String(), } go s.serve() go func() { @@ -53,6 +64,7 @@ func (s *testServer) serve() { func (s *testServer) close() { logrus.Printf("testServer %s closing", s.prefix) + s.address = "" s.listener.Close() for _, conn := range s.conns { conn.Close() @@ -69,10 +81,6 @@ func (s *testServer) echo(conn net.Conn) { } } -func (s *testServer) address() string { - return s.listener.Addr().String() -} - func ping(conn net.Conn) (string, error) { fmt.Fprintf(conn, "ping\n") result, err := bufio.NewReader(conn).ReadString('\n') @@ -82,221 +90,340 @@ func ping(conn net.Conn) (string, error) { return strings.TrimSpace(result), nil } -// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint) -// and then adds a new server (a node). The node server is then closed, and it is confirmed -// that new connections use the default server. -func Test_UnitFailOver(t *testing.T) { - tmpDir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - defaultServer, err := createServer(ctx, "default") - if err != nil { - t.Fatalf("createServer(default) failed: %v", err) - } - - node1Server, err := createServer(ctx, "node1") - if err != nil { - t.Fatalf("createServer(node1) failed: %v", err) - } +var _ = Describe("LoadBalancer", func() { + // creates a LB using a default server (ie fixed registration endpoint) + // and then adds a new server (a node). The node server is then closed, and it is confirmed + // that new connections use the default server. + When("loadbalancer is running", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var defaultServer, node1Server, node2Server *testServer + var conn1, conn2, conn3, conn4 net.Conn + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + + defaultServer, err = createServer(ctx, "default") + Expect(err).NotTo(HaveOccurred(), "createServer(default) failed") + + node1Server, err = createServer(ctx, "node1") + Expect(err).NotTo(HaveOccurred(), "createServer(node1) failed") + + node2Server, err = createServer(ctx, "node2") + Expect(err).NotTo(HaveOccurred(), "createServer(node2) failed") + + // start the loadbalancer with the default server as the only server + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address, RandomPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("adds node1 as a server", func() { + // add the node as a new server address. + lb.Update([]string{node1Server.address}) + lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK }) + + By(fmt.Sprintf("Added node1 server: %v", lb.servers.getServers())) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(node1Server.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + }) + + It("connects to node1", func() { + // make sure connections go to the node + conn1, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + Expect(ping(conn1)).To(Equal("node1:ping"), "Unexpected ping(conn1) result") - node2Server, err := createServer(ctx, "node2") - if err != nil { - t.Fatalf("createServer(node2) failed: %v", err) - } - - // start the loadbalancer with the default server as the only server - lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - - parsedURL, err := url.Parse(lb.LoadBalancerServerURL()) - if err != nil { - t.Fatalf("url.Parse failed: %v", err) - } - localAddress := parsedURL.Host - - // add the node as a new server address. - lb.Update([]string{node1Server.address()}) - - // make sure connections go to the node - conn1, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - } - if result, err := ping(conn1); err != nil { - t.Fatalf("ping(conn1) failed: %v", err) - } else if result != "node1:ping" { - t.Fatalf("Unexpected ping(conn1) result: %v", result) - } - - t.Log("conn1 tested OK") - - // set failing health check for node 1 - lb.SetHealthCheck(node1Server.address(), func() bool { return false }) - - // Server connections are checked every second, now that node 1 is failed - // the connections to it should be closed. - time.Sleep(2 * time.Second) - - if _, err := ping(conn1); err == nil { - t.Fatal("Unexpected successful ping on closed connection conn1") - } - - t.Log("conn1 closed on failure OK") - - // make sure connection still goes to the first node - it is failing health checks but so - // is the default endpoint, so it should be tried first with health checks disabled, - // before failing back to the default. - conn2, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - - } - if result, err := ping(conn2); err != nil { - t.Fatalf("ping(conn2) failed: %v", err) - } else if result != "node1:ping" { - t.Fatalf("Unexpected ping(conn2) result: %v", result) - } - - t.Log("conn2 tested OK") - - // make sure the health checks don't close the connection we just made - - // connections should only be closed when it transitions from health to unhealthy. - time.Sleep(2 * time.Second) - - if result, err := ping(conn2); err != nil { - t.Fatalf("ping(conn2) failed: %v", err) - } else if result != "node1:ping" { - t.Fatalf("Unexpected ping(conn2) result: %v", result) - } - - t.Log("conn2 tested OK again") - - // shut down the first node server to force failover to the default - node1Server.close() - - // make sure new connections go to the default, and existing connections are closed - conn3, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - - } - if result, err := ping(conn3); err != nil { - t.Fatalf("ping(conn3) failed: %v", err) - } else if result != "default:ping" { - t.Fatalf("Unexpected ping(conn3) result: %v", result) - } + By("conn1 tested OK") + }) + + It("changes node1 state to failed", func() { + // set failing health check for node 1 + lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultFailed }) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(node1Server.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(stateFailed)) + }) + + It("disconnects from node1", func() { + // Server connections are checked every second, now that node 1 is failed + // the connections to it should be closed. + Expect(ping(conn1)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") + + By("conn1 closed on failure OK") + + // connections shoould go to the default now that node 1 is failed + conn2, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result") - t.Log("conn3 tested OK") + By("conn2 tested OK") + }) + + It("does not close connections unexpectedly", func() { + // make sure the health checks don't close the connection we just made - + // connections should only be closed when it transitions from health to unhealthy. + time.Sleep(2 * time.Second) + + Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result") + + By("conn2 tested OK again") + }) + + It("closes connections when dial fails", func() { + // shut down the first node server to force failover to the default + node1Server.close() + + // make sure new connections go to the default, and existing connections are closed + conn3, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + Expect(ping(conn3)).To(Equal("default:ping"), "Unexpected ping(conn3) result") + + By("conn3 tested OK") + }) - if _, err := ping(conn2); err == nil { - t.Fatal("Unexpected successful ping on closed connection conn2") - } - - t.Log("conn2 closed on failure OK") - - // add the second node as a new server address. - lb.Update([]string{node2Server.address()}) - - // make sure connection now goes to the second node, - // and connections to the default are closed. - conn4, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - - } - if result, err := ping(conn4); err != nil { - t.Fatalf("ping(conn4) failed: %v", err) - } else if result != "node2:ping" { - t.Fatalf("Unexpected ping(conn4) result: %v", result) - } - - t.Log("conn4 tested OK") - - // Server connections are checked every second, now that we have a healthy - // server, connections to the default server should be closed - time.Sleep(2 * time.Second) - - if _, err := ping(conn3); err == nil { - t.Fatal("Unexpected successful ping on connection conn3") - } - - t.Log("conn3 closed on failure OK") -} - -// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly -func Test_UnitFailFast(t *testing.T) { - tmpDir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverURL := "http://127.0.0.1:0/" - lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - - conn, err := net.Dial("tcp", lb.localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - } - - done := make(chan error) - go func() { - _, err = ping(conn) - done <- err - }() - timeout := time.After(10 * time.Millisecond) - - select { - case err := <-done: - if err == nil { - t.Fatal("Unexpected successful ping from invalid address") - } - case <-timeout: - t.Fatal("Test timed out") - } -} - -// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail -// within the expected duration -func Test_UnitFailUnreachable(t *testing.T) { - if testing.Short() { - t.Skip("skipping slow test in short mode.") - } - tmpDir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverAddr := "192.0.2.1:6443" - lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - - // Set failing health check to reduce retries - lb.SetHealthCheck(serverAddr, func() bool { return false }) - - conn, err := net.Dial("tcp", lb.localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - } - - done := make(chan error) - go func() { - _, err = ping(conn) - done <- err - }() - timeout := time.After(11 * time.Second) - - select { - case err := <-done: - if err == nil { - t.Fatal("Unexpected successful ping from unreachable address") - } - case <-timeout: - t.Fatal("Test timed out") - } -} + It("replaces node2 as a server", func() { + // add the second node as a new server address. + lb.Update([]string{node2Server.address}) + lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK }) + + By(fmt.Sprintf("Added node2 server: %v", lb.servers.getServers())) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(node2Server.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + }) + + It("connects to node2", func() { + // make sure connection now goes to the second node, + // and connections to the default are closed. + conn4, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + Expect(ping(conn4)).To(Equal("node2:ping"), "Unexpected ping(conn3) result") + + By("conn4 tested OK") + }) + + It("does not close connections unexpectedly", func() { + // Server connections are checked every second, now that we have a healthy + // server, connections to the default server should be closed + time.Sleep(2 * time.Second) + + Expect(ping(conn2)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") + + By("conn2 closed on failure OK") + + Expect(ping(conn3)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") + + By("conn3 closed on failure OK") + }) + + It("adds default as a server", func() { + // add the default as a full server + lb.Update([]string{node2Server.address, defaultServer.address}) + lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK }) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(defaultServer.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + + By(fmt.Sprintf("Default server added: %v", lb.servers.getServers())) + }) + + It("returns the default server in the address list", func() { + // confirm that both servers are listed in the address list + Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address, defaultServer.address)) + + // confirm that the default is still listed as default + Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") + + }) + + It("does not return the default server in the address list after removing it", func() { + // remove the default as a server + lb.Update([]string{node2Server.address}) + By(fmt.Sprintf("Default removed: %v", lb.servers.getServers())) + + // confirm that it is not listed as a server + Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address)) + + // but is still listed as the default + Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") + }) + + It("removes default server when no longer default", func() { + // set node2 as the default + lb.SetDefault(node2Server.address) + By(fmt.Sprintf("Default set: %v", lb.servers.getServers())) + + // confirm that it is still listed as a server + Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address)) + + // and is listed as the default + Expect(lb.servers.getDefaultAddress()).To(Equal(node2Server.address), "node2 server is not default") + }) + + It("sets all three servers", func() { + // set node2 as the default + lb.SetDefault(defaultServer.address) + By(fmt.Sprintf("Default set: %v", lb.servers.getServers())) + + lb.Update([]string{node1Server.address, node2Server.address, defaultServer.address}) + lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK }) + lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK }) + lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK }) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(defaultServer.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + + By(fmt.Sprintf("All servers set: %v", lb.servers.getServers())) + + // confirm that it is still listed as a server + Expect(lb.ServerAddresses()).To(ConsistOf(node1Server.address, node2Server.address, defaultServer.address)) + + // and is listed as the default + Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") + }) + }) + + // confirms that the loadbalancer will not dial itself + When("the default server is the loadbalancer", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var defaultServer *testServer + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + + defaultServer, err = createServer(ctx, "default") + Expect(err).NotTo(HaveOccurred(), "createServer(default) failed") + address := defaultServer.address + defaultServer.close() + _, port, _ := net.SplitHostPort(address) + intPort, _ := strconv.Atoi(port) + + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+address, intPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("fails immediately", func() { + conn, err := net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + _, err = ping(conn) + Expect(err).To(HaveOccurred(), "Unexpected successful ping on failed connection") + }) + }) + + // confirms that connnections to invalid addresses fail quickly + When("there are no valid addresses", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://127.0.0.1:0/", RandomPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("fails fast", func() { + conn, err := net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + done := make(chan error) + go func() { + _, err = ping(conn) + done <- err + }() + timeout := time.After(10 * time.Millisecond) + + select { + case err := <-done: + if err == nil { + Fail("Unexpected successful ping from invalid address") + } + case <-timeout: + Fail("Test timed out") + } + }) + }) + + // confirms that connnections to unreachable addresses do fail within the + // expected duration + When("the server is unreachable", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://192.0.2.1:6443", RandomPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("fails with the correct timeout", func() { + conn, err := net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + done := make(chan error) + go func() { + _, err = ping(conn) + done <- err + }() + timeout := time.After(11 * time.Second) + + select { + case err := <-done: + if err == nil { + Fail("Unexpected successful ping from unreachable address") + } + case <-timeout: + Fail("Test timed out") + } + }) + }) +}) diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index 675bee5c5c86..7cdf8466ed81 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -1,118 +1,421 @@ package loadbalancer import ( + "cmp" "context" - "math/rand" + "errors" + "fmt" "net" "slices" + "sync" "time" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" ) -func (lb *LoadBalancer) setServers(serverAddresses []string) bool { - serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress) - if len(serverAddresses) == 0 { - return false - } +type HealthCheckFunc func() HealthCheckResult + +// HealthCheckResult indicates the status of a server health check poll. +// For health-checks that poll in the background, Unknown should be returned +// if a poll has not occurred since the last check. +type HealthCheckResult int - lb.mutex.Lock() - defer lb.mutex.Unlock() +const ( + HealthCheckResultUnknown HealthCheckResult = iota + HealthCheckResultFailed + HealthCheckResultOK +) + +// serverList tracks potential backend servers for use by a loadbalancer. +type serverList struct { + // This mutex protects access to the server list. All direct access to the list should be protected by it. + mutex sync.Mutex + servers []*server +} - newAddresses := sets.NewString(serverAddresses...) - curAddresses := sets.NewString(lb.serverAddresses...) +// setServers updates the server list to contain only the selected addresses. +func (sl *serverList) setAddresses(serviceName string, addresses []string) bool { + newAddresses := sets.New(addresses...) + curAddresses := sets.New(sl.getAddresses()...) if newAddresses.Equal(curAddresses) { return false } - for addedServer := range newAddresses.Difference(curAddresses) { - logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer) - lb.servers[addedServer] = &server{ - address: addedServer, - connections: make(map[net.Conn]struct{}), - healthCheck: func() bool { return true }, + sl.mutex.Lock() + defer sl.mutex.Unlock() + + var closeAllFuncs []func() + var defaultServer *server + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 { + defaultServer = sl.servers[i] + } + + // add new servers + for addedAddress := range newAddresses.Difference(curAddresses) { + if defaultServer != nil && defaultServer.address == addedAddress { + // make default server go through the same health check promotions as a new server when added + logrus.Infof("Server %s->%s from add to load balancer %s", defaultServer, stateUnchecked, serviceName) + defaultServer.state = stateUnchecked + defaultServer.lastTransition = time.Now() + } else { + s := newServer(addedAddress, false) + logrus.Infof("Adding server to load balancer %s: %s", serviceName, s.address) + sl.servers = append(sl.servers, s) } } - for removedServer := range curAddresses.Difference(newAddresses) { - server := lb.servers[removedServer] - if server != nil { - logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer) - // Defer closing connections until after the new server list has been put into place. - // Closing open connections ensures that anything stuck retrying on a stale server is forced - // over to a valid endpoint. - defer server.closeAll() - // Don't delete the default server from the server map, in case we need to fall back to it. - if removedServer != lb.defaultServerAddress { - delete(lb.servers, removedServer) - } + // remove old servers + for removedAddress := range curAddresses.Difference(newAddresses) { + if defaultServer != nil && defaultServer.address == removedAddress { + // demote the default server down to standby, instead of deleting it + defaultServer.state = stateStandby + closeAllFuncs = append(closeAllFuncs, defaultServer.closeAll) + } else { + sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool { + if s.address == removedAddress { + logrus.Infof("Removing server from load balancer %s: %s", serviceName, s.address) + // set state to invalid to prevent server from making additional connections + s.state = stateInvalid + closeAllFuncs = append(closeAllFuncs, s.closeAll) + return true + } + return false + }) } } - lb.serverAddresses = serverAddresses - lb.randomServers = append([]string{}, lb.serverAddresses...) - rand.Shuffle(len(lb.randomServers), func(i, j int) { - lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] + slices.SortFunc(sl.servers, compareServers) + + // Close all connections to servers that were removed + for _, closeAll := range closeAllFuncs { + closeAll() + } + + return true +} + +// getAddresses returns the addresses of all servers. +// If the default server is in standby state, indicating it is only present +// because it is the default, it is not returned in this list. +func (sl *serverList) getAddresses() []string { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + addresses := make([]string, 0, len(sl.servers)) + for _, s := range sl.servers { + if s.isDefault && s.state == stateStandby { + continue + } + addresses = append(addresses, s.address) + } + return addresses +} + +// setDefault sets the server with the provided address as the default server. +// The default flag is cleared on all other servers, and if the server was previously +// only kept in the list because it was the default, it is removed. +func (sl *serverList) setDefaultAddress(serviceName, address string) { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // deal with existing default first + sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool { + if s.isDefault && s.address != address { + s.isDefault = false + if s.state == stateStandby { + s.state = stateInvalid + defer s.closeAll() + return true + } + } + return false }) - // If the current server list does not contain the default server address, - // we want to include it in the random server list so that it can be tried if necessary. - // However, it should be treated as always failing health checks so that it is only - // used if all other endpoints are unavailable. - if !hasDefaultServer { - lb.randomServers = append(lb.randomServers, lb.defaultServerAddress) - if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok { - defaultServer.healthCheck = func() bool { return false } - lb.servers[lb.defaultServerAddress] = defaultServer + + // update or create server with selected address + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 { + sl.servers[i].isDefault = true + } else { + sl.servers = append(sl.servers, newServer(address, true)) + } + + logrus.Infof("Updated load balancer %s default server: %s", serviceName, address) + slices.SortFunc(sl.servers, compareServers) +} + +// getDefault returns the address of the default server. +func (sl *serverList) getDefaultAddress() string { + if s := sl.getDefaultServer(); s != nil { + return s.address + } + return "" +} + +// getDefault returns the default server. +func (sl *serverList) getDefaultServer() *server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 { + return sl.servers[i] + } + return nil +} + +// getServers returns a copy of the servers list that can be safely iterated over without holding a lock +func (sl *serverList) getServers() []*server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + return slices.Clone(sl.servers) +} + +// getServer returns the first server with the specified address +func (sl *serverList) getServer(address string) *server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 { + return sl.servers[i] + } + return nil +} + +// setHealthCheck updates the health check function for a server, replacing the +// current function. +func (sl *serverList) setHealthCheck(address string, healthCheck HealthCheckFunc) error { + if s := sl.getServer(address); s != nil { + s.healthCheck = healthCheck + return nil + } + return fmt.Errorf("no server found for %s", address) +} + +// recordSuccess records a successful check of a server, either via health-check or dial. +// The server's state is adjusted accordingly. +func (sl *serverList) recordSuccess(srv *server, r reason) { + var new_state state + switch srv.state { + case stateFailed, stateUnchecked: + // dialed or health checked OK once, improve to recovering + new_state = stateRecovering + case stateRecovering: + if r == reasonHealthCheck { + // was recovering due to successful dial or first health check, can now improve + if len(srv.connections) > 0 { + // server accepted connections while recovering, attempt to go straight to active + new_state = stateActive + } else { + // no connections, just make it preferred + new_state = statePreferred + } + } + case stateHealthy: + if r == reasonDial { + // improve from healthy to active by being dialed + new_state = stateActive + } + case statePreferred: + if r == reasonDial { + // improve from healthy to active by being dialed + new_state = stateActive + } else { + if time.Now().Sub(srv.lastTransition) > time.Minute { + // has been preferred for a while without being dialed, demote to healthy + new_state = stateHealthy + } } } - lb.currentServerAddress = lb.randomServers[0] - lb.nextServerIndex = 1 - return true + // no-op if state did not change + if new_state == stateInvalid { + return + } + + // handle active transition and sort the server list while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // handle states of other servers when attempting to make this one active + if new_state == stateActive { + for _, s := range sl.servers { + if srv.address == s.address { + continue + } + switch s.state { + case stateFailed, stateStandby, stateRecovering, stateHealthy: + // close connections to other non-active servers whenever we have a new active server + defer s.closeAll() + case stateActive: + if len(s.connections) > len(srv.connections) { + // if there is a currently active server that has more connections than we do, + // close our connections and go to preferred instead + new_state = statePreferred + defer srv.closeAll() + } else { + // otherwise, close its connections and demote it to preferred + s.state = statePreferred + defer s.closeAll() + } + } + } + } + + // ensure some other routine didn't already make the transition + if srv.state == new_state { + return + } + + logrus.Infof("Server %s->%s from successful %s", srv, new_state, r) + srv.state = new_state + srv.lastTransition = time.Now() + + slices.SortFunc(sl.servers, compareServers) } -// nextServer attempts to get the next server in the loadbalancer server list. -// If another goroutine has already updated the current server address to point at -// a different address than just failed, nothing is changed. Otherwise, a new server address -// is stored to the currentServerAddress field, and returned for use. -// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex. -func (lb *LoadBalancer) nextServer(failedServer string) (string, error) { - // note: these fields are not protected by the mutex, so we clamp the index value and update - // the index/current address using local variables, to avoid time-of-check vs time-of-use - // race conditions caused by goroutine A incrementing it in between the time goroutine B - // validates its value, and uses it as a list index. - currentServerAddress := lb.currentServerAddress - nextServerIndex := lb.nextServerIndex +// recordFailure records a failed check of a server, either via health-check or dial. +// The server's state is adjusted accordingly. +func (sl *serverList) recordFailure(srv *server, r reason) { + var new_state state + switch srv.state { + case stateUnchecked, stateRecovering: + if r == reasonDial { + // only demote from unchecked or recovering if a dial fails, health checks may + // continue to fail despite it being dialable. just leave it where it is + // and don't close any connections. + new_state = stateFailed + } + case stateHealthy, statePreferred, stateActive: + // should not have any connections when in any state other than active or + // recovering, but close them all anyway to force failover. + defer srv.closeAll() + new_state = stateFailed + } - if len(lb.randomServers) == 0 { - return "", errors.New("No servers in load balancer proxy list") + // no-op if state did not change + if new_state == stateInvalid { + return } - if len(lb.randomServers) == 1 { - return currentServerAddress, nil + + // sort the server list while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // ensure some other routine didn't already make the transition + if srv.state == new_state { + return } - if failedServer != currentServerAddress { - return currentServerAddress, nil + + logrus.Infof("Server %s->%s from failed %s", srv, new_state, r) + srv.state = new_state + srv.lastTransition = time.Now() + + slices.SortFunc(sl.servers, compareServers) +} + +// state is possible server health states, in increasing order of preference. +// The server list is kept sorted in descending order by this state value. +type state int + +const ( + stateInvalid state = iota + stateFailed // failed a health check or dial + stateStandby // reserved for use by default server if not in server list + stateUnchecked // just added, has not been health checked + stateRecovering // successfully health checked once, or dialed when failed + stateHealthy // normal state + statePreferred // recently transitioned from recovering; should be preferred as others may go down for maintenance + stateActive // currently active server +) + +func (s state) String() string { + switch s { + case stateInvalid: + return "INVALID" + case stateFailed: + return "FAILED" + case stateStandby: + return "STANDBY" + case stateUnchecked: + return "UNCHECKED" + case stateRecovering: + return "RECOVERING" + case stateHealthy: + return "HEALTHY" + case statePreferred: + return "PREFERRED" + case stateActive: + return "ACTIVE" + default: + return "UNKNOWN" } - if nextServerIndex >= len(lb.randomServers) { - nextServerIndex = 0 +} + +// reason specifies the reason for a successful or failed health report +type reason int + +const ( + reasonDial reason = iota + reasonHealthCheck +) + +func (r reason) String() string { + switch r { + case reasonDial: + return "dial" + case reasonHealthCheck: + return "health check" + default: + return "unknown reason" } +} - currentServerAddress = lb.randomServers[nextServerIndex] - nextServerIndex++ +// server tracks the connections to a server, so that they can be closed when the server is removed. +type server struct { + // This mutex protects access to the connections map. All direct access to the map should be protected by it. + mutex sync.Mutex + address string + isDefault bool + state state + lastTransition time.Time + healthCheck HealthCheckFunc + connections map[net.Conn]struct{} +} - lb.currentServerAddress = currentServerAddress - lb.nextServerIndex = nextServerIndex +// newServer creates a new server, with a default health check +// and default/state fields appropriate for whether or not +// the server is a full server, or just a fallback default. +func newServer(address string, isDefault bool) *server { + state := stateUnchecked + if isDefault { + state = stateStandby + } + return &server{ + address: address, + isDefault: isDefault, + state: state, + lastTransition: time.Now(), + healthCheck: func() HealthCheckResult { return HealthCheckResultUnknown }, + connections: make(map[net.Conn]struct{}), + } +} - return currentServerAddress, nil +func (s *server) String() string { + format := "%s@%s" + if s.isDefault { + format += "*" + } + return fmt.Sprintf(format, s.address, s.state) } -// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map -func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) { - conn, err := defaultDialer.Dial(network, address) +// dialContext dials a new connection to the server using the environment's proxy settings, and adds its wrapped connection to the map +func (s *server) dialContext(ctx context.Context, network string) (net.Conn, error) { + if s.state == stateInvalid { + return nil, fmt.Errorf("server %s is stopping", s.address) + } + + conn, err := defaultDialer.Dial(network, s.address) if err != nil { return nil, err } @@ -132,7 +435,7 @@ func (s *server) closeAll() { defer s.mutex.Unlock() if l := len(s.connections); l > 0 { - logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address) + logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s) for conn := range s.connections { // Close the connection in a goroutine so that we don't hold the lock while doing so. go conn.Close() @@ -140,6 +443,12 @@ func (s *server) closeAll() { } } +// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. +type serverConn struct { + server *server + net.Conn +} + // Close removes the connection entry from the server's connection map, and // closes the wrapped connection. func (sc *serverConn) Close() error { @@ -150,73 +459,43 @@ func (sc *serverConn) Close() error { return sc.Conn.Close() } -// SetDefault sets the selected address as the default / fallback address -func (lb *LoadBalancer) SetDefault(serverAddress string) { - lb.mutex.Lock() - defer lb.mutex.Unlock() - - hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) - // if the old default server is not currently in use, remove it from the server map - if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer { - defer server.closeAll() - delete(lb.servers, lb.defaultServerAddress) - } - // if the new default server doesn't have an entry in the map, add one - but - // with a failing health check so that it is only used as a last resort. - if _, ok := lb.servers[serverAddress]; !ok { - lb.servers[serverAddress] = &server{ - address: serverAddress, - healthCheck: func() bool { return false }, - connections: make(map[net.Conn]struct{}), +// runHealthChecks periodically health-checks all servers. +func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) { + wait.Until(func() { + for _, s := range sl.getServers() { + switch s.healthCheck() { + case HealthCheckResultOK: + sl.recordSuccess(s, reasonHealthCheck) + case HealthCheckResultFailed: + sl.recordFailure(s, reasonHealthCheck) + } } - } - - lb.defaultServerAddress = serverAddress - logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress) + }, time.Second, ctx.Done()) + logrus.Debugf("Stopped health checking for load balancer %s", serviceName) } -// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. -func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) { - lb.mutex.Lock() - defer lb.mutex.Unlock() - - if server := lb.servers[address]; server != nil { - logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address) - server.healthCheck = healthCheck - } else { - logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address) +// dialContext attemps to dial a connection to a server from the server list. +// Success or failure is recorded to ensure that server state is updated appropriately. +func (sl *serverList) dialContext(ctx context.Context, network, _ string) (net.Conn, error) { + for _, s := range sl.getServers() { + dialTime := time.Now() + conn, err := s.dialContext(ctx, network) + if err == nil { + sl.recordSuccess(s, reasonDial) + return conn, nil + } + logrus.Debugf("Dial error from server %s after %s: %s", s, time.Now().Sub(dialTime), err) + sl.recordFailure(s, reasonDial) } + return nil, errors.New("all servers failed") } -// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their -// connections closed, to force clients to switch over to a healthy server. -func (lb *LoadBalancer) runHealthChecks(ctx context.Context) { - previousStatus := map[string]bool{} - wait.Until(func() { - lb.mutex.RLock() - defer lb.mutex.RUnlock() - var healthyServerExists bool - for address, server := range lb.servers { - status := server.healthCheck() - healthyServerExists = healthyServerExists || status - if status == false && previousStatus[address] == true { - // Only close connections when the server transitions from healthy to unhealthy; - // we don't want to re-close all the connections every time as we might be ignoring - // health checks due to all servers being marked unhealthy. - defer server.closeAll() - } - previousStatus[address] = status - } - - // If there is at least one healthy server, and the default server is not in the server list, - // close all the connections to the default server so that clients reconnect and switch over - // to a preferred server. - hasDefaultServer := slices.Contains(lb.serverAddresses, lb.defaultServerAddress) - if healthyServerExists && !hasDefaultServer { - if server, ok := lb.servers[lb.defaultServerAddress]; ok { - defer server.closeAll() - } - } - }, time.Second, ctx.Done()) - logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) +// compareServers is a comparison function that can be used to sort the server list +// so that servers with a more preferred state, or higher number of connections, are ordered first. +func compareServers(a, b *server) int { + c := cmp.Compare(b.state, a.state) + if c == 0 { + return cmp.Compare(len(b.connections), len(a.connections)) + } + return c } diff --git a/pkg/agent/proxy/apiproxy.go b/pkg/agent/proxy/apiproxy.go index e711623e467e..56d86a031366 100644 --- a/pkg/agent/proxy/apiproxy.go +++ b/pkg/agent/proxy/apiproxy.go @@ -22,7 +22,7 @@ type Proxy interface { SupervisorAddresses() []string APIServerURL() string IsAPIServerLBEnabled() bool - SetHealthCheck(address string, healthCheck func() bool) + SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) } // NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If @@ -52,7 +52,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor return nil, err } p.supervisorLB = lb - p.supervisorURL = lb.LoadBalancerServerURL() + p.supervisorURL = lb.LocalURL() p.apiServerURL = p.supervisorURL } @@ -102,7 +102,7 @@ func (p *proxy) Update(addresses []string) { p.supervisorAddresses = supervisorAddresses } -func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) { +func (p *proxy) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) { if p.supervisorLB != nil { p.supervisorLB.SetHealthCheck(address, healthCheck) } @@ -155,7 +155,7 @@ func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error { return err } p.apiServerLB = lb - p.apiServerURL = lb.LoadBalancerServerURL() + p.apiServerURL = lb.LocalURL() } else { p.apiServerURL = u.String() } diff --git a/pkg/agent/tunnel/tunnel.go b/pkg/agent/tunnel/tunnel.go index a5df415c7343..d04f9fdc0b22 100644 --- a/pkg/agent/tunnel/tunnel.go +++ b/pkg/agent/tunnel/tunnel.go @@ -14,6 +14,7 @@ import ( "github.com/gorilla/websocket" agentconfig "github.com/k3s-io/k3s/pkg/agent/config" + "github.com/k3s-io/k3s/pkg/agent/loadbalancer" "github.com/k3s-io/k3s/pkg/agent/proxy" daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/util" @@ -310,7 +311,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan if _, ok := disconnect[address]; !ok { conn := a.connect(ctx, wg, address, tlsConfig) disconnect[address] = conn.cancel - proxy.SetHealthCheck(address, conn.connected) + proxy.SetHealthCheck(address, conn.healthCheck) } } @@ -384,7 +385,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan if _, ok := disconnect[address]; !ok { conn := a.connect(ctx, nil, address, tlsConfig) disconnect[address] = conn.cancel - proxy.SetHealthCheck(address, conn.connected) + proxy.SetHealthCheck(address, conn.healthCheck) } } @@ -427,22 +428,20 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo } type agentConnection struct { - cancel context.CancelFunc - connected func() bool + cancel context.CancelFunc + healthCheck loadbalancer.HealthCheckFunc } // connect initiates a connection to the remotedialer server. Incoming dial requests from // the server will be checked by the authorizer function prior to being fulfilled. func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection { + var status loadbalancer.HealthCheckResult + wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address) ws := &websocket.Dialer{ TLSClientConfig: tlsConfig, } - // Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect. - // If we cannot connect, connected will be set to false when the initial connection attempt fails. - connected := true - once := sync.Once{} if waitGroup != nil { waitGroup.Add(1) @@ -454,7 +453,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup } onConnect := func(_ context.Context, _ *remotedialer.Session) error { - connected = true + status = loadbalancer.HealthCheckResultOK logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy") if waitGroup != nil { once.Do(waitGroup.Done) @@ -467,7 +466,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup for { // ConnectToProxy blocks until error or context cancellation err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect) - connected = false + status = loadbalancer.HealthCheckResultFailed if err != nil && !errors.Is(err, context.Canceled) { logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...") // wait between reconnection attempts to avoid hammering the server @@ -484,8 +483,10 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup }() return agentConnection{ - cancel: cancel, - connected: func() bool { return connected }, + cancel: cancel, + healthCheck: func() loadbalancer.HealthCheckResult { + return status + }, } } diff --git a/pkg/etcd/etcdproxy.go b/pkg/etcd/etcdproxy.go index ec781e11a3ae..156834440c08 100644 --- a/pkg/etcd/etcdproxy.go +++ b/pkg/etcd/etcdproxy.go @@ -48,6 +48,10 @@ type etcdproxy struct { } func (e *etcdproxy) Update(addresses []string) { + if e.etcdLB == nil { + return + } + e.etcdLB.Update(addresses) validEndpoint := map[string]bool{} @@ -70,10 +74,8 @@ func (e *etcdproxy) Update(addresses []string) { // start a polling routine that makes periodic requests to the etcd node's supervisor port. // If the request fails, the node is marked unhealthy. -func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() bool { - // Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect. - // If we cannot connect, connected will be set to false when the initial connection attempt fails. - connected := true +func (e etcdproxy) createHealthCheck(ctx context.Context, address string) loadbalancer.HealthCheckFunc { + var status loadbalancer.HealthCheckResult host, _, _ := net.SplitHostPort(address) url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort))) @@ -89,13 +91,17 @@ func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() } if err != nil || statusCode != http.StatusOK { logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode) - connected = false + status = loadbalancer.HealthCheckResultFailed } else { - connected = true + status = loadbalancer.HealthCheckResultOK } }, 5*time.Second, 1.0, true) - return func() bool { - return connected + return func() loadbalancer.HealthCheckResult { + // Reset the status to unknown on reading, until next time it is checked. + // This avoids having a health check result alter the server state between active checks. + s := status + status = loadbalancer.HealthCheckResultUnknown + return s } } From ba27be130ce780746cc019f028229d6e20524b0b Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Thu, 5 Dec 2024 01:37:08 +0000 Subject: [PATCH 06/14] Add loadbalancer metrics Signed-off-by: Brad Davidson (cherry picked from commit 3d2fabb013ed1cd280e1711557382d64f536f327) Signed-off-by: Brad Davidson --- pkg/agent/loadbalancer/loadbalancer.go | 13 ++++++++++- pkg/agent/loadbalancer/metrics.go | 30 ++++++++++++++++++++++++++ pkg/agent/loadbalancer/servers.go | 9 +++++++- pkg/metrics/metrics.go | 3 +++ 4 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 pkg/agent/loadbalancer/metrics.go diff --git a/pkg/agent/loadbalancer/loadbalancer.go b/pkg/agent/loadbalancer/loadbalancer.go index 2f6d33fbf4c2..09727db18922 100644 --- a/pkg/agent/loadbalancer/loadbalancer.go +++ b/pkg/agent/loadbalancer/loadbalancer.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/inetaf/tcpproxy" "github.com/k3s-io/k3s/pkg/version" @@ -95,8 +96,18 @@ func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbS } lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{ Addr: serviceName, - DialContext: lb.servers.dialContext, OnDialError: onDialError, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + start := time.Now() + status := "success" + conn, err := lb.servers.dialContext(ctx, network, address) + latency := time.Since(start) + if err != nil { + status = "error" + } + loadbalancerDials.WithLabelValues(serviceName, status).Observe(latency.Seconds()) + return conn, err + }, }) if err := lb.updateConfig(); err != nil { diff --git a/pkg/agent/loadbalancer/metrics.go b/pkg/agent/loadbalancer/metrics.go new file mode 100644 index 000000000000..11f27486eda7 --- /dev/null +++ b/pkg/agent/loadbalancer/metrics.go @@ -0,0 +1,30 @@ +package loadbalancer + +import ( + "github.com/k3s-io/k3s/pkg/version" + "github.com/prometheus/client_golang/prometheus" + "k8s.io/component-base/metrics" +) + +var ( + loadbalancerConnections = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: version.Program + "_loadbalancer_server_connections", + Help: "Count of current connections to loadbalancer server", + }, []string{"name", "server"}) + + loadbalancerState = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: version.Program + "_loadbalancer_server_health", + Help: "Current health value of loadbalancer server", + }, []string{"name", "server"}) + + loadbalancerDials = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: version.Program + "_loadbalancer_dial_duration_seconds", + Help: "Time taken to dial a connection to a backend server", + Buckets: metrics.ExponentialBuckets(0.001, 2, 15), + }, []string{"name", "status"}) +) + +// MustRegister registers loadbalancer metrics +func MustRegister(registerer prometheus.Registerer) { + registerer.MustRegister(loadbalancerConnections, loadbalancerState, loadbalancerDials) +} diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index 7cdf8466ed81..13334ea881dc 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -79,6 +79,9 @@ func (sl *serverList) setAddresses(serviceName string, addresses []string) bool // set state to invalid to prevent server from making additional connections s.state = stateInvalid closeAllFuncs = append(closeAllFuncs, s.closeAll) + // remove metrics + loadbalancerState.DeleteLabelValues(serviceName, s.address) + loadbalancerConnections.DeleteLabelValues(serviceName, s.address) return true } return false @@ -459,7 +462,7 @@ func (sc *serverConn) Close() error { return sc.Conn.Close() } -// runHealthChecks periodically health-checks all servers. +// runHealthChecks periodically health-checks all servers and updates metrics func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) { wait.Until(func() { for _, s := range sl.getServers() { @@ -469,6 +472,10 @@ func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) { case HealthCheckResultFailed: sl.recordFailure(s, reasonHealthCheck) } + if s.state != stateInvalid { + loadbalancerState.WithLabelValues(serviceName, s.address).Set(float64(s.state)) + loadbalancerConnections.WithLabelValues(serviceName, s.address).Set(float64(len(s.connections))) + } } }, time.Second, ctx.Done()) logrus.Debugf("Stopped health checking for load balancer %s", serviceName) diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index a769e6a38418..eccb4abb0bbc 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -6,6 +6,7 @@ import ( "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/agent/https" + "github.com/k3s-io/k3s/pkg/agent/loadbalancer" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/prometheus/client_golang/prometheus/promhttp" lassometrics "github.com/rancher/lasso/pkg/metrics" @@ -32,6 +33,8 @@ var DefaultMetrics = &Config{ func init() { // ensure that lasso exposes metrics through the same registry used by Kubernetes components lassometrics.MustRegister(DefaultRegisterer) + // same for loadbalancer metrics + loadbalancer.MustRegister(DefaultRegisterer) } // Config holds fields for the metrics listener From 7e64d3868bdfd76145d87756f30dc71fdd06889d Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Thu, 21 Nov 2024 23:48:49 +0000 Subject: [PATCH 07/14] Use helper to set consistent rest.Config rate limits and timeouts Signed-off-by: Brad Davidson (cherry picked from commit 71918e0d69021e19f054e1001d781c57f5047983) Signed-off-by: Brad Davidson --- pkg/agent/netpol/netpol.go | 4 ++-- pkg/agent/tunnel/tunnel.go | 3 +-- pkg/cli/token/token.go | 3 +-- pkg/daemons/config/types.go | 2 ++ pkg/daemons/executor/embed.go | 3 +-- pkg/secretsencrypt/config.go | 3 +-- pkg/server/context.go | 3 +-- pkg/server/router.go | 8 +++++--- pkg/server/secrets-encrypt.go | 15 +-------------- pkg/server/server.go | 6 +++--- pkg/util/api.go | 5 ++--- pkg/util/client.go | 17 ++++++++++++++++- 12 files changed, 36 insertions(+), 36 deletions(-) diff --git a/pkg/agent/netpol/netpol.go b/pkg/agent/netpol/netpol.go index 5c892a668f36..a9f7a43f532e 100644 --- a/pkg/agent/netpol/netpol.go +++ b/pkg/agent/netpol/netpol.go @@ -26,12 +26,12 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/metrics" + "github.com/k3s-io/k3s/pkg/util" "github.com/pkg/errors" "github.com/sirupsen/logrus" v1core "k8s.io/api/core/v1" "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/tools/clientcmd" ) func init() { @@ -57,7 +57,7 @@ func Run(ctx context.Context, nodeConfig *config.Node) error { return nil } - restConfig, err := clientcmd.BuildConfigFromFlags("", nodeConfig.AgentConfig.KubeConfigK3sController) + restConfig, err := util.GetRESTConfig(nodeConfig.AgentConfig.KubeConfigK3sController) if err != nil { return err } diff --git a/pkg/agent/tunnel/tunnel.go b/pkg/agent/tunnel/tunnel.go index d04f9fdc0b22..e1eb566cc675 100644 --- a/pkg/agent/tunnel/tunnel.go +++ b/pkg/agent/tunnel/tunnel.go @@ -32,7 +32,6 @@ import ( "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/tools/cache" - "k8s.io/client-go/tools/clientcmd" toolswatch "k8s.io/client-go/tools/watch" "k8s.io/kubernetes/pkg/cluster/ports" ) @@ -70,7 +69,7 @@ func Setup(ctx context.Context, config *daemonconfig.Node, proxy proxy.Proxy) er return err } - nodeRestConfig, err := clientcmd.BuildConfigFromFlags("", config.AgentConfig.KubeConfigKubelet) + nodeRestConfig, err := util.GetRESTConfig(config.AgentConfig.KubeConfigKubelet) if err != nil { return err } diff --git a/pkg/cli/token/token.go b/pkg/cli/token/token.go index e16038fea5b6..9d514d7b5286 100644 --- a/pkg/cli/token/token.go +++ b/pkg/cli/token/token.go @@ -24,7 +24,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/util/duration" - "k8s.io/client-go/tools/clientcmd" bootstrapapi "k8s.io/cluster-bootstrap/token/api" bootstraputil "k8s.io/cluster-bootstrap/token/util" "k8s.io/utils/ptr" @@ -48,7 +47,7 @@ func create(app *cli.Context, cfg *cmds.Token) error { return err } - restConfig, err := clientcmd.BuildConfigFromFlags("", cfg.Kubeconfig) + restConfig, err := util.GetRESTConfig(cfg.Kubeconfig) if err != nil { return err } diff --git a/pkg/daemons/config/types.go b/pkg/daemons/config/types.go index f6336a2ba251..ffb67d702e02 100644 --- a/pkg/daemons/config/types.go +++ b/pkg/daemons/config/types.go @@ -17,6 +17,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilnet "k8s.io/apimachinery/pkg/util/net" "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/client-go/kubernetes" "k8s.io/client-go/tools/record" utilsnet "k8s.io/utils/net" ) @@ -369,6 +370,7 @@ type ControlRuntime struct { ClientETCDCert string ClientETCDKey string + K8s kubernetes.Interface K3s *k3s.Factory Core *core.Factory Event record.EventRecorder diff --git a/pkg/daemons/executor/embed.go b/pkg/daemons/executor/embed.go index 0553da84e3e0..7e69f956e84b 100644 --- a/pkg/daemons/executor/embed.go +++ b/pkg/daemons/executor/embed.go @@ -28,7 +28,6 @@ import ( "k8s.io/apiserver/pkg/authentication/authenticator" typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/tools/cache" - "k8s.io/client-go/tools/clientcmd" toolswatch "k8s.io/client-go/tools/watch" cloudprovider "k8s.io/cloud-provider" cloudproviderapi "k8s.io/cloud-provider/api" @@ -269,7 +268,7 @@ func (e *Embedded) Docker(ctx context.Context, cfg *daemonconfig.Node) error { // waitForUntaintedNode watches nodes, waiting to find one not tainted as // uninitialized by the external cloud provider. func waitForUntaintedNode(ctx context.Context, kubeConfig string) error { - restConfig, err := clientcmd.BuildConfigFromFlags("", kubeConfig) + restConfig, err := util.GetRESTConfig(kubeConfig) if err != nil { return err } diff --git a/pkg/secretsencrypt/config.go b/pkg/secretsencrypt/config.go index aae309d8fbad..7d2f2e4a725b 100644 --- a/pkg/secretsencrypt/config.go +++ b/pkg/secretsencrypt/config.go @@ -15,7 +15,6 @@ import ( "github.com/k3s-io/k3s/pkg/version" "github.com/prometheus/common/expfmt" corev1 "k8s.io/api/core/v1" - "k8s.io/client-go/tools/clientcmd" "github.com/k3s-io/k3s/pkg/generated/clientset/versioned/scheme" "github.com/sirupsen/logrus" @@ -237,7 +236,7 @@ func GetEncryptionConfigMetrics(runtime *config.ControlRuntime, initialMetrics b var unixUpdateTime int64 var reloadSuccessCounter int64 var lastFailure string - restConfig, err := clientcmd.BuildConfigFromFlags("", runtime.KubeConfigSupervisor) + restConfig, err := util.GetRESTConfig(runtime.KubeConfigSupervisor) if err != nil { return 0, 0, err } diff --git a/pkg/server/context.go b/pkg/server/context.go index ac6724820ee3..fb4928e8f1ad 100644 --- a/pkg/server/context.go +++ b/pkg/server/context.go @@ -19,7 +19,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/record" ) @@ -43,7 +42,7 @@ func NewContext(ctx context.Context, config *Config, forServer bool) (*Context, if forServer { cfg = config.ControlConfig.Runtime.KubeConfigSupervisor } - restConfig, err := clientcmd.BuildConfigFromFlags("", cfg) + restConfig, err := util.GetRESTConfig(cfg) if err != nil { return nil, err } diff --git a/pkg/server/router.go b/pkg/server/router.go index ec60d5f3d9c9..b4a2dc57cbb7 100644 --- a/pkg/server/router.go +++ b/pkg/server/router.go @@ -34,6 +34,7 @@ import ( "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/endpoints/request" + typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" bootstrapapi "k8s.io/cluster-bootstrap/token/api" "k8s.io/kubernetes/pkg/auth/nodeidentifier" ) @@ -305,16 +306,17 @@ func fileHandler(fileName ...string) http.Handler { } func apiserversHandler(server *config.Control) http.Handler { - var endpointsClient coreclient.EndpointsClient + var endpointsClient typedcorev1.EndpointsInterface return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + ctx := req.Context() var endpoints []string if endpointsClient == nil { if server.Runtime.Core != nil { - endpointsClient = server.Runtime.Core.Core().V1().Endpoints() + endpointsClient = server.Runtime.K8s.CoreV1().Endpoints(metav1.NamespaceDefault) } } if endpointsClient != nil { - if endpoint, _ := endpointsClient.Get("default", "kubernetes", metav1.GetOptions{}); endpoint != nil { + if endpoint, _ := endpointsClient.Get(ctx, "kubernetes", metav1.GetOptions{}); endpoint != nil { endpoints = util.GetAddresses(endpoint) } } diff --git a/pkg/server/secrets-encrypt.go b/pkg/server/secrets-encrypt.go index 256c98ce1003..a3759d9617c4 100644 --- a/pkg/server/secrets-encrypt.go +++ b/pkg/server/secrets-encrypt.go @@ -27,8 +27,6 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" apiserverconfigv1 "k8s.io/apiserver/pkg/apis/apiserver/v1" - "k8s.io/client-go/kubernetes" - "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/pager" "k8s.io/client-go/util/retry" "k8s.io/utils/ptr" @@ -395,18 +393,7 @@ func reencryptAndRemoveKey(ctx context.Context, server *config.Control, skip boo } func updateSecrets(ctx context.Context, server *config.Control, nodeName string) error { - restConfig, err := clientcmd.BuildConfigFromFlags("", server.Runtime.KubeConfigSupervisor) - if err != nil { - return err - } - // For secrets we need a much higher QPS than default - restConfig.QPS = secretsencrypt.SecretQPS - restConfig.Burst = secretsencrypt.SecretBurst - k8s, err := kubernetes.NewForConfig(restConfig) - if err != nil { - return err - } - + k8s := server.Runtime.K8s nodeRef := &corev1.ObjectReference{ Kind: "Node", Name: nodeName, diff --git a/pkg/server/server.go b/pkg/server/server.go index 8c6f40c4330c..a8c1e0d470f7 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -35,7 +35,6 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" clientset "k8s.io/client-go/kubernetes" - "k8s.io/client-go/tools/clientcmd" ) func ResolveDataDir(dataDir string) (string, error) { @@ -113,6 +112,7 @@ func runControllers(ctx context.Context, config *Config) error { controlConfig.Runtime.NodePasswdFile); err != nil { logrus.Warn(errors.Wrap(err, "error migrating node-password file")) } + controlConfig.Runtime.K8s = sc.K8s controlConfig.Runtime.K3s = sc.K3s controlConfig.Runtime.Event = sc.Event controlConfig.Runtime.Core = sc.Core @@ -208,7 +208,7 @@ func coreControllers(ctx context.Context, sc *Context, config *Config) error { } if !config.ControlConfig.DisableHelmController { - restConfig, err := clientcmd.BuildConfigFromFlags("", config.ControlConfig.Runtime.KubeConfigSupervisor) + restConfig, err := util.GetRESTConfig(config.ControlConfig.Runtime.KubeConfigSupervisor) if err != nil { return err } @@ -285,7 +285,7 @@ func stageFiles(ctx context.Context, sc *Context, controlConfig *config.Control) return err } - restConfig, err := clientcmd.BuildConfigFromFlags("", controlConfig.Runtime.KubeConfigSupervisor) + restConfig, err := util.GetRESTConfig(controlConfig.Runtime.KubeConfigSupervisor) if err != nil { return err } diff --git a/pkg/util/api.go b/pkg/util/api.go index 5ce53c49ba48..4df9ad73a945 100644 --- a/pkg/util/api.go +++ b/pkg/util/api.go @@ -23,7 +23,6 @@ import ( authorizationv1client "k8s.io/client-go/kubernetes/typed/authorization/v1" coregetter "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/record" ) @@ -58,7 +57,7 @@ func GetAddresses(endpoint *v1.Endpoints) []string { // readyz endpoint instead of the deprecated healthz endpoint, and supports context. func WaitForAPIServerReady(ctx context.Context, kubeconfigPath string, timeout time.Duration) error { var lastErr error - restConfig, err := clientcmd.BuildConfigFromFlags("", kubeconfigPath) + restConfig, err := GetRESTConfig(kubeconfigPath) if err != nil { return err } @@ -112,7 +111,7 @@ type genericAccessReviewRequest func(context.Context) (*authorizationv1.SubjectA // the access would be allowed. func WaitForRBACReady(ctx context.Context, kubeconfigPath string, timeout time.Duration, ra authorizationv1.ResourceAttributes, user string, groups ...string) error { var lastErr error - restConfig, err := clientcmd.BuildConfigFromFlags("", kubeconfigPath) + restConfig, err := GetRESTConfig(kubeconfigPath) if err != nil { return err } diff --git a/pkg/util/client.go b/pkg/util/client.go index 561a5cbc0817..a7ca9fe26b6d 100644 --- a/pkg/util/client.go +++ b/pkg/util/client.go @@ -5,12 +5,15 @@ import ( "os" "runtime" "strings" + "time" "github.com/k3s-io/k3s/pkg/datadir" "github.com/k3s-io/k3s/pkg/version" + "github.com/rancher/wrangler/v3/pkg/ratelimit" "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/apis/meta/v1/validation" clientset "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" ) @@ -28,7 +31,7 @@ func GetKubeConfigPath(file string) string { // GetClientSet creates a Kubernetes client from the kubeconfig at the provided path. func GetClientSet(file string) (clientset.Interface, error) { - restConfig, err := clientcmd.BuildConfigFromFlags("", file) + restConfig, err := GetRESTConfig(file) if err != nil { return nil, err } @@ -36,6 +39,18 @@ func GetClientSet(file string) (clientset.Interface, error) { return clientset.NewForConfig(restConfig) } +// GetRESTConfig returns a REST config with default timeouts and ratelimitsi cribbed from wrangler defaults. +// ref: https://github.com/rancher/wrangler/blob/v3.0.0/pkg/clients/clients.go#L184-L190 +func GetRESTConfig(file string) (*rest.Config, error) { + restConfig, err := clientcmd.BuildConfigFromFlags("", file) + if err != nil { + return nil, err + } + restConfig.Timeout = 15 * time.Minute + restConfig.RateLimiter = ratelimit.None + return restConfig, nil +} + // GetUserAgent builds a complete UserAgent string for a given controller, including the node name if possible. func GetUserAgent(controllerName string) string { nodeName := os.Getenv("NODE_NAME") From a55fe8780d3ee50673c94f839c0ab7f1394a85e6 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 22 Nov 2024 01:27:18 +0000 Subject: [PATCH 08/14] Return apiserver addresses from both etcd and endpoints Signed-off-by: Brad Davidson (cherry picked from commit 168b344d1d40d3d9c86452cd0b1b38fca7ab7196) Signed-off-by: Brad Davidson --- pkg/server/router.go | 95 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 81 insertions(+), 14 deletions(-) diff --git a/pkg/server/router.go b/pkg/server/router.go index b4a2dc57cbb7..fca554027880 100644 --- a/pkg/server/router.go +++ b/pkg/server/router.go @@ -19,6 +19,7 @@ import ( "github.com/k3s-io/k3s/pkg/bootstrap" "github.com/k3s-io/k3s/pkg/cli/cmds" "github.com/k3s-io/k3s/pkg/daemons/config" + "github.com/k3s-io/k3s/pkg/etcd" "github.com/k3s-io/k3s/pkg/nodepassword" "github.com/k3s-io/k3s/pkg/server/auth" "github.com/k3s-io/k3s/pkg/util" @@ -31,6 +32,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/json" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/endpoints/request" @@ -305,22 +307,15 @@ func fileHandler(fileName ...string) http.Handler { }) } +// apiserversHandler returns a list of apiserver addresses. +// It attempts to merge results from both the apiserver and directly from etcd, +// in case we are recovering from an apiserver outage that rendered the endpoint list unavailable. func apiserversHandler(server *config.Control) http.Handler { - var endpointsClient typedcorev1.EndpointsInterface + collectAddresses := getAddressCollector(server) return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - ctx := req.Context() - var endpoints []string - if endpointsClient == nil { - if server.Runtime.Core != nil { - endpointsClient = server.Runtime.K8s.CoreV1().Endpoints(metav1.NamespaceDefault) - } - } - if endpointsClient != nil { - if endpoint, _ := endpointsClient.Get(ctx, "kubernetes", metav1.GetOptions{}); endpoint != nil { - endpoints = util.GetAddresses(endpoint) - } - } - + ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) + defer cancel() + endpoints := collectAddresses(ctx) resp.Header().Set("content-type", "application/json") if err := json.NewEncoder(resp).Encode(endpoints); err != nil { util.SendError(errors.Wrap(err, "failed to encode apiserver endpoints"), resp, req, http.StatusInternalServerError) @@ -526,3 +521,75 @@ func ensureSecret(ctx context.Context, config *Config, node *nodeInfo) { return false, nil }) } + +// addressGetter is a common signature for functions that return an address channel +type addressGetter func(ctx context.Context) <-chan []string + +// kubernetesGetter returns a function that returns a channel that can be read to get apiserver addresses from kubernetes endpoints +func kubernetesGetter(server *config.Control) addressGetter { + var endpointsClient typedcorev1.EndpointsInterface + return func(ctx context.Context) <-chan []string { + ch := make(chan []string, 1) + go func() { + if endpointsClient == nil { + if server.Runtime.K8s != nil { + endpointsClient = server.Runtime.K8s.CoreV1().Endpoints(metav1.NamespaceDefault) + } + } + if endpointsClient != nil { + if endpoint, err := endpointsClient.Get(ctx, "kubernetes", metav1.GetOptions{}); err != nil { + logrus.Debugf("Failed to get apiserver addresses from kubernetes: %v", err) + } else { + ch <- util.GetAddresses(endpoint) + } + } + close(ch) + }() + return ch + } +} + +// etcdGetter returns a function that returns a channel that can be read to get apiserver addresses from etcd +func etcdGetter(server *config.Control) addressGetter { + return func(ctx context.Context) <-chan []string { + ch := make(chan []string, 1) + go func() { + if addresses, err := etcd.GetAPIServerURLsFromETCD(ctx, server); err != nil { + logrus.Debugf("Failed to get apiserver addresses from etcd: %v", err) + } else { + ch <- addresses + } + close(ch) + }() + return ch + } +} + +// getAddressCollector returns a function that can be called to return +// apiserver addresses from both kubernetes and etcd +func getAddressCollector(server *config.Control) func(ctx context.Context) []string { + getFromKubernetes := kubernetesGetter(server) + getFromEtcd := etcdGetter(server) + + // read from both kubernetes and etcd in parallel, returning the collected results + return func(ctx context.Context) []string { + a := sets.Set[string]{} + r := []string{} + k8sCh := getFromKubernetes(ctx) + k8sOk := true + etcdCh := getFromEtcd(ctx) + etcdOk := true + + for k8sOk || etcdOk { + select { + case r, k8sOk = <-k8sCh: + a.Insert(r...) + case r, etcdOk = <-etcdCh: + a.Insert(r...) + case <-ctx.Done(): + return a.UnsortedList() + } + } + return a.UnsortedList() + } +} From 67385e97b55741afc0a82b7d326818b4fcf4b93e Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 22 Nov 2024 01:47:17 +0000 Subject: [PATCH 09/14] Fall back to polling the supervisor for apiserver addresses when the watch fails Signed-off-by: Brad Davidson (cherry picked from commit c7ff957cae69d618e7a0b80c7e6c3936f29396c3) Signed-off-by: Brad Davidson --- pkg/agent/config/config.go | 25 ++++--- pkg/agent/tunnel/tunnel.go | 149 ++++++++++++++++++++++++------------- 2 files changed, 113 insertions(+), 61 deletions(-) diff --git a/pkg/agent/config/config.go b/pkg/agent/config/config.go index 4883297713ad..26043c8bd8e4 100644 --- a/pkg/agent/config/config.go +++ b/pkg/agent/config/config.go @@ -91,15 +91,24 @@ func KubeProxyDisabled(ctx context.Context, node *config.Node, proxy proxy.Proxy return disabled } -// APIServers returns a list of apiserver endpoints, suitable for seeding client loadbalancer configurations. +// WaitForAPIServers returns a list of apiserver endpoints, suitable for seeding client loadbalancer configurations. // This function will block until it can return a populated list of apiservers, or if the remote server returns // an error (indicating that it does not support this functionality). -func APIServers(ctx context.Context, node *config.Node, proxy proxy.Proxy) []string { +func WaitForAPIServers(ctx context.Context, node *config.Node, proxy proxy.Proxy) []string { var addresses []string + var info *clientaccess.Info var err error _ = wait.PollUntilContextCancel(ctx, 5*time.Second, true, func(ctx context.Context) (bool, error) { - addresses, err = getAPIServers(ctx, node, proxy) + if info == nil { + withCert := clientaccess.WithClientCertificate(node.AgentConfig.ClientKubeletCert, node.AgentConfig.ClientKubeletKey) + info, err = clientaccess.ParseAndValidateToken(proxy.SupervisorURL(), node.Token, withCert) + if err != nil { + logrus.Warnf("Failed to validate server token: %v", err) + return false, nil + } + } + addresses, err = GetAPIServers(ctx, info) if err != nil { logrus.Infof("Failed to retrieve list of apiservers from server: %v", err) return false, err @@ -760,14 +769,8 @@ func get(ctx context.Context, envInfo *cmds.Agent, proxy proxy.Proxy) (*config.N return nodeConfig, nil } -// getAPIServers attempts to return a list of apiservers from the server. -func getAPIServers(ctx context.Context, node *config.Node, proxy proxy.Proxy) ([]string, error) { - withCert := clientaccess.WithClientCertificate(node.AgentConfig.ClientKubeletCert, node.AgentConfig.ClientKubeletKey) - info, err := clientaccess.ParseAndValidateToken(proxy.SupervisorURL(), node.Token, withCert) - if err != nil { - return nil, err - } - +// GetAPIServers attempts to return a list of apiservers from the server. +func GetAPIServers(ctx context.Context, info *clientaccess.Info) ([]string, error) { data, err := info.Get("/v1-" + version.Program + "/apiservers") if err != nil { return nil, err diff --git a/pkg/agent/tunnel/tunnel.go b/pkg/agent/tunnel/tunnel.go index e1eb566cc675..2fe031ce627f 100644 --- a/pkg/agent/tunnel/tunnel.go +++ b/pkg/agent/tunnel/tunnel.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "os" - "reflect" "strconv" "sync" "time" @@ -16,6 +15,7 @@ import ( agentconfig "github.com/k3s-io/k3s/pkg/agent/config" "github.com/k3s-io/k3s/pkg/agent/loadbalancer" "github.com/k3s-io/k3s/pkg/agent/proxy" + "github.com/k3s-io/k3s/pkg/clientaccess" daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/version" @@ -27,6 +27,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/kubernetes" @@ -138,17 +139,18 @@ func Setup(ctx context.Context, config *daemonconfig.Node, proxy proxy.Proxy) er // connecting to. If that fails, fall back to querying the endpoints list from Kubernetes. This // fallback requires that the server we're joining be running an apiserver, but is the only safe // thing to do if its supervisor is down-level and can't provide us with an endpoint list. - addresses := agentconfig.APIServers(ctx, config, proxy) - logrus.Infof("Got apiserver addresses from supervisor: %v", addresses) - + addresses := agentconfig.WaitForAPIServers(ctx, config, proxy) if len(addresses) > 0 { + logrus.Infof("Got apiserver addresses from supervisor: %v", addresses) if localSupervisorDefault { proxy.SetSupervisorDefault(addresses[0]) } proxy.Update(addresses) } else { - if endpoint, _ := client.CoreV1().Endpoints(metav1.NamespaceDefault).Get(ctx, "kubernetes", metav1.GetOptions{}); endpoint != nil { - addresses = util.GetAddresses(endpoint) + if endpoint, err := client.CoreV1().Endpoints(metav1.NamespaceDefault).Get(ctx, "kubernetes", metav1.GetOptions{}); err != nil { + logrus.Errorf("Failed to get apiserver addresses from kubernetes endpoints: %v", err) + } else { + addresses := util.GetAddresses(endpoint) logrus.Infof("Got apiserver addresses from kubernetes endpoints: %v", addresses) if len(addresses) > 0 { proxy.Update(addresses) @@ -159,7 +161,7 @@ func Setup(ctx context.Context, config *daemonconfig.Node, proxy proxy.Proxy) er wg := &sync.WaitGroup{} - go tunnel.watchEndpoints(ctx, apiServerReady, wg, tlsConfig, proxy) + go tunnel.watchEndpoints(ctx, apiServerReady, wg, tlsConfig, config, proxy) wait := make(chan int, 1) go func() { @@ -302,23 +304,21 @@ func (a *agentTunnel) watchPods(ctx context.Context, apiServerReady <-chan struc // WatchEndpoints attempts to create tunnels to all supervisor addresses. Once the // apiserver is up, go into a watch loop, adding and removing tunnels as endpoints come // and go from the cluster. -func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan struct{}, wg *sync.WaitGroup, tlsConfig *tls.Config, proxy proxy.Proxy) { - // Attempt to connect to supervisors, storing their cancellation function for later when we - // need to disconnect. - disconnect := map[string]context.CancelFunc{} - for _, address := range proxy.SupervisorAddresses() { - if _, ok := disconnect[address]; !ok { - conn := a.connect(ctx, wg, address, tlsConfig) - disconnect[address] = conn.cancel - proxy.SetHealthCheck(address, conn.healthCheck) - } - } +func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan struct{}, wg *sync.WaitGroup, tlsConfig *tls.Config, node *daemonconfig.Node, proxy proxy.Proxy) { + syncProxyAddresses := a.getProxySyncer(ctx, wg, tlsConfig, proxy) + refreshFromSupervisor := getAPIServersRequester(node, proxy, syncProxyAddresses) <-apiServerReady + endpoints := a.client.CoreV1().Endpoints(metav1.NamespaceDefault) fieldSelector := fields.Set{metav1.ObjectNameField: "kubernetes"}.String() lw := &cache.ListWatch{ ListFunc: func(options metav1.ListOptions) (object runtime.Object, e error) { + // if we're being called to re-list, then likely there was an + // interruption to the apiserver connection and the listwatch is retrying + // its connection. This is a good suggestion that it might be necessary + // to refresh the apiserver address from the supervisor. + go refreshFromSupervisor(ctx) options.FieldSelector = fieldSelector return endpoints.List(ctx, options) }, @@ -364,38 +364,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan // goroutine that sleeps for a short period before checking for changes and updating // the proxy addresses. If another update occurs, the previous update operation // will be cancelled and a new one queued. - go func() { - select { - case <-time.After(endpointDebounceDelay): - case <-debounceCtx.Done(): - return - } - - newAddresses := util.GetAddresses(endpoint) - if reflect.DeepEqual(newAddresses, proxy.SupervisorAddresses()) { - return - } - proxy.Update(newAddresses) - - validEndpoint := map[string]bool{} - - for _, address := range proxy.SupervisorAddresses() { - validEndpoint[address] = true - if _, ok := disconnect[address]; !ok { - conn := a.connect(ctx, nil, address, tlsConfig) - disconnect[address] = conn.cancel - proxy.SetHealthCheck(address, conn.healthCheck) - } - } - - for address, cancel := range disconnect { - if !validEndpoint[address] { - cancel() - delete(disconnect, address) - logrus.Infof("Stopped tunnel to %s", address) - } - } - }() + go syncProxyAddresses(debounceCtx, util.GetAddresses(endpoint)) } } } @@ -507,3 +476,83 @@ func (a *agentTunnel) dialContext(ctx context.Context, network, address string) } return defaultDialer.DialContext(ctx, network, address) } + +// proxySyncer is a common signature for functions that sync the proxy address list with a context +type proxySyncer func(ctx context.Context, addresses []string) + +// getProxySyncer returns a function that can be called to update the list of supervisors. +// This function is responsible for connecting to or disconnecting websocket tunnels, +// as well as updating the proxy loadbalancer server list. +func (a *agentTunnel) getProxySyncer(ctx context.Context, wg *sync.WaitGroup, tlsConfig *tls.Config, proxy proxy.Proxy) proxySyncer { + disconnect := map[string]context.CancelFunc{} + // Attempt to connect to supervisors, storing their cancellation function for later when we + // need to disconnect. + for _, address := range proxy.SupervisorAddresses() { + if _, ok := disconnect[address]; !ok { + conn := a.connect(ctx, wg, address, tlsConfig) + disconnect[address] = conn.cancel + proxy.SetHealthCheck(address, conn.healthCheck) + } + } + + // return a function that can be called to update the address list. + // servers will be connected to or disconnected from as necessary, + // and the proxy addresses updated. + return func(debounceCtx context.Context, addresses []string) { + select { + case <-time.After(endpointDebounceDelay): + case <-debounceCtx.Done(): + return + } + + newAddresses := sets.New(addresses...) + curAddresses := sets.New(proxy.SupervisorAddresses()...) + if newAddresses.Equal(curAddresses) { + return + } + + proxy.Update(addresses) + + // add new servers + for address := range newAddresses.Difference(curAddresses) { + if _, ok := disconnect[address]; !ok { + conn := a.connect(ctx, nil, address, tlsConfig) + logrus.Infof("Started tunnel to %s", address) + disconnect[address] = conn.cancel + proxy.SetHealthCheck(address, conn.healthCheck) + } + } + + // remove old servers + for address := range curAddresses.Difference(newAddresses) { + if cancel, ok := disconnect[address]; ok { + cancel() + delete(disconnect, address) + logrus.Infof("Stopped tunnel to %s", address) + } + } + } +} + +// getAPIServersRequester returns a function that can be called to update the +// proxy apiserver endpoints with addresses retrieved from the supervisor. +func getAPIServersRequester(node *daemonconfig.Node, proxy proxy.Proxy, syncProxyAddresses proxySyncer) func(ctx context.Context) { + var info *clientaccess.Info + return func(ctx context.Context) { + if info == nil { + var err error + withCert := clientaccess.WithClientCertificate(node.AgentConfig.ClientKubeletCert, node.AgentConfig.ClientKubeletKey) + info, err = clientaccess.ParseAndValidateToken(proxy.SupervisorURL(), node.Token, withCert) + if err != nil { + logrus.Warnf("Failed to validate server token: %v", err) + return + } + } + + if addresses, err := agentconfig.GetAPIServers(ctx, info); err != nil { + logrus.Warnf("Failed to get apiserver addresses from supervisor: %v", err) + } else { + syncProxyAddresses(ctx, addresses) + } + } +} From 647e99fc38a9e71c395289b0020f6ec148337777 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 22 Nov 2024 17:23:41 +0000 Subject: [PATCH 10/14] Add command output to test failure message Signed-off-by: Brad Davidson (cherry picked from commit 81dda9d626e55e7912833670c3cc88caed53d898) Signed-off-by: Brad Davidson --- tests/e2e/validatecluster/validatecluster_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/validatecluster/validatecluster_test.go b/tests/e2e/validatecluster/validatecluster_test.go index accae34dadf8..efc2003fd8ec 100644 --- a/tests/e2e/validatecluster/validatecluster_test.go +++ b/tests/e2e/validatecluster/validatecluster_test.go @@ -238,8 +238,8 @@ var _ = Describe("Verify Create", Ordered, func() { }, "420s", "2s").Should(Succeed()) cmd := "kubectl --kubeconfig=" + kubeConfigFile + " exec volume-test -- sh -c 'echo local-path-test > /data/test'" - _, err = e2e.RunCommand(cmd) - Expect(err).NotTo(HaveOccurred()) + res, err = e2e.RunCommand(cmd) + Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) cmd = "kubectl delete pod volume-test --kubeconfig=" + kubeConfigFile res, err = e2e.RunCommand(cmd) From d6682c2108fede34f2cb27c396237e6410fbb5ac Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 22 Nov 2024 18:11:41 +0000 Subject: [PATCH 11/14] Fix integration test failure message The error message should be printf style, not just concatenated. The current message is garbled if the command or result contains things that look like formatting directives: `Internal error occurred: error sending request: Post "https://10.10.10.102:10250/exec/default/volume-test/volume-test?command=sh&command=-c&command=echo+local-path-test+%!!(MISSING)E(MISSING)+%!!(MISSING)F(MISSING)data%!!(MISSING)F(MISSING)test&error=1&output=1": proxy error from 127.0.0.1:6443 while dialing 10.10.10.102:10250, code 502: 502 Bad Gateway` Signed-off-by: Brad Davidson (cherry picked from commit 45195e2654853becc27b312978e5b79f6cfce8ac) Signed-off-by: Brad Davidson --- .../snapshotrestore/snapshotrestore_test.go | 2 +- .../validatecluster/validatecluster_test.go | 26 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/tests/e2e/snapshotrestore/snapshotrestore_test.go b/tests/e2e/snapshotrestore/snapshotrestore_test.go index dc47907f78c7..20deef629e74 100644 --- a/tests/e2e/snapshotrestore/snapshotrestore_test.go +++ b/tests/e2e/snapshotrestore/snapshotrestore_test.go @@ -95,7 +95,7 @@ var _ = Describe("Verify snapshots and cluster restores work", Ordered, func() { cmd := "kubectl get pods -o=name -l k8s-app=nginx-app-clusterip --field-selector=status.phase=Running --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) g.Expect(err).NotTo(HaveOccurred()) - g.Expect(res).Should((ContainSubstring("test-clusterip")), "failed cmd: "+cmd+" result: "+res) + g.Expect(res).Should((ContainSubstring("test-clusterip")), "failed cmd: %q result: %s", cmd, res) }, "240s", "5s").Should(Succeed()) }) diff --git a/tests/e2e/validatecluster/validatecluster_test.go b/tests/e2e/validatecluster/validatecluster_test.go index efc2003fd8ec..8db46e673cfa 100644 --- a/tests/e2e/validatecluster/validatecluster_test.go +++ b/tests/e2e/validatecluster/validatecluster_test.go @@ -95,7 +95,7 @@ var _ = Describe("Verify Create", Ordered, func() { cmd := "kubectl get pods -o=name -l k8s-app=nginx-app-clusterip --field-selector=status.phase=Running --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) Expect(err).NotTo(HaveOccurred()) - g.Expect(res).Should((ContainSubstring("test-clusterip")), "failed cmd: "+cmd+" result: "+res) + g.Expect(res).Should((ContainSubstring("test-clusterip")), "failed cmd: %q result: %s", cmd, res) }, "240s", "5s").Should(Succeed()) clusterip, _ := e2e.FetchClusterIP(kubeConfigFile, "nginx-clusterip-svc", false) @@ -130,7 +130,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("test-nodeport")) }, "240s", "5s").Should(Succeed()) } @@ -150,14 +150,14 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl get pods -o=name -l k8s-app=nginx-app-loadbalancer --field-selector=status.phase=Running --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("test-loadbalancer")) }, "240s", "5s").Should(Succeed()) Eventually(func(g Gomega) { cmd = "curl -L --insecure http://" + ip + ":" + port + "/name.html" res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("test-loadbalancer")) }, "240s", "5s").Should(Succeed()) } @@ -174,7 +174,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("test-ingress")) }, "240s", "5s").Should(Succeed()) } @@ -204,7 +204,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl get pods dnsutils --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("dnsutils")) }, "420s", "2s").Should(Succeed()) @@ -212,7 +212,7 @@ var _ = Describe("Verify Create", Ordered, func() { cmd := "kubectl --kubeconfig=" + kubeConfigFile + " exec -i -t dnsutils -- nslookup kubernetes.default" res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("kubernetes.default.svc.cluster.local")) }, "420s", "2s").Should(Succeed()) }) @@ -224,7 +224,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl get pvc local-path-pvc --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("local-path-pvc")) g.Expect(res).Should(ContainSubstring("Bound")) }, "420s", "2s").Should(Succeed()) @@ -232,18 +232,18 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl get pod volume-test --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("volume-test")) g.Expect(res).Should(ContainSubstring("Running")) }, "420s", "2s").Should(Succeed()) cmd := "kubectl --kubeconfig=" + kubeConfigFile + " exec volume-test -- sh -c 'echo local-path-test > /data/test'" res, err = e2e.RunCommand(cmd) - Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) cmd = "kubectl delete pod volume-test --kubeconfig=" + kubeConfigFile res, err = e2e.RunCommand(cmd) - Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) _, err = e2e.DeployWorkload("local-path-provisioner.yaml", kubeConfigFile, *hardened) Expect(err).NotTo(HaveOccurred(), "local-path-provisioner manifest not deployed") @@ -257,7 +257,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl get pod volume-test --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("volume-test")) g.Expect(res).Should(ContainSubstring("Running")) @@ -266,7 +266,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl exec volume-test --kubeconfig=" + kubeConfigFile + " -- cat /data/test" res, err = e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) fmt.Println("Data after re-creation", res) g.Expect(res).Should(ContainSubstring("local-path-test")) }, "180s", "2s").Should(Succeed()) From 3f9d640700f0d2d65ee921cdb5a232b64dee6390 Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Fri, 22 Nov 2024 18:37:17 +0000 Subject: [PATCH 12/14] Tail journald logs into report on suite failure Signed-off-by: Brad Davidson (cherry picked from commit e9cf3a7ab5b67ef4484b6ea6d1e306b0e9949f13) Signed-off-by: Brad Davidson --- tests/e2e/dualstack/dualstack_test.go | 4 +++- tests/e2e/embeddedmirror/embeddedmirror_test.go | 4 +++- tests/e2e/externalip/externalip_test.go | 4 +++- tests/e2e/privateregistry/privateregistry_test.go | 5 +++-- tests/e2e/rootless/rootless_test.go | 4 +++- tests/e2e/rotateca/rotateca_test.go | 4 +++- tests/e2e/s3/s3_test.go | 5 +++-- .../secretsencryption/secretsencryption_test.go | 4 +++- tests/e2e/snapshotrestore/snapshotrestore_test.go | 4 +++- tests/e2e/splitserver/splitserver_test.go | 8 +++++--- tests/e2e/startup/startup_test.go | 4 +++- .../svcpoliciesandfirewall_test.go | 9 +++++---- tests/e2e/tailscale/tailscale_test.go | 4 +++- tests/e2e/testutils.go | 15 ++++++++++++++- tests/e2e/token/token_test.go | 4 +++- tests/e2e/upgradecluster/upgradecluster_test.go | 15 ++++++++------- tests/e2e/validatecluster/validatecluster_test.go | 4 +++- tests/e2e/wasm/wasm_test.go | 6 ++++-- 18 files changed, 75 insertions(+), 32 deletions(-) diff --git a/tests/e2e/dualstack/dualstack_test.go b/tests/e2e/dualstack/dualstack_test.go index c9612f9b7142..9262af922cea 100644 --- a/tests/e2e/dualstack/dualstack_test.go +++ b/tests/e2e/dualstack/dualstack_test.go @@ -195,7 +195,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/embeddedmirror/embeddedmirror_test.go b/tests/e2e/embeddedmirror/embeddedmirror_test.go index 7188b552b988..089fb465277b 100644 --- a/tests/e2e/embeddedmirror/embeddedmirror_test.go +++ b/tests/e2e/embeddedmirror/embeddedmirror_test.go @@ -146,7 +146,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/externalip/externalip_test.go b/tests/e2e/externalip/externalip_test.go index 524bb8340276..9d2150991924 100644 --- a/tests/e2e/externalip/externalip_test.go +++ b/tests/e2e/externalip/externalip_test.go @@ -165,7 +165,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/privateregistry/privateregistry_test.go b/tests/e2e/privateregistry/privateregistry_test.go index 856f49b596c6..fe25a94e2181 100644 --- a/tests/e2e/privateregistry/privateregistry_test.go +++ b/tests/e2e/privateregistry/privateregistry_test.go @@ -149,8 +149,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/rootless/rootless_test.go b/tests/e2e/rootless/rootless_test.go index 361778c72db7..4a205934e3d5 100644 --- a/tests/e2e/rootless/rootless_test.go +++ b/tests/e2e/rootless/rootless_test.go @@ -167,7 +167,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, serverNodeNames)) + } else { Expect(e2e.GetCoverageReport(serverNodeNames)).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/rotateca/rotateca_test.go b/tests/e2e/rotateca/rotateca_test.go index c43ab4d10899..3a6f2b0ca14f 100644 --- a/tests/e2e/rotateca/rotateca_test.go +++ b/tests/e2e/rotateca/rotateca_test.go @@ -138,7 +138,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/s3/s3_test.go b/tests/e2e/s3/s3_test.go index fc3be6a5fde4..b61824525934 100644 --- a/tests/e2e/s3/s3_test.go +++ b/tests/e2e/s3/s3_test.go @@ -175,8 +175,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/secretsencryption/secretsencryption_test.go b/tests/e2e/secretsencryption/secretsencryption_test.go index 187dcedba2fc..763e2f0ba381 100644 --- a/tests/e2e/secretsencryption/secretsencryption_test.go +++ b/tests/e2e/secretsencryption/secretsencryption_test.go @@ -221,7 +221,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, serverNodeNames)) + } else { Expect(e2e.GetCoverageReport(serverNodeNames)).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/snapshotrestore/snapshotrestore_test.go b/tests/e2e/snapshotrestore/snapshotrestore_test.go index 20deef629e74..f9ca105cb24b 100644 --- a/tests/e2e/snapshotrestore/snapshotrestore_test.go +++ b/tests/e2e/snapshotrestore/snapshotrestore_test.go @@ -317,7 +317,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/splitserver/splitserver_test.go b/tests/e2e/splitserver/splitserver_test.go index c78520d67b41..642dbc1592e3 100644 --- a/tests/e2e/splitserver/splitserver_test.go +++ b/tests/e2e/splitserver/splitserver_test.go @@ -283,9 +283,11 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { - allNodes := append(cpNodeNames, etcdNodeNames...) - allNodes = append(allNodes, agentNodeNames...) + allNodes := append(cpNodeNames, etcdNodeNames...) + allNodes = append(allNodes, agentNodeNames...) + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, allNodes)) + } else { Expect(e2e.GetCoverageReport(allNodes)).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/startup/startup_test.go b/tests/e2e/startup/startup_test.go index c926164fac14..fd1d71872490 100644 --- a/tests/e2e/startup/startup_test.go +++ b/tests/e2e/startup/startup_test.go @@ -310,7 +310,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/svcpoliciesandfirewall/svcpoliciesandfirewall_test.go b/tests/e2e/svcpoliciesandfirewall/svcpoliciesandfirewall_test.go index 53128947a234..dec419e176c4 100644 --- a/tests/e2e/svcpoliciesandfirewall/svcpoliciesandfirewall_test.go +++ b/tests/e2e/svcpoliciesandfirewall/svcpoliciesandfirewall_test.go @@ -128,7 +128,7 @@ var _ = Describe("Verify Services Traffic policies and firewall config", Ordered Eventually(func(g Gomega) { externalIPs, _ := e2e.FetchExternalIPs(kubeConfigFile, lbSvcExt) g.Expect(externalIPs).To(HaveLen(1), "more than 1 exernalIP found") - g.Expect(externalIPs[0]).To(Equal(serverNodeIP),"external IP does not match servernodeIP") + g.Expect(externalIPs[0]).To(Equal(serverNodeIP), "external IP does not match servernodeIP") }, "25s", "5s").Should(Succeed()) }) @@ -154,7 +154,6 @@ var _ = Describe("Verify Services Traffic policies and firewall config", Ordered return e2e.RunCommand(cmd) }, "25s", "5s").ShouldNot(ContainSubstring("10.42")) - // Verify connectivity to the other nodeIP does not work because of external traffic policy=local for _, externalIP := range lbSvcExternalIPs { if externalIP == lbSvcExtExternalIPs[0] { @@ -250,7 +249,7 @@ var _ = Describe("Verify Services Traffic policies and firewall config", Ordered )) // Check the non working command fails because of internal traffic policy=local - Eventually(func() (bool) { + Eventually(func() bool { _, err := e2e.RunCommand(nonWorkingCmd) if err != nil && strings.Contains(err.Error(), "exit status") { // Treat exit status as a successful condition @@ -348,7 +347,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/tailscale/tailscale_test.go b/tests/e2e/tailscale/tailscale_test.go index 3def1ac41ab5..449840e4f990 100644 --- a/tests/e2e/tailscale/tailscale_test.go +++ b/tests/e2e/tailscale/tailscale_test.go @@ -118,7 +118,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/testutils.go b/tests/e2e/testutils.go index 950a3c8af896..29edcc7f951d 100644 --- a/tests/e2e/testutils.go +++ b/tests/e2e/testutils.go @@ -48,7 +48,7 @@ type NodeError struct { type SvcExternalIP struct { IP string `json:"ip"` - ipMode string `json:"ipMode"` + IPMode string `json:"ipMode"` } type ObjIP struct { @@ -364,6 +364,19 @@ func GetJournalLogs(node string) (string, error) { return RunCmdOnNode(cmd, node) } +func TailJournalLogs(lines int, nodes []string) string { + logs := &strings.Builder{} + for _, node := range nodes { + cmd := fmt.Sprintf("journalctl -u k3s* --no-pager --lines=%d", lines) + if l, err := RunCmdOnNode(cmd, node); err != nil { + fmt.Fprintf(logs, "** failed to read journald log for node %s ***\n%v\n", node, err) + } else { + fmt.Fprintf(logs, "** journald log for node %s ***\n%s\n", node, l) + } + } + return logs.String() +} + // GetVagrantLog returns the logs of on vagrant commands that initialize the nodes and provision K3s on each node. // It also attempts to fetch the systemctl logs of K3s on nodes where the k3s.service failed. func GetVagrantLog(cErr error) string { diff --git a/tests/e2e/token/token_test.go b/tests/e2e/token/token_test.go index bd0cc38a1fc8..3b3c011d6ae7 100644 --- a/tests/e2e/token/token_test.go +++ b/tests/e2e/token/token_test.go @@ -202,7 +202,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/upgradecluster/upgradecluster_test.go b/tests/e2e/upgradecluster/upgradecluster_test.go index 18bd1cbee7b1..fab93a6bbd89 100644 --- a/tests/e2e/upgradecluster/upgradecluster_test.go +++ b/tests/e2e/upgradecluster/upgradecluster_test.go @@ -215,14 +215,13 @@ var _ = Describe("Verify Upgrade", Ordered, func() { }, "420s", "2s").Should(Succeed()) cmd := "kubectl --kubeconfig=" + kubeConfigFile + " exec volume-test -- sh -c 'echo local-path-test > /data/test'" - _, err = e2e.RunCommand(cmd) - Expect(err).NotTo(HaveOccurred()) + res, err := e2e.RunCommand(cmd) + Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) fmt.Println("Data stored in pvc: local-path-test") cmd = "kubectl delete pod volume-test --kubeconfig=" + kubeConfigFile - res, err := e2e.RunCommand(cmd) - Expect(err).NotTo(HaveOccurred()) - fmt.Println(res) + res, err = e2e.RunCommand(cmd) + Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) _, err = e2e.DeployWorkload("local-path-provisioner.yaml", kubeConfigFile, *hardened) Expect(err).NotTo(HaveOccurred(), "local-path-provisioner manifest not deployed") @@ -245,7 +244,7 @@ var _ = Describe("Verify Upgrade", Ordered, func() { Eventually(func() (string, error) { cmd := "kubectl exec volume-test --kubeconfig=" + kubeConfigFile + " -- cat /data/test" return e2e.RunCommand(cmd) - }, "180s", "2s").Should(ContainSubstring("local-path-test")) + }, "180s", "2s").Should(ContainSubstring("local-path-test"), "Failed to retrieve data from pvc") }) It("Upgrades with no issues", func() { @@ -385,7 +384,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/validatecluster/validatecluster_test.go b/tests/e2e/validatecluster/validatecluster_test.go index 8db46e673cfa..2c4807cce98d 100644 --- a/tests/e2e/validatecluster/validatecluster_test.go +++ b/tests/e2e/validatecluster/validatecluster_test.go @@ -381,7 +381,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/wasm/wasm_test.go b/tests/e2e/wasm/wasm_test.go index 1e887a086a29..7fa216088b35 100644 --- a/tests/e2e/wasm/wasm_test.go +++ b/tests/e2e/wasm/wasm_test.go @@ -135,10 +135,12 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if failed && !*ci { - fmt.Println("FAILED!") + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) + } + if !failed || *ci { Expect(e2e.DestroyCluster()).To(Succeed()) Expect(os.Remove(kubeConfigFile)).To(Succeed()) } From 9bfde88ccdda5db950f2762887b099d39bf9d0dc Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Sat, 7 Dec 2024 00:30:25 +0000 Subject: [PATCH 13/14] Fix agent tunnel address on rke2 Fix issue where rke2 tunnel was trying to connect to apiserver port instead of supervisor Signed-off-by: Brad Davidson (cherry picked from commit 5a5b1361519805f0b7a653be82b3a140dc691a9d) Signed-off-by: Brad Davidson --- pkg/agent/tunnel/tunnel.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pkg/agent/tunnel/tunnel.go b/pkg/agent/tunnel/tunnel.go index 2fe031ce627f..98094a8d02dc 100644 --- a/pkg/agent/tunnel/tunnel.go +++ b/pkg/agent/tunnel/tunnel.go @@ -505,13 +505,14 @@ func (a *agentTunnel) getProxySyncer(ctx context.Context, wg *sync.WaitGroup, tl return } - newAddresses := sets.New(addresses...) + // Compare list of supervisor addresses before and after syncing apiserver + // endpoints into the proxy to figure out which supervisors we need to connect to + // or disconnect from. Note that the addresses we were passed will not match + // the supervisor addresses if the supervisor and apiserver are on different ports - + // they must be round-tripped through proxy.Update before comparing. curAddresses := sets.New(proxy.SupervisorAddresses()...) - if newAddresses.Equal(curAddresses) { - return - } - proxy.Update(addresses) + newAddresses := sets.New(proxy.SupervisorAddresses()...) // add new servers for address := range newAddresses.Difference(curAddresses) { From d1aa856519370527eb7483ac5d37a398a11cdf2a Mon Sep 17 00:00:00 2001 From: Brad Davidson Date: Mon, 9 Dec 2024 18:06:28 +0000 Subject: [PATCH 14/14] Add hidden flag/var for supervisor/apiserver listen config Add flags supervisor and apiserver ports and bind address so that we can add an e2e to cover supervisor and apiserver on separate ports, as used by rke2 Signed-off-by: Brad Davidson (cherry picked from commit e143e0fa12033fe2331558ea8e4ca86e813fdbb5) Signed-off-by: Brad Davidson --- pkg/cli/cmds/server.go | 23 ++++++++- tests/e2e/startup/startup_test.go | 84 +++++++++++++++++++++++++++++++ tests/e2e/testutils.go | 13 +++++ 3 files changed, 119 insertions(+), 1 deletion(-) diff --git a/pkg/cli/cmds/server.go b/pkg/cli/cmds/server.go index c398fc14c1bb..e2eee6ae022d 100644 --- a/pkg/cli/cmds/server.go +++ b/pkg/cli/cmds/server.go @@ -188,6 +188,27 @@ var ServerFlags = []cli.Flag{ Value: 6443, Destination: &ServerConfig.HTTPSPort, }, + &cli.IntFlag{ + Name: "supervisor-port", + EnvVar: version.ProgramUpper + "_SUPERVISOR_PORT", + Usage: "(experimental) Supervisor listen port override", + Hidden: true, + Destination: &ServerConfig.SupervisorPort, + }, + &cli.IntFlag{ + Name: "apiserver-port", + EnvVar: version.ProgramUpper + "_APISERVER_PORT", + Usage: "(experimental) apiserver internal listen port override", + Hidden: true, + Destination: &ServerConfig.APIServerPort, + }, + &cli.StringFlag{ + Name: "apiserver-bind-address", + EnvVar: version.ProgramUpper + "_APISERVER_BIND_ADDRESS", + Usage: "(experimental) apiserver internal bind address override", + Hidden: true, + Destination: &ServerConfig.APIServerBindAddress, + }, &cli.StringFlag{ Name: "advertise-address", Usage: "(listener) IPv4/IPv6 address that apiserver uses to advertise to members of the cluster (default: node-external-ip/node-ip)", @@ -195,7 +216,7 @@ var ServerFlags = []cli.Flag{ }, &cli.IntFlag{ Name: "advertise-port", - Usage: "(listener) Port that apiserver uses to advertise to members of the cluster (default: listen-port)", + Usage: "(listener) Port that apiserver uses to advertise to members of the cluster (default: https-listen-port)", Destination: &ServerConfig.AdvertisePort, }, &cli.StringSliceFlag{ diff --git a/tests/e2e/startup/startup_test.go b/tests/e2e/startup/startup_test.go index fd1d71872490..3c7dd13627ea 100644 --- a/tests/e2e/startup/startup_test.go +++ b/tests/e2e/startup/startup_test.go @@ -71,6 +71,12 @@ func KillK3sCluster(nodes []string) error { if _, err := e2e.RunCmdOnNode("k3s-killall.sh", node); err != nil { return err } + if _, err := e2e.RunCmdOnNode("journalctl --flush --sync --rotate --vacuum-size=1", node); err != nil { + return err + } + if _, err := e2e.RunCmdOnNode("rm -rf /etc/rancher/k3s/config.yaml.d", node); err != nil { + return err + } if strings.Contains(node, "server") { if _, err := e2e.RunCmdOnNode("rm -rf /var/lib/rancher/k3s/server/db", node); err != nil { return err @@ -93,6 +99,83 @@ var _ = BeforeSuite(func() { }) var _ = Describe("Various Startup Configurations", Ordered, func() { + Context("Verify dedicated supervisor port", func() { + It("Starts K3s with no issues", func() { + for _, node := range agentNodeNames { + cmd := "mkdir -p /etc/rancher/k3s/config.yaml.d; grep -F server: /etc/rancher/k3s/config.yaml | sed s/6443/9345/ > /tmp/99-server.yaml; sudo mv /tmp/99-server.yaml /etc/rancher/k3s/config.yaml.d/" + res, err := e2e.RunCmdOnNode(cmd, node) + By("checking command results: " + res) + Expect(err).NotTo(HaveOccurred()) + } + supervisorPortYAML := "supervisor-port: 9345\napiserver-port: 6443\napiserver-bind-address: 0.0.0.0\ndisable: traefik\nnode-taint: node-role.kubernetes.io/control-plane:NoExecute" + err := StartK3sCluster(append(serverNodeNames, agentNodeNames...), supervisorPortYAML, "") + Expect(err).NotTo(HaveOccurred(), e2e.GetVagrantLog(err)) + + fmt.Println("CLUSTER CONFIG") + fmt.Println("OS:", *nodeOS) + fmt.Println("Server Nodes:", serverNodeNames) + fmt.Println("Agent Nodes:", agentNodeNames) + kubeConfigFile, err = e2e.GenKubeConfigFile(serverNodeNames[0]) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Checks node and pod status", func() { + fmt.Printf("\nFetching node status\n") + Eventually(func(g Gomega) { + nodes, err := e2e.ParseNodes(kubeConfigFile, false) + g.Expect(err).NotTo(HaveOccurred()) + for _, node := range nodes { + g.Expect(node.Status).Should(Equal("Ready")) + } + }, "360s", "5s").Should(Succeed()) + _, _ = e2e.ParseNodes(kubeConfigFile, true) + + fmt.Printf("\nFetching pods status\n") + Eventually(func(g Gomega) { + pods, err := e2e.ParsePods(kubeConfigFile, false) + g.Expect(err).NotTo(HaveOccurred()) + for _, pod := range pods { + if strings.Contains(pod.Name, "helm-install") { + g.Expect(pod.Status).Should(Equal("Completed"), pod.Name) + } else { + g.Expect(pod.Status).Should(Equal("Running"), pod.Name) + } + } + }, "360s", "5s").Should(Succeed()) + _, _ = e2e.ParsePods(kubeConfigFile, true) + }) + + It("Returns pod metrics", func() { + cmd := "kubectl top pod -A" + Eventually(func() error { + _, err := e2e.RunCommand(cmd) + return err + }, "600s", "5s").Should(Succeed()) + }) + + It("Returns node metrics", func() { + cmd := "kubectl top node" + _, err := e2e.RunCommand(cmd) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Runs an interactive command a pod", func() { + cmd := "kubectl run busybox --rm -it --restart=Never --image=rancher/mirrored-library-busybox:1.36.1 -- uname -a" + _, err := e2e.RunCmdOnNode(cmd, serverNodeNames[0]) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Collects logs from a pod", func() { + cmd := "kubectl logs -n kube-system -l k8s-app=metrics-server -c metrics-server" + _, err := e2e.RunCommand(cmd) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Kills the cluster", func() { + err := KillK3sCluster(append(serverNodeNames, agentNodeNames...)) + Expect(err).NotTo(HaveOccurred()) + }) + }) Context("Verify CRI-Dockerd :", func() { It("Starts K3s with no issues", func() { dockerYAML := "docker: true" @@ -311,6 +394,7 @@ var _ = AfterEach(func() { var _ = AfterSuite(func() { if failed { + AddReportEntry("config", e2e.GetConfig(append(serverNodeNames, agentNodeNames...))) AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) diff --git a/tests/e2e/testutils.go b/tests/e2e/testutils.go index 29edcc7f951d..2d2cb12071b0 100644 --- a/tests/e2e/testutils.go +++ b/tests/e2e/testutils.go @@ -377,6 +377,19 @@ func TailJournalLogs(lines int, nodes []string) string { return logs.String() } +func GetConfig(nodes []string) string { + config := &strings.Builder{} + for _, node := range nodes { + cmd := "tar -Pc /etc/rancher/k3s/ | tar -vxPO" + if c, err := RunCmdOnNode(cmd, node); err != nil { + fmt.Fprintf(config, "** failed to get config for node %s ***\n%v\n", node, err) + } else { + fmt.Fprintf(config, "** config for node %s ***\n%s\n", node, c) + } + } + return config.String() +} + // GetVagrantLog returns the logs of on vagrant commands that initialize the nodes and provision K3s on each node. // It also attempts to fetch the systemctl logs of K3s on nodes where the k3s.service failed. func GetVagrantLog(cErr error) string {