Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically cast timestamp strings as timestamp #1933

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion drizzle-orm/src/aws-data-api/pg/driver.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { entityKind } from '~/entity.ts';
import { entityKind, is } from '~/entity.ts';
import type { Logger } from '~/logger.ts';
import { DefaultLogger } from '~/logger.ts';
import { PgDatabase } from '~/pg-core/db.ts';
Expand All @@ -12,6 +12,8 @@ import {
import type { DrizzleConfig } from '~/utils.ts';
import type { AwsDataApiClient, AwsDataApiPgQueryResultHKT } from './session.ts';
import { AwsDataApiSession } from './session.ts';
import { PgSelectConfig, PgTimestampString } from '~/pg-core/index.ts';
import { Param, SQL, sql } from '~/index.ts';

export interface PgDriverOptions {
logger?: Logger;
Expand All @@ -38,6 +40,27 @@ export class AwsPgDialect extends PgDialect {
override escapeParam(num: number): string {
return `:${num + 1}`;
}

override buildSelectQuery(config: PgSelectConfig): SQL<unknown> {
if (config.where) {
config.where = this.castTimestampStringParamAsTimestamp(config.where)
}

return super.buildSelectQuery(config)
}

castTimestampStringParamAsTimestamp(existingSql: SQL<unknown>): SQL<unknown> {
return sql.join(existingSql.queryChunks.map((chunk) => {
if (is(chunk, Param) && is(chunk.encoder, PgTimestampString)) {
return sql`cast(${chunk.value} as timestamp)`
}
if (is(chunk, SQL)) {
return this.castTimestampStringParamAsTimestamp(chunk)
}

return chunk
}))
}
}

export function drizzle<TSchema extends Record<string, unknown> = Record<string, never>>(
Expand Down
30 changes: 21 additions & 9 deletions integration-tests/tests/awsdatapi.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { fromIni } from '@aws-sdk/credential-providers';
import type { TestFn } from 'ava';
import anyTest from 'ava';
import * as dotenv from 'dotenv';
import { asc, eq, name, placeholder, sql, TransactionRollbackError } from 'drizzle-orm';
import { and, asc, eq, gte, name, placeholder, sql, TransactionRollbackError } from 'drizzle-orm';
import type { AwsDataApiPgDatabase } from 'drizzle-orm/aws-data-api/pg';
import { drizzle } from 'drizzle-orm/aws-data-api/pg';
import { migrate } from 'drizzle-orm/aws-data-api/pg/migrator';
Expand All @@ -18,6 +18,7 @@ const usersTable = pgTable('users', {
name: text('name').notNull(),
verified: boolean('verified').notNull().default(false),
jsonb: jsonb('jsonb').$type<string[]>(),
updatedAt: timestamp('updated_at', { withTimezone: true, mode: 'string' }).notNull().defaultNow(),
createdAt: timestamp('created_at', { withTimezone: true }).notNull().defaultNow(),
});

Expand Down Expand Up @@ -62,7 +63,8 @@ test.beforeEach(async (t) => {
name text not null,
verified boolean not null default false,
jsonb jsonb,
created_at timestamptz not null default now()
created_at timestamptz not null default now(),
updated_at timestamptz not null default now()
)
`,
);
Expand All @@ -79,7 +81,7 @@ test.serial('select all fields', async (t) => {

t.assert(result[0]!.createdAt instanceof Date); // eslint-disable-line no-instanceof/no-instanceof
// t.assert(Math.abs(result[0]!.createdAt.getTime() - now) < 100);
t.deepEqual(result, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt }]);
t.deepEqual(result, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt, updatedAt: result[0]!.createdAt }]);
});

test.serial('select sql', async (t) => {
Expand Down Expand Up @@ -185,7 +187,7 @@ test.serial('update with returning all fields', async (t) => {

t.assert(users[0]!.createdAt instanceof Date); // eslint-disable-line no-instanceof/no-instanceof
// t.assert(Math.abs(users[0]!.createdAt.getTime() - now) < 100);
t.deepEqual(users, [{ id: 1, name: 'Jane', verified: false, jsonb: null, createdAt: users[0]!.createdAt }]);
t.deepEqual(users, [{ id: 1, name: 'Jane', verified: false, jsonb: null, createdAt: users[0]!.createdAt, updatedAt: users[0]!.updatedAt }]);
});

test.serial('update with returning partial', async (t) => {
Expand All @@ -208,7 +210,7 @@ test.serial('delete with returning all fields', async (t) => {

t.assert(users[0]!.createdAt instanceof Date); // eslint-disable-line no-instanceof/no-instanceof
// t.assert(Math.abs(users[0]!.createdAt.getTime() - now) < 100);
t.deepEqual(users, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: users[0]!.createdAt }]);
t.deepEqual(users, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: users[0]!.createdAt, updatedAt: users[0]!.updatedAt }]);
});

test.serial('delete with returning partial', async (t) => {
Expand All @@ -228,13 +230,13 @@ test.serial('insert + select', async (t) => {

await db.insert(usersTable).values({ name: 'John' });
const result = await db.select().from(usersTable);
t.deepEqual(result, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt }]);
t.deepEqual(result, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt, updatedAt: result[0]!.updatedAt }]);

await db.insert(usersTable).values({ name: 'Jane' });
const result2 = await db.select().from(usersTable);
t.deepEqual(result2, [
{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result2[0]!.createdAt },
{ id: 2, name: 'Jane', verified: false, jsonb: null, createdAt: result2[1]!.createdAt },
{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result2[0]!.createdAt, updatedAt: result2[0]!.updatedAt },
{ id: 2, name: 'Jane', verified: false, jsonb: null, createdAt: result2[1]!.createdAt, updatedAt: result2[0]!.updatedAt },
]);
});

Expand All @@ -257,7 +259,7 @@ test.serial('insert with overridden default values', async (t) => {
await db.insert(usersTable).values({ name: 'John', verified: true });
const result = await db.select().from(usersTable);

t.deepEqual(result, [{ id: 1, name: 'John', verified: true, jsonb: null, createdAt: result[0]!.createdAt }]);
t.deepEqual(result, [{ id: 1, name: 'John', verified: true, jsonb: null, createdAt: result[0]!.createdAt, updatedAt: result[0]!.updatedAt }]);
});

test.serial('insert many', async (t) => {
Expand Down Expand Up @@ -430,13 +432,15 @@ test.serial('full join with alias', async (t) => {
verified: false,
jsonb: null,
createdAt: result[0]!.users.createdAt,
updatedAt: result[0]!.users.updatedAt
},
customer: {
id: 11,
name: 'Hans',
verified: false,
jsonb: null,
createdAt: result[0]!.customer!.createdAt,
updatedAt: result[0]!.users.updatedAt
},
}]);
});
Expand Down Expand Up @@ -858,6 +862,14 @@ test.serial('select from raw sql with mapped values', async (t) => {
]);
});

test.serial('select query with date works', async (t) => {
const { db } = t.context;
const [newUser] = await db.insert(usersTable).values({ name: 'John' }).returning()

const [result] = await db.select().from(usersTable).where(and(eq(usersTable.id, newUser!.id), gte(usersTable.updatedAt, newUser!.updatedAt)))
t.deepEqual(result, newUser)
})

test.after.always(async (t) => {
const ctx = t.context;
await ctx.db.execute(sql`drop table "users"`);
Expand Down