diff --git a/pkg/sql/conn_io.go b/pkg/sql/conn_io.go index 1e422d0e50c1..beb4dbde3704 100644 --- a/pkg/sql/conn_io.go +++ b/pkg/sql/conn_io.go @@ -23,6 +23,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sessiondata" "github.com/cockroachdb/cockroach/pkg/sql/sqlbase" "github.com/cockroachdb/cockroach/pkg/util/log" + "github.com/cockroachdb/cockroach/pkg/util/ring" "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/errors" "github.com/lib/pq/oid" @@ -94,7 +95,7 @@ type StmtBuf struct { cond *sync.Cond // data contains the elements of the buffer. - data []Command + data ring.Buffer // []Command // startPos indicates the index of the first command currently in data // relative to the start of the connection. @@ -398,7 +399,7 @@ func (buf *StmtBuf) Push(ctx context.Context, cmd Command) error { if buf.mu.closed { return errors.AssertionFailedf("buffer is closed") } - buf.mu.data = append(buf.mu.data, cmd) + buf.mu.data.AddLast(cmd) buf.mu.lastPos++ buf.mu.cond.Signal() @@ -426,10 +427,11 @@ func (buf *StmtBuf) CurCmd() (Command, CmdPos, error) { if err != nil { return nil, 0, err } - if cmdIdx < len(buf.mu.data) { - return buf.mu.data[cmdIdx], curPos, nil + len := buf.mu.data.Len() + if cmdIdx < len { + return buf.mu.data.Get(cmdIdx).(Command), curPos, nil } - if cmdIdx != len(buf.mu.data) { + if cmdIdx != len { return nil, 0, errors.AssertionFailedf( "can only wait for next command; corrupt cursor: %d", errors.Safe(curPos)) } @@ -473,8 +475,7 @@ func (buf *StmtBuf) ltrim(ctx context.Context, pos CmdPos) { if buf.mu.startPos == pos { break } - buf.mu.data[0] = nil - buf.mu.data = buf.mu.data[1:] + buf.mu.data.RemoveFirst() buf.mu.startPos++ } } @@ -509,7 +510,7 @@ func (buf *StmtBuf) seekToNextBatch() error { buf.mu.Unlock() return err } - if cmdIdx == len(buf.mu.data) { + if cmdIdx == buf.mu.data.Len() { buf.mu.Unlock() return errors.AssertionFailedf("invalid seek start point") } @@ -529,7 +530,7 @@ func (buf *StmtBuf) seekToNextBatch() error { return err } - if _, ok := buf.mu.data[cmdIdx].(Sync); ok { + if _, ok := buf.mu.data.Get(cmdIdx).(Sync); ok { foundSync = true } diff --git a/pkg/sql/conn_io_test.go b/pkg/sql/conn_io_test.go index 1485f8dec6b5..14f7ade500ba 100644 --- a/pkg/sql/conn_io_test.go +++ b/pkg/sql/conn_io_test.go @@ -177,7 +177,7 @@ func TestStmtBufLtrim(t *testing.T) { buf.AdvanceOne() trimPos := CmdPos(2) buf.ltrim(ctx, trimPos) - if l := len(buf.mu.data); l != 3 { + if l := buf.mu.data.Len(); l != 3 { t.Fatalf("expected 3 left, got: %d", l) } if s := buf.mu.startPos; s != 2 { diff --git a/pkg/util/ring/ring_buffer.go b/pkg/util/ring/ring_buffer.go index 8dbb0d8cd751..d0cf5120e0c5 100644 --- a/pkg/util/ring/ring_buffer.go +++ b/pkg/util/ring/ring_buffer.go @@ -10,22 +10,22 @@ package ring -const bufferInitialSize = 8 - -// Buffer is a deque maintained over a ring buffer. Note: it is backed by -// a slice (unlike container/ring one that is backed by a linked list). +// Buffer is a deque maintained over a ring buffer. +// +// Note: it is backed by a slice (unlike container/ring which is backed by a +// linked list). type Buffer struct { buffer []interface{} - head int // the index of the front of the deque. - tail int // the index of the first position right after the end of the deque. + head int // the index of the front of the buffer + tail int // the index of the first position after the end of the buffer - // indicates whether the deque is empty, necessary to distinguish - // between an empty deque and a deque that uses all of its capacity. + // Indicates whether the buffer is empty. Necessary to distinguish + // between an empty buffer and a buffer that uses all of its capacity. nonEmpty bool } -// Len returns the number of elements in the deque. -func (r Buffer) Len() int { +// Len returns the number of elements in the Buffer. +func (r *Buffer) Len() int { if !r.nonEmpty { return 0 } @@ -38,83 +38,78 @@ func (r Buffer) Len() int { } } -// AddFirst add element to the front of the deque -// and doubles it's underlying slice if necessary. -func (r *Buffer) AddFirst(element interface{}) { - if cap(r.buffer) == 0 { - r.buffer = make([]interface{}, bufferInitialSize) - r.buffer[0] = element - r.tail = 1 - } else { - if r.Len() == cap(r.buffer) { - newBuffer := make([]interface{}, 2*cap(r.buffer)) - if r.head < r.tail { - copy(newBuffer[:r.Len()], r.buffer[r.head:r.tail]) - } else { - copy(newBuffer[:cap(r.buffer)-r.head], r.buffer[r.head:]) - copy(newBuffer[cap(r.buffer)-r.head:r.Len()], r.buffer[:r.tail]) - } - r.head = 0 - r.tail = cap(r.buffer) - r.buffer = newBuffer - } - r.head = (cap(r.buffer) + r.head - 1) % cap(r.buffer) - r.buffer[r.head] = element - } - r.nonEmpty = true +// Cap returns the capacity of the Buffer. +func (r *Buffer) Cap() int { + return cap(r.buffer) } -// AddLast adds element to the end of the deque -// and doubles it's underlying slice if necessary. -func (r *Buffer) AddLast(element interface{}) { - if cap(r.buffer) == 0 { - r.buffer = make([]interface{}, bufferInitialSize) - r.buffer[0] = element - r.tail = 1 - } else { - if r.Len() == cap(r.buffer) { - newBuffer := make([]interface{}, 2*cap(r.buffer)) - if r.head < r.tail { - copy(newBuffer[:r.Len()], r.buffer[r.head:r.tail]) - } else { - copy(newBuffer[:cap(r.buffer)-r.head], r.buffer[r.head:]) - copy(newBuffer[cap(r.buffer)-r.head:r.Len()], r.buffer[:r.tail]) - } - r.head = 0 - r.tail = cap(r.buffer) - r.buffer = newBuffer - } - r.buffer[r.tail] = element - r.tail = (r.tail + 1) % cap(r.buffer) - } - r.nonEmpty = true -} - -// Get returns an element at position pos in the deque (zero-based). -func (r Buffer) Get(pos int) interface{} { +// Get returns an element at position pos in the Buffer (zero-based). +func (r *Buffer) Get(pos int) interface{} { if !r.nonEmpty || pos < 0 || pos >= r.Len() { - panic("unexpected behavior: index out of bounds") + panic("index out of bounds") } return r.buffer[(pos+r.head)%cap(r.buffer)] } -// GetFirst returns an element at the front of the deque. -func (r Buffer) GetFirst() interface{} { +// GetFirst returns an element at the front of the Buffer. +func (r *Buffer) GetFirst() interface{} { if !r.nonEmpty { - panic("unexpected behavior: getting first from empty deque") + panic("getting first from empty ring buffer") } return r.buffer[r.head] } -// GetLast returns an element at the front of the deque. -func (r Buffer) GetLast() interface{} { +// GetLast returns an element at the front of the Buffer. +func (r *Buffer) GetLast() interface{} { if !r.nonEmpty { - panic("unexpected behavior: getting last from empty deque") + panic("getting last from empty ring buffer") } return r.buffer[(cap(r.buffer)+r.tail-1)%cap(r.buffer)] } -// RemoveFirst removes a single element from the front of the deque. +func (r *Buffer) grow(n int) { + newBuffer := make([]interface{}, n) + if r.head < r.tail { + copy(newBuffer[:r.Len()], r.buffer[r.head:r.tail]) + } else { + copy(newBuffer[:cap(r.buffer)-r.head], r.buffer[r.head:]) + copy(newBuffer[cap(r.buffer)-r.head:r.Len()], r.buffer[:r.tail]) + } + r.head = 0 + r.tail = cap(r.buffer) + r.buffer = newBuffer +} + +func (r *Buffer) maybeGrow() { + if r.Len() != cap(r.buffer) { + return + } + n := 2 * cap(r.buffer) + if n == 0 { + n = 1 + } + r.grow(n) +} + +// AddFirst add element to the front of the Buffer and doubles it's underlying +// slice if necessary. +func (r *Buffer) AddFirst(element interface{}) { + r.maybeGrow() + r.head = (cap(r.buffer) + r.head - 1) % cap(r.buffer) + r.buffer[r.head] = element + r.nonEmpty = true +} + +// AddLast adds element to the end of the Buffer and doubles it's underlying +// slice if necessary. +func (r *Buffer) AddLast(element interface{}) { + r.maybeGrow() + r.buffer[r.tail] = element + r.tail = (r.tail + 1) % cap(r.buffer) + r.nonEmpty = true +} + +// RemoveFirst removes a single element from the front of the Buffer. func (r *Buffer) RemoveFirst() { if r.Len() == 0 { panic("removing first from empty ring buffer") @@ -126,7 +121,7 @@ func (r *Buffer) RemoveFirst() { } } -// RemoveLast removes a single element from the end of the deque. +// RemoveLast removes a single element from the end of the Buffer. func (r *Buffer) RemoveLast() { if r.Len() == 0 { panic("removing last from empty ring buffer") @@ -139,6 +134,16 @@ func (r *Buffer) RemoveLast() { } } +// Reserve reserves the provided number of elemnets in the Buffer. It is an +// error to reserve a size less than the Buffer's current length. +func (r *Buffer) Reserve(n int) { + if n < r.Len() { + panic("reserving fewer elements than current length") + } else if n > cap(r.buffer) { + r.grow(n) + } +} + // Reset makes Buffer treat its underlying memory as if it were empty. This // allows for reusing the same memory again without explicitly removing old // elements. diff --git a/pkg/util/ring/ring_buffer_test.go b/pkg/util/ring/ring_buffer_test.go index e3ef82da51a0..fb250009cade 100644 --- a/pkg/util/ring/ring_buffer_test.go +++ b/pkg/util/ring/ring_buffer_test.go @@ -11,9 +11,10 @@ package ring import ( - "fmt" "math/rand" "testing" + + "github.com/stretchr/testify/require" ) const maxCount = 1000 @@ -22,43 +23,37 @@ func testRingBuffer(t *testing.T, count int) { var buffer Buffer naiveBuffer := make([]interface{}, 0, count) for elementIdx := 0; elementIdx < count; elementIdx++ { - if buffer.Len() != len(naiveBuffer) { - t.Errorf("Ring buffer returned incorrect Len: expected %v, found %v", len(naiveBuffer), buffer.Len()) - panic("") - } - - op := rand.Float64() - if op < 0.35 { + switch rand.Intn(4) { + case 0: buffer.AddFirst(elementIdx) naiveBuffer = append([]interface{}{elementIdx}, naiveBuffer...) - } else if op < 0.70 { + case 1: buffer.AddLast(elementIdx) naiveBuffer = append(naiveBuffer, elementIdx) - } else if op < 0.85 { + case 2: if len(naiveBuffer) > 0 { buffer.RemoveFirst() - naiveBuffer = naiveBuffer[1:] + // NB: shift to preserve length. + copy(naiveBuffer, naiveBuffer[1:]) + naiveBuffer = naiveBuffer[:len(naiveBuffer)-1] } - } else { + case 3: if len(naiveBuffer) > 0 { buffer.RemoveLast() naiveBuffer = naiveBuffer[:len(naiveBuffer)-1] } + default: + t.Fatal("unexpected") } + require.Equal(t, len(naiveBuffer), buffer.Len()) for pos, el := range naiveBuffer { res := buffer.Get(pos) - if res != el { - panic(fmt.Sprintf("Ring buffer returned incorrect value in position %v: expected %+v, found %+v", pos, el, res)) - } + require.Equal(t, el, res) } if len(naiveBuffer) > 0 { - if buffer.GetFirst() != naiveBuffer[0] { - panic(fmt.Sprintf("Ring buffer returned incorrect value of the first element: expected %+v, found %+v", naiveBuffer[0], buffer.GetFirst())) - } - if buffer.GetLast() != naiveBuffer[len(naiveBuffer)-1] { - panic(fmt.Sprintf("Ring buffer returned incorrect value of the last element: expected %+v, found %+v", naiveBuffer[len(naiveBuffer)-1], buffer.GetLast())) - } + require.Equal(t, naiveBuffer[0], buffer.GetFirst()) + require.Equal(t, naiveBuffer[len(naiveBuffer)-1], buffer.GetLast()) } } } @@ -68,3 +63,54 @@ func TestRingBuffer(t *testing.T) { testRingBuffer(t, count) } } + +func TestRingBufferCapacity(t *testing.T) { + var b Buffer + + require.Panics(t, func() { b.Reserve(-1) }) + require.Equal(t, 0, b.Len()) + require.Equal(t, 0, b.Cap()) + + b.Reserve(0) + require.Equal(t, 0, b.Len()) + require.Equal(t, 0, b.Cap()) + + b.AddFirst("a") + require.Equal(t, 1, b.Len()) + require.Equal(t, 1, b.Cap()) + require.Panics(t, func() { b.Reserve(0) }) + require.Equal(t, 1, b.Len()) + require.Equal(t, 1, b.Cap()) + b.Reserve(1) + require.Equal(t, 1, b.Len()) + require.Equal(t, 1, b.Cap()) + b.Reserve(2) + require.Equal(t, 1, b.Len()) + require.Equal(t, 2, b.Cap()) + + b.AddLast("z") + require.Equal(t, 2, b.Len()) + require.Equal(t, 2, b.Cap()) + require.Panics(t, func() { b.Reserve(1) }) + require.Equal(t, 2, b.Len()) + require.Equal(t, 2, b.Cap()) + b.Reserve(2) + require.Equal(t, 2, b.Len()) + require.Equal(t, 2, b.Cap()) + b.Reserve(9) + require.Equal(t, 2, b.Len()) + require.Equal(t, 9, b.Cap()) + + b.RemoveFirst() + require.Equal(t, 1, b.Len()) + require.Equal(t, 9, b.Cap()) + b.Reserve(1) + require.Equal(t, 1, b.Len()) + require.Equal(t, 9, b.Cap()) + b.RemoveLast() + require.Equal(t, 0, b.Len()) + require.Equal(t, 9, b.Cap()) + b.Reserve(0) + require.Equal(t, 0, b.Len()) + require.Equal(t, 9, b.Cap()) +}