Skip to content

Commit

Permalink
Implement PG aggregate functions
Browse files Browse the repository at this point in the history
Implement `count`, `avg`, `sum`, `max` and `min` to PG dialect
  • Loading branch information
L-Mario564 committed Nov 9, 2023
1 parent 1a482ce commit 73981dc
Show file tree
Hide file tree
Showing 16 changed files with 358 additions and 28 deletions.
62 changes: 62 additions & 0 deletions drizzle-orm/src/built-in-function.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import { entityKind } from './entity.ts';
import { Dialect } from './column-builder.ts';
import { type SQLWrapper, type SQL, sql, DriverValueDecoder, GetDecoderResult } from './sql/sql.ts';

/** @internal */
export const BuiltInFunctionSQL = Symbol.for('drizzle:BuiltInFunctionSQL');

export interface BuiltInFunction<T = unknown> extends SQLWrapper {
// SQLWrapper runtime implementation is defined in 'sql/sql.ts'
}
export abstract class BuiltInFunction<T = unknown> implements SQLWrapper {
static readonly [entityKind]: string = 'BuiltInFunction';

declare readonly _: {
readonly type: T;
readonly dialect: Dialect;
};

/** @internal */
static readonly Symbol = {
SQL: BuiltInFunctionSQL as typeof BuiltInFunctionSQL,
};

/** @internal */
get [BuiltInFunctionSQL](): SQL<T> {
return this.sql;
};

protected sql: SQL<T>;

constructor(sql: SQL<T>) {
this.sql = sql;
}

as(alias: string): SQL.Aliased<T>;
/**
* @deprecated
* Use ``sql<DataType>`query`.as(alias)`` instead.
*/
as<TData>(): SQL<TData>;
/**
* @deprecated
* Use ``sql<DataType>`query`.as(alias)`` instead.
*/
as<TData>(alias: string): SQL.Aliased<TData>;
as(alias?: string): SQL<T> | SQL.Aliased<T> {
// TODO: remove with deprecated overloads
if (alias === undefined) {
return this.sql;
}

return this.sql.as(alias);
}

mapWith<
TDecoder extends
| DriverValueDecoder<any, any>
| DriverValueDecoder<any, any>['mapFromDriverValue'],
>(decoder: TDecoder): SQL<GetDecoderResult<TDecoder>> {
return this.sql.mapWith(decoder);
}
}
51 changes: 51 additions & 0 deletions drizzle-orm/src/distinct.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import { entityKind, is } from './entity.ts';
import { type SQLWrapper } from './index.ts';

/** @internal */
export const DistinctValue = Symbol.for('drizzle:DistinctValue');

export class Distinct<T extends SQLWrapper = SQLWrapper> {
static readonly [entityKind]: string = 'Distinct';

declare readonly _: {
readonly type: T;
};

/** @internal */
static readonly Symbol = {
Value: DistinctValue as typeof DistinctValue,
};

/** @internal */
[DistinctValue]: T;

constructor(value: T) {
this[DistinctValue] = value;
}
}

export type MaybeDistinct<T extends SQLWrapper = SQLWrapper> = T | Distinct<T>;

export type WithoutDistinct<T> = T extends Distinct ? T['_']['type'] : T;

export function distinct<T extends SQLWrapper = SQLWrapper>(value: T) {
return new Distinct(value);
}

/** @internal */
export function getValueWithDistinct<T>(value: T): {
value: WithoutDistinct<T>;
distinct: boolean;
} {
if (is(value, Distinct)) {
return {
value: value[DistinctValue],
distinct: true
} as any;
}

return {
value,
distinct: false
} as any;
}
2 changes: 2 additions & 0 deletions drizzle-orm/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
export * from './alias.ts';
export * from './built-in-function.ts';
export * from './column-builder.ts';
export * from './column.ts';
export * from './distinct.ts';
export * from './entity.ts';
export * from './errors.ts';
export * from './expressions.ts';
Expand Down
17 changes: 9 additions & 8 deletions drizzle-orm/src/operations.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { BuiltInFunction } from './built-in-function.ts';
import type { AnyColumn, Column } from './column.ts';
import type { SQL } from './sql/index.ts';
import type { Table } from './table.ts';
Expand All @@ -13,22 +14,22 @@ export type OptionalKeyOnly<
T extends Column,
> = TKey extends RequiredKeyOnly<TKey, T> ? never : TKey;

export type SelectedFieldsFlat<TColumn extends Column> = Record<
export type SelectedFieldsFlat<TColumn extends Column, TBuiltInFunction extends BuiltInFunction> = Record<
string,
TColumn | SQL | SQL.Aliased
TColumn | TBuiltInFunction | SQL | SQL.Aliased
>;

