Skip to content

Commit

Permalink
pgstmt: union (#27)
Browse files Browse the repository at this point in the history
* pgstmt: add union

* pgstmt: add join union to select statement

* add offset

* add nested union

* update test
  • Loading branch information
acoshift authored Jan 26, 2023
1 parent 0c02819 commit 50e6664
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 0 deletions.
48 changes: 48 additions & 0 deletions pgstmt/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,31 @@ type SelectStatement interface {
From(table ...string)
FromSelect(f func(b SelectStatement), as string)
FromValues(f func(b Values), as string)

Join(table string) Join
InnerJoin(table string) Join
FullOuterJoin(table string) Join
LeftJoin(table string) Join
RightJoin(table string) Join

JoinSelect(f func(b SelectStatement), as string) Join
InnerJoinSelect(f func(b SelectStatement), as string) Join
FullOuterJoinSelect(f func(b SelectStatement), as string) Join
LeftJoinSelect(f func(b SelectStatement), as string) Join
RightJoinSelect(f func(b SelectStatement), as string) Join

JoinLateralSelect(f func(b SelectStatement), as string) Join
InnerJoinLateralSelect(f func(b SelectStatement), as string) Join
FullOuterJoinLateralSelect(f func(b SelectStatement), as string) Join
LeftJoinLateralSelect(f func(b SelectStatement), as string) Join
RightJoinLateralSelect(f func(b SelectStatement), as string) Join

JoinUnion(f func(b UnionStatement), as string) Join
InnerJoinUnion(f func(b UnionStatement), as string) Join
FullOuterJoinUnion(f func(b UnionStatement), as string) Join
LeftJoinUnion(f func(b UnionStatement), as string) Join
RightJoinUnion(f func(b UnionStatement), as string) Join

Where(f func(b Cond))
GroupBy(col ...string)
Having(f func(b Cond))
Expand Down Expand Up @@ -217,6 +227,44 @@ func (st *selectStmt) RightJoinLateralSelect(f func(b SelectStatement), as strin
return st.joinSelect("right join lateral", f, as)
}

func (st *selectStmt) joinUnion(typ string, f func(b UnionStatement), as string) Join {
var x unionStmt
f(&x)

var b buffer
b.push(paren(x.make()))
if as != "" {
b.push(as)
}

j := join{
typ: typ,
table: &b,
}
st.joins.push(&j)
return &j
}

func (st *selectStmt) JoinUnion(f func(b UnionStatement), as string) Join {
return st.joinUnion("join", f, as)
}

func (st *selectStmt) InnerJoinUnion(f func(b UnionStatement), as string) Join {
return st.joinUnion("inner join", f, as)
}

func (st *selectStmt) FullOuterJoinUnion(f func(b UnionStatement), as string) Join {
return st.joinUnion("full outer join", f, as)
}

func (st *selectStmt) LeftJoinUnion(f func(b UnionStatement), as string) Join {
return st.joinUnion("left join", f, as)
}

func (st *selectStmt) RightJoinUnion(f func(b UnionStatement), as string) Join {
return st.joinUnion("right join", f, as)
}

func (st *selectStmt) Where(f func(b Cond)) {
f(&st.where)
}
Expand Down
31 changes: 31 additions & 0 deletions pgstmt/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,37 @@ func TestSelect(t *testing.T) {
`,
nil,
},
{
"inner join union",
pgstmt.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table1")
b.InnerJoinUnion(func(b pgstmt.UnionStatement) {
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table2")
})
b.AllSelect(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table3")
})
b.OrderBy("id").Desc()
b.Limit(100)
}, "t").Using("id")
}),
`
select id
from table1
inner join (
(select id from table2)
union all
(select id from table3)
order by id desc
limit 100
) t using (id)
`,
nil,
},
}

for _, tC := range cases {
Expand Down
99 changes: 99 additions & 0 deletions pgstmt/union.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package pgstmt

func Union(f func(b UnionStatement)) *Result {
var st unionStmt
f(&st)
return newResult(build(st.make()))
}

type UnionStatement interface {
Select(f func(b SelectStatement))
AllSelect(f func(b SelectStatement))
Union(f func(b UnionStatement))
AllUnion(f func(b UnionStatement))
OrderBy(col string) OrderBy
Limit(n int64)
Offset(n int64)
}

type unionStmt struct {
b buffer
orderBy group
limit *int64
offset *int64
}

func (st *unionStmt) Select(f func(b SelectStatement)) {
var x selectStmt
f(&x)

if st.b.empty() {
st.b.push(paren(x.make()))
} else {
st.b.push("union", paren(x.make()))
}
}

func (st *unionStmt) AllSelect(f func(b SelectStatement)) {
var x selectStmt
f(&x)

if st.b.empty() {
st.b.push(paren(x.make()))
} else {
st.b.push("union all", paren(x.make()))
}
}

func (st *unionStmt) Union(f func(b UnionStatement)) {
var x unionStmt
f(&x)

if st.b.empty() {
st.b.push(paren(x.make()))
} else {
st.b.push("union", paren(x.make()))
}
}

func (st *unionStmt) AllUnion(f func(b UnionStatement)) {
var x unionStmt
f(&x)

if st.b.empty() {
st.b.push(paren(x.make()))
} else {
st.b.push("union all", paren(x.make()))
}
}

func (st *unionStmt) OrderBy(col string) OrderBy {
p := orderBy{
col: col,
}
st.orderBy.push(&p)
return &p
}

func (st *unionStmt) Limit(n int64) {
st.limit = &n
}

func (st *unionStmt) Offset(n int64) {
st.offset = &n
}

func (st *unionStmt) make() *buffer {
var b buffer
b.push(&st.b)
if !st.orderBy.empty() {
b.push("order by", &st.orderBy)
}
if st.limit != nil {
b.push("limit", *st.limit)
}
if st.offset != nil {
b.push("offset", *st.offset)
}
return &b
}
94 changes: 94 additions & 0 deletions pgstmt/union_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package pgstmt_test

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/acoshift/pgsql/pgstmt"
)

func TestUnion(t *testing.T) {
t.Parallel()

cases := []struct {
name string
result *pgstmt.Result
query string
args []interface{}
}{
{
"union select",
pgstmt.Union(func(b pgstmt.UnionStatement) {
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table1")
})
b.AllSelect(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table2")
})
b.OrderBy("id")
b.Limit(10)
b.Offset(2)
}),
`
(select id from table1)
union all (select id from table2)
order by id
limit 10 offset 2
`,
nil,
},
{
"union nested",
pgstmt.Union(func(b pgstmt.UnionStatement) {
b.Union(func(b pgstmt.UnionStatement) {
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table1")
})
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table2")
})
})
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table3")
})
b.AllUnion(func(b pgstmt.UnionStatement) {
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table4")
})
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table5")
})
})
}),
`
(
(select id from table1)
union (select id from table2)
)
union (select id from table3)
union all (
(select id from table4)
union
(select id from table5)
)
`,
nil,
},
}

for _, tC := range cases {
t.Run(tC.name, func(t *testing.T) {
q, args := tC.result.SQL()
assert.Equal(t, stripSpace(tC.query), q)
assert.EqualValues(t, tC.args, args)
})
}
}
2 changes: 2 additions & 0 deletions pgstmt/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ func stripSpace(s string) string {
}
s = p
}
s = strings.ReplaceAll(s, "( ", "(")
s = strings.ReplaceAll(s, " )", ")")
return s
}

0 comments on commit 50e6664

Please sign in to comment.