Skip to content

Commit

Permalink
fix(policy): properly handle array-form of upsert payload
Browse files Browse the repository at this point in the history
Fixes #1080
  • Loading branch information
ymc9 committed Mar 8, 2024
1 parent 4dd7aa0 commit 9daff09
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 39 deletions.
32 changes: 22 additions & 10 deletions packages/runtime/src/cross/nested-write-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import type { FieldInfo, ModelMeta } from './model-meta';
import { resolveField } from './model-meta';
import { MaybePromise, PrismaWriteActionType, PrismaWriteActions } from './types';
import { enumerate, getModelFields } from './utils';
import { getModelFields } from './utils';

type NestingPathItem = { field?: FieldInfo; model: string; where: any; unique: boolean };

Expand Down Expand Up @@ -155,7 +155,7 @@ export class NestedWriteVisitor {
// visit payload
switch (action) {
case 'create':
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, {});
let callbackResult: any;
if (this.callback.create) {
Expand Down Expand Up @@ -183,7 +183,7 @@ export class NestedWriteVisitor {
break;

case 'connectOrCreate':
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.connectOrCreate) {
Expand All @@ -198,7 +198,7 @@ export class NestedWriteVisitor {

case 'connect':
if (this.callback.connect) {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item, true);
await this.callback.connect(model, item, newContext);
}
Expand All @@ -210,7 +210,7 @@ export class NestedWriteVisitor {
// if relation is to-many, the payload is a unique filter object
// if relation is to-one, the payload can only be boolean `true`
if (this.callback.disconnect) {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item, typeof item === 'object');
await this.callback.disconnect(model, item, newContext);
}
Expand All @@ -225,7 +225,7 @@ export class NestedWriteVisitor {
break;

case 'update':
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.update) {
Expand All @@ -244,7 +244,7 @@ export class NestedWriteVisitor {
break;

case 'updateMany':
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.updateMany) {
Expand All @@ -258,7 +258,7 @@ export class NestedWriteVisitor {
break;

case 'upsert': {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, item.where);
let callbackResult: any;
if (this.callback.upsert) {
Expand All @@ -278,7 +278,7 @@ export class NestedWriteVisitor {

case 'delete': {
if (this.callback.delete) {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, toplevel ? item.where : item);
await this.callback.delete(model, item, newContext);
}
Expand All @@ -288,7 +288,7 @@ export class NestedWriteVisitor {

case 'deleteMany':
if (this.callback.deleteMany) {
for (const item of enumerate(data)) {
for (const item of this.enumerateReverse(data)) {
const newContext = pushNewContext(field, model, toplevel ? item.where : item);
await this.callback.deleteMany(model, item, newContext);
}
Expand Down Expand Up @@ -336,4 +336,16 @@ export class NestedWriteVisitor {
}
}
}

// enumerate a (possible) array in reverse order, so that the enumeration
// callback can safely delete the current item
private *enumerateReverse(data: any) {
if (Array.isArray(data)) {
for (let i = data.length - 1; i >= 0; i--) {
yield data[i];
}
} else {
yield data;
}
}
}
53 changes: 35 additions & 18 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -343,29 +343,19 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
}

if (context.parent.connect) {
// if the payload parent already has a "connect" clause, merge it
if (Array.isArray(context.parent.connect)) {
context.parent.connect.push(args.where);
} else {
context.parent.connect = [context.parent.connect, args.where];
}
} else {
// otherwise, create a new "connect" clause
context.parent.connect = args.where;
}
this.mergeToParent(context.parent, 'connect', args.where);
// record the key of connected entities so we can avoid validating them later
connectedEntities.add(getEntityKey(model, existing));
} else {
// create case
pushIdFields(model, context);

// create a new "create" clause at the parent level
context.parent.create = args.create;
this.mergeToParent(context.parent, 'create', args.create);
}

// remove the connectOrCreate clause
delete context.parent['connectOrCreate'];
this.removeFromParent(context.parent, 'connectOrCreate', args);

// return false to prevent visiting the nested payload
return false;
Expand Down Expand Up @@ -895,7 +885,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
await _create(model, args, context);

// remove it from the update payload
delete context.parent.create;
this.removeFromParent(context.parent, 'create', args);

// don't visit payload
return false;
Expand Down Expand Up @@ -928,22 +918,23 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
await _registerPostUpdateCheck(model, uniqueFilter);

// convert upsert to update
context.parent.update = {
const convertedUpdate = {
where: args.where,
data: this.validateUpdateInputSchema(model, args.update),
};
delete context.parent.upsert;
this.mergeToParent(context.parent, 'update', convertedUpdate);
this.removeFromParent(context.parent, 'upsert', args);

// continue visiting the new payload
return context.parent.update;
return convertedUpdate;
} else {
// create case

// process the entire create subtree separately
await _create(model, args.create, context);

// remove it from the update payload
delete context.parent.upsert;
this.removeFromParent(context.parent, 'upsert', args);

// don't visit payload
return false;
Expand Down Expand Up @@ -1390,5 +1381,31 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
return requireField(this.modelMeta, fieldInfo.type, fieldInfo.backLink);
}

private mergeToParent(parent: any, key: string, value: any) {
if (parent[key]) {
if (Array.isArray(parent[key])) {
parent[key].push(value);
} else {
parent[key] = [parent[key], value];
}
} else {
parent[key] = value;
}
}

private removeFromParent(parent: any, key: string, data: any) {
if (parent[key] === data) {
delete parent[key];
} else if (Array.isArray(parent[key])) {
const idx = parent[key].indexOf(data);
if (idx >= 0) {
parent[key].splice(idx, 1);
if (parent[key].length === 0) {
delete parent[key];
}
}
}
}

//#endregion
}
25 changes: 14 additions & 11 deletions tests/integration/tests/regression/issue-1078.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { loadSchema } from '@zenstackhq/testtools';

describe('issue 1078', () => {
it('regression', async () => {
const { prisma, enhance } = await loadSchema(
const { enhance } = await loadSchema(
`
model Counter {
id String @id
Expand All @@ -12,21 +12,25 @@ describe('issue 1078', () => {
@@validate(value >= 0)
@@allow('all', true)
}
}
`
);

const db = enhance();

const counter = await db.counter.create({
data: { id: '1', name: 'It should create', value: 1 },
});
await expect(
db.counter.create({
data: { id: '1', name: 'It should create', value: 1 },
})
).toResolveTruthy();

//! This query fails validation
const updated = await db.counter.update({
where: { id: '1' },
data: { name: 'It should update' },
});
await expect(
db.counter.update({
where: { id: '1' },
data: { name: 'It should update' },
})
).toResolveTruthy();
});

it('read', async () => {
Expand All @@ -37,8 +41,7 @@ describe('issue 1078', () => {
title String @allow('read', true, true)
content String
}
`,
{ logPrismaQuery: true }
`
);

const db = enhance();
Expand Down
133 changes: 133 additions & 0 deletions tests/integration/tests/regression/issue-1080.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import { loadSchema } from '@zenstackhq/testtools';

describe('issue 1080', () => {
it('regression', async () => {
const { enhance } = await loadSchema(
`
model Project {
id String @id @unique @default(uuid())
Fields Field[]
@@allow('all', true)
}
model Field {
id String @id @unique @default(uuid())
name String
Project Project @relation(fields: [projectId], references: [id])
projectId String
@@allow('all', true)
}
`,
{ logPrismaQuery: true }
);

const db = enhance();

const project = await db.project.create({
include: { Fields: true },
data: {
Fields: {
create: [{ name: 'first' }, { name: 'second' }],
},
},
});

let updated = await db.project.update({
where: { id: project.id },
include: { Fields: true },
data: {
Fields: {
upsert: [
{
where: { id: project.Fields[0].id },
create: { name: 'first1' },
update: { name: 'first1' },
},
{
where: { id: project.Fields[1].id },
create: { name: 'second1' },
update: { name: 'second1' },
},
],
},
},
});
expect(updated).toMatchObject({
Fields: expect.arrayContaining([
expect.objectContaining({ name: 'first1' }),
expect.objectContaining({ name: 'second1' }),
]),
});

updated = await db.project.update({
where: { id: project.id },
include: { Fields: true },
data: {
Fields: {
upsert: {
where: { id: project.Fields[0].id },
create: { name: 'first2' },
update: { name: 'first2' },
},
},
},
});
expect(updated).toMatchObject({
Fields: expect.arrayContaining([
expect.objectContaining({ name: 'first2' }),
expect.objectContaining({ name: 'second1' }),
]),
});

updated = await db.project.update({
where: { id: project.id },
include: { Fields: true },
data: {
Fields: {
upsert: {
where: { id: project.Fields[0].id },
create: { name: 'first3' },
update: { name: 'first3' },
},
update: {
where: { id: project.Fields[1].id },
data: { name: 'second3' },
},
},
},
});
expect(updated).toMatchObject({
Fields: expect.arrayContaining([
expect.objectContaining({ name: 'first3' }),
expect.objectContaining({ name: 'second3' }),
]),
});

updated = await db.project.update({
where: { id: project.id },
include: { Fields: true },
data: {
Fields: {
upsert: {
where: { id: 'non-exist' },
create: { name: 'third1' },
update: { name: 'third1' },
},
update: {
where: { id: project.Fields[1].id },
data: { name: 'second4' },
},
},
},
});
expect(updated).toMatchObject({
Fields: expect.arrayContaining([
expect.objectContaining({ name: 'first3' }),
expect.objectContaining({ name: 'second4' }),
expect.objectContaining({ name: 'third1' }),
]),
});
});
});

0 comments on commit 9daff09

Please sign in to comment.