Skip to content

Commit

Permalink
feat: refact DeleteItems(...) to batch deletes
Browse files Browse the repository at this point in the history
  • Loading branch information
jimlambrt committed Jul 27, 2024
1 parent 1d5f53b commit d99a9b0
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 32 deletions.
111 changes: 86 additions & 25 deletions delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,53 +70,114 @@ func (rw *RW) Delete(ctx context.Context, i interface{}, opt ...Option) (int, er
}

// DeleteItems will delete multiple items of the same type. Options supported:
// WithDebug, WithTable
func (rw *RW) DeleteItems(ctx context.Context, deleteItems []interface{}, opt ...Option) (int, error) {
// WithWhereClause, WithDebug, WithTable
func (rw *RW) DeleteItems(ctx context.Context, deleteItems interface{}, opt ...Option) (int, error) {
const op = "dbw.DeleteItems"
if rw.underlying == nil {
switch {
case rw.underlying == nil:
return noRowsAffected, fmt.Errorf("%s: missing underlying db: %w", op, ErrInvalidParameter)
}
if len(deleteItems) == 0 {
case isNil(deleteItems):
return noRowsAffected, fmt.Errorf("%s: no interfaces to delete: %w", op, ErrInvalidParameter)
}
valDeleteItems := reflect.ValueOf(deleteItems)
switch {
case valDeleteItems.Kind() != reflect.Slice:
return noRowsAffected, fmt.Errorf("%s: not a slice: %w", op, ErrInvalidParameter)
case valDeleteItems.Len() == 0:
return noRowsAffected, fmt.Errorf("%s: missing items: %w", op, ErrInvalidParameter)

}
if err := raiseErrorOnHooks(deleteItems); err != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}

opts := GetOpts(opt...)
if opts.WithLookup {
switch {
case opts.WithLookup:
return noRowsAffected, fmt.Errorf("%s: with lookup not a supported option: %w", op, ErrInvalidParameter)
case opts.WithVersion != nil:
return noRowsAffected, fmt.Errorf("%s: with version is not a supported option: %w", op, ErrInvalidParameter)
}
// verify that createItems are all the same type.

// we need to dig out the stmt so in just a sec we can make sure the PKs are
// set for all the items, so we'll just use the first item to do so.
mDb := rw.underlying.wrapped.Model(valDeleteItems.Index(0).Interface())
err := mDb.Statement.Parse(valDeleteItems.Index(0).Interface())
switch {
case err != nil:
return noRowsAffected, fmt.Errorf("%s: (internal error) error parsing stmt: %w", op, err)
case err == nil && mDb.Statement.Schema == nil:
return noRowsAffected, fmt.Errorf("%s: (internal error) unable to parse stmt: %w", op, ErrUnknown)
}

// verify that createItems are all the same type, among a myriad of
// other things on the set of items
var foundType reflect.Type
for i, v := range deleteItems {

for i := 0; i < valDeleteItems.Len(); i++ {
if i == 0 {
foundType = reflect.TypeOf(v)
foundType = reflect.TypeOf(valDeleteItems.Index(i).Interface())
}
currentType := reflect.TypeOf(v)
if foundType != currentType {
currentType := reflect.TypeOf(valDeleteItems.Index(i).Interface())
switch {
case isNil(valDeleteItems.Index(i).Interface()) || currentType == nil:
return noRowsAffected, fmt.Errorf("%s: unable to determine type of item %d: %w", op, i, ErrInvalidParameter)
case foundType != currentType:
return noRowsAffected, fmt.Errorf("%s: items contain disparate types. item %d is not a %s: %w", op, i, foundType.Name(), ErrInvalidParameter)
}
}
if opts.WithBeforeWrite != nil {
if err := opts.WithBeforeWrite(deleteItems); err != nil {
return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err)
if opts.WithWhereClause == "" {
// make sure the PK is set for the current item
reflectValue := reflect.Indirect(reflect.ValueOf(valDeleteItems.Index(i).Interface()))
for _, pf := range mDb.Statement.Schema.PrimaryFields {
if _, isZero := pf.ValueOf(ctx, reflectValue); isZero {
return noRowsAffected, fmt.Errorf("%s: primary key %s is not set: %w", op, pf.Name, ErrInvalidParameter)
}
}
}
if opts.WithBeforeWrite != nil {
if err := opts.WithBeforeWrite(valDeleteItems.Index(i).Interface()); err != nil {
return noRowsAffected, fmt.Errorf("%s: error before write: %w", op, err)
}
}
}
rowsDeleted := 0
for _, item := range deleteItems {
cnt, err := rw.Delete(ctx, item,
WithDebug(opts.WithDebug),
WithTable(opts.WithTable),
)
rowsDeleted += cnt
db := rw.underlying.wrapped.WithContext(ctx)
if opts.WithDebug {
db = db.Debug()
}

if opts.WithWhereClause != "" {
where, args, err := rw.whereClausesFromOpts(ctx, valDeleteItems.Index(0).Interface(), opts)
if err != nil {
return rowsDeleted, fmt.Errorf("%s: %w", op, err)
return noRowsAffected, fmt.Errorf("%s: %w", op, err)
}
db = db.Where(where, args...)
}

switch {
case opts.WithTable != "":
db = db.Table(opts.WithTable)
default:
tabler, ok := valDeleteItems.Index(0).Interface().(tableNamer)
if ok {
db = db.Table(tabler.TableName())
}
}

db = db.Delete(deleteItems)
if db.Error != nil {
return noRowsAffected, fmt.Errorf("%s: %w", op, db.Error)
}
rowsDeleted := int(db.RowsAffected)
if rowsDeleted > 0 && opts.WithAfterWrite != nil {
if err := opts.WithAfterWrite(deleteItems, int(rowsDeleted)); err != nil {
return rowsDeleted, fmt.Errorf("%s: error after write: %w", op, err)
for i := 0; i < valDeleteItems.Len(); i++ {
if err := opts.WithAfterWrite(valDeleteItems.Index(i).Interface(), int(rowsDeleted)); err != nil {
return rowsDeleted, fmt.Errorf("%s: error after write: %w", op, err)
}
}
}
return rowsDeleted, nil
}

type tableNamer interface {
TableName() string
}
120 changes: 114 additions & 6 deletions delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ package dbw_test
import (
"context"
"errors"
"fmt"
"testing"

"github.com/hashicorp/go-dbw"
"github.com/hashicorp/go-dbw/internal/dbtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm/schema"
)

func TestDb_Delete(t *testing.T) {
Expand Down Expand Up @@ -278,8 +280,8 @@ func TestDb_DeleteItems(t *testing.T) {
testWithTableUser, err := dbtest.NewTestUser()
require.NoError(t, err)

createFn := func() []interface{} {
results := []interface{}{}
createFn := func() interface{} {
results := []*dbtest.TestUser{}
for i := 0; i < 10; i++ {
u := testUser(t, testRw, "", "", "")
results = append(results, u)
Expand Down Expand Up @@ -312,8 +314,9 @@ func TestDb_DeleteItems(t *testing.T) {
}

type args struct {
deleteItems []interface{}
opt []dbw.Option
deleteItems interface{}
opt []dbw.Option
deleteItemsIds []string
}
tests := []struct {
name string
Expand All @@ -324,12 +327,14 @@ func TestDb_DeleteItems(t *testing.T) {
wantOplogMsgs bool
wantErr bool
wantErrIs error
wantErrContains string
}{
{
name: "simple",
rw: dbw.New(db),
args: args{
deleteItems: createFn(),
opt: []dbw.Option{dbw.WithDebug(true)},
},
wantRowsDeleted: 10,
wantErr: false,
Expand All @@ -344,6 +349,36 @@ func TestDb_DeleteItems(t *testing.T) {
wantRowsDeleted: 10,
wantErr: false,
},
{
name: "success-WithWhereClause",
rw: dbw.New(db),
args: func() args {
users := []*dbtest.TestUser{}
for i := 0; i < 10; i++ {
u := testUser(t, testRw, fmt.Sprintf("name-%d", i), "", "")
users = append(users, u)
}
return args{
deleteItems: users,
opt: []dbw.Option{
dbw.WithWhere("name in(?,?)", "name-0", "name-1"),
dbw.WithDebug(true),
},
deleteItemsIds: []string{users[0].PublicId, users[1].PublicId},
}
}(),
wantRowsDeleted: 2,
},
{
name: "err-bad-where-clause",
rw: dbw.New(db),
args: args{
deleteItems: createFn(),
opt: []dbw.Option{dbw.WithWhere("not a valid where clause")},
},
wantErr: true,
wantErrContains: "syntax error",
},
{
name: "with-table-fail",
rw: dbw.New(db),
Expand Down Expand Up @@ -436,6 +471,66 @@ func TestDb_DeleteItems(t *testing.T) {
wantErr: true,
wantErrIs: dbw.ErrInvalidParameter,
},
{
name: "err-not-slice",
rw: dbw.New(db),
args: args{
deleteItems: "not-a-slice",
},
wantErr: true,
wantErrIs: dbw.ErrInvalidParameter,
wantErrContains: "not a slice",
},
{
name: "err-WithVersion",
rw: dbw.New(db),
args: args{
deleteItems: createFn(),
opt: []dbw.Option{dbw.WithVersion(func() *uint32 { i := uint32(1); return &i }())},
},
wantErr: true,
wantErrIs: dbw.ErrInvalidParameter,
wantErrContains: "with version is not a supported option",
},
{
name: "err-parse-stmt",
rw: dbw.New(db),
args: args{
deleteItems: []int{1, 2, 3},
},
wantErr: true,
wantErrIs: schema.ErrUnsupportedDataType,
wantErrContains: "error parsing stmt: unsupported data type",
},
{
name: "err-slice-contains-nil-item",
rw: dbw.New(db),
args: args{
deleteItems: func() interface{} {
items := createFn()
items = append(items.([]*dbtest.TestUser), nil)
return items
}(),
},
wantErr: true,
wantErrIs: dbw.ErrInvalidParameter,
wantErrContains: "unable to determine type of item ",
},
{
name: "err-empty-pk",
rw: dbw.New(db),
args: args{
deleteItems: func() interface{} {
items := createFn()
users := items.([]*dbtest.TestUser)
users[len(users)-1].PublicId = ""
return items
}(),
},
wantErr: true,
wantErrIs: dbw.ErrInvalidParameter,
wantErrContains: "primary key PublicId is not set",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -445,14 +540,27 @@ func TestDb_DeleteItems(t *testing.T) {
require.Error(err)
if tt.wantErrIs != nil {
assert.ErrorIs(err, tt.wantErrIs)
fmt.Printf("error is: %T\n", errors.Unwrap(err))
}
if tt.wantErrContains != "" {
assert.ErrorContains(err, tt.wantErrContains)
}
return
}
require.NoError(err)
assert.Equal(tt.wantRowsDeleted, rowsDeleted)
for _, item := range tt.args.deleteItems {
var deletedIds []string
switch {
case len(tt.args.deleteItemsIds) > 0:
deletedIds = tt.args.deleteItemsIds
default:
for _, item := range tt.args.deleteItems.([]*dbtest.TestUser) {
deletedIds = append(deletedIds, item.PublicId)
}
}
for _, id := range deletedIds {
u := dbtest.AllocTestUser()
u.PublicId = item.(*dbtest.TestUser).PublicId
u.PublicId = id
err := tt.rw.LookupByPublicId(context.Background(), &u)
require.Error(err)
assert.ErrorIs(err, dbw.ErrRecordNotFound)
Expand Down
2 changes: 1 addition & 1 deletion writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ type Writer interface {
// is returned the caller must decide what to do with the transaction, which
// almost always should be to rollback. Delete returns the number of rows
// deleted or an error.
DeleteItems(ctx context.Context, deleteItems []interface{}, opt ...Option) (int, error)
DeleteItems(ctx context.Context, deleteItems interface{}, opt ...Option) (int, error)

// Exec will execute the sql with the values as parameters. The int returned
// is the number of rows affected by the sql. No options are currently
Expand Down

0 comments on commit d99a9b0

Please sign in to comment.