Skip to content

Commit

Permalink
Merge pull request #307 from wneessen/feature/269_goroutineconcurrenc…
Browse files Browse the repository at this point in the history
…y-safety

go-mail goroutine-/thread-safety
  • Loading branch information
wneessen authored Sep 27, 2024
2 parents 077c85b + c1f6ef0 commit 65a91a2
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 101 deletions.
69 changes: 41 additions & 28 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net"
"os"
"strings"
"sync"
"time"

"github.com/wneessen/go-mail/log"
Expand Down Expand Up @@ -87,12 +88,12 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con

// Client is the SMTP client struct
type Client struct {
// connection is the net.Conn that the smtp.Client is based on
connection net.Conn

// Timeout for the SMTP server connection
connTimeout time.Duration

// dialContextFunc is a custom DialContext function to dial target SMTP server
dialContextFunc DialContextFunc

// dsn indicates that we want to use DSN for the Client
dsn bool

Expand All @@ -102,24 +103,34 @@ type Client struct {
// dsnrntype defines the DSNRcptNotifyOption in case DSN is enabled
dsnrntype []string

// isEncrypted indicates if a Client connection is encrypted or not
isEncrypted bool

// noNoop indicates the Noop is to be skipped
noNoop bool
// fallbackPort is used as an alternative port number in case the primary port is unavailable or
// fails to bind.
fallbackPort int

// HELO/EHLO string for the greeting the target SMTP server
helo string

// Hostname of the target SMTP server to connect to
host string

// isEncrypted indicates if a Client connection is encrypted or not
isEncrypted bool

// logger is a logger that implements the log.Logger interface
logger log.Logger

// mutex is used to synchronize access to shared resources, ensuring that only one goroutine can
// modify them at a time.
mutex sync.RWMutex

// noNoop indicates the Noop is to be skipped
noNoop bool

// pass is the corresponding SMTP AUTH password
pass string

// Port of the SMTP server to connect to
port int
fallbackPort int
// port specifies the network port number on which the server listens for incoming connections.
port int

// smtpAuth is a pointer to smtp.Auth
smtpAuth smtp.Auth
Expand All @@ -130,26 +141,20 @@ type Client struct {
// smtpClient is the smtp.Client that is set up when using the Dial*() methods
smtpClient *smtp.Client

// Use SSL for the connection
useSSL bool

// tlspolicy sets the client to use the provided TLSPolicy for the STARTTLS protocol
tlspolicy TLSPolicy

// tlsconfig represents the tls.Config setting for the STARTTLS connection
tlsconfig *tls.Config

// user is the SMTP AUTH username
user string

// useDebugLog enables the debug logging on the SMTP client
useDebugLog bool

// logger is a logger that implements the log.Logger interface
logger log.Logger
// user is the SMTP AUTH username
user string

// dialContextFunc is a custom DialContext function to dial target SMTP server
dialContextFunc DialContextFunc
// Use SSL for the connection
useSSL bool
}

// Option returns a function that can be used for grouping Client options
Expand Down Expand Up @@ -550,6 +555,9 @@ func (c *Client) SetLogger(logger log.Logger) {

// SetTLSConfig overrides the current *tls.Config with the given *tls.Config value
func (c *Client) SetTLSConfig(tlsconfig *tls.Config) error {
c.mutex.Lock()
defer c.mutex.Unlock()

if tlsconfig == nil {
return ErrInvalidTLSConfig
}
Expand Down Expand Up @@ -589,6 +597,9 @@ func (c *Client) setDefaultHelo() error {

// DialWithContext establishes a connection to the SMTP server with a given context.Context
func (c *Client) DialWithContext(dialCtx context.Context) error {
c.mutex.Lock()
defer c.mutex.Unlock()

ctx, cancel := context.WithDeadline(dialCtx, time.Now().Add(c.connTimeout))
defer cancel()

Expand All @@ -602,17 +613,16 @@ func (c *Client) DialWithContext(dialCtx context.Context) error {
c.dialContextFunc = tlsDialer.DialContext
}
}
var err error
c.connection, err = c.dialContextFunc(ctx, "tcp", c.ServerAddr())
connection, err := c.dialContextFunc(ctx, "tcp", c.ServerAddr())
if err != nil && c.fallbackPort != 0 {
// TODO: should we somehow log or append the previous error?
c.connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
connection, err = c.dialContextFunc(ctx, "tcp", c.serverFallbackAddr())
}
if err != nil {
return err
}

client, err := smtp.NewClient(c.connection, c.host)
client, err := smtp.NewClient(connection, c.host)
if err != nil {
return err
}
Expand Down Expand Up @@ -691,7 +701,7 @@ func (c *Client) DialAndSendWithContext(ctx context.Context, messages ...*Msg) e
// checkConn makes sure that a required server connection is available and extends the
// connection deadline
func (c *Client) checkConn() error {
if c.connection == nil {
if !c.smtpClient.HasConnection() {
return ErrNoActiveConnection
}

Expand All @@ -701,7 +711,7 @@ func (c *Client) checkConn() error {
}
}

if err := c.connection.SetDeadline(time.Now().Add(c.connTimeout)); err != nil {
if err := c.smtpClient.UpdateDeadline(c.connTimeout); err != nil {
return ErrDeadlineExtendFailed
}
return nil
Expand All @@ -715,7 +725,7 @@ func (c *Client) serverFallbackAddr() string {

// tls tries to make sure that the STARTTLS requirements are satisfied
func (c *Client) tls() error {
if c.connection == nil {
if !c.smtpClient.HasConnection() {
return ErrNoActiveConnection
}
if !c.useSSL && c.tlspolicy != NoTLS {
Expand Down Expand Up @@ -791,6 +801,9 @@ func (c *Client) auth() error {
// sendSingleMsg sends out a single message and returns an error if the transmission/delivery fails.
// It is invoked by the public Send methods
func (c *Client) sendSingleMsg(message *Msg) error {
c.mutex.Lock()
defer c.mutex.Unlock()

if message.encoding == NoEncoding {
if ok, _ := c.smtpClient.Extension("8BITMIME"); !ok {
return &SendError{Reason: ErrNoUnencoded, isTemp: false, affectedMsg: message}
Expand Down
Loading

0 comments on commit 65a91a2

Please sign in to comment.