Skip to content

Commit

Permalink
server: Add an option to create a base request context. (#56)
Browse files Browse the repository at this point in the history
Previously, the server used a background context as the basis for a request
context. This change adds a NewContext server option that the server now uses
instead. The default is to still use a background context.
  • Loading branch information
creachadair authored Sep 26, 2021
1 parent bf24a3e commit aba174f
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 14 deletions.
28 changes: 28 additions & 0 deletions jrpc2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1200,3 +1200,31 @@ func TestServerFromContext(t *testing.T) {
t.Errorf("ServerFromContext: got %p, want %p", got, loc.Server)
}
}

func TestServer_newContext(t *testing.T) {
// Prepare a context with a test value attached to it, that the handler can
// extract to verify that the base context was plumbed in correctly.
type ctxKey string
ctx := context.WithValue(context.Background(), ctxKey("test"), 42)

loc := server.NewLocal(handler.Map{
"Test": handler.New(func(ctx context.Context) error {
val := ctx.Value(ctxKey("test"))
if val == nil {
t.Error("Test value is not present in context")
} else if v, ok := val.(int); !ok || v != 42 {
t.Errorf("Wrong test value: got %+v, want %v", val, 42)
}
return nil
}),
}, &server.LocalOptions{
Server: &jrpc2.ServerOptions{
// Use the test context constructed above as the base request context.
NewContext: func() context.Context { return ctx },
},
})
defer loc.Close()
if _, err := loc.Client.Call(context.Background(), "Test", nil); err != nil {
t.Errorf("Call failed: %v", err)
}
}
11 changes: 11 additions & 0 deletions opts.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ type ServerOptions struct {
// this setting does not constrain order of issue.
Concurrency int

// If set, this function is called to create a new base request context.
// If unset, the server uses a background context.
NewContext func() context.Context

// If set, this function is called with the method name and encoded request
// parameters received from the client, before they are delivered to the
// handler. Its return value replaces the context and argument values. This
Expand Down Expand Up @@ -94,6 +98,13 @@ func (s *ServerOptions) startTime() time.Time {
return s.StartTime
}

func (o *ServerOptions) newContext() func() context.Context {
if o == nil || o.NewContext == nil {
return context.Background
}
return o.NewContext
}

type decoder = func(context.Context, string, json.RawMessage) (context.Context, json.RawMessage, error)

func (s *ServerOptions) decodeContext() (decoder, bool) {
Expand Down
30 changes: 16 additions & 14 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,20 @@ type logger = func(string, ...interface{})
// responses on a channel.Channel provided by the caller, and dispatches
// requests to user-defined Handlers.
type Server struct {
wg sync.WaitGroup // ready when workers are done at shutdown time
mux Assigner // associates method names with handlers
sem *semaphore.Weighted // bounds concurrent execution (default 1)
allow1 bool // allow v1 requests with no version marker
allowP bool // allow server notifications to the client
log logger // write debug logs here
rpcLog RPCLogger // log RPC requests and responses here
dectx decoder // decode context from request
ckreq verifier // request checking hook
expctx bool // whether to expect request context
metrics *metrics.M // metrics collected during execution
start time.Time // when Start was called
builtin bool // whether built-in rpc.* methods are enabled
wg sync.WaitGroup // ready when workers are done at shutdown time
mux Assigner // associates method names with handlers
sem *semaphore.Weighted // bounds concurrent execution (default 1)
allow1 bool // allow v1 requests with no version marker
allowP bool // allow server notifications to the client
log logger // write debug logs here
rpcLog RPCLogger // log RPC requests and responses here
newctx func() context.Context // create a new base request context
dectx decoder // decode context from request
ckreq verifier // request checking hook
expctx bool // whether to expect request context
metrics *metrics.M // metrics collected during execution
start time.Time // when Start was called
builtin bool // whether built-in rpc.* methods are enabled

mu *sync.Mutex // protects the fields below

Expand Down Expand Up @@ -74,6 +75,7 @@ func NewServer(mux Assigner, opts *ServerOptions) *Server {
allowP: opts.allowPush(),
log: opts.logger(),
rpcLog: opts.rpcLog(),
newctx: opts.newContext(),
dectx: dc,
ckreq: opts.checkRequest(),
expctx: exp,
Expand Down Expand Up @@ -306,7 +308,7 @@ func (s *Server) checkAndAssign(next jmessages) tasks {
// setContext constructs and attaches a request context to t, and reports
// whether this succeeded.
func (s *Server) setContext(t *task, id string) bool {
base, params, err := s.dectx(context.Background(), t.hreq.method, t.hreq.params)
base, params, err := s.dectx(s.newctx(), t.hreq.method, t.hreq.params)
t.hreq.params = params
if err != nil {
t.err = Errorf(code.InternalError, "invalid request context: %v", err)
Expand Down

0 comments on commit aba174f

Please sign in to comment.