Skip to content

Commit

Permalink
Add additional methods to Batch similar to what exists on Queryx
Browse files Browse the repository at this point in the history
  • Loading branch information
dkropachev committed Jun 20, 2024
1 parent 207ba87 commit b4c49dd
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 42 deletions.
34 changes: 34 additions & 0 deletions batchx.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package gocqlx

import (
"fmt"

"github.com/gocql/gocql"
)

Expand All @@ -27,6 +29,38 @@ func (b *Batch) BindStruct(qry *Queryx, arg interface{}) error {
return nil
}

// Bind binds query parameters to values from args using a mapper.
// If value cannot be found an error is reported.
func (b *Batch) Bind(qry *Queryx, args ...interface{}) error {
if len(qry.Names) != len(args) {
return fmt.Errorf("query requires %d arguments, but %d provided", len(qry.Names), len(args))
}
b.Query(qry.Statement(), args...)
return nil
}

// BindMap binds query named parameters to values from arg using a mapper.
// If value cannot be found an error is reported.
func (b *Batch) BindMap(qry *Queryx, arg map[string]interface{}) error {
args, err := qry.bindMapArgs(arg)
if err != nil {
return err
}
b.Query(qry.Statement(), args...)
return nil
}

// BindStructMap binds query named parameters to values from arg0 and arg1 using a mapper.
// If value cannot be found an error is reported.
func (b *Batch) BindStructMap(qry *Queryx, arg0 interface{}, arg1 map[string]interface{}) error {
args, err := qry.bindStructArgs(arg0, arg1)
if err != nil {
return err
}
b.Query(qry.Statement(), args...)
return nil
}

