Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DisconnectReceiptTimeout conn option #126

Merged
merged 1 commit into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,31 @@ const DefaultMsgSendTimeout = 10 * time.Second
// Default receipt timeout in Conn.Send function
const DefaultRcvReceiptTimeout = 30 * time.Second

// Default receipt timeout in Conn.Disconnect function
const DefaultDisconnectReceiptTimeout = 30 * time.Second

// Reply-To header used for temporary queues/RPC with rabbit.
const ReplyToHeader = "reply-to"

// A Conn is a connection to a STOMP server. Create a Conn using either
// the Dial or Connect function.
type Conn struct {
conn io.ReadWriteCloser
readCh chan *frame.Frame
writeCh chan writeRequest
version Version
session string
server string
readTimeout time.Duration
writeTimeout time.Duration
msgSendTimeout time.Duration
rcvReceiptTimeout time.Duration
hbGracePeriodMultiplier float64
closed bool
closeMutex *sync.Mutex
options *connOptions
log Logger
conn io.ReadWriteCloser
readCh chan *frame.Frame
writeCh chan writeRequest
version Version
session string
server string
readTimeout time.Duration
writeTimeout time.Duration
msgSendTimeout time.Duration
rcvReceiptTimeout time.Duration
disconnectReceiptTimeout time.Duration
hbGracePeriodMultiplier float64
closed bool
closeMutex *sync.Mutex
options *connOptions
log Logger
}