export type SelectedFieldsFlatFull<TColumn extends Column> = Record<
export type SelectedFieldsFlatFull<TColumn extends Column, TBuiltInFunction extends BuiltInFunction> = Record<
string,
TColumn | SQL | SQL.Aliased
TColumn | TBuiltInFunction | SQL | SQL.Aliased
>;

export type SelectedFields<TColumn extends Column, TTable extends Table> = Record<
export type SelectedFields<TColumn extends Column, TTable extends Table, TBuiltInFunction extends BuiltInFunction> = Record<
string,
SelectedFieldsFlat<TColumn>[string] | TTable | SelectedFieldsFlat<TColumn>
SelectedFieldsFlat<TColumn, TBuiltInFunction>[string] | TTable | SelectedFieldsFlat<TColumn, TBuiltInFunction>
>;

export type SelectedFieldsOrdered<TColumn extends Column> = {
export type SelectedFieldsOrdered<TColumn extends Column, TBuiltInFunction extends BuiltInFunction> = {
path: string[];
field: TColumn | SQL | SQL.Aliased;
field: TColumn | TBuiltInFunction | SQL | SQL.Aliased;
}[];
9 changes: 6 additions & 3 deletions drizzle-orm/src/pg-core/dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ import { ViewBaseConfig } from '~/view-common.ts';
import { View } from '~/view.ts';
import type { PgSession } from './session.ts';
import { type PgMaterializedView, PgViewBase } from './view.ts';
import { BuiltInFunction, BuiltInFunctionSQL } from '~/built-in-function.ts';
import type { PgBuiltInFunction } from './functions/common.ts';

export class PgDialect {
static readonly [entityKind]: string = 'PgDialect';
Expand Down Expand Up @@ -154,8 +156,9 @@ export class PgDialect {

if (is(field, SQL.Aliased) && field.isSelectionField) {
chunk.push(sql.identifier(field.fieldAlias));
} else if (is(field, SQL.Aliased) || is(field, SQL)) {
const query = is(field, SQL.Aliased) ? field.sql : field;
} else if (is(field, SQL.Aliased) || is(field, SQL) || is(field, BuiltInFunction)) {
const field_ = is(field, BuiltInFunction) ? field[BuiltInFunctionSQL] : field
const query = is(field_, SQL.Aliased) ? field_.sql : field_;

if (isSingleTable) {
chunk.push(
Expand Down Expand Up @@ -211,7 +214,7 @@ export class PgDialect {
setOperators,
}: PgSelectConfig,
): SQL {
const fieldsList = fieldsFlat ?? orderSelectedFields<PgColumn>(fields);
const fieldsList = fieldsFlat ?? orderSelectedFields<PgColumn, PgBuiltInFunction>(fields);
for (const f of fieldsList) {
if (
is(f.field, Column)
Expand Down
172 changes: 172 additions & 0 deletions drizzle-orm/src/pg-core/functions/aggregate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import { is, entityKind } from '~/entity.ts';
import { PgColumn } from '../columns/index.ts';
import { type SQL, sql, type SQLWrapper, isSQLWrapper, SQLChunk } from '~/sql/index.ts';
import { PgBuiltInFunction } from './common.ts';
import { type MaybeDistinct, getValueWithDistinct } from '~/distinct.ts';

export class PgAggregateFunction<T = unknown> extends PgBuiltInFunction<T> {
static readonly [entityKind]: string = 'PgAggregateFunction';

filterWhere(where?: SQL | undefined): this {
if (where) {
this.sql.append(sql` filter (where ${where})`);
}
return this;
}
}

/**
* Returns the number of values in `expression`.
*
* ## Examples
*
* ```ts
* // Number employees with null values
* db.select({ value: count() }).from(employees)
* // Number of employees where `name` is not null
* db.select({ value: count(employees.name) }).from(employees)
* // Number of employees where `name` is distinct (no duplicates)
* db.select({ value: count(distinct(employees.name)) }).from(employees)
* // Number of employees where their salaries are greater than $2,000
* db.select({ value: count().filterWhere(gt(employees.salary, 2000)) }).from(employees)
* ```
*/
export function count<T extends 'number' | 'bigint' | undefined = undefined>(expression?: MaybeDistinct<SQLWrapper> | '*', config?: {
mode: T;
}): PgAggregateFunction<T extends 'number' ? number : bigint> {
const { value, distinct } = getValueWithDistinct(expression);
const chunks: SQLChunk[] = [];

if (distinct) {
chunks.push(sql`distinct `);
}
chunks.push(isSQLWrapper(value) ? value : sql`*`);

const sql_ = sql
.join([sql`count(`, ...chunks, sql`)` ])
.mapWith(config?.mode === 'number' ? Number : BigInt);

return new PgAggregateFunction(sql_) as any;
}

/**
* Returns the average (arithmetic mean) of all non-null values in `expression`.
*
* ## Examples
*
* ```ts
* // Average salary of an employee
* db.select({ value: avg(employees.salary) }).from(employees)
* // Average salary of an employee where `salary` is distinct (no duplicates)
* db.select({ value: avg(distinct(employees.salary)) }).from(employees)
* // Average salary of an employee where their salaries are greater than $2,000
* db.select({ value: avg(employees.salary).filterWhere(gt(employees.salary, 2000)) }).from(employees)
* ```
*/
export function avg<T extends 'number' | 'bigint' | 'string' | undefined = undefined>(expression: MaybeDistinct<SQLWrapper>, config?: {
mode: T;
}): PgAggregateFunction<(T extends 'bigint' ? bigint : T extends 'number' ? number : string) | null> {
const { value, distinct } = getValueWithDistinct(expression);
const chunks: SQLChunk[] = [];

if (distinct) {
chunks.push(sql`distinct `);
}
chunks.push(value);

let sql_ = sql.join([sql`avg(`, ...chunks, sql`)`]);

if (config?.mode === 'bigint') {
sql_ = sql_.mapWith(BigInt);
} else if (config?.mode === 'number') {
sql_ = sql_.mapWith(Number);
}

return new PgAggregateFunction(sql_) as any;
}

/**
* Returns the sum of all non-null values in `expression`.
*
* ## Examples
*
* ```ts
* // Sum of every employee's salary
* db.select({ value: sum(employees.salary) }).from(employees)
* // Sum of every employee's salary where `salary` is distinct (no duplicates)
* db.select({ value: sum(distinct(employees.salary)) }).from(employees)
* // Sum of every employee's salary where their salaries are greater than $2,000
* db.select({ value: sum(employees.salary).filterWhere(gt(employees.salary, 2000)) }).from(employees)
* ```
*/
export function sum<T extends 'number' | 'bigint' | 'string' | undefined = undefined>(expression: MaybeDistinct<SQLWrapper>, config?: {
mode: T;
}): PgAggregateFunction<(T extends 'bigint' ? bigint : T extends 'number' ? number : string) | null> {
const { value, distinct } = getValueWithDistinct(expression);
const chunks: SQLChunk[] = [];

if (distinct) {
chunks.push(sql`distinct `);
}
chunks.push(value);

let sql_ = sql.join([sql`sum(`, ...chunks, sql`)`]);

if (config?.mode === 'bigint') {
sql_ = sql_.mapWith(BigInt);
} else if (config?.mode === 'number') {
sql_ = sql_.mapWith(Number);
}

return new PgAggregateFunction(sql_) as any;
}

/**
* Returns the maximum value in `expression`.
*
* ## Examples
*
* ```ts
* // The employee with the highest salary
* db.select({ value: max(employees.salary) }).from(employees)
* ```
*/
export function max<T extends SQLWrapper>(expression: T): T extends PgColumn
? PgAggregateFunction<T['_']['data'] | null>
: PgAggregateFunction<string | null>
{
let sql_ = sql.join([sql`max(`, expression, sql`)`]);

if (is(expression, PgColumn)) {
sql_ = sql_.mapWith(expression);
} else {
sql_ = sql_.mapWith(String);
}

return new PgAggregateFunction(sql_) as any;
}

/**
* Returns the minimum value in `expression`.
*
* ## Examples
*
* ```ts
* // The employee with the lowest salary
* db.select({ value: min(employees.salary) }).from(employees)
* ```
*/
export function min<T extends SQLWrapper>(expression: T): T extends PgColumn
? PgAggregateFunction<T['_']['data'] | null>
: PgAggregateFunction<string | null>
{
let sql_ = sql.join([sql`min(`, expression, sql`)`]);

if (is(expression, PgColumn)) {
sql_ = sql_.mapWith(expression);
} else {
sql_ = sql_.mapWith(String);
}

return new PgAggregateFunction(sql_) as any;
}
11 changes: 11 additions & 0 deletions drizzle-orm/src/pg-core/functions/common.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import { BuiltInFunction } from '~/built-in-function.ts';
import { entityKind } from '~/entity.ts';

export class PgBuiltInFunction<T = unknown> extends BuiltInFunction<T> {
static readonly [entityKind]: string = 'PgBuiltInFunction';

declare readonly _: {
readonly type: T;
readonly dialect: 'pg';
};
}
2 changes: 2 additions & 0 deletions drizzle-orm/src/pg-core/functions/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export * from './aggregate.ts';
export * from './common.ts';
1 change: 1 addition & 0 deletions drizzle-orm/src/pg-core/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export * from './columns/index.ts';
export * from './db.ts';
export * from './dialect.ts';
export * from './foreign-keys.ts';
export * from './functions/index.ts';
export * from './indexes.ts';
export * from './primary-keys.ts';
export * from './query-builders/index.ts';
Expand Down
Loading

0 comments on commit 73981dc

Please sign in to comment.