Skip to content

Commit

Permalink
pgstmt: handle raw (#26)
Browse files Browse the repository at this point in the history
* pgstmt: handle raw

* add timestamp test
  • Loading branch information
acoshift authored Dec 21, 2022
1 parent 26d9674 commit 2760bed
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 6 deletions.
10 changes: 10 additions & 0 deletions pgstmt/arg.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ func Arg(v interface{}) interface{} {
return arg{v}
case arg:
case notArg:
case raw:
case _any:
case all:
case defaultValue:
Expand All @@ -30,6 +31,15 @@ type notArg struct {
value interface{}
}

// Raw marks value as raw sql without escape
func Raw(v interface{}) interface{} {
return raw{v}
}

type raw struct {
value interface{}
}

// Any marks value as any($?)
func Any(v interface{}) interface{} {
return _any{v}
Expand Down
16 changes: 13 additions & 3 deletions pgstmt/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import (
"fmt"
"strconv"
"strings"
"time"

"github.com/lib/pq"
)

type buffer struct {
Expand Down Expand Up @@ -49,7 +52,7 @@ func build(b *buffer) (string, []interface{}) {
for _, x := range p {
switch x := x.(type) {
default:
q = append(q, convertToString(x))
q = append(q, convertToString(x, false))
case builder:
q = append(q, f(x.build(), " "))
case arg:
Expand Down Expand Up @@ -80,11 +83,14 @@ func build(b *buffer) (string, []interface{}) {
return query, args
}

func convertToString(x interface{}) string {
func convertToString(x interface{}, quoteStr bool) string {
switch x := x.(type) {
default:
return fmt.Sprint(x)
case string:
if quoteStr {
return pq.QuoteLiteral(x)
}
return x
case int:
return strconv.Itoa(x)
Expand All @@ -94,8 +100,12 @@ func convertToString(x interface{}) string {
return strconv.FormatInt(x, 10)
case bool:
return strconv.FormatBool(x)
case time.Time:
return convertToString(string(pq.FormatTimestamp(x)), true)
case notArg:
return convertToString(x.value)
return convertToString(x.value, true)
case raw:
return fmt.Sprint(x.value)
case defaultValue:
return "default"
}
Expand Down
2 changes: 1 addition & 1 deletion pgstmt/cond.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (st *cond) Op(field, op string, value interface{}) {
func (st *cond) OpRaw(field, op string, rawValue interface{}) {
var x group
x.sep = " "
x.push(field, op, rawValue)
x.push(field, op, Raw(rawValue))
st.ops.push(&x)
}

Expand Down
7 changes: 5 additions & 2 deletions pgstmt/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgstmt_test

import (
"testing"
"time"

"github.com/stretchr/testify/assert"

Expand All @@ -15,7 +16,7 @@ func TestUpdate(t *testing.T) {
q, args := pgstmt.Update(func(b pgstmt.UpdateStatement) {
b.Table("users")
b.Set("name").To("test")
b.Set("email", "address", "updated_at").To("test@localhost", "123", pgstmt.NotArg("now()"))
b.Set("email", "address", "updated_at").To("test@localhost", "123", pgstmt.Raw("now()"))
b.Set("age").ToRaw(1)
b.Where(func(b pgstmt.Cond) {
b.Eq("id", 5)
Expand Down Expand Up @@ -89,6 +90,7 @@ func TestUpdate(t *testing.T) {
b.Set("name").ToRaw("p.name")
b.Set("address").ToRaw("p.address")
b.Set("updated_at").ToRaw("now()")
b.Set("date").To(pgstmt.NotArg(time.Date(2022, 1, 2, 3, 4, 5, 6, time.UTC)))
b.From("users")
b.InnerJoin("profiles p").Using("email")
b.Where(func(b pgstmt.Cond) {
Expand All @@ -101,7 +103,8 @@ func TestUpdate(t *testing.T) {
update users
set name = p.name,
address = p.address,
updated_at = now()
updated_at = now(),
date = '2022-01-02 03:04:05.000000006Z'
from users
inner join profiles p using (email)
where (users.id = $1)
Expand Down

0 comments on commit 2760bed

Please sign in to comment.