diff --git a/go.mod b/go.mod index db561600b0..d717168c21 100644 --- a/go.mod +++ b/go.mod @@ -187,7 +187,7 @@ require ( golang.org/x/text v0.20.0 golang.org/x/time v0.8.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 - google.golang.org/grpc v1.68.0 // do not update to 1.68.0 until we find a way around https://github.com/grpc/grpc-go/pull/7535 + google.golang.org/grpc v1.68.0 google.golang.org/protobuf v1.35.2 gopkg.in/yaml.v3 v3.0.1 k8s.io/klog/v2 v2.130.1 diff --git a/internal/app/machined/pkg/system/runner/goroutine/goroutine.go b/internal/app/machined/pkg/system/runner/goroutine/goroutine.go index 3ec8a36cf5..fa04924c42 100644 --- a/internal/app/machined/pkg/system/runner/goroutine/goroutine.go +++ b/internal/app/machined/pkg/system/runner/goroutine/goroutine.go @@ -11,12 +11,16 @@ import ( "io" stdlibruntime "runtime" "sync" + "time" "github.com/siderolabs/talos/internal/app/machined/pkg/runtime" "github.com/siderolabs/talos/internal/app/machined/pkg/system/events" "github.com/siderolabs/talos/internal/app/machined/pkg/system/runner" ) +// ErrAborted is returned by the service when it's aborted (doesn't stop on timeout). +var ErrAborted = errors.New("service aborted") + // goroutineRunner is a runner.Runner that runs a service in a goroutine. type goroutineRunner struct { main FuncMain @@ -66,10 +70,32 @@ func (r *goroutineRunner) Run(eventSink events.Recorder) error { eventSink(events.StateRunning, "Service started as goroutine") - return r.wrappedMain() + errCh := make(chan error) + ctx := r.ctx + + go func() { + errCh <- r.wrappedMain(ctx) + }() + + select { + case <-r.ctx.Done(): + eventSink(events.StateStopping, "Service stopping") + case err := <-errCh: + // service finished on its own + return err + } + + select { + case <-time.After(r.opts.GracefulShutdownTimeout * 2): + eventSink(events.StateStopping, "Service hasn't stopped gracefully on timeout, aborting") + + return ErrAborted + case err := <-errCh: + return err + } } -func (r *goroutineRunner) wrappedMain() (err error) { +func (r *goroutineRunner) wrappedMain(ctx context.Context) (err error) { defer func() { if r := recover(); r != nil { buf := make([]byte, 8192) @@ -87,7 +113,7 @@ func (r *goroutineRunner) wrappedMain() (err error) { defer writerCloser() //nolint:errcheck - if err = r.main(r.ctx, r.runtime, w); !errors.Is(err, context.Canceled) { + if err = r.main(ctx, r.runtime, w); !errors.Is(err, context.Canceled) { return err // return error if it's not context.Canceled (service was not aborted) } diff --git a/internal/app/machined/pkg/system/runner/goroutine/goroutine_test.go b/internal/app/machined/pkg/system/runner/goroutine/goroutine_test.go index 526d70443c..af6af81b67 100644 --- a/internal/app/machined/pkg/system/runner/goroutine/goroutine_test.go +++ b/internal/app/machined/pkg/system/runner/goroutine/goroutine_test.go @@ -139,6 +139,38 @@ func (suite *GoroutineSuite) TestStop() { suite.Assert().NoError(<-errCh) } +func (suite *GoroutineSuite) TestStuckOnStop() { + r := goroutine.NewRunner(suite.r, "teststop", + func(ctx context.Context, data runtime.Runtime, logger io.Writer) error { + // hanging forever + select {} + }, + runner.WithLoggingManager(suite.loggingManager), + runner.WithGracefulShutdownTimeout(10*time.Millisecond), + ) + + suite.Assert().NoError(r.Open()) + + defer func() { suite.Assert().NoError(r.Close()) }() + + errCh := make(chan error) + + go func() { + errCh <- r.Run(MockEventSink) + }() + + time.Sleep(20 * time.Millisecond) + + select { + case <-errCh: + suite.Require().Fail("should not return yet") + default: + } + + suite.Assert().NoError(r.Stop()) + suite.Assert().ErrorIs(<-errCh, goroutine.ErrAborted) +} + func (suite *GoroutineSuite) TestRunLogs() { r := goroutine.NewRunner(suite.r, "logtest", func(ctx context.Context, data runtime.Runtime, logger io.Writer) error { diff --git a/internal/app/machined/pkg/system/services/apid.go b/internal/app/machined/pkg/system/services/apid.go index 4aafd4f429..f8587fe4c8 100644 --- a/internal/app/machined/pkg/system/services/apid.go +++ b/internal/app/machined/pkg/system/services/apid.go @@ -14,6 +14,7 @@ import ( "path/filepath" "strconv" "strings" + "time" "github.com/containerd/containerd/v2/pkg/cap" "github.com/containerd/containerd/v2/pkg/oci" @@ -200,6 +201,7 @@ func (o *APID) Runner(r runtime.Runtime) (runner.Runner, error) { runner.WithLoggingManager(r.Logging()), runner.WithContainerdAddress(constants.SystemContainerdAddress), runner.WithEnv(env), + runner.WithGracefulShutdownTimeout(15*time.Second), runner.WithCgroupPath(constants.CgroupApid), runner.WithSelinuxLabel(constants.SelinuxLabelApid), runner.WithOCISpecOpts( diff --git a/internal/app/machined/pkg/system/services/trustd.go b/internal/app/machined/pkg/system/services/trustd.go index 333d0c5bbe..42c3a97f04 100644 --- a/internal/app/machined/pkg/system/services/trustd.go +++ b/internal/app/machined/pkg/system/services/trustd.go @@ -13,6 +13,7 @@ import ( "os" "path/filepath" "strconv" + "time" "github.com/containerd/containerd/v2/pkg/cap" "github.com/containerd/containerd/v2/pkg/oci" @@ -164,6 +165,7 @@ func (t *Trustd) Runner(r runtime.Runtime) (runner.Runner, error) { runner.WithContainerdAddress(constants.SystemContainerdAddress), runner.WithEnv(env), runner.WithCgroupPath(constants.CgroupTrustd), + runner.WithGracefulShutdownTimeout(15*time.Second), runner.WithSelinuxLabel(constants.SelinuxLabelTrustd), runner.WithOCISpecOpts( oci.WithDroppedCapabilities(cap.Known()), diff --git a/pkg/chunker/stream/stream.go b/pkg/chunker/stream/stream.go index 76c5b83daa..ad39a1ae93 100644 --- a/pkg/chunker/stream/stream.go +++ b/pkg/chunker/stream/stream.go @@ -9,8 +9,10 @@ import ( "errors" "fmt" "io" + "os" "github.com/siderolabs/gen/xslices" + "github.com/siderolabs/go-circular" "github.com/siderolabs/talos/pkg/chunker" ) @@ -61,27 +63,35 @@ func NewChunker(ctx context.Context, source Source, setters ...Option) chunker.C } // Read implements ChunkReader. +// +//nolint:gocyclo func (c *Stream) Read() <-chan []byte { // Create a buffered channel of length 1. ch := make(chan []byte, 1) go func(ch chan []byte) { defer close(ch) - //nolint:errcheck - defer c.source.Close() + + ctx, cancel := context.WithCancel(c.ctx) + defer cancel() + + go func() { + <-ctx.Done() + c.source.Close() //nolint:errcheck + }() buf := make([]byte, c.options.Size) for { select { - case <-c.ctx.Done(): + case <-ctx.Done(): return default: } n, err := c.source.Read(buf) if err != nil { - if !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe)) { + if !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrClosedPipe) || errors.Is(err, os.ErrClosed) || errors.Is(err, circular.ErrClosed)) { fmt.Printf("read error: %s\n", err.Error()) } @@ -93,7 +103,7 @@ func (c *Stream) Read() <-chan []byte { b := xslices.CopyN(buf, n) select { - case <-c.ctx.Done(): + case <-ctx.Done(): return case ch <- b: } diff --git a/pkg/chunker/stream/stream_test.go b/pkg/chunker/stream/stream_test.go index 6a4b20fb96..90d4872ed8 100644 --- a/pkg/chunker/stream/stream_test.go +++ b/pkg/chunker/stream/stream_test.go @@ -123,10 +123,6 @@ func (suite *StreamChunkerSuite) TestStreamingCancel() { ctxCancel() - // need any I/O for chunker to notice that context got canceled - //nolint:errcheck - suite.writer.Write([]byte("")) - suite.Require().Equal([]byte("abcdefghijklmno"), <-combinedCh) }