diff --git a/executor/executor_test.go b/executor/executor_test.go index 351b2f9a69c2c..dc7ef6078e578 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1653,6 +1653,13 @@ func (s *testSuite) TestGeneratedColumnRead(c *C) { result = tk.MustQuery(`SELECT * FROM test_gc_read ORDER BY a`) result.Check(testkit.Rows(`10 `, `11 2 13 22`, `13 4 17 52`, `18 8 26 144`)) + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + tk.MustExec("insert into t values(18)") + tk.MustExec("update test_gc_read set a = a+1 where a in (select a from t)") + result = tk.MustQuery("select * from test_gc_read order by a") + result.Check(testkit.Rows(`10 `, `11 2 13 22`, `13 4 17 52`, `19 8 27 152`)) + // Test different types between generation expression and generated column. tk.MustExec(`CREATE TABLE test_gc_read_cast(a VARCHAR(255), b VARCHAR(255), c INT AS (JSON_EXTRACT(a, b)), d INT AS (JSON_EXTRACT(a, b)) STORED)`) tk.MustExec(`INSERT INTO test_gc_read_cast (a, b) VALUES ('{"a": "3"}', '$.a')`) diff --git a/executor/update_test.go b/executor/update_test.go index 405ac69a6faf1..3308322d0786b 100644 --- a/executor/update_test.go +++ b/executor/update_test.go @@ -227,3 +227,13 @@ func (s *testUpdateSuite) TestUpdateWithSubquery(c *C) { tk.MustExec("update t1 set status = 'N' where status = 'F' and (id in (select id from t2 where field = 'MAIN') or id2 in (select id from t2 where field = 'main'))") tk.MustQuery("select * from t1").Check(testkit.Rows("abc N abc")) } + +func (s *testUpdateSuite) TestUpdateMultiDatabaseTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop database if exists test2") + tk.MustExec("create database test2") + tk.MustExec("create table t(a int, b int generated always as (a+1) virtual)") + tk.MustExec("create table test2.t(a int, b int generated always as (a+1) virtual)") + tk.MustExec("update t, test2.t set test.t.a=1") +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 785c4f90dbde2..30ec613ada651 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2290,7 +2290,10 @@ func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { proj.SetChildren(p) p = proj } - orderedList, np, err := b.buildUpdateLists(tableList, update.List, p) + + var updateTableList []*ast.TableName + updateTableList = extractTableList(sel.From.TableRefs, updateTableList, true) + orderedList, np, err := b.buildUpdateLists(updateTableList, update.List, p) if err != nil { return nil, errors.Trace(err) } @@ -2323,8 +2326,6 @@ func (b *planBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A // If columns in set list contains generated columns, raise error. // And, fill virtualAssignments here; that's for generated columns. virtualAssignments := make([]*ast.Assignment, 0) - tableAsName := make(map[*model.TableInfo][]*model.CIStr) - extractTableAsNameForUpdate(p, tableAsName) for _, tn := range tableList { tableInfo := tn.TableInfo @@ -2340,12 +2341,10 @@ func (b *planBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A if _, ok := modifyColumns[columnFullName]; ok { return nil, nil, ErrBadGeneratedColumn.GenWithStackByArgs(colInfo.Name.O, tableInfo.Name.O) } - for _, asName := range tableAsName[tableInfo] { - virtualAssignments = append(virtualAssignments, &ast.Assignment{ - Column: &ast.ColumnName{Table: *asName, Name: colInfo.Name}, - Expr: tableVal.Cols()[i].GeneratedExpr, - }) - } + virtualAssignments = append(virtualAssignments, &ast.Assignment{ + Column: &ast.ColumnName{Schema: tn.Schema, Table: tn.Name, Name: colInfo.Name}, + Expr: tableVal.Cols()[i].GeneratedExpr, + }) } } @@ -2382,10 +2381,19 @@ func (b *planBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A p = np newList = append(newList, &expression.Assignment{Col: col, Expr: newExpr}) } + + tblDbMap := make(map[string]string, len(tableList)) + for _, tbl := range tableList { + tblDbMap[tbl.Name.L] = tbl.DBInfo.Name.L + } for _, assign := range newList { col := assign.Col dbName := col.DBName.L + // To solve issue#10028, we need to get database name by the table alias name. + if dbNameTmp, ok := tblDbMap[col.TblName.L]; ok { + dbName = dbNameTmp + } if dbName == "" { dbName = b.ctx.GetSessionVars().CurrentDB } @@ -2394,47 +2402,6 @@ func (b *planBuilder) buildUpdateLists(tableList []*ast.TableName, list []*ast.A return newList, p, nil } -// extractTableAsNameForUpdate extracts tables' alias names for update. -func extractTableAsNameForUpdate(p LogicalPlan, asNames map[*model.TableInfo][]*model.CIStr) { - switch x := p.(type) { - case *DataSource: - alias := extractTableAlias(p) - if alias != nil { - if _, ok := asNames[x.tableInfo]; !ok { - asNames[x.tableInfo] = make([]*model.CIStr, 0, 1) - } - asNames[x.tableInfo] = append(asNames[x.tableInfo], alias) - } - case *LogicalProjection: - if !x.calculateGenCols { - return - } - - ds, isDS := x.Children()[0].(*DataSource) - if !isDS { - // try to extract the DataSource below a LogicalUnionScan. - if us, isUS := x.Children()[0].(*LogicalUnionScan); isUS { - ds, isDS = us.Children()[0].(*DataSource) - } - } - if !isDS { - return - } - - alias := extractTableAlias(x) - if alias != nil { - if _, ok := asNames[ds.tableInfo]; !ok { - asNames[ds.tableInfo] = make([]*model.CIStr, 0, 1) - } - asNames[ds.tableInfo] = append(asNames[ds.tableInfo], alias) - } - default: - for _, child := range p.Children() { - extractTableAsNameForUpdate(child, asNames) - } - } -} - func (b *planBuilder) buildDelete(delete *ast.DeleteStmt) (Plan, error) { if b.pushTableHints(delete.TableHints) { // table hints are only visible in the current DELETE statement. @@ -2568,6 +2535,7 @@ func extractTableList(node ast.ResultSetNode, input []*ast.TableName, asName boo if x.AsName.L != "" && asName { newTableName := *s newTableName.Name = x.AsName + newTableName.Schema = model.NewCIStr("") input = append(input, &newTableName) } else { input = append(input, s) diff --git a/session/session_test.go b/session/session_test.go index f4da72ad9726b..418ff6e88c25e 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -2424,6 +2424,21 @@ s.a = t.a and t.c >= 1 and t.c <= 10000 and s.b !='xx';`) + // Fix issue 10028 + tk.MustExec("create database ap") + tk.MustExec("create database tp") + tk.MustExec("grant all privileges on ap.* to xxx") + tk.MustExec("grant select on tp.* to xxx") + tk.MustExec("flush privileges") + tk.MustExec("create table tp.record( id int,name varchar(128),age int)") + tk.MustExec("insert into tp.record (id,name,age) values (1,'john',18),(2,'lary',19),(3,'lily',18)") + tk.MustExec("create table ap.record( id int,name varchar(128),age int)") + tk.MustExec("insert into ap.record(id) values(1)") + c.Assert(tk1.Se.Auth(&auth.UserIdentity{Username: "xxx", Hostname: "localhost"}, + []byte(""), + []byte("")), IsTrue) + _, err2 := tk1.Exec("update ap.record t inner join tp.record tt on t.id=tt.id set t.name=tt.name") + c.Assert(err2, IsNil) } func (s *testSessionSuite) TestTxnGoString(c *C) {