Skip to content

Commit

Permalink
Add Server.Shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
grahamtonysmith authored Mar 21, 2023
1 parent 90d596c commit e7e393a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 2 deletions.
52 changes: 50 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package smtp

import (
"context"
"crypto/tls"
"errors"
"io"
Expand All @@ -13,7 +14,10 @@ import (
"github.com/emersion/go-sasl"
)

var errTCPAndLMTP = errors.New("smtp: cannot start LMTP server listening on a TCP socket")
var (
errTCPAndLMTP = errors.New("smtp: cannot start LMTP server listening on a TCP socket")
ErrServerClosed = errors.New("smtp: server already closed")
)

// A function that creates SASL servers.
type SaslServerFactory func(conn *Conn) sasl.Server
Expand Down Expand Up @@ -64,6 +68,8 @@ type Server struct {
// The server backend.
Backend Backend

wg sync.WaitGroup

caps []string
auths map[string]SaslServerFactory
done chan struct{}
Expand Down Expand Up @@ -135,7 +141,11 @@ func (s *Server) Serve(l net.Listener) error {
}
return err
}

s.wg.Add(1)
go func() {
defer s.wg.Done()

err := s.handleConn(newConn(c, s))
if err != nil {
s.ErrorLog.Printf("handler error: %s", err)
Expand Down Expand Up @@ -253,7 +263,7 @@ func (s *Server) ListenAndServeTLS() error {
func (s *Server) Close() error {
select {
case <-s.done:
return errors.New("smtp: server already closed")
return ErrServerClosed
default:
close(s.done)
}
Expand All @@ -274,6 +284,44 @@ func (s *Server) Close() error {
return err
}

// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
// listeners and then waiting indefinitely for connections to return to
// idle and then shut down.
// If the provided context expires before the shutdown is complete,
// Shutdown returns the context's error, otherwise it returns any
// error returned from closing the Server's underlying Listener(s).
func (s *Server) Shutdown(ctx context.Context) error {
select {
case <-s.done:
return ErrServerClosed
default:
close(s.done)
}

var err error
s.locker.Lock()
for _, l := range s.listeners {
if lerr := l.Close(); lerr != nil && err == nil {
err = lerr
}
}
s.locker.Unlock()

connDone := make(chan struct{})
go func() {
defer close(connDone)
s.wg.Wait()
}()

select {
case <-ctx.Done():
return ctx.Err()
case <-connDone:
return err
}
}

// EnableAuth enables an authentication mechanism on this server.
//
// This function should not be called directly, it must only be used by
Expand Down
31 changes: 31 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package smtp_test
import (
"bufio"
"bytes"
"context"
"errors"
"io"
"io/ioutil"
Expand Down Expand Up @@ -1234,3 +1235,33 @@ func TestServer_TooLongCommand(t *testing.T) {
t.Fatal("Invalid too long MAIL response:", scanner.Text())
}
}

func TestServerShutdown(t *testing.T) {
_, s, c, _ := testServerGreeted(t)

ctx := context.Background()
errChan := make(chan error)
go func() {
defer close(errChan)

errChan <- s.Shutdown(ctx)
errChan <- s.Shutdown(ctx)
}()

select {
case err := <-errChan:
t.Fatal("Expected no err because conn is open:", err)
default:
c.Close()
}

errOne := <-errChan
if errOne != nil {
t.Fatal("Expected err to be nil:", errOne)
}

errTwo := <-errChan
if errTwo != smtp.ErrServerClosed {
t.Fatal("Expected err to be ErrServerClosed:", errTwo)
}
}

0 comments on commit e7e393a

Please sign in to comment.