Skip to content

Commit df41f2b

Browse files
committed
feat: allow comparing fields from different models in mutation policies
- Generate TS checker functions to evaluate rules in JS runtime - Make sure fields needed in the checker are selected when reading entities Only supporting mutation rules (create, update, post-update, delete) because: 1. Evaluating read in JS runtime may result in reading lots of rows and then discard 2. Don't know how to support aggregation without reading all rows
1 parent ea105d1 commit df41f2b

File tree

9 files changed

+1163
-118
lines changed

9 files changed

+1163
-118
lines changed

packages/runtime/src/enhancements/policy/handler.ts

Lines changed: 154 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/* eslint-disable @typescript-eslint/no-explicit-any */
22

3+
import deepmerge from 'deepmerge';
34
import { lowerCaseFirst } from 'lower-case-first';
45
import invariant from 'tiny-invariant';
56
import { P, match } from 'ts-pattern';
@@ -23,7 +24,7 @@ import { Logger } from '../logger';
2324
import { createDeferredPromise, createFluentPromise } from '../promise';
2425
import { PrismaProxyHandler } from '../proxy';
2526
import { QueryUtils } from '../query-utils';
26-
import type { CheckerConstraint } from '../types';
27+
import type { AdditionalCheckerFunc, CheckerConstraint } from '../types';
2728
import { clone, formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils';
2829
import { ConstraintSolver } from './constraint-solver';
2930
import { PolicyUtil } from './policy-utils';
@@ -152,8 +153,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
152153
}
153154

154155
const result = await this.modelClient[actionName](_args);
155-
this.policyUtils.postProcessForRead(result, this.model, origArgs);
156-
return result;
156+
return this.policyUtils.postProcessForRead(result, this.model, origArgs);
157157
}
158158

159159
//#endregion
@@ -779,10 +779,27 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
779779
}
780780
};
781781

782-
const _connectDisconnect = async (model: string, args: any, context: NestedWriteVisitorContext) => {
782+
const _connectDisconnect = async (
783+
model: string,
784+
args: any,
785+
context: NestedWriteVisitorContext,
786+
operation: 'connect' | 'disconnect'
787+
) => {
783788
if (context.field?.backLink) {
784789
const backLinkField = this.policyUtils.getModelField(model, context.field.backLink);
785790
if (backLinkField?.isRelationOwner) {
791+
let uniqueFilter = args;
792+
if (operation === 'disconnect') {
793+
// disconnect filter is not unique, need to build a reversed query to
794+
// locate the entity and use its id fields as unique filter
795+
const reversedQuery = this.policyUtils.buildReversedQuery(context);
796+
const found = await db[model].findUnique({
797+
where: reversedQuery,
798+
select: this.policyUtils.makeIdSelection(model),
799+
});
800+
uniqueFilter = found && this.policyUtils.getIdFieldValues(model, found);
801+
}
802+
786803
// update happens on the related model, require updatable,
787804
// translate args to foreign keys so field-level policies can be checked
788805
const checkArgs: any = {};
@@ -794,10 +811,15 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
794811
}
795812
}
796813
}
797-
await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, checkArgs);
798814

799-
// register post-update check
800-
await _registerPostUpdateCheck(model, args, args);
815+
// `uniqueFilter` can be undefined if the entity to be disconnected doesn't exist
816+
if (uniqueFilter) {
817+
// check for update
818+
await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, checkArgs);
819+
820+
// register post-update check
821+
await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter);
822+
}
801823
}
802824
}
803825
};
@@ -970,14 +992,14 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
970992
}
971993
},
972994

973-
connect: async (model, args, context) => _connectDisconnect(model, args, context),
995+
connect: async (model, args, context) => _connectDisconnect(model, args, context, 'connect'),
974996

975997
connectOrCreate: async (model, args, context) => {
976998
// the where condition is already unique, so we can use it to check if the target exists
977999
const existing = await this.policyUtils.checkExistence(db, model, args.where);
9781000
if (existing) {
9791001
// connect
980-
await _connectDisconnect(model, args.where, context);
1002+
await _connectDisconnect(model, args.where, context, 'connect');
9811003
return true;
9821004
} else {
9831005
// create
@@ -997,7 +1019,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
9971019
}
9981020
},
9991021

