Skip to content

Commit

Permalink
[MM-54335] Batching of call join/leave operations (#704)
Browse files Browse the repository at this point in the history
* Implement websocket broadcast to specific users

* Broadcast specific events to call participants only

* Instrument code to monitor app and store performance

* go mod tidy

* Remove debug metric

* Implement setCallEnded util

* Simple batcher implementation

* Implement batch context

* Abort batch on pre cb error

* addSessionsBatchers

* removeSessionsBatchers

* Use batching only when it may be useful

* Tests

* More logs

* Disable session auth check

* callState.Clone()

* Clone call state for simple rollback in case of error
  • Loading branch information
streamer45 authored May 27, 2024
1 parent ad857c6 commit 6b3bf34
Show file tree
Hide file tree
Showing 18 changed files with 1,727 additions and 193 deletions.
4 changes: 4 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ packages:
config:
interfaces:
LoggerIFace:
github.com/mattermost/rtcd/service/rtc:
config:
interfaces:
Metrics:
2 changes: 1 addition & 1 deletion server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ func (p *Plugin) handleEndCall(w http.ResponseWriter, r *http.Request) {
}

if state.Call.EndAt == 0 {
state.Call.EndAt = time.Now().UnixMilli()
setCallEnded(&state.Call)
}

if err := p.store.UpdateCall(&state.Call); err != nil {
Expand Down
113 changes: 113 additions & 0 deletions server/batching/batcher.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) 2022-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.

package batching

import (
"fmt"
"time"
)

type Context map[string]any

const (
ContextBatchNumKey = "batch_num"
ContextBatchSizeKey = "batch_size"
)

type Item func(ctx Context)
type BatchCb func(ctx Context) error

type Batcher struct {
cfg Config
itemsCh chan Item
stopCh chan struct{}
doneCh chan struct{}
batches int
}

type Config struct {
// The frequency at which batches should be executed.
Interval time.Duration
// The maximum size of the queue of items.
Size int
// An optional callback to be executed before processing a batch.
// This is where expensive operations should usually be performed
// in order to make the batching efficient.
PreRunCb BatchCb
// An optional callback to be executed after processing a batch.
PostRunCb BatchCb
}

// NewBatcher creates a new Batcher with the given config.
func NewBatcher(cfg Config) (*Batcher, error) {
if cfg.Interval <= 0 {
return nil, fmt.Errorf("interval should be > 0")
}

if cfg.Size <= 0 {
return nil, fmt.Errorf("size should be > 0")
}

return &Batcher{
cfg: cfg,
itemsCh: make(chan Item, cfg.Size),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
}, nil
}

// Push adds one item into the work queue.
func (b *Batcher) Push(item Item) error {
select {
case b.itemsCh <- item:
default:
return fmt.Errorf("failed to push item, channel is full")
}

return nil
}

// Start begins the processing of batches at the configured interval. Should only be called once.
func (b *Batcher) Start() {
go func() {
defer close(b.doneCh)
ticker := time.NewTicker(b.cfg.Interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if batchSize := len(b.itemsCh); batchSize > 0 {
b.batches++

ctx := Context{
ContextBatchNumKey: b.batches,
ContextBatchSizeKey: batchSize,
}

if b.cfg.PreRunCb != nil {
if err := b.cfg.PreRunCb(ctx); err != nil {
continue
}
}

for i := 0; i < batchSize; i++ {
(<-b.itemsCh)(ctx)
}

if b.cfg.PostRunCb != nil {
_ = b.cfg.PostRunCb(ctx)
}
}
case <-b.stopCh:
return
}
}
}()
}

// Stop stops the batching process. Should only be called once.
func (b *Batcher) Stop() {
close(b.stopCh)
<-b.doneCh
}
137 changes: 137 additions & 0 deletions server/batching/batcher_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Copyright (c) 2022-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.

package batching

