Skip to content

Commit

Permalink
Add parameter type sensitivity to type encoding
Browse files Browse the repository at this point in the history
This requires recording the response to parsing to know what types are
expected for parameters, and then passing that information to the type
encoder.

Signed-off-by: Daniel Farina <daniel@heroku.com>
  • Loading branch information
Daniel Farina committed Jan 5, 2013
1 parent 48e71d6 commit 6235e1b
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 15 deletions.
6 changes: 6 additions & 0 deletions buf.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ func (b *readBuf) int32() (n int) {
return
}

func (b *readBuf) oid() (n oid) {
n = oid(binary.BigEndian.Uint32(*b))
*b = (*b)[4:]
return
}

func (b *readBuf) int16() (n int) {
n = int(binary.BigEndian.Uint16(*b))
*b = (*b)[2:]
Expand Down
32 changes: 20 additions & 12 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,19 @@ func (cn *conn) prepareTo(q, stmtName string) (_ driver.Stmt, err error) {
case '1', '2', 'N':
case 't':
st.nparams = int(r.int16())
st.paramTyps = make([]oid, st.nparams, st.nparams)

for i := 0; i < st.nparams; i += 1 {
st.paramTyps[i] = r.oid()
}
case 'T':
n := r.int16()
st.cols = make([]string, n)
st.ooid = make([]int, n)
st.rowTyps = make([]oid, n)
for i := range st.cols {
st.cols[i] = r.string()
r.next(6)
st.ooid[i] = r.int32()
st.rowTyps[i] = r.oid()
r.next(8)
}
case 'n':
Expand Down Expand Up @@ -390,14 +395,17 @@ func (cn *conn) auth(r *readBuf, o Values) {
}
}

type oid uint32

type stmt struct {
cn *conn
name string
query string
cols []string
nparams int
ooid []int
closed bool
cn *conn
name string
query string
cols []string
nparams int
rowTyps []oid
paramTyps []oid
closed bool
}

func (st *stmt) Close() (err error) {
Expand Down Expand Up @@ -470,11 +478,11 @@ func (st *stmt) exec(v []driver.Value) {
w.string(st.name)
w.int16(0)
w.int16(len(v))
for _, x := range v {
for i, x := range v {
if x == nil {
w.int32(-1)
} else {
b := encode(x)
b := encode(x, st.paramTyps[i])
w.int32(len(b))
w.bytes(b)
}
Expand Down Expand Up @@ -584,7 +592,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
dest[i] = nil
continue
}
dest[i] = decode(r.next(l), rs.st.ooid[i])
dest[i] = decode(r.next(l), rs.st.rowTyps[i])
}
return
default:
Expand Down
14 changes: 11 additions & 3 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,23 @@ import (
"time"
)

func encode(x interface{}) []byte {
func encode(x interface{}, pgtypoid oid) []byte {
switch v := x.(type) {
case int64:
return []byte(fmt.Sprintf("%d", v))
case float32, float64:
return []byte(fmt.Sprintf("%f", v))
case []byte:
return []byte(fmt.Sprintf("\\x%x", v))
if pgtypoid == t_bytea {
return []byte(fmt.Sprintf("\\x%x", v))
}

return v
case string:
if pgtypoid == t_bytea {
return []byte(fmt.Sprintf("\\x%x", v))
}

return []byte(v)
case bool:
return []byte(fmt.Sprintf("%t", v))
Expand All @@ -29,7 +37,7 @@ func encode(x interface{}) []byte {
panic("not reached")
}

func decode(s []byte, typ int) interface{} {
func decode(s []byte, typ oid) interface{} {
switch typ {
case t_bytea:
s = s[2:] // trim off "\\x"
Expand Down
18 changes: 18 additions & 0 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,21 @@ func TestStringWithNul(t *testing.T) {
"injection attacks may be plausible")
}
}

func TestByteToText(t *testing.T) {
db := openTestConn(t)
defer db.Close()

b := []byte("hello world")
row := db.QueryRow("SELECT $1::text", b)

var result []byte
err := row.Scan(&result)
if err != nil {
t.Fatal(err)
}

if string(result) != string(b) {
t.Fatalf("expected %v but got %v", b, result)
}
}

0 comments on commit 6235e1b

Please sign in to comment.