diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index f37886c93..8c11a2a72 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -15,7 +15,7 @@ import { } from '@zenstackhq/language/ast'; import { ExpressionContext, - getDataModelFieldReference, + getFieldReference, getFunctionExpressionContext, getLiteral, isDataModelFieldReference, @@ -96,7 +96,7 @@ export default class FunctionInvocationValidator implements AstValidator { // operand is a reference, it can only be a model/type-def field const ref = operand.target.ref; - if (isDataModelField(ref)) { + if (isDataModelField(ref) || isTypeDefField(ref)) { return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope); } return EMPTY_SCOPE; diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index d893d729f..cdc37cc81 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -30,9 +30,9 @@ import { ReferenceExpr, StringLiteral, } from '@zenstackhq/language/ast'; +import { getIdFields } from '@zenstackhq/sdk'; import { getPrismaVersion } from '@zenstackhq/sdk/prisma'; import { match } from 'ts-pattern'; -import { getIdFields } from '../../utils/ast-utils'; import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime'; import { diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index 0cdc8d44a..46e6505fe 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -1,11 +1,18 @@ import { + ExpressionContext, PluginError, PluginGlobalOptions, PluginOptions, RUNTIME_PACKAGE, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, ensureEmptyDir, + getAttributeArg, + getAttributeArgLiteral, getDataModels, + getLiteralArray, hasAttribute, + isDataModelFieldReference, isDiscriminatorField, isEnumFieldReference, isForeignKeyField, @@ -15,7 +22,7 @@ import { resolvePath, saveSourceFile, } from '@zenstackhq/sdk'; -import { DataModel, EnumField, Model, TypeDef, isDataModel, isEnum, isTypeDef } from '@zenstackhq/sdk/ast'; +import { DataModel, EnumField, Model, TypeDef, isArrayExpr, isDataModel, isEnum, isTypeDef } from '@zenstackhq/sdk/ast'; import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers'; import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma'; import { streamAllContents } from 'langium'; @@ -26,7 +33,7 @@ import { name } from '.'; import { getDefaultOutputFolder } from '../plugin-utils'; import Transformer from './transformer'; import { ObjectMode } from './types'; -import { makeFieldSchema, makeValidationRefinements } from './utils/schema-gen'; +import { makeFieldSchema } from './utils/schema-gen'; export class ZodSchemaGenerator { private readonly sourceFiles: SourceFile[] = []; @@ -294,7 +301,7 @@ export class ZodSchemaGenerator { sf.replaceWithText((writer) => { this.addPreludeAndImports(typeDef, writer, output); - writer.write(`export const ${typeDef.name}Schema = z.object(`); + writer.write(`const baseSchema = z.object(`); writer.inlineBlock(() => { typeDef.fields.forEach((field) => { writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`); @@ -313,9 +320,24 @@ export class ZodSchemaGenerator { writer.writeLine(').strict();'); break; } - }); - // TODO: "@@validate" refinements + // compile "@@validate" to a function calling zod's `.refine()` + const refineFuncName = this.createRefineFunction(typeDef, writer); + + if (refineFuncName) { + // export a schema without refinement for extensibility: `[Model]WithoutRefineSchema` + const noRefineSchema = `${upperCaseFirst(typeDef.name)}WithoutRefineSchema`; + writer.writeLine(` +/** + * \`${typeDef.name}\` schema prior to calling \`.refine()\` for extensibility. + */ +export const ${noRefineSchema} = baseSchema; +export const ${typeDef.name}Schema = ${refineFuncName}(${noRefineSchema}); +`); + } else { + writer.writeLine(`export const ${typeDef.name}Schema = baseSchema;`); + } + }); return schemaName; } @@ -436,22 +458,7 @@ export class ZodSchemaGenerator { } // compile "@@validate" to ".refine" - const refinements = makeValidationRefinements(model); - let refineFuncName: string | undefined; - if (refinements.length > 0) { - refineFuncName = `refine${upperCaseFirst(model.name)}`; - writer.writeLine( - ` -/** - * Schema refinement function for applying \`@@validate\` rules. - */ -export function ${refineFuncName}(schema: z.ZodType) { return schema${refinements.join( - '\n' - )}; -} -` - ); - } + const refineFuncName = this.createRefineFunction(model, writer); // delegate discriminator fields are to be excluded from mutation schemas const delegateDiscriminatorFields = model.fields.filter((field) => isDiscriminatorField(field)); @@ -658,6 +665,74 @@ export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema}; return schemaName; } + private createRefineFunction(decl: DataModel | TypeDef, writer: CodeBlockWriter) { + const refinements = this.makeValidationRefinements(decl); + let refineFuncName: string | undefined; + if (refinements.length > 0) { + refineFuncName = `refine${upperCaseFirst(decl.name)}`; + writer.writeLine( + ` + /** + * Schema refinement function for applying \`@@validate\` rules. + */ + export function ${refineFuncName}(schema: z.ZodType) { return schema${refinements.join( + '\n' + )}; + } + ` + ); + return refineFuncName; + } else { + return undefined; + } + } + + private makeValidationRefinements(decl: DataModel | TypeDef) { + const attrs = decl.attributes.filter((attr) => attr.decl.ref?.name === '@@validate'); + const refinements = attrs + .map((attr) => { + const valueArg = getAttributeArg(attr, 'value'); + if (!valueArg) { + return undefined; + } + + const messageArg = getAttributeArgLiteral(attr, 'message'); + const message = messageArg ? `message: ${JSON.stringify(messageArg)},` : ''; + + const pathArg = getAttributeArg(attr, 'path'); + const path = + pathArg && isArrayExpr(pathArg) + ? `path: ['${getLiteralArray(pathArg)?.join(`', '`)}'],` + : ''; + + const options = `, { ${message} ${path} }`; + + try { + let expr = new TypeScriptExpressionTransformer({ + context: ExpressionContext.ValidationRule, + fieldReferenceContext: 'value', + }).transform(valueArg); + + if (isDataModelFieldReference(valueArg)) { + // if the expression is a simple field reference, treat undefined + // as true since the all fields are optional in validation context + expr = `${expr} ?? true`; + } + + return `.refine((value: any) => ${expr}${options})`; + } catch (err) { + if (err instanceof TypeScriptExpressionTransformerError) { + throw new PluginError(name, err.message); + } else { + throw err; + } + } + }) + .filter((r) => !!r); + + return refinements; + } + private makePartial(schema: string, fields?: string[]) { if (fields) { if (fields.length === 0) { diff --git a/packages/schema/src/plugins/zod/utils/schema-gen.ts b/packages/schema/src/plugins/zod/utils/schema-gen.ts index c130934b2..37466adbb 100644 --- a/packages/schema/src/plugins/zod/utils/schema-gen.ts +++ b/packages/schema/src/plugins/zod/utils/schema-gen.ts @@ -1,20 +1,7 @@ +import { getLiteral, isFromStdlib } from '@zenstackhq/sdk'; import { - ExpressionContext, - getAttributeArg, - getAttributeArgLiteral, - getLiteral, - getLiteralArray, - isDataModelFieldReference, - isFromStdlib, - PluginError, - TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, -} from '@zenstackhq/sdk'; -import { - DataModel, DataModelField, DataModelFieldAttribute, - isArrayExpr, isBooleanLiteral, isDataModel, isEnum, @@ -25,7 +12,6 @@ import { TypeDefField, } from '@zenstackhq/sdk/ast'; import { upperCaseFirst } from 'upper-case-first'; -import { name } from '..'; import { isDefaultWithAuth } from '../../enhancer/enhancer-utils'; export function makeFieldSchema(field: DataModelField | TypeDefField) { @@ -222,50 +208,6 @@ function makeZodSchema(field: DataModelField | TypeDefField) { return schema; } -export function makeValidationRefinements(model: DataModel) { - const attrs = model.attributes.filter((attr) => attr.decl.ref?.name === '@@validate'); - const refinements = attrs - .map((attr) => { - const valueArg = getAttributeArg(attr, 'value'); - if (!valueArg) { - return undefined; - } - - const messageArg = getAttributeArgLiteral(attr, 'message'); - const message = messageArg ? `message: ${JSON.stringify(messageArg)},` : ''; - - const pathArg = getAttributeArg(attr, 'path'); - const path = - pathArg && isArrayExpr(pathArg) ? `path: ['${getLiteralArray(pathArg)?.join(`', '`)}'],` : ''; - - const options = `, { ${message} ${path} }`; - - try { - let expr = new TypeScriptExpressionTransformer({ - context: ExpressionContext.ValidationRule, - fieldReferenceContext: 'value', - }).transform(valueArg); - - if (isDataModelFieldReference(valueArg)) { - // if the expression is a simple field reference, treat undefined - // as true since the all fields are optional in validation context - expr = `${expr} ?? true`; - } - - return `.refine((value: any) => ${expr}${options})`; - } catch (err) { - if (err instanceof TypeScriptExpressionTransformerError) { - throw new PluginError(name, err.message); - } else { - throw err; - } - } - }) - .filter((r) => !!r); - - return refinements; -} - function getAttrLiteralArg(attr: DataModelFieldAttribute, paramName: string) { const arg = attr.args.find((arg) => arg.$resolvedParam?.name === paramName); return arg && getLiteral(arg.value); diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 73e17d2e0..a6fab7ea5 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -2,30 +2,19 @@ import { BinaryExpr, DataModel, DataModelAttribute, - DataModelField, Expression, InheritableNode, - isArrayExpr, isBinaryExpr, isDataModel, isDataModelField, isInvocationExpr, - isMemberAccessExpr, isModel, - isReferenceExpr, isTypeDef, Model, ModelImport, - ReferenceExpr, TypeDef, } from '@zenstackhq/language/ast'; -import { - getInheritanceChain, - getModelFieldsWithBases, - getRecursiveBases, - isDelegateModel, - isFromStdlib, -} from '@zenstackhq/sdk'; +import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, copyAstNode, @@ -151,29 +140,6 @@ function cloneAst( return clone; } -export function getIdFields(dataModel: DataModel) { - const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) => - f.attributes.some((attr) => attr.decl.$refText === '@id') - ); - if (fieldLevelId) { - return [fieldLevelId]; - } else { - // get model level @@id attribute - const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); - if (modelIdAttr) { - // get fields referenced in the attribute: @@id([field1, field2]]) - if (!isArrayExpr(modelIdAttr.args[0]?.value)) { - return []; - } - const argValue = modelIdAttr.args[0].value; - return argValue.items - .filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr)) - .map((expr) => expr.target.ref as DataModelField); - } - } - return []; -} - export function isAuthInvocation(node: AstNode) { return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref); } @@ -186,16 +152,6 @@ export function isCheckInvocation(node: AstNode) { return isInvocationExpr(node) && node.function.ref?.name === 'check' && isFromStdlib(node.function.ref); } -export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { - if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { - return expr.target.ref; - } else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) { - return expr.member.ref; - } else { - return undefined; - } -} - export function resolveImportUri(imp: ModelImport): URI | undefined { if (!imp.path) return undefined; // This will return true if imp.path is undefined, null, or an empty string (""). diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 5b34a0e0c..46c2a82c1 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -28,6 +28,7 @@ import { isObjectExpr, isReferenceExpr, isTypeDef, + isTypeDefField, Model, Reference, ReferenceExpr, @@ -518,17 +519,17 @@ export function getIdFields(decl: DataModel | TypeDef) { } const argValue = modelIdAttr.args[0].value; return argValue.items - .filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr)) + .filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getFieldReference(expr)) .map((expr) => expr.target.ref as DataModelField); } } return []; } -export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { - if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { +export function getFieldReference(expr: Expression): DataModelField | TypeDefField | undefined { + if (isReferenceExpr(expr) && (isDataModelField(expr.target.ref) || isTypeDefField(expr.target.ref))) { return expr.target.ref; - } else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) { + } else if (isMemberAccessExpr(expr) && (isDataModelField(expr.member.ref) || isTypeDefField(expr.member.ref))) { return expr.member.ref; } else { return undefined; diff --git a/tests/integration/tests/enhancements/json/crud.test.ts b/tests/integration/tests/enhancements/json/crud.test.ts index be1f218d0..12c35ed09 100644 --- a/tests/integration/tests/enhancements/json/crud.test.ts +++ b/tests/integration/tests/enhancements/json/crud.test.ts @@ -191,6 +191,46 @@ describe('Json field CRUD', () => { ).toResolveTruthy(); }); + it('respects refine validation rules', async () => { + const params = await loadSchema( + ` + type Address { + city String @length(2, 10) + } + + type Profile { + age Int @gte(18) + address Address? + @@validate(age > 18 && length(address.city, 2, 2)) + } + + model User { + id Int @id @default(autoincrement()) + profile Profile @json + @@allow('all', true) + } + `, + { + provider: 'postgresql', + dbUrl, + } + ); + + prisma = params.prisma; + const schema = params.zodSchemas.models.ProfileSchema; + + expect(schema.safeParse({ age: 10, address: { city: 'NY' } })).toMatchObject({ success: false }); + expect(schema.safeParse({ age: 20, address: { city: 'NYC' } })).toMatchObject({ success: false }); + expect(schema.safeParse({ age: 20, address: { city: 'NY' } })).toMatchObject({ success: true }); + + const db = params.enhance(); + await expect(db.user.create({ data: { profile: { age: 10 } } })).toBeRejectedByPolicy(); + await expect( + db.user.create({ data: { profile: { age: 20, address: { city: 'NYC' } } } }) + ).toBeRejectedByPolicy(); + await expect(db.user.create({ data: { profile: { age: 20, address: { city: 'NY' } } } })).toResolveTruthy(); + }); + it('respects enums used by data models', async () => { const params = await loadSchema( `