Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import {
isArrayExpr,
isDataModel,
isGeneratorDecl,
isReferenceExpr,
isTypeDef,
type Model,
} from '@zenstackhq/sdk/ast';
Expand All @@ -45,6 +44,7 @@ import {
} from 'ts-morph';
import { upperCaseFirst } from 'upper-case-first';
import { name } from '..';
import { getConcreteModels, getDiscriminatorField } from '../../../utils/ast-utils';
import { execPackage } from '../../../utils/exec-utils';
import { CorePlugins, getPluginCustomOutputFolder } from '../../plugin-utils';
import { trackPrismaSchemaError } from '../../prisma';
Expand Down Expand Up @@ -407,9 +407,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
this.model.declarations
.filter((d): d is DataModel => isDelegateModel(d))
.forEach((dm) => {
const concreteModels = this.model.declarations.filter(
(d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm)
);
const concreteModels = getConcreteModels(dm);
if (concreteModels.length > 0) {
delegateInfo.push([dm, concreteModels]);
}
Expand Down Expand Up @@ -579,7 +577,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
const typeName = typeAlias.getName();
const payloadRecord = delegateInfo.find(([delegate]) => `$${delegate.name}Payload` === typeName);
if (payloadRecord) {
const discriminatorDecl = this.getDiscriminatorField(payloadRecord[0]);
const discriminatorDecl = getDiscriminatorField(payloadRecord[0]);
if (discriminatorDecl) {
source = `${payloadRecord[1]
.map(
Expand Down Expand Up @@ -826,15 +824,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
.filter((n) => n.getName().startsWith(DELEGATE_AUX_RELATION_PREFIX));
}

private getDiscriminatorField(delegate: DataModel) {
const delegateAttr = getAttribute(delegate, '@@delegate');
if (!delegateAttr) {
return undefined;
}
const arg = delegateAttr.args[0]?.value;
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
}

private saveSourceFile(sf: SourceFile) {
if (this.options.preserveTsFiles) {
saveSourceFile(sf);
Expand Down
24 changes: 13 additions & 11 deletions packages/schema/src/plugins/enhancer/policy/expression-writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -839,16 +839,18 @@ export class ExpressionWriter {
operation = this.options.operationContext;
}

this.block(() => {
if (operation === 'postUpdate') {
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
// e.g.:
// @@allow('all', check(author)) should not delegate "postUpdate" to author
this.writer.write(`${fieldRef.target.$refText}: ${FALSE}`);
} else {
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`);
}
});
this.block(() =>
this.writeFieldCondition(fieldRef, () => {
if (operation === 'postUpdate') {
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
// e.g.:
// @@allow('all', check(author)) should not delegate "postUpdate" to author
this.writer.write(FALSE);
} else {
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
this.writer.write(`${targetGuardFunc}(context, db)`);
}
})
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ export class PolicyGenerator {
writer: CodeBlockWriter,
sourceFile: SourceFile
) {
// first handle several cases where a constant function can be used

if (kind === 'update' && allows.length === 0) {
// no allow rule for 'update', policy is constant based on if there's
// post-update counterpart
Expand Down
5 changes: 2 additions & 3 deletions packages/schema/src/plugins/prisma/schema-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ import path from 'path';
import semver from 'semver';
import { name } from '.';
import { getStringLiteral } from '../../language-server/validator/utils';
import { getConcreteModels } from '../../utils/ast-utils';
import { execPackage } from '../../utils/exec-utils';
import { isDefaultWithAuth } from '../enhancer/enhancer-utils';
import {
Expand Down Expand Up @@ -320,9 +321,7 @@ export class PrismaSchemaGenerator {
}

// collect concrete models inheriting this model
const concreteModels = decl.$container.declarations.filter(
(d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl)
);
const concreteModels = getConcreteModels(decl);

// generate an optional relation field in delegate base model to each concrete model
concreteModels.forEach((concrete) => {
Expand Down
28 changes: 27 additions & 1 deletion packages/schema/src/utils/ast-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@ import {
BinaryExpr,
DataModel,
DataModelAttribute,
DataModelField,
Expression,
InheritableNode,
isBinaryExpr,
isDataModel,
isDataModelField,
isInvocationExpr,
isModel,
isReferenceExpr,
isTypeDef,
Model,
ModelImport,
TypeDef,
} from '@zenstackhq/language/ast';
import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import { getAttribute, getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
import {
AstNode,
copyAstNode,
Expand Down Expand Up @@ -310,3 +312,27 @@ export function findUpInheritance(start: DataModel, target: DataModel): DataMode
}
return undefined;
}

/**
* Gets all concrete models that inherit from the given delegate model
*/
export function getConcreteModels(dataModel: DataModel): DataModel[] {
if (!isDelegateModel(dataModel)) {
return [];
}
return dataModel.$container.declarations.filter(
(d): d is DataModel => isDataModel(d) && d !== dataModel && d.superTypes.some((base) => base.ref === dataModel)
);
}

/**
* Gets the discriminator field for the given delegate model
*/
export function getDiscriminatorField(dataModel: DataModel) {
const delegateAttr = getAttribute(dataModel, '@@delegate');
if (!delegateAttr) {
return undefined;
}
const arg = delegateAttr.args[0]?.value;
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
}
Original file line number Diff line number Diff line change
Expand Up @@ -571,4 +571,84 @@ describe('Polymorphic Policy Test', () => {
expect(foundPost2.foo).toBeUndefined();
expect(foundPost2.bar).toBeUndefined();
});

it('respects concrete policies when read as base optional relation', async () => {
const { enhance } = await loadSchema(
`
model User {
id Int @id @default(autoincrement())
asset Asset?
@@allow('all', true)
}

model Asset {
id Int @id @default(autoincrement())
user User @relation(fields: [userId], references: [id])
userId Int @unique
type String

@@delegate(type)
@@allow('all', true)
}

model Post extends Asset {
title String
private Boolean
@@allow('create', true)
@@deny('read', private)
}
`
);

const fullDb = enhance(undefined, { kinds: ['delegate'] });
await fullDb.user.create({ data: { id: 1 } });
await fullDb.post.create({ data: { title: 'Post1', private: true, user: { connect: { id: 1 } } } });
await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({
asset: expect.objectContaining({ type: 'Post' }),
});

const db = enhance();
const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } });
expect(read.asset).toBeTruthy();
expect(read.asset.title).toBeUndefined();
});

it('respects concrete policies when read as base required relation', async () => {
const { enhance } = await loadSchema(
`
model User {
id Int @id @default(autoincrement())
asset Asset @relation(fields: [assetId], references: [id])
assetId Int @unique
@@allow('all', true)
}

model Asset {
id Int @id @default(autoincrement())
user User?
type String

@@delegate(type)
@@allow('all', true)
}

model Post extends Asset {
title String
private Boolean
@@deny('read', private)
}
`
);

const fullDb = enhance(undefined, { kinds: ['delegate'] });
await fullDb.post.create({ data: { id: 1, title: 'Post1', private: true, user: { create: { id: 1 } } } });
await expect(fullDb.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({
asset: expect.objectContaining({ type: 'Post' }),
});

const db = enhance();
const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } });
expect(read).toBeTruthy();
expect(read.asset.title).toBeUndefined();
});
});
80 changes: 80 additions & 0 deletions tests/regression/tests/issue-1930.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import { loadSchema } from '@zenstackhq/testtools';

