diff --git a/signal.go b/signal.go index a8ba94b..2a83a1d 100644 --- a/signal.go +++ b/signal.go @@ -8,6 +8,7 @@ import ( "runtime" "sync" "sync/atomic" + "syscall" "time" ) @@ -135,3 +136,32 @@ func (s *signalCtx) cancel(err error) { close(s.done) }) } + +// IsSignal returns true if the given error is a *SignalError that was +// generated upon receipt of one of the given signals. If no signal is +// passed, the function only tests for err to be of type *SginalError. +func IsSignal(err error, signals ...os.Signal) bool { + if e, ok := err.(*SignalError); ok { + if len(signals) == 0 { + return true + } + for _, signal := range signals { + if signal == e.Signal { + return true + } + } + } + return false +} + +// IsTermination returns true if the given error was caused by receiving a +// termination signal. +func IsTermination(err error) bool { + return IsSignal(err, syscall.SIGTERM) +} + +// IsInterruption returns true if the given error was caused by receiving an +// interruption signal. +func IsInterruption(err error) bool { + return IsSignal(err, syscall.SIGINT) +} diff --git a/signal_test.go b/signal_test.go index 478529c..69a6792 100644 --- a/signal_test.go +++ b/signal_test.go @@ -4,6 +4,7 @@ import ( "context" "os" "reflect" + "syscall" "testing" "time" ) @@ -117,3 +118,21 @@ func TestWithSignals(t *testing.T) { } }) } + +func TestIsTermination(t *testing.T) { + if !IsTermination(&SignalError{Signal: syscall.SIGTERM}) { + t.Error("SIGTERM wasn't recognized as a termination error") + } + if IsTermination(&SignalError{Signal: syscall.SIGINT}) { + t.Error("SIGINT was mistakenly recognized as a termination error") + } +} + +func TestIsInterruption(t *testing.T) { + if !IsInterruption(&SignalError{Signal: syscall.SIGINT}) { + t.Error("SIGINT wasn't recognized as a interruption error") + } + if IsInterruption(&SignalError{Signal: syscall.SIGTERM}) { + t.Error("SIGTERM was mistakenly recognized as a interruption error") + } +}