1000-
disconnect: async (model, args, context) => _connectDisconnect(model, args, context),
1022+
disconnect: async (model, args, context) => _connectDisconnect(model, args, context, 'disconnect'),
10011023

10021024
set: async (model, args, context) => {
10031025
// find the set of items to be replaced
@@ -1012,10 +1034,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10121034
const currentSet = await db[model].findMany(findCurrSetArgs);
10131035

10141036
// register current set for update (foreign key)
1015-
await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context)));
1037+
await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context, 'disconnect')));
10161038

10171039
// proceed with connecting the new set
1018-
await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context)));
1040+
await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context, 'connect')));
10191041
},
10201042

10211043
delete: async (model, args, context) => {
@@ -1160,48 +1182,78 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
11601182

11611183
args.data = this.validateUpdateInputSchema(this.model, args.data);
11621184

1163-
if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) {
1164-
// use a transaction to do post-update checks
1165-
const postWriteChecks: PostWriteCheckRecord[] = [];
1166-
return this.queryUtils.transaction(this.prisma, async (tx) => {
1167-
// collect pre-update values
1168-
let select = this.policyUtils.makeIdSelection(this.model);
1169-
const preValueSelect = this.policyUtils.getPreValueSelect(this.model);
1170-
if (preValueSelect) {
1171-
select = { ...select, ...preValueSelect };
1172-
}
1173-
const currentSetQuery = { select, where: args.where };
1174-
this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read');
1185+
const additionalChecker = this.policyUtils.getAdditionalChecker(this.model, 'update');
11751186

1176-
if (this.shouldLogQuery) {
1177-
this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`);
1178-
}
1179-
const currentSet = await tx[this.model].findMany(currentSetQuery);
1187+
const canProceedWithoutTransaction =
1188+
// no post-update rules
1189+
!this.policyUtils.hasAuthGuard(this.model, 'postUpdate') &&
1190+
// no Zod schema
1191+
!this.policyUtils.getZodSchema(this.model) &&
1192+
// no additional checker
1193+
!additionalChecker;
11801194

1181-
postWriteChecks.push(
1182-
...currentSet.map((preValue) => ({
1183-
model: this.model,
1184-
operation: 'postUpdate' as PolicyOperationKind,
1185-
uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue),
1186-
preValue: preValueSelect ? preValue : undefined,
1187-
}))
1188-
);
1189-
1190-
// proceed with the update
1191-
const result = await tx[this.model].updateMany(args);
1192-
1193-
// run post-write checks
1194-
await this.runPostWriteChecks(postWriteChecks, tx);
1195-
1196-
return result;
1197-
});
1198-
} else {
1195+
if (canProceedWithoutTransaction) {
11991196
// proceed without a transaction
12001197
if (this.shouldLogQuery) {
12011198
this.logger.info(`[policy] \`updateMany\` ${this.model}: ${formatObject(args)}`);
12021199
}
12031200
return this.modelClient.updateMany(args);
12041201
}
1202+
1203+
// collect post-update checks
1204+
const postWriteChecks: PostWriteCheckRecord[] = [];
1205+
1206+
return this.queryUtils.transaction(this.prisma, async (tx) => {
1207+
// collect pre-update values
1208+
let select = this.policyUtils.makeIdSelection(this.model);
1209+
const preValueSelect = this.policyUtils.getPreValueSelect(this.model);
1210+
if (preValueSelect) {
1211+
select = { ...select, ...preValueSelect };
1212+
}
1213+
1214+
// merge selection required for running additional checker
1215+
const additionalCheckerSelector = this.policyUtils.getAdditionalCheckerSelector(this.model, 'update');
1216+
if (additionalCheckerSelector) {
1217+
select = deepmerge(select, additionalCheckerSelector);
1218+
}
1219+
1220+
const currentSetQuery = { select, where: args.where };
1221+
this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'update');
1222+
1223+
if (this.shouldLogQuery) {
1224+
this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`);
1225+
}
1226+
let candidates = await tx[this.model].findMany(currentSetQuery);
1227+
1228+
if (additionalChecker) {
1229+
// filter candidates with additional checker and build an id filter
1230+
const r = this.buildIdFilterWithAdditionalChecker(candidates, additionalChecker);
1231+
candidates = r.filteredCandidates;
1232+
1233+
// merge id filter into update's where clause
1234+
args.where = args.where ? { AND: [args.where, r.idFilter] } : r.idFilter;
1235+
}
1236+
1237+
postWriteChecks.push(
1238+
...candidates.map((preValue) => ({
1239+
model: this.model,
1240+
operation: 'postUpdate' as PolicyOperationKind,
1241+
uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue),
1242+
preValue: preValueSelect ? preValue : undefined,
1243+
}))
1244+
);
1245+
1246+
// proceed with the update
1247+
if (this.shouldLogQuery) {
1248+
this.logger.info(`[policy] \`updateMany\` in tx for ${this.model}: ${formatObject(args)}`);
1249+
}
1250+
const result = await tx[this.model].updateMany(args);
1251+
1252+
// run post-write checks
1253+
await this.runPostWriteChecks(postWriteChecks, tx);
1254+
1255+
return result;
1256+
});
12051257
});
12061258
}
12071259

