From c2d66ac39a0f8091250e94414722ea073781a325 Mon Sep 17 00:00:00 2001 From: Vision Onyeaku Date: Tue, 27 Feb 2024 19:34:32 +0100 Subject: [PATCH] Automatically cast timestamp strings as timestamp in queries When making select queries, in where statements if there is a clause that includes a timestamp field(with timestamps with mode:string), you get an error like ERROR: operator does not exist: timestamp without time zone >= text; Hint: No operator matches the given name and argument types. You might need to add explicit type casts This fixes this by casting the string explicitly as timestamp --- drizzle-orm/src/aws-data-api/pg/driver.ts | 25 ++++++++++++++++++- integration-tests/tests/awsdatapi.test.ts | 30 ++++++++++++++++------- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/drizzle-orm/src/aws-data-api/pg/driver.ts b/drizzle-orm/src/aws-data-api/pg/driver.ts index e5acbe939..f90902c67 100644 --- a/drizzle-orm/src/aws-data-api/pg/driver.ts +++ b/drizzle-orm/src/aws-data-api/pg/driver.ts @@ -1,4 +1,4 @@ -import { entityKind } from '~/entity.ts'; +import { entityKind, is } from '~/entity.ts'; import type { Logger } from '~/logger.ts'; import { DefaultLogger } from '~/logger.ts'; import { PgDatabase } from '~/pg-core/db.ts'; @@ -12,6 +12,8 @@ import { import type { DrizzleConfig } from '~/utils.ts'; import type { AwsDataApiClient, AwsDataApiPgQueryResultHKT } from './session.ts'; import { AwsDataApiSession } from './session.ts'; +import { PgSelectConfig, PgTimestampString } from '~/pg-core/index.ts'; +import { Param, SQL, sql } from '~/index.ts'; export interface PgDriverOptions { logger?: Logger; @@ -38,6 +40,27 @@ export class AwsPgDialect extends PgDialect { override escapeParam(num: number): string { return `:${num + 1}`; } + + override buildSelectQuery(config: PgSelectConfig): SQL { + if (config.where) { + config.where = this.castTimestampStringParamAsTimestamp(config.where) + } + + return super.buildSelectQuery(config) + } + + castTimestampStringParamAsTimestamp(existingSql: SQL): SQL { + return sql.join(existingSql.queryChunks.map((chunk) => { + if (is(chunk, Param) && is(chunk.encoder, PgTimestampString)) { + return sql`cast(${chunk.value} as timestamp)` + } + if (is(chunk, SQL)) { + return this.castTimestampStringParamAsTimestamp(chunk) + } + + return chunk + })) + } } export function drizzle = Record>( diff --git a/integration-tests/tests/awsdatapi.test.ts b/integration-tests/tests/awsdatapi.test.ts index 1f390eb70..1f4679b5a 100644 --- a/integration-tests/tests/awsdatapi.test.ts +++ b/integration-tests/tests/awsdatapi.test.ts @@ -5,7 +5,7 @@ import { fromIni } from '@aws-sdk/credential-providers'; import type { TestFn } from 'ava'; import anyTest from 'ava'; import * as dotenv from 'dotenv'; -import { asc, eq, name, placeholder, sql, TransactionRollbackError } from 'drizzle-orm'; +import { and, asc, eq, gte, name, placeholder, sql, TransactionRollbackError } from 'drizzle-orm'; import type { AwsDataApiPgDatabase } from 'drizzle-orm/aws-data-api/pg'; import { drizzle } from 'drizzle-orm/aws-data-api/pg'; import { migrate } from 'drizzle-orm/aws-data-api/pg/migrator'; @@ -18,6 +18,7 @@ const usersTable = pgTable('users', { name: text('name').notNull(), verified: boolean('verified').notNull().default(false), jsonb: jsonb('jsonb').$type(), + updatedAt: timestamp('updated_at', { withTimezone: true, mode: 'string' }).notNull().defaultNow(), createdAt: timestamp('created_at', { withTimezone: true }).notNull().defaultNow(), }); @@ -62,7 +63,8 @@ test.beforeEach(async (t) => { name text not null, verified boolean not null default false, jsonb jsonb, - created_at timestamptz not null default now() + created_at timestamptz not null default now(), + updated_at timestamptz not null default now() ) `, ); @@ -79,7 +81,7 @@ test.serial('select all fields', async (t) => { t.assert(result[0]!.createdAt instanceof Date); // eslint-disable-line no-instanceof/no-instanceof // t.assert(Math.abs(result[0]!.createdAt.getTime() - now) < 100); - t.deepEqual(result, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt }]); + t.deepEqual(result, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt, updatedAt: result[0]!.createdAt }]); }); test.serial('select sql', async (t) => { @@ -185,7 +187,7 @@ test.serial('update with returning all fields', async (t) => { t.assert(users[0]!.createdAt instanceof Date); // eslint-disable-line no-instanceof/no-instanceof // t.assert(Math.abs(users[0]!.createdAt.getTime() - now) < 100); - t.deepEqual(users, [{ id: 1, name: 'Jane', verified: false, jsonb: null, createdAt: users[0]!.createdAt }]); + t.deepEqual(users, [{ id: 1, name: 'Jane', verified: false, jsonb: null, createdAt: users[0]!.createdAt, updatedAt: users[0]!.updatedAt }]); }); test.serial('update with returning partial', async (t) => { @@ -208,7 +210,7 @@ test.serial('delete with returning all fields', async (t) => { t.assert(users[0]!.createdAt instanceof Date); // eslint-disable-line no-instanceof/no-instanceof // t.assert(Math.abs(users[0]!.createdAt.getTime() - now) < 100); - t.deepEqual(users, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: users[0]!.createdAt }]); + t.deepEqual(users, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: users[0]!.createdAt, updatedAt: users[0]!.updatedAt }]); }); test.serial('delete with returning partial', async (t) => { @@ -228,13 +230,13 @@ test.serial('insert + select', async (t) => { await db.insert(usersTable).values({ name: 'John' }); const result = await db.select().from(usersTable); - t.deepEqual(result, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt }]); + t.deepEqual(result, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt, updatedAt: result[0]!.updatedAt }]); await db.insert(usersTable).values({ name: 'Jane' }); const result2 = await db.select().from(usersTable); t.deepEqual(result2, [ - { id: 1, name: 'John', verified: false, jsonb: null, createdAt: result2[0]!.createdAt }, - { id: 2, name: 'Jane', verified: false, jsonb: null, createdAt: result2[1]!.createdAt }, + { id: 1, name: 'John', verified: false, jsonb: null, createdAt: result2[0]!.createdAt, updatedAt: result2[0]!.updatedAt }, + { id: 2, name: 'Jane', verified: false, jsonb: null, createdAt: result2[1]!.createdAt, updatedAt: result2[0]!.updatedAt }, ]); }); @@ -257,7 +259,7 @@ test.serial('insert with overridden default values', async (t) => { await db.insert(usersTable).values({ name: 'John', verified: true }); const result = await db.select().from(usersTable); - t.deepEqual(result, [{ id: 1, name: 'John', verified: true, jsonb: null, createdAt: result[0]!.createdAt }]); + t.deepEqual(result, [{ id: 1, name: 'John', verified: true, jsonb: null, createdAt: result[0]!.createdAt, updatedAt: result[0]!.updatedAt }]); }); test.serial('insert many', async (t) => { @@ -430,6 +432,7 @@ test.serial('full join with alias', async (t) => { verified: false, jsonb: null, createdAt: result[0]!.users.createdAt, + updatedAt: result[0]!.users.updatedAt }, customer: { id: 11, @@ -437,6 +440,7 @@ test.serial('full join with alias', async (t) => { verified: false, jsonb: null, createdAt: result[0]!.customer!.createdAt, + updatedAt: result[0]!.users.updatedAt }, }]); }); @@ -858,6 +862,14 @@ test.serial('select from raw sql with mapped values', async (t) => { ]); }); +test.serial('select query with date works', async (t) => { + const { db } = t.context; + const [newUser] = await db.insert(usersTable).values({ name: 'John' }).returning() + + const [result] = await db.select().from(usersTable).where(and(eq(usersTable.id, newUser!.id), gte(usersTable.updatedAt, newUser!.updatedAt))) + t.deepEqual(result, newUser) +}) + test.after.always(async (t) => { const ctx = t.context; await ctx.db.execute(sql`drop table "users"`);