From 0aa93f87d62855ad3245f5f8a217399f1e450ed4 Mon Sep 17 00:00:00 2001 From: Peter McAtominey Date: Fri, 31 Jul 2020 23:10:53 +0100 Subject: [PATCH 1/4] mage: cancel context on SIGINT On receiving an interrupt signal, mage cancels the context allowing the magefile to perform any cleanup before exiting. A second interrupt signal will kill the magefile process without delay. The behaviour for a timeout remains unchanged (context is canclled and the magefile exits). --- mage/main.go | 6 +++ mage/main_test.go | 91 +++++++++++++++++++++++++++++++- mage/template.go | 46 ++++++++++------ mage/testdata/signals/signals.go | 47 +++++++++++++++++ 4 files changed, 172 insertions(+), 18 deletions(-) create mode 100644 mage/testdata/signals/signals.go diff --git a/mage/main.go b/mage/main.go index cccb0870..dd1bfc8b 100644 --- a/mage/main.go +++ b/mage/main.go @@ -11,11 +11,13 @@ import ( "log" "os" "os/exec" + "os/signal" "path/filepath" "regexp" "runtime" "sort" "strings" + "syscall" "text/template" "time" @@ -650,6 +652,10 @@ func RunCompiled(inv Invocation, exePath string, errlog *log.Logger) int { c.Env = append(c.Env, fmt.Sprintf("MAGEFILE_TIMEOUT=%s", inv.Timeout.String())) } debug.Print("running magefile with mage vars:\n", strings.Join(filter(c.Env, "MAGEFILE"), "\n")) + // catch SIGINT to allow magefile to handle them + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT) + defer signal.Stop(sigCh) err := c.Run() if !sh.CmdRan(err) { errlog.Printf("failed to run compiled magefile: %v", err) diff --git a/mage/main_test.go b/mage/main_test.go index 9fd5dbb4..a02fa963 100644 --- a/mage/main_test.go +++ b/mage/main_test.go @@ -20,6 +20,7 @@ import ( "runtime" "strconv" "strings" + "syscall" "testing" "time" @@ -1146,7 +1147,7 @@ func TestCompiledFlags(t *testing.T) { if err == nil { t.Fatalf("expected an error because of timeout") } - got = stdout.String() + got = stderr.String() want = "context deadline exceeded" if strings.Contains(got, want) == false { t.Errorf("got %q, does not contain %q", got, want) @@ -1235,7 +1236,7 @@ func TestCompiledEnvironmentVars(t *testing.T) { if err == nil { t.Fatalf("expected an error because of timeout") } - got = stdout.String() + got = stderr.String() want = "context deadline exceeded" if strings.Contains(got, want) == false { t.Errorf("got %q, does not contain %q", got, want) @@ -1305,6 +1306,92 @@ func TestCompiledVerboseFlag(t *testing.T) { } } +func TestSignals(t *testing.T) { + stderr := &bytes.Buffer{} + stdout := &bytes.Buffer{} + dir := "./testdata/signals" + compileDir, err := ioutil.TempDir(dir, "") + if err != nil { + t.Fatal(err) + } + name := filepath.Join(compileDir, "mage_out") + // The CompileOut directory is relative to the + // invocation directory, so chop off the invocation dir. + outName := "./" + name[len(dir)-1:] + defer os.RemoveAll(compileDir) + inv := Invocation{ + Dir: dir, + Stdout: stdout, + Stderr: stderr, + CompileOut: outName, + } + code := Invoke(inv) + if code != 0 { + t.Errorf("expected to exit with code 0, but got %v, stderr: %s", code, stderr) + } + + run := func(stdout, stderr *bytes.Buffer, filename string, target string, signals ...syscall.Signal) error { + stderr.Reset() + stdout.Reset() + cmd := exec.Command(filename, target) + cmd.Stderr = stderr + cmd.Stdout = stdout + if err := cmd.Start(); err != nil { + return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s", + filename, target, err, stdout, stderr) + } + pid := cmd.Process.Pid + go func() { + time.Sleep(time.Millisecond * 500) + for _, s := range signals { + syscall.Kill(pid, s) + time.Sleep(time.Millisecond * 50) + } + }() + if err := cmd.Wait(); err != nil { + return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s", + filename, target, err, stdout, stderr) + } + return nil + } + + if err := run(stdout, stderr, name, "exitsAfterSighup", syscall.SIGHUP); err != nil { + t.Fatal(err) + } + got := stdout.String() + want := "received sighup\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "exitsAfterSigint", syscall.SIGINT); err != nil { + t.Fatal(err) + } + got = stdout.String() + want = "exiting...done\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "exitsAfterCancel", syscall.SIGINT); err != nil { + t.Fatal(err) + } + got = stdout.String() + want = "exiting...done\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "ignoresSignals", syscall.SIGINT, syscall.SIGINT); err == nil { + t.Fatalf("expected an error because of force kill") + } + got = stderr.String() + want = "Error: target killed\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } +} + func TestClean(t *testing.T) { if err := os.RemoveAll(mg.CacheDir()); err != nil { t.Error("error removing cache dir:", err) diff --git a/mage/template.go b/mage/template.go index fbe69719..af7b7007 100644 --- a/mage/template.go +++ b/mage/template.go @@ -14,10 +14,12 @@ import ( "io/ioutil" "log" "os" + "os/signal" "path/filepath" "sort" "strconv" "strings" + "syscall" "text/tabwriter" "time" {{range .Imports}}{{.UniqueName}} "{{.Path}}" @@ -260,17 +262,19 @@ Options: var ctxCancel func() getContext := func() (context.Context, func()) { - if ctx != nil { - return ctx, ctxCancel + if ctx == nil { + ctx, ctxCancel = context.WithCancel(context.Background()) } + return ctx, ctxCancel + } + + getTimeout := func() <-chan time.Time { if args.Timeout != 0 { - ctx, ctxCancel = context.WithTimeout(context.Background(), args.Timeout) - } else { - ctx = context.Background() - ctxCancel = func() {} + return time.After(args.Timeout) } - return ctx, ctxCancel + + return make(chan time.Time) } runTarget := func(fn func(context.Context) error) interface{} { @@ -285,15 +289,25 @@ Options: err := fn(ctx) d <- err }() - select { - case <-ctx.Done(): - cancel() - e := ctx.Err() - fmt.Printf("ctx err: %v\n", e) - return e - case err = <-d: - cancel() - return err + timeoutCh := getTimeout() + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT) + for { + select { + case <-sigCh: + select { + case <-ctx.Done(): + return fmt.Errorf("target killed") + default: + cancel() + } + case <-timeoutCh: + cancel() + return fmt.Errorf("context deadline exceeded") + case err = <-d: + cancel() + return err + } } } // This is necessary in case there aren't any targets, to avoid an unused diff --git a/mage/testdata/signals/signals.go b/mage/testdata/signals/signals.go new file mode 100644 index 00000000..4c58116b --- /dev/null +++ b/mage/testdata/signals/signals.go @@ -0,0 +1,47 @@ +//+build mage + +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "time" +) + +// Exits after receiving SIGHUP +func ExitsAfterSighup(ctx context.Context) { + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGHUP) + <-sigC + fmt.Println("received sighup") +} + +// Exits after SIGINT and wait +func ExitsAfterSigint(ctx context.Context) { + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGINT) + <-sigC + fmt.Printf("exiting...") + time.Sleep(200 * time.Millisecond) + fmt.Println("done") +} + +// Exits after ctx cancel and wait +func ExitsAfterCancel(ctx context.Context) { + <-ctx.Done() + fmt.Printf("exiting...") + time.Sleep(200 * time.Millisecond) + fmt.Println("done") +} + +// Ignores all signals, requires killing +func IgnoresSignals(ctx context.Context) { + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGINT) + for { + <-sigC + } +} From 26b0dea96da76df9f455c7df68961e5e2cd5f03c Mon Sep 17 00:00:00 2001 From: Peter McAtominey Date: Sat, 26 Dec 2020 17:04:55 +0000 Subject: [PATCH 2/4] mage: add cleanup timeout to cancel --- mage/main_test.go | 23 +++++++++++++++++++++-- mage/template.go | 20 +++++++++++++++----- mage/testdata/signals/signals.go | 5 ++++- parse/parse.go | 8 ++++---- 4 files changed, 44 insertions(+), 12 deletions(-) diff --git a/mage/main_test.go b/mage/main_test.go index 692716a0..12cb99ea 100644 --- a/mage/main_test.go +++ b/mage/main_test.go @@ -1387,12 +1387,22 @@ func TestSignals(t *testing.T) { if strings.Contains(got, want) == false { t.Errorf("got %q, does not contain %q", got, want) } + got = stderr.String() + want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } if err := run(stdout, stderr, name, "exitsAfterCancel", syscall.SIGINT); err != nil { t.Fatal(err) } got = stdout.String() - want = "exiting...done\n" + want = "exiting...done\ndeferred cleanup\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + got = stderr.String() + want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\n" if strings.Contains(got, want) == false { t.Errorf("got %q, does not contain %q", got, want) } @@ -1401,7 +1411,16 @@ func TestSignals(t *testing.T) { t.Fatalf("expected an error because of force kill") } got = stderr.String() - want = "Error: target killed\n" + want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\nexiting mage\nError: exit forced\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "ignoresSignals", syscall.SIGINT); err == nil { + t.Fatalf("expected an error because of force kill") + } + got = stderr.String() + want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\nError: cleanup timeout exceeded\n" if strings.Contains(got, want) == false { t.Errorf("got %q, does not contain %q", got, want) } diff --git a/mage/template.go b/mage/template.go index 2da2c884..09103167 100644 --- a/mage/template.go +++ b/mage/template.go @@ -276,7 +276,7 @@ Options: return make(chan time.Time) } - runTarget := func(fn func(context.Context) error) interface{} { + runTarget := func(logger *log.Logger, fn func(context.Context) error) interface{} { var err interface{} ctx, cancel := getContext() d := make(chan interface{}) @@ -294,11 +294,21 @@ Options: for { select { case <-sigCh: + logger.Println("cancelling mage targets, waiting up to 5 seconds for cleanup...") + cancel() + cleanupCh := time.After(5 * time.Second) + select { - case <-ctx.Done(): - return fmt.Errorf("target killed") - default: - cancel() + // target exited by itself + case err = <-d: + return err + // cleanup timeout exceeded + case <-cleanupCh: + return fmt.Errorf("cleanup timeout exceeded") + // second SIGINT received + case <-sigCh: + logger.Println("exiting mage") + return fmt.Errorf("exit forced") } case <-timeoutCh: cancel() diff --git a/mage/testdata/signals/signals.go b/mage/testdata/signals/signals.go index 4c58116b..96d62552 100644 --- a/mage/testdata/signals/signals.go +++ b/mage/testdata/signals/signals.go @@ -31,13 +31,16 @@ func ExitsAfterSigint(ctx context.Context) { // Exits after ctx cancel and wait func ExitsAfterCancel(ctx context.Context) { + defer func() { + fmt.Println("deferred cleanup") + }() <-ctx.Done() fmt.Printf("exiting...") time.Sleep(200 * time.Millisecond) fmt.Println("done") } -// Ignores all signals, requires killing +// Ignores all signals, requires killing via timeout or second SIGINT func IgnoresSignals(ctx context.Context) { sigC := make(chan os.Signal, 1) signal.Notify(sigC, syscall.SIGINT) diff --git a/parse/parse.go b/parse/parse.go index 1549f9e6..7227137e 100644 --- a/parse/parse.go +++ b/parse/parse.go @@ -93,7 +93,7 @@ func (f Function) ExecCode() (string, error) { wrapFn := func(ctx context.Context) error { return %s(ctx) } - err := runTarget(wrapFn)`[1:] + err := runTarget(logger, wrapFn)`[1:] return fmt.Sprintf(out, name), nil } if f.IsContext && !f.IsError { @@ -102,7 +102,7 @@ func (f Function) ExecCode() (string, error) { %s(ctx) return nil } - err := runTarget(wrapFn)`[1:] + err := runTarget(logger, wrapFn)`[1:] return fmt.Sprintf(out, name), nil } if !f.IsContext && f.IsError { @@ -110,7 +110,7 @@ func (f Function) ExecCode() (string, error) { wrapFn := func(ctx context.Context) error { return %s() } - err := runTarget(wrapFn)`[1:] + err := runTarget(logger, wrapFn)`[1:] return fmt.Sprintf(out, name), nil } if !f.IsContext && !f.IsError { @@ -119,7 +119,7 @@ func (f Function) ExecCode() (string, error) { %s() return nil } - err := runTarget(wrapFn)`[1:] + err := runTarget(logger, wrapFn)`[1:] return fmt.Sprintf(out, name), nil } return "", fmt.Errorf("Error formatting ExecCode code for %#v", f) From 1314893cb38e4e6c2b4530c66907cac105346732 Mon Sep 17 00:00:00 2001 From: Nate Finch Date: Mon, 28 Nov 2022 23:05:46 -0500 Subject: [PATCH 3/4] switch back to using a timeout on the context, so the timeout spans calls to different targets --- mage/template.go | 68 +++++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/mage/template.go b/mage/template.go index 09103167..733e60dd 100644 --- a/mage/template.go +++ b/mage/template.go @@ -258,24 +258,26 @@ Options: } var ctx context.Context - var ctxCancel func() + ctxCancel := func(){} + + // by deferring in a closure, we let the cancel function get replaced + // by the getContext function. + defer func() { + ctxCancel() + }() getContext := func() (context.Context, func()) { if ctx == nil { - ctx, ctxCancel = context.WithCancel(context.Background()) + if args.Timeout != 0 { + ctx, ctxCancel = context.WithTimeout(context.Background(), args.Timeout) + } else { + ctx, ctxCancel = context.WithCancel(context.Background()) + } } return ctx, ctxCancel } - getTimeout := func() <-chan time.Time { - if args.Timeout != 0 { - return time.After(args.Timeout) - } - - return make(chan time.Time) - } - runTarget := func(logger *log.Logger, fn func(context.Context) error) interface{} { var err interface{} ctx, cancel := getContext() @@ -288,35 +290,35 @@ Options: err := fn(ctx) d <- err }() - timeoutCh := getTimeout() sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT) - for { + select { + case <-sigCh: + logger.Println("cancelling mage targets, waiting up to 5 seconds for cleanup...") + cancel() + cleanupCh := time.After(5 * time.Second) + select { - case <-sigCh: - logger.Println("cancelling mage targets, waiting up to 5 seconds for cleanup...") - cancel() - cleanupCh := time.After(5 * time.Second) - - select { - // target exited by itself - case err = <-d: - return err - // cleanup timeout exceeded - case <-cleanupCh: - return fmt.Errorf("cleanup timeout exceeded") - // second SIGINT received - case <-sigCh: - logger.Println("exiting mage") - return fmt.Errorf("exit forced") - } - case <-timeoutCh: - cancel() - return fmt.Errorf("context deadline exceeded") + // target exited by itself case err = <-d: - cancel() return err + // cleanup timeout exceeded + case <-cleanupCh: + return fmt.Errorf("cleanup timeout exceeded") + // second SIGINT received + case <-sigCh: + logger.Println("exiting mage") + return fmt.Errorf("exit forced") } + case <-ctx.Done(): + cancel() + e := ctx.Err() + fmt.Printf("ctx err: %v\n", e) + return e + case err = <-d: + // we intentionally don't cancel the context here, because + // the next target will need to run with the same context. + return err } } // This is necessary in case there aren't any targets, to avoid an unused From e340775b6ce40238e107137127996f733a82b157 Mon Sep 17 00:00:00 2001 From: Nate Finch Date: Mon, 28 Nov 2022 23:29:11 -0500 Subject: [PATCH 4/4] test on 1.18, too --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ffbc5871..7463f9f7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,6 +9,7 @@ jobs: fail-fast: false matrix: go-version: + - 1.18.x - 1.17.x - 1.16.x - 1.15.x