Skip to content

Commit c54d95d

Browse files
committed
feat: implement control plane loadbalancer
A simple wrapper around generic TCP loadbalancer which implements control plane loadbalancer. Signed-off-by: Andrey Smirnov <andrey.smirnov@talos-systems.com>
1 parent 4a6e29e commit c54d95d

File tree

2 files changed

+254
-0
lines changed

2 files changed

+254
-0
lines changed

controlplane/controlplane.go

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// This Source Code Form is subject to the terms of the Mozilla Public
2+
// License, v. 2.0. If a copy of the MPL was not distributed with this
3+
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4+
5+
// Package controlplane wraps generic TCP loadbalancer for Kubernetes controlplane endpoint LB.
6+
package controlplane
7+
8+
import (
9+
"fmt"
10+
"io"
11+
"log"
12+
"net"
13+
"strconv"
14+
"time"
15+
16+
"github.com/talos-systems/go-loadbalancer/loadbalancer"
17+
)
18+
19+
// LoadBalancer provides Kubernetes control plane TCP loadbalancer with a way to update endpoints (list of control plane nodes).
20+
type LoadBalancer struct {
21+
lb loadbalancer.TCP
22+
23+
done chan struct{}
24+
25+
endpoint string
26+
}
27+
28+
// NewLoadBalancer initializes the load balancer.
29+
//
30+
// If bindPort is zero, load balancer will bind to a random available port.
31+
func NewLoadBalancer(bindAddress string, bindPort int, logWriter io.Writer) (*LoadBalancer, error) {
32+
if bindPort == 0 {
33+
var err error
34+
35+
bindPort, err = findListenPort(bindAddress)
36+
if err != nil {
37+
return nil, fmt.Errorf("unable to find available port: %w", err)
38+
}
39+
}
40+
41+
lb := &LoadBalancer{
42+
endpoint: net.JoinHostPort(bindAddress, strconv.Itoa(bindPort)),
43+
}
44+
45+
// set aggressive timeouts to prevent proxying to unhealthy upstreams
46+
lb.lb.DialTimeout = 5 * time.Second
47+
lb.lb.KeepAlivePeriod = time.Second
48+
lb.lb.TCPUserTimeout = 5 * time.Second
49+
50+
lb.lb.Logger = log.New(logWriter, lb.endpoint+" ", log.Default().Flags())
51+
52+
// create a route without any upstreams yet
53+
if err := lb.lb.AddRoute(lb.endpoint, nil); err != nil {
54+
return nil, err
55+
}
56+
57+
return lb, nil
58+
}
59+
60+
// Endpoint returns loadbalancer endpoint as "host:port".
61+
func (lb *LoadBalancer) Endpoint() string {
62+
return lb.endpoint
63+
}
64+
65+
// Start the loadbalancer providing a channel which provides endpoint list update.
66+
//
67+
// Load balancer starts with an empty list of endpoints, so initial list should be provided on the channel.
68+
func (lb *LoadBalancer) Start(upstreamCh <-chan []string) error {
69+
if err := lb.lb.Start(); err != nil {
70+
return err
71+
}
72+
73+
lb.done = make(chan struct{})
74+
75+
go func() {
76+
for {
77+
select {
78+
case upstreams := <-upstreamCh:
79+
if err := lb.lb.ReconcileRoute(lb.endpoint, upstreams); err != nil {
80+
lb.lb.Logger.Printf("failed reconciling list of upstreams: %s", err)
81+
}
82+
case <-lb.done:
83+
return
84+
}
85+
}
86+
}()
87+
88+
return nil
89+
}
90+
91+
// Shutdown the loadbalancer listener and wait for the connections to be closed.
92+
func (lb *LoadBalancer) Shutdown() error {
93+
if err := lb.lb.Close(); err != nil {
94+
return err
95+
}
96+
97+
close(lb.done)
98+
99+
lb.lb.Wait() //nolint:errcheck
100+
101+
return nil
102+
}
103+
104+
func findListenPort(address string) (int, error) {
105+
l, err := net.Listen("tcp", net.JoinHostPort(address, "0"))
106+
if err != nil {
107+
return 0, err
108+
}
109+
110+
port := l.Addr().(*net.TCPAddr).Port //nolint:forcetypeassert
111+
112+
return port, l.Close()
113+
}

