Skip to content

Commit 6235e1b

Browse files
author
Daniel Farina
committed
Add parameter type sensitivity to type encoding
This requires recording the response to parsing to know what types are expected for parameters, and then passing that information to the type encoder. Signed-off-by: Daniel Farina <daniel@heroku.com>
1 parent 48e71d6 commit 6235e1b

File tree

4 files changed

+55
-15
lines changed

4 files changed

+55
-15
lines changed

buf.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ func (b *readBuf) int32() (n int) {
1313
return
1414
}
1515

16+
func (b *readBuf) oid() (n oid) {
17+
n = oid(binary.BigEndian.Uint32(*b))
18+
*b = (*b)[4:]
19+
return
20+
}
21+
1622
func (b *readBuf) int16() (n int) {
1723
n = int(binary.BigEndian.Uint16(*b))
1824
*b = (*b)[2:]

conn.go

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,19 @@ func (cn *conn) prepareTo(q, stmtName string) (_ driver.Stmt, err error) {
191191
case '1', '2', 'N':
192192
case 't':
193193
st.nparams = int(r.int16())
194+
st.paramTyps = make([]oid, st.nparams, st.nparams)
195+
196+
for i := 0; i < st.nparams; i += 1 {
197+
st.paramTyps[i] = r.oid()
198+
}
194199
case 'T':
195200
n := r.int16()
196201
st.cols = make([]string, n)
197-
st.ooid = make([]int, n)
202+
st.rowTyps = make([]oid, n)
198203
for i := range st.cols {
199204
st.cols[i] = r.string()
200205
r.next(6)
201-
st.ooid[i] = r.int32()
206+
st.rowTyps[i] = r.oid()
202207
r.next(8)
203208
}
204209
case 'n':
@@ -390,14 +395,17 @@ func (cn *conn) auth(r *readBuf, o Values) {
390395
}
391396
}
392397

398+
type oid uint32
399+
393400
type stmt struct {
394-
cn *conn
395-
name string
396-
query string
397-
cols []string
398-
nparams int
399-
ooid []int
400-
closed bool
401+
cn *conn
402+
name string
403+
query string
404+
cols []string
405+
nparams int
406+
rowTyps []oid
407+
paramTyps []oid
408+
closed bool
401409
}
402410

403411
func (st *stmt) Close() (err error) {
@@ -470,11 +478,11 @@ func (st *stmt) exec(v []driver.Value) {
470478
w.string(st.name)
471479
w.int16(0)
472480
w.int16(len(v))
473-
for _, x := range v {
481+
for i, x := range v {
474482
if x == nil {
475483
w.int32(-1)
476484
} else {
477-
b := encode(x)
485+
b := encode(x, st.paramTyps[i])
478486
w.int32(len(b))
479487
w.bytes(b)
480488
}
@@ -584,7 +592,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
584592
dest[i] = nil
585593
continue
586594
}
587-
dest[i] = decode(r.next(l), rs.st.ooid[i])
595+
dest[i] = decode(r.next(l), rs.st.rowTyps[i])
588596
}
589597
return
590598
default:

encode.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,23 @@ import (
88
"time"
99
)
1010

11-
func encode(x interface{}) []byte {
11+
func encode(x interface{}, pgtypoid oid) []byte {
1212
switch v := x.(type) {
1313
case int64:
1414
return []byte(fmt.Sprintf("%d", v))
1515
case float32, float64:
1616
return []byte(fmt.Sprintf("%f", v))
1717
case []byte:
18-
return []byte(fmt.Sprintf("\\x%x", v))
18+
if pgtypoid == t_bytea {
19+
return []byte(fmt.Sprintf("\\x%x", v))
20+
}
21+
22+
return v
1923
case string:
24+
if pgtypoid == t_bytea {
25+
return []byte(fmt.Sprintf("\\x%x", v))
26+
}
27+
2028
return []byte(v)
2129
case bool:
2230
return []byte(fmt.Sprintf("%t", v))
@@ -29,7 +37,7 @@ func encode(x interface{}) []byte {
2937
panic("not reached")
3038
}
3139

32-
func decode(s []byte, typ int) interface{} {
40+
func decode(s []byte, typ oid) interface{} {
3341
switch typ {
3442
case t_bytea:
3543
s = s[2:] // trim off "\\x"

encode_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,21 @@ func TestStringWithNul(t *testing.T) {
100100
"injection attacks may be plausible")
101101
}
102102
}
103+
104+
func TestByteToText(t *testing.T) {
105+
db := openTestConn(t)
106+
defer db.Close()
107+
108+
b := []byte("hello world")
109+
row := db.QueryRow("SELECT $1::text", b)
110+
111+
var result []byte
112+
err := row.Scan(&result)
113+
if err != nil {
114+
t.Fatal(err)
115+
}
116+
117+
if string(result) != string(b) {
118+
t.Fatalf("expected %v but got %v", b, result)
119+
}
120+
}

0 commit comments

Comments
 (0)