Skip to content

Commit

Permalink
Merge pull request #1 from livingforjesus/fix-arrayvalue-issue
Browse files Browse the repository at this point in the history
Fix issue with insert/update array in aws-data-api
  • Loading branch information
livingforjesus authored Feb 21, 2024
2 parents 0da1cba + a175ed5 commit d2d6ce4
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 9 deletions.
13 changes: 13 additions & 0 deletions drizzle-orm/src/aws-data-api/common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ export function getValueFromDataApi(field: Field) {
if (field.arrayValue.stringValues !== undefined) {
return field.arrayValue.stringValues;
}
if (field.arrayValue.longValues !== undefined) {
return field.arrayValue.longValues;
}
if (field.arrayValue.doubleValues !== undefined) {
return field.arrayValue.doubleValues;
}
if (field.arrayValue.booleanValues !== undefined) {
return field.arrayValue.booleanValues;
}
if (field.arrayValue.arrayValues !== undefined) {
return field.arrayValue.arrayValues
}

throw new Error('Unknown array type');
} else {
throw new Error('Unknown type');
Expand Down
34 changes: 32 additions & 2 deletions drizzle-orm/src/aws-data-api/pg/driver.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { entityKind } from '~/entity.ts';
import { entityKind, is } from '~/entity.ts';
import type { Logger } from '~/logger.ts';
import { DefaultLogger } from '~/logger.ts';
import { PgDatabase } from '~/pg-core/db.ts';
Expand All @@ -9,9 +9,11 @@ import {
type RelationalSchemaConfig,
type TablesRelationalConfig,
} from '~/relations.ts';
import type { DrizzleConfig } from '~/utils.ts';
import type { DrizzleConfig, UpdateSet } from '~/utils.ts';
import type { AwsDataApiClient, AwsDataApiPgQueryResultHKT } from './session.ts';
import { AwsDataApiSession } from './session.ts';
import { PgArray, PgColumn, PgInsertConfig, PgTable, TableConfig } from '~/pg-core/index.ts';
import { Param, SQL, Table, sql } from '~/index.ts';

export interface PgDriverOptions {
logger?: Logger;
Expand All @@ -38,6 +40,34 @@ export class AwsPgDialect extends PgDialect {
override escapeParam(num: number): string {
return `:${num + 1}`;
}

override buildInsertQuery({ table, values, onConflict, returning }: PgInsertConfig<PgTable<TableConfig>>): SQL<unknown> {
const columns: Record<string, PgColumn> = table[Table.Symbol.Columns];
const colEntries: [string, PgColumn][] = Object.entries(columns);
for (let value of values) {
for (const [fieldName, col] of colEntries) {
const colValue = value[fieldName];
if (is(colValue, Param) && colValue.value !== undefined && is(colValue.encoder, PgArray) && Array.isArray(colValue.value)) {
value[fieldName] = sql`cast(${col.mapToDriverValue(colValue.value)} as ${sql.raw(colValue.encoder.getSQLType())})`
}
}
}

return super.buildInsertQuery({table, values, onConflict, returning})
}

override buildUpdateSet(table: PgTable<TableConfig>, set: UpdateSet): SQL<unknown> {
const columns: Record<string, PgColumn> = table[Table.Symbol.Columns];

Object.entries(set)
.forEach(([colName, colValue]) => {
const currentColumn = columns[colName];
if (currentColumn && is(colValue, Param) && colValue.value !== undefined && is(colValue.encoder, PgArray) && Array.isArray(colValue.value)) {
set[colName] = sql`cast(${currentColumn?.mapToDriverValue(colValue.value)} as ${sql.raw(colValue.encoder.getSQLType())})`
}
})
return super.buildUpdateSet(table, set)
}
}

export function drizzle<TSchema extends Record<string, unknown> = Record<string, never>>(
Expand Down
42 changes: 35 additions & 7 deletions integration-tests/tests/awsdatapi.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const usersTable = pgTable('users', {
name: text('name').notNull(),
verified: boolean('verified').notNull().default(false),
jsonb: jsonb('jsonb').$type<string[]>(),
bestTexts: text('best_texts').array().default(sql`'{}'`).notNull(),
createdAt: timestamp('created_at', { withTimezone: true }).notNull().defaultNow(),
});

Expand Down Expand Up @@ -62,6 +63,7 @@ test.beforeEach(async (t) => {
name text not null,
verified boolean not null default false,
jsonb jsonb,
best_texts text[] not null default '{}',
created_at timestamptz not null default now()
)
`,
Expand All @@ -79,7 +81,7 @@ test.serial('select all fields', async (t) => {

t.assert(result[0]!.createdAt instanceof Date); // eslint-disable-line no-instanceof/no-instanceof
// t.assert(Math.abs(result[0]!.createdAt.getTime() - now) < 100);
t.deepEqual(result, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt }]);
t.deepEqual(result, [{ bestTexts: [], id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt }]);
});

test.serial('select sql', async (t) => {
Expand Down Expand Up @@ -185,7 +187,7 @@ test.serial('update with returning all fields', async (t) => {

t.assert(users[0]!.createdAt instanceof Date); // eslint-disable-line no-instanceof/no-instanceof
// t.assert(Math.abs(users[0]!.createdAt.getTime() - now) < 100);
t.deepEqual(users, [{ id: 1, name: 'Jane', verified: false, jsonb: null, createdAt: users[0]!.createdAt }]);
t.deepEqual(users, [{ id: 1, bestTexts: [], name: 'Jane', verified: false, jsonb: null, createdAt: users[0]!.createdAt }]);
});

test.serial('update with returning partial', async (t) => {
Expand All @@ -208,7 +210,7 @@ test.serial('delete with returning all fields', async (t) => {

t.assert(users[0]!.createdAt instanceof Date); // eslint-disable-line no-instanceof/no-instanceof
// t.assert(Math.abs(users[0]!.createdAt.getTime() - now) < 100);
t.deepEqual(users, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: users[0]!.createdAt }]);
t.deepEqual(users, [{ bestTexts: [], id: 1, name: 'John', verified: false, jsonb: null, createdAt: users[0]!.createdAt }]);
});

test.serial('delete with returning partial', async (t) => {
Expand All @@ -228,13 +230,13 @@ test.serial('insert + select', async (t) => {

await db.insert(usersTable).values({ name: 'John' });
const result = await db.select().from(usersTable);
t.deepEqual(result, [{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt }]);
t.deepEqual(result, [{ bestTexts: [], id: 1, name: 'John', verified: false, jsonb: null, createdAt: result[0]!.createdAt }]);

await db.insert(usersTable).values({ name: 'Jane' });
const result2 = await db.select().from(usersTable);
t.deepEqual(result2, [
{ id: 1, name: 'John', verified: false, jsonb: null, createdAt: result2[0]!.createdAt },
{ id: 2, name: 'Jane', verified: false, jsonb: null, createdAt: result2[1]!.createdAt },
{ bestTexts: [], id: 1, name: 'John', verified: false, jsonb: null, createdAt: result2[0]!.createdAt },
{ bestTexts: [], id: 2, name: 'Jane', verified: false, jsonb: null, createdAt: result2[1]!.createdAt },
]);
});

Expand All @@ -257,7 +259,7 @@ test.serial('insert with overridden default values', async (t) => {
await db.insert(usersTable).values({ name: 'John', verified: true });
const result = await db.select().from(usersTable);

t.deepEqual(result, [{ id: 1, name: 'John', verified: true, jsonb: null, createdAt: result[0]!.createdAt }]);
t.deepEqual(result, [{ bestTexts: [], id: 1, name: 'John', verified: true, jsonb: null, createdAt: result[0]!.createdAt }]);
});

test.serial('insert many', async (t) => {
Expand Down Expand Up @@ -426,12 +428,14 @@ test.serial('full join with alias', async (t) => {
t.deepEqual(result, [{
users: {
id: 10,
bestTexts: [],
name: 'Ivan',
verified: false,
jsonb: null,
createdAt: result[0]!.users.createdAt,
},
customer: {
bestTexts: [],
id: 11,
name: 'Hans',
verified: false,
Expand Down Expand Up @@ -858,6 +862,30 @@ test.serial('select from raw sql with mapped values', async (t) => {
]);
});

test.serial('insert with array values works', async (t) => {
const { db } = t.context;

const bestTexts = ['text1', 'text2', 'text3']
const [insertResult] = await db.insert(usersTable).values({
name: 'John',
bestTexts
}).returning()

t.deepEqual(insertResult?.bestTexts , bestTexts);
});

test.serial('update with array values works', async (t) => {
const { db } = t.context;
const [newUser] = await db.insert(usersTable).values({ name: 'John' }).returning()

const bestTexts = ['text4', 'text5', 'text6']
const [insertResult] = await db.update(usersTable).set({
bestTexts
}).where(eq(usersTable.id, newUser!.id)).returning()

t.deepEqual(insertResult?.bestTexts , bestTexts);
});

test.after.always(async (t) => {
const ctx = t.context;
await ctx.db.execute(sql`drop table "users"`);
Expand Down

0 comments on commit d2d6ce4

Please sign in to comment.