diff --git a/engine.go b/engine.go index 0c794512d..d8c99c181 100644 --- a/engine.go +++ b/engine.go @@ -177,6 +177,7 @@ func (engine *Engine) SupportInsertMany() bool { // QuoteStr Engine's database use which character as quote. // mysql, sqlite use ` and postgres use " +// Deprecated, use Quote() instead func (engine *Engine) QuoteStr() string { return engine.dialect.QuoteStr() } @@ -196,13 +197,10 @@ func (engine *Engine) Quote(value string) string { return value } - if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' { - return value - } - - value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1) + buf := builder.StringBuilder{} + engine.QuoteTo(&buf, value) - return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr() + return buf.String() } // QuoteTo quotes string and writes into the buffer @@ -216,20 +214,30 @@ func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) { return } - if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' { - buf.WriteString(value) + quotePair := engine.dialect.Quote("") + + if value[0] == '`' || len(quotePair) < 2 || value[0] == quotePair[0] { // no quote + _, _ = buf.WriteString(value) return + } else { + prefix, suffix := quotePair[0], quotePair[1] + + _ = buf.WriteByte(prefix) + for i := 0; i < len(value); i++ { + if value[i] == '.' { + _ = buf.WriteByte(suffix) + _ = buf.WriteByte('.') + _ = buf.WriteByte(prefix) + } else { + _ = buf.WriteByte(value[i]) + } + } + _ = buf.WriteByte(suffix) } - - value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1) - - buf.WriteString(engine.dialect.QuoteStr()) - buf.WriteString(value) - buf.WriteString(engine.dialect.QuoteStr()) } func (engine *Engine) quote(sql string) string { - return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr() + return engine.dialect.Quote(sql) } // SqlType will be deprecated, please use SQLType instead @@ -1581,7 +1589,7 @@ func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{ func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) { switch sqlTypeName { case core.Time: - s := t.Format("2006-01-02 15:04:05") //time.RFC3339 + s := t.Format("2006-01-02 15:04:05") // time.RFC3339 v = s[11:19] case core.Date: v = t.Format("2006-01-02") diff --git a/helpers.go b/helpers.go index db8fc581f..a31e922c0 100644 --- a/helpers.go +++ b/helpers.go @@ -281,7 +281,7 @@ func rValue(bean interface{}) reflect.Value { func rType(bean interface{}) reflect.Type { sliceValue := reflect.Indirect(reflect.ValueOf(bean)) - //return reflect.TypeOf(sliceValue.Interface()) + // return reflect.TypeOf(sliceValue.Interface()) return sliceValue.Type() } @@ -309,3 +309,24 @@ func sliceEq(left, right []string) bool { func indexName(tableName, idxName string) string { return fmt.Sprintf("IDX_%v_%v", tableName, idxName) } + +func eraseAny(value string, strToErase ...string) string { + if len(strToErase) == 0 { + return value + } + var replaceSeq []string + for _, s := range strToErase { + replaceSeq = append(replaceSeq, s, "") + } + + replacer := strings.NewReplacer(replaceSeq...) + + return replacer.Replace(value) +} + +func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string { + for i := range cols { + cols[i] = quoteFunc(cols[i]) + } + return strings.Join(cols, sep+" ") +} diff --git a/helpers_test.go b/helpers_test.go index d57c54aec..7e3171268 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -4,7 +4,11 @@ package xorm -import "testing" +import ( + "testing" + + "github.com/stretchr/testify/assert" +) func TestSplitTag(t *testing.T) { var cases = []struct { @@ -24,3 +28,19 @@ func TestSplitTag(t *testing.T) { } } } + +func TestEraseAny(t *testing.T) { + raw := "SELECT * FROM `table`.[table_name]" + assert.EqualValues(t, raw, eraseAny(raw)) + assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`")) + assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]")) +} + +func TestQuoteColumns(t *testing.T) { + cols := []string{"f1", "f2", "f3"} + quoteFunc := func(value string) string { + return "[" + value + "]" + } + + assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ",")) +} diff --git a/session_insert.go b/session_insert.go index 3cff48f61..713565661 100644 --- a/session_insert.go +++ b/session_insert.go @@ -242,23 +242,17 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error var sql string if session.engine.dialect.DBType() == core.ORACLE { - temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", + temp := fmt.Sprintf(") INTO %s (%v) VALUES (", session.engine.Quote(tableName), - session.engine.QuoteStr(), - strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), - session.engine.QuoteStr()) - sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL", + quoteColumns(colNames, session.engine.Quote, ",")) + sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL", session.engine.Quote(tableName), - session.engine.QuoteStr(), - strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), - session.engine.QuoteStr(), + quoteColumns(colNames, session.engine.Quote, ","), strings.Join(colMultiPlaces, temp)) } else { - sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", + sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)", session.engine.Quote(tableName), - session.engine.QuoteStr(), - strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), - session.engine.QuoteStr(), + quoteColumns(colNames, session.engine.Quote, ","), strings.Join(colMultiPlaces, "),(")) } res, err := session.exec(sql, args...) @@ -378,11 +372,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) { output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement) } if len(colPlaces) > 0 { - sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v)%s VALUES (%v)", + sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)", session.engine.Quote(tableName), - session.engine.QuoteStr(), - strings.Join(colNames, session.engine.Quote(", ")), - session.engine.QuoteStr(), + quoteColumns(colNames, session.engine.Quote, ","), output, colPlaces) } else { diff --git a/session_update.go b/session_update.go index 216c4e87d..85b0bb0bf 100644 --- a/session_update.go +++ b/session_update.go @@ -96,14 +96,15 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string, return ErrCacheFailed } kvs := strings.Split(strings.TrimSpace(sqls[1]), ",") + for idx, kv := range kvs { sps := strings.SplitN(kv, "=", 2) sps2 := strings.Split(sps[0], ".") colName := sps2[len(sps2)-1] - if strings.Contains(colName, "`") { - colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1)) - } else if strings.Contains(colName, session.engine.QuoteStr()) { - colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1)) + // treat quote prefix, suffix and '`' as quotes + quotes := append(strings.Split(session.engine.Quote(""), ""), "`") + if strings.ContainsAny(colName, strings.Join(quotes, "")) { + colName = strings.TrimSpace(eraseAny(colName, quotes...)) } else { session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName) return ErrCacheFailed @@ -221,19 +222,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } } - //for update action to like "column = column + ?" + // for update action to like "column = column + ?" incColumns := session.statement.getInc() for _, v := range incColumns { colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?") args = append(args, v.arg) } - //for update action to like "column = column - ?" + // for update action to like "column = column - ?" decColumns := session.statement.getDec() for _, v := range decColumns { colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?") args = append(args, v.arg) } - //for update action to like "column = expression" + // for update action to like "column = expression" exprColumns := session.statement.getExpr() for _, v := range exprColumns { colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr) @@ -382,7 +383,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6 } if cacher := session.engine.getCacher(tableName); cacher != nil && session.statement.UseCache { - //session.cacheUpdate(table, tableName, sqlStr, args...) + // session.cacheUpdate(table, tableName, sqlStr, args...) session.engine.logger.Debug("[cacheUpdate] clear table ", tableName) cacher.ClearIds(tableName) cacher.ClearBeans(tableName) diff --git a/session_update_test.go b/session_update_test.go index 415c699ff..c90ec5bd9 100644 --- a/session_update_test.go +++ b/session_update_test.go @@ -11,8 +11,8 @@ import ( "testing" "time" - "xorm.io/core" "github.com/stretchr/testify/assert" + "xorm.io/core" ) func TestUpdateMap(t *testing.T) { diff --git a/statement.go b/statement.go index 2dd9a3846..585378a8b 100644 --- a/statement.go +++ b/statement.go @@ -6,7 +6,6 @@ package xorm import ( "database/sql/driver" - "errors" "fmt" "reflect" "strings" @@ -398,7 +397,7 @@ func (statement *Statement) buildUpdates(bean interface{}, continue } } else { - //TODO: how to handler? + // TODO: how to handler? panic("not supported") } } else { @@ -579,21 +578,9 @@ func (statement *Statement) getExpr() map[string]exprParam { func (statement *Statement) col2NewColsWithQuote(columns ...string) []string { newColumns := make([]string, 0) + quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") for _, col := range columns { - col = strings.Replace(col, "`", "", -1) - col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1) - ccols := strings.Split(col, ",") - for _, c := range ccols { - fields := strings.Split(strings.TrimSpace(c), ".") - if len(fields) == 1 { - newColumns = append(newColumns, statement.Engine.quote(fields[0])) - } else if len(fields) == 2 { - newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+ - statement.Engine.quote(fields[1])) - } else { - panic(errors.New("unwanted colnames")) - } - } + newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...))) } return newColumns } @@ -764,7 +751,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition return statement } tbs := strings.Split(tp.TableName(), ".") - var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr()) + quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") + + var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) case *builder.Builder: @@ -774,7 +763,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition return statement } tbs := strings.Split(tp.TableName(), ".") - var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr()) + quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`") + + var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, "")) fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition) statement.joinArgs = append(statement.joinArgs, subQueryArgs...) default: @@ -1246,7 +1237,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) { var whereStr = sqls[1] - //TODO: for postgres only, if any other database? + // TODO: for postgres only, if any other database? var paraStr string if statement.Engine.dialect.DBType() == core.POSTGRES { paraStr = "$"