diff --git a/go/vt/vitessdriver/driver.go b/go/vt/vitessdriver/driver.go index a7dd61caebf..7f8a50a4956 100644 --- a/go/vt/vitessdriver/driver.go +++ b/go/vt/vitessdriver/driver.go @@ -20,11 +20,17 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/base64" "encoding/json" "errors" + "fmt" "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" "vitess.io/vitess/go/vt/vtgate/grpcvtgateconn" "vitess.io/vitess/go/vt/vtgate/vtgateconn" ) @@ -167,6 +173,10 @@ type Configuration struct { // // Default: "vitess" DriverName string `json:"-"` + + // SessionToken is a protobuf encoded vtgatepb.Session represented as base64, which + // can be used to distribute a transaction over the wire. + SessionToken string } // toJSON converts Configuration to the JSON string which is required by the @@ -205,7 +215,15 @@ func (c *conn) dial() error { if err != nil { return err } - c.session = c.conn.Session(c.Target, nil) + if c.Configuration.SessionToken != "" { + sessionFromToken, err := sessionTokenToSession(c.Configuration.SessionToken) + if err != nil { + return err + } + c.session = c.conn.SessionFromPb(sessionFromToken) + } else { + c.session = c.conn.Session(c.Target, nil) + } return nil } @@ -231,7 +249,140 @@ func (c *conn) Close() error { return nil } +// DistributedTxFromSessionToken allows users to send serialized sessions over the wire and +// reconnect to an existing transaction. Setting the sessionToken and address on the +// supplied configuration is the minimum required +// WARNING: the original Tx must already have already done work on all shards to be affected, +// otherwise the ShardSessions will not be sent through in the session token, and thus will +// never be committed in the source. The returned validation function checks to make sure that +// the new transaction work has not added any new ShardSessions. +func DistributedTxFromSessionToken(ctx context.Context, c Configuration) (*sql.Tx, func() error, error) { + if c.SessionToken == "" { + return nil, nil, errors.New("c.SessionToken is required") + } + + session, err := sessionTokenToSession(c.SessionToken) + if err != nil { + return nil, nil, err + } + + // if there isn't 1 or more shards already referenced, no work in this Tx can be committed + originalShardSessionCount := len(session.ShardSessions) + if originalShardSessionCount == 0 { + return nil, nil, errors.New("there must be at least 1 ShardSession") + } + + db, err := OpenWithConfiguration(c) + if err != nil { + return nil, nil, err + } + + // this should return the only connection associated with the db + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return nil, nil, err + } + + // this is designed to be run after all new work has been done in the tx, similar to + // where you would traditionally run a tx.Commit, to help prevent you from silently + // losing transactional data. + validationFunc := func() error { + var sessionToken string + sessionToken, err = SessionTokenFromTx(ctx, tx) + if err != nil { + return err + } + + session, err = sessionTokenToSession(sessionToken) + if err != nil { + return err + } + + if len(session.ShardSessions) > originalShardSessionCount { + return fmt.Errorf("mismatched ShardSession count: originally %d, now %d", + originalShardSessionCount, len(session.ShardSessions), + ) + } + + return nil + } + + return tx, validationFunc, nil +} + +// SessionTokenFromTx serializes the sessionFromToken on the tx, which can be reconstituted +// into a *sql.Tx using DistributedTxFromSessionToken +func SessionTokenFromTx(ctx context.Context, tx *sql.Tx) (string, error) { + var sessionToken string + + err := tx.QueryRowContext(ctx, "vt_session_token").Scan(&sessionToken) + if err != nil { + return "", err + } + + session, err := sessionTokenToSession(sessionToken) + if err != nil { + return "", err + } + + // if there isn't 1 or more shards already referenced, no work in this Tx can be committed + originalShardSessionCount := len(session.ShardSessions) + if originalShardSessionCount == 0 { + return "", errors.New("there must be at least 1 ShardSession") + } + + return sessionToken, nil +} + +func newSessionTokenRow(session *vtgatepb.Session, c *converter) (driver.Rows, error) { + sessionToken, err := sessionToSessionToken(session) + if err != nil { + return nil, err + } + + qr := sqltypes.Result{ + Fields: []*querypb.Field{{ + Name: "vt_session_token", + Type: sqltypes.VarBinary, + }}, + Rows: [][]sqltypes.Value{{ + sqltypes.NewVarBinary(sessionToken), + }}, + } + + return newRows(&qr, c), nil +} + +func sessionToSessionToken(session *vtgatepb.Session) (string, error) { + b, err := proto.Marshal(session) + if err != nil { + return "", err + } + + return base64.StdEncoding.EncodeToString(b), nil +} + +func sessionTokenToSession(sessionToken string) (*vtgatepb.Session, error) { + b, err := base64.StdEncoding.DecodeString(sessionToken) + if err != nil { + return nil, err + } + + session := &vtgatepb.Session{} + err = proto.Unmarshal(b, session) + if err != nil { + return nil, err + } + + return session, nil +} + func (c *conn) Begin() (driver.Tx, error) { + // if we're loading from an existing session, we need to avoid starting a new transaction + if c.Configuration.SessionToken != "" { + return c, nil + } + if _, err := c.Exec("begin", nil); err != nil { return nil, err } @@ -248,11 +399,25 @@ func (c *conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, err } func (c *conn) Commit() error { + // if we're loading from an existing session, disallow committing/rolling back the transaction + // this isn't a technical limitation, but is enforced to prevent misuse, so that only + // the original creator of the transaction can commit/rollback + if c.Configuration.SessionToken != "" { + return errors.New("calling Commit from a distributed tx is not allowed") + } + _, err := c.Exec("commit", nil) return err } func (c *conn) Rollback() error { + // if we're loading from an existing session, disallow committing/rolling back the transaction + // this isn't a technical limitation, but is enforced to prevent misuse, so that only + // the original creator of the transaction can commit/rollback + if c.Configuration.SessionToken != "" { + return errors.New("calling Rollback from a distributed tx is not allowed") + } + _, err := c.Exec("rollback", nil) return err } @@ -314,6 +479,11 @@ func (c *conn) Query(query string, args []driver.Value) (driver.Rows, error) { } func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + // special case for serializing the current sessionFromToken state + if query == "vt_session_token" { + return newSessionTokenRow(c.session.SessionPb(), c.convert) + } + bv, err := c.convert.bindVarsFromNamedValues(args) if err != nil { return nil, err diff --git a/go/vt/vitessdriver/driver_test.go b/go/vt/vitessdriver/driver_test.go index 36b5b91f612..6af7a534de8 100644 --- a/go/vt/vitessdriver/driver_test.go +++ b/go/vt/vitessdriver/driver_test.go @@ -224,7 +224,7 @@ func TestConfigurationToJSON(t *testing.T) { Streaming: true, DefaultLocation: "Local", } - want := `{"Protocol":"some-invalid-protocol","Address":"","Target":"ks2","Streaming":true,"DefaultLocation":"Local"}` + want := `{"Protocol":"some-invalid-protocol","Address":"","Target":"ks2","Streaming":true,"DefaultLocation":"Local","SessionToken":""}` json, err := config.toJSON() if err != nil { @@ -625,3 +625,103 @@ func TestTxExecStreamingNotAllowed(t *testing.T) { t.Errorf("err: %v, does not contain %s", err, want) } } + +func TestSessionToken(t *testing.T) { + c := Configuration{ + Protocol: "grpc", + Address: testAddress, + Target: "@primary", + } + + ctx := context.Background() + + db, err := OpenWithConfiguration(c) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + s, err := tx.Prepare("txRequest") + if err != nil { + t.Fatal(err) + } + + _, err = s.Exec(int64(0)) + if err != nil { + t.Fatal(err) + } + + sessionToken, err := SessionTokenFromTx(ctx, tx) + if err != nil { + t.Fatal(err) + } + + distributedTxConfig := Configuration{ + Address: testAddress, + Target: "@primary", + SessionToken: sessionToken, + } + + sameTx, sameValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) + if err != nil { + t.Fatal(err) + } + + newS, err := sameTx.Prepare("distributedTxRequest") + if err != nil { + t.Fatal(err) + } + + _, err = newS.Exec(int64(1)) + if err != nil { + t.Fatal(err) + } + + err = sameValidationFunc() + if err != nil { + t.Fatal(err) + } + + // enforce that Rollback can't be called on the distributed tx + noRollbackTx, noRollbackValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) + if err != nil { + t.Fatal(err) + } + + err = noRollbackValidationFunc() + if err != nil { + t.Fatal(err) + } + + err = noRollbackTx.Rollback() + if err == nil || err.Error() != "calling Rollback from a distributed tx is not allowed" { + t.Fatal(err) + } + + // enforce that Commit can't be called on the distributed tx + noCommitTx, noCommitValidationFunc, err := DistributedTxFromSessionToken(ctx, distributedTxConfig) + if err != nil { + t.Fatal(err) + } + + err = noCommitValidationFunc() + if err != nil { + t.Fatal(err) + } + + err = noCommitTx.Commit() + if err == nil || err.Error() != "calling Commit from a distributed tx is not allowed" { + t.Fatal(err) + } + + // finally commit the original tx + err = tx.Commit() + if err != nil { + t.Fatal(err) + } +} diff --git a/go/vt/vitessdriver/fakeserver_test.go b/go/vt/vitessdriver/fakeserver_test.go index 77591605e0a..eefa2abd285 100644 --- a/go/vt/vitessdriver/fakeserver_test.go +++ b/go/vt/vitessdriver/fakeserver_test.go @@ -17,21 +17,19 @@ limitations under the License. package vitessdriver import ( + "context" "errors" "fmt" "reflect" - "context" - "google.golang.org/protobuf/proto" "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/vtgate/vtgateservice" - binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + "vitess.io/vitess/go/vt/vtgate/vtgateservice" ) // fakeVTGateService has the server side of this fake @@ -228,6 +226,30 @@ var execMap = map[string]struct { result: &sqltypes.Result{}, session: session2, }, + "distributedTxRequest": { + execQuery: &queryExecute{ + SQL: "distributedTxRequest", + BindVariables: map[string]*querypb.BindVariable{ + "v1": sqltypes.Int64BindVariable(1), + }, + Session: &vtgatepb.Session{ + InTransaction: true, + ShardSessions: []*vtgatepb.Session_ShardSession{ + { + Target: &querypb.Target{ + Keyspace: "ks", + Shard: "1", + TabletType: topodatapb.TabletType_PRIMARY, + }, + TransactionId: 1, + }, + }, + TargetString: "@rdonly", + }, + }, + result: &sqltypes.Result{}, + session: session2, + }, "begin": { execQuery: &queryExecute{ SQL: "begin", diff --git a/go/vt/vtgate/vtgateconn/vtgateconn.go b/go/vt/vtgate/vtgateconn/vtgateconn.go index f216e12e6cf..6483526aabb 100644 --- a/go/vt/vtgate/vtgateconn/vtgateconn.go +++ b/go/vt/vtgate/vtgateconn/vtgateconn.go @@ -17,11 +17,10 @@ limitations under the License. package vtgateconn import ( + "context" "flag" "fmt" - "context" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" @@ -55,6 +54,19 @@ func (conn *VTGateConn) Session(targetString string, options *querypb.ExecuteOpt } } +// SessionPb returns the underlying proto session. +func (sn *VTGateSession) SessionPb() *vtgatepb.Session { + return sn.session +} + +// SessionFromPb returns a VTGateSession based on the provided proto session. +func (conn *VTGateConn) SessionFromPb(sn *vtgatepb.Session) *VTGateSession { + return &VTGateSession{ + session: sn, + impl: conn.impl, + } +} + // ResolveTransaction resolves the 2pc transaction. func (conn *VTGateConn) ResolveTransaction(ctx context.Context, dtid string) error { return conn.impl.ResolveTransaction(ctx, dtid)