diff --git a/jrpc2_test.go b/jrpc2_test.go index a2d5422..e594e31 100644 --- a/jrpc2_test.go +++ b/jrpc2_test.go @@ -1200,3 +1200,31 @@ func TestServerFromContext(t *testing.T) { t.Errorf("ServerFromContext: got %p, want %p", got, loc.Server) } } + +func TestServer_newContext(t *testing.T) { + // Prepare a context with a test value attached to it, that the handler can + // extract to verify that the base context was plumbed in correctly. + type ctxKey string + ctx := context.WithValue(context.Background(), ctxKey("test"), 42) + + loc := server.NewLocal(handler.Map{ + "Test": handler.New(func(ctx context.Context) error { + val := ctx.Value(ctxKey("test")) + if val == nil { + t.Error("Test value is not present in context") + } else if v, ok := val.(int); !ok || v != 42 { + t.Errorf("Wrong test value: got %+v, want %v", val, 42) + } + return nil + }), + }, &server.LocalOptions{ + Server: &jrpc2.ServerOptions{ + // Use the test context constructed above as the base request context. + NewContext: func() context.Context { return ctx }, + }, + }) + defer loc.Close() + if _, err := loc.Client.Call(context.Background(), "Test", nil); err != nil { + t.Errorf("Call failed: %v", err) + } +} diff --git a/opts.go b/opts.go index c63a954..92ca22e 100644 --- a/opts.go +++ b/opts.go @@ -44,6 +44,10 @@ type ServerOptions struct { // this setting does not constrain order of issue. Concurrency int + // If set, this function is called to create a new base request context. + // If unset, the server uses a background context. + NewContext func() context.Context + // If set, this function is called with the method name and encoded request // parameters received from the client, before they are delivered to the // handler. Its return value replaces the context and argument values. This @@ -94,6 +98,13 @@ func (s *ServerOptions) startTime() time.Time { return s.StartTime } +func (o *ServerOptions) newContext() func() context.Context { + if o == nil || o.NewContext == nil { + return context.Background + } + return o.NewContext +} + type decoder = func(context.Context, string, json.RawMessage) (context.Context, json.RawMessage, error) func (s *ServerOptions) decodeContext() (decoder, bool) { diff --git a/server.go b/server.go index 2ed4ef3..7c3dc9c 100644 --- a/server.go +++ b/server.go @@ -23,19 +23,20 @@ type logger = func(string, ...interface{}) // responses on a channel.Channel provided by the caller, and dispatches // requests to user-defined Handlers. type Server struct { - wg sync.WaitGroup // ready when workers are done at shutdown time - mux Assigner // associates method names with handlers - sem *semaphore.Weighted // bounds concurrent execution (default 1) - allow1 bool // allow v1 requests with no version marker - allowP bool // allow server notifications to the client - log logger // write debug logs here - rpcLog RPCLogger // log RPC requests and responses here - dectx decoder // decode context from request - ckreq verifier // request checking hook - expctx bool // whether to expect request context - metrics *metrics.M // metrics collected during execution - start time.Time // when Start was called - builtin bool // whether built-in rpc.* methods are enabled + wg sync.WaitGroup // ready when workers are done at shutdown time + mux Assigner // associates method names with handlers + sem *semaphore.Weighted // bounds concurrent execution (default 1) + allow1 bool // allow v1 requests with no version marker + allowP bool // allow server notifications to the client + log logger // write debug logs here + rpcLog RPCLogger // log RPC requests and responses here + newctx func() context.Context // create a new base request context + dectx decoder // decode context from request + ckreq verifier // request checking hook + expctx bool // whether to expect request context + metrics *metrics.M // metrics collected during execution + start time.Time // when Start was called + builtin bool // whether built-in rpc.* methods are enabled mu *sync.Mutex // protects the fields below @@ -74,6 +75,7 @@ func NewServer(mux Assigner, opts *ServerOptions) *Server { allowP: opts.allowPush(), log: opts.logger(), rpcLog: opts.rpcLog(), + newctx: opts.newContext(), dectx: dc, ckreq: opts.checkRequest(), expctx: exp, @@ -306,7 +308,7 @@ func (s *Server) checkAndAssign(next jmessages) tasks { // setContext constructs and attaches a request context to t, and reports // whether this succeeded. func (s *Server) setContext(t *task, id string) bool { - base, params, err := s.dectx(context.Background(), t.hreq.method, t.hreq.params) + base, params, err := s.dectx(s.newctx(), t.hreq.method, t.hreq.params) t.hreq.params = params if err != nil { t.err = Errorf(code.InternalError, "invalid request context: %v", err)