Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize mpsc #358

Merged
merged 9 commits into from
Nov 30, 2022
86 changes: 16 additions & 70 deletions exp/mpsc/mpsc.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
// Package mpsc implements a multiple-producer, single-consumer queue.
package mpsc

// N.B. this is a trivial wrapper around the spsc package, which just adds
// a lock to the Tx end to make it multiple-producer.

import (
"capnproto.org/go/capnp/v3/internal/chanmutex"
"context"
"sync"

"capnproto.org/go/capnp/v3/exp/spsc"
)

// A multiple-producer, single-consumer queue. Create one with New(),
Expand All @@ -16,99 +21,40 @@ type Queue[T any] struct {

// The receive end of a Queue.
type Rx[T any] struct {
// The head of the list. If the list is empty, this will be
// non-nil but have a locked mu field.
head *node[T]
rx spsc.Rx[T]
}

// The send/transmit end of a Queue.
type Tx[T any] struct {
// Mutex which must be held by senders. A goroutine must hold this
// lock to manipulate `tail`.
mu chanmutex.Mutex

// Pointer to the tail of the list. This will have a locked mu,
// and zero values for other fields.
tail *node[T]
}

// A node in the linked linst that makes up the queue internally.
type node[T any] struct {
// A mutex which guards the other fields in the node.
// Nodes start out with this locked, and then we unlock it
// after filling in the other fields.
mu chanmutex.Mutex

// The next node in the list, if any. Must be non-nil if
// mu is unlocked:
next *node[T]

// The value in this node:
value T
}

// Create a new node, with a locked mutex and zero values for
// the other fields.
func newNode[T any]() *node[T] {
return &node[T]{mu: chanmutex.NewLocked()}
mu sync.Mutex
tx spsc.Tx[T]
}

// Create a new, initially empty Queue.
func New[T any]() *Queue[T] {
node := newNode[T]()
q := spsc.New[T]()
return &Queue[T]{
Tx: Tx[T]{
tail: node,
mu: chanmutex.NewUnlocked(),
},
Rx: Rx[T]{head: node},
Tx: Tx[T]{tx: q.Tx},
Rx: Rx[T]{rx: q.Rx},
}
}

// Send a message on the queue.
func (tx *Tx[T]) Send(v T) {
newTail := newNode[T]()

tx.mu.Lock()

oldTail := tx.tail
oldTail.next = newTail
oldTail.value = v
tx.tail = newTail
oldTail.mu.Unlock()

tx.mu.Unlock()
defer tx.mu.Unlock()
tx.tx.Send(v)
}

// Receive a message from the queue. Blocks if the queue is empty.
// If the context ends before the receive happens, this returns
// ctx.Err().
func (rx *Rx[T]) Recv(ctx context.Context) (T, error) {
var zero T
select {
case <-rx.head.mu:
return rx.doRecv(), nil
case <-ctx.Done():
return zero, ctx.Err()
}
return rx.rx.Recv(ctx)
}

// Try to receive a message from the queue. If successful, ok will be true.
// If the queue is empty, this will return immediately with ok = false.
func (rx *Rx[T]) TryRecv() (v T, ok bool) {
var zero T
select {
case <-rx.head.mu:
return rx.doRecv(), true
default:
return zero, false
}
}

// Helper for shared logic between Recv and TryRecv. Must be holding
// rx.head.mu.
func (rx *Rx[T]) doRecv() T {
ret := rx.head.value
rx.head = rx.head.next
return ret
return rx.rx.TryRecv()
}
110 changes: 110 additions & 0 deletions exp/spsc/spsc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Package spsc implements a single-producer, single-consumer queue.
package spsc

// Implementation overview: The queue wraps a buffered channel; under normal
// operation the sender and receiver are just using this channel in the usual
// way, but if the channel fills up, instead of blocking, the sender will close
// the main channel, use a secondary "next" channel to send the receiver the
// next channel to read from (as part of a node, which includes the next "next"
// as well). next has buffer size 1 so this operation does not block either.
//
// If the receiver reaches the end of the items queue, it reads from "next"
// for more items.

import (
"context"
)

const itemsBuffer = 64

// A single-producer, single-consumer queue. Create one with New(),
// and send with Tx.Send(). Tx and Rx are each not safe for use by
// multiple goroutines, but two separate goroutines can use Tx and
// Rx respectively.
type Queue[T any] struct {
Tx[T]
Rx[T]
}

// The receive end of a Queue.
type Rx[T any] struct {
head node[T]
}

// The send/transmit end of a Queue.
type Tx[T any] struct {
// Pointer to the tail of the list. This will have a locked mu,
// and zero values for other fields.
tail node[T]
}

type node[T any] struct {
items chan T
next chan node[T]
}

func newNode[T any]() node[T] {
return node[T]{
items: make(chan T, itemsBuffer),
next: make(chan node[T], 1),
}
}

// Create a new, initially empty Queue.
func New[T any]() Queue[T] {
n := newNode[T]()
return Queue[T]{
Rx: Rx[T]{head: n},
Tx: Tx[T]{tail: n},
}
}

// Send a message on the queue.
func (tx *Tx[T]) Send(v T) {
for {
select {
case tx.tail.items <- v:
return
default:
close(tx.tail.items)
n := newNode[T]()
tx.tail.next <- n
tx.tail = n
}
}
}

// Receive a message from the queue. Blocks if the queue is empty.
// If the context ends before the receive happens, this returns
// ctx.Err().
func (rx *Rx[T]) Recv(ctx context.Context) (T, error) {
for {
select {
case <-ctx.Done():
var zero T
return zero, ctx.Err()
case v, ok := <-rx.head.items:
if ok {
return v, nil
}
rx.head = <-rx.head.next
}
}
}

// Try to receive a message from the queue. If successful, ok will be true.
// If the queue is empty, this will return immediately with ok = false.
func (rx *Rx[T]) TryRecv() (v T, ok bool) {
for {
select {
case v, ok = <-rx.head.items:
if !ok {
rx.head = <-rx.head.next
continue
}
return v, true
default:
return v, false
}
}
}
66 changes: 66 additions & 0 deletions exp/spsc/spsc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package spsc

import (
"context"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

// Test that filling/overflowing the internal items queue doesn't block the sender.
func TestFillItemsNonBlock(t *testing.T) {
t.Parallel()

q := New[int]()

for i := 0; i < itemsBuffer+1; i++ {
q.Send(i)
}
}

// Try filling the queue, then draining it with TryRecv().
func TestFillThenTryDrain(t *testing.T) {
t.Parallel()

q := New[int]()

limit := itemsBuffer + 1

for i := 0; i < limit; i++ {
q.Send(i)
}

for i := 0; i < limit; i++ {
v, ok := q.TryRecv()
assert.True(t, ok)
assert.Equal(t, i, v)
}
_, ok := q.TryRecv()
assert.False(t, ok)
}

// Try filling the queue, then draining it with Recv().
func TestFillThenDrain(t *testing.T) {
t.Parallel()

q := New[int]()

limit := itemsBuffer + 1

for i := 0; i < limit; i++ {
q.Send(i)
}

ctx := context.Background()
for i := 0; i < limit; i++ {
v, err := q.Recv(ctx)
assert.Nil(t, err)
assert.Equal(t, i, v)
}
ctx, cancel := context.WithTimeout(ctx, 50*time.Millisecond)
defer cancel()
_, err := q.Recv(ctx)
assert.NotNil(t, err)
assert.ErrorIs(t, err, ctx.Err())
}
2 changes: 1 addition & 1 deletion rpc/answer.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func errorAnswer(c *Conn, id answerID, err error) *answer {

// newReturn creates a new Return message.
func (c *Conn) newReturn(ctx context.Context) (rpccp.Return, func(), capnp.ReleaseFunc, error) {
msg, send, releaseMsg, err := c.transport.NewMessage(ctx)
msg, send, releaseMsg, err := c.transport.NewMessage()
if err != nil {
return rpccp.Return{}, nil, nil, rpcerr.Failedf("create return: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions rpc/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ type measuringTransport struct {
inUse, maxInUse uint64
}

func (t *measuringTransport) RecvMessage(ctx context.Context) (rpccp.Message, capnp.ReleaseFunc, error) {
msg, release, err := t.Transport.RecvMessage(ctx)
func (t *measuringTransport) RecvMessage() (rpccp.Message, capnp.ReleaseFunc, error) {
msg, release, err := t.Transport.RecvMessage()
if err != nil {
return msg, release, err
}
Expand Down
Loading