Skip to content

Commit

Permalink
Merge pull request #57 from bmatsuo/bmatsuo/fix-cgo-arguments
Browse files Browse the repository at this point in the history
Fix cgo argument check panic for small slices from bytes.Buffer (and similar types)
  • Loading branch information
bmatsuo committed Mar 24, 2016
2 parents f9ee9a4 + 50ccdc0 commit cd9b251
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 35 deletions.
52 changes: 35 additions & 17 deletions lmdb/cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ package lmdb
#include "lmdbgo.h"
*/
import "C"
import "runtime"
import (
"runtime"
"unsafe"
)

// These flags are used exclusively for Cursor.Get.
const (
Expand Down Expand Up @@ -169,10 +172,9 @@ func (c *Cursor) getVal0(op uint) (key, val *C.MDB_val, err error) {
func (c *Cursor) getVal1(setkey []byte, op uint) (key, val *C.MDB_val, err error) {
key = new(C.MDB_val)
val = new(C.MDB_val)
kdata, kn := valBytes(setkey)
ret := C.lmdbgo_mdb_cursor_get1(
c._c,
kdata, C.size_t(kn),
unsafe.Pointer(&setkey[0]), C.size_t(len(setkey)),
(*C.MDB_val)(key), (*C.MDB_val)(val),
C.MDB_cursor_op(op),
)
Expand All @@ -186,28 +188,36 @@ func (c *Cursor) getVal1(setkey []byte, op uint) (key, val *C.MDB_val, err error
func (c *Cursor) getVal2(setkey, setval []byte, op uint) (key, val *C.MDB_val, err error) {
key = new(C.MDB_val)
val = new(C.MDB_val)
kdata, kn := valBytes(setkey)
vdata, vn := valBytes(setval)
ret := C.lmdbgo_mdb_cursor_get2(
c._c,
kdata, C.size_t(kn),
vdata, C.size_t(vn),
unsafe.Pointer(&setkey[0]), C.size_t(len(setkey)),
unsafe.Pointer(&setval[0]), C.size_t(len(setval)),
(*C.MDB_val)(key), (*C.MDB_val)(val),
C.MDB_cursor_op(op),
)
return key, val, operrno("mdb_cursor_get", ret)
}

func (c *Cursor) putNilKey(flags uint) error {
ret := C.lmdbgo_mdb_cursor_put2(c._c, nil, 0, nil, 0, C.uint(flags))
return operrno("mdb_cursor_put", ret)
}

// Put stores an item in the database.
//
// See mdb_cursor_put.
func (c *Cursor) Put(key, val []byte, flags uint) error {
kdata, kn := valBytes(key)
vdata, vn := valBytes(val)
if len(key) == 0 {
return c.putNilKey(flags)
}
vn := len(val)
if vn == 0 {
val = []byte{0}
}
ret := C.lmdbgo_mdb_cursor_put2(
c._c,
kdata, C.size_t(kn),
vdata, C.size_t(vn),
unsafe.Pointer(&key[0]), C.size_t(len(key)),
unsafe.Pointer(&val[0]), C.size_t(len(val)),
C.uint(flags),
)
return operrno("mdb_cursor_put", ret)
Expand All @@ -217,11 +227,14 @@ func (c *Cursor) Put(key, val []byte, flags uint) error {
// avoiding a memcopy. The returned byte slice is only valid in txn's thread,
// before it has terminated.
func (c *Cursor) PutReserve(key []byte, n int, flags uint) ([]byte, error) {
kdata, kn := valBytes(key)
if len(key) == 0 {
return nil, c.putNilKey(flags)
}

val := &C.MDB_val{mv_size: C.size_t(n)}
ret := C.lmdbgo_mdb_cursor_put1(
c._c,
kdata, C.size_t(kn),
unsafe.Pointer(&key[0]), C.size_t(len(key)),
(*C.MDB_val)(val),
C.uint(flags|C.MDB_RESERVE),
)
Expand All @@ -238,13 +251,18 @@ func (c *Cursor) PutReserve(key []byte, n int, flags uint) ([]byte, error) {
//
// See mdb_cursor_put.
func (c *Cursor) PutMulti(key []byte, page []byte, stride int, flags uint) error {
kdata, kn := valBytes(key)
vdata, _ := valBytes(page)
if len(key) == 0 {
return c.putNilKey(flags)
}
if len(page) == 0 {
page = []byte{0}
}

vn := WrapMulti(page, stride).Len()
ret := C.lmdbgo_mdb_cursor_putmulti(
c._c,
kdata, C.size_t(kn),
vdata, C.size_t(vn), C.size_t(stride),
unsafe.Pointer(&key[0]), C.size_t(len(key)),
unsafe.Pointer(&page[0]), C.size_t(vn), C.size_t(stride),
C.uint(flags|C.MDB_MULTIPLE),
)
return operrno("mdb_cursor_put", ret)
Expand Down
53 changes: 53 additions & 0 deletions lmdb/cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,59 @@ func TestCursor_Close(t *testing.T) {
}
}

func TestCursor_bytesBuffer(t *testing.T) {
env := setup(t)
defer clean(env, t)

db, err := openRoot(env, 0)
if err != nil {
t.Error(err)
return
}

err = env.Update(func(txn *Txn) (err error) {
cur, err := txn.OpenCursor(db)
if err != nil {
return err
}
defer cur.Close()
k := new(bytes.Buffer)
k.WriteString("hello")
v := new(bytes.Buffer)
v.WriteString("world")
return cur.Put(k.Bytes(), v.Bytes(), 0)
})
if err != nil {
t.Error(err)
return
}

err = env.View(func(txn *Txn) (err error) {
cur, err := txn.OpenCursor(db)
if err != nil {
return err
}
defer cur.Close()
k := new(bytes.Buffer)
k.WriteString("hello")
_k, v, err := cur.Get(k.Bytes(), nil, SetKey)
if err != nil {
return err
}
if !bytes.Equal(_k, k.Bytes()) {
return fmt.Errorf("unexpected key: %q", _k)
}
if !bytes.Equal(v, []byte("world")) {
return fmt.Errorf("unexpected value: %q", v)
}
return nil
})
if err != nil {
t.Error(err)
return
}
}

func TestCursor_PutReserve(t *testing.T) {
env := setup(t)
defer clean(env, t)
Expand Down
33 changes: 24 additions & 9 deletions lmdb/txn.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ func (txn *Txn) Get(dbi DBI, key []byte) ([]byte, error) {
val := new(C.MDB_val)
ret := C.lmdbgo_mdb_get(
txn._txn, C.MDB_dbi(dbi),
kdata, C.size_t(kn),
unsafe.Pointer(&kdata[0]), C.size_t(kn),
(*C.MDB_val)(val),
)
err := operrno("mdb_get", ret)
Expand All @@ -288,16 +288,29 @@ func (txn *Txn) Get(dbi DBI, key []byte) ([]byte, error) {
return txn.bytes(val), nil
}

func (txn *Txn) putNilKey(dbi DBI, flags uint) error {
// mdb_put with an empty key will always fail
ret := C.lmdbgo_mdb_put2(txn._txn, C.MDB_dbi(dbi), nil, 0, nil, 0, C.uint(flags))
return operrno("mdb_put", ret)
}

// Put stores an item in database dbi.
//
// See mdb_put.
func (txn *Txn) Put(dbi DBI, key []byte, val []byte, flags uint) error {
kdata, kn := valBytes(key)
vdata, vn := valBytes(val)
kn := len(key)
if kn == 0 {
return txn.putNilKey(dbi, flags)
}
vn := len(val)
if vn == 0 {
val = []byte{0}
}

ret := C.lmdbgo_mdb_put2(
txn._txn, C.MDB_dbi(dbi),
kdata, C.size_t(kn),
vdata, C.size_t(vn),
unsafe.Pointer(&key[0]), C.size_t(kn),
unsafe.Pointer(&val[0]), C.size_t(vn),
C.uint(flags),
)
return operrno("mdb_put", ret)
Expand All @@ -307,11 +320,13 @@ func (txn *Txn) Put(dbi DBI, key []byte, val []byte, flags uint) error {
// avoiding a memcopy. The returned byte slice is only valid in txn's thread,
// before it has terminated.
func (txn *Txn) PutReserve(dbi DBI, key []byte, n int, flags uint) ([]byte, error) {
kdata, kn := valBytes(key)
if len(key) == 0 {
return nil, txn.putNilKey(dbi, flags)
}
val := &C.MDB_val{mv_size: C.size_t(n)}
ret := C.lmdbgo_mdb_put1(
txn._txn, C.MDB_dbi(dbi),
kdata, C.size_t(kn),
unsafe.Pointer(&key[0]), C.size_t(len(key)),
(*C.MDB_val)(val),
C.uint(flags|C.MDB_RESERVE),
)
Expand All @@ -331,8 +346,8 @@ func (txn *Txn) Del(dbi DBI, key, val []byte) error {
vdata, vn := valBytes(val)
ret := C.lmdbgo_mdb_del(
txn._txn, C.MDB_dbi(dbi),
kdata, C.size_t(kn),
vdata, C.size_t(vn),
unsafe.Pointer(&kdata[0]), C.size_t(kn),
unsafe.Pointer(&vdata[0]), C.size_t(vn),
)
return operrno("mdb_del", ret)
}
Expand Down
69 changes: 69 additions & 0 deletions lmdb/txn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,35 @@ func TestTxn_Del_dup(t *testing.T) {
}
}

func TestTexn_Put_emptyValue(t *testing.T) {
env := setup(t)
defer clean(env, t)

var db DBI
err := env.Update(func(txn *Txn) (err error) {
db, err = txn.OpenRoot(0)
if err != nil {
return err
}
err = txn.Put(db, []byte("k"), nil, 0)
if err != nil {
return err
}
v, err := txn.Get(db, []byte("k"))
if err != nil {
return err
}
if len(v) != 0 {
t.Errorf("value: %q (!= \"\")", v)
}
return nil
})
if err != nil {
t.Error(err)
return
}
}

func TestTxn_PutReserve(t *testing.T) {
env := setup(t)
defer clean(env, t)
Expand Down Expand Up @@ -253,6 +282,46 @@ func TestTxn_PutReserve(t *testing.T) {
}
}

func TestTxn_bytesBuffer(t *testing.T) {
env := setup(t)
defer clean(env, t)

db, err := openRoot(env, 0)
if err != nil {
t.Error(err)
return
}

err = env.Update(func(txn *Txn) (err error) {
k := new(bytes.Buffer)
k.WriteString("hello")
v := new(bytes.Buffer)
v.WriteString("world")
return txn.Put(db, k.Bytes(), v.Bytes(), 0)
})
if err != nil {
t.Error(err)
return
}

err = env.View(func(txn *Txn) (err error) {
k := new(bytes.Buffer)
k.WriteString("hello")
v, err := txn.Get(db, k.Bytes())
if err != nil {
return err
}
if !bytes.Equal(v, []byte("world")) {
return fmt.Errorf("unexpected value: %q", v)
}
return nil
})
if err != nil {
t.Error(err)
return
}
}

func TestTxn_Put_overwrite(t *testing.T) {
env := setup(t)
defer clean(env, t)
Expand Down
12 changes: 7 additions & 5 deletions lmdb/val.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,19 @@ func (m *Multi) Page() []byte {
return m.page[:len(m.page):len(m.page)]
}

func valBytes(b []byte) (unsafe.Pointer, int) {
var eb = []byte{0}

func valBytes(b []byte) ([]byte, int) {
if len(b) == 0 {
return nil, 0
return eb, 0
}
return unsafe.Pointer(&b[0]), len(b)
return b, len(b)
}

func wrapVal(b []byte) *C.MDB_val {
ptr, n := valBytes(b)
p, n := valBytes(b)
return &C.MDB_val{
mv_data: unsafe.Pointer(ptr),
mv_data: unsafe.Pointer(&p[0]),
mv_size: C.size_t(n),
}
}
Expand Down
8 changes: 4 additions & 4 deletions lmdb/val_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ func TestMultiVal_panic(t *testing.T) {

func TestValBytes(t *testing.T) {
ptr, n := valBytes(nil)
if ptr != nil {
t.Errorf("unexpected non-nil pointer")
if len(ptr) == 0 {
t.Errorf("unexpected unaddressable slice")
}
if n != 0 {
t.Errorf("unexpected length: %d (expected 0)", n)
}

b := []byte("abc")
ptr, n = valBytes(b)
if ptr == nil {
t.Errorf("unexpected nil pointer")
if len(ptr) == 0 {
t.Errorf("unexpected unaddressable slice")
}
if n != 3 {
t.Errorf("unexpected length: %d (expected %d)", n, len(b))
Expand Down

0 comments on commit cd9b251

Please sign in to comment.