Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(minato): impl some math functions #64

Merged
merged 6 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion packages/core/src/eval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ export namespace Eval {
subtract: Binary<number, number>
div: Binary<number, number>
divide: Binary<number, number>
mod: Binary<number, number>
modulo: Binary<number, number>

// mathematic
abs: Unary<number, number>
floor: Unary<number, number>
ceil: Unary<number, number>
round: Unary<number, number>
exp: Unary<number, number>
log<A extends boolean>(x: Term<number, A>, base?: Term<number, A>): Expr<number, A>
pow: Binary<number, number>
power: Binary<number, number>
random(): Expr<number, false>

// comparison
eq: Multi<Comparable, boolean>
Expand Down Expand Up @@ -116,7 +129,7 @@ operators['$'] = getRecursive
type UnaryCallback<T> = T extends (value: infer R) => Eval.Expr<infer S> ? (value: R, data: any[]) => S : never
function unary<K extends keyof Eval.Static>(key: K, callback: UnaryCallback<Eval.Static[K]>): Eval.Static[K] {
operators[`$${key}`] = callback
return (value: any) => Eval(key, value) as any
return ((value: any) => Eval(key, value)) as any
}

type MultivariateCallback<T> = T extends (...args: infer R) => Eval.Expr<infer S> ? (args: R, data: any) => S : never
Expand Down Expand Up @@ -153,6 +166,18 @@ Eval.add = multary('add', (args, data) => args.reduce<number>((prev, curr) => pr
Eval.mul = Eval.multiply = multary('multiply', (args, data) => args.reduce<number>((prev, curr) => prev * executeEval(data, curr), 1))
Eval.sub = Eval.subtract = multary('subtract', ([left, right], data) => executeEval(data, left) - executeEval(data, right))
Eval.div = Eval.divide = multary('divide', ([left, right], data) => executeEval(data, left) / executeEval(data, right))
Eval.mod = Eval.modulo = multary('modulo', ([left, right], data) => executeEval(data, left) % executeEval(data, right))

// mathematic
Eval.abs = unary('abs', (arg, data) => Math.abs(executeEval(data, arg)))
Eval.floor = unary('floor', (arg, data) => Math.floor(executeEval(data, arg)))
Eval.ceil = unary('ceil', (arg, data) => Math.ceil(executeEval(data, arg)))
Eval.round = unary('round', (arg, data) => Math.round(executeEval(data, arg)))
Eval.exp = unary('exp', (arg, data) => Math.exp(executeEval(data, arg)))
Eval.log = multary('log', ([left, right], data) => Math.log(executeEval(data, left)) / Math.log(executeEval(data, right ?? Math.E)))
Eval.pow = Eval.power = multary('power', ([left, right], data) => Math.pow(executeEval(data, left), executeEval(data, right)))
Eval.random = () => Eval('random', {})
operators.$random = () => Math.random()

// comparison
Eval.eq = comparator('eq', (left, right) => left === right)
Expand Down
98 changes: 56 additions & 42 deletions packages/mongo/src/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Dict, isNullable, valueMap } from 'cosmokit'
import { isComparable, Query, Selection } from '@minatojs/core'
import { Eval, isComparable, Query, Selection } from '@minatojs/core'
import { Filter, FilterOperators } from 'mongodb'

function createFieldFilter(query: Query.FieldQuery, key: string) {
Expand Down Expand Up @@ -72,14 +72,66 @@ function transformFieldQuery(query: Query.FieldQuery, key: string, filters: Filt
return result
}

export type ExtractUnary<T> = T extends [infer U] ? U : T

export type EvalOperators = {
[K in keyof Eval.Static as `$${K}`]?: (expr: ExtractUnary<Parameters<Eval.Static[K]>>, group?: object) => any
} & { $: (expr: any, group?: object) => any }

const aggrKeys = ['$sum', '$avg', '$min', '$max', '$count', '$length', '$array']

export class Transformer {
private counter = 0
private evalOperators: EvalOperators
public walkedKeys: string[]

constructor(public virtualKey?: string, public lookup?: boolean, public recursivePrefix: string = '$') {
this.walkedKeys = []

this.evalOperators = {
$: (arg, group) => {
if (typeof arg === 'string') {
this.walkedKeys.push(this.getActualKey(arg))
return this.recursivePrefix + this.getActualKey(arg)
} else if (this.lookup) {
this.walkedKeys.push(arg[0] + '.' + this.getActualKey(arg[1]))
return this.recursivePrefix + arg[0] + '.' + this.getActualKey(arg[1])
} else {
this.walkedKeys.push(this.getActualKey(arg[1]))
return this.recursivePrefix + this.getActualKey(arg[1])
}
},
$if: (arg, group) => ({ $cond: arg.map(val => this.eval(val, group)) }),
$array: (arg, group) => this.transformEvalExpr(arg),
$object: (arg, group) => this.transformEvalExpr(arg),

$length: (arg, group) => ({ $size: this.eval(arg, group) }),
$nin: (arg, group) => ({ $not: { $in: arg.map(val => this.eval(val, group)) } }),

$modulo: (arg, group) => ({ $mod: arg.map(val => this.eval(val, group)) }),
$log: ([left, right], group) => isNullable(right)
? { $ln: this.eval(left, group) }
: { $log: [this.eval(left, group), this.eval(right, group)] },
$power: (arg, group) => ({ $pow: arg.map(val => this.eval(val, group)) }),
$random: (arg, group) => ({ $rand: {} }),

$number: (arg, group) => {
const value = this.eval(arg, group)
return {
$ifNull: [{
$switch: {
branches: [
{
case: { $eq: [{ $type: value }, 'date'] },
then: { $floor: { $divide: [{ $toLong: value }, 1000] } },
},
],
default: { $toDouble: value },
},
}, 0],
}
},
}
}

public createKey() {
Expand All @@ -97,47 +149,9 @@ export class Transformer {
return { $literal: expr }
}

if (expr.$) {
if (typeof expr.$ === 'string') {
this.walkedKeys.push(this.getActualKey(expr.$))
return this.recursivePrefix + this.getActualKey(expr.$)
} else if (this.lookup) {
this.walkedKeys.push(expr.$[0] + '.' + this.getActualKey(expr.$[1]))
return this.recursivePrefix + expr.$[0] + '.' + this.getActualKey(expr.$[1])
} else {
this.walkedKeys.push(this.getActualKey(expr.$[1]))
return this.recursivePrefix + this.getActualKey(expr.$[1])
}
}

if (expr.$if) {
return { $cond: expr.$if.map(val => this.eval(val, group)) }
}

if (expr.$object || expr.$array) {
return this.transformEvalExpr(expr.$object || expr.$array)
}

if (expr.$length) {
return { $size: this.eval(expr.$length) }
}

if (expr.$nin) {
return { $not: { $in: expr.$nin.map(val => this.eval(val, group)) } }
}

if (expr.$number) {
const value = this.eval(expr.$number)
return {
$switch: {
branches: [
{
case: { $eq: [{ $type: value }, 'date'] },
then: { $floor: { $divide: [{ $toLong: value }, 1000] } },
},
],
default: { $toDouble: value },
},
for (const key in expr) {
if (this.evalOperators[key]) {
return this.evalOperators[key](expr[key], group)
}
}

Expand Down
20 changes: 18 additions & 2 deletions packages/postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,23 @@ class PostgresBuilder extends Builder {
// number
$add: (args) => `(${args.map(arg => this.parseEval(arg, 'double precision')).join(' + ')})`,
$multiply: (args) => `(${args.map(arg => this.parseEval(arg, 'double precision')).join(' * ')})`,
$modulo: ([left, right]) => {
const dividend = this.parseEval(left, 'double precision'), divisor = this.parseEval(right, 'double precision')
return `${dividend} - (${divisor} * floor(${dividend} / ${divisor}))`
},
$log: ([left, right]) => isNullable(right)
? `ln(${this.parseEval(left, 'double precision')})`
: `ln(${this.parseEval(left, 'double precision')}) / ln(${this.parseEval(right, 'double precision')})`,
$random: () => `random()`,

$eq: this.binary('=', 'text'),

$number: (arg) => {
const value = this.parseEval(arg)
const res = this.state.sqlType === 'raw' ? `${value}::double precision`
: `extract(epoch from ${value})::integer`
: `extract(epoch from ${value})::bigint`
this.state.sqlType = 'raw'
return res
return `coalesce(${res}, 0)`
},

$sum: (expr) => this.createAggr(expr, value => `coalesce(sum(${value})::double precision, 0)`, undefined, 'double precision'),
Expand Down Expand Up @@ -437,6 +445,14 @@ export class PostgresDriver extends Driver {
debug(_, query, parameters) {
logger.debug(`> %s` + (parameters.length ? `\nparameters: %o` : ``), query, parameters.length ? parameters : '')
},
transform: {
value: {
from: (value, column) => {
if (column.type === 20) return Number(value)
return value
},
},
},
...config,
}

Expand Down
13 changes: 12 additions & 1 deletion packages/sql-utils/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,17 @@ export class Builder {
$multiply: (args) => `(${args.map(arg => this.parseEval(arg)).join(' * ')})`,
$subtract: this.binary('-'),
$divide: this.binary('/'),
$modulo: this.binary('%'),

// mathemetic
$abs: (arg) => `abs(${this.parseEval(arg)})`,
$floor: (arg) => `floor(${this.parseEval(arg)})`,
$ceil: (arg) => `ceil(${this.parseEval(arg)})`,
$round: (arg) => `round(${this.parseEval(arg)})`,
$exp: (arg) => `exp(${this.parseEval(arg)})`,
$log: (args) => `log(${args.filter(x => !isNullable(x)).map(arg => this.parseEval(arg)).reverse().join(', ')})`,
$power: (args) => `power(${args.map(arg => this.parseEval(arg)).join(', ')})`,
$random: () => `rand()`,

// string
$concat: (args) => `concat(${args.map(arg => this.parseEval(arg)).join(', ')})`,
Expand Down Expand Up @@ -134,7 +145,7 @@ export class Builder {
: this.state.sqlType === 'time' ? `unix_timestamp(convert_tz(addtime('1970-01-01 00:00:00', ${value}), '${this._timezone}', '+0:00'))`
: `unix_timestamp(convert_tz(${value}, '${this._timezone}', '+0:00'))`
this.state.sqlType = 'raw'
return res
return `ifnull(${res}, 0)`
},

// aggregation
Expand Down
8 changes: 7 additions & 1 deletion packages/sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class SQLiteBuilder extends Builder {

this.evalOperators.$if = (args) => `iif(${args.map(arg => this.parseEval(arg)).join(', ')})`
this.evalOperators.$concat = (args) => `(${args.map(arg => this.parseEval(arg)).join('||')})`
this.evalOperators.$modulo = ([left, right]) => `modulo(${this.parseEval(left)}, ${this.parseEval(right)})`
this.evalOperators.$log = ([left, right]) => isNullable(right)
? `log(${this.parseEval(left)})`
: `log(${this.parseEval(left)}) / log(${this.parseEval(right)})`
this.evalOperators.$length = (expr) => this.createAggr(expr, value => `count(${value})`, value => {
if (this.state.sqlType === 'json') {
this.state.sqlType = 'raw'
Expand All @@ -66,7 +70,7 @@ class SQLiteBuilder extends Builder {
const res = this.state.sqlType === 'raw' ? `cast(${this.parseEval(arg)} as double)`
: `cast(${value} / 1000 as integer)`
this.state.sqlType = 'raw'
return res
return `ifnull(${res}, 0)`
}

this.define<boolean, number>({
Expand Down Expand Up @@ -275,6 +279,8 @@ export class SQLiteDriver extends Driver {
}
this.db.create_function('regexp', (pattern, str) => +new RegExp(pattern).test(str))
this.db.create_function('json_array_contains', (array, value) => +(JSON.parse(array) as any[]).includes(JSON.parse(value)))
this.db.create_function('modulo', (left, right) => left % right)
this.db.create_function('rand', () => Math.random())
}

#joinKeys(keys?: string[]) {
Expand Down
8 changes: 8 additions & 0 deletions packages/tests/src/selection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ namespace SelectionTests {
{ id: 1, value: 0 },
])
})

it('random', async () => {
await expect(database.select('foo').orderBy(row => $.random()).execute(['id'])).to.eventually.have.deep.members([
{ id: 1 },
{ id: 2 },
{ id: 3 },
])
})
}

export function project(database: Database<Tables>) {
Expand Down
25 changes: 25 additions & 0 deletions packages/tests/src/update.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ interface Bar {
id: number
text?: string
num?: number
double?: number
bool?: boolean
list?: string[]
timestamp?: Date
Expand All @@ -30,6 +31,7 @@ function OrmOperations(database: Database<Tables>) {
id: 'unsigned',
text: 'string',
num: 'integer',
double: 'double',
bool: 'boolean',
list: 'list',
timestamp: 'timestamp',
Expand Down Expand Up @@ -313,6 +315,29 @@ namespace OrmOperations {
date.setHours(0, 0, 0, 0)
await expect(database.eval('temp2', row => $.array($.number(row.date)), { num: 192 })).to.eventually.deep.equal([+date / 1000])
await expect(database.eval('temp2', row => $.array($.number(row.time)), { num: 193 })).to.eventually.deep.equal([43200 + date.getTimezoneOffset() * 60])
await expect(database.eval('temp2', row => $.min($.number(row.timestamp)))).to.eventually.deep.equal(0)
})

it('math functions', async () => {
const table = await setup(database, 'temp2', barTable)
table[0].double = 123.45
table[0].num = 6
await database.set('temp2', table[0].id, { double: table[0].double, num: table[0].num })
await expect(database.eval('temp2', row => $.max($.abs($.sub(0, row.double))), table[0].id)).to.eventually.deep.eq(table[0].double)
await expect(database.eval('temp2', row => $.max($.mod(row.double, row.num)), table[0].id)).to.eventually.deep.eq(table[0].double % table[0].num)
await expect(database.eval('temp2', row => $.max($.ceil(row.double)), table[0].id)).to.eventually.deep.eq(Math.ceil(table[0].double))
await expect(database.eval('temp2', row => $.max($.floor(row.double)), table[0].id)).to.eventually.deep.eq(Math.floor(table[0].double))
await expect(database.eval('temp2', row => $.max($.round(row.double)), table[0].id)).to.eventually.deep.eq(Math.round(table[0].double))
await expect(database.eval('temp2', row => $.max($.exp(row.double)), table[0].id)).to.eventually.deep.eq(Math.exp(table[0].double))
await expect(database.eval('temp2', row => $.max($.log(row.double)), table[0].id)).to.eventually.deep.eq(Math.log(table[0].double))
await expect(database.eval('temp2', row => $.max($.floor($.log(row.double, 3))), table[0].id)).to.eventually.deep.eq(Math.floor(Math.log(table[0].double) / Math.log(3)))
await expect(database.eval('temp2', row => $.max($.floor($.pow(row.double, row.num))), table[0].id))
.to.eventually.deep.eq(Math.floor(Math.pow(table[0].double, table[0].num)))
})

it('$.random', async () => {
await setup(database, 'temp2', barTable)
await expect(database.eval('temp2', row => $.max($.random()))).to.eventually.gt(0).lt(1)
})
}
}
Expand Down