From ee0e4003ff7e48ba2f3ce0364b7bf5c6e4d8e30b Mon Sep 17 00:00:00 2001 From: Mario564 Date: Fri, 15 Nov 2024 10:53:49 -0800 Subject: [PATCH] Add `createUpdateSchema` in `drizzle-zod` + Additional tests --- drizzle-zod/src/create-schema.ts | 30 +++++++-- drizzle-zod/src/types.ts | 47 ++++++++++---- drizzle-zod/tests/pg.test.ts | 101 ++++++++++++++++++++++++++++++- 3 files changed, 160 insertions(+), 18 deletions(-) diff --git a/drizzle-zod/src/create-schema.ts b/drizzle-zod/src/create-schema.ts index 1df825470..b6e4fd778 100644 --- a/drizzle-zod/src/create-schema.ts +++ b/drizzle-zod/src/create-schema.ts @@ -1,9 +1,9 @@ -import { z } from 'zod'; +import { optional, z } from 'zod'; import { Column, getTableColumns, getViewSelectedFields, is, isTable, isView, SQL } from 'drizzle-orm'; import { columnToSchema } from './column'; import { isPgEnum, PgEnum } from 'drizzle-orm/pg-core'; import type { Table, View } from 'drizzle-orm'; -import type { CreateInsertSchema, CreateSchemaFactoryOptions, CreateSelectSchema } from './types'; +import type { CreateInsertSchema, CreateSchemaFactoryOptions, CreateSelectSchema, CreateUpdateSchema } from './types'; function getColumns(tableLike: Table | View) { return isTable(tableLike) ? getTableColumns(tableLike) : getViewSelectedFields(tableLike); @@ -73,7 +73,13 @@ const insertConditions = { never: (column?: Column) => column?.generated?.type === 'always' || column?.generatedIdentity?.type === 'always', optional: (column: Column) => !column.notNull || (column.notNull && column.hasDefault), nullable: (column: Column) => !column.notNull -} +}; + +const updateConditions = { + never: (column?: Column) => column?.generated?.type === 'always' || column?.generatedIdentity?.type === 'always', + optional: () => true, + nullable: (column: Column) => !column.notNull +}; export const createSelectSchema: CreateSelectSchema = ( entity: Table | View | PgEnum<[string, ...string[]]>, @@ -94,6 +100,14 @@ export const createInsertSchema: CreateInsertSchema = ( return handleColumns(columns, refine ?? {}, insertConditions) as any; } +export const createUpdateSchema: CreateUpdateSchema = ( + entity: Table, + refine?: Record +) => { + const columns = getColumns(entity); + return handleColumns(columns, refine ?? {}, updateConditions) as any; +} + export function createSchemaFactory(options?: CreateSchemaFactoryOptions) { const createSelectSchema: CreateSelectSchema = ( entity: Table | View | PgEnum<[string, ...string[]]>, @@ -114,5 +128,13 @@ export function createSchemaFactory(options?: CreateSchemaFactoryOptions) { return handleColumns(columns, refine ?? {}, insertConditions, options) as any; } - return { createSelectSchema, createInsertSchema }; + const createUpdateSchema: CreateUpdateSchema = ( + entity: Table, + refine?: Record + ) => { + const columns = getColumns(entity); + return handleColumns(columns, refine ?? {}, updateConditions, options) as any; + } + + return { createSelectSchema, createInsertSchema, createUpdateSchema }; } diff --git a/drizzle-zod/src/types.ts b/drizzle-zod/src/types.ts index 3f11116e7..b4b820528 100644 --- a/drizzle-zod/src/types.ts +++ b/drizzle-zod/src/types.ts @@ -54,22 +54,32 @@ export type GetZodType< export type BuildRefineColumns< TColumns extends Record -> = Simplify<{ +> = Simplify extends infer TSchema extends z.ZodTypeAny - ? TSchema - : z.ZodAny - : TColumns[K] extends infer TObject extends SelectedFieldsFlat - ? BuildRefineColumns + ? ColumnIsGeneratedAlwaysAs extends true + ? never + : GetZodType< + TColumn['_']['data'], + TColumn['_']['dataType'], + TColumn['_'] extends { enumValues: [string, ...string[]] } ? TColumn['_']['enumValues'] : undefined + > extends infer TSchema extends z.ZodTypeAny + ? TSchema + : z.ZodAny + : TColumns[K] extends infer TObject extends SelectedFieldsFlat | Table | View + ? BuildRefineColumns< + TObject extends Table + ? TObject['_']['columns'] + : TObject extends View + ? TObject['_']['selectedFields'] + : TObject + > : TColumns[K] -}>; +}>>; -export type BuildRefine> = BuildRefineColumns extends infer TBuildColumns +export type BuildRefine< + TColumns extends Record +> = BuildRefineColumns extends infer TBuildColumns ? { [K in keyof TBuildColumns]?: TBuildColumns[K] extends z.ZodTypeAny @@ -193,13 +203,24 @@ export interface CreateInsertSchema { (table: TTable): BuildSchema<'insert', TTable['_']['columns'], undefined>; < TTable extends Table, - TRefine extends BuildRefine + TRefine extends BuildRefine> >( table: TTable, refine?: TRefine ): BuildSchema<'insert', TTable['_']['columns'], TRefine>; } +export interface CreateUpdateSchema { + (table: TTable): BuildSchema<'update', TTable['_']['columns'], undefined>; + < + TTable extends Table, + TRefine extends BuildRefine> + >( + table: TTable, + refine?: TRefine + ): BuildSchema<'update', TTable['_']['columns'], TRefine>; +} + export interface CreateSchemaFactoryOptions { zodInstance?: any; } diff --git a/drizzle-zod/tests/pg.test.ts b/drizzle-zod/tests/pg.test.ts index 1b34f75c6..75d657dfd 100644 --- a/drizzle-zod/tests/pg.test.ts +++ b/drizzle-zod/tests/pg.test.ts @@ -1,7 +1,7 @@ import { char, date, getViewConfig, integer, pgEnum, pgMaterializedView, pgTable, pgView, serial, text, timestamp, varchar } from 'drizzle-orm/pg-core'; import { test } from 'vitest'; import { z } from 'zod'; -import { createInsertSchema, createSelectSchema } from '../src'; +import { createInsertSchema, createSelectSchema, createUpdateSchema } from '../src'; import { expectEnumValues, expectSchemaShape } from './utils.ts'; import { sql } from 'drizzle-orm'; @@ -28,6 +28,21 @@ test('table - insert', (t) => { expectSchemaShape(t, expected).from(result); }); +test('table - update', (t) => { + const table = pgTable('test', { + id: integer('id').generatedAlwaysAsIdentity().primaryKey(), + name: text('name').notNull(), + age: integer('age') + }); + + const result = createUpdateSchema(table); + const expected = z.object({ + name: z.string().optional(), + age: z.number().nullable().optional() + }); + expectSchemaShape(t, expected).from(result); +}); + test('view qb - select', (t) => { const table = pgTable('test', { id: serial('id').primaryKey(), @@ -123,6 +138,50 @@ test('nullability - select', (t) => { expectSchemaShape(t, expected).from(result); }); +test('nullability - insert', (t) => { + const table = pgTable('test', { + c1: integer(), + c2: integer().notNull(), + c3: integer().default(1), + c4: integer().notNull().default(1), + c5: integer().generatedAlwaysAs(1), + c6: integer().generatedAlwaysAsIdentity(), + c7: integer().generatedByDefaultAsIdentity(), + }); + + const result = createInsertSchema(table); + const expected = z.object({ + c1: z.number().int().nullable().optional(), + c2: z.number().int(), + c3: z.number().int().nullable().optional(), + c4: z.number().int().optional(), + c7: z.number().int().optional(), + }); + expectSchemaShape(t, expected).from(result); +}); + +test('nullability - update', (t) => { + const table = pgTable('test', { + c1: integer(), + c2: integer().notNull(), + c3: integer().default(1), + c4: integer().notNull().default(1), + c5: integer().generatedAlwaysAs(1), + c6: integer().generatedAlwaysAsIdentity(), + c7: integer().generatedByDefaultAsIdentity(), + }); + + const result = createUpdateSchema(table); + const expected = z.object({ + c1: z.number().int().nullable().optional(), + c2: z.number().int().optional(), + c3: z.number().int().nullable().optional(), + c4: z.number().int().optional(), + c7: z.number().int().optional(), + }); + expectSchemaShape(t, expected).from(result); +}); + test('refine table - select', (t) => { const table = pgTable('test', { c1: integer(), @@ -142,6 +201,46 @@ test('refine table - select', (t) => { expectSchemaShape(t, expected).from(result); }); +test('refine table - insert', (t) => { + const table = pgTable('test', { + c1: integer(), + c2: integer().notNull(), + c3: integer().notNull(), + c4: integer().generatedAlwaysAs(1) + }); + + const result = createInsertSchema(table, { + c2: (schema) => schema.max(1000), + c3: z.string().transform((v) => Number(v)) + }); + const expected = z.object({ + c1: z.number().int().nullable().optional(), + c2: z.number().int().max(1000), + c3: z.string().transform((v) => Number(v)) + }); + expectSchemaShape(t, expected).from(result); +}); + +test('refine table - update', (t) => { + const table = pgTable('test', { + c1: integer(), + c2: integer().notNull(), + c3: integer().notNull(), + c4: integer().generatedAlwaysAs(1) + }); + + const result = createUpdateSchema(table, { + c2: (schema) => schema.max(1000), + c3: z.string().transform((v) => Number(v)), + }); + const expected = z.object({ + c1: z.number().int().nullable().optional(), + c2: z.number().int().max(1000).optional(), + c3: z.string().transform((v) => Number(v)), + }); + expectSchemaShape(t, expected).from(result); +}); + test('refine view - select', (t) => { const table = pgTable('test', { c1: integer(),