controlplane/controlplane_test.go

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// This Source Code Form is subject to the terms of the Mozilla Public
2+
// License, v. 2.0. If a copy of the MPL was not distributed with this
3+
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
4+
5+
package controlplane_test
6+
7+
import (
8+
"fmt"
9+
"io/ioutil"
10+
"net"
11+
"os"
12+
"strconv"
13+
"testing"
14+
"time"
15+
16+
"github.com/stretchr/testify/assert"
17+
"github.com/stretchr/testify/require"
18+
"github.com/talos-systems/go-retry/retry"
19+
"go.uber.org/goleak"
20+
21+
"github.com/talos-systems/go-loadbalancer/controlplane"
22+
)
23+
24+
type mockUpstream struct {
25+
addr string
26+
l net.Listener
27+
28+
identity string
29+
}
30+
31+
func (u *mockUpstream) Start() error {
32+
var err error
33+
34+
u.l, err = net.Listen("tcp", "localhost:0")
35+
if err != nil {
36+
return err
37+
}
38+
39+
u.addr = u.l.Addr().String()
40+
41+
go u.serve()
42+
43+
return nil
44+
}
45+
46+
func (u *mockUpstream) serve() {
47+
for {
48+
c, err := u.l.Accept()
49+
if err != nil {
50+
return
51+
}
52+
53+
c.Write([]byte(u.identity)) //nolint: errcheck
54+
c.Close() //nolint: errcheck
55+
}
56+
}
57+
58+
func (u *mockUpstream) Close() {
59+
u.l.Close() //nolint: errcheck
60+
}
61+
62+
func TestLoadBalancer(t *testing.T) {
63+
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
64+
65+
const (
66+
upstreamCount = 5
67+
pivot = 2
68+
)
69+
70+
upstreams := make([]mockUpstream, upstreamCount)
71+
for i := range upstreams {
72+
upstreams[i].identity = strconv.Itoa(i)
73+
require.NoError(t, upstreams[i].Start())
74+
}
75+
76+
upstreamAddrs := make([]string, len(upstreams))
77+
for i := range upstreamAddrs {
78+
upstreamAddrs[i] = upstreams[i].addr
79+
}
80+
81+
lb, err := controlplane.NewLoadBalancer("localhost", 0, os.Stderr)
82+
require.NoError(t, err)
83+
84+
upstreamCh := make(chan []string)
85+
86+
require.NoError(t, lb.Start(upstreamCh))
87+
88+
upstreamCh <- upstreamAddrs[:pivot]
89+
90+
readIdentity := func() (int, error) {
91+
c, err := net.Dial("tcp", lb.Endpoint())
92+
if err != nil {
93+
return 0, retry.ExpectedError(err)
94+
}
95+
96+
defer c.Close() //nolint:errcheck
97+
98+
id, err := ioutil.ReadAll(c)
99+
if err != nil {
100+
return 0, retry.ExpectedError(err)
101+
}
102+
103+
return strconv.Atoi(string(id))
104+
}
105+
106+
assert.NoError(t, retry.Constant(10*time.Second, retry.WithUnits(time.Second)).Retry(func() error {
107+
identity, err := readIdentity()
108+
if err != nil {
109+
return err
110+
}
111+
112+
if identity < 0 || identity > pivot-1 {
113+
return fmt.Errorf("unexpected response: %d", identity)
114+
}
115+
116+
return nil
117+
}))
118+
119+
// change the upstreams
120+
upstreamCh <- upstreamAddrs[pivot:]
121+
122+
assert.NoError(t, retry.Constant(10*time.Second, retry.WithUnits(time.Second)).Retry(func() error {
123+
identity, err := readIdentity()
124+
if err != nil {
125+
return err
126+
}
127+
128+
// upstreams are not changed immediately, there might be some stale responses
129+
if identity < pivot {
130+
return retry.ExpectedErrorf("unexpected response: %d", identity)
131+
}
132+
133+
return nil
134+
}))
135+
136+
assert.NoError(t, lb.Shutdown())
137+
138+
for i := range upstreams {
139+
upstreams[i].Close()
140+
}
141+
}

0 commit comments

Comments
 (0)