Skip to content

Commit

Permalink
Add support for concurrent callback handlers in the Client (#60)
Browse files Browse the repository at this point in the history
Change the implementation of the client-side callback handler to allow multiple
callback handlers to execute concurrently.

Although this change does not modify the API seen by the user of the client, it
is breaking in the sense that it removes the assumption that only one callback
handler could be active at a time.

Related changes:

 - Add jrpc2.ClientFromContext to expose the client to callback handlers.
 - Fix a cancellation bug in server-side callback handling.
 - Update tests for the new behaviour.
  • Loading branch information
creachadair authored Nov 26, 2021
1 parent 376d23d commit 49e872f
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 17 deletions.
35 changes: 29 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ type Client struct {
log func(string, ...interface{}) // write debug logs here
enctx encoder
snote func(*jmessage)
scall func(*jmessage) []byte
scall func(context.Context, *jmessage) []byte
chook func(*Client, *Response)

cbctx context.Context // terminates when the client is closed
cbcancel func() // cancels cbctx

allow1 bool // tolerate v1 replies with no version marker

mu sync.Mutex // protects the fields below
Expand All @@ -35,6 +38,7 @@ type Client struct {

// NewClient returns a new client that communicates with the server via ch.
func NewClient(ch channel.Channel, opts *ClientOptions) *Client {
cbctx, cbcancel := context.WithCancel(context.Background())
c := &Client{
done: new(sync.WaitGroup),
log: opts.logFunc(),
Expand All @@ -44,6 +48,9 @@ func NewClient(ch channel.Channel, opts *ClientOptions) *Client {
scall: opts.handleCallback(),
chook: opts.handleCancel(),

cbctx: cbctx,
cbcancel: cbcancel,

// Lock-protected fields
ch: ch,
pending: make(map[string]*Response),
Expand Down Expand Up @@ -99,7 +106,7 @@ func (c *Client) accept(ch receiver) error {
}

// handleRequest handles a callback or notification from the server. The
// caller must hold c.mu, and this blocks until the handler completes.
// caller must hold c.mu. This function does not block for the handler.
// Precondition: msg is a request or notification, not a response or error.
func (c *Client) handleRequest(msg *jmessage) {
if msg.isNotification() {
Expand All @@ -113,10 +120,22 @@ func (c *Client) handleRequest(msg *jmessage) {
} else if c.ch == nil {
c.log("Client channel is closed; discarding callback: %v", msg)
} else {
bits := c.scall(msg)
if err := c.ch.Send(bits); err != nil {
c.log("Sending reply for callback %v failed: %v", msg, err)
}
// Run the callback handler in its own goroutine. The context will be
// cancelled automatically when the client is closed.
ctx := context.WithValue(c.cbctx, clientKey{}, c)
c.done.Add(1)
go func() {
defer c.done.Done()
bits := c.scall(ctx, msg)

c.mu.Lock()
defer c.mu.Unlock()
if c.err != nil {
c.log("Discarding callback response: %v", c.err)
} else if err := c.ch.Send(bits); err != nil {
c.log("Sending reply for callback %v failed: %v", msg, err)
}
}()
}
}

Expand Down Expand Up @@ -392,10 +411,14 @@ func (c *Client) stop(err error) {
}
c.ch.Close()

// Unblock and fail any pending callbacks.
c.cbcancel()

// Unblock and fail any pending requests.
for _, p := range c.pending {
p.cancel()
}

c.err = err
c.ch = nil
}
Expand Down
9 changes: 9 additions & 0 deletions ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,12 @@ type inboundRequestKey struct{}
func ServerFromContext(ctx context.Context) *Server { return ctx.Value(serverKey{}).(*Server) }

type serverKey struct{}

// ClientFromContext returns the client associated with the given context.
// This will be populated on the context passed to callback handlers.
//
// A callback handler must not close the client, as the close will deadlock
// waiting for the callback to return.
func ClientFromContext(ctx context.Context) *Client { return ctx.Value(clientKey{}).(*Client) }

type clientKey struct{}
146 changes: 146 additions & 0 deletions jrpc2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -870,6 +871,151 @@ func TestClient_onCancelHook(t *testing.T) {
}
}

// Verify that client callback handlers are cancelled when the client stops.
func TestClient_closeEndsCallbacks(t *testing.T) {
ready := make(chan struct{})
loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context) error {
// Call back to the client and block indefinitely until it returns.
srv := jrpc2.ServerFromContext(ctx)
_, err := srv.Callback(ctx, "whatever", nil)
return err
}),
}, &server.LocalOptions{
Server: &jrpc2.ServerOptions{AllowPush: true},
Client: &jrpc2.ClientOptions{
OnCallback: handler.New(func(ctx context.Context) error {
// Signal the test that the callback handler is running. When the
// client is closed, it should terminate ctx and allow this to
// return. If that doesn't happen, time out and fail.
close(ready)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(10 * time.Second):
return errors.New("context not cancelled before timeout")
}
}),
},
})
go func() {
rsp, err := loc.Client.Call(context.Background(), "Test", nil)
if err == nil {
t.Errorf("Client call: got %+v, wanted error", rsp)
}
}()
<-ready
loc.Client.Close()
loc.Server.Wait()
}

