diff --git a/rpc/rpc.go b/rpc/rpc.go index 0f5a6504..83065ac0 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -88,10 +88,11 @@ type Conn struct { bootstrap capnp.Client er errReporter abortTimeout time.Duration + baseContext func() context.Context // bgctx is a Context that is canceled when shutdown starts. Note - // that it's parent is context.Background(), so we can rely on this - // being the *only* time it will be canceled. + // that if baseContext is not provided, it's parent is context.Background(), + // so we can rely on this being the *only* time it will be canceled. bgctx context.Context // tasks block shutdown. @@ -202,6 +203,11 @@ type Options struct { // by Dial or Accept on the Network itself; application code should not // set this. Network Network + + // BaseContext is an optional function that returns a base context + // for any incoming connection. If ommitted, the context.Background() + // will be used instead. + BaseContext func() context.Context } // Logger is used for logging by the RPC system. Each method logs @@ -231,8 +237,9 @@ type Logger interface { // requests from the transport. func NewConn(t Transport, opts *Options) *Conn { c := &Conn{ - transport: t, - closed: make(chan struct{}), + transport: t, + baseContext: context.Background, + closed: make(chan struct{}), } sender := spsc.New[asyncSend]() @@ -248,6 +255,10 @@ func NewConn(t Transport, opts *Options) *Conn { c.abortTimeout = opts.AbortTimeout c.network = opts.Network c.remotePeerID = opts.RemotePeerID + + if opts.BaseContext != nil { + c.baseContext = opts.BaseContext + } } if c.abortTimeout == 0 { c.abortTimeout = 100 * time.Millisecond @@ -261,7 +272,7 @@ func NewConn(t Transport, opts *Options) *Conn { func (c *Conn) startBackgroundTasks() { // We use an errgroup to link the lifetime of background tasks // to each other. - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(c.baseContext()) g, ctx := errgroup.WithContext(ctx) c.bgctx = ctx