Skip to content

Commit

Permalink
chore: optimize DNSResolveCacheController
Browse files Browse the repository at this point in the history
Optimize `DNSResolveCacheController` type, including `dns.Server` optimization for easy start/stop. This PR ensures that we
delete server from runners on stop (even unexpected) and restart it properly. Also fixes incorrect assumption on unit-tests.

Fixes siderolabs#8563

This PR also does those things:
- Removes `utils.Runner`
- Removes `ctxutil.MonitorFn`
- Removes `dns.Runner`
- Removes `network.dnsRunner`

Signed-off-by: Dmitriy Matrenichev <dmitry.matrenichev@siderolabs.com>
(cherry picked from commit ba7cdc8)
  • Loading branch information
DmitriyMV authored and smira committed Apr 12, 2024
1 parent eca03b0 commit d5932a3
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 377 deletions.
143 changes: 53 additions & 90 deletions internal/app/machined/pkg/controllers/network/dns_resolve_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"context"
"errors"
"fmt"
"io"
"net"
"net/netip"
"sync"
Expand All @@ -19,9 +18,9 @@ import (
"github.com/cosi-project/runtime/pkg/safe"
"github.com/cosi-project/runtime/pkg/state"
"github.com/siderolabs/gen/optional"
"github.com/siderolabs/gen/pair"
"go.uber.org/zap"

"github.com/siderolabs/talos/internal/pkg/ctxutil"
"github.com/siderolabs/talos/internal/pkg/dns"
"github.com/siderolabs/talos/pkg/machinery/resources/network"
)
Expand All @@ -33,10 +32,9 @@ type DNSResolveCacheController struct {
mx sync.Mutex
handler *dns.Handler
cache *dns.Cache
runners map[runnerConfig]*dnsRunner
runners map[runnerConfig]pair.Pair[func(), <-chan struct{}]
reconcile chan struct{}
originalCtx context.Context //nolint:containedctx
wg sync.WaitGroup
}

