Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support @@validate in type declarations #1868

Merged
merged 2 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import {
} from '@zenstackhq/language/ast';
import {
ExpressionContext,
getDataModelFieldReference,
getFieldReference,
getFunctionExpressionContext,
getLiteral,
isDataModelFieldReference,
Expand Down Expand Up @@ -96,7 +96,7 @@ export default class FunctionInvocationValidator implements AstValidator<Express
// first argument must refer to a model field
const firstArg = expr.args?.[0]?.value;
if (firstArg) {
if (!getDataModelFieldReference(firstArg)) {
if (!getFieldReference(firstArg)) {
accept('error', 'first argument must be a field reference', { node: firstArg });
}
}
Expand Down
5 changes: 3 additions & 2 deletions packages/schema/src/language-server/zmodel-scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,14 @@ export class ZModelScopeProvider extends DefaultScopeProvider {
const node = context.container as MemberAccessExpr;

// typedef's fields are only added to the scope if the access starts with `auth().`
const allowTypeDefScope = isAuthOrAuthMemberAccess(node.operand);
// or the member access resides inside a typedef
const allowTypeDefScope = isAuthOrAuthMemberAccess(node.operand) || !!getContainerOfType(node, isTypeDef);

return match(node.operand)
.when(isReferenceExpr, (operand) => {
// operand is a reference, it can only be a model/type-def field
const ref = operand.target.ref;
if (isDataModelField(ref)) {
if (isDataModelField(ref) || isTypeDefField(ref)) {
return this.createScopeForContainer(ref.type.reference?.ref, globalScope, allowTypeDefScope);
}
return EMPTY_SCOPE;
Expand Down
2 changes: 1 addition & 1 deletion packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ import {
ReferenceExpr,
StringLiteral,
} from '@zenstackhq/language/ast';
import { getIdFields } from '@zenstackhq/sdk';
import { getPrismaVersion } from '@zenstackhq/sdk/prisma';
import { match } from 'ts-pattern';
import { getIdFields } from '../../utils/ast-utils';

import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime';
import {
Expand Down
117 changes: 96 additions & 21 deletions packages/schema/src/plugins/zod/generator.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
import {
ExpressionContext,
PluginError,
PluginGlobalOptions,
PluginOptions,
RUNTIME_PACKAGE,
TypeScriptExpressionTransformer,
TypeScriptExpressionTransformerError,
ensureEmptyDir,
getAttributeArg,
getAttributeArgLiteral,
getDataModels,
getLiteralArray,
hasAttribute,
isDataModelFieldReference,
isDiscriminatorField,
isEnumFieldReference,
isForeignKeyField,
Expand All @@ -15,7 +22,7 @@
resolvePath,
saveSourceFile,
} from '@zenstackhq/sdk';
import { DataModel, EnumField, Model, TypeDef, isDataModel, isEnum, isTypeDef } from '@zenstackhq/sdk/ast';
import { DataModel, EnumField, Model, TypeDef, isArrayExpr, isDataModel, isEnum, isTypeDef } from '@zenstackhq/sdk/ast';
import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers';
import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma';
import { streamAllContents } from 'langium';
Expand All @@ -26,7 +33,7 @@
import { getDefaultOutputFolder } from '../plugin-utils';
import Transformer from './transformer';
import { ObjectMode } from './types';
import { makeFieldSchema, makeValidationRefinements } from './utils/schema-gen';
import { makeFieldSchema } from './utils/schema-gen';

export class ZodSchemaGenerator {
private readonly sourceFiles: SourceFile[] = [];
Expand Down Expand Up @@ -294,7 +301,7 @@
sf.replaceWithText((writer) => {
this.addPreludeAndImports(typeDef, writer, output);

writer.write(`export const ${typeDef.name}Schema = z.object(`);
writer.write(`const baseSchema = z.object(`);
writer.inlineBlock(() => {
typeDef.fields.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
Expand All @@ -313,9 +320,24 @@
writer.writeLine(').strict();');
break;
}
});

// TODO: "@@validate" refinements
// compile "@@validate" to a function calling zod's `.refine()`
const refineFuncName = this.createRefineFunction(typeDef, writer);

if (refineFuncName) {
// export a schema without refinement for extensibility: `[Model]WithoutRefineSchema`
const noRefineSchema = `${upperCaseFirst(typeDef.name)}WithoutRefineSchema`;
writer.writeLine(`
/**
* \`${typeDef.name}\` schema prior to calling \`.refine()\` for extensibility.
*/
export const ${noRefineSchema} = baseSchema;
export const ${typeDef.name}Schema = ${refineFuncName}(${noRefineSchema});
`);
} else {
writer.writeLine(`export const ${typeDef.name}Schema = baseSchema;`);
}
});

return schemaName;
}
Expand Down Expand Up @@ -436,22 +458,7 @@
}

// compile "@@validate" to ".refine"
const refinements = makeValidationRefinements(model);
let refineFuncName: string | undefined;
if (refinements.length > 0) {
refineFuncName = `refine${upperCaseFirst(model.name)}`;
writer.writeLine(
`
/**
* Schema refinement function for applying \`@@validate\` rules.
*/
export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${refinements.join(
'\n'
)};
}
`
);
}
const refineFuncName = this.createRefineFunction(model, writer);

// delegate discriminator fields are to be excluded from mutation schemas
const delegateDiscriminatorFields = model.fields.filter((field) => isDiscriminatorField(field));
Expand Down Expand Up @@ -658,6 +665,74 @@
return schemaName;
}

private createRefineFunction(decl: DataModel | TypeDef, writer: CodeBlockWriter) {
const refinements = this.makeValidationRefinements(decl);
let refineFuncName: string | undefined;
if (refinements.length > 0) {
refineFuncName = `refine${upperCaseFirst(decl.name)}`;
writer.writeLine(
`
/**
* Schema refinement function for applying \`@@validate\` rules.
*/
export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${refinements.join(
'\n'
)};
}
`
);
return refineFuncName;
} else {
return undefined;
}
}

private makeValidationRefinements(decl: DataModel | TypeDef) {
const attrs = decl.attributes.filter((attr) => attr.decl.ref?.name === '@@validate');
const refinements = attrs
.map((attr) => {
const valueArg = getAttributeArg(attr, 'value');
if (!valueArg) {
return undefined;
}

const messageArg = getAttributeArgLiteral<string>(attr, 'message');
const message = messageArg ? `message: ${JSON.stringify(messageArg)},` : '';

const pathArg = getAttributeArg(attr, 'path');
const path =
pathArg && isArrayExpr(pathArg)
? `path: ['${getLiteralArray<string>(pathArg)?.join(`', '`)}'],`
: '';

const options = `, { ${message} ${path} }`;

try {
let expr = new TypeScriptExpressionTransformer({
context: ExpressionContext.ValidationRule,
fieldReferenceContext: 'value',
}).transform(valueArg);

if (isDataModelFieldReference(valueArg)) {
// if the expression is a simple field reference, treat undefined
// as true since the all fields are optional in validation context
expr = `${expr} ?? true`;
}

return `.refine((value: any) => ${expr}${options})`;
} catch (err) {
if (err instanceof TypeScriptExpressionTransformerError) {
throw new PluginError(name, err.message);
} else {
throw err;
}
}
})
.filter((r) => !!r);

return refinements;
}

private makePartial(schema: string, fields?: string[]) {
if (fields) {
if (fields.length === 0) {
Expand Down
60 changes: 1 addition & 59 deletions packages/schema/src/plugins/zod/utils/schema-gen.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
import { getLiteral, isFromStdlib } from '@zenstackhq/sdk';
import {
ExpressionContext,
getAttributeArg,
getAttributeArgLiteral,
getLiteral,
getLiteralArray,
isDataModelFieldReference,
isFromStdlib,
PluginError,
TypeScriptExpressionTransformer,
TypeScriptExpressionTransformerError,
} from '@zenstackhq/sdk';
import {
DataModel,
DataModelField,
DataModelFieldAttribute,
isArrayExpr,
isBooleanLiteral,
isDataModel,
isEnum,
Expand All @@ -25,7 +12,6 @@ import {
TypeDefField,
} from '@zenstackhq/sdk/ast';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '..';
import { isDefaultWithAuth } from '../../enhancer/enhancer-utils';

export function makeFieldSchema(field: DataModelField | TypeDefField) {
Expand Down Expand Up @@ -222,50 +208,6 @@ function makeZodSchema(field: DataModelField | TypeDefField) {
return schema;
}

export function makeValidationRefinements(model: DataModel) {
const attrs = model.attributes.filter((attr) => attr.decl.ref?.name === '@@validate');
const refinements = attrs
.map((attr) => {
const valueArg = getAttributeArg(attr, 'value');
if (!valueArg) {
return undefined;
}

const messageArg = getAttributeArgLiteral<string>(attr, 'message');
const message = messageArg ? `message: ${JSON.stringify(messageArg)},` : '';

const pathArg = getAttributeArg(attr, 'path');
const path =
pathArg && isArrayExpr(pathArg) ? `path: ['${getLiteralArray<string>(pathArg)?.join(`', '`)}'],` : '';

const options = `, { ${message} ${path} }`;

try {
let expr = new TypeScriptExpressionTransformer({
context: ExpressionContext.ValidationRule,
fieldReferenceContext: 'value',
}).transform(valueArg);

if (isDataModelFieldReference(valueArg)) {
// if the expression is a simple field reference, treat undefined
// as true since the all fields are optional in validation context
expr = `${expr} ?? true`;
}

return `.refine((value: any) => ${expr}${options})`;
} catch (err) {
if (err instanceof TypeScriptExpressionTransformerError) {
throw new PluginError(name, err.message);
} else {
throw err;
}
}
})
.filter((r) => !!r);

return refinements;
}

function getAttrLiteralArg<T extends string | number>(attr: DataModelFieldAttribute, paramName: string) {
const arg = attr.args.find((arg) => arg.$resolvedParam?.name === paramName);
return arg && getLiteral<T>(arg.value);
Expand Down
46 changes: 1 addition & 45 deletions packages/schema/src/utils/ast-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,19 @@ import {
BinaryExpr,
DataModel,
DataModelAttribute,
DataModelField,
Expression,
InheritableNode,
isArrayExpr,
isBinaryExpr,
isDataModel,
isDataModelField,
isInvocationExpr,
isMemberAccessExpr,
isModel,
isReferenceExpr,
isTypeDef,
Model,
ModelImport,
ReferenceExpr,
TypeDef,
} from '@zenstackhq/language/ast';
import {
getInheritanceChain,
getModelFieldsWithBases,
getRecursiveBases,
isDelegateModel,
isFromStdlib,
} from '@zenstackhq/sdk';
import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import {
AstNode,
copyAstNode,
Expand Down Expand Up @@ -151,29 +140,6 @@ function cloneAst<T extends InheritableNode>(
return clone;
}

export function getIdFields(dataModel: DataModel) {
const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) =>
f.attributes.some((attr) => attr.decl.$refText === '@id')
);
if (fieldLevelId) {
return [fieldLevelId];
} else {
// get model level @@id attribute
const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id');
if (modelIdAttr) {
// get fields referenced in the attribute: @@id([field1, field2]])
if (!isArrayExpr(modelIdAttr.args[0]?.value)) {
return [];
}
const argValue = modelIdAttr.args[0].value;
return argValue.items
.filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr))
.map((expr) => expr.target.ref as DataModelField);
}
}
return [];
}

export function isAuthInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref);
}
Expand All @@ -186,16 +152,6 @@ export function isCheckInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'check' && isFromStdlib(node.function.ref);
}

export function getDataModelFieldReference(expr: Expression): DataModelField | undefined {
if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) {
return expr.target.ref;
} else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) {
return expr.member.ref;
} else {
return undefined;
}
}

export function resolveImportUri(imp: ModelImport): URI | undefined {
if (!imp.path) return undefined; // This will return true if imp.path is undefined, null, or an empty string ("").

Expand Down
Loading
Loading