diff --git a/pkg/sql/internal.go b/pkg/sql/internal.go index 0736ff93f5d1..89eade396f8c 100644 --- a/pkg/sql/internal.go +++ b/pkg/sql/internal.go @@ -159,13 +159,14 @@ func (ie *InternalExecutor) runWithEx( ctx context.Context, txn *kv.Txn, w ieResultWriter, + mode ieExecutionMode, sd *sessiondata.SessionData, stmtBuf *StmtBuf, wg *sync.WaitGroup, syncCallback func([]*streamingCommandResult), errCallback func(error), ) error { - ex, err := ie.initConnEx(ctx, txn, w, sd, stmtBuf, syncCallback) + ex, err := ie.initConnEx(ctx, txn, w, mode, sd, stmtBuf, syncCallback) if err != nil { return err } @@ -201,13 +202,19 @@ func (ie *InternalExecutor) initConnEx( ctx context.Context, txn *kv.Txn, w ieResultWriter, + mode ieExecutionMode, sd *sessiondata.SessionData, stmtBuf *StmtBuf, syncCallback func([]*streamingCommandResult), ) (*connExecutor, error) { clientComm := &internalClientComm{ w: w, + mode: mode, sync: syncCallback, + resetRowsAffected: func() { + var zero int + _ = w.addResult(ctx, ieIteratorResult{rowsAffected: &zero}) + }, } applicationStats := ie.s.sqlStats.GetApplicationStats(sd.ApplicationName, true /* internal */) @@ -565,7 +572,7 @@ func (ie *InternalExecutor) queryInternalBuffered( // We will run the query to completion, so we can use an async result // channel. rw := newAsyncIEResultChannel() - it, err := ie.execInternal(ctx, opName, rw, txn, sessionDataOverride, stmt, qargs...) + it, err := ie.execInternal(ctx, opName, rw, defaultIEExecutionMode, txn, sessionDataOverride, stmt, qargs...) if err != nil { return nil, nil, err } @@ -667,7 +674,11 @@ func (ie *InternalExecutor) ExecEx( // We will run the query to completion, so we can use an async result // channel. rw := newAsyncIEResultChannel() - it, err := ie.execInternal(ctx, opName, rw, txn, session, stmt, qargs...) + // Since we only return the number of rows affected as given by the + // rowsIterator, we execute this stmt in "rows affected" mode allowing the + // internal executor to transparently retry. + const mode = rowsAffectedIEExecutionMode + it, err := ie.execInternal(ctx, opName, rw, mode, txn, session, stmt, qargs...) if err != nil { return 0, err } @@ -706,7 +717,7 @@ func (ie *InternalExecutor) QueryIteratorEx( qargs ...interface{}, ) (isql.Rows, error) { return ie.execInternal( - ctx, opName, newSyncIEResultChannel(), txn, session, stmt, qargs..., + ctx, opName, newSyncIEResultChannel(), defaultIEExecutionMode, txn, session, stmt, qargs..., ) } @@ -787,6 +798,7 @@ func (ie *InternalExecutor) execInternal( ctx context.Context, opName string, rw *ieResultChannel, + mode ieExecutionMode, txn *kv.Txn, sessionDataOverride sessiondata.InternalExecutorOverride, stmt string, @@ -905,7 +917,7 @@ func (ie *InternalExecutor) execInternal( errCallback := func(err error) { _ = rw.addResult(ctx, ieIteratorResult{err: err}) } - err = ie.runWithEx(ctx, txn, rw, sd, stmtBuf, &wg, syncCallback, errCallback) + err = ie.runWithEx(ctx, txn, rw, mode, sd, stmtBuf, &wg, syncCallback, errCallback) if err != nil { return nil, err } @@ -1028,7 +1040,7 @@ func (ie *InternalExecutor) commitTxn(ctx context.Context) error { rw := newAsyncIEResultChannel() stmtBuf := NewStmtBuf() - ex, err := ie.initConnEx(ctx, ie.extraTxnState.txn, rw, sd, stmtBuf, nil /* syncCallback */) + ex, err := ie.initConnEx(ctx, ie.extraTxnState.txn, rw, defaultIEExecutionMode, sd, stmtBuf, nil /* syncCallback */) if err != nil { return errors.Wrap(err, "cannot create conn executor to commit txn") } @@ -1081,6 +1093,26 @@ func (ie *InternalExecutor) checkIfTxnIsConsistent(txn *kv.Txn) error { return nil } +// ieExecutionMode determines how the internal executor consumes the results of +// the statement evaluation. +type ieExecutionMode int + +const ( + // defaultIEExecutionMode is the execution mode in which the results of the + // statement evaluation are consumed according to the statement's type. + defaultIEExecutionMode ieExecutionMode = iota + // rowsAffectedIEExecutionMode is the execution mode in which the internal + // executor is only interested in the number of rows affected, regardless of + // the statement's type. + // + // With this mode, if a stmt encounters a retry error, the internal executor + // will proceed to transparently reset the number of rows affected (if any + // have been seen by the rowsIterator) and retry the corresponding command. + // Such behavior makes sense given that in production code at most one + // command in the StmtBuf results in "rows affected". + rowsAffectedIEExecutionMode +) + // internalClientComm is an implementation of ClientComm used by the // InternalExecutor. Result rows are streamed on the channel to the // ieResultWriter. @@ -1099,6 +1131,15 @@ type internalClientComm struct { // The results of the query execution will be written into w. w ieResultWriter + // mode determines how the results of the query execution are consumed. + mode ieExecutionMode + + // resetRowsAffected is a callback that sends a single ieIteratorResult + // object to w in order to set the number of rows affected to zero. Only + // used in rowsAffectedIEExecutionMode when discarding a result (indicating + // that a command will be retried). + resetRowsAffected func() + // sync, if set, is called whenever a Sync is executed with all accumulated // results since the last Sync. sync func([]*streamingCommandResult) @@ -1136,6 +1177,9 @@ func (icc *internalClientComm) createRes(pos CmdPos) *streamingCommandResult { // results slice at the moment and all previous results have been // "finalized"). icc.results = icc.results[:len(icc.results)-1] + if icc.mode == rowsAffectedIEExecutionMode { + icc.resetRowsAffected() + } }, } icc.results = append(icc.results, res) @@ -1231,6 +1275,14 @@ func (icc *internalClientComm) Close() {} // ClientPos is part of the ClientLock interface. func (icc *internalClientComm) ClientPos() CmdPos { + if icc.mode == rowsAffectedIEExecutionMode { + // With the "rows affected" mode, any command can be rewound since we + // assume that only a single command results in actual "rows affected", + // and in Discard we will reset the number to zero (if we were in + // process of evaluation that command when we encountered the retry + // error). + return -1 + } // Find the latest result that cannot be rewound. lastDelivered := CmdPos(-1) for _, r := range icc.results { diff --git a/pkg/sql/internal_test.go b/pkg/sql/internal_test.go index 6db7acf1e927..e7f8ca678ed6 100644 --- a/pkg/sql/internal_test.go +++ b/pkg/sql/internal_test.go @@ -714,6 +714,23 @@ func TestInternalExecutorRetryAfterRows(t *testing.T) { if !testutils.IsError(err, "inject_retry_errors_enabled") { t.Fatalf("expected to see injected retry error, got %v", err) } + + // Now verify that ExecEx correctly and transparently to us retries the + // stmt. + numRows, err := ie.ExecEx( + ctx, "read rows", nil, /* txn */ + sessiondata.InternalExecutorOverride{ + User: username.MakeSQLUsernameFromPreNormalizedString(username.RootUser), + InjectRetryErrorsEnabled: true, + }, + "SELECT * FROM test.t", + ) + if err != nil { + t.Fatal(err) + } + if numRows != 1 { + t.Fatalf("expected 1 rowsAffected, got %d", numRows) + } } // TODO(andrei): Test that descriptor leases are released by the