type writeRequest struct {
Expand Down Expand Up @@ -195,6 +199,7 @@ func Connect(conn io.ReadWriteCloser, opts ...func(*Conn) error) (*Conn, error)

c.msgSendTimeout = options.MsgSendTimeout
c.rcvReceiptTimeout = options.RcvReceiptTimeout
c.disconnectReceiptTimeout = options.DisconnectReceiptTimeout

if options.ResponseHeadersCallback != nil {
options.ResponseHeadersCallback(response.Header)
Expand Down Expand Up @@ -421,13 +426,18 @@ func (c *Conn) Disconnect() error {
C: ch,
}

response := <-ch
if response.Command != frame.RECEIPT {
return newError(response)
err := readReceiptWithTimeout(ch, c.disconnectReceiptTimeout, ErrDisconnectReceiptTimeout)
if err == nil {
c.closed = true
return c.conn.Close()
}

c.closed = true
return c.conn.Close()
if err == ErrDisconnectReceiptTimeout {
c.closed = true
_ = c.conn.Close()
}

return err
}

// MustDisconnect will disconnect 'ungracefully' from the STOMP server.
Expand Down Expand Up @@ -480,7 +490,7 @@ func (c *Conn) Send(destination, contentType string, body []byte, opts ...func(*
return err
}

err = readReceiptWithTimeout(request, c.rcvReceiptTimeout)
err = readReceiptWithTimeout(request.C, c.rcvReceiptTimeout, ErrMsgReceiptTimeout)
if err != nil {
return err
}
Expand All @@ -497,16 +507,16 @@ func (c *Conn) Send(destination, contentType string, body []byte, opts ...func(*
return nil
}

func readReceiptWithTimeout(request writeRequest, timeout time.Duration) error {
func readReceiptWithTimeout(responseChan chan *frame.Frame, timeout time.Duration, timeoutErr error) error {
var timeoutChan <-chan time.Time
if timeout > 0 {
timeoutChan = time.After(timeout)
}

select {
case <-timeoutChan:
return ErrMsgReceiptTimeout
case response := <-request.C:
return timeoutErr
case response := <-responseChan:
if response.Command != frame.RECEIPT {
return newError(response)
}
Expand Down
16 changes: 15 additions & 1 deletion conn_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type connOptions struct {
HeartBeatError time.Duration
MsgSendTimeout time.Duration
RcvReceiptTimeout time.Duration
DisconnectReceiptTimeout time.Duration
HeartBeatGracePeriodMultiplier float64
Login, Passcode string
AcceptVersions []string
Expand All @@ -38,6 +39,7 @@ func newConnOptions(conn *Conn, opts []func(*Conn) error) (*connOptions, error)
HeartBeatError: DefaultHeartBeatError,
MsgSendTimeout: DefaultMsgSendTimeout,
RcvReceiptTimeout: DefaultRcvReceiptTimeout,
DisconnectReceiptTimeout: DefaultDisconnectReceiptTimeout,
Logger: log.StdLogger{},
}

Expand Down Expand Up @@ -146,9 +148,14 @@ var ConnOpt struct {

// RcvReceiptTimeout is a connect option that allows the client to specify
// how long to wait for a receipt in the Conn.Send function. This helps
// avoid deadlocks. If this is not specified, the default is 10 seconds.
// avoid deadlocks. If this is not specified, the default is 30 seconds.
RcvReceiptTimeout func(rcvReceiptTimeout time.Duration) func(*Conn) error

// DisconnectReceiptTimeout is a connect option that allows the client to specify
// how long to wait for a receipt in the Conn.Disconnect function. This helps
// avoid deadlocks. If this is not specified, the default is 30 seconds.
DisconnectReceiptTimeout func(disconnectReceiptTimeout time.Duration) func(*Conn) error

// HeartBeatGracePeriodMultiplier is used to calculate the effective read heart-beat timeout
// the broker will enforce for each client’s connection. The multiplier is applied to
// the read-timeout interval the client specifies in its CONNECT frame
Expand Down Expand Up @@ -248,6 +255,13 @@ func init() {
}
}

ConnOpt.DisconnectReceiptTimeout = func(disconnectReceiptTimeout time.Duration) func(*Conn) error {
return func(c *Conn) error {
c.options.DisconnectReceiptTimeout = disconnectReceiptTimeout
return nil
}
}

ConnOpt.HeartBeatGracePeriodMultiplier = func(multiplier float64) func(*Conn) error {
return func(c *Conn) error {
c.options.HeartBeatGracePeriodMultiplier = multiplier
Expand Down
37 changes: 33 additions & 4 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,35 @@ func (s *StompSuite) Test_connect_not_panic_on_empty_response(c *C) {
<-stop
}

func (s *StompSuite) Test_successful_disconnect_with_receipt_timeout(c *C) {
resetId()
fc1, fc2 := testutil.NewFakeConn(c)

defer func() {
fc2.Close()
}()

go func() {
reader := frame.NewReader(fc2)
writer := frame.NewWriter(fc2)

f1, err := reader.Read()
c.Assert(err, IsNil)
c.Assert(f1.Command, Equals, "CONNECT")
connectedFrame := frame.New("CONNECTED")
err = writer.Write(connectedFrame)
c.Assert(err, IsNil)
}()

client, err := Connect(fc1, ConnOpt.DisconnectReceiptTimeout(1 * time.Nanosecond))
c.Assert(err, IsNil)
c.Assert(client, NotNil)

err = client.Disconnect()
c.Assert(err, Equals, ErrDisconnectReceiptTimeout)
c.Assert(client.closed, Equals, true)
}

// Sets up a connection for testing
func connectHelper(c *C, version Version) (*Conn, *fakeReaderWriter) {
fc1, fc2 := testutil.NewFakeConn(c)
Expand Down Expand Up @@ -697,7 +726,7 @@ func (s *StompSuite) Test_TimeoutTriggers(c *C) {
C: make(chan *frame.Frame),
}

err := readReceiptWithTimeout(request, timeout)
err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout)

c.Assert(err, NotNil)
}
Expand All @@ -715,7 +744,7 @@ func (s *StompSuite) Test_ChannelReceviesReceipt(c *C) {
}

go sendFrameHelper(&receipt, request.C)
err := readReceiptWithTimeout(request, timeout)
err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout)

c.Assert(err, IsNil)
}
Expand All @@ -733,7 +762,7 @@ func (s *StompSuite) Test_ChannelReceviesNonReceipt(c *C) {
}

go sendFrameHelper(&receipt, request.C)
err := readReceiptWithTimeout(request, timeout)
err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout)

c.Assert(err, NotNil)
}
Expand All @@ -751,7 +780,7 @@ func (s *StompSuite) Test_ZeroTimeout(c *C) {
}

go sendFrameHelper(&receipt, request.C)
err := readReceiptWithTimeout(request, timeout)
err := readReceiptWithTimeout(request.C, timeout, ErrMsgReceiptTimeout)

c.Assert(err, IsNil)
}
27 changes: 14 additions & 13 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,20 @@ import (

// Error values
var (
ErrInvalidCommand = newErrorMessage("invalid command")
ErrInvalidFrameFormat = newErrorMessage("invalid frame format")
ErrUnsupportedVersion = newErrorMessage("unsupported version")
ErrCompletedTransaction = newErrorMessage("transaction is completed")
ErrNackNotSupported = newErrorMessage("NACK not supported in STOMP 1.0")
ErrNotReceivedMessage = newErrorMessage("cannot ack/nack a message, not from server")
ErrCannotNackAutoSub = newErrorMessage("cannot send NACK for a subscription with ack:auto")
ErrCompletedSubscription = newErrorMessage("subscription is unsubscribed")
ErrClosedUnexpectedly = newErrorMessage("connection closed unexpectedly")
ErrAlreadyClosed = newErrorMessage("connection already closed")
ErrMsgSendTimeout = newErrorMessage("msg send timeout")
ErrMsgReceiptTimeout = newErrorMessage("msg receipt timeout")
ErrNilOption = newErrorMessage("nil option")
ErrInvalidCommand = newErrorMessage("invalid command")
ErrInvalidFrameFormat = newErrorMessage("invalid frame format")
ErrUnsupportedVersion = newErrorMessage("unsupported version")
ErrCompletedTransaction = newErrorMessage("transaction is completed")
ErrNackNotSupported = newErrorMessage("NACK not supported in STOMP 1.0")
ErrNotReceivedMessage = newErrorMessage("cannot ack/nack a message, not from server")
ErrCannotNackAutoSub = newErrorMessage("cannot send NACK for a subscription with ack:auto")
ErrCompletedSubscription = newErrorMessage("subscription is unsubscribed")
ErrClosedUnexpectedly = newErrorMessage("connection closed unexpectedly")
ErrAlreadyClosed = newErrorMessage("connection already closed")
ErrMsgSendTimeout = newErrorMessage("msg send timeout")
ErrMsgReceiptTimeout = newErrorMessage("msg receipt timeout")
ErrDisconnectReceiptTimeout = newErrorMessage("disconnect receipt timeout")
ErrNilOption = newErrorMessage("nil option")
)

// StompError implements the Error interface, and provides
Expand Down