Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Less write lock contention, better read timeout handling #97

Merged
merged 6 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import (
"net/http"
"net/url"
"reflect"
"runtime/pprof"
"sync/atomic"
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
logging "github.com/ipfs/go-log/v2"
"go.opencensus.io/trace"
Expand Down Expand Up @@ -238,7 +240,7 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs []
hnd = h
}

go (&wsConn{
wconn := &wsConn{
conn: conn,
connFactory: connFactory,
reconnectBackoff: config.reconnectBackoff,
Expand All @@ -248,7 +250,14 @@ func websocketClient(ctx context.Context, addr string, namespace string, outs []
requests: requests,
stop: stop,
exiting: exiting,
}).handleWsConn(ctx)
}

go func() {
lbl := pprof.Labels("jrpc-mode", "wsclient", "jrpc-remote", addr, "jrpc-local", conn.LocalAddr().String(), "jrpc-uuid", uuid.New().String())
pprof.Do(ctx, lbl, func(ctx context.Context) {
wconn.handleWsConn(ctx)
})
}()

if err := c.provide(outs); err != nil {
return nil, err
Expand Down
39 changes: 38 additions & 1 deletion handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -446,11 +446,48 @@ func (s *handler) handle(ctx context.Context, req request, w func(func(io.Writer
log.Errorw("error and res returned", "request", req, "r.err", resp.Error, "res", res)
}

w(func(w io.Writer) {
withLazyWriter(w, func(w io.Writer) {
if err := json.NewEncoder(w).Encode(resp); err != nil {
log.Error(err)
stats.Record(ctx, metrics.RPCResponseError.M(1))
return
}
})
}

// withLazyWriter makes it possible to defer acquiring a writer until the first write.
// This is useful because json.Encode needs to marshal the response fully before writing, which may be
// a problem for very large responses.
func withLazyWriter(withWriterFunc func(func(io.Writer)), cb func(io.Writer)) {
lw := &lazyWriter{
withWriterFunc: withWriterFunc,

done: make(chan struct{}),
}

defer close(lw.done)
cb(lw)
}

type lazyWriter struct {
withWriterFunc func(func(io.Writer))

w io.Writer
done chan struct{}
}

func (lw *lazyWriter) Write(p []byte) (n int, err error) {
if lw.w == nil {
acquired := make(chan struct{})
go func() {
lw.withWriterFunc(func(w io.Writer) {
lw.w = w
close(acquired)
<-lw.done
})
}()
<-acquired
}

return lw.w.Write(p)
}
166 changes: 155 additions & 11 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import (
"net"
"net/http"
"net/http/httptest"
"os"
"reflect"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

Expand All @@ -33,7 +35,7 @@ func init() {
}

type SimpleServerHandler struct {
n int
n int32
}

type TestType struct {
Expand All @@ -57,14 +59,14 @@ func (h *SimpleServerHandler) Add(in int) error {
return errors.New("test")
}

h.n += in
atomic.AddInt32(&h.n, int32(in))

return nil
}

func (h *SimpleServerHandler) AddGet(in int) int {
h.n += in
return h.n
atomic.AddInt32(&h.n, int32(in))
return int(h.n)
}

func (h *SimpleServerHandler) StringMatch(t TestType, i2 int64) (out TestOut, err error) {
Expand All @@ -88,7 +90,7 @@ func TestRawRequests(t *testing.T) {
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()

tc := func(req, resp string, n int) func(t *testing.T) {
tc := func(req, resp string, n int32) func(t *testing.T) {
return func(t *testing.T) {
rpcHandler.n = 0

Expand Down Expand Up @@ -225,7 +227,7 @@ func TestRPC(t *testing.T) {
// Add(int) error

require.NoError(t, client.Add(2))
require.Equal(t, 2, serverHandler.n)
require.Equal(t, 2, int(serverHandler.n))

err = client.Add(-3546)
require.EqualError(t, err, "test")
Expand All @@ -234,7 +236,7 @@ func TestRPC(t *testing.T) {

n := client.AddGet(3)
require.Equal(t, 5, n)
require.Equal(t, 5, serverHandler.n)
require.Equal(t, 5, int(serverHandler.n))

// StringMatch

Expand Down Expand Up @@ -268,7 +270,7 @@ func TestRPC(t *testing.T) {

// this one should actually work
noret.Add(4)
require.Equal(t, 9, serverHandler.n)
require.Equal(t, 9, int(serverHandler.n))
closer()

var noparam struct {
Expand Down Expand Up @@ -343,7 +345,7 @@ func TestRPCHttpClient(t *testing.T) {
// Add(int) error

require.NoError(t, client.Add(2))
require.Equal(t, 2, serverHandler.n)
require.Equal(t, 2, int(serverHandler.n))

err = client.Add(-3546)
require.EqualError(t, err, "test")
Expand All @@ -352,7 +354,7 @@ func TestRPCHttpClient(t *testing.T) {

n := client.AddGet(3)
require.Equal(t, 5, n)
require.Equal(t, 5, serverHandler.n)
require.Equal(t, 5, int(serverHandler.n))

// StringMatch

Expand All @@ -379,7 +381,7 @@ func TestRPCHttpClient(t *testing.T) {

// this one should actually work
noret.Add(4)
require.Equal(t, 9, serverHandler.n)
require.Equal(t, 9, int(serverHandler.n))
closer()

var noparam struct {
Expand Down Expand Up @@ -429,6 +431,41 @@ func TestRPCHttpClient(t *testing.T) {
closer()
}

func TestParallelRPC(t *testing.T) {
// setup server

serverHandler := &SimpleServerHandler{}

rpcServer := NewServer()
rpcServer.Register("SimpleServerHandler", serverHandler)

// httptest stuff
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()
// setup client

var client struct {
Add func(int) error
}
closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &client, nil)
require.NoError(t, err)
defer closer()

var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 100; j++ {
require.NoError(t, client.Add(2))
}
}()
}
wg.Wait()

require.Equal(t, 20000, int(serverHandler.n))
}

type CtxHandler struct {
lk sync.Mutex

Expand Down Expand Up @@ -1414,3 +1451,110 @@ func TestReverseCallAliased(t *testing.T) {

closer()
}

type BigCallTestServerHandler struct {
}

type RecRes struct {
I int
R []RecRes
}

func (h *BigCallTestServerHandler) Do() (RecRes, error) {
var res RecRes
res.I = 123

for i := 0; i < 15000; i++ {
var ires RecRes
ires.I = i

for j := 0; j < 15000; j++ {
var jres RecRes
jres.I = j

ires.R = append(ires.R, jres)
}

res.R = append(res.R, ires)
}

fmt.Println("sending result")

return res, nil
}

func (h *BigCallTestServerHandler) Ch(ctx context.Context) (<-chan int, error) {
out := make(chan int)

go func() {
var i int
for {
select {
case <-ctx.Done():
fmt.Println("closing")
close(out)
return
case <-time.After(time.Second):
}
fmt.Println("sending")
out <- i
i++
}
}()

return out, nil
}

// TestBigResult tests that the connection doesn't die when sending a large result,
// and that requests which happen while a large result is being sent don't fail.
func TestBigResult(t *testing.T) {
magik6k marked this conversation as resolved.
Show resolved Hide resolved
if os.Getenv("I_HAVE_A_LOT_OF_MEMORY_AND_TIME") != "1" {
magik6k marked this conversation as resolved.
Show resolved Hide resolved
// needs ~40GB of memory and ~4 minutes to run
magik6k marked this conversation as resolved.
Show resolved Hide resolved
t.Skip("skipping test due to requiced resources, set I_HAVE_A_LOT_OF_MEMORY_AND_TIME=1 to run")
}

// setup server

serverHandler := &BigCallTestServerHandler{}

rpcServer := NewServer()
rpcServer.Register("SimpleServerHandler", serverHandler)

// httptest stuff
testServ := httptest.NewServer(rpcServer)
defer testServ.Close()
// setup client

var client struct {
Do func() (RecRes, error)
Ch func(ctx context.Context) (<-chan int, error)
}
closer, err := NewClient(context.Background(), "ws://"+testServ.Listener.Addr().String(), "SimpleServerHandler", &client, nil)
require.NoError(t, err)
defer closer()

chctx, cancel := context.WithCancel(context.Background())
defer cancel()

// client.Ch will generate some requests, which will require websocket locks,
// and before fixes in #97 would cause deadlocks / timeouts when combined with
// the large result processing from client.Do
ch, err := client.Ch(chctx)
require.NoError(t, err)

prevN := <-ch

go func() {
for n := range ch {
if n != prevN+1 {
panic("bad order")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider use testing.T methods to fail from this goroutine instead of panicing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically shouldn't use those in a goroutine that isn't running the test, panic was good enough

}
prevN = n
}
}()

_, err = client.Do()
require.NoError(t, err)

fmt.Println("done")
}
7 changes: 6 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ import (
"encoding/json"
"io"
"net/http"
"runtime/pprof"
"strings"
"time"

"github.com/google/uuid"
"github.com/gorilla/websocket"
)

Expand Down Expand Up @@ -77,7 +79,10 @@ func (s *RPCServer) handleWS(ctx context.Context, w http.ResponseWriter, r *http
}
}

wc.handleWsConn(ctx)
lbl := pprof.Labels("jrpc-mode", "wsserver", "jrpc-remote", r.RemoteAddr, "jrpc-uuid", uuid.New().String())
pprof.Do(ctx, lbl, func(ctx context.Context) {
wc.handleWsConn(ctx)
})

if err := c.Close(); err != nil {
log.Errorw("closing websocket connection", "error", err)
Expand Down
Loading