Skip to content

Commit

Permalink
Merge pull request #5451 from systay/last_insert_id
Browse files Browse the repository at this point in the history
Handle `last_insert_id()` and `database()` in vtgate
  • Loading branch information
sougou authored Jan 7, 2020
2 parents 32b04a8 + a40f4ce commit 219b34f
Show file tree
Hide file tree
Showing 28 changed files with 903 additions and 251 deletions.
4 changes: 0 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ require (
github.com/golang/mock v1.3.1
github.com/golang/protobuf v1.3.2
github.com/golang/snappy v0.0.0-20170215233205-553a64147049
github.com/google/btree v1.0.0 // indirect
github.com/golangci/gocyclo v0.0.0-20180528144436-0a533e8fa43d // indirect
github.com/golangci/golangci-lint v1.21.0 // indirect
github.com/golangci/revgrep v0.0.0-20180812185044-276a5c0a1039 // indirect
Expand All @@ -51,9 +50,6 @@ require (
github.com/mattn/go-runewidth v0.0.1 // indirect
github.com/minio/minio-go v0.0.0-20190131015406-c8a261de75c1
github.com/mitchellh/go-testing-interface v1.0.0 // indirect
github.com/mitchellh/mapstructure v1.1.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.1 // indirect
github.com/olekukonko/tablewriter v0.0.0-20160115111002-cca8bbc07984
github.com/opentracing-contrib/go-grpc v0.0.0-20180928155321-4b5a12d3ff02
github.com/opentracing/opentracing-go v1.1.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,8 @@ golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgw
golang.org/x/tools v0.0.0-20190719005602-e377ae9d6386/go.mod h1:jcCCGcm9btYwXyDqrUWc6MKQKKGJCWEQ3AfLSRIbEuI=
golang.org/x/tools v0.0.0-20190830154057-c17b040389b9 h1:5/jaG/gKlo3xxvUn85ReNyTlN7BvlPPsxC6sHZKjGEE=
golang.org/x/tools v0.0.0-20190830154057-c17b040389b9/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191209225234-22774f7dae43 h1:NfPq5mgc5ArFgVLCpeS4z07IoxSAqVfV/gQ5vxdgaxI=
golang.org/x/tools v0.0.0-20191209225234-22774f7dae43/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20190910044552-dd2b5c81c578/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20190911151314-feee8acb394c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20190930201159-7c411dea38b0/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
Expand Down
275 changes: 143 additions & 132 deletions go/vt/proto/vtgate/vtgate.pb.go

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ func TestNormalize(t *testing.T) {
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.BytesBindVariable([]byte("aa")),
},
}, {
// str val in select
in: "select 'aa' from t",
outstmt: "select :bv1 from t",
outbv: map[string]*querypb.BindVariable{
"bv1": sqltypes.BytesBindVariable([]byte("aa")),
},
}, {
// int val
in: "select * from t where v1 = 1",
Expand Down
40 changes: 40 additions & 0 deletions go/vt/vtgate/endtoend/database_func_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
Copyright 2019 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package endtoend

import (
"context"
"fmt"
"testing"

"vitess.io/vitess/go/mysql"
)

func TestDatabaseFunc(t *testing.T) {
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
if err != nil {
t.Fatal(err)
}
defer conn.Close()

exec(t, conn, "use ks")
qr := exec(t, conn, "select database()")
if got, want := fmt.Sprintf("%v", qr.Rows), `[[VARCHAR("ks")]]`; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
}
41 changes: 41 additions & 0 deletions go/vt/vtgate/endtoend/last_insert_id_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
Copyright 2019 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package endtoend

import (
"context"
"fmt"
"testing"

"vitess.io/vitess/go/mysql"
)

func TestLastInsertId(t *testing.T) {
ctx := context.Background()
conn, err := mysql.Connect(ctx, &vtParams)
if err != nil {
t.Fatal(err)
}
defer conn.Close()

exec(t, conn, "insert into t1_last_insert_id(id1) values(42)")

qr := exec(t, conn, "select last_insert_id()")
if got, want := fmt.Sprintf("%v", qr.Rows), `[[INT64(1)]]`; got != want {
t.Errorf("select:\n%v want\n%v", got, want)
}
}
16 changes: 16 additions & 0 deletions go/vt/vtgate/endtoend/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ create table t2_id4_idx(
primary key(id),
key idx_id4(id4)
) Engine=InnoDB;
create table t1_last_insert_id(
id bigint not null auto_increment,
id1 bigint,
primary key(id)
) Engine=InnoDB;
`

vschema = &vschemapb.Keyspace{
Expand Down Expand Up @@ -151,6 +157,16 @@ create table t2_id4_idx(
Type: sqltypes.VarChar,
}},
},
"t1_last_insert_id": {
ColumnVindexes: []*vschemapb.ColumnVindex{{
Column: "id1",
Name: "hash",
}},
Columns: []*vschemapb.Column{{
Name: "id1",
Type: sqltypes.Int64,
}},
},
},
}
)
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ const (
// This is used for sending different IN clause values
// to different shards.
ListVarName = "__vals"
//LastInsertIDName is a reserved bind var name for last_insert_id()
LastInsertIDName = "__lastInsertId"
//DBVarName is a reserved bind var name for database()
DBVarName = "__vtdbname"
)

// VCursor defines the interface the engine will use
Expand Down Expand Up @@ -94,6 +98,10 @@ type Plan struct {
Rows uint64 `json:",omitempty"`
// Total number of errors
Errors uint64 `json:",omitempty"`
// NeedsLastInsertID signals whether this plan will need to be provided with last_insert_id
NeedsLastInsertID bool `json:"-"` // don't include in the json representation
// NeedsDatabaseName signals whether this plan will need to be provided with the database name
NeedsDatabaseName bool `json:"-"` // don't include in the json representation
}

// AddStats updates the plan execution statistics
Expand Down
22 changes: 18 additions & 4 deletions go/vt/vtgate/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st

switch stmtType {
case sqlparser.StmtSelect:
return e.handleExec(ctx, safeSession, sql, bindVars, destKeyspace, destTabletType, dest, logStats)
return e.handleExec(ctx, safeSession, sql, bindVars, destKeyspace, destTabletType, dest, logStats, stmtType)
case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete:
safeSession := safeSession

Expand All @@ -212,7 +212,7 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st
// at the beginning, but never after.
safeSession.SetAutocommittable(mustCommit)

qr, err := e.handleExec(ctx, safeSession, sql, bindVars, destKeyspace, destTabletType, dest, logStats)
qr, err := e.handleExec(ctx, safeSession, sql, bindVars, destKeyspace, destTabletType, dest, logStats, stmtType)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -247,7 +247,7 @@ func (e *Executor) execute(ctx context.Context, safeSession *SafeSession, sql st
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unrecognized statement: %s", sql)
}

func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, destKeyspace string, destTabletType topodatapb.TabletType, dest key.Destination, logStats *LogStats) (*sqltypes.Result, error) {
func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, destKeyspace string, destTabletType topodatapb.TabletType, dest key.Destination, logStats *LogStats, stmtType sqlparser.StatementType) (*sqltypes.Result, error) {
if dest != nil {
// V1 mode or V3 mode with a forced shard or range target
// TODO(sougou): change this flow to go through V3 functions
Expand Down Expand Up @@ -306,8 +306,19 @@ func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql
return nil, err
}

qr, err := plan.Instructions.Execute(vcursor, bindVars, true)
if plan.NeedsLastInsertID {
bindVars[engine.LastInsertIDName] = sqltypes.Uint64BindVariable(safeSession.GetLastInsertId())
}
if plan.NeedsDatabaseName {
keyspace, _, _, _ := e.ParseDestinationTarget(safeSession.TargetString)
if keyspace == "" {
bindVars[engine.DBVarName] = sqltypes.NullBindVariable
} else {
bindVars[engine.DBVarName] = sqltypes.StringBindVariable(keyspace)
}
}

qr, err := plan.Instructions.Execute(vcursor, bindVars, true)
logStats.ExecuteTime = time.Since(execStart)

e.updateQueryCounts(plan.Instructions.RouteType(), plan.Instructions.GetKeyspaceName(), plan.Instructions.GetTableName(), int64(logStats.ShardQueries))
Expand All @@ -318,6 +329,9 @@ func (e *Executor) handleExec(ctx context.Context, safeSession *SafeSession, sql
errCount = 1
} else {
logStats.RowsAffected = qr.RowsAffected
if qr != nil && stmtType == sqlparser.StmtInsert {
safeSession.LastInsertId = qr.InsertID
}
}

// Check if there was partial DML execution. If so, rollback the transaction.
Expand Down
17 changes: 17 additions & 0 deletions go/vt/vtgate/executor_dml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"strings"
"testing"

"github.com/stretchr/testify/require"

"golang.org/x/net/context"

"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -1858,3 +1860,18 @@ func TestDeleteEqualWithPrepare(t *testing.T) {
t.Errorf("sbclookup.Queries:\n%+v, want\n%+v\n", sbclookup.Queries, wantQueries)
}
}

func TestUpdateLastInsertID(t *testing.T) {
executor, sbc1, _, _ := createExecutorEnv()

sql := "update user set a = last_insert_id() where id = 1"
masterSession.LastInsertId = 43
_, err := executorExec(executor, sql, map[string]*querypb.BindVariable{})
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{{
Sql: "update user set a = :__lastInsertId where id = 1 /* vtgate:: keyspace_id:166b40b44aba4bd6 */",
BindVariables: map[string]*querypb.BindVariable{"__lastInsertId": sqltypes.Uint64BindVariable(43)},
}}

require.Equal(t, wantQueries, sbc1.Queries)
}
Loading

0 comments on commit 219b34f

Please sign in to comment.