@@ -1328,14 +1380,53 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
13281380
this.policyUtils.tryReject(this.prisma, this.model, 'delete');
13291381

13301382
// inject policy conditions
1331-
args = args ?? {};
1383+
args = clone(args);
13321384
this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete');
13331385

1334-
// conduct the deletion
1335-
if (this.shouldLogQuery) {
1336-
this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`);
1386+
const additionalChecker = this.policyUtils.getAdditionalChecker(this.model, 'delete');
1387+
if (additionalChecker) {
1388+
// additional checker exists, need to run deletion inside a transaction
1389+
return this.queryUtils.transaction(this.prisma, async (tx) => {
1390+
// find the delete candidates, selecting id fields and fields needed for
1391+
// running the additional checker
1392+
let candidateSelect = this.policyUtils.makeIdSelection(this.model);
1393+
const additionalCheckerSelector = this.policyUtils.getAdditionalCheckerSelector(
1394+
this.model,
1395+
'delete'
1396+
);
1397+
if (additionalCheckerSelector) {
1398+
candidateSelect = deepmerge(candidateSelect, additionalCheckerSelector);
1399+
}
1400+
1401+
if (this.shouldLogQuery) {
1402+
this.logger.info(
1403+
`[policy] \`findMany\` ${this.model}: ${formatObject({
1404+
where: args.where,
1405+
select: candidateSelect,
1406+
})}`
1407+
);
1408+
}
1409+
const candidates = await tx[this.model].findMany({ where: args.where, select: candidateSelect });
1410+
1411+
// build a ID filter based on id values filtered by the additional checker
1412+
const { idFilter } = this.buildIdFilterWithAdditionalChecker(candidates, additionalChecker);
1413+
1414+
// merge the ID filter into the where clause
1415+
args.where = args.where ? { AND: [args.where, idFilter] } : idFilter;
1416+
1417+
// finally, conduct the deletion with the combined where clause
1418+
if (this.shouldLogQuery) {
1419+
this.logger.info(`[policy] \`deleteMany\` in tx for ${this.model}:\n${formatObject(args)}`);
1420+
}
1421+
return tx[this.model].deleteMany(args);
1422+
});
1423+
} else {
1424+
// conduct the deletion directly
1425+
if (this.shouldLogQuery) {
1426+
this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`);
1427+
}
1428+
return this.modelClient.deleteMany(args);
13371429
}
1338-
return this.modelClient.deleteMany(args);
13391430
});
13401431
}
13411432

@@ -1599,5 +1690,17 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
15991690
}
16001691
}
16011692

1693+
private buildIdFilterWithAdditionalChecker(candidates: any[], additionalChecker: AdditionalCheckerFunc) {
1694+
const filteredCandidates = candidates.filter((value) => additionalChecker({ user: this.context?.user }, value));
1695+
const idFields = this.policyUtils.getIdFields(this.model);
1696+
let idFilter: any;
1697+
if (idFields.length === 1) {
1698+
idFilter = { [idFields[0].name]: { in: filteredCandidates.map((x) => x[idFields[0].name]) } };
1699+
} else {
1700+
idFilter = { AND: filteredCandidates.map((x) => this.policyUtils.getIdFieldValues(this.model, x)) };
1701+
}
1702+
return { filteredCandidates, idFilter };
1703+
}
1704+
16021705
//#endregion
16031706
}

0 commit comments

Comments
 (0)