Skip to content

Commit

Permalink
Merge pull request #359 from tdakkota/refactor/multicore-decrypting
Browse files Browse the repository at this point in the history
refactor(mtproto): add worker-based message decrypting
  • Loading branch information
ernado authored May 28, 2021
2 parents b861cd9 + 6269362 commit a4a8c69
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 77 deletions.
47 changes: 47 additions & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
name: Benchmark
on:
push:
tags:
- v*
branches:
- main
pull_request:
workflow_dispatch:

jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout latest code
uses: actions/checkout@v2.3.4

- name: Install Go
uses: actions/setup-go@v2.1.3
with:
go-version: 1.16

- name: Get Go environment
id: go-env
run: echo "::set-output name=modcache::$(go env GOMODCACHE)"
- name: Set up cache
uses: actions/cache@v2.1.5
with:
path: ${{ steps.go-env.outputs.modcache }}
key: benchmark-${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
benchmark-${{ runner.os }}-go-
- name: Download dependencies
run: go mod download && go mod tidy

# Run all benchmarks.
- name: Run tests
run: go test -v -bench . -run ^$ ./... | tee benchmark.txt

- name: Upload artifact
uses: actions/upload-artifact@v2.2.3
with:
name: benchmark-result
path: benchmark.txt
if-no-files-found: error
retention-days: 1
5 changes: 1 addition & 4 deletions internal/mtproto/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ type Conn struct {
pingInterval time.Duration

readConcurrency int
messages chan *crypto.EncryptedMessageData
gotSession *tdsync.Ready

dialTimeout time.Duration
Expand Down Expand Up @@ -135,7 +134,6 @@ func New(dialer Dialer, opt Options) *Conn {
pingInterval: opt.PingInterval,

readConcurrency: opt.ReadConcurrency,
messages: make(chan *crypto.EncryptedMessageData, opt.ReadConcurrency),
gotSession: tdsync.NewReady(),

rpc: opt.engine,
Expand Down Expand Up @@ -197,13 +195,12 @@ func (c *Conn) Run(ctx context.Context, f func(ctx context.Context) error) error
g := tdsync.NewLogGroup(ctx, c.log.Named("group"))
g.Go("handleClose", c.handleClose)
g.Go("pingLoop", c.pingLoop)
g.Go("readLoop", c.readLoop)
g.Go("ackLoop", c.ackLoop)
g.Go("saltsLoop", c.saltLoop)
g.Go("userCallback", f)

for i := 0; i < c.readConcurrency; i++ {
g.Go("readEncryptedMessages-"+strconv.Itoa(i), c.readEncryptedMessages)
g.Go("readLoop-"+strconv.Itoa(i), c.readLoop)
}
if err := g.Wait(); err != nil {
return xerrors.Errorf("group: %w", err)
Expand Down
6 changes: 5 additions & 1 deletion internal/mtproto/encrypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package mtproto
import (
"fmt"
"testing"
"time"

"go.uber.org/zap"

"github.com/gotd/neo"
"github.com/gotd/td/bin"
"github.com/gotd/td/internal/crypto"
)
Expand Down Expand Up @@ -44,6 +46,7 @@ func BenchmarkEncryption(b *testing.B) {
rand: Zero{},
log: zap.NewNop(),
cipher: crypto.NewClientCipher(Zero{}),
clock: neo.NewTime(time.Now()),
}
for i := 0; i < 256; i++ {
c.authKey.Value[i] = byte(i)
Expand All @@ -52,7 +55,8 @@ func BenchmarkEncryption(b *testing.B) {
for _, payload := range []int{
128,
1024,
5000,
16 * 1024,
512 * 1024,
} {
b.Run(fmt.Sprintf("%d", payload), func(b *testing.B) {
benchPayload(b, c, payload)
Expand Down
2 changes: 1 addition & 1 deletion internal/mtproto/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (opt *Options) setDefaults() {
if opt.Handler == nil {
opt.Handler = nopHandler{}
}
if opt.ReadConcurrency == 0 {
if opt.ReadConcurrency < 2 {
opt.setDefaultConcurrency()
}
if opt.Cipher == nil {
Expand Down
89 changes: 39 additions & 50 deletions internal/mtproto/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,7 @@ func checkMessageID(now time.Time, rawID int64) error {
return nil
}

func (c *Conn) read(ctx context.Context, b *bin.Buffer) (*crypto.EncryptedMessageData, error) {
b.Reset()
if err := c.conn.Recv(ctx, b); err != nil {
return nil, err
}

func (c *Conn) decryptMessage(b *bin.Buffer) (*crypto.EncryptedMessageData, error) {
session := c.session()
msg, err := c.cipher.DecryptFromBuffer(session.Key, b)
if err != nil {
Expand All @@ -74,6 +69,36 @@ func (c *Conn) read(ctx context.Context, b *bin.Buffer) (*crypto.EncryptedMessag
return msg, nil
}

func (c *Conn) consumeMessage(ctx context.Context, buf *bin.Buffer) error {
msg, err := c.decryptMessage(buf)
if errors.Is(err, errRejected) {
c.log.Warn("Ignoring rejected message", zap.Error(err))
return nil
}
if err != nil {
return xerrors.Errorf("consume message: %w", err)
}

if err := c.handleMessage(msg.MessageID, &bin.Buffer{Buf: msg.Data()}); err != nil {
// Probably we can return here, but this will shutdown whole
// connection which can be unexpected.
c.log.Warn("Error while handling message", zap.Error(err))
// Sending acknowledge even on error. Client should restore
// from missing updates via explicit pts check and getDiff call.
}

needAck := (msg.SeqNo & 0x01) != 0
if needAck {
select {
case <-ctx.Done():
return ctx.Err()
case c.ackSendChan <- msg.MessageID:
}
}

return nil
}

func (c *Conn) noUpdates(err error) bool {
// Checking for read timeout.
var syscall *net.OpError
Expand Down Expand Up @@ -106,6 +131,7 @@ func (c *Conn) handleAuthKeyNotFound(ctx context.Context) error {

func (c *Conn) readLoop(ctx context.Context) (err error) {
b := new(bin.Buffer)

log := c.log.Named("read")
log.Debug("Read loop started")
defer func() {
Expand All @@ -115,21 +141,17 @@ func (c *Conn) readLoop(ctx context.Context) (err error) {
}
l.Debug("Read loop done")
}()
defer close(c.messages)

for {
msg, err := c.read(ctx, b)
if errors.Is(err, errRejected) {
c.log.Warn("Ignoring rejected message", zap.Error(err))
continue
}
b.Reset()

err = c.conn.Recv(ctx, b)
if err == nil {
select {
case <-ctx.Done():
return ctx.Err()
case c.messages <- msg:
continue
if err := c.consumeMessage(ctx, b); err != nil {
return err
}

continue
}

select {
Expand Down Expand Up @@ -157,36 +179,3 @@ func (c *Conn) readLoop(ctx context.Context) (err error) {
}
}
}

func (c *Conn) readEncryptedMessages(ctx context.Context) error {
b := new(bin.Buffer)
for {
select {
case <-ctx.Done():
return ctx.Err()
case msg, ok := <-c.messages:
if !ok {
return nil
}
b.ResetTo(msg.Data())

if err := c.handleMessage(msg.MessageID, b); err != nil {
// Probably we can return here, but this will shutdown whole
// connection which can be unexpected.
c.log.Warn("Error while handling message", zap.Error(err))
// Sending acknowledge even on error. Client should restore
// from missing updates via explicit pts check and getDiff call.
}

needAck := (msg.SeqNo & 0x01) != 0
if needAck {
select {
case <-ctx.Done():
return ctx.Err()
case c.ackSendChan <- msg.MessageID:
continue
}
}
}
}
}
30 changes: 18 additions & 12 deletions internal/mtproto/read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package mtproto
import (
"context"
"crypto/rand"
"errors"
"fmt"
"io"
"runtime"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/zap"

"github.com/gotd/neo"
Expand Down Expand Up @@ -74,24 +75,31 @@ func (n noopBuf) Consume(id int64) bool {

type constantConn struct {
data []byte
counter *atomic.Int64
cancel context.CancelFunc
counter int
mux sync.Mutex
}

func (c constantConn) Send(ctx context.Context, b *bin.Buffer) error {
func (c *constantConn) Send(ctx context.Context, b *bin.Buffer) error {
return nil
}

func (c constantConn) Recv(ctx context.Context, b *bin.Buffer) error {
if c.counter.Dec() == 0 {
func (c *constantConn) Recv(ctx context.Context, b *bin.Buffer) error {
c.mux.Lock()
exit := c.counter == 0
if exit {
c.mux.Unlock()
c.cancel()
return ctx.Err()
return errors.New("error")
}
c.counter--
c.mux.Unlock()

b.Put(c.data)
return nil
}

func (c constantConn) Close() error {
func (c *constantConn) Close() error {
return nil
}

Expand Down Expand Up @@ -132,10 +140,10 @@ func benchRead(payloadSize int) func(b *testing.B) {
defer cancel()

conn := Conn{
conn: constantConn{
conn: &constantConn{
data: msg.Raw(),
counter: atomic.NewInt64(int64(b.N)),
cancel: cancel,
counter: b.N,
},
handler: nopHandler{},
clock: c,
Expand All @@ -146,17 +154,15 @@ func benchRead(payloadSize int) func(b *testing.B) {
messageIDBuf: noopBuf{},
authKey: authKey,
readConcurrency: procs,
messages: make(chan *crypto.EncryptedMessageData, procs),
}
grp := tdsync.NewCancellableGroup(ctx)

b.ResetTimer()
b.ReportAllocs()
b.SetBytes(int64(payloadSize))

grp.Go(conn.readLoop)
for i := 0; i < procs; i++ {
grp.Go(conn.readEncryptedMessages)
grp.Go(conn.readLoop)
}
a.ErrorIs(grp.Wait(), context.Canceled)
}
Expand Down
19 changes: 10 additions & 9 deletions telegram/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,16 @@ func NewClient(appID int, appHash string, opt Options) *Client {
}

client.opts = mtproto.Options{
PublicKeys: opt.PublicKeys,
Random: opt.Random,
Logger: opt.Logger,
AckBatchSize: opt.AckBatchSize,
AckInterval: opt.AckInterval,
RetryInterval: opt.RetryInterval,
MaxRetries: opt.MaxRetries,
MessageID: opt.MessageID,
Clock: opt.Clock,
PublicKeys: opt.PublicKeys,
Random: opt.Random,
Logger: opt.Logger,
AckBatchSize: opt.AckBatchSize,
AckInterval: opt.AckInterval,
RetryInterval: opt.RetryInterval,
MaxRetries: opt.MaxRetries,
ReadConcurrency: opt.ReadConcurrency,
MessageID: opt.MessageID,
Clock: opt.Clock,

Types: tmap.New(
tg.TypesMap(),
Expand Down
4 changes: 4 additions & 0 deletions telegram/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ type Options struct {
AckInterval time.Duration
RetryInterval time.Duration
MaxRetries int
// ReadConcurrency is a count of workers to decrypt and decode incoming messages.
// Should be more than 2 to make effect. Otherwise ignored.
ReadConcurrency int

// Device is device config.
// Will be sent with session creation request.
Expand Down Expand Up @@ -96,6 +99,7 @@ func (opt *Options) setDefaults() {
if opt.MaxRetries == 0 {
opt.MaxRetries = 5
}
// Keep ReadConcurrency is zero, mtproto.Options will set default value.
opt.Device.SetDefaults()
if opt.Clock == nil {
opt.Clock = clock.System
Expand Down

0 comments on commit a4a8c69

Please sign in to comment.