Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove remaining non-hierarchical locking in the rpc package. #406

Merged
merged 8 commits into from
Dec 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 81 additions & 73 deletions rpc/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,16 @@ func (c *Conn) handleBootstrap(ctx context.Context, id answerID) error {
return err
}

func idempotent(f func()) func() {
called := false
return func() {
if !called {
called = true
f()
}
}
}

func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capnp.ReleaseFunc) error {
rl := &releaseList{}
defer rl.Release()
Expand Down Expand Up @@ -694,18 +704,25 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
return nil
}

c.lk.Lock()
if c.lk.answers[id] != nil {
c.lk.Unlock()
releaseCall()
return rpcerr.Failedf("incoming call: answer ID %d reused", id)
}
var (
err error
p parsedCall
parseErr error
)
syncutil.With(&c.lk, func() {
if c.lk.answers[id] != nil {
rl.Add(releaseCall)
err = rpcerr.Failedf("incoming call: answer ID %d reused", id)
return
}

var p parsedCall
parseErr := c.parseCall(&p, call) // parseCall sets CapTable
parseErr = c.parseCall(&p, call) // parseCall sets CapTable
})
if err != nil {
return err
}

// Create return message.
c.lk.Unlock()
ret, send, retReleaser, err := c.newReturn()
if err != nil {
err = rpcerr.Annotate(err, "incoming call")
Expand All @@ -720,74 +737,74 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
ret.SetReleaseParamCaps(false)

// Find target and start call.
c.lk.Lock()
ans := &answer{
c: c,
id: id,
ret: ret,
sendMsg: send,
msgReleaser: retReleaser,
}
c.lk.Lock()
defer c.lk.Unlock()

c.lk.answers[id] = ans
if parseErr != nil {
parseErr = rpcerr.Annotate(parseErr, "incoming call")
ans.sendException(rl, parseErr)
c.lk.Unlock()
c.er.ReportError(parseErr)
releaseCall()
rl.Add(func() {
c.er.ReportError(parseErr)
releaseCall()
})
return nil
}
released := false
releaseArgs := func() {
if released {
return
}
released = true
releaseCall()

recv := capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: idempotent(releaseCall),
Returner: ans,
}

switch p.target.which {
case rpccp.MessageTarget_Which_importedCap:
ent := c.findExport(p.target.importedCap)
if ent == nil {
ans.ret = rpccp.Return{}
ans.sendMsg = nil
ans.msgReleaser = nil
c.lk.Unlock()
retReleaser.Decr()
releaseCall()
rl.Add(func() {
retReleaser.Decr()
releaseCall()
})
return rpcerr.Failedf("incoming call: unknown export ID %d", id)
}
c.tasks.Add(1) // will be finished by answer.Return
var callCtx context.Context
callCtx, ans.cancel = context.WithCancel(c.bgctx)
c.lk.Unlock()
pcall := ent.client.RecvCall(callCtx, capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: releaseArgs,
Returner: ans,
rl.Add(func() {
pcall := ent.client.RecvCall(callCtx, recv)
// Place PipelineCaller into answer. Since the receive goroutine is
// the only one that uses answer.pcall, it's fine that there's a
// time gap for this being set.
ans.setPipelineCaller(p.method, pcall)
})
// Place PipelineCaller into answer. Since the receive goroutine is
// the only one that uses answer.pcall, it's fine that there's a
// time gap for this being set.
ans.setPipelineCaller(p.method, pcall)
return nil
case rpccp.MessageTarget_Which_promisedAnswer:
tgtAns := c.lk.answers[p.target.promisedAnswer]
if tgtAns == nil || tgtAns.flags.Contains(finishReceived) {
ans.ret = rpccp.Return{}
ans.sendMsg = nil
ans.msgReleaser = nil
c.lk.Unlock()
retReleaser.Decr()
releaseCall()
rl.Add(func() {
retReleaser.Decr()
releaseCall()
})
return rpcerr.Failedf("incoming call: use of unknown or finished answer ID %d for promised answer target", p.target.promisedAnswer)
}
if tgtAns.flags.Contains(resultsReady) {
if tgtAns.err != nil {
ans.sendException(rl, tgtAns.err)
c.lk.Unlock()
releaseCall()
rl.Add(releaseCall)
return nil
}
// tgtAns.results is guaranteed to stay alive because it hasn't
Expand All @@ -798,17 +815,15 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
if err != nil {
err = rpcerr.Failedf("incoming call: read results from target answer: %w", err)
ans.sendException(rl, err)
c.lk.Unlock()
releaseCall()
rl.Add(releaseCall)
c.er.ReportError(err)
return nil
}
sub, err := capnp.Transform(content, p.target.transform)
if err != nil {
// Not reporting, as this is the caller's fault.
ans.sendException(rl, err)
c.lk.Unlock()
releaseCall()
rl.Add(releaseCall)
return nil
}
iface := sub.Interface()
Expand All @@ -824,30 +839,22 @@ func (c *Conn) handleCall(ctx context.Context, call rpccp.Call, releaseCall capn
c.tasks.Add(1) // will be finished by answer.Return
var callCtx context.Context
callCtx, ans.cancel = context.WithCancel(c.bgctx)
c.lk.Unlock()
pcall := tgt.RecvCall(callCtx, capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: releaseArgs,
Returner: ans,
rl.Add(func() {
pcall := tgt.RecvCall(callCtx, recv)
ans.setPipelineCaller(p.method, pcall)
})
ans.setPipelineCaller(p.method, pcall)
} else {
// Results not ready, use pipeline caller.
tgtAns.pcalls.Add(1) // will be finished by answer.Return
var callCtx context.Context
callCtx, ans.cancel = context.WithCancel(c.bgctx)
tgt := tgtAns.pcall
c.tasks.Add(1) // will be finished by answer.Return
c.lk.Unlock()
pcall := tgt.PipelineRecv(callCtx, p.target.transform, capnp.Recv{
Args: p.args,
Method: p.method,
ReleaseArgs: releaseArgs,
Returner: ans,
rl.Add(func() {
pcall := tgt.PipelineRecv(callCtx, p.target.transform, recv)
tgtAns.pcalls.Done()
ans.setPipelineCaller(p.method, pcall)
})
tgtAns.pcalls.Done()
ans.setPipelineCaller(p.method, pcall)
}
return nil
default:
Expand Down Expand Up @@ -934,11 +941,14 @@ func parseTransform(list rpccp.PromisedAnswer_Op_List) ([]capnp.PipelineOp, erro
}

func (c *Conn) handleReturn(ctx context.Context, ret rpccp.Return, release capnp.ReleaseFunc) error {
rl := &releaseList{}
defer rl.Release()
c.lk.Lock()
defer c.lk.Unlock()

qid := questionID(ret.AnswerId())
if uint32(qid) >= uint32(len(c.lk.questions)) {
c.lk.Unlock()
release()
rl.Add(release)
return rpcerr.Failedf("incoming return: question %d does not exist", qid)
}
// Pop the question from the table. Receiving the Return message
Expand All @@ -947,8 +957,7 @@ func (c *Conn) handleReturn(ctx context.Context, ret rpccp.Return, release capnp
q := c.lk.questions[qid]
c.lk.questions[qid] = nil
if q == nil {
c.lk.Unlock()
release()
rl.Add(release)
return rpcerr.Failedf("incoming return: question %d does not exist", qid)
}
canceled := q.flags&finished != 0
Expand All @@ -962,11 +971,9 @@ func (c *Conn) handleReturn(ctx context.Context, ret rpccp.Return, release capnp
if q.flags&finishSent != 0 {
c.lk.questionID.remove(uint32(qid))
}
c.lk.Unlock()
release()
rl.Add(release)
default:
c.lk.Unlock()
release()
rl.Add(release)

go func() {
<-q.finishMsgSend
Expand All @@ -989,7 +996,6 @@ func (c *Conn) handleReturn(ctx context.Context, ret rpccp.Return, release capnp
// client or an error), so we save the ReleaseFunc for later:
q.release = release
}
c.lk.Unlock()
// We're going to potentially block fulfilling some promises so fork
// off a goroutine to avoid blocking the receive loop.
go func() {
Expand Down Expand Up @@ -1342,16 +1348,18 @@ func (c *Conn) handleDisembargo(ctx context.Context, d rpccp.Disembargo, release
defer release()

id := embargoID(d.Context().ReceiverLoopback())
c.lk.Lock()
e := c.findEmbargo(id)
var e *embargo
syncutil.With(&c.lk, func() {
e = c.findEmbargo(id)
if e != nil {
// TODO(soon): verify target matches the right import.
c.lk.embargoes[id] = nil
c.lk.embargoID.remove(uint32(id))
}
})
if e == nil {
c.lk.Unlock()
return rpcerr.Failedf("incoming disembargo: received sender loopback for unknown ID %d", id)
}
// TODO(soon): verify target matches the right import.
c.lk.embargoes[id] = nil
c.lk.embargoID.remove(uint32(id))
c.lk.Unlock()
e.lift()

case rpccp.Disembargo_context_Which_senderLoopback:
Expand Down