Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clause/clause.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func (c Clause) Build(builder Builder) {

const (
PrimaryKey string = "~~~py~~~" // primary key
PrimaryKeys string = "~~~ps~~~" // primary keys
CurrentTable string = "~~~ct~~~" // current table
Associations string = "~~~as~~~" // associations
)
Expand Down
83 changes: 72 additions & 11 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (db *DB) Save(value interface{}) (tx *DB) {
// First finds the first record ordered by primary key, matching given conditions conds
func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
})
if len(conds) > 0 {
if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 {
Expand Down Expand Up @@ -146,7 +146,7 @@ func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) {
// Last finds the last record ordered by primary key, matching given conditions conds
func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
Desc: true,
})
if len(conds) > 0 {
Expand All @@ -173,9 +173,10 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) {

// FindInBatches finds all records in batches of batchSize
func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB {
// Use PrimaryKeys to handle composite primary key situations
var (
tx = db.Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
}).Session(&Session{})
queryDB = tx
rowsAffected int64
Expand All @@ -199,6 +200,7 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat
}
}

find:
for {
result := queryDB.Limit(batchSize).Find(dest)
rowsAffected += result.RowsAffected
Expand Down Expand Up @@ -227,17 +229,76 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat

// Optimize for-break
resultsValue := reflect.Indirect(reflect.ValueOf(dest))
if result.Statement.Schema.PrioritizedPrimaryField == nil {
if result.Statement.Schema.PrioritizedPrimaryField == nil && result.Statement.Schema.PrimaryFields != nil && len(result.Statement.Schema.PrimaryFields) == 1 {
tx.AddError(ErrPrimaryKeyRequired)
break
}

primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired)
break
// The following will build a where clause like this:
// struct {
// col1 uint `gorm:"primaryKey;autoIncrement:false"`
// col2 uint `gorm:"primaryKey;autoIncrement:false"`
// col3 uint `gorm:"primaryKey;autoIncrement:false"`
// }
// last row returned was col1 = 2, col2 = 3, col3 = 5
// where clause will be generated as follows
// WHERE (col1 > 2 OR (col1 = 2 AND col2 > 3) OR (col1 = 2 AND col2 = 3 AND col3 > 5))
// Detect composite primary keys
if result.Statement.Schema.PrimaryFields != nil {
pkCount := len(result.Statement.Schema.PrimaryFields)

// Handle composite primary key Where clauses
if pkCount > 1 {
var f *schema.Field
var orClauses []clause.Expression
for i := 0; i < pkCount; i++ {
var andClauses []clause.Expression
// Build 1st column GT clause
if i == 0 {
f = result.Statement.Schema.PrimaryFields[i]
primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
break find
}
orClauses = append(orClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue})
} else {
// Build AND clause and append to OR clauses
for j := 0; j <= i; j++ {
f = result.Statement.Schema.PrimaryFields[j]
primaryValue, zero := f.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
break find
}
if j == i {
// Build current outer column GT clause
andClauses = append(andClauses, clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue})
} else {
// Build all other columns EQ clause
andClauses = append(andClauses, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: f.DBName}, Value: primaryValue})
}
}
orClauses = append(orClauses, clause.And(andClauses...))
}
}
queryDB = tx.Clauses(clause.Or(orClauses...))
} else {
primaryValue, zero := result.Statement.Schema.PrimaryFields[0].ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: result.Statement.Schema.PrimaryFields[0].DBName}, Value: primaryValue})
}
} else {
primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1))
if zero {
tx.AddError(ErrPrimaryKeyRequired) //nolint:typecheck,errcheck,gosec
break
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}
queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue})
}

tx.RowsAffected = rowsAffected
Expand Down Expand Up @@ -307,7 +368,7 @@ func (db *DB) assignInterfacesToValue(values ...interface{}) {
// // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"}
func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
queryTx := db.Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
})

if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 {
Expand Down Expand Up @@ -347,7 +408,7 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) {
func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) {
tx = db.getInstance()
queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey},
Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKeys},
})

