diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ffbc5871..7463f9f7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -9,6 +9,7 @@ jobs: fail-fast: false matrix: go-version: + - 1.18.x - 1.17.x - 1.16.x - 1.15.x diff --git a/mage/main.go b/mage/main.go index 65964228..0a50da6d 100644 --- a/mage/main.go +++ b/mage/main.go @@ -12,11 +12,13 @@ import ( "log" "os" "os/exec" + "os/signal" "path/filepath" "regexp" "runtime" "sort" "strings" + "syscall" "text/template" "time" @@ -737,6 +739,10 @@ func RunCompiled(inv Invocation, exePath string, errlog *log.Logger) int { c.Env = append(c.Env, fmt.Sprintf("MAGEFILE_TIMEOUT=%s", inv.Timeout.String())) } debug.Print("running magefile with mage vars:\n", strings.Join(filter(c.Env, "MAGEFILE"), "\n")) + // catch SIGINT to allow magefile to handle them + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT) + defer signal.Stop(sigCh) err := c.Run() if !sh.CmdRan(err) { errlog.Printf("failed to run compiled magefile: %v", err) diff --git a/mage/main_test.go b/mage/main_test.go index 2074c64a..4ea4d0b1 100644 --- a/mage/main_test.go +++ b/mage/main_test.go @@ -22,6 +22,7 @@ import ( "runtime" "strconv" "strings" + "syscall" "testing" "time" @@ -1292,7 +1293,7 @@ func TestCompiledFlags(t *testing.T) { if err == nil { t.Fatalf("expected an error because of timeout") } - got = stdout.String() + got = stderr.String() want = "context deadline exceeded" if strings.Contains(got, want) == false { t.Errorf("got %q, does not contain %q", got, want) @@ -1384,7 +1385,7 @@ func TestCompiledEnvironmentVars(t *testing.T) { if err == nil { t.Fatalf("expected an error because of timeout") } - got = stdout.String() + got = stderr.String() want = "context deadline exceeded" if strings.Contains(got, want) == false { t.Errorf("got %q, does not contain %q", got, want) @@ -1457,6 +1458,111 @@ func TestCompiledVerboseFlag(t *testing.T) { } } +func TestSignals(t *testing.T) { + stderr := &bytes.Buffer{} + stdout := &bytes.Buffer{} + dir := "./testdata/signals" + compileDir, err := ioutil.TempDir(dir, "") + if err != nil { + t.Fatal(err) + } + name := filepath.Join(compileDir, "mage_out") + // The CompileOut directory is relative to the + // invocation directory, so chop off the invocation dir. + outName := "./" + name[len(dir)-1:] + defer os.RemoveAll(compileDir) + inv := Invocation{ + Dir: dir, + Stdout: stdout, + Stderr: stderr, + CompileOut: outName, + } + code := Invoke(inv) + if code != 0 { + t.Errorf("expected to exit with code 0, but got %v, stderr: %s", code, stderr) + } + + run := func(stdout, stderr *bytes.Buffer, filename string, target string, signals ...syscall.Signal) error { + stderr.Reset() + stdout.Reset() + cmd := exec.Command(filename, target) + cmd.Stderr = stderr + cmd.Stdout = stdout + if err := cmd.Start(); err != nil { + return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s", + filename, target, err, stdout, stderr) + } + pid := cmd.Process.Pid + go func() { + time.Sleep(time.Millisecond * 500) + for _, s := range signals { + syscall.Kill(pid, s) + time.Sleep(time.Millisecond * 50) + } + }() + if err := cmd.Wait(); err != nil { + return fmt.Errorf("running '%s %s' failed with: %v\nstdout: %s\nstderr: %s", + filename, target, err, stdout, stderr) + } + return nil + } + + if err := run(stdout, stderr, name, "exitsAfterSighup", syscall.SIGHUP); err != nil { + t.Fatal(err) + } + got := stdout.String() + want := "received sighup\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "exitsAfterSigint", syscall.SIGINT); err != nil { + t.Fatal(err) + } + got = stdout.String() + want = "exiting...done\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + got = stderr.String() + want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "exitsAfterCancel", syscall.SIGINT); err != nil { + t.Fatal(err) + } + got = stdout.String() + want = "exiting...done\ndeferred cleanup\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + got = stderr.String() + want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "ignoresSignals", syscall.SIGINT, syscall.SIGINT); err == nil { + t.Fatalf("expected an error because of force kill") + } + got = stderr.String() + want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\nexiting mage\nError: exit forced\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } + + if err := run(stdout, stderr, name, "ignoresSignals", syscall.SIGINT); err == nil { + t.Fatalf("expected an error because of force kill") + } + got = stderr.String() + want = "cancelling mage targets, waiting up to 5 seconds for cleanup...\nError: cleanup timeout exceeded\n" + if strings.Contains(got, want) == false { + t.Errorf("got %q, does not contain %q", got, want) + } +} + func TestCompiledDeterministic(t *testing.T) { dir := "./testdata/compiled" compileDir, err := ioutil.TempDir(dir, "") diff --git a/mage/template.go b/mage/template.go index 4f7125eb..af822f0f 100644 --- a/mage/template.go +++ b/mage/template.go @@ -14,10 +14,12 @@ import ( _ioutil "io/ioutil" _log "log" "os" + "os/signal" _filepath "path/filepath" _sort "sort" "strconv" _strings "strings" + "syscall" _tabwriter "text/tabwriter" "time" {{range .Imports}}{{.UniqueName}} "{{.Path}}" @@ -256,23 +258,27 @@ Options: } var ctx context.Context - var ctxCancel func() + ctxCancel := func(){} + + // by deferring in a closure, we let the cancel function get replaced + // by the getContext function. + defer func() { + ctxCancel() + }() getContext := func() (context.Context, func()) { - if ctx != nil { - return ctx, ctxCancel + if ctx == nil { + if args.Timeout != 0 { + ctx, ctxCancel = context.WithTimeout(context.Background(), args.Timeout) + } else { + ctx, ctxCancel = context.WithCancel(context.Background()) + } } - if args.Timeout != 0 { - ctx, ctxCancel = context.WithTimeout(context.Background(), args.Timeout) - } else { - ctx = context.Background() - ctxCancel = func() {} - } return ctx, ctxCancel } - runTarget := func(fn func(context.Context) error) interface{} { + runTarget := func(logger *_log.Logger, fn func(context.Context) error) interface{} { var err interface{} ctx, cancel := getContext() d := make(chan interface{}) @@ -284,14 +290,34 @@ Options: err := fn(ctx) d <- err }() + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT) select { + case <-sigCh: + logger.Println("cancelling mage targets, waiting up to 5 seconds for cleanup...") + cancel() + cleanupCh := time.After(5 * time.Second) + + select { + // target exited by itself + case err = <-d: + return err + // cleanup timeout exceeded + case <-cleanupCh: + return _fmt.Errorf("cleanup timeout exceeded") + // second SIGINT received + case <-sigCh: + logger.Println("exiting mage") + return _fmt.Errorf("exit forced") + } case <-ctx.Done(): cancel() e := ctx.Err() _fmt.Printf("ctx err: %v\n", e) return e case err = <-d: - cancel() + // we intentionally don't cancel the context here, because + // the next target will need to run with the same context. return err } } diff --git a/mage/testdata/signals/signals.go b/mage/testdata/signals/signals.go new file mode 100644 index 00000000..96d62552 --- /dev/null +++ b/mage/testdata/signals/signals.go @@ -0,0 +1,50 @@ +//+build mage + +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "time" +) + +// Exits after receiving SIGHUP +func ExitsAfterSighup(ctx context.Context) { + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGHUP) + <-sigC + fmt.Println("received sighup") +} + +// Exits after SIGINT and wait +func ExitsAfterSigint(ctx context.Context) { + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGINT) + <-sigC + fmt.Printf("exiting...") + time.Sleep(200 * time.Millisecond) + fmt.Println("done") +} + +// Exits after ctx cancel and wait +func ExitsAfterCancel(ctx context.Context) { + defer func() { + fmt.Println("deferred cleanup") + }() + <-ctx.Done() + fmt.Printf("exiting...") + time.Sleep(200 * time.Millisecond) + fmt.Println("done") +} + +// Ignores all signals, requires killing via timeout or second SIGINT +func IgnoresSignals(ctx context.Context) { + sigC := make(chan os.Signal, 1) + signal.Notify(sigC, syscall.SIGINT) + for { + <-sigC + } +} diff --git a/parse/parse.go b/parse/parse.go index 14c5bcd9..48cf1ece 100644 --- a/parse/parse.go +++ b/parse/parse.go @@ -169,7 +169,7 @@ func (f Function) ExecCode() string { } out += ` } - ret := runTarget(wrapFn)` + ret := runTarget(logger, wrapFn)` return out }