import (
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestBatcher(t *testing.T) {
t.Run("NewBatcher", func(t *testing.T) {
t.Run("invalid interval", func(t *testing.T) {
b, err := NewBatcher(Config{
Size: 10,
})
require.EqualError(t, err, "interval should be > 0")
require.Nil(t, b)
})

t.Run("invalid size", func(t *testing.T) {
b, err := NewBatcher(Config{
Interval: time.Second,
})
require.EqualError(t, err, "size should be > 0")
require.Nil(t, b)
})

t.Run("valid", func(t *testing.T) {
b, err := NewBatcher(Config{
Interval: time.Second,
Size: 10,
})
require.NoError(t, err)
require.NotNil(t, b)
})
})

t.Run("basic batching", func(t *testing.T) {
b, err := NewBatcher(Config{
Interval: 10 * time.Millisecond,
Size: 10,
})
require.NoError(t, err)
require.NotNil(t, b)

b.Start()

var counter int

// Simulating some bursts of requests that need batching.
for i := 0; i < 10; i++ {
for j := 0; j < 10; j++ {
fmt.Printf("pushing %d\n", i*10+j)
err := b.Push(func(_ Context) {
fmt.Printf("executing %d\n", counter)
counter++
})
require.NoError(t, err)
}
time.Sleep(50 * time.Millisecond)
}

b.Stop()

require.Equal(t, 100, counter)
require.GreaterOrEqual(t, b.batches, 10)
})

t.Run("context aware batching", func(t *testing.T) {
preRunCb := func(ctx Context) error {
ctx["shared_state"] = 0
return nil
}

postRunCb := func(ctx Context) error {
require.Equal(t, ctx[ContextBatchSizeKey].(int), ctx["shared_state"])
return nil
}

b, err := NewBatcher(Config{
Interval: 10 * time.Millisecond,
Size: 10,
PreRunCb: preRunCb,
PostRunCb: postRunCb,
})
require.NoError(t, err)
require.NotNil(t, b)

b.Start()

// Simulating some bursts of requests that need batching.
for i := 0; i < 10; i++ {
for j := 0; j < 10; j++ {
err := b.Push(func(ctx Context) {
ctx["shared_state"] = ctx["shared_state"].(int) + 1
})
require.NoError(t, err)
}
time.Sleep(50 * time.Millisecond)
}

b.Stop()
})

t.Run("returning error", func(t *testing.T) {
b, err := NewBatcher(Config{
Interval: 10 * time.Millisecond,
Size: 100,
PreRunCb: func(_ Context) error {
return fmt.Errorf("some error")
},
})
require.NoError(t, err)
require.NotNil(t, b)

b.Start()

var counter int
// Simulating some bursts of requests that need batching.
for i := 0; i < 10; i++ {
for j := 0; j < 10; j++ {
err := b.Push(func(_ Context) {
counter++
})
require.NoError(t, err)
}
time.Sleep(50 * time.Millisecond)
}

b.Stop()
require.Zero(t, counter)
})
}
23 changes: 23 additions & 0 deletions server/db/calls_sessions_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,26 @@ func (s *Store) DeleteCallsSessions(callID string) error {

return nil
}

func (s *Store) GetCallSessionsCount(callID string, opts GetCallSessionOpts) (int, error) {
s.metrics.IncStoreOp("GetCallSessionsCount")
defer func(start time.Time) {
s.metrics.ObserveStoreMethodsTime("GetCallSessionsCount", time.Since(start).Seconds())
}(time.Now())

qb := getQueryBuilder(s.driverName).Select("COUNT(*)").
From("calls_sessions").
Where(sq.Eq{"CallID": callID})

q, args, err := qb.ToSql()
if err != nil {
return 0, fmt.Errorf("failed to prepare query: %w", err)
}

var count int
if err := s.dbXFromGetOpts(opts).Get(&count, q, args...); err != nil {
return 0, fmt.Errorf("failed to get call sessions count: %w", err)
}

return count, nil
}
43 changes: 37 additions & 6 deletions server/db/calls_sessions_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ import (
func TestCallsSessionsStore(t *testing.T) {
t.Parallel()
testStore(t, map[string]func(t *testing.T, store *Store){
"TestCreateCallSession": testCreateCallSession,
"TestDeleteCallSession": testDeleteCallSession,
"TestUpdateCallSession": testUpdateCallSession,
"TestGetCallSession": testGetCallSession,
"TestGetCallSessions": testGetCallSessions,
"TestDeleteCallsSessions": testDeleteCallsSessions,
"TestCreateCallSession": testCreateCallSession,
"TestDeleteCallSession": testDeleteCallSession,
"TestUpdateCallSession": testUpdateCallSession,
"TestGetCallSession": testGetCallSession,
"TestGetCallSessions": testGetCallSessions,
"TestDeleteCallsSessions": testDeleteCallsSessions,
"TestGetCallSessionsCount": testGetCallSessionsCount,
})
}

Expand Down Expand Up @@ -222,3 +223,33 @@ func testDeleteCallsSessions(t *testing.T, store *Store) {
require.Empty(t, sessions)
})
}

func testGetCallSessionsCount(t *testing.T, store *Store) {
t.Run("no sessions", func(t *testing.T) {
cnt, err := store.GetCallSessionsCount(model.NewId(), GetCallSessionOpts{})
require.NoError(t, err)
require.Zero(t, cnt)
})

t.Run("multiple sessions", func(t *testing.T) {
sessions := map[string]*public.CallSession{}
callID := model.NewId()
for i := 0; i < 10; i++ {
session := &public.CallSession{
ID: model.NewId(),
CallID: callID,
UserID: model.NewId(),
JoinAt: time.Now().UnixMilli(),
}

err := store.CreateCallSession(session)
require.NoError(t, err)

sessions[session.ID] = session
}

cnt, err := store.GetCallSessionsCount(callID, GetCallSessionOpts{})
require.NoError(t, err)
require.Equal(t, 10, cnt)
})
}
Loading

0 comments on commit 6b3bf34

Please sign in to comment.