// Name implements controller.Controller interface.
Expand Down Expand Up @@ -71,7 +69,7 @@ func (ctrl *DNSResolveCacheController) Outputs() []controller.Output {
//
//nolint:gocyclo,cyclop
func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Runtime, logger *zap.Logger) error {
ctrl.init(ctx, logger)
ctrl.init(ctx)

ctrl.mx.Lock()
defer ctrl.mx.Unlock()
Expand All @@ -81,9 +79,19 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run
for {
select {
case <-ctx.Done():
return ctxutil.Cause(ctx)
return nil
case <-r.EventCh():
case <-ctrl.reconcile:
for cfg, stop := range ctrl.runners {
select {
default:
continue
case <-stop.F2:
}

stop.F1()
delete(ctrl.runners, cfg)
}
}

cfg, err := safe.ReaderGetByID[*network.HostDNSConfig](ctx, r, network.HostDNSConfigID)
Expand All @@ -101,7 +109,7 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run
ctrl.stopRunners(ctx, true)

if err = safe.CleanupOutputs[*network.DNSResolveCache](ctx, r); err != nil {
return fmt.Errorf("error cleaning up dns status: %w", err)
return fmt.Errorf("error cleaning up dns status on disable: %w", err)
}

continue
Expand All @@ -111,10 +119,10 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run

for _, addr := range cfg.TypedSpec().ListenAddresses {
for _, netwk := range []string{"udp", "tcp"} {
config := runnerConfig{net: netwk, addr: addr}
runnerCfg := runnerConfig{net: netwk, addr: addr}

if _, ok := ctrl.runners[config]; !ok {
runner, rErr := newDNSRunner(config, ctrl.cache, ctrl.Logger)
if _, ok := ctrl.runners[runnerCfg]; !ok {
runner, rErr := newDNSRunner(runnerCfg, ctrl.cache, ctrl.Logger)
if rErr != nil {
return fmt.Errorf("error creating dns runner: %w", rErr)
}
Expand All @@ -123,30 +131,23 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run
continue
}

ctrl.wg.Add(1)

go func() {
defer ctrl.wg.Done()

runner.Run(ctx, logger, ctrl.reconcile)
}()

ctrl.runners[config] = runner
ctrl.runners[runnerCfg] = pair.MakePair(runner.Start(ctrl.handleDone(ctx, logger)))
}

if err = ctrl.writeDNSStatus(ctx, r, config); err != nil {
if err = ctrl.writeDNSStatus(ctx, r, runnerCfg); err != nil {
return fmt.Errorf("error writing dns status: %w", err)
}

touchedRunners[config] = struct{}{}
touchedRunners[runnerCfg] = struct{}{}
}
}

for config := range ctrl.runners {
if _, ok := touchedRunners[config]; !ok {
ctrl.runners[config].Stop()
for runnerCfg, stop := range ctrl.runners {
if _, ok := touchedRunners[runnerCfg]; !ok {
stop.F1()
delete(ctrl.runners, runnerCfg)

delete(ctrl.runners, config)
continue
}
}

Expand Down Expand Up @@ -182,7 +183,7 @@ func (ctrl *DNSResolveCacheController) writeDNSStatus(ctx context.Context, r con
})
}

func (ctrl *DNSResolveCacheController) init(ctx context.Context, logger *zap.Logger) {
func (ctrl *DNSResolveCacheController) init(ctx context.Context) {
if ctrl.runners != nil {
if ctrl.originalCtx != ctx {
// This should not happen, but if it does, it's a bug.
Expand All @@ -195,7 +196,7 @@ func (ctrl *DNSResolveCacheController) init(ctx context.Context, logger *zap.Log
ctrl.originalCtx = ctx
ctrl.handler = dns.NewHandler(ctrl.Logger)
ctrl.cache = dns.NewCache(ctrl.handler, ctrl.Logger)
ctrl.runners = map[runnerConfig]*dnsRunner{}
ctrl.runners = map[runnerConfig]pair.Pair[func(), <-chan struct{}]{}
ctrl.reconcile = make(chan struct{}, 1)

// Ensure we stop all runners when the context is canceled, no matter where we are currently.
Expand All @@ -215,29 +216,42 @@ func (ctrl *DNSResolveCacheController) stopRunners(ctx context.Context, ignoreCt
return
}

for _, r := range ctrl.runners {
r.Stop()
for _, stop := range ctrl.runners {
stop.F1()
}

clear(ctrl.runners)

ctrl.handler.Stop()

ctrl.wg.Wait()
}

type dnsRunner struct {
runner *dns.Runner
lis io.Closer
logger *zap.Logger
func (ctrl *DNSResolveCacheController) handleDone(ctx context.Context, logger *zap.Logger) func(err error) {
return func(err error) {
if ctx.Err() != nil {
if err != nil && !errors.Is(err, net.ErrClosed) {
logger.Error("controller is closing, but error running dns server", zap.Error(err))
}

return
}

if err != nil {
logger.Error("error running dns server", zap.Error(err))
}

select {
case ctrl.reconcile <- struct{}{}:
default:
}
}
}

type runnerConfig struct {
net string
addr netip.AddrPort
}

func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dnsRunner, error) {
func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dns.Server, error) {
if cfg.addr.Addr().Is6() {
cfg.net += "6"
}
Expand All @@ -246,8 +260,6 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dnsR

var serverOpts dns.ServerOptions

var lis io.Closer

switch cfg.net {
case "udp", "udp6":
packetConn, err := dns.NewUDPPacketConn(cfg.net, cfg.addr.String())
Expand All @@ -262,11 +274,10 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dnsR
return nil, fmt.Errorf("error creating udp packet conn: %w", err)
}

lis = packetConn

serverOpts = dns.ServerOptions{
PacketConn: packetConn,
Handler: cache,
Logger: logger,
}

case "tcp", "tcp6":
Expand All @@ -282,64 +293,16 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dnsR
return nil, fmt.Errorf("error creating tcp listener: %w", err)
}

lis = listener

serverOpts = dns.ServerOptions{
Listener: listener,
Handler: cache,
ReadTimeout: 3 * time.Second,
WriteTimeout: 5 * time.Second,
IdleTimeout: func() time.Duration { return 10 * time.Second },
MaxTCPQueries: -1,
Logger: logger,
}
}

runner := dns.NewRunner(dns.NewServer(serverOpts), logger)

return &dnsRunner{
runner: runner,
lis: lis,
logger: logger,
}, nil
}

func (dnsRunner *dnsRunner) Run(ctx context.Context, logger *zap.Logger, reconcile chan<- struct{}) {
err := dnsRunner.runner.Run()
if err == nil {
if ctx.Err() == nil {
select {
case reconcile <- struct{}{}:
default:
}
}

return
}

if ctx.Err() == nil {
logger.Error("error running dns server, triggering reconcile", zap.Error(err))

select {
case reconcile <- struct{}{}:
default:
}

return
}

if !errors.Is(err, net.ErrClosed) {
logger.Error("controller is closing, but error running dns server", zap.Error(err))

return
}
}

func (dnsRunner *dnsRunner) Stop() {
dnsRunner.runner.Stop()

if err := dnsRunner.lis.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
dnsRunner.logger.Error("error closing listener", zap.Error(err))
} else {
dnsRunner.logger.Debug("dns listener closed")
}
return dns.NewServer(serverOpts), nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,19 @@ func (suite *DNSServer) TestResolving() {

var res *dns.Msg

err := retry.Constant(2*time.Second, retry.WithUnits(100*time.Millisecond)).Retry(func() error {
err := retry.Constant(5*time.Second, retry.WithUnits(100*time.Millisecond)).Retry(func() error {
r, err := dns.Exchange(msg, "127.0.0.53:"+port)
if err != nil {
return retry.ExpectedError(err)
}

if r.Rcode != dns.RcodeSuccess {
return retry.ExpectedErrorf("expected rcode %d, got %d", dns.RcodeSuccess, r.Rcode)
}

res = r

return retry.ExpectedError(err)
return nil
})
suite.Require().NoError(err)
suite.Require().Equal(dns.RcodeSuccess, res.Rcode, res)
Expand Down Expand Up @@ -137,7 +144,7 @@ func (suite *DNSServer) TestSetupStartStop() {
func TestDNSServer(t *testing.T) {
suite.Run(t, &DNSServer{
DefaultSuite: ctest.DefaultSuite{
Timeout: 5 * time.Second,
Timeout: 10 * time.Second,
AfterSetup: func(suite *ctest.DefaultSuite) {
suite.Require().NoError(suite.Runtime().RegisterController(&netctrl.DNSUpstreamController{}))
suite.Require().NoError(suite.Runtime().RegisterController(&netctrl.DNSResolveCacheController{
Expand Down
10 changes: 0 additions & 10 deletions internal/pkg/ctxutil/ctxutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,6 @@ package ctxutil

import "context"

// MonitorFn starts a function in a new goroutine and cancels the context with error as cause when the function returns.
// It returns the new context.
func MonitorFn(ctx context.Context, fn func() error) context.Context {
ctx, cancel := context.WithCancelCause(ctx)

go func() { cancel(fn()) }()

return ctx
}

// Cause returns the cause of the context error, or nil if there is no error or the error is a usual context error.
func Cause(ctx context.Context) error {
if c := context.Cause(ctx); c != ctx.Err() {
Expand Down
33 changes: 0 additions & 33 deletions internal/pkg/ctxutil/ctxutil_test.go

This file was deleted.

Loading

0 comments on commit d5932a3

Please sign in to comment.