diff --git a/examples/README.md b/examples/README.md index 79b48b3..ce89ae6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -5,6 +5,7 @@ We recommend you follow the examples in the following order: ## Basics 1. [Simple](./simple) +1. [Graceful Shutdown](./graceful-shutdown) 1. [Server banner and middleware](./banner) 1. [Identifying Users](./identity) 1. [Multiple authentication types](./multi-auth) diff --git a/examples/graceful-shutdown/main.go b/examples/graceful-shutdown/main.go new file mode 100644 index 0000000..bf24120 --- /dev/null +++ b/examples/graceful-shutdown/main.go @@ -0,0 +1,73 @@ +package main + +import ( + "context" + "errors" + "net" + "os" + "os/signal" + "syscall" + "time" + + "github.com/charmbracelet/log" + "github.com/charmbracelet/ssh" + "github.com/charmbracelet/wish" + "github.com/charmbracelet/wish/logging" +) + +const ( + host = "localhost" + port = "23234" +) + +func main() { + srv, err := wish.NewServer( + wish.WithAddress(net.JoinHostPort(host, port)), + wish.WithHostKeyPath(".ssh/id_ed25519"), + wish.WithMiddleware( + func(next ssh.Handler) ssh.Handler { + return func(sess ssh.Session) { + wish.Println(sess, "Hello, world!") + next(sess) + } + }, + logging.Middleware(), + ), + ) + if err != nil { + log.Error("Could not start server", "error", err) + } + + // Before starting our server, we create a channel and listen for some + // common interrupt signals. + done := make(chan os.Signal, 1) + signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + // We then start the server in a goroutine, as we'll listen for the done + // signal later. + go func() { + log.Info("Starting SSH server", "host", host, "port", port) + if err = srv.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + // We ignore ErrServerClosed because it is expected. + log.Error("Could not start server", "error", err) + done <- nil + } + }() + + // Here we wait for the done signal: this can be either an interrupt, or + // the server shutting down for any other reason. + <-done + + // When it arrives, we create a context with a timeout. + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer func() { cancel() }() + + // When we start the shutdown, the server will no longer accept new + // connections, but will wait as much as the given context allows for the + // active connections to finish. + // After the timeout, it shuts down anyway. + log.Info("Stopping SSH server") + if err := srv.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + log.Error("Could not stop server", "error", err) + } +} diff --git a/examples/simple/main.go b/examples/simple/main.go index 82dfbfc..2266f0c 100644 --- a/examples/simple/main.go +++ b/examples/simple/main.go @@ -1,13 +1,8 @@ package main import ( - "context" "errors" "net" - "os" - "os/signal" - "syscall" - "time" "github.com/charmbracelet/log" "github.com/charmbracelet/ssh" @@ -48,32 +43,9 @@ func main() { log.Error("Could not start server", "error", err) } - // Before starting our server, we create a channel and listen for some - // common interrupt signals. - done := make(chan os.Signal, 1) - signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) - - go func() { - log.Info("Starting SSH server", "host", host, "port", port) - if err = srv.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { - // We ignore ErrServerClosed because it is expected. - log.Error("Could not start server", "error", err) - done <- nil - } - }() - - // Here we wait for the done signal. - // When it arrives, we create a context and start the shutdown. - <-done - - // When we start the shutdown, the server will no longer accept new - // connections, but will wait as much as the given context allows for the - // active connections to finish. - // After the timeout, it shuts down anyway. - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer func() { cancel() }() - log.Info("Stopping SSH server") - if err := srv.Shutdown(ctx); err != nil && !errors.Is(err, ssh.ErrServerClosed) { - log.Error("Could not stop server", "error", err) + log.Info("Starting SSH server", "host", host, "port", port) + if err = srv.ListenAndServe(); err != nil && !errors.Is(err, ssh.ErrServerClosed) { + // We ignore ErrServerClosed because it is expected. + log.Error("Could not start server", "error", err) } }