Skip to content

Commit

Permalink
windows: wait for control actions before returning
Browse files Browse the repository at this point in the history
Also buffer the OS signal so it's not potentially lost during Run.
  • Loading branch information
djdv committed May 12, 2021
1 parent ef35c56 commit a0e7b19
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 71 deletions.
16 changes: 11 additions & 5 deletions service_test.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package service_test package service_test


import ( import (
"fmt"
"os" "os"
"testing" "testing"
"time" "time"
Expand All @@ -22,22 +23,27 @@ func TestRunInterrupt(t *testing.T) {
t.Fatalf("New err: %s", err) t.Fatalf("New err: %s", err)
} }


retChan := make(chan error)
go func() {
if err = s.Run(); err != nil {
retChan <- fmt.Errorf("Run() err: %w", err)
}
}()
go func() { go func() {
<-time.After(1 * time.Second) <-time.After(1 * time.Second)
interruptProcess(t) interruptProcess(t)
}()


go func() {
for i := 0; i < 25 && p.numStopped == 0; i++ { for i := 0; i < 25 && p.numStopped == 0; i++ {
<-time.After(200 * time.Millisecond) <-time.After(200 * time.Millisecond)
} }
if p.numStopped == 0 { if p.numStopped == 0 {
t.Fatal("Run() hasn't been stopped") retChan <- fmt.Errorf("Run() hasn't been stopped")
} }
retChan <- nil
}() }()


if err = s.Run(); err != nil { if err = <-retChan; err != nil {
t.Fatalf("Run() err: %s", err) t.Fatal(err)
} }
} }


Expand Down
255 changes: 189 additions & 66 deletions service_windows.go
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -160,48 +160,64 @@ func (ws *windowsService) getError() error {
return ws.stopStartErr return ws.stopStartErr
} }


func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (bool, uint32) { func (ws *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, changes chan<- svc.Status) (ssec bool, exitCode uint32) {
const cmdsAccepted = svc.AcceptStop | svc.AcceptShutdown const exitFailure = 1

var err error
defer func() {
if err != nil {
ws.setError(err)
exitCode = exitFailure
}
}()

// Signal that we're starting.
changes <- svc.Status{State: svc.StartPending} changes <- svc.Status{State: svc.StartPending}


if err := ws.i.Start(ws); err != nil { // Perform the actual start.
ws.setError(err) if initErr := ws.i.Start(ws); initErr != nil {
return true, 1 err = initErr
return
} }


changes <- svc.Status{State: svc.Running, Accepts: cmdsAccepted} // Signal that we're ready.
changes <- svc.Status{
State: svc.Running,
Accepts: svc.AcceptStop | svc.AcceptShutdown,
}

// Expect service change requests.
var exitFunc func(s Service) error
loop: loop:
for { for c := range r {
c := <-r
switch c.Cmd { switch c.Cmd {
case svc.Interrogate: case svc.Interrogate:
changes <- c.CurrentStatus changes <- c.CurrentStatus
case svc.Stop:
changes <- svc.Status{State: svc.StopPending}
if err := ws.i.Stop(ws); err != nil {
ws.setError(err)
return true, 2
}
break loop
case svc.Shutdown: case svc.Shutdown:
changes <- svc.Status{State: svc.StopPending} if shutdowner, ok := ws.i.(Shutdowner); ok {
var err error exitFunc = shutdowner.Shutdown
if wsShutdown, ok := ws.i.(Shutdowner); ok { break loop
err = wsShutdown.Shutdown(ws)
} else {
err = ws.i.Stop(ws)
}
if err != nil {
ws.setError(err)
return true, 2
} }
fallthrough
case svc.Stop:
exitFunc = ws.i.Stop
break loop break loop
default: default:
continue loop err = fmt.Errorf("unexpected control request: %v", c.Cmd)
break loop
}
}

// We were requested to stop.
changes <- svc.Status{State: svc.StopPending}
if exitErr := exitFunc(ws); exitErr != nil {
if err != nil {
exitErr = fmt.Errorf("%s - %w", err, exitErr)
} }
err = exitErr
} }


return false, 0 return
} }


func (ws *windowsService) Install() error { func (ws *windowsService) Install() error {
Expand Down Expand Up @@ -249,19 +265,55 @@ func (ws *windowsService) Uninstall() error {
return err return err
} }
defer m.Disconnect() defer m.Disconnect()

// MSDN:
// "The DeleteService function marks a service for deletion
// from the service control manager database.
// The database entry is not removed until all open handles
// to the service have been closed by calls to the CloseServiceHandle function,
// and the service is not running."
//
// Since we want to try and wait for the delete to actually happen.
// We close this handle manually when appropriate.
s, err := m.OpenService(ws.Name) s, err := m.OpenService(ws.Name)
if err != nil { if err != nil {
return fmt.Errorf("service %s is not installed", ws.Name) return fmt.Errorf("service %s is not installed", ws.Name)
} }
defer s.Close()
err = s.Delete() if err = s.Delete(); err != nil {
if err != nil { s.Close()
return err return err
} }
err = eventlog.Remove(ws.Name) if err = eventlog.Remove(ws.Name); err != nil {
if err != nil { s.Close()
return fmt.Errorf("RemoveEventLogSource() failed: %s", err) return fmt.Errorf("RemoveEventLogSource() failed: %s", err)
} }

// Service is now marked for deletion by the system.
// Release our handle to it.
if err := s.Close(); err != nil {
return err
}

// Try to get the service handle back.
// If we get an error from the manager,
// we know the service has been deleted.
// Otherwise, we'll block and keep checking
// until the something returns an error, or we give up.
// Since the service is already marked for deletion,
// we don't consider the unblocking condition to be an error.
// But the service will still exist in the service manager's scope.
// And the caller of Uninstall will be on their own from there.
for attempts := 10; attempts != 0; attempts-- {
s, err := m.OpenService(ws.Name)
if err != nil {
break // expected
}
if err := s.Close(); err != nil {
return err
}
time.Sleep(100 * time.Millisecond)
}
return nil return nil
} }


