diff --git a/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go b/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go index fef3a8a939..fb1c027d7e 100644 --- a/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go +++ b/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go @@ -8,7 +8,6 @@ import ( "context" "errors" "fmt" - "io" "net" "net/netip" "sync" @@ -19,9 +18,9 @@ import ( "github.com/cosi-project/runtime/pkg/safe" "github.com/cosi-project/runtime/pkg/state" "github.com/siderolabs/gen/optional" + "github.com/siderolabs/gen/pair" "go.uber.org/zap" - "github.com/siderolabs/talos/internal/pkg/ctxutil" "github.com/siderolabs/talos/internal/pkg/dns" "github.com/siderolabs/talos/pkg/machinery/resources/network" ) @@ -33,10 +32,9 @@ type DNSResolveCacheController struct { mx sync.Mutex handler *dns.Handler cache *dns.Cache - runners map[runnerConfig]*dnsRunner + runners map[runnerConfig]pair.Pair[func(), <-chan struct{}] reconcile chan struct{} originalCtx context.Context //nolint:containedctx - wg sync.WaitGroup } // Name implements controller.Controller interface. @@ -71,7 +69,7 @@ func (ctrl *DNSResolveCacheController) Outputs() []controller.Output { // //nolint:gocyclo,cyclop func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Runtime, logger *zap.Logger) error { - ctrl.init(ctx, logger) + ctrl.init(ctx) ctrl.mx.Lock() defer ctrl.mx.Unlock() @@ -81,9 +79,19 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run for { select { case <-ctx.Done(): - return ctxutil.Cause(ctx) + return nil case <-r.EventCh(): case <-ctrl.reconcile: + for cfg, stop := range ctrl.runners { + select { + default: + continue + case <-stop.F2: + } + + stop.F1() + delete(ctrl.runners, cfg) + } } cfg, err := safe.ReaderGetByID[*network.HostDNSConfig](ctx, r, network.HostDNSConfigID) @@ -101,7 +109,7 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run ctrl.stopRunners(ctx, true) if err = safe.CleanupOutputs[*network.DNSResolveCache](ctx, r); err != nil { - return fmt.Errorf("error cleaning up dns status: %w", err) + return fmt.Errorf("error cleaning up dns status on disable: %w", err) } continue @@ -111,10 +119,10 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run for _, addr := range cfg.TypedSpec().ListenAddresses { for _, netwk := range []string{"udp", "tcp"} { - config := runnerConfig{net: netwk, addr: addr} + runnerCfg := runnerConfig{net: netwk, addr: addr} - if _, ok := ctrl.runners[config]; !ok { - runner, rErr := newDNSRunner(config, ctrl.cache, ctrl.Logger) + if _, ok := ctrl.runners[runnerCfg]; !ok { + runner, rErr := newDNSRunner(runnerCfg, ctrl.cache, ctrl.Logger) if rErr != nil { return fmt.Errorf("error creating dns runner: %w", rErr) } @@ -123,30 +131,23 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run continue } - ctrl.wg.Add(1) - - go func() { - defer ctrl.wg.Done() - - runner.Run(ctx, logger, ctrl.reconcile) - }() - - ctrl.runners[config] = runner + ctrl.runners[runnerCfg] = pair.MakePair(runner.Start(ctrl.handleDone(ctx, logger))) } - if err = ctrl.writeDNSStatus(ctx, r, config); err != nil { + if err = ctrl.writeDNSStatus(ctx, r, runnerCfg); err != nil { return fmt.Errorf("error writing dns status: %w", err) } - touchedRunners[config] = struct{}{} + touchedRunners[runnerCfg] = struct{}{} } } - for config := range ctrl.runners { - if _, ok := touchedRunners[config]; !ok { - ctrl.runners[config].Stop() + for runnerCfg, stop := range ctrl.runners { + if _, ok := touchedRunners[runnerCfg]; !ok { + stop.F1() + delete(ctrl.runners, runnerCfg) - delete(ctrl.runners, config) + continue } } @@ -182,7 +183,7 @@ func (ctrl *DNSResolveCacheController) writeDNSStatus(ctx context.Context, r con }) } -func (ctrl *DNSResolveCacheController) init(ctx context.Context, logger *zap.Logger) { +func (ctrl *DNSResolveCacheController) init(ctx context.Context) { if ctrl.runners != nil { if ctrl.originalCtx != ctx { // This should not happen, but if it does, it's a bug. @@ -195,7 +196,7 @@ func (ctrl *DNSResolveCacheController) init(ctx context.Context, logger *zap.Log ctrl.originalCtx = ctx ctrl.handler = dns.NewHandler(ctrl.Logger) ctrl.cache = dns.NewCache(ctrl.handler, ctrl.Logger) - ctrl.runners = map[runnerConfig]*dnsRunner{} + ctrl.runners = map[runnerConfig]pair.Pair[func(), <-chan struct{}]{} ctrl.reconcile = make(chan struct{}, 1) // Ensure we stop all runners when the context is canceled, no matter where we are currently. @@ -215,21 +216,34 @@ func (ctrl *DNSResolveCacheController) stopRunners(ctx context.Context, ignoreCt return } - for _, r := range ctrl.runners { - r.Stop() + for _, stop := range ctrl.runners { + stop.F1() } clear(ctrl.runners) ctrl.handler.Stop() - - ctrl.wg.Wait() } -type dnsRunner struct { - runner *dns.Runner - lis io.Closer - logger *zap.Logger +func (ctrl *DNSResolveCacheController) handleDone(ctx context.Context, logger *zap.Logger) func(err error) { + return func(err error) { + if ctx.Err() != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { + logger.Error("controller is closing, but error running dns server", zap.Error(err)) + } + + return + } + + if err != nil { + logger.Error("error running dns server", zap.Error(err)) + } + + select { + case ctrl.reconcile <- struct{}{}: + default: + } + } } type runnerConfig struct { @@ -237,7 +251,7 @@ type runnerConfig struct { addr netip.AddrPort } -func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dnsRunner, error) { +func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dns.Server, error) { if cfg.addr.Addr().Is6() { cfg.net += "6" } @@ -246,8 +260,6 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dnsR var serverOpts dns.ServerOptions - var lis io.Closer - switch cfg.net { case "udp", "udp6": packetConn, err := dns.NewUDPPacketConn(cfg.net, cfg.addr.String()) @@ -262,11 +274,10 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dnsR return nil, fmt.Errorf("error creating udp packet conn: %w", err) } - lis = packetConn - serverOpts = dns.ServerOptions{ PacketConn: packetConn, Handler: cache, + Logger: logger, } case "tcp", "tcp6": @@ -282,8 +293,6 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dnsR return nil, fmt.Errorf("error creating tcp listener: %w", err) } - lis = listener - serverOpts = dns.ServerOptions{ Listener: listener, Handler: cache, @@ -291,55 +300,9 @@ func newDNSRunner(cfg runnerConfig, cache *dns.Cache, logger *zap.Logger) (*dnsR WriteTimeout: 5 * time.Second, IdleTimeout: func() time.Duration { return 10 * time.Second }, MaxTCPQueries: -1, + Logger: logger, } } - runner := dns.NewRunner(dns.NewServer(serverOpts), logger) - - return &dnsRunner{ - runner: runner, - lis: lis, - logger: logger, - }, nil -} - -func (dnsRunner *dnsRunner) Run(ctx context.Context, logger *zap.Logger, reconcile chan<- struct{}) { - err := dnsRunner.runner.Run() - if err == nil { - if ctx.Err() == nil { - select { - case reconcile <- struct{}{}: - default: - } - } - - return - } - - if ctx.Err() == nil { - logger.Error("error running dns server, triggering reconcile", zap.Error(err)) - - select { - case reconcile <- struct{}{}: - default: - } - - return - } - - if !errors.Is(err, net.ErrClosed) { - logger.Error("controller is closing, but error running dns server", zap.Error(err)) - - return - } -} - -func (dnsRunner *dnsRunner) Stop() { - dnsRunner.runner.Stop() - - if err := dnsRunner.lis.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - dnsRunner.logger.Error("error closing listener", zap.Error(err)) - } else { - dnsRunner.logger.Debug("dns listener closed") - } + return dns.NewServer(serverOpts), nil } diff --git a/internal/app/machined/pkg/controllers/network/dns_resolve_cache_test.go b/internal/app/machined/pkg/controllers/network/dns_resolve_cache_test.go index 49763da916..84d6cad4ad 100644 --- a/internal/app/machined/pkg/controllers/network/dns_resolve_cache_test.go +++ b/internal/app/machined/pkg/controllers/network/dns_resolve_cache_test.go @@ -78,12 +78,19 @@ func (suite *DNSServer) TestResolving() { var res *dns.Msg - err := retry.Constant(2*time.Second, retry.WithUnits(100*time.Millisecond)).Retry(func() error { + err := retry.Constant(5*time.Second, retry.WithUnits(100*time.Millisecond)).Retry(func() error { r, err := dns.Exchange(msg, "127.0.0.53:"+port) + if err != nil { + return retry.ExpectedError(err) + } + + if r.Rcode != dns.RcodeSuccess { + return retry.ExpectedErrorf("expected rcode %d, got %d", dns.RcodeSuccess, r.Rcode) + } res = r - return retry.ExpectedError(err) + return nil }) suite.Require().NoError(err) suite.Require().Equal(dns.RcodeSuccess, res.Rcode, res) @@ -137,7 +144,7 @@ func (suite *DNSServer) TestSetupStartStop() { func TestDNSServer(t *testing.T) { suite.Run(t, &DNSServer{ DefaultSuite: ctest.DefaultSuite{ - Timeout: 5 * time.Second, + Timeout: 10 * time.Second, AfterSetup: func(suite *ctest.DefaultSuite) { suite.Require().NoError(suite.Runtime().RegisterController(&netctrl.DNSUpstreamController{})) suite.Require().NoError(suite.Runtime().RegisterController(&netctrl.DNSResolveCacheController{ diff --git a/internal/pkg/ctxutil/ctxutil.go b/internal/pkg/ctxutil/ctxutil.go index 1fff6bad64..892992200f 100644 --- a/internal/pkg/ctxutil/ctxutil.go +++ b/internal/pkg/ctxutil/ctxutil.go @@ -7,16 +7,6 @@ package ctxutil import "context" -// MonitorFn starts a function in a new goroutine and cancels the context with error as cause when the function returns. -// It returns the new context. -func MonitorFn(ctx context.Context, fn func() error) context.Context { - ctx, cancel := context.WithCancelCause(ctx) - - go func() { cancel(fn()) }() - - return ctx -} - // Cause returns the cause of the context error, or nil if there is no error or the error is a usual context error. func Cause(ctx context.Context) error { if c := context.Cause(ctx); c != ctx.Err() { diff --git a/internal/pkg/ctxutil/ctxutil_test.go b/internal/pkg/ctxutil/ctxutil_test.go deleted file mode 100644 index 1c92879d28..0000000000 --- a/internal/pkg/ctxutil/ctxutil_test.go +++ /dev/null @@ -1,33 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -package ctxutil_test - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/siderolabs/talos/internal/pkg/ctxutil" -) - -func TestStartFn(t *testing.T) { - ctx := ctxutil.MonitorFn(context.Background(), func() error { return nil }) - - <-ctx.Done() - - require.Equal(t, context.Canceled, ctx.Err()) - require.Nil(t, ctxutil.Cause(ctx)) - - myErr := errors.New("my error") - - ctx = ctxutil.MonitorFn(context.Background(), func() error { return myErr }) - - <-ctx.Done() - - require.Equal(t, context.Canceled, ctx.Err()) - require.Equal(t, myErr, ctxutil.Cause(ctx)) -} diff --git a/internal/pkg/dns/dns.go b/internal/pkg/dns/dns.go index 9bc7522a5b..838161c9ba 100644 --- a/internal/pkg/dns/dns.go +++ b/internal/pkg/dns/dns.go @@ -9,6 +9,7 @@ import ( "context" "errors" "fmt" + "io" "math/rand" "net" "slices" @@ -24,53 +25,8 @@ import ( "github.com/miekg/dns" "go.uber.org/zap" "golang.org/x/sys/unix" - - "github.com/siderolabs/talos/internal/pkg/utils" ) -// NewRunner creates a new Runner. -func NewRunner(srv Server, logger *zap.Logger) *Runner { - r := utils.NewRunner(srv.ActivateAndServe, srv.Shutdown, func(err error) bool { - // There a possible scenario where `Run` reached `ListenAndServe` and then yielded CPU time to another - // goroutine and then `Stop` reached `Shutdown`. In that case `ListenAndServe` will actually start after - // `Shutdown` and `Stop` method will forever block if we do not try again. - return strings.Contains(err.Error(), "server not started") - }) - - return &Runner{r: r, logger: logger} -} - -// Runner is a dns server handler. -type Runner struct { - r *utils.Runner - logger *zap.Logger -} - -// Server is a dns server. -type Server interface { - ActivateAndServe() error - Shutdown() error -} - -// Run runs dns server. -func (r *Runner) Run() error { - r.logger.Debug("starting dns server") - - err := r.r.Run() - - r.logger.Debug("dns server stopped", zap.Error(err)) - - return err -} - -// Stop stops dns server. It's safe to call even if server is already stopped. -func (r *Runner) Stop() { - err := r.r.Stop() - if err != nil { - r.logger.Warn("error shutting down dns server", zap.Error(err)) - } -} - // Cache is a [dns.Handler] to [plugin.Handler] adapter. type Cache struct { cache *cache.Cache @@ -211,22 +167,78 @@ type ServerOptions struct { WriteTimeout time.Duration IdleTimeout func() time.Duration MaxTCPQueries int + Logger *zap.Logger } // NewServer creates a new Server. -func NewServer(opts ServerOptions) Server { - return &server{&dns.Server{ - Listener: opts.Listener, - PacketConn: opts.PacketConn, - Handler: opts.Handler, - ReadTimeout: opts.ReadTimeout, - WriteTimeout: opts.WriteTimeout, - IdleTimeout: opts.IdleTimeout, - MaxTCPQueries: opts.MaxTCPQueries, - }} +func NewServer(opts ServerOptions) *Server { + return &Server{ + srv: &dns.Server{ + Listener: opts.Listener, + PacketConn: opts.PacketConn, + Handler: opts.Handler, + ReadTimeout: opts.ReadTimeout, + WriteTimeout: opts.WriteTimeout, + IdleTimeout: opts.IdleTimeout, + MaxTCPQueries: opts.MaxTCPQueries, + }, + logger: opts.Logger, + } } -type server struct{ *dns.Server } +// Server is a dns server. +type Server struct { + srv *dns.Server + logger *zap.Logger +} + +// Start starts the dns server. Returns a function to stop the server. +func (s *Server) Start(onDone func(err error)) (stop func(), stopped <-chan struct{}) { + done := make(chan struct{}) + + fn := sync.OnceFunc(func() { + for { + err := s.srv.Shutdown() + if err != nil { + if strings.Contains(err.Error(), "server not started") { + // There a possible scenario where `go func()` not yet reached `ActivateAndServe` and yielded CPU + // time to another goroutine and then this closure reached `Shutdown`. In that case + // `ActivateAndServe` will actually start after `Shutdown` and this closure will block forever + // because `go func()` will never exit and close `done` channel. + continue + } + + s.logger.Error("error shutting down dns server", zap.Error(err)) + } + + break + } + + closer := io.Closer(s.srv.Listener) + if closer == nil { + closer = s.srv.PacketConn + } + + if closer != nil { + err := closer.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + s.logger.Error("error closing dns server listener", zap.Error(err)) + } else { + s.logger.Debug("dns server listener closed") + } + } + + <-done + }) + + go func() { + defer close(done) + + onDone(s.srv.ActivateAndServe()) + }() + + return fn, done +} // NewTCPListener creates a new TCP listener. func NewTCPListener(network, addr string) (net.Listener, error) { diff --git a/internal/pkg/dns/dns_test.go b/internal/pkg/dns/dns_test.go index 702e2c6c0c..fb696a39ba 100644 --- a/internal/pkg/dns/dns_test.go +++ b/internal/pkg/dns/dns_test.go @@ -5,11 +5,7 @@ package dns_test import ( - "context" - "errors" "net" - "sync" - "sync/atomic" "testing" "time" @@ -18,10 +14,8 @@ import ( "github.com/siderolabs/gen/xslices" "github.com/siderolabs/gen/xtesting/check" "github.com/stretchr/testify/require" - "go.uber.org/zap" "go.uber.org/zap/zaptest" - "github.com/siderolabs/talos/internal/pkg/ctxutil" "github.com/siderolabs/talos/internal/pkg/dns" ) @@ -53,10 +47,8 @@ func TestDNS(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ctx, stop := newServer(t, test.nameservers...) - - stopOnce := sync.OnceFunc(stop) - defer stopOnce() + stop := newServer(t, test.nameservers...) + defer stop() time.Sleep(10 * time.Millisecond) @@ -69,20 +61,14 @@ func TestDNS(t *testing.T) { t.Logf("r: %s", r) - stopOnce() - - <-ctx.Done() - - require.NoError(t, ctxutil.Cause(ctx)) + stop() }) } } func TestDNSEmptyDestinations(t *testing.T) { - ctx, stop := newServer(t) - - stopOnce := sync.OnceFunc(stop) - defer stopOnce() + stop := newServer(t) + defer stop() time.Sleep(10 * time.Millisecond) @@ -94,14 +80,10 @@ func TestDNSEmptyDestinations(t *testing.T) { require.NoError(t, err) require.Equal(t, dnssrv.RcodeServerFailure, r.Rcode, r) - stopOnce() - - <-ctx.Done() - - require.NoError(t, ctxutil.Cause(ctx)) + stop() } -func newServer(t *testing.T, nameservers ...string) (context.Context, func()) { +func newServer(t *testing.T, nameservers ...string) func() { l := zaptest.NewLogger(t) handler := dns.NewHandler(l) @@ -121,12 +103,21 @@ func newServer(t *testing.T, nameservers ...string) (context.Context, func()) { pc, err := dns.NewUDPPacketConn("udp", "127.0.0.53:10700") require.NoError(t, err) - runner := dns.NewRunner(dns.NewServer(dns.ServerOptions{ + srv := dns.NewServer(dns.ServerOptions{ PacketConn: pc, Handler: dns.NewCache(handler, l), - }), l) + Logger: l, + }) - return ctxutil.MonitorFn(context.Background(), runner.Run), runner.Stop + stop, _ := srv.Start(func(err error) { + if err != nil { + t.Errorf("error running dns server: %v", err) + } + + t.Logf("dns server stopped") + }) + + return stop } func createQuery() *dnssrv.Msg { @@ -144,86 +135,3 @@ func createQuery() *dnssrv.Msg { }, } } - -func TestActivateFailure(t *testing.T) { - // Ensure that we correctly handle an error inside [dns.Runner.Run]. - l := zaptest.NewLogger(t) - - runner := dns.NewRunner(&testServer{t: t}, l) - - ctx := ctxutil.MonitorFn(context.Background(), runner.Run) - defer runner.Stop() - - <-ctx.Done() - - require.Equal(t, errFailed, ctxutil.Cause(ctx)) -} - -func TestRunnerStopsBeforeRun(t *testing.T) { - // Ensure that we correctly handle an error inside [dns.Runner.Run]. - l := zap.NewNop() - - for range 1000 { - runner := dns.NewRunner(&runnerStopper{}, l) - - ctx := ctxutil.MonitorFn(context.Background(), runner.Run) - runner.Stop() - - <-ctx.Done() - } - - for range 1000 { - runner := dns.NewRunner(&runnerStopper{}, l) - - runner.Stop() - ctx := ctxutil.MonitorFn(context.Background(), runner.Run) - - <-ctx.Done() - } -} - -type testServer struct { - t *testing.T -} - -var errFailed = errors.New("listen failure") - -func (ts *testServer) ActivateAndServe() error { return errFailed } - -func (ts *testServer) Shutdown() error { - ts.t.Fatal("should not be called") - - return nil -} - -func (ts *testServer) Name() string { - return "test-server" -} - -type runnerStopper struct { - val atomic.Pointer[chan struct{}] -} - -func (rs *runnerStopper) ActivateAndServe() error { - ch := make(chan struct{}) - - if rs.val.Swap(&ch) != nil { - panic("chan should be empty") - } - - <-ch - - return nil -} - -func (rs *runnerStopper) Shutdown() error { - chPtr := rs.val.Load() - - if chPtr == nil { - return errors.New("server not started") - } - - close(*chPtr) - - return nil -} diff --git a/internal/pkg/utils/utils.go b/internal/pkg/utils/utils.go deleted file mode 100644 index 3334b89418..0000000000 --- a/internal/pkg/utils/utils.go +++ /dev/null @@ -1,74 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -// Package utils provides various utility functions. -package utils - -import ( - "errors" - "sync/atomic" -) - -const ( - notRunning = iota - running - closing - closed -) - -// Runner is a fn/stop runner. -type Runner struct { - fn func() error - stop func() error - retryStop func(error) bool - status atomic.Int64 - done chan struct{} -} - -// NewRunner creates a new runner. -func NewRunner(fn, stop func() error, retryStop func(error) bool) *Runner { - return &Runner{fn: fn, stop: stop, retryStop: retryStop, done: make(chan struct{})} -} - -// Run runs fn. -func (r *Runner) Run() error { - defer func() { - if r.status.Swap(closed) != closed { - close(r.done) - } - }() - - if !r.status.CompareAndSwap(notRunning, running) { - return ErrAlreadyRunning - } - - return r.fn() -} - -var ( - // ErrAlreadyRunning is the error that is returned when runner is already running/closing/closed. - ErrAlreadyRunning = errors.New("runner is already running/closing/closed") - // ErrNotRunning is the error that is returned when runner is not running/closing/closed. - ErrNotRunning = errors.New("runner is not running/closing/closed") -) - -// Stop stops runner. It's safe to call even if runner is already stopped or in process of being stopped. -func (r *Runner) Stop() error { - if r.status.CompareAndSwap(notRunning, closing) || !r.status.CompareAndSwap(running, closing) { - return ErrNotRunning - } - - for { - err := r.stop() - if err != nil { - if r.retryStop(err) && r.status.Load() == closing { - continue - } - } - - <-r.done - - return err - } -}