Skip to content
This repository has been archived by the owner on Jan 28, 2021. It is now read-only.

Give vitess a chance to enforce connection timeouts if we've been waiting for a row for longer #801

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions server/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ package server

import (
"context"
"sync"

opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go"
"github.com/src-d/go-mysql-server/sql"
"sync"
"vitess.io/vitess/go/mysql"
)

Expand Down
91 changes: 69 additions & 22 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
sqle "github.com/src-d/go-mysql-server"
"github.com/src-d/go-mysql-server/auth"
"github.com/src-d/go-mysql-server/sql"
errors "gopkg.in/src-d/go-errors.v1"
"gopkg.in/src-d/go-errors.v1"

"github.com/sirupsen/logrus"
"vitess.io/vitess/go/mysql"
Expand All @@ -21,25 +21,29 @@ import (

var regKillCmd = regexp.MustCompile(`^kill (?:(query|connection) )?(\d+)$`)

var errConnectionNotFound = errors.NewKind("Connection not found: %c")
var errConnectionNotFound = errors.NewKind("connection not found: %c")
// ErrRowTimeout will be returned if the wait for the row is longer than the connection timeout
var ErrRowTimeout = errors.NewKind("row read wait bigger than connection timeout")
juanjux marked this conversation as resolved.
Show resolved Hide resolved

// TODO parametrize
const rowsBatch = 100

// Handler is a connection handler for a SQLe engine.
type Handler struct {
mu sync.Mutex
e *sqle.Engine
sm *SessionManager
c map[uint32]*mysql.Conn
mu sync.Mutex
e *sqle.Engine
sm *SessionManager
c map[uint32]*mysql.Conn
readTimeout time.Duration
}

// NewHandler creates a new Handler given a SQLe engine.
func NewHandler(e *sqle.Engine, sm *SessionManager) *Handler {
func NewHandler(e *sqle.Engine, sm *SessionManager, rt time.Duration) *Handler {
return &Handler{
e: e,
sm: sm,
c: make(map[uint32]*mysql.Conn),
e: e,
sm: sm,
c: make(map[uint32]*mysql.Conn),
readTimeout: rt,
}
}

Expand Down Expand Up @@ -103,6 +107,42 @@ func (h *Handler) ComQuery(

var r *sqltypes.Result
var proccesedAtLeastOneBatch bool

rowchan := make(chan sql.Row)
errchan := make(chan error)
quit := make(chan struct{})

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] You could put defer close(quit) here, instead of calling it in multiple places below. As far as I can see, anytime this function returns, you'd want quit closed anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to make it explicit so the code reader can see where we are actively signaling the goroutine.


// This goroutine will be select{}ed giving a chance to Vitess to call the
// handler.CloseConnection callback and enforcing the timeout if configured
go func() {
for {
select {
case <-quit:
return
default:
}
row, err := rows.Next()
if err != nil {
errchan <- err
return
}
rowchan <- row
}
}()

// Default waitTime is one 1 minute if there is not timeout configured, in which case
// it will loop to iterate again unless the socket died by the OS timeout or other problems.
// If there is a timeout, it will be enforced to ensure that Vitess has a chance to
// call Handler.CloseConnection()
waitTime := 1 * time.Minute

if h.readTimeout > 0 {
waitTime = h.readTimeout
}
timer := time.NewTimer(waitTime)
defer timer.Stop()

rowLoop:
for {
if r == nil {
r = &sqltypes.Result{Fields: schemaToFields(schema)}
Expand All @@ -115,26 +155,32 @@ func (h *Handler) ComQuery(

r = nil
proccesedAtLeastOneBatch = true

continue
}

row, err := rows.Next()
if err != nil {
select {
case err = <-errchan:
if err == io.EOF {
break
break rowLoop
}

return err
}
case row := <-rowchan:
outputRow, err := rowToSQL(schema, row)
if err != nil {
close(quit)
return err
juanjux marked this conversation as resolved.
Show resolved Hide resolved
}

outputRow, err := rowToSQL(schema, row)
if err != nil {
return err
r.Rows = append(r.Rows, outputRow)
r.RowsAffected++
case <-timer.C:
if h.readTimeout != 0 {
juanjux marked this conversation as resolved.
Show resolved Hide resolved
// Return so Vitess can call the CloseConnection callback
close(quit)
return ErrRowTimeout.New()
}
}

r.Rows = append(r.Rows, outputRow)
r.RowsAffected++
timer.Reset(waitTime)
}

if err := rows.Close(); err != nil {
Expand All @@ -149,6 +195,7 @@ func (h *Handler) ComQuery(
return nil
}

close(quit)
return callback(r)
}

Expand Down
73 changes: 57 additions & 16 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net"
"reflect"
"testing"
"time"
"unsafe"

sqle "github.com/src-d/go-mysql-server"
Expand Down Expand Up @@ -37,22 +38,23 @@ func setupMemDB(require *require.Assertions) *sqle.Engine {
return e
}

// This session builder is used as dummy mysql Conn is not complete and
// causes panic when accessing remote address.
func testSessionBuilder(c *mysql.Conn, addr string) sql.Session {
const client = "127.0.0.1:34567"
return sql.NewSession(addr, client, c.User, c.ConnectionID)
}

func TestHandlerOutput(t *testing.T) {
// This session builder is used as dummy mysql Conn is not complete and
// causes panic when accessing remote address.
testSessionBuilder := func(c *mysql.Conn, addr string) sql.Session {
client := "127.0.0.1:34567"
return sql.NewSession(addr, client, c.User, c.ConnectionID)
}

e := setupMemDB(require.New(t))
dummyConn := &mysql.Conn{ConnectionID: 1}
handler := NewHandler(e, NewSessionManager(testSessionBuilder, opentracing.NoopTracer{}, "foo"))
handler := NewHandler(e, NewSessionManager(testSessionBuilder, opentracing.NoopTracer{}, "foo"), 0)
handler.NewConnection(dummyConn)

type exptectedValues struct {
callsToCallback int
lenLastBacth int
lenLastBatch int
lastRowsAffected uint64
}

Expand All @@ -70,7 +72,7 @@ func TestHandlerOutput(t *testing.T) {
query: "SELECT * FROM test",
expected: exptectedValues{
callsToCallback: 11,
lenLastBacth: 10,
lenLastBatch: 10,
lastRowsAffected: uint64(10),
},
},
Expand All @@ -81,7 +83,7 @@ func TestHandlerOutput(t *testing.T) {
query: "SELECT * FROM test limit 100",
expected: exptectedValues{
callsToCallback: 1,
lenLastBacth: 100,
lenLastBatch: 100,
lastRowsAffected: uint64(100),
},
},
Expand All @@ -92,7 +94,7 @@ func TestHandlerOutput(t *testing.T) {
query: "SELECT * FROM test limit 60",
expected: exptectedValues{
callsToCallback: 1,
lenLastBacth: 60,
lenLastBatch: 60,
lastRowsAffected: uint64(60),
},
},
Expand All @@ -103,7 +105,7 @@ func TestHandlerOutput(t *testing.T) {
query: "SELECT * FROM test limit 200",
expected: exptectedValues{
callsToCallback: 2,
lenLastBacth: 100,
lenLastBatch: 100,
lastRowsAffected: uint64(100),
},
},
Expand All @@ -114,7 +116,7 @@ func TestHandlerOutput(t *testing.T) {
query: "SELECT * FROM test limit 530",
expected: exptectedValues{
callsToCallback: 6,
lenLastBacth: 30,
lenLastBatch: 30,
lastRowsAffected: uint64(30),
},
},
Expand All @@ -123,18 +125,18 @@ func TestHandlerOutput(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var callsToCallback int
var lenLastBacth int
var lenLastBatch int
var lastRowsAffected uint64
err := handler.ComQuery(test.conn, test.query, func(res *sqltypes.Result) error {
callsToCallback++
lenLastBacth = len(res.Rows)
lenLastBatch = len(res.Rows)
lastRowsAffected = res.RowsAffected
return nil
})

require.NoError(t, err)
require.Equal(t, test.expected.callsToCallback, callsToCallback)
require.Equal(t, test.expected.lenLastBacth, lenLastBacth)
require.Equal(t, test.expected.lenLastBatch, lenLastBatch)
require.Equal(t, test.expected.lastRowsAffected, lastRowsAffected)

})
Expand Down Expand Up @@ -174,6 +176,7 @@ func TestHandlerKill(t *testing.T) {
opentracing.NoopTracer{},
"foo",
),
0,
)

require.Len(handler.c, 0)
Expand Down Expand Up @@ -243,3 +246,41 @@ func TestSchemaToFields(t *testing.T) {
fields := schemaToFields(schema)
require.Equal(expected, fields)
}

func TestHandlerTimeout(t *testing.T) {
require := require.New(t)

e := setupMemDB(require)
e2 := setupMemDB(require)

timeOutHandler := NewHandler(
e, NewSessionManager(testSessionBuilder, opentracing.NoopTracer{}, "foo"),
1 * time.Second)

noTimeOutHandler := NewHandler(
e2, NewSessionManager(testSessionBuilder, opentracing.NoopTracer{}, "foo"),
0)
require.Equal(1 * time.Second, timeOutHandler.readTimeout)
require.Equal(0 * time.Second, noTimeOutHandler.readTimeout)

connTimeout := newConn(1)
timeOutHandler.NewConnection(connTimeout)

connNoTimeout := newConn(2)
noTimeOutHandler.NewConnection(connNoTimeout)

err := timeOutHandler.ComQuery(connTimeout, "SELECT SLEEP(2)", func(res *sqltypes.Result) error {
return nil
})
require.EqualError(err, "row read wait bigger than connection timeout")

err = timeOutHandler.ComQuery(connTimeout, "SELECT SLEEP(0.5)", func(res *sqltypes.Result) error {
return nil
})
require.NoError(err)

err = noTimeOutHandler.ComQuery(connNoTimeout, "SELECT SLEEP(2)", func(res *sqltypes.Result) error {
return nil
})
require.NoError(err)
}
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func NewServer(cfg Config, e *sqle.Engine, sb SessionBuilder) (*Server, error) {
cfg.ConnWriteTimeout = 0
}

handler := NewHandler(e, NewSessionManager(sb, tracer, cfg.Address))
handler := NewHandler(e, NewSessionManager(sb, tracer, cfg.Address), cfg.ConnReadTimeout)
a := cfg.Auth.Mysql()
l, err := mysql.NewListener(cfg.Protocol, cfg.Address, a, handler, cfg.ConnReadTimeout, cfg.ConnWriteTimeout)
if err != nil {
Expand Down