diff --git a/go/test/endtoend/vtgate/foreignkey/fk_test.go b/go/test/endtoend/vtgate/foreignkey/fk_test.go index 147bd6f4d83..eb4ba14580e 100644 --- a/go/test/endtoend/vtgate/foreignkey/fk_test.go +++ b/go/test/endtoend/vtgate/foreignkey/fk_test.go @@ -136,10 +136,11 @@ func TestUpdateWithFK(t *testing.T) { utils.Exec(t, conn, `insert into u_t2(id, col2) values (342, 123), (19, 1234)`) utils.Exec(t, conn, `insert into u_t3(id, col3) values (32, 123), (1, 12)`) - t.Run("Cascade update with a new value", func(t *testing.T) { - t.Skip("This doesn't work right now. We are able to only cascade updates for which the data already exists in the parent table") - _ = utils.Exec(t, conn, `update u_t1 set col1 = 2 where id = 100`) - }) + // Cascade update with a new value + _ = utils.Exec(t, conn, `update u_t1 set col1 = 2 where id = 100`) + // Verify the result in u_t2 and u_t3 as well. + utils.AssertMatches(t, conn, `select * from u_t2 order by id`, `[[INT64(19) INT64(1234)] [INT64(342) NULL]]`) + utils.AssertMatches(t, conn, `select * from u_t3 order by id`, `[[INT64(1) INT64(12)] [INT64(32) INT64(2)]]`) // Update u_t1 which has a foreign key constraint to u_t2 with SET NULL type, and to u_t3 with CASCADE type. qr = utils.Exec(t, conn, `update u_t1 set col1 = 13 where id = 100`) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index c4d0bdd3199..21eb2bed982 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2532,3 +2532,14 @@ func (v *visitor) visitAllSelects(in SelectStatement, f func(p *Select, idx int) } panic("switch should be exhaustive") } + +func IsNonLiteral(updExprs UpdateExprs) bool { + for _, updateExpr := range updExprs { + switch updateExpr.Expr.(type) { + case *Argument, *NullVal, BoolVal, *Literal: + default: + return true + } + } + return false +} diff --git a/go/vt/vterrors/code.go b/go/vt/vterrors/code.go index aeddacc4a1e..b2f83b898b8 100644 --- a/go/vt/vterrors/code.go +++ b/go/vt/vterrors/code.go @@ -82,7 +82,6 @@ var ( VT12001 = errorWithoutState("VT12001", vtrpcpb.Code_UNIMPLEMENTED, "unsupported: %s", "This statement is unsupported by Vitess. Please rewrite your query to use supported syntax.") VT12002 = errorWithoutState("VT12002", vtrpcpb.Code_UNIMPLEMENTED, "unsupported: cross-shard foreign keys", "Vitess does not support cross shard foreign keys.") - VT12003 = errorWithoutState("VT12002", vtrpcpb.Code_UNIMPLEMENTED, "unsupported: foreign keys management at vitess", "Vitess does not support managing foreign keys tables.") // VT13001 General Error VT13001 = errorWithoutState("VT13001", vtrpcpb.Code_INTERNAL, "[BUG] %s", "This error should not happen and is a bug. Please file an issue on GitHub: https://github.com/vitessio/vitess/issues/new/choose.") @@ -148,7 +147,6 @@ var ( VT10001, VT12001, VT12002, - VT12003, VT13001, VT13002, VT14001, diff --git a/go/vt/vtgate/engine/cached_size.go b/go/vt/vtgate/engine/cached_size.go index 8ef69bfb89e..dcaefd270ed 100644 --- a/go/vt/vtgate/engine/cached_size.go +++ b/go/vt/vtgate/engine/cached_size.go @@ -294,43 +294,6 @@ func (cached *FkChild) CachedSize(alloc bool) int64 { } return size } -func (cached *FkParent) CachedSize(alloc bool) int64 { - if cached == nil { - return int64(0) - } - size := int64(0) - if alloc { - size += int64(80) - } - // field Values []vitess.io/vitess/go/vt/sqlparser.Exprs - { - size += hack.RuntimeAllocSize(int64(cap(cached.Values)) * int64(24)) - for _, elem := range cached.Values { - { - size += hack.RuntimeAllocSize(int64(cap(elem)) * int64(16)) - for _, elem := range elem { - if cc, ok := elem.(cachedObject); ok { - size += cc.CachedSize(true) - } - } - } - } - } - // field Cols []vitess.io/vitess/go/vt/vtgate/engine.CheckCol - { - size += hack.RuntimeAllocSize(int64(cap(cached.Cols)) * int64(22)) - for _, elem := range cached.Cols { - size += elem.CachedSize(false) - } - } - // field BvName string - size += hack.RuntimeAllocSize(int64(len(cached.BvName))) - // field Exec vitess.io/vitess/go/vt/vtgate/engine.Primitive - if cc, ok := cached.Exec.(cachedObject); ok { - size += cc.CachedSize(true) - } - return size -} func (cached *FkVerify) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) @@ -339,7 +302,7 @@ func (cached *FkVerify) CachedSize(alloc bool) int64 { if alloc { size += int64(48) } - // field Verify []*vitess.io/vitess/go/vt/vtgate/engine.FkParent + // field Verify []*vitess.io/vitess/go/vt/vtgate/engine.Verify { size += hack.RuntimeAllocSize(int64(cap(cached.Verify)) * int64(8)) for _, elem := range cached.Verify { @@ -1280,6 +1243,22 @@ func (cached *VStream) CachedSize(alloc bool) int64 { size += hack.RuntimeAllocSize(int64(len(cached.Position))) return size } +func (cached *Verify) CachedSize(alloc bool) int64 { + if cached == nil { + return int64(0) + } + size := int64(0) + if alloc { + size += int64(32) + } + // field Exec vitess.io/vitess/go/vt/vtgate/engine.Primitive + if cc, ok := cached.Exec.(cachedObject); ok { + size += cc.CachedSize(true) + } + // field Typ string + size += hack.RuntimeAllocSize(int64(len(cached.Typ))) + return size +} func (cached *VindexFunc) CachedSize(alloc bool) int64 { if cached == nil { return int64(0) diff --git a/go/vt/vtgate/engine/fk_verify.go b/go/vt/vtgate/engine/fk_verify.go index 408317bb37c..350aeec59e0 100644 --- a/go/vt/vtgate/engine/fk_verify.go +++ b/go/vt/vtgate/engine/fk_verify.go @@ -23,28 +23,30 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" ) -// FkParent is a primitive that represents a parent table foreign key constraint to verify against. -type FkParent struct { - Values []sqlparser.Exprs - Cols []CheckCol - BvName string - +// Verify contains the verification primitve and its type i.e. parent or child +type Verify struct { Exec Primitive + Typ string } // FkVerify is a primitive that verifies that the foreign key constraints in parent tables are satisfied. // It does this by executing a select distinct query on the parent table with the values that are being inserted/updated. type FkVerify struct { - Verify []*FkParent + Verify []*Verify Exec Primitive txNeeded } +// constants for verification type. +const ( + ParentVerify = "VerifyParent" + ChildVerify = "VerifyChild" +) + // RouteType implements the Primitive interface func (f *FkVerify) RouteType() string { return "FKVerify" @@ -67,35 +69,13 @@ func (f *FkVerify) GetFields(ctx context.Context, vcursor VCursor, bindVars map[ // TryExecute implements the Primitive interface func (f *FkVerify) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) { - for _, fk := range f.Verify { - pt := newProbeTable(fk.Cols) - newBv := &querypb.BindVariable{ - Type: querypb.Type_TUPLE, - } - for _, exprs := range fk.Values { - var row sqltypes.Row - var values []*querypb.Value - for _, expr := range exprs { - val, err := getValue(expr, bindVars) - if err != nil { - return nil, vterrors.Wrapf(err, "unable to get value for the expression %v", expr) - } - row = append(row, val) - values = append(values, sqltypes.ValueToProto(val)) - } - if exists, err := pt.exists(row); err != nil { - return nil, err - } else if !exists { - newBv.Values = append(newBv.Values, &querypb.Value{Type: querypb.Type_TUPLE, Values: values}) - } - } - distinctValues := len(newBv.Values) - qr, err := vcursor.ExecutePrimitive(ctx, fk.Exec, map[string]*querypb.BindVariable{fk.BvName: newBv}, wantfields) + for _, v := range f.Verify { + qr, err := vcursor.ExecutePrimitive(ctx, v.Exec, bindVars, wantfields) if err != nil { return nil, err } - if distinctValues != len(qr.Rows) { - return nil, vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.NoReferencedRow2, "Cannot add or update a child row: a foreign key constraint fails") + if len(qr.Rows) > 0 { + return nil, getError(v.Typ) } } return vcursor.ExecutePrimitive(ctx, f.Exec, bindVars, wantfields) @@ -103,42 +83,16 @@ func (f *FkVerify) TryExecute(ctx context.Context, vcursor VCursor, bindVars map // TryStreamExecute implements the Primitive interface func (f *FkVerify) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error { - for _, fk := range f.Verify { - pt := newProbeTable(fk.Cols) - newBv := &querypb.BindVariable{ - Type: querypb.Type_TUPLE, - } - for _, exprs := range fk.Values { - var row sqltypes.Row - var values []*querypb.Value - for _, expr := range exprs { - val, err := getValue(expr, bindVars) - if err != nil { - return vterrors.Wrapf(err, "unable to get value for the expression %v", expr) - } - row = append(row, val) - values = append(values, sqltypes.ValueToProto(val)) + for _, v := range f.Verify { + err := vcursor.StreamExecutePrimitive(ctx, v.Exec, bindVars, wantfields, func(qr *sqltypes.Result) error { + if len(qr.Rows) > 0 { + return getError(v.Typ) } - if exists, err := pt.exists(row); err != nil { - return err - } else if !exists { - newBv.Values = append(newBv.Values, &querypb.Value{Type: querypb.Type_TUPLE, Values: values}) - } - } - distinctValues := len(newBv.Values) - - seenRows := 0 - err := vcursor.StreamExecutePrimitive(ctx, fk.Exec, map[string]*querypb.BindVariable{fk.BvName: newBv}, wantfields, func(qr *sqltypes.Result) error { - seenRows += len(qr.Rows) return nil }) if err != nil { return err } - - if distinctValues != seenRows { - return vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.NoReferencedRow2, "Cannot add or update a child row: a foreign key constraint fails") - } } return vcursor.StreamExecutePrimitive(ctx, f.Exec, bindVars, wantfields, callback) } @@ -147,17 +101,15 @@ func (f *FkVerify) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVa func (f *FkVerify) Inputs() ([]Primitive, []map[string]any) { var inputs []Primitive var inputsMap []map[string]any - for idx, parent := range f.Verify { + for idx, v := range f.Verify { inputsMap = append(inputsMap, map[string]any{ - inputName: fmt.Sprintf("VerifyParent-%d", idx+1), - "BvName": parent.BvName, - "Cols": parent.Cols, + inputName: fmt.Sprintf("%s-%d", v.Typ, idx+1), }) - inputs = append(inputs, parent.Exec) + inputs = append(inputs, v.Exec) } inputs = append(inputs, f.Exec) inputsMap = append(inputsMap, map[string]any{ - inputName: "Child", + inputName: "PostVerify", }) return inputs, inputsMap @@ -169,24 +121,9 @@ func (f *FkVerify) description() PrimitiveDescription { var _ Primitive = (*FkVerify)(nil) -func getValue(expr sqlparser.Expr, bindVars map[string]*querypb.BindVariable) (sqltypes.Value, error) { - switch e := expr.(type) { - case *sqlparser.Literal: - return sqlparser.LiteralToValue(e) - case sqlparser.BoolVal: - b := int32(0) - if e { - b = 1 - } - return sqltypes.NewInt32(b), nil - case *sqlparser.NullVal: - return sqltypes.NULL, nil - case *sqlparser.Argument: - bv, exists := bindVars[e.Name] - if !exists { - return sqltypes.Value{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] bind variable %s missing", e.Name) - } - return sqltypes.BindVariableToValue(bv) +func getError(typ string) error { + if typ == ParentVerify { + return vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.NoReferencedRow2, "Cannot add or update a child row: a foreign key constraint fails") } - return sqltypes.Value{}, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] unexpected expression type: %T", expr) + return vterrors.NewErrorf(vtrpcpb.Code_FAILED_PRECONDITION, vterrors.RowIsReferenced2, "Cannot delete or update a parent row: a foreign key constraint fails") } diff --git a/go/vt/vtgate/engine/fk_verify_test.go b/go/vt/vtgate/engine/fk_verify_test.go index 5821510f93d..5635a32bc2c 100644 --- a/go/vt/vtgate/engine/fk_verify_test.go +++ b/go/vt/vtgate/engine/fk_verify_test.go @@ -22,16 +22,21 @@ import ( "github.com/stretchr/testify/require" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" - "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/vindexes" ) func TestFKVerifyUpdate(t *testing.T) { verifyP := &Route{ - Query: "select distinct cola, colb from parent where (cola, colb) in ::__vals", + Query: "select 1 from child c left join parent p on p.cola = 1 and p.colb = 'a' where p.cola is null and p.colb is null", + RoutingParameters: &RoutingParameters{ + Opcode: Unsharded, + Keyspace: &vindexes.Keyspace{Name: "ks"}, + }, + } + verifyC := &Route{ + Query: "select 1 from grandchild g join child c on g.cola = c.cola and g.colb = c.colb where c.foo = 48", RoutingParameters: &RoutingParameters{ Opcode: Unsharded, Keyspace: &vindexes.Keyspace{Name: "ks"}, @@ -47,29 +52,19 @@ func TestFKVerifyUpdate(t *testing.T) { }, } fkc := &FkVerify{ - Verify: []*FkParent{ - { - Values: []sqlparser.Exprs{{sqlparser.NewIntLiteral("1"), sqlparser.NewStrLiteral("a")}}, - Cols: []CheckCol{ - {Col: 0, Type: sqltypes.Int64, Collation: collations.CollationBinaryID}, - {Col: 1, Type: sqltypes.VarChar, Collation: collations.CollationUtf8mb4ID}, - }, - BvName: "__vals", - Exec: verifyP, - }, - }, - Exec: childP, + Verify: []*Verify{{Exec: verifyP, Typ: ParentVerify}}, + Exec: childP, } t.Run("foreign key verification success", func(t *testing.T) { - fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("cola|colb", "int64|varchar"), "1|a") + fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("1", "int64")) vc := newDMLTestVCursor("0") vc.results = []*sqltypes.Result{fakeRes} _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) require.NoError(t, err) vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `ExecuteMultiShard ks.0: select distinct cola, colb from parent where (cola, colb) in ::__vals {__vals: type:TUPLE values:{type:TUPLE values:{type:INT64 value:"1"} values:{type:VARCHAR value:"a"}}} false false`, + `ExecuteMultiShard ks.0: select 1 from child c left join parent p on p.cola = 1 and p.colb = 'a' where p.cola is null and p.colb is null {} false false`, `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.0: update child set cola = 1, colb = 'a' where foo = 48 {} true true`, }) @@ -79,22 +74,22 @@ func TestFKVerifyUpdate(t *testing.T) { require.NoError(t, err) vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `StreamExecuteMulti select distinct cola, colb from parent where (cola, colb) in ::__vals ks.0: {__vals: type:TUPLE values:{type:TUPLE values:{type:INT64 value:"1"} values:{type:VARCHAR value:"a"}}} `, + `StreamExecuteMulti select 1 from child c left join parent p on p.cola = 1 and p.colb = 'a' where p.cola is null and p.colb is null ks.0: {} `, `ResolveDestinations ks [] Destinations:DestinationAllShards()`, `ExecuteMultiShard ks.0: update child set cola = 1, colb = 'a' where foo = 48 {} true true`, }) }) - t.Run("foreign key verification failure", func(t *testing.T) { + t.Run("parent foreign key verification failure", func(t *testing.T) { // No results from select, should cause the foreign key verification to fail. - fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("cola|colb", "int64|varchar")) + fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("1", "int64"), "1", "1", "1") vc := newDMLTestVCursor("0") vc.results = []*sqltypes.Result{fakeRes} _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) require.ErrorContains(t, err, "Cannot add or update a child row: a foreign key constraint fails") vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `ExecuteMultiShard ks.0: select distinct cola, colb from parent where (cola, colb) in ::__vals {__vals: type:TUPLE values:{type:TUPLE values:{type:INT64 value:"1"} values:{type:VARCHAR value:"a"}}} false false`, + `ExecuteMultiShard ks.0: select 1 from child c left join parent p on p.cola = 1 and p.colb = 'a' where p.cola is null and p.colb is null {} false false`, }) vc.Rewind() @@ -102,86 +97,29 @@ func TestFKVerifyUpdate(t *testing.T) { require.ErrorContains(t, err, "Cannot add or update a child row: a foreign key constraint fails") vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `StreamExecuteMulti select distinct cola, colb from parent where (cola, colb) in ::__vals ks.0: {__vals: type:TUPLE values:{type:TUPLE values:{type:INT64 value:"1"} values:{type:VARCHAR value:"a"}}} `, - }) - }) -} - -// TestFKVerifyInsert tests the functionality of FkVerify Primitive to verify the validation of the foreign key constraints when executing an insert. -func TestFKVerifyInsert(t *testing.T) { - verifyP := &Route{ - Query: "select distinct cola, colb from parent where (cola, colb) in ::__vals", - RoutingParameters: &RoutingParameters{ - Opcode: Unsharded, - Keyspace: &vindexes.Keyspace{Name: "ks"}, - }, - } - childP := &Insert{ - Opcode: InsertUnsharded, - Query: "insert into child (cola, colb, x) values (1, 'a', 1), (2, 'b', 2), (1, 'a', 3),", - Keyspace: &vindexes.Keyspace{Name: "ks"}, - } - fkc := &FkVerify{ - Verify: []*FkParent{ - { - Values: []sqlparser.Exprs{ - {sqlparser.NewIntLiteral("1"), sqlparser.NewStrLiteral("a")}, - {sqlparser.NewIntLiteral("2"), sqlparser.NewStrLiteral("b")}, - {sqlparser.NewIntLiteral("1"), sqlparser.NewStrLiteral("a")}, - }, - Cols: []CheckCol{ - {Col: 0, Type: sqltypes.Int64, Collation: collations.CollationBinaryID}, - {Col: 1, Type: sqltypes.VarChar, Collation: collations.CollationUtf8mb4ID}, - }, - BvName: "__vals", - Exec: verifyP, - }, - }, - Exec: childP, - } - - t.Run("foreign key verification success", func(t *testing.T) { - fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("cola|colb", "int64|varchar"), "1|a", "2|b") - vc := newDMLTestVCursor("0") - vc.results = []*sqltypes.Result{fakeRes} - _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) - require.NoError(t, err) - vc.ExpectLog(t, []string{ - `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `ExecuteMultiShard ks.0: select distinct cola, colb from parent where (cola, colb) in ::__vals {__vals: type:TUPLE values:{type:TUPLE values:{type:INT64 value:"1"} values:{type:VARCHAR value:"a"}} values:{type:TUPLE values:{type:INT64 value:"2"} values:{type:VARCHAR value:"b"}}} false false`, - `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `ExecuteMultiShard ks.0: insert into child (cola, colb, x) values (1, 'a', 1), (2, 'b', 2), (1, 'a', 3), {} true true`, - }) - - vc.Rewind() - err = fkc.TryStreamExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error { return nil }) - require.NoError(t, err) - vc.ExpectLog(t, []string{ - `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `StreamExecuteMulti select distinct cola, colb from parent where (cola, colb) in ::__vals ks.0: {__vals: type:TUPLE values:{type:TUPLE values:{type:INT64 value:"1"} values:{type:VARCHAR value:"a"}} values:{type:TUPLE values:{type:INT64 value:"2"} values:{type:VARCHAR value:"b"}}} `, - `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `ExecuteMultiShard ks.0: insert into child (cola, colb, x) values (1, 'a', 1), (2, 'b', 2), (1, 'a', 3), {} true true`, + `StreamExecuteMulti select 1 from child c left join parent p on p.cola = 1 and p.colb = 'a' where p.cola is null and p.colb is null ks.0: {} `, }) }) - t.Run("foreign key verification failure", func(t *testing.T) { - // Only 1 result from select, should cause the foreign key verification to fail. - fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("cola|colb", "int64|varchar"), "2|b") + fkc.Verify[0] = &Verify{Exec: verifyC, Typ: ChildVerify} + t.Run("child foreign key verification failure", func(t *testing.T) { + // No results from select, should cause the foreign key verification to fail. + fakeRes := sqltypes.MakeTestResult(sqltypes.MakeTestFields("1", "int64"), "1", "1", "1") vc := newDMLTestVCursor("0") vc.results = []*sqltypes.Result{fakeRes} _, err := fkc.TryExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true) - require.ErrorContains(t, err, "Cannot add or update a child row: a foreign key constraint fails") + require.ErrorContains(t, err, "Cannot delete or update a parent row: a foreign key constraint fails") vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `ExecuteMultiShard ks.0: select distinct cola, colb from parent where (cola, colb) in ::__vals {__vals: type:TUPLE values:{type:TUPLE values:{type:INT64 value:"1"} values:{type:VARCHAR value:"a"}} values:{type:TUPLE values:{type:INT64 value:"2"} values:{type:VARCHAR value:"b"}}} false false`, + `ExecuteMultiShard ks.0: select 1 from grandchild g join child c on g.cola = c.cola and g.colb = c.colb where c.foo = 48 {} false false`, }) vc.Rewind() err = fkc.TryStreamExecute(context.Background(), vc, map[string]*querypb.BindVariable{}, true, func(result *sqltypes.Result) error { return nil }) - require.ErrorContains(t, err, "Cannot add or update a child row: a foreign key constraint fails") + require.ErrorContains(t, err, "Cannot delete or update a parent row: a foreign key constraint fails") vc.ExpectLog(t, []string{ `ResolveDestinations ks [] Destinations:DestinationAllShards()`, - `StreamExecuteMulti select distinct cola, colb from parent where (cola, colb) in ::__vals ks.0: {__vals: type:TUPLE values:{type:TUPLE values:{type:INT64 value:"1"} values:{type:VARCHAR value:"a"}} values:{type:TUPLE values:{type:INT64 value:"2"} values:{type:VARCHAR value:"b"}}} `, + `StreamExecuteMulti select 1 from grandchild g join child c on g.cola = c.cola and g.colb = c.colb where c.foo = 48 ks.0: {} `, }) }) } diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index 6cfbac52e0d..b5d67c9d994 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -222,7 +222,8 @@ type ( TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error - // Inputs is a slice containing the inputs to this Primitive + // Inputs is a slice containing the inputs to this Primitive. + // The returned map has additional information about the inputs, that is used in the description. Inputs() ([]Primitive, []map[string]any) // description is the description, sans the inputs, of this Primitive. diff --git a/go/vt/vtgate/planbuilder/delete.go b/go/vt/vtgate/planbuilder/delete.go index 034d1fa020f..5fa743e7034 100644 --- a/go/vt/vtgate/planbuilder/delete.go +++ b/go/vt/vtgate/planbuilder/delete.go @@ -57,7 +57,7 @@ func gen4DeleteStmtPlanner( } if ks, tables := ctx.SemTable.SingleUnshardedKeyspace(); ks != nil { - if fkManagementNotRequired(vschema, tables) { + if fkManagementNotRequired(ctx, vschema, tables) { plan := deleteUnshardedShortcut(deleteStmt, ks, tables) plan = pushCommentDirectivesOnPlan(plan, deleteStmt) return newPlanResult(plan.Primitive(), operators.QualifiedTables(ks, tables)...), nil @@ -94,7 +94,7 @@ func gen4DeleteStmtPlanner( return newPlanResult(plan.Primitive(), operators.TablesUsed(op)...), nil } -func fkManagementNotRequired(vschema plancontext.VSchema, vTables []*vindexes.Table) bool { +func fkManagementNotRequired(ctx *plancontext.PlanningContext, vschema plancontext.VSchema, vTables []*vindexes.Table) bool { // Find the foreign key mode and check for any managed child foreign keys. for _, vTable := range vTables { ksMode, err := vschema.ForeignKeyMode(vTable.Keyspace.Name) @@ -104,7 +104,7 @@ func fkManagementNotRequired(vschema plancontext.VSchema, vTables []*vindexes.Ta if ksMode != vschemapb.Keyspace_FK_MANAGED { continue } - childFks := vTable.ChildFKsNeedsHandling(vindexes.DeleteAction) + childFks := vTable.ChildFKsNeedsHandling(ctx.VerifyAllFKs, vindexes.DeleteAction) if len(childFks) > 0 { return false } diff --git a/go/vt/vtgate/planbuilder/fk_verify.go b/go/vt/vtgate/planbuilder/fk_verify.go new file mode 100644 index 00000000000..71638f88b9b --- /dev/null +++ b/go/vt/vtgate/planbuilder/fk_verify.go @@ -0,0 +1,103 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package planbuilder + +import ( + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" + "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" +) + +var _ logicalPlan = (*fkVerify)(nil) + +type verifyLP struct { + verify logicalPlan + typ string +} + +// fkVerify is the logicalPlan for engine.FkVerify. +type fkVerify struct { + input logicalPlan + verify []*verifyLP +} + +// newFkVerify builds a new fkVerify. +func newFkVerify(input logicalPlan, verify []*verifyLP) *fkVerify { + return &fkVerify{ + input: input, + verify: verify, + } +} + +// Primitive implements the logicalPlan interface +func (fkc *fkVerify) Primitive() engine.Primitive { + var verify []*engine.Verify + for _, v := range fkc.verify { + verify = append(verify, &engine.Verify{ + Exec: v.verify.Primitive(), + Typ: v.typ, + }) + } + return &engine.FkVerify{ + Exec: fkc.input.Primitive(), + Verify: verify, + } +} + +// Wireup implements the logicalPlan interface +func (fkc *fkVerify) Wireup(ctx *plancontext.PlanningContext) error { + for _, v := range fkc.verify { + err := v.verify.Wireup(ctx) + if err != nil { + return err + } + } + return fkc.input.Wireup(ctx) +} + +// Rewrite implements the logicalPlan interface +func (fkc *fkVerify) Rewrite(inputs ...logicalPlan) error { + if len(fkc.verify) != len(inputs)-1 { + return vterrors.VT13001("fkVerify: wrong number of inputs") + } + fkc.input = inputs[0] + for i := 1; i < len(inputs); i++ { + fkc.verify[i-1].verify = inputs[i] + } + return nil +} + +// ContainsTables implements the logicalPlan interface +func (fkc *fkVerify) ContainsTables() semantics.TableSet { + return fkc.input.ContainsTables() +} + +// Inputs implements the logicalPlan interface +func (fkc *fkVerify) Inputs() []logicalPlan { + inputs := []logicalPlan{fkc.input} + for _, v := range fkc.verify { + inputs = append(inputs, v.verify) + } + return inputs +} + +// OutputColumns implements the logicalPlan interface +func (fkc *fkVerify) OutputColumns() []sqlparser.SelectExpr { + return nil +} diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index e17548c68b1..602a61ccc81 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -67,6 +67,8 @@ func transformToLogicalPlan(ctx *plancontext.PlanningContext, op ops.Operator) ( return transformDistinct(ctx, op) case *operators.FkCascade: return transformFkCascade(ctx, op) + case *operators.FkVerify: + return transformFkVerify(ctx, op) } return nil, vterrors.VT13001(fmt.Sprintf("unknown type encountered: %T (transformToLogicalPlan)", op)) @@ -110,6 +112,33 @@ func transformFkCascade(ctx *plancontext.PlanningContext, fkc *operators.FkCasca return newFkCascade(parentLP, selLP, children), nil } +// transformFkVerify transforms a FkVerify operator into a logical plan. +func transformFkVerify(ctx *plancontext.PlanningContext, fkv *operators.FkVerify) (logicalPlan, error) { + inputLP, err := transformToLogicalPlan(ctx, fkv.Input) + if err != nil { + return nil, err + } + + // Once we have the input logical plan, we can create the primitives for the verification operators. + // For all of these, we don't need the semTable anymore. We set it to nil, to avoid using an incorrect one. + ctx.SemTable = nil + + // Go over the children and convert them to Primitives too. + var verify []*verifyLP + for _, v := range fkv.Verify { + lp, err := transformToLogicalPlan(ctx, v.Op) + if err != nil { + return nil, err + } + verify = append(verify, &verifyLP{ + verify: lp, + typ: v.Typ, + }) + } + + return newFkVerify(inputLP, verify), nil +} + func transformAggregator(ctx *plancontext.PlanningContext, op *operators.Aggregator) (logicalPlan, error) { plan, err := transformToLogicalPlan(ctx, op.Source) if err != nil { diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 15e1833703c..ba8e56b4f1c 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -47,7 +47,9 @@ func ToSQL(ctx *plancontext.PlanningContext, op ops.Operator) (sqlparser.Stateme if err != nil { return nil, nil, err } - q.sortTables() + if ctx.SemTable != nil { + q.sortTables() + } return q.stmt, q.dmlOperator, nil } diff --git a/go/vt/vtgate/planbuilder/operators/ast2op.go b/go/vt/vtgate/planbuilder/operators/ast2op.go index 8e1e76c65c6..9103b37fcb1 100644 --- a/go/vt/vtgate/planbuilder/operators/ast2op.go +++ b/go/vt/vtgate/planbuilder/operators/ast2op.go @@ -111,14 +111,19 @@ func createOperatorFromUnion(ctx *plancontext.PlanningContext, node *sqlparser.U return &Horizon{Source: union, Query: node}, nil } -func createOpFromStmt(ctx *plancontext.PlanningContext, stmt sqlparser.Statement) (ops.Operator, error) { +// createOpFromStmt creates an operator from the given statement. It takes in two additional arguments— +// 1. verifyAllFKs: For this given statement, do we need to verify validity of all the foreign keys on the vtgate level. +// 2. fkToIgnore: The foreign key constraint to specifically ignore while planning the statement. +func createOpFromStmt(ctx *plancontext.PlanningContext, stmt sqlparser.Statement, verifyAllFKs bool, fkToIgnore string) (ops.Operator, error) { newCtx, err := plancontext.CreatePlanningContext(stmt, ctx.ReservedVars, ctx.VSchema, ctx.PlannerVersion) if err != nil { return nil, err } - ctx = newCtx - return PlanQuery(ctx, stmt) + newCtx.VerifyAllFKs = verifyAllFKs + newCtx.ParentFKToIgnore = fkToIgnore + + return PlanQuery(newCtx, stmt) } func createOperatorFromInsert(ctx *plancontext.PlanningContext, ins *sqlparser.Insert) (ops.Operator, error) { @@ -151,7 +156,7 @@ func createOperatorFromInsert(ctx *plancontext.PlanningContext, ins *sqlparser.I return nil, err } if ksMode == vschemapb.Keyspace_FK_MANAGED { - parentFKs := vindexTable.ParentFKsNeedsHandling() + parentFKs := vindexTable.ParentFKsNeedsHandling(ctx.VerifyAllFKs, ctx.ParentFKToIgnore) if len(parentFKs) > 0 { return nil, vterrors.VT12002() } @@ -592,13 +597,15 @@ func isRestrict(onDelete sqlparser.ReferenceAction) bool { // createSelectionOp creates the selection operator to select the parent columns for the foreign key constraints. // The Select statement looks something like this - `SELECT FROM WHERE ` // TODO (@Harshit, @GuptaManan100): Compress the columns in the SELECT statement, if there are multiple foreign key constraints using the same columns. -func createSelectionOp(ctx *plancontext.PlanningContext, selectExprs []sqlparser.SelectExpr, tableExprs sqlparser.TableExprs, where *sqlparser.Where) (ops.Operator, error) { +func createSelectionOp(ctx *plancontext.PlanningContext, selectExprs []sqlparser.SelectExpr, tableExprs sqlparser.TableExprs, where *sqlparser.Where, limit *sqlparser.Limit) (ops.Operator, error) { selectionStmt := &sqlparser.Select{ SelectExprs: selectExprs, From: tableExprs, Where: where, + Limit: limit, } - return createOpFromStmt(ctx, selectionStmt) + // There are no foreign keys to check for a select query, so we can pass anything for verifyAllFKs and fkToIgnore. + return createOpFromStmt(ctx, selectionStmt, false /* verifyAllFKs */, "" /* fkToIgnore */) } func selectParentColumns(fk vindexes.ChildFKInfo, lastOffset int) ([]int, []sqlparser.SelectExpr) { diff --git a/go/vt/vtgate/planbuilder/operators/ast2op_test.go b/go/vt/vtgate/planbuilder/operators/ast2op_test.go index 0e5f86fc8bc..4dbcf49e80a 100644 --- a/go/vt/vtgate/planbuilder/operators/ast2op_test.go +++ b/go/vt/vtgate/planbuilder/operators/ast2op_test.go @@ -22,6 +22,7 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -36,14 +37,20 @@ func Test_fkNeedsHandlingForUpdates(t *testing.T) { Name: sqlparser.NewIdentifierCS("t2"), Keyspace: &vindexes.Keyspace{Name: "ks2"}, } + t3 := &vindexes.Table{ + Name: sqlparser.NewIdentifierCS("t3"), + Keyspace: &vindexes.Keyspace{Name: "ks"}, + } tests := []struct { - name string - updateExprs sqlparser.UpdateExprs - parentFks []vindexes.ParentFKInfo - childFks []vindexes.ChildFKInfo - parentFKsWanted []bool - childFKsWanted []bool + name string + verifyAllFks bool + parentFkToIgnore string + updateExprs sqlparser.UpdateExprs + parentFks []vindexes.ParentFKInfo + childFks []vindexes.ChildFKInfo + parentFKsWanted []bool + childFKsWanted []bool }{{ name: "No Fks filtered", updateExprs: sqlparser.UpdateExprs{ @@ -111,15 +118,74 @@ func Test_fkNeedsHandlingForUpdates(t *testing.T) { parentFks: []vindexes.ParentFKInfo{ {Table: t2, ChildColumns: sqlparser.MakeColumns("b", "a", "c")}, {Table: t2, ChildColumns: sqlparser.MakeColumns("a", "b")}, + {Table: t3, ChildColumns: sqlparser.MakeColumns("a", "b")}, + }, + parentFKsWanted: []bool{false, true, false}, + childFKsWanted: []bool{true}, + }, { + name: "Unsharded fk with verifyAllFk", + verifyAllFks: true, + updateExprs: sqlparser.UpdateExprs{ + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("a"), Expr: sqlparser.NewIntLiteral("1")}, + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("c"), Expr: &sqlparser.NullVal{}}, + }, + childFks: []vindexes.ChildFKInfo{ + {Table: t2, ParentColumns: sqlparser.MakeColumns("a", "b", "c")}, + }, + parentFks: []vindexes.ParentFKInfo{ + {Table: t2, ChildColumns: sqlparser.MakeColumns("b", "a", "c")}, + {Table: t2, ChildColumns: sqlparser.MakeColumns("a", "b")}, + {Table: t3, ChildColumns: sqlparser.MakeColumns("a", "b")}, + {Table: t3, ChildColumns: sqlparser.MakeColumns("a", "b", "c")}, + }, + parentFKsWanted: []bool{false, true, true, false}, + childFKsWanted: []bool{true}, + }, { + name: "Mixed case", + verifyAllFks: true, + parentFkToIgnore: "ks.t1abks.t3", + updateExprs: sqlparser.UpdateExprs{ + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("a"), Expr: sqlparser.NewIntLiteral("1")}, + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("c"), Expr: &sqlparser.NullVal{}}, + }, + childFks: []vindexes.ChildFKInfo{ + {Table: t2, ParentColumns: sqlparser.MakeColumns("a", "b", "c")}, + }, + parentFks: []vindexes.ParentFKInfo{ + {Table: t2, ChildColumns: sqlparser.MakeColumns("b", "a", "c")}, + {Table: t2, ChildColumns: sqlparser.MakeColumns("a", "b")}, + {Table: t3, ChildColumns: sqlparser.MakeColumns("a", "b")}, + {Table: t3, ChildColumns: sqlparser.MakeColumns("a", "b", "c")}, }, - parentFKsWanted: []bool{false, true}, + parentFKsWanted: []bool{false, true, false, false}, + childFKsWanted: []bool{true}, + }, { + name: "Ignore Fk specified", + parentFkToIgnore: "ks.t1aefks2.t2", + updateExprs: sqlparser.UpdateExprs{ + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("a"), Expr: sqlparser.NewIntLiteral("1")}, + &sqlparser.UpdateExpr{Name: sqlparser.NewColName("c"), Expr: &sqlparser.NullVal{}}, + }, + childFks: []vindexes.ChildFKInfo{ + {Table: t2, ParentColumns: sqlparser.MakeColumns("a", "b", "c")}, + }, + parentFks: []vindexes.ParentFKInfo{ + {Table: t2, ChildColumns: sqlparser.MakeColumns("b", "a", "c")}, + {Table: t2, ChildColumns: sqlparser.MakeColumns("a", "b")}, + {Table: t2, ChildColumns: sqlparser.MakeColumns("a", "e", "f")}, + }, + parentFKsWanted: []bool{false, true, false}, childFKsWanted: []bool{true}, }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t1.ParentForeignKeys = tt.parentFks t1.ChildForeignKeys = tt.childFks - parentFksGot, childFksGot := getFKRequirementsForUpdate(tt.updateExprs, t1) + ctx := &plancontext.PlanningContext{ + VerifyAllFKs: tt.verifyAllFks, + ParentFKToIgnore: tt.parentFkToIgnore, + } + parentFksGot, childFksGot := getFKRequirementsForUpdate(ctx, tt.updateExprs, t1) var pFks []vindexes.ParentFKInfo for idx, expected := range tt.parentFKsWanted { if expected { diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_delete_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_delete_op.go index 55db837e147..796e71f05bf 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_delete_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_delete_op.go @@ -55,7 +55,7 @@ func createOperatorFromDelete(ctx *plancontext.PlanningContext, deleteStmt *sqlp return delOp, nil } - childFks := vindexTable.ChildFKsNeedsHandling(vindexes.DeleteAction) + childFks := vindexTable.ChildFKsNeedsHandling(ctx.VerifyAllFKs, vindexes.DeleteAction) // If there are no foreign key constraints, then we don't need to do anything. if len(childFks) == 0 { return delOp, nil @@ -145,7 +145,7 @@ func createFkCascadeOpForDelete(ctx *plancontext.PlanningContext, parentOp ops.O } fkChildren = append(fkChildren, fkChild) } - selectionOp, err := createSelectionOp(ctx, selectExprs, delStmt.TableExprs, delStmt.Where) + selectionOp, err := createSelectionOp(ctx, selectExprs, delStmt.TableExprs, delStmt.Where, nil) if err != nil { return nil, err } @@ -196,7 +196,8 @@ func createFkChildForDelete(ctx *plancontext.PlanningContext, fk vindexes.ChildF return nil, vterrors.VT09016() } - childOp, err := createOpFromStmt(ctx, childStmt) + // For the child statement of a DELETE query, we don't need to verify all the FKs on VTgate or ignore any foreign key explicitly. + childOp, err := createOpFromStmt(ctx, childStmt, false /* verifyAllFKs */, "" /* fkToIgnore */) if err != nil { return nil, err } diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_update_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_update_op.go index 8c358547064..f1efb5d2a0b 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_update_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_update_op.go @@ -52,26 +52,17 @@ func createOperatorFromUpdate(ctx *plancontext.PlanningContext, updStmt *sqlpars return updOp, nil } - parentFks, childFks := getFKRequirementsForUpdate(updStmt.Exprs, vindexTable) + parentFks, childFks := getFKRequirementsForUpdate(ctx, updStmt.Exprs, vindexTable) if len(childFks) == 0 && len(parentFks) == 0 { return updOp, nil } - if len(parentFks) > 0 { - return nil, vterrors.VT12003() - } - - // If there are no foreign key constraints, then we don't need to do anything. - if len(childFks) == 0 { - return updOp, nil - } - // If the delete statement has a limit, we don't support it yet. if updStmt.Limit != nil { - return nil, vterrors.VT12001("foreign keys management at vitess with limit") + return nil, vterrors.VT12001("update with limit with foreign key constraints") } - return createFKCascadeOp(ctx, updOp, updClone, childFks) + return buildFkOperator(ctx, updOp, updClone, parentFks, childFks, vindexTable) } func createUpdateOperator(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, vindexTable *vindexes.Table, qt *QueryTable, routing Routing) (ops.Operator, error) { @@ -127,9 +118,9 @@ func createUpdateOperator(ctx *plancontext.PlanningContext, updStmt *sqlparser.U // getFKRequirementsForUpdate analyzes update expressions to determine which foreign key constraints needs management at the VTGate. // It identifies parent and child foreign keys that require verification or cascade operations due to column updates. -func getFKRequirementsForUpdate(updateExprs sqlparser.UpdateExprs, vindexTable *vindexes.Table) ([]vindexes.ParentFKInfo, []vindexes.ChildFKInfo) { - parentFks := vindexTable.ParentFKsNeedsHandling() - childFks := vindexTable.ChildFKsNeedsHandling(vindexes.UpdateAction) +func getFKRequirementsForUpdate(ctx *plancontext.PlanningContext, updateExprs sqlparser.UpdateExprs, vindexTable *vindexes.Table) ([]vindexes.ParentFKInfo, []vindexes.ChildFKInfo) { + parentFks := vindexTable.ParentFKsNeedsHandling(ctx.VerifyAllFKs, ctx.ParentFKToIgnore) + childFks := vindexTable.ChildFKsNeedsHandling(ctx.VerifyAllFKs, vindexes.UpdateAction) if len(childFks) == 0 && len(parentFks) == 0 { return nil, nil } @@ -139,7 +130,7 @@ func getFKRequirementsForUpdate(updateExprs sqlparser.UpdateExprs, vindexTable * // Go over all the update expressions for _, updateExpr := range updateExprs { // Any foreign key to a child table for a column that has been updated - // will require the cascade operations to happen, so we include all such foreign keys. + // will require the cascade operations or restrict verification to happen, so we include all such foreign keys. for idx, childFk := range childFks { if childFk.ParentColumns.FindColumn(updateExpr.Name.Name) >= 0 { cFksRequired[idx] = true @@ -186,38 +177,67 @@ func getFKRequirementsForUpdate(updateExprs sqlparser.UpdateExprs, vindexTable * return pFksNeedsHandling, cFksNeedsHandling } -func createFKCascadeOp(ctx *plancontext.PlanningContext, parentOp ops.Operator, updStmt *sqlparser.Update, childFks []vindexes.ChildFKInfo) (ops.Operator, error) { - // We only support simple expressions in update queries with cascade. - for _, updateExpr := range updStmt.Exprs { - switch updateExpr.Expr.(type) { - case *sqlparser.Argument, *sqlparser.NullVal, sqlparser.BoolVal, *sqlparser.Literal: - default: - return nil, vterrors.VT12001("foreign keys management at vitess with non-literal values") +func buildFkOperator(ctx *plancontext.PlanningContext, updOp ops.Operator, updClone *sqlparser.Update, parentFks []vindexes.ParentFKInfo, childFks []vindexes.ChildFKInfo, updatedTable *vindexes.Table) (ops.Operator, error) { + // We only support simple expressions in update queries for foreign key handling. + if sqlparser.IsNonLiteral(updClone.Exprs) { + return nil, vterrors.VT12001("update expression with non-literal values with foreign key constraints") + } + + restrictChildFks, cascadeChildFks := splitChildFks(childFks) + + op, err := createFKCascadeOp(ctx, updOp, updClone, cascadeChildFks, updatedTable) + if err != nil { + return nil, err + } + + return createFKVerifyOp(ctx, op, updClone, parentFks, restrictChildFks) +} + +// splitChildFks splits the child foreign keys into restrict and cascade list as restrict is handled through Verify operator and cascade is handled through Cascade operator. +func splitChildFks(fks []vindexes.ChildFKInfo) (restrictChildFks, cascadeChildFks []vindexes.ChildFKInfo) { + for _, fk := range fks { + // Any RESTRICT type foreign keys that arrive here for 2 reasons— + // 1. cross-shard/cross-keyspace RESTRICT cases. + // 2. shard-scoped/unsharded RESTRICT cases arising because we have to validate all the foreign keys on VTGate. + if isRestrict(fk.OnUpdate) { + // For RESTRICT foreign keys, we need to verify that there are no child rows corresponding to the rows being updated. + // This is done using a FkVerify Operator. + restrictChildFks = append(restrictChildFks, fk) + } else { + // For all the other foreign keys like CASCADE, SET NULL, we have to cascade the update to the children, + // This is done by using a FkCascade Operator. + cascadeChildFks = append(cascadeChildFks, fk) } } + return +} + +func createFKCascadeOp(ctx *plancontext.PlanningContext, parentOp ops.Operator, updStmt *sqlparser.Update, childFks []vindexes.ChildFKInfo, updatedTable *vindexes.Table) (ops.Operator, error) { + if len(childFks) == 0 { + return parentOp, nil + } var fkChildren []*FkChild var selectExprs []sqlparser.SelectExpr for _, fk := range childFks { - // Any RESTRICT type foreign keys that arrive here, - // are cross-shard/cross-keyspace RESTRICT cases, which we don't currently support. + // We should have already filtered out update restrict foreign keys. if isRestrict(fk.OnUpdate) { - return nil, vterrors.VT12002() + return nil, vterrors.VT13001("ON UPDATE RESTRICT foreign keys should already be filtered") } // We need to select all the parent columns for the foreign key constraint, to use in the update of the child table. cols, exprs := selectParentColumns(fk, len(selectExprs)) selectExprs = append(selectExprs, exprs...) - fkChild, err := createFkChildForUpdate(ctx, fk, updStmt, cols) + fkChild, err := createFkChildForUpdate(ctx, fk, updStmt, cols, updatedTable) if err != nil { return nil, err } fkChildren = append(fkChildren, fkChild) } - selectionOp, err := createSelectionOp(ctx, selectExprs, updStmt.TableExprs, updStmt.Where) + selectionOp, err := createSelectionOp(ctx, selectExprs, updStmt.TableExprs, updStmt.Where, nil) if err != nil { return nil, err } @@ -230,98 +250,290 @@ func createFKCascadeOp(ctx *plancontext.PlanningContext, parentOp ops.Operator, } // createFkChildForUpdate creates the update query operator for the child table based on the foreign key constraints. -func createFkChildForUpdate(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, updStmt *sqlparser.Update, cols []int) (*FkChild, error) { - // Reserve a bind variable name - bvName := ctx.ReservedVars.ReserveVariable(foriegnKeyContraintValues) - - // Create child update operator +func createFkChildForUpdate(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, updStmt *sqlparser.Update, cols []int, updatedTable *vindexes.Table) (*FkChild, error) { // Create a ValTuple of child column names var valTuple sqlparser.ValTuple for _, column := range fk.ChildColumns { valTuple = append(valTuple, sqlparser.NewColName(column.String())) } + // Reserve a bind variable name + bvName := ctx.ReservedVars.ReserveVariable(foriegnKeyContraintValues) // Create a comparison expression for WHERE clause compExpr := sqlparser.NewComparisonExpr(sqlparser.InOp, valTuple, sqlparser.NewListArg(bvName), nil) - - // Populate the update expressions and the where clause for the child update query based on the foreign key constraint type. var childWhereExpr sqlparser.Expr = compExpr - var childUpdateExprs sqlparser.UpdateExprs + var childOp ops.Operator + var err error switch fk.OnUpdate { case sqlparser.Cascade: - // For CASCADE type constraint, the query looks like this - - // `UPDATE SET WHERE IN ()` + childOp, err = buildChildUpdOpForCascade(ctx, fk, updStmt, childWhereExpr, updatedTable) + case sqlparser.SetNull: + childOp, err = buildChildUpdOpForSetNull(ctx, fk, updStmt, childWhereExpr, valTuple) + case sqlparser.SetDefault: + return nil, vterrors.VT09016() + } + if err != nil { + return nil, err + } + + return &FkChild{ + BVName: bvName, + Cols: cols, + Op: childOp, + }, nil +} + +// buildChildUpdOpForCascade builds the child update statement operator for the CASCADE type foreign key constraint. +// The query looks like this - +// +// `UPDATE SET WHERE IN ()` +func buildChildUpdOpForCascade(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, updStmt *sqlparser.Update, childWhereExpr sqlparser.Expr, updatedTable *vindexes.Table) (ops.Operator, error) { + // The update expressions are the same as the update expressions in the parent update query + // with the column names replaced with the child column names. + var childUpdateExprs sqlparser.UpdateExprs + for _, updateExpr := range updStmt.Exprs { + colIdx := fk.ParentColumns.FindColumn(updateExpr.Name.Name) + if colIdx == -1 { + continue + } - // The update expressions are the same as the update expressions in the parent update query + // The where condition is the same as the comparison expression above // with the column names replaced with the child column names. - for _, updateExpr := range updStmt.Exprs { - colIdx := fk.ParentColumns.FindColumn(updateExpr.Name.Name) - if colIdx == -1 { - continue + childUpdateExprs = append(childUpdateExprs, &sqlparser.UpdateExpr{ + Name: sqlparser.NewColName(fk.ChildColumns[colIdx].String()), + Expr: updateExpr.Expr, + }) + } + // Because we could be updating the child to a non-null value, + // We have to run with foreign key checks OFF because the parent isn't guaranteed to have + // the data being updated to. + parsedComments := sqlparser.Comments{ + "/*+ SET_VAR(foreign_key_checks=OFF) */", + }.Parsed() + childUpdStmt := &sqlparser.Update{ + Comments: parsedComments, + Exprs: childUpdateExprs, + TableExprs: []sqlparser.TableExpr{sqlparser.NewAliasedTableExpr(fk.Table.GetTableName(), "")}, + Where: &sqlparser.Where{Type: sqlparser.WhereClause, Expr: childWhereExpr}, + } + // Since we are running the child update with foreign key checks turned off, + // we need to verify the validity of the remaining foreign keys on VTGate, + // while specifically ignoring the parent foreign key in question. + return createOpFromStmt(ctx, childUpdStmt, true, fk.String(updatedTable)) + +} + +// buildChildUpdOpForSetNull builds the child update statement operator for the SET NULL type foreign key constraint. +// The query looks like this - +// +// `UPDATE SET +// WHERE IN () +// [AND NOT IN ()]` +func buildChildUpdOpForSetNull(ctx *plancontext.PlanningContext, fk vindexes.ChildFKInfo, updStmt *sqlparser.Update, childWhereExpr sqlparser.Expr, valTuple sqlparser.ValTuple) (ops.Operator, error) { + // For the SET NULL type constraint, we need to set all the child columns to NULL. + var childUpdateExprs sqlparser.UpdateExprs + for _, column := range fk.ChildColumns { + childUpdateExprs = append(childUpdateExprs, &sqlparser.UpdateExpr{ + Name: sqlparser.NewColName(column.String()), + Expr: &sqlparser.NullVal{}, + }) + } + + // SET NULL cascade should be avoided for the case where the parent columns remains unchanged on the update. + // We need to add a condition to the where clause to handle this case. + // The additional condition looks like [AND NOT IN ()]. + // If any of the parent columns is being set to NULL, then we don't need this condition. + var updateValues sqlparser.ValTuple + colSetToNull := false + for _, updateExpr := range updStmt.Exprs { + colIdx := fk.ParentColumns.FindColumn(updateExpr.Name.Name) + if colIdx >= 0 { + if sqlparser.IsNull(updateExpr.Expr) { + colSetToNull = true + break } + updateValues = append(updateValues, updateExpr.Expr) + } + } + if !colSetToNull { + childWhereExpr = &sqlparser.AndExpr{ + Left: childWhereExpr, + Right: sqlparser.NewComparisonExpr(sqlparser.NotInOp, valTuple, updateValues, nil), + } + } + childUpdStmt := &sqlparser.Update{ + Exprs: childUpdateExprs, + TableExprs: []sqlparser.TableExpr{sqlparser.NewAliasedTableExpr(fk.Table.GetTableName(), "")}, + Where: &sqlparser.Where{Type: sqlparser.WhereClause, Expr: childWhereExpr}, + } + return createOpFromStmt(ctx, childUpdStmt, false, "") +} - // The where condition is the same as the comparison expression above - // with the column names replaced with the child column names. - childUpdateExprs = append(childUpdateExprs, &sqlparser.UpdateExpr{ - Name: sqlparser.NewColName(fk.ChildColumns[colIdx].String()), - Expr: updateExpr.Expr, - }) +// createFKVerifyOp creates the verify operator for the parent foreign key constraints. +func createFKVerifyOp(ctx *plancontext.PlanningContext, childOp ops.Operator, updStmt *sqlparser.Update, parentFks []vindexes.ParentFKInfo, restrictChildFks []vindexes.ChildFKInfo) (ops.Operator, error) { + if len(parentFks) == 0 && len(restrictChildFks) == 0 { + return childOp, nil + } + + var Verify []*VerifyOp + // This validates that new values exists on the parent table. + for _, fk := range parentFks { + op, err := createFkVerifyOpForParentFKForUpdate(ctx, updStmt, fk) + if err != nil { + return nil, err } - case sqlparser.SetNull: - // For SET NULL type constraint, the query looks like this - - // `UPDATE SET - // WHERE IN () - // [AND NOT IN ()]` - - // For the SET NULL type constraint, we need to set all the child columns to NULL. - for _, column := range fk.ChildColumns { - childUpdateExprs = append(childUpdateExprs, &sqlparser.UpdateExpr{ - Name: sqlparser.NewColName(column.String()), - Expr: &sqlparser.NullVal{}, - }) + Verify = append(Verify, &VerifyOp{ + Op: op, + Typ: engine.ParentVerify, + }) + } + // This validates that the old values don't exist on the child table. + for _, fk := range restrictChildFks { + op, err := createFkVerifyOpForChildFKForUpdate(ctx, updStmt, fk) + if err != nil { + return nil, err } + Verify = append(Verify, &VerifyOp{ + Op: op, + Typ: engine.ChildVerify, + }) + } - // SET NULL cascade should be avoided for the case where the parent columns remains unchanged on the update. - // We need to add a condition to the where clause to handle this case. - // The additional condition looks like [AND NOT IN ()]. - // If any of the parent columns is being set to NULL, then we don't need this condition. - var updateValues sqlparser.ValTuple - colSetToNull := false + return &FkVerify{ + Verify: Verify, + Input: childOp, + }, nil +} + +// Each parent foreign key constraint is verified by an anti join query of the form: +// select 1 from child_tbl left join parent_tbl on +// where and and limit 1 +// E.g: +// Child (c1, c2) references Parent (p1, p2) +// update Child set c1 = 1 where id = 1 +// verify query: +// select 1 from Child left join Parent on Parent.p1 = 1 and Parent.p2 = Child.c2 +// where Parent.p1 is null and Parent.p2 is null and Child.id = 1 +// and Child.c2 is not null +// limit 1 +func createFkVerifyOpForParentFKForUpdate(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, pFK vindexes.ParentFKInfo) (ops.Operator, error) { + childTblExpr := updStmt.TableExprs[0].(*sqlparser.AliasedTableExpr) + childTbl, err := childTblExpr.TableName() + if err != nil { + return nil, err + } + parentTbl := pFK.Table.GetTableName() + var whereCond sqlparser.Expr + var joinCond sqlparser.Expr + for idx, column := range pFK.ChildColumns { + var matchedExpr *sqlparser.UpdateExpr for _, updateExpr := range updStmt.Exprs { - colIdx := fk.ParentColumns.FindColumn(updateExpr.Name.Name) - if colIdx >= 0 { - if sqlparser.IsNull(updateExpr.Expr) { - colSetToNull = true - break - } - updateValues = append(updateValues, updateExpr.Expr) + if column.Equal(updateExpr.Name.Name) { + matchedExpr = updateExpr + break } } - if !colSetToNull { - childWhereExpr = &sqlparser.AndExpr{ - Left: compExpr, - Right: sqlparser.NewComparisonExpr(sqlparser.NotInOp, valTuple, updateValues, nil), + parentIsNullExpr := &sqlparser.IsExpr{ + Left: sqlparser.NewColNameWithQualifier(pFK.ParentColumns[idx].String(), parentTbl), + Right: sqlparser.IsNullOp, + } + var predicate sqlparser.Expr = parentIsNullExpr + var joinExpr sqlparser.Expr + if matchedExpr == nil { + predicate = &sqlparser.AndExpr{ + Left: parentIsNullExpr, + Right: &sqlparser.IsExpr{ + Left: sqlparser.NewColNameWithQualifier(pFK.ChildColumns[idx].String(), childTbl), + Right: sqlparser.IsNotNullOp, + }, + } + joinExpr = &sqlparser.ComparisonExpr{ + Operator: sqlparser.EqualOp, + Left: sqlparser.NewColNameWithQualifier(pFK.ParentColumns[idx].String(), parentTbl), + Right: sqlparser.NewColNameWithQualifier(pFK.ChildColumns[idx].String(), childTbl), + } + } else { + joinExpr = &sqlparser.ComparisonExpr{ + Operator: sqlparser.EqualOp, + Left: sqlparser.NewColNameWithQualifier(pFK.ParentColumns[idx].String(), parentTbl), + Right: prefixColNames(childTbl, matchedExpr.Expr), } } - case sqlparser.SetDefault: - return nil, vterrors.VT09016() - } - childStmt := &sqlparser.Update{ - Exprs: childUpdateExprs, - TableExprs: []sqlparser.TableExpr{sqlparser.NewAliasedTableExpr(fk.Table.GetTableName(), "")}, - Where: &sqlparser.Where{Type: sqlparser.WhereClause, Expr: childWhereExpr}, + if idx == 0 { + joinCond, whereCond = joinExpr, predicate + continue + } + joinCond = &sqlparser.AndExpr{Left: joinCond, Right: joinExpr} + whereCond = &sqlparser.AndExpr{Left: whereCond, Right: predicate} + } + // add existing where condition on the update statement + if updStmt.Where != nil { + whereCond = &sqlparser.AndExpr{Left: whereCond, Right: prefixColNames(childTbl, updStmt.Where.Expr)} } + return createSelectionOp(ctx, + sqlparser.SelectExprs{sqlparser.NewAliasedExpr(sqlparser.NewIntLiteral("1"), "")}, + []sqlparser.TableExpr{ + sqlparser.NewJoinTableExpr( + childTblExpr, + sqlparser.LeftJoinType, + sqlparser.NewAliasedTableExpr(parentTbl, ""), + sqlparser.NewJoinCondition(joinCond, nil)), + }, + sqlparser.NewWhere(sqlparser.WhereClause, whereCond), + sqlparser.NewLimitWithoutOffset(1)) +} - childOp, err := createOpFromStmt(ctx, childStmt) +// Each child foreign key constraint is verified by a join query of the form: +// select 1 from child_tbl join parent_tbl on where limit 1 +// E.g: +// Child (c1, c2) references Parent (p1, p2) +// update Parent set p1 = 1 where id = 1 +// verify query: +// select 1 from Child join Parent on Parent.p1 = Child.c1 and Parent.p2 = Child.c2 +// where Parent.id = 1 limit 1 +func createFkVerifyOpForChildFKForUpdate(ctx *plancontext.PlanningContext, updStmt *sqlparser.Update, cFk vindexes.ChildFKInfo) (ops.Operator, error) { + // ON UPDATE RESTRICT foreign keys that require validation, should only be allowed in the case where we + // are verifying all the FKs on vtgate level. + if !ctx.VerifyAllFKs { + return nil, vterrors.VT12002() + } + parentTblExpr := updStmt.TableExprs[0].(*sqlparser.AliasedTableExpr) + parentTbl, err := parentTblExpr.TableName() if err != nil { return nil, err } + childTbl := cFk.Table.GetTableName() + var joinCond sqlparser.Expr + for idx := range cFk.ParentColumns { + joinExpr := &sqlparser.ComparisonExpr{ + Operator: sqlparser.EqualOp, + Left: sqlparser.NewColNameWithQualifier(cFk.ParentColumns[idx].String(), parentTbl), + Right: sqlparser.NewColNameWithQualifier(cFk.ChildColumns[idx].String(), childTbl), + } - return &FkChild{ - BVName: bvName, - Cols: cols, - Op: childOp, - }, nil + if idx == 0 { + joinCond = joinExpr + continue + } + joinCond = &sqlparser.AndExpr{Left: joinCond, Right: joinExpr} + } + + var whereCond sqlparser.Expr + // add existing where condition on the update statement + if updStmt.Where != nil { + whereCond = prefixColNames(parentTbl, updStmt.Where.Expr) + } + return createSelectionOp(ctx, + sqlparser.SelectExprs{sqlparser.NewAliasedExpr(sqlparser.NewIntLiteral("1"), "")}, + []sqlparser.TableExpr{ + sqlparser.NewJoinTableExpr( + parentTblExpr, + sqlparser.NormalJoinType, + sqlparser.NewAliasedTableExpr(childTbl, ""), + sqlparser.NewJoinCondition(joinCond, nil)), + }, + sqlparser.NewWhere(sqlparser.WhereClause, whereCond), + sqlparser.NewLimitWithoutOffset(1)) } diff --git a/go/vt/vtgate/planbuilder/operators/fk_cascade.go b/go/vt/vtgate/planbuilder/operators/fk_cascade.go index f4528694c39..a9afbde0a7c 100644 --- a/go/vt/vtgate/planbuilder/operators/fk_cascade.go +++ b/go/vt/vtgate/planbuilder/operators/fk_cascade.go @@ -102,5 +102,5 @@ func (fkc *FkCascade) GetOrdering() ([]ops.OrderBy, error) { // ShortDescription implements the Operator interface func (fkc *FkCascade) ShortDescription() string { - return "FkCascade" + return "" } diff --git a/go/vt/vtgate/planbuilder/operators/fk_verify.go b/go/vt/vtgate/planbuilder/operators/fk_verify.go new file mode 100644 index 00000000000..8c2431d26fc --- /dev/null +++ b/go/vt/vtgate/planbuilder/operators/fk_verify.go @@ -0,0 +1,80 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package operators + +import ( + "vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops" +) + +// VerifyOp keeps the information about the foreign key verification operation. +// It is a Parent verification or a Child verification. +type VerifyOp struct { + Op ops.Operator + Typ string +} + +// FkVerify is used to represent a foreign key verification operation +// as an operator. This operator is created for DML queries that require +// verifications on the existence of the rows in the parent table (for example, INSERT and UPDATE). +type FkVerify struct { + Verify []*VerifyOp + Input ops.Operator + + noColumns + noPredicates +} + +var _ ops.Operator = (*FkVerify)(nil) + +// Inputs implements the Operator interface +func (fkv *FkVerify) Inputs() []ops.Operator { + inputs := []ops.Operator{fkv.Input} + for _, v := range fkv.Verify { + inputs = append(inputs, v.Op) + } + return inputs +} + +// SetInputs implements the Operator interface +func (fkv *FkVerify) SetInputs(operators []ops.Operator) { + fkv.Input = operators[0] + if len(fkv.Verify) != len(operators)-1 { + panic("mismatched number of verify inputs") + } + for i := 1; i < len(operators); i++ { + fkv.Verify[i-1].Op = operators[i] + } +} + +// Clone implements the Operator interface +func (fkv *FkVerify) Clone(inputs []ops.Operator) ops.Operator { + newFkv := &FkVerify{ + Verify: fkv.Verify, + } + newFkv.SetInputs(inputs) + return newFkv +} + +// GetOrdering implements the Operator interface +func (fkv *FkVerify) GetOrdering() ([]ops.OrderBy, error) { + return nil, nil +} + +// ShortDescription implements the Operator interface +func (fkv *FkVerify) ShortDescription() string { + return "" +} diff --git a/go/vt/vtgate/planbuilder/operators/horizon_planning.go b/go/vt/vtgate/planbuilder/operators/horizon_planning.go index 4dfb185f07e..f55a84fb6a1 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_planning.go @@ -342,10 +342,7 @@ func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Proje } expr = semantics.RewriteDerivedTableExpression(expr, derivedTbl) - out, err := prefixColNames(tblName, expr) - if err != nil { - return err - } + out := prefixColNames(tblName, expr) alias := sqlparser.UnescapedString(out) predicate.LHSExprs[idx] = sqlparser.NewColNameWithQualifier(alias, derivedTblName) @@ -357,15 +354,14 @@ func exposeColumnsThroughDerivedTable(ctx *plancontext.PlanningContext, p *Proje // prefixColNames adds qualifier prefixes to all ColName:s. // We want to be more explicit than the user was to make sure we never produce invalid SQL -func prefixColNames(tblName sqlparser.TableName, e sqlparser.Expr) (out sqlparser.Expr, err error) { - out = sqlparser.CopyOnRewrite(e, nil, func(cursor *sqlparser.CopyOnWriteCursor) { +func prefixColNames(tblName sqlparser.TableName, e sqlparser.Expr) sqlparser.Expr { + return sqlparser.CopyOnRewrite(e, nil, func(cursor *sqlparser.CopyOnWriteCursor) { col, ok := cursor.Node().(*sqlparser.ColName) if !ok { return } col.Qualifier = tblName }, nil).(sqlparser.Expr) - return } func createProjectionWithTheseColumns( diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 7d54d01b565..9bd1778ab6c 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -28,13 +28,12 @@ import ( "strings" "testing" - "vitess.io/vitess/go/test/vschemawrapper" - "github.com/nsf/jsondiff" "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" + "vitess.io/vitess/go/test/vschemawrapper" "vitess.io/vitess/go/vt/key" topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/servenv" @@ -152,7 +151,10 @@ func setFks(t *testing.T, vschema *vindexes.VSchema) { // u_tbl3(col2) -> u_tbl2(col2) Cascade Null. // u_tbl4(col4) -> u_tbl3(col3) Restrict. // u_tbl6(col6) -> u_tbl5(col5) Restrict. - // u_tbl8(col8) -> u_tbl9(col9) Cascade Null. + // u_tbl8(col8) -> u_tbl9(col9) Null Null. + // u_tbl8(col8) -> u_tbl6(col6) Cascade Null. + // u_tbl4(col4) -> u_tbl7(col7) Cascade Cascade. + // u_tbl9(col9) -> u_tbl4(col4) Restrict Restrict. _ = vschema.AddForeignKey("unsharded_fk_allow", "u_tbl2", createFkDefinition([]string{"col2"}, "u_tbl1", []string{"col1"}, sqlparser.Cascade, sqlparser.Cascade)) _ = vschema.AddForeignKey("unsharded_fk_allow", "u_tbl9", createFkDefinition([]string{"col9"}, "u_tbl1", []string{"col1"}, sqlparser.SetNull, sqlparser.NoAction)) @@ -161,6 +163,9 @@ func setFks(t *testing.T, vschema *vindexes.VSchema) { _ = vschema.AddForeignKey("unsharded_fk_allow", "u_tbl4", createFkDefinition([]string{"col4"}, "u_tbl3", []string{"col3"}, sqlparser.Restrict, sqlparser.Restrict)) _ = vschema.AddForeignKey("unsharded_fk_allow", "u_tbl6", createFkDefinition([]string{"col6"}, "u_tbl5", []string{"col5"}, sqlparser.DefaultAction, sqlparser.DefaultAction)) _ = vschema.AddForeignKey("unsharded_fk_allow", "u_tbl8", createFkDefinition([]string{"col8"}, "u_tbl9", []string{"col9"}, sqlparser.SetNull, sqlparser.SetNull)) + _ = vschema.AddForeignKey("unsharded_fk_allow", "u_tbl8", createFkDefinition([]string{"col8"}, "u_tbl6", []string{"col6"}, sqlparser.Cascade, sqlparser.CASCADE)) + _ = vschema.AddForeignKey("unsharded_fk_allow", "u_tbl4", createFkDefinition([]string{"col4"}, "u_tbl7", []string{"col7"}, sqlparser.Cascade, sqlparser.Cascade)) + _ = vschema.AddForeignKey("unsharded_fk_allow", "u_tbl9", createFkDefinition([]string{"col9"}, "u_tbl4", []string{"col4"}, sqlparser.Restrict, sqlparser.Restrict)) } } diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 6fc99232379..28c592758f9 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -43,6 +43,15 @@ type PlanningContext struct { // DelegateAggregation tells us when we are allowed to split an aggregation across vtgate and mysql // We aggregate within a shard, and then at the vtgate level we aggregate the incoming shard aggregates DelegateAggregation bool + + // VerifyAllFKs tells whether we need verification for all the fk constraints on VTGate. + // This is required for queries we are running with /*+ SET_VAR(foreign_key_checks=OFF) */ + VerifyAllFKs bool + + // ParentFKToIgnore stores a specific parent foreign key that we would need to ignore while planning + // a certain query. This field is used in UPDATE CASCADE planning, wherein while planning the child update + // query, we need to ignore the parent foreign key constraint that caused the cascade in question. + ParentFKToIgnore string } func CreatePlanningContext(stmt sqlparser.Statement, reservedVars *sqlparser.ReservedVars, vschema VSchema, version querypb.ExecuteOptions_PlannerVersion) (*PlanningContext, error) { @@ -70,13 +79,13 @@ func CreatePlanningContext(stmt sqlparser.Statement, reservedVars *sqlparser.Res }, nil } -func (c *PlanningContext) IsSubQueryToReplace(e sqlparser.Expr) bool { +func (ctx *PlanningContext) IsSubQueryToReplace(e sqlparser.Expr) bool { ext, ok := e.(*sqlparser.Subquery) if !ok { return false } - for _, extractedSubq := range c.SemTable.GetSubqueryNeedingRewrite() { - if extractedSubq.Merged && c.SemTable.EqualsExpr(extractedSubq.Subquery, ext) { + for _, extractedSubq := range ctx.SemTable.GetSubqueryNeedingRewrite() { + if extractedSubq.Merged && ctx.SemTable.EqualsExpr(extractedSubq.Subquery, ext) { return true } } diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json index fcf770107cd..c7666a07bea 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json @@ -490,9 +490,84 @@ } }, { - "comment": "update table with column's parent foreign key cross shard - disallowed", + "comment": "update table with column's parent foreign key cross shard", "query": "update tbl10 set col = 'foo'", - "plan": "VT12002: unsupported: foreign keys management at vitess" + "plan": { + "QueryType": "UPDATE", + "Original": "update tbl10 set col = 'foo'", + "Instructions": { + "OperatorType": "FKVerify", + "Inputs": [ + { + "InputName": "VerifyParent-1", + "OperatorType": "Projection", + "Expressions": [ + "INT64(1) as 1" + ], + "Inputs": [ + { + "OperatorType": "Limit", + "Count": "INT64(1)", + "Inputs": [ + { + "OperatorType": "Filter", + "Predicate": "tbl3.col is null", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "R:0,R:0", + "TableName": "tbl10_tbl3", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "sharded_fk_allow", + "Sharded": true + }, + "FieldQuery": "select 1 from tbl10 where 1 != 1", + "Query": "select 1 from tbl10 lock in share mode", + "Table": "tbl10" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "sharded_fk_allow", + "Sharded": true + }, + "FieldQuery": "select tbl3.col from tbl3 where 1 != 1", + "Query": "select tbl3.col from tbl3 where tbl3.col = 'foo' lock in share mode", + "Table": "tbl3" + } + ] + } + ] + } + ] + } + ] + }, + { + "InputName": "PostVerify", + "OperatorType": "Update", + "Variant": "Scatter", + "Keyspace": { + "Name": "sharded_fk_allow", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "Query": "update tbl10 set col = 'foo'", + "Table": "tbl10" + } + ] + }, + "TablesUsed": [ + "sharded_fk_allow.tbl10", + "sharded_fk_allow.tbl3" + ] + } }, { "comment": "delete table with shard scoped foreign key set default - disallowed", @@ -630,7 +705,7 @@ "Sharded": false }, "TargetTabletType": "PRIMARY", - "Query": "update u_tbl2 set col2 = 'foo' where (col2) in ::fkc_vals", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl2 set col2 = 'foo' where (col2) in ::fkc_vals", "Table": "u_tbl2" } ] @@ -711,17 +786,17 @@ { "comment": "update in a table with limit - disallowed", "query": "update u_tbl2 set col2 = 'bar' limit 2", - "plan": "VT12001: unsupported: foreign keys management at vitess with limit" + "plan": "VT12001: unsupported: update with limit with foreign key constraints" }, { "comment": "update in a table with non-literal value - set null fail due to child update where condition", "query": "update u_tbl2 set m = 2, col2 = col1 + 'bar' where id = 1", - "plan": "VT12001: unsupported: foreign keys management at vitess with non-literal values" + "plan": "VT12001: unsupported: update expression with non-literal values with foreign key constraints" }, { "comment": "update in a table with non-literal value - with cascade fail as the cascade value is not known", "query": "update u_tbl1 set m = 2, col1 = x + 'bar' where id = 1", - "plan": "VT12001: unsupported: foreign keys management at vitess with non-literal values" + "plan": "VT12001: unsupported: update expression with non-literal values with foreign key constraints" }, { "comment": "update in a table with a child table having SET DEFAULT constraint - disallowed", @@ -732,5 +807,255 @@ "comment": "delete in a table with limit - disallowed", "query": "delete from u_tbl2 limit 2", "plan": "VT12001: unsupported: foreign keys management at vitess with limit" + }, + { + "comment": "update with fk on cross-shard with a where condition on non-literal value - disallowed", + "query": "update tbl3 set coly = colx + 10 where coly = 10", + "plan": "VT12001: unsupported: update expression with non-literal values with foreign key constraints" + }, + { + "comment": "update with fk on cross-shard with a where condition", + "query": "update tbl3 set coly = 20 where coly = 10", + "plan": { + "QueryType": "UPDATE", + "Original": "update tbl3 set coly = 20 where coly = 10", + "Instructions": { + "OperatorType": "FKVerify", + "Inputs": [ + { + "InputName": "VerifyParent-1", + "OperatorType": "Projection", + "Expressions": [ + "INT64(1) as 1" + ], + "Inputs": [ + { + "OperatorType": "Limit", + "Count": "INT64(1)", + "Inputs": [ + { + "OperatorType": "Filter", + "Predicate": "tbl1.t1col1 is null", + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "R:0,R:0", + "TableName": "tbl3_tbl1", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "sharded_fk_allow", + "Sharded": true + }, + "FieldQuery": "select 1 from tbl3 where 1 != 1", + "Query": "select 1 from tbl3 where tbl3.coly = 10 lock in share mode", + "Table": "tbl3" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "sharded_fk_allow", + "Sharded": true + }, + "FieldQuery": "select tbl1.t1col1 from tbl1 where 1 != 1", + "Query": "select tbl1.t1col1 from tbl1 where tbl1.t1col1 = 20 lock in share mode", + "Table": "tbl1" + } + ] + } + ] + } + ] + } + ] + }, + { + "InputName": "PostVerify", + "OperatorType": "Update", + "Variant": "Scatter", + "Keyspace": { + "Name": "sharded_fk_allow", + "Sharded": true + }, + "TargetTabletType": "PRIMARY", + "Query": "update tbl3 set coly = 20 where tbl3.coly = 10", + "Table": "tbl3" + } + ] + }, + "TablesUsed": [ + "sharded_fk_allow.tbl1", + "sharded_fk_allow.tbl3" + ] + } + }, + { + "comment": "Update in a table with shard-scoped foreign keys with cascade that requires a validation of a different parent foreign key", + "query": "update u_tbl6 set col6 = 'foo'", + "plan": { + "QueryType": "UPDATE", + "Original": "update u_tbl6 set col6 = 'foo'", + "Instructions": { + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select col6 from u_tbl6 where 1 != 1", + "Query": "select col6 from u_tbl6 lock in share mode", + "Table": "u_tbl6" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "FKVerify", + "BvName": "fkc_vals", + "Cols": [ + 0 + ], + "Inputs": [ + { + "InputName": "VerifyParent-1", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where 1 != 1", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = 'foo' where (u_tbl8.col8) in ::fkc_vals and u_tbl9.col9 is null limit 1", + "Table": "u_tbl8, u_tbl9" + }, + { + "InputName": "PostVerify", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl8 set col8 = 'foo' where (u_tbl8.col8) in ::fkc_vals", + "Table": "u_tbl8" + } + ] + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update u_tbl6 set col6 = 'foo'", + "Table": "u_tbl6" + } + ] + }, + "TablesUsed": [ + "unsharded_fk_allow.u_tbl6", + "unsharded_fk_allow.u_tbl8", + "unsharded_fk_allow.u_tbl9" + ] + } + }, + { + "comment": "Update that cascades and requires parent fk and restrict child fk verification", + "query": "update u_tbl7 set col7 = 'foo'", + "plan": { + "QueryType": "UPDATE", + "Original": "update u_tbl7 set col7 = 'foo'", + "Instructions": { + "OperatorType": "FkCascade", + "Inputs": [ + { + "InputName": "Selection", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select col7 from u_tbl7 where 1 != 1", + "Query": "select col7 from u_tbl7 lock in share mode", + "Table": "u_tbl7" + }, + { + "InputName": "CascadeChild-1", + "OperatorType": "FKVerify", + "BvName": "fkc_vals", + "Cols": [ + 0 + ], + "Inputs": [ + { + "InputName": "VerifyParent-1", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where 1 != 1", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = 'foo' where (u_tbl4.col4) in ::fkc_vals and u_tbl3.col3 is null limit 1", + "Table": "u_tbl3, u_tbl4" + }, + { + "InputName": "VerifyChild-2", + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", + "Query": "select 1 from u_tbl4, u_tbl9 where (u_tbl4.col4) in ::fkc_vals and u_tbl4.col4 = u_tbl9.col9 limit 1", + "Table": "u_tbl4, u_tbl9" + }, + { + "InputName": "PostVerify", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update /*+ SET_VAR(foreign_key_checks=OFF) */ u_tbl4 set col4 = 'foo' where (u_tbl4.col4) in ::fkc_vals", + "Table": "u_tbl4" + } + ] + }, + { + "InputName": "Parent", + "OperatorType": "Update", + "Variant": "Unsharded", + "Keyspace": { + "Name": "unsharded_fk_allow", + "Sharded": false + }, + "TargetTabletType": "PRIMARY", + "Query": "update u_tbl7 set col7 = 'foo'", + "Table": "u_tbl7" + } + ] + }, + "TablesUsed": [ + "unsharded_fk_allow.u_tbl3", + "unsharded_fk_allow.u_tbl4", + "unsharded_fk_allow.u_tbl7", + "unsharded_fk_allow.u_tbl9" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/update.go b/go/vt/vtgate/planbuilder/update.go index a729bc96c0a..496f8ddbf22 100644 --- a/go/vt/vtgate/planbuilder/update.go +++ b/go/vt/vtgate/planbuilder/update.go @@ -24,7 +24,6 @@ import ( "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" - "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -49,7 +48,7 @@ func gen4UpdateStmtPlanner( } if ks, tables := ctx.SemTable.SingleUnshardedKeyspace(); ks != nil { - if fkManagementNotRequiredForUpdate(ctx.SemTable, vschema, tables, updStmt.Exprs) { + if fkManagementNotRequiredForUpdate(ctx, vschema, tables, updStmt.Exprs) { plan := updateUnshardedShortcut(updStmt, ks, tables) plan = pushCommentDirectivesOnPlan(plan, updStmt) return newPlanResult(plan.Primitive(), operators.QualifiedTables(ks, tables)...), nil @@ -87,7 +86,7 @@ func gen4UpdateStmtPlanner( } // TODO: Handle all this in semantic analysis. -func fkManagementNotRequiredForUpdate(semTable *semantics.SemTable, vschema plancontext.VSchema, vTables []*vindexes.Table, updateExprs sqlparser.UpdateExprs) bool { +func fkManagementNotRequiredForUpdate(ctx *plancontext.PlanningContext, vschema plancontext.VSchema, vTables []*vindexes.Table, updateExprs sqlparser.UpdateExprs) bool { childFkMap := make(map[string][]vindexes.ChildFKInfo) // Find the foreign key mode and check for any managed child foreign keys. @@ -99,14 +98,14 @@ func fkManagementNotRequiredForUpdate(semTable *semantics.SemTable, vschema plan if ksMode != vschemapb.Keyspace_FK_MANAGED { continue } - childFks := vTable.ChildFKsNeedsHandling(vindexes.UpdateAction) + childFks := vTable.ChildFKsNeedsHandling(ctx.VerifyAllFKs, vindexes.UpdateAction) if len(childFks) > 0 { childFkMap[vTable.String()] = childFks } } getFKInfo := func(expr *sqlparser.UpdateExpr) ([]vindexes.ParentFKInfo, []vindexes.ChildFKInfo) { - tblInfo, err := semTable.TableInfoForExpr(expr.Name) + tblInfo, err := ctx.SemTable.TableInfoForExpr(expr.Name) if err != nil { return nil, nil } diff --git a/go/vt/vtgate/vindexes/foreign_keys.go b/go/vt/vtgate/vindexes/foreign_keys.go index 2d111d7e18e..9ac5d93e7a4 100644 --- a/go/vt/vtgate/vindexes/foreign_keys.go +++ b/go/vt/vtgate/vindexes/foreign_keys.go @@ -44,12 +44,13 @@ func (fk *ParentFKInfo) MarshalJSON() ([]byte, error) { }) } -func (fk *ParentFKInfo) String() string { +func (fk *ParentFKInfo) String(childTable *Table) string { var str strings.Builder - str.WriteString(fk.Table.Name.String()) + str.WriteString(childTable.String()) for _, column := range fk.ChildColumns { str.WriteString(column.String()) } + str.WriteString(fk.Table.String()) for _, column := range fk.ParentColumns { str.WriteString(column.String()) } @@ -88,12 +89,13 @@ func (fk *ChildFKInfo) MarshalJSON() ([]byte, error) { }) } -func (fk *ChildFKInfo) String() string { +func (fk *ChildFKInfo) String(parentTable *Table) string { var str strings.Builder - str.WriteString(fk.Table.Name.String()) + str.WriteString(fk.Table.String()) for _, column := range fk.ChildColumns { str.WriteString(column.String()) } + str.WriteString(parentTable.String()) for _, column := range fk.ParentColumns { str.WriteString(column.String()) } @@ -112,29 +114,20 @@ func NewChildFkInfo(childTbl *Table, fkDef *sqlparser.ForeignKeyDefinition) Chil } } -// AddForeignKey is for testing only. -func (vschema *VSchema) AddForeignKey(ksname, childTableName string, fkConstraint *sqlparser.ForeignKeyDefinition) error { - ks, ok := vschema.Keyspaces[ksname] - if !ok { - return fmt.Errorf("keyspace %s not found in vschema", ksname) - } - cTbl, ok := ks.Tables[childTableName] - if !ok { - return fmt.Errorf("child table %s not found in keyspace %s", childTableName, ksname) - } - parentTableName := fkConstraint.ReferenceDefinition.ReferencedTable.Name.String() - pTbl, ok := ks.Tables[parentTableName] - if !ok { - return fmt.Errorf("parent table %s not found in keyspace %s", parentTableName, ksname) - } - pTbl.ChildForeignKeys = append(pTbl.ChildForeignKeys, NewChildFkInfo(cTbl, fkConstraint)) - cTbl.ParentForeignKeys = append(cTbl.ParentForeignKeys, NewParentFkInfo(pTbl, fkConstraint)) - return nil -} - // ParentFKsNeedsHandling returns all the parent fk constraints on this table that are not shard scoped. -func (t *Table) ParentFKsNeedsHandling() (fks []ParentFKInfo) { +func (t *Table) ParentFKsNeedsHandling(verifyAllFKs bool, fkToIgnore string) (fks []ParentFKInfo) { for _, fk := range t.ParentForeignKeys { + // Check if we need to specifically ignore this foreign key + if fkToIgnore != "" && fk.String(t) == fkToIgnore { + continue + } + + // If we require all the foreign keys, add them all. + if verifyAllFKs { + fks = append(fks, fk) + continue + } + // If the keyspaces are different, then the fk definition // is going to go across shards. if fk.Table.Keyspace.Name != t.Keyspace.Name { @@ -156,7 +149,11 @@ func (t *Table) ParentFKsNeedsHandling() (fks []ParentFKInfo) { // ChildFKsNeedsHandling retuns the child foreign keys that needs to be handled by the vtgate. // This can be either the foreign key is not shard scoped or the child tables needs cascading. -func (t *Table) ChildFKsNeedsHandling(getAction func(fk ChildFKInfo) sqlparser.ReferenceAction) (fks []ChildFKInfo) { +func (t *Table) ChildFKsNeedsHandling(verifyAllFKs bool, getAction func(fk ChildFKInfo) sqlparser.ReferenceAction) (fks []ChildFKInfo) { + // If we require all the foreign keys, return the entire list. + if verifyAllFKs { + return t.ChildForeignKeys + } for _, fk := range t.ChildForeignKeys { // If the keyspaces are different, then the fk definition // is going to go across shards. @@ -225,3 +222,23 @@ func isShardScoped(pTable *Table, cTable *Table, pCols sqlparser.Columns, cCols } return true } + +// AddForeignKey is for testing only. +func (vschema *VSchema) AddForeignKey(ksname, childTableName string, fkConstraint *sqlparser.ForeignKeyDefinition) error { + ks, ok := vschema.Keyspaces[ksname] + if !ok { + return fmt.Errorf("keyspace %s not found in vschema", ksname) + } + cTbl, ok := ks.Tables[childTableName] + if !ok { + return fmt.Errorf("child table %s not found in keyspace %s", childTableName, ksname) + } + parentTableName := fkConstraint.ReferenceDefinition.ReferencedTable.Name.String() + pTbl, ok := ks.Tables[parentTableName] + if !ok { + return fmt.Errorf("parent table %s not found in keyspace %s", parentTableName, ksname) + } + pTbl.ChildForeignKeys = append(pTbl.ChildForeignKeys, NewChildFkInfo(cTbl, fkConstraint)) + cTbl.ParentForeignKeys = append(cTbl.ParentForeignKeys, NewParentFkInfo(pTbl, fkConstraint)) + return nil +} diff --git a/go/vt/vtgate/vindexes/foreign_keys_test.go b/go/vt/vtgate/vindexes/foreign_keys_test.go index b5dfa57fbce..147614edcbf 100644 --- a/go/vt/vtgate/vindexes/foreign_keys_test.go +++ b/go/vt/vtgate/vindexes/foreign_keys_test.go @@ -72,6 +72,8 @@ func TestTable_CrossShardParentFKs(t *testing.T) { name string table *Table wantCrossShardFKTables []string + verifyAllFKs bool + fkToIgnore string }{{ name: "No Parent FKs", table: &Table{ @@ -87,6 +89,15 @@ func TestTable_CrossShardParentFKs(t *testing.T) { ParentForeignKeys: []ParentFKInfo{pkInfo(unshardedTbl, []string{"col4"}, []string{"col1"})}, }, wantCrossShardFKTables: []string{}, + }, { + name: "Unsharded keyspace with verify all FKs", + verifyAllFKs: true, + table: &Table{ + ColumnVindexes: []*ColumnVindex{col1Vindex}, + Keyspace: uks2, + ParentForeignKeys: []ParentFKInfo{pkInfo(unshardedTbl, []string{"col4"}, []string{"col1"})}, + }, + wantCrossShardFKTables: []string{"t1"}, }, { name: "Keyspaces don't match", // parent table is on uks2 table: &Table{ @@ -94,6 +105,24 @@ func TestTable_CrossShardParentFKs(t *testing.T) { ParentForeignKeys: []ParentFKInfo{pkInfo(unshardedTbl, []string{"col4"}, []string{"col1"})}, }, wantCrossShardFKTables: []string{"t1"}, + }, { + name: "Keyspaces don't match with ignore fk", // parent table is on uks2 + fkToIgnore: "uks.col1uks2.t1col4", + table: &Table{ + Keyspace: uks, + ParentForeignKeys: []ParentFKInfo{pkInfo(unshardedTbl, []string{"col4"}, []string{"col1"})}, + }, + wantCrossShardFKTables: []string{}, + }, { + name: "Unsharded keyspace with verify all FKs and fk to ignore", + verifyAllFKs: true, + fkToIgnore: "uks2.col1uks2.t1col4", + table: &Table{ + ColumnVindexes: []*ColumnVindex{col1Vindex}, + Keyspace: uks2, + ParentForeignKeys: []ParentFKInfo{pkInfo(unshardedTbl, []string{"col4"}, []string{"col1"})}, + }, + wantCrossShardFKTables: []string{}, }, { name: "Column Vindexes don't match", // primary vindexes on different vindex type table: &Table{ @@ -137,7 +166,7 @@ func TestTable_CrossShardParentFKs(t *testing.T) { }} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - crossShardFks := tt.table.ParentFKsNeedsHandling() + crossShardFks := tt.table.ParentFKsNeedsHandling(tt.verifyAllFKs, tt.fkToIgnore) var crossShardFkTables []string for _, fk := range crossShardFks { crossShardFkTables = append(crossShardFkTables, fk.Table.Name.String()) @@ -185,6 +214,7 @@ func TestChildFKs(t *testing.T) { } tests := []struct { + verifyAllFKs bool name string table *Table expChildTbls []string @@ -203,6 +233,15 @@ func TestChildFKs(t *testing.T) { ChildForeignKeys: []ChildFKInfo{ckInfo(unshardedTbl, []string{"col4"}, []string{"col1"}, sqlparser.Restrict)}, }, expChildTbls: []string{}, + }, { + name: "restrict unsharded with verify all fks", + verifyAllFKs: true, + table: &Table{ + ColumnVindexes: []*ColumnVindex{col1Vindex}, + Keyspace: uks2, + ChildForeignKeys: []ChildFKInfo{ckInfo(unshardedTbl, []string{"col4"}, []string{"col1"}, sqlparser.Restrict)}, + }, + expChildTbls: []string{"t1"}, }, { name: "restrict shard scoped", table: &Table{ @@ -211,6 +250,15 @@ func TestChildFKs(t *testing.T) { ChildForeignKeys: []ChildFKInfo{ckInfo(shardedSingleColTbl, []string{"col1"}, []string{"col1"}, sqlparser.Restrict)}, }, expChildTbls: []string{}, + }, { + name: "restrict shard scoped with verify all fks", + verifyAllFKs: true, + table: &Table{ + ColumnVindexes: []*ColumnVindex{col1Vindex}, + Keyspace: sks, + ChildForeignKeys: []ChildFKInfo{ckInfo(shardedSingleColTbl, []string{"col1"}, []string{"col1"}, sqlparser.Restrict)}, + }, + expChildTbls: []string{"t1"}, }, { name: "restrict Keyspaces don't match", table: &Table{ @@ -246,7 +294,7 @@ func TestChildFKs(t *testing.T) { deleteAction := func(fk ChildFKInfo) sqlparser.ReferenceAction { return fk.OnDelete } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - childFks := tt.table.ChildFKsNeedsHandling(deleteAction) + childFks := tt.table.ChildFKsNeedsHandling(tt.verifyAllFKs, deleteAction) var actualChildTbls []string for _, fk := range childFks { actualChildTbls = append(actualChildTbls, fk.Table.Name.String())