diff --git a/packages/core/src/driver.ts b/packages/core/src/driver.ts index 8694b1e4..b4df705a 100644 --- a/packages/core/src/driver.ts +++ b/packages/core/src/driver.ts @@ -119,14 +119,14 @@ export class Database { sel.args[0].having = Eval.and(query(...tables.map(name => sel.row[name]))) } sel.args[0].optional = Object.fromEntries(tables.map((name, index) => [name, optional?.[index]])) - return sel + return new Selection(this.getDriver(sel), sel) } else { const sel = new Selection(this.getDriver(tables[0]), tables) if (typeof query === 'function') { sel.args[0].having = Eval.and(query(sel.row)) } sel.args[0].optional = optional - return sel + return new Selection(this.getDriver(sel), sel) } } diff --git a/packages/memory/src/index.ts b/packages/memory/src/index.ts index 4ee9f048..b20bec28 100644 --- a/packages/memory/src/index.ts +++ b/packages/memory/src/index.ts @@ -1,5 +1,5 @@ import { clone, Dict, makeArray, noop, omit, pick, valueMap } from 'cosmokit' -import { Database, Driver, Eval, executeEval, executeQuery, executeSort, executeUpdate, RuntimeError, Selection } from '@minatojs/core' +import { Database, Driver, Eval, executeEval, executeQuery, executeSort, executeUpdate, isEvalExpr, RuntimeError, Selection } from '@minatojs/core' export namespace MemoryDriver { export interface Config {} @@ -47,6 +47,10 @@ export class MemoryDriver extends Driver { const { ref, query, table, args, model } = sel const { fields, group, having } = sel.args[0] const data = this.table(table, having).filter(row => executeQuery(row, query, ref)) + if (!group.length && fields && Object.values(args[0].fields ?? {}).some(x => isAggrExpr(x))) { + return [valueMap(fields!, (expr) => executeEval(data.map(row => ({ [ref]: row })), expr))] + } + const branches: { index: Dict; table: any[] }[] = [] const groupFields = group.length ? pick(fields!, group) : fields for (let row of executeSort(data, args[0], ref)) { @@ -164,4 +168,17 @@ export class MemoryDriver extends Driver { } } +const nonAggrKeys = ['$'] +const aggrKeys = ['$sum', '$avg', '$min', '$max', '$count'] + +function isAggrExpr(value: any) { + if (!isEvalExpr(value)) return false + for (const [key, args] of Object.entries(value)) { + if (!key.startsWith('$')) continue + if (nonAggrKeys.includes(key)) return false + if (aggrKeys.includes(key) || ((Array.isArray(args) ? args : [args]).some(x => isAggrExpr(x)))) return true + } + return false +} + export default MemoryDriver diff --git a/packages/mongo/src/utils.ts b/packages/mongo/src/utils.ts index 9f638224..f8b0852c 100644 --- a/packages/mongo/src/utils.ts +++ b/packages/mongo/src/utils.ts @@ -240,10 +240,12 @@ export class Transformer { stages.push({ $match: { $expr } }) } stages.push({ $project }) + $group['_id'] = model.parse($group['_id'], false) } else if (fields) { - const $project = valueMap(fields, (expr) => this.eval(expr)) + const $group: Dict = { _id: null } + const $project = valueMap(fields, (expr) => this.eval(expr, $group)) $project._id = 0 - stages.push({ $project }) + stages.push(...Object.keys($group).length === 1 ? [] : [{ $group }], { $project }) } else { const $project: Dict = { _id: 0 } for (const key in model.fields) { diff --git a/packages/mysql/src/index.ts b/packages/mysql/src/index.ts index 5571c9e2..ac798fb8 100644 --- a/packages/mysql/src/index.ts +++ b/packages/mysql/src/index.ts @@ -440,7 +440,7 @@ export class MySQLDriver extends Driver { const builder = new MySQLBuilder(sel.tables) const output = builder.parseEval(expr) const inner = builder.get(sel.table as Selection, true) - const [data] = await this.queue(`SELECT ${output} AS value FROM ${inner} ${sel.ref}`) + const [data] = await this.queue(`SELECT ${output} AS value FROM ${inner}`) return data.value } diff --git a/packages/sql-utils/src/index.ts b/packages/sql-utils/src/index.ts index b17490ed..f93c03a7 100644 --- a/packages/sql-utils/src/index.ts +++ b/packages/sql-utils/src/index.ts @@ -304,11 +304,15 @@ export class Builder { if (filter !== '1') { suffix = ` WHERE ${filter}` + suffix } + + if (inline && !args[0].fields && !suffix) { + return (prefix.startsWith('(') && prefix.endsWith(')')) ? `${prefix} ${ref}` : prefix + } + if (!prefix.includes(' ') || prefix.startsWith('(')) { suffix = ` ${ref}` + suffix } - if (inline && !args[0].fields && !suffix) return prefix const result = `SELECT ${keys} FROM ${prefix}${suffix}` return inline ? `(${result})` : result } diff --git a/packages/sqlite/src/index.ts b/packages/sqlite/src/index.ts index 07694b3a..dc457227 100644 --- a/packages/sqlite/src/index.ts +++ b/packages/sqlite/src/index.ts @@ -318,7 +318,7 @@ export class SQLiteDriver extends Driver { const builder = new SQLiteBuilder(sel.tables) const output = builder.parseEval(expr) const inner = builder.get(sel.table as Selection, true) - const { value } = this.#get(`SELECT ${output} AS value FROM ${inner} ${sel.ref}`) + const { value } = this.#get(`SELECT ${output} AS value FROM ${inner}`) return value } diff --git a/packages/tests/src/selection.ts b/packages/tests/src/selection.ts index 3a86d8e2..dc34bf63 100644 --- a/packages/tests/src/selection.ts +++ b/packages/tests/src/selection.ts @@ -26,7 +26,7 @@ function SelectionTests(database: Database) { value: 'integer', }) - database.migrate('foo', { deprecated: 'unsigned' }, async () => {}) + database.migrate('foo', { deprecated: 'unsigned' }, async () => { }) database.extend('bar', { id: 'unsigned', @@ -139,6 +139,33 @@ namespace SelectionTests { { id: 10 }, ]) }) + + it('aggregate', async () => { + await expect(database + .select('foo') + .project({ + count: row => $.count(row.id), + max: row => $.max(row.id), + min: row => $.min(row.id), + avg: row => $.avg(row.id), + }) + .execute() + ).to.eventually.deep.equal([ + { avg: 2, count: 3, max: 3, min: 1 }, + ]) + + await expect(database.select('foo') + .groupBy({}, { + count: row => $.count(row.id), + max: row => $.max(row.id), + min: row => $.min(row.id), + avg: row => $.avg(row.id), + }) + .execute() + ).to.eventually.deep.equal([ + { avg: 2, count: 3, max: 3, min: 1 }, + ]) + }) } export function aggregate(database: Database) { @@ -258,6 +285,24 @@ namespace SelectionTests { .execute() ).to.eventually.have.length(2) }) + + it('group', async () => { + await expect(database.join(['foo', 'bar'] as const, (foo, bar) => $.eq(foo.id, bar.pid)) + .groupBy('foo', { count: row => $.sum(row.bar.uid) }) + .orderBy(row => row.foo.id) + .execute()).to.eventually.deep.equal([ + { foo: { id: 1, value: 0 }, count: 6 }, + { foo: { id: 2, value: 2 }, count: 1 }, + { foo: { id: 3, value: 2 }, count: 1 }, + ]) + }) + + it('aggregate', async () => { + await expect(database + .join(['foo', 'bar'] as const) + .execute(row => $.count(row.bar.id)) + ).to.eventually.equal(6) + }) } }