// Verify that it is possible for multiple callback handlers to execute
// concurrently.
func TestClient_concurrentCallbacks(t *testing.T) {
ready1 := make(chan struct{})
ready2 := make(chan struct{})
release := make(chan struct{})

loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context) []string {
srv := jrpc2.ServerFromContext(ctx)

// Call two callbacks concurrently, wait until they are both running,
// then ungate them and wait for them both to reply. Return their
// responses back to the test for validation.
ss := make([]string, 2)
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
rsp, err := srv.Callback(ctx, "C1", nil)
if err != nil {
t.Errorf("Callback C1 failed: %v", err)
} else {
rsp.UnmarshalResult(&ss[0])
}
}()
go func() {
defer wg.Done()
rsp, err := srv.Callback(ctx, "C2", nil)
if err != nil {
t.Errorf("Callback C2 failed: %v", err)
} else {
rsp.UnmarshalResult(&ss[1])
}
}()
<-ready1 // C1 is ready
<-ready2 // C2 is ready
close(release) // allow all callbacks to proceed
wg.Wait() // wait for all callbacks to be done
return ss
}),
}, &server.LocalOptions{
Server: &jrpc2.ServerOptions{AllowPush: true},
Client: &jrpc2.ClientOptions{
OnCallback: handler.Func(func(ctx context.Context, req *jrpc2.Request) (interface{}, error) {
// A trivial callback that reports its method name.
// The name is used to select which invocation we are serving.
switch req.Method() {
case "C1":
close(ready1)
case "C2":
close(ready2)
default:
return nil, fmt.Errorf("unexpected method %q", req.Method())
}
<-release
return req.Method(), nil
}),
},
})

var got []string
if err := loc.Client.CallResult(context.Background(), "Test", nil, &got); err != nil {
t.Errorf("Call Test failed: %v", err)
}
want := []string{"C1", "C2"}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Wrong callback results: (-want, +got)\n%s", diff)
}
}

// Verify that a callback can successfully call "up" into the server.
func TestClient_callbackUpCall(t *testing.T) {
const pingMessage = "kittens!"

var probe string
loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context) error {
// Call back to the client, and propagate its response.
srv := jrpc2.ServerFromContext(ctx)
_, err := srv.Callback(ctx, "whatever", nil)
return err
}),
"Ping": handler.New(func(context.Context) string {
// This method is called by the client-side callback.
return pingMessage
}),
}, &server.LocalOptions{
Server: &jrpc2.ServerOptions{AllowPush: true},
Client: &jrpc2.ClientOptions{
OnCallback: handler.New(func(ctx context.Context) error {
// Call back up into the server.
cli := jrpc2.ClientFromContext(ctx)
return cli.CallResult(ctx, "Ping", nil, &probe)
}),
},
})

if _, err := loc.Client.Call(context.Background(), "Test", nil); err != nil {
t.Errorf("Call Test failed: %v", err)
}
loc.Close()
if probe != pingMessage {
t.Errorf("Probe response: got %q, want %q", probe, pingMessage)
}
}

// Verify that the context encoding/decoding hooks work.
func TestContextPlumbing(t *testing.T) {
want := time.Now().Add(10 * time.Second)
Expand Down
18 changes: 10 additions & 8 deletions opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,17 @@ type ClientOptions struct {
OnNotify func(*Request)

// If set, this function is called if a request is received from the server.
// If unset, server requests are logged and discarded. At most one
// invocation of this callback will be active at a time.
// Server callbacks are a non-standard extension of JSON-RPC.
// If unset, server requests are logged and discarded. Multiple invocations
// of the callback handler may be active concurrently.
//
// The callback handler can retrieve the client from its context using the
// jrpc2.ClientFromContext function. The context terminates when the client
// is closed.
//
// If a callback handler panics, the client will recover the panic and
// report a system error back to the server describing the error.
//
// Server callbacks are a non-standard extension of JSON-RPC.
OnCallback func(context.Context, *Request) (interface{}, error)

// If set, this function is called when the context for a request terminates.
Expand Down Expand Up @@ -214,15 +219,12 @@ func (c *ClientOptions) handleCancel() func(*Client, *Response) {
return c.OnCancel
}

func (c *ClientOptions) handleCallback() func(*jmessage) []byte {
func (c *ClientOptions) handleCallback() func(context.Context, *jmessage) []byte {
if c == nil || c.OnCallback == nil {
return nil
}
cb := c.OnCallback
return func(req *jmessage) []byte {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

return func(ctx context.Context, req *jmessage) []byte {
// Recover panics from the callback handler to ensure the server gets a
// response even if the callback fails without a result.
//
Expand Down
5 changes: 2 additions & 3 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -575,9 +575,8 @@ func (s *Server) stop(err error) {

// Cancel any in-flight requests that made it out of the queue, and
// terminate any pending callback invocations.
for id, rsp := range s.call {
delete(s.call, id)
rsp.cancel()
for _, rsp := range s.call {
rsp.cancel() // the waiter will clean up the map
}
for id, cancel := range s.used {
cancel()
Expand Down

0 comments on commit 49e872f

Please sign in to comment.