Skip to content

Commit

Permalink
fix(spanner/spannertest): fix ORDER BY combined with SELECT aliases (#…
Browse files Browse the repository at this point in the history
…3043)

This was broken due to the ORDER BY implementation that was a bit of a
hack. Replace that hack with a more explicit evaluation of the ORDER BY
sort keys, which also removed a duplicate chunk of row sorting code.

This also required a restructuring of how locking works during an
evaluation, which cleaned up some other code.
  • Loading branch information
dsymonds authored Oct 18, 2020
1 parent e58a55d commit 89a9df5
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 87 deletions.
2 changes: 1 addition & 1 deletion spanner/spannertest/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ by ascending esotericism:

- expression functions
- more aggregation functions
- more joins types (INNER, CROSS, FULL, RIGHT)
- INSERT/UPDATE DML statements
- SELECT HAVING
- case insensitivity
- FULL JOIN
- alternate literal types (esp. strings)
- STRUCT types
- transaction simulation
Expand Down
257 changes: 172 additions & 85 deletions spanner/spannertest/db_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,27 @@ type selIter struct {
ec evalContext
cis []colInfo
list []spansql.Expr

distinct bool // whether this is a SELECT DISTINCT
seen []row
}

func (si selIter) Cols() []colInfo { return si.cis }
func (si selIter) Next() (row, error) {
func (si *selIter) Cols() []colInfo { return si.cis }
func (si *selIter) Next() (row, error) {
for {
r, err := si.next()
if err != nil {
return nil, err
}
if si.distinct && !si.keep(r) {
continue
}
return r, nil
}
}

// next retrieves the next row for the SELECT and evaluates its expression list.
func (si *selIter) next() (row, error) {
r, err := si.ri.Next()
if err != nil {
return nil, err
Expand All @@ -216,35 +233,17 @@ func (si selIter) Next() (row, error) {
return out, nil
}

// distinctIter applies a DISTINCT filter.
type distinctIter struct {
ri rowIter
seen []row
}

func (di *distinctIter) Cols() []colInfo { return di.ri.Cols() }
func (di *distinctIter) Next() (row, error) {
func (si *selIter) keep(r row) bool {
// This is hilariously inefficient; O(N^2) in the number of returned rows.
// Some sort of hashing could be done to deduplicate instead.
// This also breaks on array/struct types.
for {
row, err := di.ri.Next()
if err != nil {
return nil, err
for _, prev := range si.seen {
if rowEqual(prev, r) {
return false
}
dupe := false
for _, prev := range di.seen {
if rowEqual(prev, row) {
dupe = true
break
}
}
if dupe {
continue
}
di.seen = append(di.seen, row)
return row, nil
}
si.seen = append(si.seen, r)
return true
}

// offsetIter applies an OFFSET clause.
Expand Down Expand Up @@ -295,39 +294,70 @@ type queryParam struct {

type queryParams map[string]queryParam // TODO: change key to spansql.Param?

func (d *database) Query(q spansql.Query, params queryParams) (rowIter, error) {
// If there's an ORDER BY clause, extend the query to include the expressions we need
// so they get evaluated during evalSelect. TODO: Is this actually okay?
type queryContext struct {
params queryParams

tables []*table // sorted by name
tableIndex map[spansql.ID]*table
locks int
}

func (qc *queryContext) Lock() {
// Take locks in name order.
for _, t := range qc.tables {
t.mu.Lock()
qc.locks++
}
}

func (qc *queryContext) Unlock() {
for _, t := range qc.tables {
t.mu.Unlock()
qc.locks--
}
}

func (d *database) Query(q spansql.Query, params queryParams) (ri rowIter, err error) {
// Figure out the context of the query and take any required locks.
qc, err := d.queryContext(q, params)
if err != nil {
return nil, err
}
qc.Lock()
// On the way out, if there were locks taken, flatten the output
// and release the locks.
if qc.locks > 0 {
defer func() {
if err == nil {
ri, err = toRawIter(ri)
}
qc.Unlock()
}()
}

// Prepare auxiliary expressions to evaluate for ORDER BY.
var aux []spansql.Expr
var desc []bool
for _, o := range q.Order {
aux = append(aux, o.Expr)
desc = append(desc, o.Desc)
}
q.Select.List = append(q.Select.List, aux...)

ri, err := d.evalSelect(q.Select, params)
si, err := d.evalSelect(q.Select, qc)
if err != nil {
return nil, err
}
ri = si

// Apply ORDER BY.
if len(q.Order) > 0 {
raw, err := toRawIter(ri)
// Evaluate the selIter completely, and sort the rows by the auxiliary expressions.
rows, keys, err := evalSelectOrder(si, aux)
if err != nil {
return nil, err
}
sort.Slice(raw.rows, func(one, two int) bool {
r1, r2 := raw.rows[one], raw.rows[two]
aux1, aux2 := r1[len(r1)-len(aux):], r2[len(r2)-len(aux):] // sort keys
return compareValLists(aux1, aux2, desc) < 0
})
// Remove ORDER BY values.
raw.cols = raw.cols[:len(raw.cols)-len(aux)]
for i, row := range raw.rows {
raw.rows[i] = row[:len(row)-len(aux)]
}
ri = raw
sort.Sort(externalRowSorter{rows: rows, keys: keys, desc: desc})
ri = &rawIter{cols: si.cis, rows: rows}
}

// Apply LIMIT, OFFSET.
Expand All @@ -350,33 +380,76 @@ func (d *database) Query(q spansql.Query, params queryParams) (rowIter, error) {
return ri, nil
}

func (d *database) evalSelect(sel spansql.Select, params queryParams) (ri rowIter, evalErr error) {
ri = &nullIter{}
ec := evalContext{
func (d *database) queryContext(q spansql.Query, params queryParams) (*queryContext, error) {
qc := &queryContext{
params: params,
}

// Look for any mentioned tables and add them to qc.tableIndex.
addTable := func(name spansql.ID) error {
if _, ok := qc.tableIndex[name]; ok {
return nil // Already found this table.
}
t, err := d.table(name)
if err != nil {
return err
}
if qc.tableIndex == nil {
qc.tableIndex = make(map[spansql.ID]*table)
}
qc.tableIndex[name] = t
return nil
}
var findTables func(sf spansql.SelectFrom) error
findTables = func(sf spansql.SelectFrom) error {
switch sf := sf.(type) {
default:
return fmt.Errorf("can't prepare query context for SelectFrom of type %T", sf)
case spansql.SelectFromTable:
return addTable(sf.Table)
case spansql.SelectFromJoin:
if err := findTables(sf.LHS); err != nil {
return err
}
return findTables(sf.RHS)
}
}
for _, sf := range q.Select.From {
if err := findTables(sf); err != nil {
return nil, err
}
}

// Build qc.tables in name order so we can take locks in a well-defined order.
var names []spansql.ID
for name := range qc.tableIndex {
names = append(names, name)
}
sort.Slice(names, func(i, j int) bool { return names[i] < names[j] })
for _, name := range names {
qc.tables = append(qc.tables, qc.tableIndex[name])
}

return qc, nil
}

func (d *database) evalSelect(sel spansql.Select, qc *queryContext) (si *selIter, evalErr error) {
var ri rowIter = &nullIter{}
ec := evalContext{
params: qc.params,
}

// First stage is to identify the data source.
// If there's a FROM then that names a table to use.
if len(sel.From) > 1 {
return nil, fmt.Errorf("selecting with more than one FROM clause not yet supported")
}
if len(sel.From) == 1 {
var unlock func()
var err error
ec, ri, unlock, err = d.evalSelectFrom(ec, sel.From[0])
ec, ri, err = d.evalSelectFrom(qc, ec, sel.From[0])
if err != nil {
return nil, err
}
defer unlock()

// On the way out, convert the result to a rawIter
// so that any locked tables may be safely unlocked.
defer func() {
if evalErr == nil {
ri, evalErr = toRawIter(ri)
}
}()
}

// Apply WHERE.
Expand Down Expand Up @@ -577,31 +650,27 @@ func (d *database) evalSelect(sel spansql.Select, params queryParams) (ri rowIte
colInfos = append(colInfos, ci)
}
}
ri = selIter{

return &selIter{
ri: ri,
ec: ec,
cis: colInfos,
list: sel.List,
}

// Apply DISTINCT.
if sel.Distinct {
ri = &distinctIter{ri: ri}
}

return ri, nil
distinct: sel.Distinct, // Apply DISTINCT.
}, nil
}

func (d *database) evalSelectFrom(ec evalContext, sf spansql.SelectFrom) (evalContext, rowIter, func(), error) {
func (d *database) evalSelectFrom(qc *queryContext, ec evalContext, sf spansql.SelectFrom) (evalContext, rowIter, error) {
switch sf := sf.(type) {
default:
return ec, nil, nil, fmt.Errorf("selecting with FROM clause of type %T not yet supported", sf)
return ec, nil, fmt.Errorf("selecting with FROM clause of type %T not yet supported", sf)
case spansql.SelectFromTable:
t, err := d.table(sf.Table)
if err != nil {
return ec, nil, nil, err
t, ok := qc.tableIndex[sf.Table]
if !ok {
// This shouldn't be possible; the queryContext should have discovered missing tables already.
return ec, nil, fmt.Errorf("unknown table %q", sf.Table)
}
t.mu.Lock()
ti := &tableIter{t: t}
if sf.Alias != "" {
ti.alias = sf.Alias
Expand All @@ -611,36 +680,33 @@ func (d *database) evalSelectFrom(ec evalContext, sf spansql.SelectFrom) (evalCo
ti.alias = sf.Table
}
ec.cols = ti.Cols()
return ec, ti, t.mu.Unlock, nil
return ec, ti, nil
case spansql.SelectFromJoin:
// TODO: Avoid the toRawIter calls here by rethinking how locking works throughout evalSelect,
// then doing the RHS recursive evalSelectFrom in joinIter.Next on demand.
// TODO: Avoid the toRawIter calls here by doing the RHS recursive evalSelectFrom in joinIter.Next on demand.

lhsEC, lhs, unlock, err := d.evalSelectFrom(ec, sf.LHS)
lhsEC, lhs, err := d.evalSelectFrom(qc, ec, sf.LHS)
if err != nil {
return ec, nil, nil, err
return ec, nil, err
}
lhsRaw, err := toRawIter(lhs)
unlock()
if err != nil {
return ec, nil, nil, err
return ec, nil, err
}

rhsEC, rhs, unlock, err := d.evalSelectFrom(ec, sf.RHS)
rhsEC, rhs, err := d.evalSelectFrom(qc, ec, sf.RHS)
if err != nil {
return ec, nil, nil, err
return ec, nil, err
}
rhsRaw, err := toRawIter(rhs)
unlock()
if err != nil {
return ec, nil, nil, err
return ec, nil, err
}

ji, ec, err := newJoinIter(lhsRaw, rhsRaw, lhsEC, rhsEC, sf)
if err != nil {
return ec, nil, nil, err
return ec, nil, err
}
return ec, ji, func() {}, nil
return ec, ji, nil
}
}

Expand Down Expand Up @@ -893,16 +959,37 @@ func (ji *joinIter) Next() (row, error) {
}
}

func evalSelectOrder(si *selIter, aux []spansql.Expr) (rows []row, keys [][]interface{}, err error) {
// This is like toRawIter except it also evaluates the auxiliary expressions for ORDER BY.
for {
r, err := si.Next()
if err == io.EOF {
break
} else if err != nil {
return nil, nil, err
}
key, err := si.ec.evalExprList(aux)
if err != nil {
return nil, nil, err
}

rows = append(rows, r.copyAllData())
keys = append(keys, key)
}
return
}

// externalRowSorter implements sort.Interface for a slice of rows
// with an external sort key.
type externalRowSorter struct {
rows []row
keys [][]interface{}
desc []bool // may be nil
}

func (ers externalRowSorter) Len() int { return len(ers.rows) }
func (ers externalRowSorter) Less(i, j int) bool {
return compareValLists(ers.keys[i], ers.keys[j], nil) < 0
return compareValLists(ers.keys[i], ers.keys[j], ers.desc) < 0
}
func (ers externalRowSorter) Swap(i, j int) {
ers.rows[i], ers.rows[j] = ers.rows[j], ers.rows[i]
Expand Down
Loading

0 comments on commit 89a9df5

Please sign in to comment.