From 6ccbcef9ebfc32fc733b949d4889fd4aa3e423b3 Mon Sep 17 00:00:00 2001 From: ti-srebot <66930949+ti-srebot@users.noreply.github.com> Date: Tue, 8 Mar 2022 16:07:49 +0800 Subject: [PATCH] executor: fix load data panic if the data is broken at escape character (#30868) (#31774) close pingcap/tidb#31589 --- executor/load_data.go | 135 +++++++++-------------------------------- executor/write_test.go | 24 +------- 2 files changed, 30 insertions(+), 129 deletions(-) diff --git a/executor/load_data.go b/executor/load_data.go index 1202675ebbce0..7d124c0cacf3d 100644 --- a/executor/load_data.go +++ b/executor/load_data.go @@ -363,65 +363,20 @@ func (e *LoadDataInfo) SetMaxRowsInBatch(limit uint64) { e.curBatchCnt = 0 } -// getValidData returns prevData and curData that starts from starting symbol. -// If the data doesn't have starting symbol, prevData is nil and curData is curData[len(curData)-startingLen+1:]. -// If curData size less than startingLen, curData is returned directly. -func (e *LoadDataInfo) getValidData(prevData, curData []byte) ([]byte, []byte) { - startingLen := len(e.LinesInfo.Starting) - if startingLen == 0 { - return prevData, curData - } - - prevLen := len(prevData) - if prevLen > 0 { - // starting symbol in the prevData - idx := strings.Index(string(hack.String(prevData)), e.LinesInfo.Starting) - if idx != -1 { - return prevData[idx:], curData - } - - // starting symbol in the middle of prevData and curData - restStart := curData - if len(curData) >= startingLen { - restStart = curData[:startingLen-1] - } - prevData = append(prevData, restStart...) - idx = strings.Index(string(hack.String(prevData)), e.LinesInfo.Starting) - if idx != -1 { - return prevData[idx:prevLen], curData - } - } - - // starting symbol in the curData +// getValidData returns curData that starts from starting symbol. +// If the data doesn't have starting symbol, return curData[len(curData)-startingLen+1:] and false. +func (e *LoadDataInfo) getValidData(curData []byte) ([]byte, bool) { idx := strings.Index(string(hack.String(curData)), e.LinesInfo.Starting) - if idx != -1 { - return nil, curData[idx:] + if idx == -1 { + return curData[len(curData)-len(e.LinesInfo.Starting)+1:], false } - // no starting symbol - if len(curData) >= startingLen { - curData = curData[len(curData)-startingLen+1:] - } - return nil, curData -} - -func (e *LoadDataInfo) isInQuoter(bs []byte) bool { - inQuoter := false - for i := 0; i < len(bs); i++ { - switch bs[i] { - case e.FieldsInfo.Enclosed: - inQuoter = !inQuoter - case e.FieldsInfo.Escaped: - i++ - default: - } - } - return inQuoter + return curData[idx:], true } -// IndexOfTerminator return index of terminator, if not, return -1. +// indexOfTerminator return index of terminator, if not, return -1. // normally, the field terminator and line terminator is short, so we just use brute force algorithm. -func (e *LoadDataInfo) IndexOfTerminator(bs []byte, inQuoter bool) int { +func (e *LoadDataInfo) indexOfTerminator(bs []byte) int { fieldTerm := []byte(e.FieldsInfo.Terminated) fieldTermLen := len(fieldTerm) lineTerm := []byte(e.LinesInfo.Terminated) @@ -459,15 +414,16 @@ func (e *LoadDataInfo) IndexOfTerminator(bs []byte, inQuoter bool) int { } } atFieldStart := true + inQuoter := false loop: for i := 0; i < len(bs); i++ { - if atFieldStart && bs[i] == e.FieldsInfo.Enclosed { + if atFieldStart && e.FieldsInfo.Enclosed != byte(0) && bs[i] == e.FieldsInfo.Enclosed { inQuoter = !inQuoter atFieldStart = false continue } restLen := len(bs) - i - 1 - if inQuoter && bs[i] == e.FieldsInfo.Enclosed { + if inQuoter && e.FieldsInfo.Enclosed != byte(0) && bs[i] == e.FieldsInfo.Enclosed { // look ahead to see if it is end of line or field. switch cmpTerm(restLen, bs[i+1:]) { case lineTermType: @@ -505,67 +461,32 @@ loop: // getLine returns a line, curData, the next data start index and a bool value. // If it has starting symbol the bool is true, otherwise is false. func (e *LoadDataInfo) getLine(prevData, curData []byte, ignore bool) ([]byte, []byte, bool) { - startingLen := len(e.LinesInfo.Starting) - prevData, curData = e.getValidData(prevData, curData) - if prevData == nil && len(curData) < startingLen { - return nil, curData, false - } - inquotor := e.isInQuoter(prevData) - prevLen := len(prevData) - terminatedLen := len(e.LinesInfo.Terminated) - curStartIdx := 0 - if prevLen < startingLen { - curStartIdx = startingLen - prevLen - } - endIdx := -1 - if len(curData) >= curStartIdx { - if ignore { - endIdx = strings.Index(string(hack.String(curData[curStartIdx:])), e.LinesInfo.Terminated) - } else { - endIdx = e.IndexOfTerminator(curData[curStartIdx:], inquotor) - } - } - if endIdx == -1 { - // no terminated symbol - if len(prevData) == 0 { - return nil, curData, true - } - - // terminated symbol in the middle of prevData and curData + if prevData != nil { curData = append(prevData, curData...) - if ignore { - endIdx = strings.Index(string(hack.String(curData[startingLen:])), e.LinesInfo.Terminated) - } else { - endIdx = e.IndexOfTerminator(curData[startingLen:], inquotor) + } + startLen := len(e.LinesInfo.Starting) + if startLen != 0 { + if len(curData) < startLen { + return nil, curData, false } - if endIdx != -1 { - nextDataIdx := startingLen + endIdx + terminatedLen - return curData[startingLen : startingLen+endIdx], curData[nextDataIdx:], true + var ok bool + curData, ok = e.getValidData(curData) + if !ok { + return nil, curData, false } - // no terminated symbol - return nil, curData, true - } - - // terminated symbol in the curData - nextDataIdx := curStartIdx + endIdx + terminatedLen - if len(prevData) == 0 { - return curData[curStartIdx : curStartIdx+endIdx], curData[nextDataIdx:], true } - - // terminated symbol in the curData - prevData = append(prevData, curData[:nextDataIdx]...) + var endIdx int if ignore { - endIdx = strings.Index(string(hack.String(prevData[startingLen:])), e.LinesInfo.Terminated) + endIdx = strings.Index(string(hack.String(curData[startLen:])), e.LinesInfo.Terminated) } else { - endIdx = e.IndexOfTerminator(prevData[startingLen:], inquotor) + endIdx = e.indexOfTerminator(curData[startLen:]) } - if endIdx >= prevLen { - return prevData[startingLen : startingLen+endIdx], curData[nextDataIdx:], true + + if endIdx == -1 { + return nil, curData, true } - // terminated symbol in the middle of prevData and curData - lineLen := startingLen + endIdx + terminatedLen - return prevData[startingLen : startingLen+endIdx], curData[lineLen-prevLen:], true + return curData[startLen : startLen+endIdx], curData[startLen+endIdx+len(e.LinesInfo.Terminated):], true } // InsertData inserts data into specified table according to the specified format. diff --git a/executor/write_test.go b/executor/write_test.go index ec326d0b6d436..2584408e5d7c8 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/executor" "github.com/pingcap/tidb/kv" - "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/planner/core" @@ -2118,33 +2117,14 @@ func TestLoadDataEscape(t *testing.T) { {nil, []byte("7\trtn0ZbN\n"), []string{"7|" + string([]byte{'r', 't', 'n', '0', 'Z', 'b', 'N'})}, nil, trivialMsg}, {nil, []byte("8\trtn0Zb\\N\n"), []string{"8|" + string([]byte{'r', 't', 'n', '0', 'Z', 'b', 'N'})}, nil, trivialMsg}, {nil, []byte("9\ttab\\ tab\n"), []string{"9|tab tab"}, nil, trivialMsg}, + // data broken at escape character. + {[]byte("1\ta string\\"), []byte("\n1\n"), []string{"1|a string\n1"}, nil, trivialMsg}, } deleteSQL := "delete from load_data_test" selectSQL := "select * from load_data_test;" checkCases(tests, ld, t, tk, ctx, selectSQL, deleteSQL) } -func TestLoadDataWithLongContent(t *testing.T) { - e := &executor.LoadDataInfo{ - FieldsInfo: &ast.FieldsClause{Terminated: ",", Escaped: '\\', Enclosed: '"'}, - LinesInfo: &ast.LinesClause{Terminated: "\n"}, - } - tests := []struct { - content string - inQuoter bool - expectedIndex int - }{ - {"123,123\n123,123", false, 7}, - {"123123\\n123123", false, -1}, - {"123123\n123123", true, -1}, - {"123123\n123123\"\n", true, 14}, - } - - for _, tt := range tests { - require.Equal(t, tt.expectedIndex, e.IndexOfTerminator([]byte(tt.content), tt.inQuoter)) - } -} - // TestLoadDataSpecifiedColumns reuse TestLoadDataEscape's test case :-) func TestLoadDataSpecifiedColumns(t *testing.T) { trivialMsg := "Records: 1 Deleted: 0 Skipped: 0 Warnings: 0"