Expand All @@ -287,7 +339,7 @@ func (ws *windowsService) Run() error {
return err return err
} }


sigChan := make(chan os.Signal) sigChan := make(chan os.Signal, 1)


signal.Notify(sigChan, os.Interrupt) signal.Notify(sigChan, os.Interrupt)


Expand Down Expand Up @@ -349,26 +401,20 @@ func (ws *windowsService) Start() error {
return err return err
} }
defer s.Close() defer s.Close()
return s.Start()
}


func (ws *windowsService) Stop() error { if err = maybeWaitForPending(s); err != nil {
m, err := mgr.Connect()
if err != nil {
return err return err
} }
defer m.Disconnect()


s, err := m.OpenService(ws.Name) initErr := s.Start()
if err != nil { if initErr != nil {
return err return initErr
} }
defer s.Close()


return ws.stopWait(s) return maybeWaitForPending(s)
} }


func (ws *windowsService) Restart() error { func (ws *windowsService) Stop() error {
m, err := mgr.Connect() m, err := mgr.Connect()
if err != nil { if err != nil {
return err return err
Expand All @@ -381,42 +427,119 @@ func (ws *windowsService) Restart() error {
} }
defer s.Close() defer s.Close()


err = ws.stopWait(s) if err = maybeWaitForPending(s); err != nil {
if err != nil {
return err return err
} }


return s.Start() if _, err = s.Control(svc.Stop); err != nil {
return err
}

return maybeWaitForPending(s)
} }


func (ws *windowsService) stopWait(s *mgr.Service) error { func (ws *windowsService) Restart() error {
// First stop the service. Then wait for the service to if err := ws.Stop(); err != nil {
// actually stop before starting it.
status, err := s.Control(svc.Stop)
if err != nil {
return err return err
} }
return ws.Start()
}

// statusInterval retreives a (bounded) duration from the status,
// or provides a default.
func statusInterval(status svc.Status) time.Duration {
// MSDN:
// "Do not wait longer than the wait hint. A good interval is
// one-tenth of the wait hint but not less than 1 second
// and not more than 10 seconds."
const (
lower = time.Second
upper = time.Second * 10
)

waitDuration := (time.Duration(status.WaitHint) * time.Millisecond) / 10
if waitDuration < lower {
waitDuration = lower
} else if waitDuration > upper {
waitDuration = upper
}
return waitDuration
}

// waitForStateChange polls the service until its state matches the desiredState,
// and error is encountered, or we timeout.
func waitForStateChange(s *mgr.Service, currentStatus svc.Status, desiredState svc.State) error {
const defaultAttempts = 10
var (
initialInterval = statusInterval(currentStatus)
queryTicker = time.NewTicker(initialInterval)
queryTimer *time.Timer
)
// If the service is providing hints,
// use them, otherwise use a default timeout.
if currentStatus.CheckPoint != 0 {
queryTimer = time.NewTimer(initialInterval)
} else {
queryTimer = time.NewTimer(initialInterval * defaultAttempts)
}
defer func() {
queryTicker.Stop()
queryTimer.Stop()
}()

var (
currentState = currentStatus.State
lastCheckpoint uint32
)
for currentState != desiredState {
select {
case <-queryTicker.C:
currentStatus, queryErr := s.Query()
if queryErr != nil {
return queryErr
}


timeDuration := time.Millisecond * 50 currentState = currentStatus.State

if currentState == desiredState {
timeout := time.After(getStopTimeout() + (timeDuration * 2)) return nil
tick := time.NewTicker(timeDuration) }
defer tick.Stop()


for status.State != svc.Stopped { if currentStatus.CheckPoint > lastCheckpoint {
select { // Service progressed,
case <-tick.C: // give it more time to complete.
status, err = s.Query() if !queryTimer.Stop() {
if err != nil { <-queryTimer.C
return err }
queryTimer.Reset(statusInterval(currentStatus))
} }
case <-timeout: lastCheckpoint = currentStatus.CheckPoint
break case <-queryTimer.C:
return fmt.Errorf("service did not enter desired state (%v) before we timed out",
desiredState)
} }
} }
return nil return nil
} }


func maybeWaitForPending(s *mgr.Service) error {
status, err := s.Query()
if err != nil {
return err
}

var wantState svc.State
switch status.State {
case svc.StartPending:
wantState = svc.Running
case svc.StopPending:
wantState = svc.Stopped
default:
return nil
}

return waitForStateChange(s, status, wantState)
}

// getStopTimeout fetches the time before windows will kill the service. // getStopTimeout fetches the time before windows will kill the service.
func getStopTimeout() time.Duration { func getStopTimeout() time.Duration {
// For default and paths see https://support.microsoft.com/en-us/kb/146092 // For default and paths see https://support.microsoft.com/en-us/kb/146092
Expand Down

0 comments on commit a0e7b19

Please sign in to comment.