Skip to content

Commit

Permalink
fix: join groupBy, project aggr, joined field (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hieuzest authored Oct 27, 2023
1 parent 7a3e658 commit bee4282
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 9 deletions.
4 changes: 2 additions & 2 deletions packages/core/src/driver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ export class Database<S = any> {
sel.args[0].having = Eval.and(query(...tables.map(name => sel.row[name])))
}
sel.args[0].optional = Object.fromEntries(tables.map((name, index) => [name, optional?.[index]]))
return sel
return new Selection(this.getDriver(sel), sel)
} else {
const sel = new Selection(this.getDriver(tables[0]), tables)
if (typeof query === 'function') {
sel.args[0].having = Eval.and(query(sel.row))
}
sel.args[0].optional = optional
return sel
return new Selection(this.getDriver(sel), sel)
}
}

Expand Down
19 changes: 18 additions & 1 deletion packages/memory/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { clone, Dict, makeArray, noop, omit, pick, valueMap } from 'cosmokit'
import { Database, Driver, Eval, executeEval, executeQuery, executeSort, executeUpdate, RuntimeError, Selection } from '@minatojs/core'
import { Database, Driver, Eval, executeEval, executeQuery, executeSort, executeUpdate, isEvalExpr, RuntimeError, Selection } from '@minatojs/core'

export namespace MemoryDriver {
export interface Config {}
Expand Down Expand Up @@ -47,6 +47,10 @@ export class MemoryDriver extends Driver {
const { ref, query, table, args, model } = sel
const { fields, group, having } = sel.args[0]
const data = this.table(table, having).filter(row => executeQuery(row, query, ref))
if (!group.length && fields && Object.values(args[0].fields ?? {}).some(x => isAggrExpr(x))) {
return [valueMap(fields!, (expr) => executeEval(data.map(row => ({ [ref]: row })), expr))]
}

const branches: { index: Dict; table: any[] }[] = []
const groupFields = group.length ? pick(fields!, group) : fields
for (let row of executeSort(data, args[0], ref)) {
Expand Down Expand Up @@ -164,4 +168,17 @@ export class MemoryDriver extends Driver {
}
}

const nonAggrKeys = ['$']
const aggrKeys = ['$sum', '$avg', '$min', '$max', '$count']

function isAggrExpr(value: any) {
if (!isEvalExpr(value)) return false
for (const [key, args] of Object.entries(value)) {
if (!key.startsWith('$')) continue
if (nonAggrKeys.includes(key)) return false
if (aggrKeys.includes(key) || ((Array.isArray(args) ? args : [args]).some(x => isAggrExpr(x)))) return true
}
return false
}

export default MemoryDriver
6 changes: 4 additions & 2 deletions packages/mongo/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,12 @@ export class Transformer {
stages.push({ $match: { $expr } })
}
stages.push({ $project })
$group['_id'] = model.parse($group['_id'], false)
} else if (fields) {
const $project = valueMap(fields, (expr) => this.eval(expr))
const $group: Dict = { _id: null }
const $project = valueMap(fields, (expr) => this.eval(expr, $group))
$project._id = 0
stages.push({ $project })
stages.push(...Object.keys($group).length === 1 ? [] : [{ $group }], { $project })
} else {
const $project: Dict = { _id: 0 }
for (const key in model.fields) {
Expand Down
2 changes: 1 addition & 1 deletion packages/mysql/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ export class MySQLDriver extends Driver {
const builder = new MySQLBuilder(sel.tables)
const output = builder.parseEval(expr)
const inner = builder.get(sel.table as Selection, true)
const [data] = await this.queue(`SELECT ${output} AS value FROM ${inner} ${sel.ref}`)
const [data] = await this.queue(`SELECT ${output} AS value FROM ${inner}`)
return data.value
}

Expand Down
6 changes: 5 additions & 1 deletion packages/sql-utils/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -304,11 +304,15 @@ export class Builder {
if (filter !== '1') {
suffix = ` WHERE ${filter}` + suffix
}

if (inline && !args[0].fields && !suffix) {
return (prefix.startsWith('(') && prefix.endsWith(')')) ? `${prefix} ${ref}` : prefix
}

if (!prefix.includes(' ') || prefix.startsWith('(')) {
suffix = ` ${ref}` + suffix
}

if (inline && !args[0].fields && !suffix) return prefix
const result = `SELECT ${keys} FROM ${prefix}${suffix}`
return inline ? `(${result})` : result
}
Expand Down
2 changes: 1 addition & 1 deletion packages/sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ export class SQLiteDriver extends Driver {
const builder = new SQLiteBuilder(sel.tables)
const output = builder.parseEval(expr)
const inner = builder.get(sel.table as Selection, true)
const { value } = this.#get(`SELECT ${output} AS value FROM ${inner} ${sel.ref}`)
const { value } = this.#get(`SELECT ${output} AS value FROM ${inner}`)
return value
}

Expand Down
47 changes: 46 additions & 1 deletion packages/tests/src/selection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function SelectionTests(database: Database<Tables>) {
value: 'integer',
})

database.migrate('foo', { deprecated: 'unsigned' }, async () => {})
database.migrate('foo', { deprecated: 'unsigned' }, async () => { })

database.extend('bar', {
id: 'unsigned',
Expand Down Expand Up @@ -139,6 +139,33 @@ namespace SelectionTests {
{ id: 10 },
])
})

it('aggregate', async () => {
await expect(database
.select('foo')
.project({
count: row => $.count(row.id),
max: row => $.max(row.id),
min: row => $.min(row.id),
avg: row => $.avg(row.id),
})
.execute()
).to.eventually.deep.equal([
{ avg: 2, count: 3, max: 3, min: 1 },
])

await expect(database.select('foo')
.groupBy({}, {
count: row => $.count(row.id),
max: row => $.max(row.id),
min: row => $.min(row.id),
avg: row => $.avg(row.id),
})
.execute()
).to.eventually.deep.equal([
{ avg: 2, count: 3, max: 3, min: 1 },
])
})
}

export function aggregate(database: Database<Tables>) {
Expand Down Expand Up @@ -258,6 +285,24 @@ namespace SelectionTests {
.execute()
).to.eventually.have.length(2)
})

it('group', async () => {
await expect(database.join(['foo', 'bar'] as const, (foo, bar) => $.eq(foo.id, bar.pid))
.groupBy('foo', { count: row => $.sum(row.bar.uid) })
.orderBy(row => row.foo.id)
.execute()).to.eventually.deep.equal([
{ foo: { id: 1, value: 0 }, count: 6 },
{ foo: { id: 2, value: 2 }, count: 1 },
{ foo: { id: 3, value: 2 }, count: 1 },
])
})

it('aggregate', async () => {
await expect(database
.join(['foo', 'bar'] as const)
.execute(row => $.count(row.bar.id))
).to.eventually.equal(6)
})
}
}

Expand Down

0 comments on commit bee4282

Please sign in to comment.