11/* eslint-disable @typescript-eslint/no-explicit-any */ 
22
3+ import  deepmerge  from  'deepmerge' ; 
34import  {  lowerCaseFirst  }  from  'lower-case-first' ; 
45import  invariant  from  'tiny-invariant' ; 
56import  {  P ,  match  }  from  'ts-pattern' ; 
@@ -23,7 +24,7 @@ import { Logger } from '../logger';
2324import  {  createDeferredPromise ,  createFluentPromise  }  from  '../promise' ; 
2425import  {  PrismaProxyHandler  }  from  '../proxy' ; 
2526import  {  QueryUtils  }  from  '../query-utils' ; 
26- import  type  {  CheckerConstraint  }  from  '../types' ; 
27+ import  type  {  AdditionalCheckerFunc ,   CheckerConstraint  }  from  '../types' ; 
2728import  {  clone ,  formatObject ,  isUnsafeMutate ,  prismaClientValidationError  }  from  '../utils' ; 
2829import  {  ConstraintSolver  }  from  './constraint-solver' ; 
2930import  {  PolicyUtil  }  from  './policy-utils' ; 
@@ -152,8 +153,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
152153        } 
153154
154155        const  result  =  await  this . modelClient [ actionName ] ( _args ) ; 
155-         this . policyUtils . postProcessForRead ( result ,  this . model ,  origArgs ) ; 
156-         return  result ; 
156+         return  this . policyUtils . postProcessForRead ( result ,  this . model ,  origArgs ) ; 
157157    } 
158158
159159    //#endregion 
@@ -779,10 +779,27 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
779779            } 
780780        } ; 
781781
782-         const  _connectDisconnect  =  async  ( model : string ,  args : any ,  context : NestedWriteVisitorContext )  =>  { 
782+         const  _connectDisconnect  =  async  ( 
783+             model : string , 
784+             args : any , 
785+             context : NestedWriteVisitorContext , 
786+             operation : 'connect'  |  'disconnect' 
787+         )  =>  { 
783788            if  ( context . field ?. backLink )  { 
784789                const  backLinkField  =  this . policyUtils . getModelField ( model ,  context . field . backLink ) ; 
785790                if  ( backLinkField ?. isRelationOwner )  { 
791+                     let  uniqueFilter  =  args ; 
792+                     if  ( operation  ===  'disconnect' )  { 
793+                         // disconnect filter is not unique, need to build a reversed query to 
794+                         // locate the entity and use its id fields as unique filter 
795+                         const  reversedQuery  =  this . policyUtils . buildReversedQuery ( context ) ; 
796+                         const  found  =  await  db [ model ] . findUnique ( { 
797+                             where : reversedQuery , 
798+                             select : this . policyUtils . makeIdSelection ( model ) , 
799+                         } ) ; 
800+                         uniqueFilter  =  found  &&  this . policyUtils . getIdFieldValues ( model ,  found ) ; 
801+                     } 
802+ 
786803                    // update happens on the related model, require updatable, 
787804                    // translate args to foreign keys so field-level policies can be checked 
788805                    const  checkArgs : any  =  { } ; 
@@ -794,10 +811,15 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
794811                            } 
795812                        } 
796813                    } 
797-                     await  this . policyUtils . checkPolicyForUnique ( model ,  args ,  'update' ,  db ,  checkArgs ) ; 
798814
799-                     // register post-update check 
800-                     await  _registerPostUpdateCheck ( model ,  args ,  args ) ; 
815+                     // `uniqueFilter` can be undefined if the entity to be disconnected doesn't exist 
816+                     if  ( uniqueFilter )  { 
817+                         // check for update 
818+                         await  this . policyUtils . checkPolicyForUnique ( model ,  uniqueFilter ,  'update' ,  db ,  checkArgs ) ; 
819+ 
820+                         // register post-update check 
821+                         await  _registerPostUpdateCheck ( model ,  uniqueFilter ,  uniqueFilter ) ; 
822+                     } 
801823                } 
802824            } 
803825        } ; 
@@ -970,14 +992,14 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
970992                } 
971993            } , 
972994
973-             connect : async  ( model ,  args ,  context )  =>  _connectDisconnect ( model ,  args ,  context ) , 
995+             connect : async  ( model ,  args ,  context )  =>  _connectDisconnect ( model ,  args ,  context ,   'connect' ) , 
974996
975997            connectOrCreate : async  ( model ,  args ,  context )  =>  { 
976998                // the where condition is already unique, so we can use it to check if the target exists 
977999                const  existing  =  await  this . policyUtils . checkExistence ( db ,  model ,  args . where ) ; 
9781000                if  ( existing )  { 
9791001                    // connect 
980-                     await  _connectDisconnect ( model ,  args . where ,  context ) ; 
1002+                     await  _connectDisconnect ( model ,  args . where ,  context ,   'connect' ) ; 
9811003                    return  true ; 
9821004                }  else  { 
9831005                    // create 
@@ -997,7 +1019,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
9971019                } 
9981020            } , 
9991021
1000-             disconnect : async  ( model ,  args ,  context )  =>  _connectDisconnect ( model ,  args ,  context ) , 
1022+             disconnect : async  ( model ,  args ,  context )  =>  _connectDisconnect ( model ,  args ,  context ,   'disconnect' ) , 
10011023
10021024            set : async  ( model ,  args ,  context )  =>  { 
10031025                // find the set of items to be replaced 
@@ -1012,10 +1034,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10121034                const  currentSet  =  await  db [ model ] . findMany ( findCurrSetArgs ) ; 
10131035
10141036                // register current set for update (foreign key) 
1015-                 await  Promise . all ( currentSet . map ( ( item )  =>  _connectDisconnect ( model ,  item ,  context ) ) ) ; 
1037+                 await  Promise . all ( currentSet . map ( ( item )  =>  _connectDisconnect ( model ,  item ,  context ,   'disconnect' ) ) ) ; 
10161038
10171039                // proceed with connecting the new set 
1018-                 await  Promise . all ( enumerate ( args ) . map ( ( item )  =>  _connectDisconnect ( model ,  item ,  context ) ) ) ; 
1040+                 await  Promise . all ( enumerate ( args ) . map ( ( item )  =>  _connectDisconnect ( model ,  item ,  context ,   'connect' ) ) ) ; 
10191041            } , 
10201042
10211043            delete : async  ( model ,  args ,  context )  =>  { 
@@ -1160,48 +1182,78 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
11601182
11611183            args . data  =  this . validateUpdateInputSchema ( this . model ,  args . data ) ; 
11621184
1163-             if  ( this . policyUtils . hasAuthGuard ( this . model ,  'postUpdate' )  ||  this . policyUtils . getZodSchema ( this . model ) )  { 
1164-                 // use a transaction to do post-update checks 
1165-                 const  postWriteChecks : PostWriteCheckRecord [ ]  =  [ ] ; 
1166-                 return  this . queryUtils . transaction ( this . prisma ,  async  ( tx )  =>  { 
1167-                     // collect pre-update values 
1168-                     let  select  =  this . policyUtils . makeIdSelection ( this . model ) ; 
1169-                     const  preValueSelect  =  this . policyUtils . getPreValueSelect ( this . model ) ; 
1170-                     if  ( preValueSelect )  { 
1171-                         select  =  {  ...select ,  ...preValueSelect  } ; 
1172-                     } 
1173-                     const  currentSetQuery  =  {  select,  where : args . where  } ; 
1174-                     this . policyUtils . injectAuthGuardAsWhere ( tx ,  currentSetQuery ,  this . model ,  'read' ) ; 
1185+             const  additionalChecker  =  this . policyUtils . getAdditionalChecker ( this . model ,  'update' ) ; 
11751186
1176-                     if  ( this . shouldLogQuery )  { 
1177-                         this . logger . info ( `[policy] \`findMany\` ${ this . model } ${ formatObject ( currentSetQuery ) }  ) ; 
1178-                     } 
1179-                     const  currentSet  =  await  tx [ this . model ] . findMany ( currentSetQuery ) ; 
1187+             const  canProceedWithoutTransaction  = 
1188+                 // no post-update rules 
1189+                 ! this . policyUtils . hasAuthGuard ( this . model ,  'postUpdate' )  && 
1190+                 // no Zod schema 
1191+                 ! this . policyUtils . getZodSchema ( this . model )  && 
1192+                 // no additional checker 
1193+                 ! additionalChecker ; 
11801194
1181-                     postWriteChecks . push ( 
1182-                         ...currentSet . map ( ( preValue )  =>  ( { 
1183-                             model : this . model , 
1184-                             operation : 'postUpdate'  as  PolicyOperationKind , 
1185-                             uniqueFilter : this . policyUtils . getEntityIds ( this . model ,  preValue ) , 
1186-                             preValue : preValueSelect  ? preValue  : undefined , 
1187-                         } ) ) 
1188-                     ) ; 
1189- 
1190-                     // proceed with the update 
1191-                     const  result  =  await  tx [ this . model ] . updateMany ( args ) ; 
1192- 
1193-                     // run post-write checks 
1194-                     await  this . runPostWriteChecks ( postWriteChecks ,  tx ) ; 
1195- 
1196-                     return  result ; 
1197-                 } ) ; 
1198-             }  else  { 
1195+             if  ( canProceedWithoutTransaction )  { 
11991196                // proceed without a transaction 
12001197                if  ( this . shouldLogQuery )  { 
12011198                    this . logger . info ( `[policy] \`updateMany\` ${ this . model } ${ formatObject ( args ) }  ) ; 
12021199                } 
12031200                return  this . modelClient . updateMany ( args ) ; 
12041201            } 
1202+ 
1203+             // collect post-update checks 
1204+             const  postWriteChecks : PostWriteCheckRecord [ ]  =  [ ] ; 
1205+ 
1206+             return  this . queryUtils . transaction ( this . prisma ,  async  ( tx )  =>  { 
1207+                 // collect pre-update values 
1208+                 let  select  =  this . policyUtils . makeIdSelection ( this . model ) ; 
1209+                 const  preValueSelect  =  this . policyUtils . getPreValueSelect ( this . model ) ; 
1210+                 if  ( preValueSelect )  { 
1211+                     select  =  {  ...select ,  ...preValueSelect  } ; 
1212+                 } 
1213+ 
1214+                 // merge selection required for running additional checker 
1215+                 const  additionalCheckerSelector  =  this . policyUtils . getAdditionalCheckerSelector ( this . model ,  'update' ) ; 
1216+                 if  ( additionalCheckerSelector )  { 
1217+                     select  =  deepmerge ( select ,  additionalCheckerSelector ) ; 
1218+                 } 
1219+ 
1220+                 const  currentSetQuery  =  {  select,  where : args . where  } ; 
1221+                 this . policyUtils . injectAuthGuardAsWhere ( tx ,  currentSetQuery ,  this . model ,  'update' ) ; 
1222+ 
1223+                 if  ( this . shouldLogQuery )  { 
1224+                     this . logger . info ( `[policy] \`findMany\` ${ this . model } ${ formatObject ( currentSetQuery ) }  ) ; 
1225+                 } 
1226+                 let  candidates  =  await  tx [ this . model ] . findMany ( currentSetQuery ) ; 
1227+ 
1228+                 if  ( additionalChecker )  { 
1229+                     // filter candidates with additional checker and build an id filter 
1230+                     const  r  =  this . buildIdFilterWithAdditionalChecker ( candidates ,  additionalChecker ) ; 
1231+                     candidates  =  r . filteredCandidates ; 
1232+ 
1233+                     // merge id filter into update's where clause 
1234+                     args . where  =  args . where  ? {  AND : [ args . where ,  r . idFilter ]  }  : r . idFilter ; 
1235+                 } 
1236+ 
1237+                 postWriteChecks . push ( 
1238+                     ...candidates . map ( ( preValue )  =>  ( { 
1239+                         model : this . model , 
1240+                         operation : 'postUpdate'  as  PolicyOperationKind , 
1241+                         uniqueFilter : this . policyUtils . getEntityIds ( this . model ,  preValue ) , 
1242+                         preValue : preValueSelect  ? preValue  : undefined , 
1243+                     } ) ) 
1244+                 ) ; 
1245+ 
1246+                 // proceed with the update 
1247+                 if  ( this . shouldLogQuery )  { 
1248+                     this . logger . info ( `[policy] \`updateMany\` in tx for ${ this . model } ${ formatObject ( args ) }  ) ; 
1249+                 } 
1250+                 const  result  =  await  tx [ this . model ] . updateMany ( args ) ; 
1251+ 
1252+                 // run post-write checks 
1253+                 await  this . runPostWriteChecks ( postWriteChecks ,  tx ) ; 
1254+ 
1255+                 return  result ; 
1256+             } ) ; 
12051257        } ) ; 
12061258    } 
12071259
@@ -1328,14 +1380,53 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
13281380            this . policyUtils . tryReject ( this . prisma ,  this . model ,  'delete' ) ; 
13291381
13301382            // inject policy conditions 
1331-             args  =  args   ??   { } ; 
1383+             args  =  clone ( args ) ; 
13321384            this . policyUtils . injectAuthGuardAsWhere ( this . prisma ,  args ,  this . model ,  'delete' ) ; 
13331385
1334-             // conduct the deletion 
1335-             if  ( this . shouldLogQuery )  { 
1336-                 this . logger . info ( `[policy] \`deleteMany\` ${ this . model } ${ formatObject ( args ) }  ) ; 
1386+             const  additionalChecker  =  this . policyUtils . getAdditionalChecker ( this . model ,  'delete' ) ; 
1387+             if  ( additionalChecker )  { 
1388+                 // additional checker exists, need to run deletion inside a transaction 
1389+                 return  this . queryUtils . transaction ( this . prisma ,  async  ( tx )  =>  { 
1390+                     // find the delete candidates, selecting id fields and fields needed for 
1391+                     // running the additional checker 
1392+                     let  candidateSelect  =  this . policyUtils . makeIdSelection ( this . model ) ; 
1393+                     const  additionalCheckerSelector  =  this . policyUtils . getAdditionalCheckerSelector ( 
1394+                         this . model , 
1395+                         'delete' 
1396+                     ) ; 
1397+                     if  ( additionalCheckerSelector )  { 
1398+                         candidateSelect  =  deepmerge ( candidateSelect ,  additionalCheckerSelector ) ; 
1399+                     } 
1400+ 
1401+                     if  ( this . shouldLogQuery )  { 
1402+                         this . logger . info ( 
1403+                             `[policy] \`findMany\` ${ this . model } ${ formatObject ( {  
1404+                                 where : args . where ,  
1405+                                 select : candidateSelect ,  
1406+                             } ) }  `
1407+                         ) ; 
1408+                     } 
1409+                     const  candidates  =  await  tx [ this . model ] . findMany ( {  where : args . where ,  select : candidateSelect  } ) ; 
1410+ 
1411+                     // build a ID filter based on id values filtered by the additional checker 
1412+                     const  {  idFilter }  =  this . buildIdFilterWithAdditionalChecker ( candidates ,  additionalChecker ) ; 
1413+ 
1414+                     // merge the ID filter into the where clause 
1415+                     args . where  =  args . where  ? {  AND : [ args . where ,  idFilter ]  }  : idFilter ; 
1416+ 
1417+                     // finally, conduct the deletion with the combined where clause 
1418+                     if  ( this . shouldLogQuery )  { 
1419+                         this . logger . info ( `[policy] \`deleteMany\` in tx for ${ this . model } ${ formatObject ( args ) }  ) ; 
1420+                     } 
1421+                     return  tx [ this . model ] . deleteMany ( args ) ; 
1422+                 } ) ; 
1423+             }  else  { 
1424+                 // conduct the deletion directly 
1425+                 if  ( this . shouldLogQuery )  { 
1426+                     this . logger . info ( `[policy] \`deleteMany\` ${ this . model } ${ formatObject ( args ) }  ) ; 
1427+                 } 
1428+                 return  this . modelClient . deleteMany ( args ) ; 
13371429            } 
1338-             return  this . modelClient . deleteMany ( args ) ; 
13391430        } ) ; 
13401431    } 
13411432
@@ -1599,5 +1690,17 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
15991690        } 
16001691    } 
16011692
1693+     private  buildIdFilterWithAdditionalChecker ( candidates : any [ ] ,  additionalChecker : AdditionalCheckerFunc )  { 
1694+         const  filteredCandidates  =  candidates . filter ( ( value )  =>  additionalChecker ( {  user : this . context ?. user  } ,  value ) ) ; 
1695+         const  idFields  =  this . policyUtils . getIdFields ( this . model ) ; 
1696+         let  idFilter : any ; 
1697+         if  ( idFields . length  ===  1 )  { 
1698+             idFilter  =  {  [ idFields [ 0 ] . name ] : {  in : filteredCandidates . map ( ( x )  =>  x [ idFields [ 0 ] . name ] )  }  } ; 
1699+         }  else  { 
1700+             idFilter  =  {  AND : filteredCandidates . map ( ( x )  =>  this . policyUtils . getIdFieldValues ( this . model ,  x ) )  } ; 
1701+         } 
1702+         return  {  filteredCandidates,  idFilter } ; 
1703+     } 
1704+ 
16021705    //#endregion 
16031706} 
0 commit comments