diff --git a/coverage/coverage.html b/coverage/coverage.html index 5f75de3..bd12817 100644 --- a/coverage/coverage.html +++ b/coverage/coverage.html @@ -65,7 +65,7 @@ - + @@ -472,8 +472,15 @@ // DefaultBatchSize is the default batch size for bulk operations like // CreateItems. This value is used if the caller does not specify a size - // using the WithBatchSize(...) option. - DefaultBatchSize = 100 + // using the WithBatchSize(...) option. Note: some databases have a limit + // on the number of query parameters (postgres is currently 64k and sqlite + // is 32k) and/or size of a SQL statement (sqlite is currently 1bn bytes), + // so this value should be set to a value that is less than the limits for + // your target db. + // See: + // - https://www.postgresql.org/docs/current/limits.html + // - https://www.sqlite.org/limits.html + DefaultBatchSize = 1000 ) // VetForWriter provides an interface that Create and Update can use to vet the @@ -682,12 +689,12 @@ } } } + } - if opts.WithBeforeWrite != nil { - if err := opts.WithBeforeWrite(valCreateItems.Index(i).Interface()); err != nil { - return fmt.Errorf("%s: error before write: %w", op, err) - } - } + if opts.WithBeforeWrite != nil { + if err := opts.WithBeforeWrite(createItems); err != nil { + return fmt.Errorf("%s: error before write: %w", op, err) + } } db := rw.underlying.wrapped.WithContext(ctx) @@ -760,11 +767,9 @@ *opts.WithRowsAffected = tx.RowsAffected } if tx.RowsAffected > 0 && opts.WithAfterWrite != nil { - for i := 0; i < valCreateItems.Len(); i++ { - if err := opts.WithAfterWrite(valCreateItems.Index(i).Interface(), int(tx.RowsAffected)); err != nil { - return fmt.Errorf("%s: error after write: %w", op, err) - } - } + if err := opts.WithAfterWrite(createItems, int(tx.RowsAffected)); err != nil { + return fmt.Errorf("%s: error after write: %w", op, err) + } } return nil } @@ -1152,56 +1157,117 @@ } // DeleteItems will delete multiple items of the same type. Options supported: -// WithDebug, WithTable -func (rw *RW) DeleteItems(ctx context.Context, deleteItems []interface{}, opt ...Option) (int, error) { +// WithWhereClause, WithDebug, WithTable +func (rw *RW) DeleteItems(ctx context.Context, deleteItems interface{}, opt ...Option) (int, error) { const op = "dbw.DeleteItems" - if rw.underlying == nil { - return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) - } - if len(deleteItems) == 0 { - return noRowsAffected, fmt.Errorf("%s: no interfaces to delete: %w", op, ErrInvalidParameter) - } + switch { + case rw.underlying == nil: + return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter) + case isNil(deleteItems): + return noRowsAffected, fmt.Errorf("%s: no interfaces to delete: %w", op, ErrInvalidParameter) + } + valDeleteItems := reflect.ValueOf(deleteItems) + switch { + case valDeleteItems.Kind() != reflect.Slice: + return noRowsAffected, fmt.Errorf("%s: not a slice: %w", op, ErrInvalidParameter) + case valDeleteItems.Len() == 0: + return noRowsAffected, fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter) + + } if err := raiseErrorOnHooks(deleteItems); err != nil { return noRowsAffected, fmt.Errorf("%s: %w", op, err) } + opts := GetOpts(opt...) - if opts.WithLookup { - return noRowsAffected, fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter) - } - // verify that createItems are all the same type. + switch { + case opts.WithLookup: + return noRowsAffected, fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter) + case opts.WithVersion != nil: + return noRowsAffected, fmt.Errorf("%s: with version is not a supported option: %w", op, ErrInvalidParameter) + } + + // we need to dig out the stmt so in just a sec we can make sure the PKs are + // set for all the items, so we'll just use the first item to do so. + mDb := rw.underlying.wrapped.Model(valDeleteItems.Index(0).Interface()) + err := mDb.Statement.Parse(valDeleteItems.Index(0).Interface()) + switch { + case err != nil: + return noRowsAffected, fmt.Errorf("%s: (internal error) error parsing stmt: %w", op, err) + case err == nil && mDb.Statement.Schema == nil: + return noRowsAffected, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown) + } + + // verify that deleteItems are all the same type, among a myriad of + // other things on the set of items var foundType reflect.Type - for i, v := range deleteItems { + + for i := 0; i < valDeleteItems.Len(); i++ { if i == 0 { - foundType = reflect.TypeOf(v) - } - currentType := reflect.TypeOf(v) - if foundType != currentType { - return noRowsAffected, fmt.Errorf("%s: items contain disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter) + foundType = reflect.TypeOf(valDeleteItems.Index(i).Interface()) } + currentType := reflect.TypeOf(valDeleteItems.Index(i).Interface()) + switch { + case isNil(valDeleteItems.Index(i).Interface()) || currentType == nil: + return noRowsAffected, fmt.Errorf("%s: unable to determine type of item %d: %w", op, i, ErrInvalidParameter) + case foundType != currentType: + return noRowsAffected, fmt.Errorf("%s: items contain disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter) + } + if opts.WithWhereClause == "" { + // make sure the PK is set for the current item + reflectValue := reflect.Indirect(reflect.ValueOf(valDeleteItems.Index(i).Interface())) + for _, pf := range mDb.Statement.Schema.PrimaryFields { + if _, isZero := pf.ValueOf(ctx, reflectValue); isZero { + return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter) + } + } + } } + if opts.WithBeforeWrite != nil { if err := opts.WithBeforeWrite(deleteItems); err != nil { return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err) } } - rowsDeleted := 0 - for _, item := range deleteItems { - cnt, err := rw.Delete(ctx, item, - WithDebug(opts.WithDebug), - WithTable(opts.WithTable), - ) - rowsDeleted += cnt - if err != nil { - return rowsDeleted, fmt.Errorf("%s: %w", op, err) + + db := rw.underlying.wrapped.WithContext(ctx) + if opts.WithDebug { + db = db.Debug() + } + + if opts.WithWhereClause != "" { + where, args, err := rw.whereClausesFromOpts(ctx, valDeleteItems.Index(0).Interface(), opts) + if err != nil { + return noRowsAffected, fmt.Errorf("%s: %w", op, err) + } + db = db.Where(where, args...) + } + + switch { + case opts.WithTable != "": + db = db.Table(opts.WithTable) + default: + tabler, ok := valDeleteItems.Index(0).Interface().(tableNamer) + if ok { + db = db.Table(tabler.TableName()) } } - if rowsDeleted > 0 && opts.WithAfterWrite != nil { + + db = db.Delete(deleteItems) + if db.Error != nil { + return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error) + } + rowsDeleted := int(db.RowsAffected) + if rowsDeleted > 0 && opts.WithAfterWrite != nil { if err := opts.WithAfterWrite(deleteItems, int(rowsDeleted)); err != nil { return rowsDeleted, fmt.Errorf("%s: error after write: %w", op, err) } } return rowsDeleted, nil } + +type tableNamer interface { + TableName() string +}