diff --git a/internal/app/machined/internal/server/v1alpha1/v1alpha1_server.go b/internal/app/machined/internal/server/v1alpha1/v1alpha1_server.go index 82a9ac635b..da1e366ca3 100644 --- a/internal/app/machined/internal/server/v1alpha1/v1alpha1_server.go +++ b/internal/app/machined/internal/server/v1alpha1/v1alpha1_server.go @@ -450,12 +450,10 @@ func (s *Server) Shutdown(ctx context.Context, in *machine.ShutdownRequest) (rep // Upgrade initiates an upgrade. // -//nolint:gocyclo,cyclop +//nolint:gocyclo func (s *Server) Upgrade(ctx context.Context, in *machine.UpgradeRequest) (*machine.UpgradeResponse, error) { actorID := uuid.New().String() - var mu *concurrency.Mutex - ctx = context.WithValue(ctx, runtime.ActorIDCtxKey{}, actorID) if err := s.checkSupported(runtime.Upgrade); err != nil { @@ -471,23 +469,21 @@ func (s *Server) Upgrade(ctx context.Context, in *machine.UpgradeRequest) (*mach } if s.Controller.Runtime().Config().Machine().Type() != machinetype.TypeWorker && !in.GetForce() { - client, err := etcd.NewClientFromControlPlaneIPs(ctx, s.Controller.Runtime().State().V1Alpha2().Resources()) + etcdClient, err := etcd.NewClientFromControlPlaneIPs(ctx, s.Controller.Runtime().State().V1Alpha2().Resources()) if err != nil { return nil, fmt.Errorf("failed to create etcd client: %w", err) } // acquire the upgrade mutex - if mu, err = upgradeMutex(client); err != nil { + unlocker, err := tryLockUpgradeMutex(ctx, etcdClient) + if err != nil { return nil, fmt.Errorf("failed to acquire upgrade mutex: %w", err) } - if err = mu.TryLock(ctx); err != nil { - return nil, fmt.Errorf("failed to acquire upgrade lock: %w", err) - } - - if err = client.ValidateForUpgrade(ctx, s.Controller.Runtime().Config(), in.GetPreserve()); err != nil { - mu.Unlock(ctx) //nolint:errcheck + // unlock the mutex once the API call is done, as it protects only pre-upgrade checks + defer unlocker() + if err = etcdClient.ValidateForUpgrade(ctx, s.Controller.Runtime().Config(), in.GetPreserve()); err != nil { return nil, fmt.Errorf("error validating etcd for upgrade: %w", err) } } @@ -520,10 +516,6 @@ func (s *Server) Upgrade(ctx context.Context, in *machine.UpgradeRequest) (*mach } go func() { - if mu != nil { - defer mu.Unlock(ctx) //nolint:errcheck - } - if err := s.Controller.Run(runCtx, runtime.SequenceStageUpgrade, in); err != nil { if !runtime.IsRebootError(err) { log.Println("reboot for staged upgrade failed:", err) @@ -532,10 +524,6 @@ func (s *Server) Upgrade(ctx context.Context, in *machine.UpgradeRequest) (*mach }() } else { go func() { - if mu != nil { - defer mu.Unlock(ctx) //nolint:errcheck - } - if err := s.Controller.Run(runCtx, runtime.SequenceUpgrade, in); err != nil { if !runtime.IsRebootError(err) { log.Println("upgrade failed:", err) @@ -2268,17 +2256,37 @@ func capturePackets(ctx context.Context, w io.Writer, handle *afpacket.TPacket, } } -func upgradeMutex(c *etcd.Client) (*concurrency.Mutex, error) { - sess, err := concurrency.NewSession(c.Client, +func tryLockUpgradeMutex(ctx context.Context, etcdClient *etcd.Client) (unlock func(), err error) { + sess, err := concurrency.NewSession(etcdClient.Client, + concurrency.WithContext(ctx), concurrency.WithTTL(MinimumEtcdUpgradeLeaseLockSeconds), ) if err != nil { - return nil, err + return nil, fmt.Errorf("error establishing etcd concurrency session: %w", err) } mu := concurrency.NewMutex(sess, constants.EtcdTalosEtcdUpgradeMutex) - return mu, nil + if err = mu.TryLock(ctx); err != nil { + return nil, fmt.Errorf("error trying to lock etcd upgrade mutex: %w", err) + } + + log.Printf("etcd upgrade mutex locked with session ID %08x", sess.Lease()) + + return func() { + unlockCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := mu.Unlock(unlockCtx); err != nil { + log.Printf("error unlocking etcd upgrade mutex: %v", err) + } + + if err := sess.Close(); err != nil { + log.Printf("error closing etcd upgrade mutex session: %v", err) + } + + log.Printf("etcd upgrade mutex unlocked and session closed") + }, nil } // Netstat implements the machine.MachineServer interface.