// ExecuteBatch executes a batch operation and returns nil if successful
// otherwise an error describing the failure.
func (s *Session) ExecuteBatch(batch *Batch) error {
Expand Down
172 changes: 130 additions & 42 deletions batchx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,52 +52,140 @@ func TestBatch(t *testing.T) {
SongID: mustParseUUID("60fc234a-8481-4343-93bb-72ecab404863"),
}

insertSong := qb.Insert("batch_test.songs").
Columns("id", "title", "album", "artist", "tags", "data").Query(session)
insertPlaylist := qb.Insert("batch_test.playlists").
Columns("id", "title", "album", "artist", "song_id").Query(session)
selectSong := qb.Select("batch_test.songs").Where(qb.Eq("id")).Query(session)
selectPlaylist := qb.Select("batch_test.playlists").Where(qb.Eq("id")).Query(session)

t.Run("batch inserts", func(t *testing.T) {
t.Parallel()

type batchQry struct {
qry *gocqlx.Queryx
arg interface{}
}

qrys := []batchQry{
{qry: insertSong, arg: song},
{qry: insertPlaylist, arg: playlist},
}

b := session.NewBatch(gocql.LoggedBatch)
for _, qry := range qrys {
if err := b.BindStruct(qry.qry, qry.arg); err != nil {
t.Fatal("BindStruct failed:", err)
}
}
if err := session.ExecuteBatch(b); err != nil {
t.Fatal("batch execution:", err)
}

// verify song was inserted
var gotSong Song
if err := selectSong.BindStruct(song).Get(&gotSong); err != nil {
t.Fatal("select song:", err)
}
if diff := cmp.Diff(gotSong, song); diff != "" {
t.Errorf("expected %v song, got %v, diff: %q", song, gotSong, diff)
}

// verify playlist item was inserted
var gotPlayList PlaylistItem
if err := selectPlaylist.BindStruct(playlist).Get(&gotPlayList); err != nil {
t.Fatal("select song:", err)
tcases := []struct {
name string
methodSong func(*gocqlx.Batch, *gocqlx.Queryx, Song) error
methodPlaylist func(*gocqlx.Batch, *gocqlx.Queryx, PlaylistItem) error
}{
{
name: "BindStruct",
methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error {
return b.BindStruct(q, song)
},
methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error {
return b.BindStruct(q, playlist)
},
},
{
name: "BindMap",
methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error {
return b.BindMap(q, map[string]interface{}{
"id": song.ID,
"title": song.Title,
"album": song.Album,
"artist": song.Artist,
"tags": song.Tags,
"data": song.Data,
})
},
methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error {
return b.BindMap(q, map[string]interface{}{
"id": playlist.ID,
"title": playlist.Title,
"album": playlist.Album,
"artist": playlist.Artist,
"song_id": playlist.SongID,
})
},
},
{
name: "Bind",
methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error {
return b.Bind(q, song.ID, song.Title, song.Album, song.Artist, song.Tags, song.Data)
},
methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error {
return b.Bind(q, playlist.ID, playlist.Title, playlist.Album, playlist.Artist, playlist.SongID)
},
},
{
name: "BindStructMap",
methodSong: func(b *gocqlx.Batch, q *gocqlx.Queryx, song Song) error {
in := map[string]interface{}{
"title": song.Title,
"album": song.Album,
}
return b.BindStructMap(q, struct {
ID gocql.UUID
Artist string
Tags []string
Data []byte
}{
ID: song.ID,
Artist: song.Artist,
Tags: song.Tags,
Data: song.Data,
}, in)
},
methodPlaylist: func(b *gocqlx.Batch, q *gocqlx.Queryx, playlist PlaylistItem) error {
in := map[string]interface{}{
"title": playlist.Title,
"album": playlist.Album,
}
return b.BindStructMap(q, struct {
ID gocql.UUID
Artist string
SongID gocql.UUID
}{
ID: playlist.ID,
Artist: playlist.Artist,
SongID: playlist.SongID,
},
in,
)
},
},
}
if diff := cmp.Diff(gotPlayList, playlist); diff != "" {
t.Errorf("expected %v playList, got %v, diff: %q", playlist, gotPlayList, diff)
for _, tcase := range tcases {
t.Run(tcase.name, func(t *testing.T) {
insertSong := qb.Insert("batch_test.songs").
Columns("id", "title", "album", "artist", "tags", "data").Query(session)
insertPlaylist := qb.Insert("batch_test.playlists").
Columns("id", "title", "album", "artist", "song_id").Query(session)
selectSong := qb.Select("batch_test.songs").Where(qb.Eq("id")).Query(session)
selectPlaylist := qb.Select("batch_test.playlists").Where(qb.Eq("id")).Query(session)
deleteSong := qb.Delete("batch_test.songs").Where(qb.Eq("id")).Query(session)
deletePlaylist := qb.Delete("batch_test.playlists").Where(qb.Eq("id")).Query(session)

b := session.NewBatch(gocql.LoggedBatch)

if err = tcase.methodSong(b, insertSong, song); err != nil {
t.Fatal("insert song:", err)
}
if err = tcase.methodPlaylist(b, insertPlaylist, playlist); err != nil {
t.Fatal("insert playList:", err)
}

if err := session.ExecuteBatch(b); err != nil {
t.Fatal("batch execution:", err)
}

// verify song was inserted
var gotSong Song
if err := selectSong.BindStruct(song).Get(&gotSong); err != nil {
t.Fatal("select song:", err)
}
if diff := cmp.Diff(gotSong, song); diff != "" {
t.Errorf("expected %v song, got %v, diff: %q", song, gotSong, diff)
}

// verify playlist item was inserted
var gotPlayList PlaylistItem
if err := selectPlaylist.BindStruct(playlist).Get(&gotPlayList); err != nil {
t.Fatal("select playList:", err)
}
if diff := cmp.Diff(gotPlayList, playlist); diff != "" {
t.Errorf("expected %v playList, got %v, diff: %q", playlist, gotPlayList, diff)
}
if err = deletePlaylist.BindStruct(playlist).Exec(); err != nil {
t.Error("delete playlist:", err)
}
if err = deleteSong.BindStruct(song).Exec(); err != nil {
t.Error("delete song:", err)
}
})
}
})
}

0 comments on commit b4c49dd

Please sign in to comment.