describe('issue 1930', () => {
it('regression', async () => {
const { enhance } = await loadSchema(
`
model Organization {
id String @id @default(cuid())
entities Entity[]

@@allow('all', true)
}

model Entity {
id String @id @default(cuid())
org Organization? @relation(fields: [orgId], references: [id])
orgId String?
contents EntityContent[]
entityType String
isDeleted Boolean @default(false)

@@delegate(entityType)

@@allow('all', !isDeleted)
}

model EntityContent {
id String @id @default(cuid())
entity Entity @relation(fields: [entityId], references: [id])
entityId String

entityContentType String

@@delegate(entityContentType)

@@allow('create', true)
@@allow('read', check(entity))
}

model Article extends Entity {
}

model ArticleContent extends EntityContent {
body String?
}

model OtherContent extends EntityContent {
data Int
}
`
);

const fullDb = enhance(undefined, { kinds: ['delegate'] });
const org = await fullDb.organization.create({ data: {} });
const article = await fullDb.article.create({
data: { org: { connect: { id: org.id } } },
});

const db = enhance();

// normal create/read
await expect(
db.articleContent.create({
data: { body: 'abc', entity: { connect: { id: article.id } } },
})
).toResolveTruthy();
await expect(db.article.findFirst({ include: { contents: true } })).resolves.toMatchObject({
contents: expect.arrayContaining([expect.objectContaining({ body: 'abc' })]),
});

// deleted article's contents are not readable
const deletedArticle = await fullDb.article.create({
data: { org: { connect: { id: org.id } }, isDeleted: true },
});
const content1 = await fullDb.articleContent.create({
data: { body: 'bcd', entity: { connect: { id: deletedArticle.id } } },
});
await expect(db.articleContent.findUnique({ where: { id: content1.id } })).toResolveNull();
});
});
Loading