Skip to content

Commit

Permalink
vitessdriver: full distributed tx support + tests
Browse files Browse the repository at this point in the history
Signed-off-by: Derek Perkins <derek@nozzle.io>
  • Loading branch information
derekperkins committed Dec 30, 2021
1 parent 65ccf9f commit 40e8160
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 40 deletions.
92 changes: 57 additions & 35 deletions go/vt/vitessdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"google.golang.org/protobuf/proto"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"

"vitess.io/vitess/go/vt/vtgate/grpcvtgateconn"
"vitess.io/vitess/go/vt/vtgate/vtgateconn"
Expand Down Expand Up @@ -119,6 +120,13 @@ func (d drv) Open(name string) (driver.Conn, error) {

c.setDefaults()

if c.Configuration.SessionToken != "" {
c.Configuration.sessionFromToken, err = sessionTokenToSession(c.Configuration.SessionToken)
if err != nil {
return nil, err
}
}

if c.convert, err = newConverter(&c.Configuration); err != nil {
return nil, err
}
Expand Down Expand Up @@ -173,7 +181,8 @@ type Configuration struct {
DriverName string `json:"-"`

// allows for some magic
Session *vtgateconn.VTGateSession
SessionToken string
sessionFromToken *vtgatepb.Session
}

// toJSON converts Configuration to the JSON string which is required by the
Expand Down Expand Up @@ -212,8 +221,8 @@ func (c *conn) dial() error {
if err != nil {
return err
}
if c.Configuration.Session != nil {
c.session = c.Configuration.Session
if c.Configuration.sessionFromToken != nil {
c.session = c.conn.SessionFromPb(c.Configuration.sessionFromToken)
} else {
c.session = c.conn.Session(c.Target, nil)
}
Expand Down Expand Up @@ -242,46 +251,40 @@ func (c *conn) Close() error {
return nil
}

// SessionTokenFromTx serializes the session on the tx, which can be reconstituted
// into a *sql.Tx using DistributedTxFromSessionToken
func SessionTokenFromTx(ctx context.Context, tx *sql.Tx) (string, error) {
var sessionToken string

err := tx.QueryRowContext(ctx, "vt_session_token").Scan(&sessionToken)
if err != nil {
return "", err
}

return sessionToken, nil
}

// DistributedTxFromSessionToken allows users to send serialized sessions over the wire and
// reconnect to an existing transaction
func DistributedTxFromSessionToken(ctx context.Context, sessionToken string) (*sql.Tx, error) {
session, err := sessionTokenToSession(sessionToken)
if err != nil {
return nil, err
// reconnect to an existing transaction. Setting the sessionToken and address on the
// supplied configuration is the minimum required
func DistributedTxFromSessionToken(ctx context.Context, c Configuration) (*sql.Tx, error) {
if c.SessionToken == "" {
return nil, errors.New("c.SessionToken is required")
}
if c.Address == "" {
return nil, errors.New("c.Address is required")
}

db, err := OpenWithConfiguration(Configuration{
// include session here - there will be a new *DB created each time
// that stores the session state in the &conn{} struct
Session: session,
})
db, err := OpenWithConfiguration(c)
if err != nil {
return nil, err
}

// this should return the only connection associated with the db
c, err := db.Conn(ctx)
return db.BeginTx(ctx, nil)
}

// SessionTokenFromTx serializes the sessionFromToken on the tx, which can be reconstituted
// into a *sql.Tx using DistributedTxFromSessionToken
func SessionTokenFromTx(ctx context.Context, tx *sql.Tx) (string, error) {
var sessionToken string

err := tx.QueryRowContext(ctx, "vt_session_token").Scan(&sessionToken)
if err != nil {
return nil, err
return "", err
}

return c.BeginTx(ctx, nil)
return sessionToken, nil
}

func newSessionTokenRow(session *vtgateconn.VTGateSession, c *converter) (driver.Rows, error) {
func newSessionTokenRow(session *vtgatepb.Session, c *converter) (driver.Rows, error) {
sessionToken, err := sessionToSessionToken(session)
if err != nil {
return nil, err
Expand All @@ -300,7 +303,7 @@ func newSessionTokenRow(session *vtgateconn.VTGateSession, c *converter) (driver
return newRows(&qr, c), nil
}

func sessionToSessionToken(session *vtgateconn.VTGateSession) (string, error) {
func sessionToSessionToken(session *vtgatepb.Session) (string, error) {
b, err := proto.Marshal(session)
if err != nil {
return "", err
Expand All @@ -309,13 +312,13 @@ func sessionToSessionToken(session *vtgateconn.VTGateSession) (string, error) {
return base64.StdEncoding.EncodeToString(b), nil
}

func sessionTokenToSession(sessionToken string) (*vtgateconn.VTGateSession, error) {
func sessionTokenToSession(sessionToken string) (*vtgatepb.Session, error) {
b, err := base64.StdEncoding.DecodeString(sessionToken)
if err != nil {
return nil, err
}

var session *vtgateconn.VTGateSession
session := &vtgatepb.Session{}
err = proto.Unmarshal(b, session)
if err != nil {
return nil, err
Expand All @@ -325,6 +328,11 @@ func sessionTokenToSession(sessionToken string) (*vtgateconn.VTGateSession, erro
}

func (c *conn) Begin() (driver.Tx, error) {
// if we're loading from an existing session, we need to avoid starting a new transaction
if c.Configuration.SessionToken != "" {
return c, nil
}

if _, err := c.Exec("begin", nil); err != nil {
return nil, err
}
Expand All @@ -341,11 +349,25 @@ func (c *conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, err
}

func (c *conn) Commit() error {
// if we're loading from an existing session, disallow committing/rolling back the transaction
// this isn't a technical limitation, but is enforced to prevent misuse, so that only
// the original creator of the transaction can commit/rollback
if c.Configuration.SessionToken != "" {
return errors.New("calling Commit from a distributed tx is not allowed")
}

_, err := c.Exec("commit", nil)
return err
}

func (c *conn) Rollback() error {
// if we're loading from an existing session, disallow committing/rolling back the transaction
// this isn't a technical limitation, but is enforced to prevent misuse, so that only
// the original creator of the transaction can commit/rollback
if c.Configuration.SessionToken != "" {
return errors.New("calling Rollback from a distributed tx is not allowed")
}

_, err := c.Exec("rollback", nil)
return err
}
Expand Down Expand Up @@ -407,9 +429,9 @@ func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
}

func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
// special case for serializing the current session state
// special case for serializing the current sessionFromToken state
if query == "vt_session_token" {
return newSessionTokenRow(c.session, c.convert)
return newSessionTokenRow(c.session.SessionPb(), c.convert)
}

bv, err := c.convert.bindVarsFromNamedValues(args)
Expand Down
85 changes: 85 additions & 0 deletions go/vt/vitessdriver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,88 @@ func TestTxExecStreamingNotAllowed(t *testing.T) {
t.Errorf("err: %v, does not contain %s", err, want)
}
}

func TestSessionToken(t *testing.T) {
c := Configuration{
Protocol: "grpc",
Address: testAddress,
Target: "@primary",
}

ctx := context.Background()

db, err := OpenWithConfiguration(c)
if err != nil {
t.Fatal(err)
}
defer db.Close()

tx, err := db.Begin()
if err != nil {
t.Fatal(err)
}

s, err := tx.Prepare("txRequest")
if err != nil {
t.Fatal(err)
}

_, err = s.Exec(int64(0))
if err != nil {
t.Fatal(err)
}

sessionToken, err := SessionTokenFromTx(ctx, tx)
if err != nil {
t.Fatal(err)
}

distributedTxConfig := Configuration{
Address: testAddress,
Target: "@primary",
SessionToken: sessionToken,
}

sameTx, err := DistributedTxFromSessionToken(ctx, distributedTxConfig)
if err != nil {
t.Fatal(err)
}

newS, err := sameTx.Prepare("distributedTxRequest")
if err != nil {
t.Fatal(err)
}

_, err = newS.Exec(int64(1))
if err != nil {
t.Fatal(err)
}

// enforce that Rollback can't be called on the distributed tx
noRollbackTx, err := DistributedTxFromSessionToken(ctx, distributedTxConfig)
if err != nil {
t.Fatal(err)
}

err = noRollbackTx.Rollback()
if err != nil && err.Error() != "calling Rollback from a distributed tx is not allowed" {
t.Fatal(err)
}

// enforce that Commit can't be called on the distributed tx
noCommitTx, err := DistributedTxFromSessionToken(ctx, distributedTxConfig)
if err != nil {
t.Fatal(err)
}

err = noCommitTx.Commit()
if err != nil && err.Error() != "calling Commit from a distributed tx is not allowed" {
t.Fatal(err)
}

// finally commit the original tx
err = tx.Commit()
if err != nil {
t.Fatal(err)
}
}
28 changes: 25 additions & 3 deletions go/vt/vitessdriver/fakeserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@ limitations under the License.
package vitessdriver

import (
"context"
"errors"
"fmt"
"reflect"

"context"

"google.golang.org/protobuf/proto"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/vtgate/vtgateservice"

Expand Down Expand Up @@ -228,6 +226,30 @@ var execMap = map[string]struct {
result: &sqltypes.Result{},
session: session2,
},
"distributedTxRequest": {
execQuery: &queryExecute{
SQL: "distributedTxRequest",
BindVariables: map[string]*querypb.BindVariable{
"v1": sqltypes.Int64BindVariable(1),
},
Session: &vtgatepb.Session{
InTransaction: true,
ShardSessions: []*vtgatepb.Session_ShardSession{
{
Target: &querypb.Target{
Keyspace: "ks",
Shard: "1",
TabletType: topodatapb.TabletType_PRIMARY,
},
TransactionId: 1,
},
},
TargetString: "@rdonly",
},
},
result: &sqltypes.Result{},
session: session2,
},
"begin": {
execQuery: &queryExecute{
SQL: "begin",
Expand Down
16 changes: 14 additions & 2 deletions go/vt/vtgate/vtgateconn/vtgateconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@ limitations under the License.
package vtgateconn

import (
"context"
"flag"
"fmt"

"context"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/log"

Expand Down Expand Up @@ -55,6 +54,19 @@ func (conn *VTGateConn) Session(targetString string, options *querypb.ExecuteOpt
}
}

// SessionPb returns the underlying proto session.
func (sn *VTGateSession) SessionPb() *vtgatepb.Session {
return sn.session
}

// SessionFromPb returns a VTGateSession based on the provided proto session.
func (conn *VTGateConn) SessionFromPb(sn *vtgatepb.Session) *VTGateSession {
return &VTGateSession{
session: sn,
impl: conn.impl,
}
}

// ResolveTransaction resolves the 2pc transaction.
func (conn *VTGateConn) ResolveTransaction(ctx context.Context, dtid string) error {
return conn.impl.ResolveTransaction(ctx, dtid)
Expand Down

0 comments on commit 40e8160

Please sign in to comment.