Skip to content

Commit

Permalink
Merge pull request #288 from zenhack/fix-271
Browse files Browse the repository at this point in the history
Fix #271
  • Loading branch information
lthibault authored Aug 8, 2022
2 parents 17c65f2 + 779b538 commit e667e0d
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 9 deletions.
31 changes: 25 additions & 6 deletions answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package capnp

import (
"context"
"fmt"
"strconv"
"sync"

Expand Down Expand Up @@ -128,9 +129,7 @@ func (p *Promise) resolution() resolution {
func (p *Promise) Fulfill(result Ptr) {
defer p.mu.Unlock()
p.mu.Lock()
if !p.isUnresolved() {
panic("Promise.Fulfill called after Fulfill or Reject")
}
p.requireUnresolved("Fulfill")
p.resolve(result, nil)
}

Expand All @@ -145,9 +144,7 @@ func (p *Promise) Reject(e error) {
}
defer p.mu.Unlock()
p.mu.Lock()
if !p.isUnresolved() {
panic("Promise.Reject called after Fulfill or Reject")
}
p.requireUnresolved("Reject")
p.resolve(Ptr{}, e)
}

Expand All @@ -163,6 +160,28 @@ func (p *Promise) Resolve(r Ptr, e error) {
}
}

// requireUnresolved is a helper method for checking for duplicate
// calls to Fulfill() or Reject(); panics if the promise is not in
// the unresolved state.
//
// The callerMethod argument should be the name of the method which
// is invoking requireUnresolved. The panic message will report this
// value as well as the method that originally resolved the promise,
// and which method (Fulfill or Reject) was used to resolve it.
func (p *Promise) requireUnresolved(callerMethod string) {
if !p.isUnresolved() {
var prevMethod string
if p.err == nil {
prevMethod = "Fulfill"
} else {
prevMethod = fmt.Sprintf("Reject (error = %q)", p.err)
}

panic("Promise." + callerMethod +
" called after previous call to " + prevMethod)
}
}

// resolve moves p into the resolved state from unresolved. The caller
// must be holding onto p.mu.
func (p *Promise) resolve(r Ptr, e error) {
Expand Down
53 changes: 53 additions & 0 deletions rpc/level0_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"capnproto.org/go/capnp/v3/exc"
"capnproto.org/go/capnp/v3/pogs"
"capnproto.org/go/capnp/v3/rpc"
testcp "capnproto.org/go/capnp/v3/rpc/internal/testcapnp"
"capnproto.org/go/capnp/v3/rpc/transport"
"capnproto.org/go/capnp/v3/server"
rpccp "capnproto.org/go/capnp/v3/std/capnp/rpc"
Expand Down Expand Up @@ -1768,6 +1769,58 @@ func TestSendCancel(t *testing.T) {
}
}

func TestHandleReturn_regression(t *testing.T) {
t.Parallel()

p1, p2 := transport.NewPipe(1)

srv := testcp.PingPong_ServerToClient(pingPongServer{})
conn1 := rpc.NewConn(rpc.NewTransport(p2), &rpc.Options{
BootstrapClient: capnp.Client(srv),
})
defer conn1.Close()

conn2 := rpc.NewConn(rpc.NewTransport(p1), nil)
defer conn2.Close()

t.Run("MethodCallWithExpiredContext", func(t *testing.T) {
pp := testcp.PingPong(conn2.Bootstrap(context.Background()))
defer pp.Release()

// create an EXPIRED context
ctx, cancel := context.WithCancel(context.Background())
cancel()

f, release := pp.EchoNum(ctx, func(ps testcp.PingPong_echoNum_Params) error {
ps.SetN(42)
return nil
})
defer release()

_, err := f.Struct()
assert.ErrorIs(t, err, ctx.Err())
})

t.Run("BootstrapWithExpiredContext", func(t *testing.T) {
// create an EXPIRED context
ctx, cancel := context.WithCancel(context.Background())
cancel()

// NOTE: bootstrap with expired context
pp := testcp.PingPong(conn2.Bootstrap(ctx))
defer pp.Release()

f, release := pp.EchoNum(ctx, func(ps testcp.PingPong_echoNum_Params) error {
ps.SetN(42)
return nil
})
defer release()

_, err := f.Struct()
assert.ErrorIs(t, err, ctx.Err())
})
}

// finishTest drains both sides of a pipe and reports any errors to t.
func finishTest(t errorfer, conn *rpc.Conn, p2 rpc.Transport) {
ctx, cancel := context.WithCancel(context.Background())
Expand Down
8 changes: 5 additions & 3 deletions rpc/question.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,11 @@ func (q *question) PipelineSend(ctx context.Context, transform []capnp.PipelineO
return q.c.newPipelineCallMessage(m, q.id, transform, q2.id, s)
}, func(err error) {
if err != nil {
q.c.questions[q2.id] = nil
q.c.questionID.remove(uint32(q2.id))
q.p.Reject(rpcerr.Failedf("send message: %w", err))
syncutil.With(&q.c.mu, func() {
q.c.questions[q2.id] = nil
q.c.questionID.remove(uint32(q2.id))
})
q2.p.Reject(rpcerr.Failedf("send message: %w", err))
return
}

Expand Down

0 comments on commit e667e0d

Please sign in to comment.