result := queryTx.Find(dest, conds...)
Expand Down
61 changes: 48 additions & 13 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,27 +105,62 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
write(v.Raw, v.Alias)
}
case clause.Column:
if v.Table != "" {
if v.Table == clause.CurrentTable {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.')
}

if v.Name == clause.PrimaryKey {
// Handle composite primary keys explicitly
if v.Name == clause.PrimaryKeys {
if stmt.Schema == nil {
stmt.DB.AddError(ErrModelValueRequired)
} else if stmt.Schema.PrioritizedPrimaryField != nil {
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
} else if stmt.Schema.PrimaryFields != nil {
for idx, s := range stmt.Schema.PrimaryFieldDBNames {
if idx > 0 {
writer.WriteByte(',') //nolint:typecheck,errcheck,gosec
}
if v.Table != "" {
if v.Table == clause.CurrentTable {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
}
write(v.Raw, s)
}
} else if len(stmt.Schema.DBNames) > 0 {
if v.Table != "" {
if v.Table == clause.CurrentTable {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
}
write(v.Raw, stmt.Schema.DBNames[0])
} else {
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
}
} else {
write(v.Raw, v.Name)
if v.Table != "" {
if v.Table == clause.CurrentTable {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.') //nolint:typecheck,errcheck,gosec
}

if v.Name == clause.PrimaryKey {
switch {
case stmt.Schema == nil:
stmt.DB.AddError(ErrModelValueRequired) //nolint:typecheck,errcheck,gosec,staticcheck
case stmt.Schema.PrioritizedPrimaryField != nil:
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
case len(stmt.Schema.DBNames) > 0:
write(v.Raw, stmt.Schema.DBNames[0])
default:
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck,gosec,staticcheck
}
} else {
write(v.Raw, v.Name)
}
}

if v.Alias != "" {
Expand Down
69 changes: 69 additions & 0 deletions tests/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,75 @@ func TestFindInBatchesWithError(t *testing.T) {
}
}

func TestFindInBatchesCompositeKey(t *testing.T) {
coupons := []Coupon{
{AmountOff: 1.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{
{ProductId: "1", Desc: "find_in_batches"},
{ProductId: "2", Desc: "find_in_batches"},
{ProductId: "3", Desc: "find_in_batches"},
}},
{AmountOff: 2.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{
{ProductId: "1", Desc: "find_in_batches"},
{ProductId: "2", Desc: "find_in_batches"},
{ProductId: "3", Desc: "find_in_batches"},
}},
{AmountOff: 3.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{
{ProductId: "1", Desc: "find_in_batches"},
{ProductId: "2", Desc: "find_in_batches"},
{ProductId: "3", Desc: "find_in_batches"},
}},
{AmountOff: 4.0, PercentOff: 0.5, AppliesToProduct: []*CouponProduct{
{ProductId: "1", Desc: "find_in_batches"},
{ProductId: "2", Desc: "find_in_batches"},
{ProductId: "3", Desc: "find_in_batches"},
}},
}

DB.Create(&coupons)

var (
results []CouponProduct
lastBatch int
)

if result := DB.Table("coupon_products as cp").Where(&CouponProduct{Desc: "find_in_batches"}).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error {
lastBatch = batch

if tx.RowsAffected != 2 {
t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected)
}

if len(results) != 2 {
t.Errorf("Incorrect coupon_product length, expects: 2, got %v", len(results))
}

for idx := range results {
results[idx].Desc = results[idx].Desc + "_new"
}

if err := tx.Save(results).Error; err != nil {
t.Fatalf("failed to save coupon_product, got error %v", err)
}

return nil
}); result.Error != nil || result.RowsAffected != 12 {
t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected)
}

if lastBatch != 6 {
t.Errorf("incorrect final batch, expects: %v, got %v", 6, lastBatch)
}

var count int64
DB.Model(&CouponProduct{}).Where(&CouponProduct{Desc: "find_in_batches_new"}).Count(&count)
if count != 12 {
t.Errorf("incorrect count after update, expects: %v, got %v", 12, count)
}

DB.Unscoped().Where(&CouponProduct{Desc: "find_in_batches_new"}).Delete(&CouponProduct{})
DB.Unscoped().Where("id in (1,2,3)").Delete(&Coupon{})
}

func TestFillSmallerStruct(t *testing.T) {
user := User{Name: "SmallerUser", Age: 100}
DB.Save(&user)
Expand Down
Loading