Skip to content
This repository has been archived by the owner on Sep 7, 2021. It is now read-only.
This repository is currently being migrated. It's locked while the migration is in progress.

fix statement.LimitN(0) will delete or update all data #1119

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions session_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
if err != nil {
return 0, err
}
if len(condSQL) == 0 && session.statement.LimitN == 0 {
pLimitN := session.statement.LimitN
if len(condSQL) == 0 && (pLimitN == nil || *pLimitN == 0) {
return 0, ErrNeedDeletedCond
}

Expand All @@ -115,8 +116,9 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
if len(session.statement.OrderStr) > 0 {
orderSQL += fmt.Sprintf(" ORDER BY %s", session.statement.OrderStr)
}
if session.statement.LimitN > 0 {
orderSQL += fmt.Sprintf(" LIMIT %d", session.statement.LimitN)
if pLimitN != nil && *pLimitN > 0 {
limitNValue := *pLimitN
orderSQL += fmt.Sprintf(" LIMIT %d", limitNValue)
}

if len(orderSQL) > 0 {
Expand All @@ -135,7 +137,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
} else {
deleteSQL += " WHERE " + inSQL
}
// TODO: how to handle delete limit on mssql?
// TODO: how to handle delete limit on mssql?
case core.MSSQL:
return 0, ErrNotImplemented
default:
Expand Down Expand Up @@ -176,7 +178,7 @@ func (session *Session) Delete(bean interface{}) (int64, error) {
} else {
realSQL += " WHERE " + inSQL
}
// TODO: how to handle delete limit on mssql?
// TODO: how to handle delete limit on mssql?
case core.MSSQL:
return 0, ErrNotImplemented
default:
Expand Down
12 changes: 6 additions & 6 deletions session_iterate.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error {
}

var bufferSize = session.statement.bufferSize
var limit = session.statement.LimitN
if limit > 0 && bufferSize > limit {
bufferSize = limit
var pLimitN = session.statement.LimitN
if pLimitN != nil && bufferSize > *pLimitN {
bufferSize = *pLimitN
}
var start = session.statement.Start
v := rValue(bean)
Expand All @@ -83,11 +83,11 @@ func (session *Session) bufferIterate(bean interface{}, fun IterFunc) error {
}

start = start + slice.Elem().Len()
if limit > 0 && idx+bufferSize > limit {
bufferSize = limit - idx
if pLimitN != nil && idx+bufferSize > *pLimitN {
bufferSize = *pLimitN - idx
}

if bufferSize <= 0 || slice.Elem().Len() < bufferSize || idx == limit {
if bufferSize <= 0 || slice.Elem().Len() < bufferSize || (pLimitN != nil && idx == *pLimitN) {
break
}
}
Expand Down
13 changes: 7 additions & 6 deletions session_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,12 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
var tableName = session.statement.TableName()
// TODO: Oracle support needed
var top string
if st.LimitN > 0 {
if st.LimitN != nil {
limitValue := *st.LimitN
if st.Engine.dialect.DBType() == core.MYSQL {
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
condSQL = condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
} else if st.Engine.dialect.DBType() == core.SQLITE {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("rowid IN (SELECT rowid FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
condSQL, condArgs, err = builder.ToSQL(cond)
Expand All @@ -311,7 +312,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condSQL = "WHERE " + condSQL
}
} else if st.Engine.dialect.DBType() == core.POSTGRES {
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", st.LimitN)
tempCondSQL := condSQL + fmt.Sprintf(" LIMIT %d", limitValue)
cond = cond.And(builder.Expr(fmt.Sprintf("CTID IN (SELECT CTID FROM %v %v)",
session.engine.Quote(tableName), tempCondSQL), condArgs...))
condSQL, condArgs, err = builder.ToSQL(cond)
Expand All @@ -326,7 +327,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
if st.OrderStr != "" && st.Engine.dialect.DBType() == core.MSSQL &&
table != nil && len(table.PrimaryKeys) == 1 {
cond = builder.Expr(fmt.Sprintf("%s IN (SELECT TOP (%d) %s FROM %v%v)",
table.PrimaryKeys[0], st.LimitN, table.PrimaryKeys[0],
table.PrimaryKeys[0], limitValue, table.PrimaryKeys[0],
session.engine.Quote(tableName), condSQL), condArgs...)

condSQL, condArgs, err = builder.ToSQL(cond)
Expand All @@ -337,7 +338,7 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
condSQL = "WHERE " + condSQL
}
} else {
top = fmt.Sprintf("TOP (%d) ", st.LimitN)
top = fmt.Sprintf("TOP (%d) ", limitValue)
}
}
}
Expand Down
31 changes: 19 additions & 12 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type Statement struct {
RefTable *core.Table
Engine *Engine
Start int
LimitN int
LimitN *int
idParam *core.PK
OrderStr string
JoinStr string
Expand Down Expand Up @@ -66,7 +66,7 @@ type Statement struct {
func (statement *Statement) Init() {
statement.RefTable = nil
statement.Start = 0
statement.LimitN = 0
statement.LimitN = nil
statement.OrderStr = ""
statement.UseCascade = true
statement.JoinStr = ""
Expand Down Expand Up @@ -689,7 +689,7 @@ func (statement *Statement) Top(limit int) *Statement {

// Limit generate LIMIT start, limit statement
func (statement *Statement) Limit(limit int, start ...int) *Statement {
statement.LimitN = limit
statement.LimitN = &limit
Copy link
Member

@BetaCat0 BetaCat0 Oct 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think statement.LimitN here will never get a nil val since you've already passed an non-pointer value within param limit. Furthermore, you should also write some cases to test your code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if limit > 0

Copy link
Member

@BetaCat0 BetaCat0 Oct 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And maybe we still need an extra param to indicate that the value "0" in param limit is specified intentionally or not(considering using a bool param?).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a critical bug, it did delete all my table data, but i just want to delete none

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Frank-hust I will accept your PR, could you please resolve the conflict and add some tests for it?

if len(start) > 0 {
statement.Start = start[0]
}
Expand Down Expand Up @@ -1062,9 +1062,11 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
}

pLimitN := statement.LimitN
if dialect.DBType() == core.MSSQL {
if statement.LimitN > 0 {
top = fmt.Sprintf(" TOP %d ", statement.LimitN)
if pLimitN != nil {
LimitNValue := *pLimitN
top = fmt.Sprintf(" TOP %d ", LimitNValue)
}
if statement.Start > 0 {
var column string
Expand Down Expand Up @@ -1125,16 +1127,20 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n
if needLimit {
if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
if statement.Start > 0 {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
} else if statement.LimitN > 0 {
fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
if pLimitN != nil {
fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", *pLimitN, statement.Start)
} else {
fmt.Fprintf(&buf, "LIMIT 0 OFFSET %v", statement.Start)
}
} else if pLimitN != nil {
fmt.Fprint(&buf, " LIMIT ", *pLimitN)
}
} else if dialect.DBType() == core.ORACLE {
if statement.Start != 0 || statement.LimitN != 0 {
if statement.Start != 0 || pLimitN != nil {
oldString := buf.String()
buf.Reset()
fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start)
columnStr, columnStr, oldString, statement.Start+*pLimitN, statement.Start)
}
}
}
Expand Down Expand Up @@ -1191,8 +1197,9 @@ func (statement *Statement) convertIDSQL(sqlStr string) string {
}

var top string
if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
top = fmt.Sprintf("TOP %d ", statement.LimitN)
pLimitN := statement.LimitN
if pLimitN != nil && statement.Engine.dialect.DBType() == core.MSSQL {
top = fmt.Sprintf("TOP %d ", *pLimitN)
}

newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
Expand Down