Skip to content

Commit

Permalink
feat(hooks): add "portable" generation mode
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 committed Nov 14, 2024
1 parent 716091e commit 7b81985
Show file tree
Hide file tree
Showing 17 changed files with 266 additions and 37 deletions.
3 changes: 1 addition & 2 deletions packages/plugins/swr/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ function generateModelHooks(
const fileName = paramCase(model.name);
const sf = project.createSourceFile(path.join(outDir, `${fileName}.ts`), undefined, { overwrite: true });

sf.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(outDir, options);
sf.addImportDeclaration({
namedImports: ['Prisma'],
Expand Down Expand Up @@ -261,6 +259,7 @@ function generateIndex(project: Project, outDir: string, models: DataModel[]) {
const sf = project.createSourceFile(path.join(outDir, 'index.ts'), undefined, { overwrite: true });
sf.addStatements(models.map((d) => `export * from './${paramCase(d.name)}';`));
sf.addStatements(`export { Provider } from '@zenstackhq/swr/runtime';`);
sf.addStatements(`export { default as metadata } from './__model_meta';`);
}

function generateQueryHook(
Expand Down
36 changes: 33 additions & 3 deletions packages/plugins/tanstack-query/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
import { DataModel, DataModelFieldType, Model, isEnum, isTypeDef } from '@zenstackhq/sdk/ast';
import { getPrismaClientImportSpec, supportCreateMany, type DMMF } from '@zenstackhq/sdk/prisma';
import { paramCase } from 'change-case';
import fs from 'fs';
import { lowerCaseFirst } from 'lower-case-first';
import path from 'path';
import { Project, SourceFile, VariableDeclarationKind } from 'ts-morph';
Expand Down Expand Up @@ -45,6 +46,14 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
outDir = resolvePath(outDir, options);
ensureEmptyDir(outDir);

if (options.portable && typeof options.portable !== 'boolean') {
throw new PluginError(
name,
`Invalid value for "portable" option: ${options.portable}, a boolean value is expected`
);
}
const portable = options.portable ?? false;

await generateModelMeta(project, models, typeDefs, {
output: path.join(outDir, '__model_meta.ts'),
generateAttributes: false,
Expand All @@ -61,6 +70,10 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
generateModelHooks(target, version, project, outDir, dataModel, mapping, options);
});

if (portable) {
generateBundledTypes(project, outDir, options);
}

await saveProject(project);
return { warnings };
}
Expand Down Expand Up @@ -333,9 +346,7 @@ function generateModelHooks(
const fileName = paramCase(model.name);
const sf = project.createSourceFile(path.join(outDir, `${fileName}.ts`), undefined, { overwrite: true });

sf.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(outDir, options);
const prismaImport = options.portable ? './__types' : getPrismaClientImportSpec(outDir, options);
sf.addImportDeclaration({
namedImports: ['Prisma', model.name],
isTypeOnly: true,
Expand Down Expand Up @@ -584,6 +595,7 @@ function generateIndex(
sf.addStatements(`export { SvelteQueryContextKey, setHooksContext } from '${runtimeImportBase}/svelte';`);
break;
}
sf.addStatements(`export { default as metadata } from './__model_meta';`);
}

function makeGetContext(target: TargetFramework) {
Expand Down Expand Up @@ -724,3 +736,21 @@ function makeMutationOptions(target: string, returnType: string, argsType: strin
function makeRuntimeImportBase(version: TanStackVersion) {
return `@zenstackhq/tanstack-query/runtime${version === 'v5' ? '-v5' : ''}`;
}

function generateBundledTypes(project: Project, outDir: string, options: PluginOptions) {
if (!options.prismaClientDtsPath) {
throw new PluginError(name, `Unable to determine the location of PrismaClient types`);
}

// copy PrismaClient index.d.ts
const content = fs.readFileSync(options.prismaClientDtsPath, 'utf-8');
project.createSourceFile(path.join(outDir, '__types.d.ts'), content, { overwrite: true });

// "runtime/library.d.ts" is referenced by Prisma's DTS, and it's generated into Prisma's output
// folder if a custom output is specified; if not, it's referenced from '@prisma/client'
const libraryDts = path.join(path.dirname(options.prismaClientDtsPath), 'runtime', 'library.d.ts');
if (fs.existsSync(libraryDts)) {
const content = fs.readFileSync(libraryDts, 'utf-8');
project.createSourceFile(path.join(outDir, 'runtime', 'library.d.ts'), content, { overwrite: true });
}
}
153 changes: 153 additions & 0 deletions packages/plugins/tanstack-query/tests/portable.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/// <reference types="@types/jest" />

import { loadSchema, normalizePath } from '@zenstackhq/testtools';
import path from 'path';
import tmp from 'tmp';

describe('Tanstack Query Plugin Portable Tests', () => {
it('supports portable for standard prisma client', async () => {
await loadSchema(
`
plugin tanstack {
provider = '${normalizePath(path.resolve(__dirname, '../dist'))}'
output = '$projectRoot/hooks'
target = 'react'
portable = true
}
model User {
id Int @id @default(autoincrement())
email String
posts Post[]
}
model Post {
id Int @id @default(autoincrement())
title String
author User @relation(fields: [authorId], references: [id])
authorId Int
}
`,
{
provider: 'postgresql',
pushDb: false,
extraDependencies: ['react@18.2.0', '@types/react@18.2.0', '@tanstack/react-query@5.56.x'],
copyDependencies: [path.resolve(__dirname, '../dist')],
compile: true,
extraSourceFiles: [
{
name: 'main.ts',
content: `
import { useFindUniqueUser } from './hooks';
const { data } = useFindUniqueUser({ where: { id: 1 }, include: { posts: true } });
console.log(data?.email);
console.log(data?.posts[0].title);
`,
},
],
}
);
});

it('supports portable for custom prisma client output', async () => {
const t = tmp.dirSync({ unsafeCleanup: true });
const projectDir = t.name;

await loadSchema(
`
datasource db {
provider = 'postgresql'
url = env('DATABASE_URL')
}
generator client {
provider = 'prisma-client-js'
output = '$projectRoot/myprisma'
}
plugin tanstack {
provider = '${normalizePath(path.resolve(__dirname, '../dist'))}'
output = '$projectRoot/hooks'
target = 'react'
portable = true
}
model User {
id Int @id @default(autoincrement())
email String
posts Post[]
}
model Post {
id Int @id @default(autoincrement())
title String
author User @relation(fields: [authorId], references: [id])
authorId Int
}
`,
{
provider: 'postgresql',
pushDb: false,
extraDependencies: ['react@18.2.0', '@types/react@18.2.0', '@tanstack/react-query@5.56.x'],
copyDependencies: [path.resolve(__dirname, '../dist')],
compile: true,
addPrelude: false,
projectDir,
prismaLoadPath: `${projectDir}/myprisma`,
extraSourceFiles: [
{
name: 'main.ts',
content: `
import { useFindUniqueUser } from './hooks';
const { data } = useFindUniqueUser({ where: { id: 1 }, include: { posts: true } });
console.log(data?.email);
console.log(data?.posts[0].title);
`,
},
],
}
);
});

it('supports portable for logical client', async () => {
await loadSchema(
`
plugin tanstack {
provider = '${normalizePath(path.resolve(__dirname, '../dist'))}'
output = '$projectRoot/hooks'
target = 'react'
portable = true
}
model Base {
id Int @id @default(autoincrement())
createdAt DateTime @default(now())
type String
@@delegate(type)
}
model User extends Base {
email String
}
`,
{
provider: 'postgresql',
pushDb: false,
extraDependencies: ['react@18.2.0', '@types/react@18.2.0', '@tanstack/react-query@5.56.x'],
copyDependencies: [path.resolve(__dirname, '../dist')],
compile: true,
extraSourceFiles: [
{
name: 'main.ts',
content: `
import { useFindUniqueUser } from './hooks';
const { data } = useFindUniqueUser({ where: { id: 1 } });
console.log(data?.email);
console.log(data?.createdAt);
`,
},
],
}
);
});
});
2 changes: 0 additions & 2 deletions packages/plugins/trpc/src/client-helper/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ export function generateClientTypingForModel(
}
);

sf.addStatements([`/* eslint-disable */`]);

generateImports(clientType, sf, options, version);

// generate a `ClientType` interface that contains typing for query/mutation operations
Expand Down
5 changes: 0 additions & 5 deletions packages/plugins/trpc/src/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ function createAppRouter(
overwrite: true,
});

appRouter.addStatements('/* eslint-disable */');

const prismaImport = getPrismaClientImportSpec(path.dirname(indexFile), options);

if (version === 'v10') {
Expand Down Expand Up @@ -274,8 +272,6 @@ function generateModelCreateRouter(
overwrite: true,
});

modelRouter.addStatements('/* eslint-disable */');

if (version === 'v10') {
modelRouter.addImportDeclarations([
{
Expand Down Expand Up @@ -386,7 +382,6 @@ function createHelper(outDir: string) {
overwrite: true,
});

sf.addStatements('/* eslint-disable */');
sf.addStatements(`import { TRPCError } from '@trpc/server';`);
sf.addStatements(`import { isPrismaClientKnownRequestError } from '${RUNTIME_PACKAGE}';`);

Expand Down
7 changes: 5 additions & 2 deletions packages/schema/src/cli/plugin-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ export class PluginRunner {
let dmmf: DMMF.Document | undefined = undefined;
let shortNameMap: Map<string, string> | undefined;
let prismaClientPath = '@prisma/client';
let prismaClientDtsPath: string | undefined = undefined;

const project = createProject();
for (const { name, description, run, options: pluginOptions } of corePlugins) {
const options = { ...pluginOptions, prismaClientPath };
Expand Down Expand Up @@ -165,6 +167,7 @@ export class PluginRunner {
if (r.prismaClientPath) {
// use the prisma client path returned by the plugin
prismaClientPath = r.prismaClientPath;
prismaClientDtsPath = r.prismaClientDtsPath;
}
}

Expand All @@ -173,13 +176,13 @@ export class PluginRunner {

// run user plugins
for (const { name, description, run, options: pluginOptions } of userPlugins) {
const options = { ...pluginOptions, prismaClientPath };
const options = { ...pluginOptions, prismaClientPath, prismaClientDtsPath };
const r = await this.runPlugin(
name,
description,
run,
runnerOptions,
options,
options as PluginOptions,
dmmf,
shortNameMap,
project,
Expand Down
9 changes: 7 additions & 2 deletions packages/schema/src/plugins/enhancer/enhance/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ export class EnhancerGenerator {
private readonly outDir: string
) {}

async generate(): Promise<{ dmmf: DMMF.Document | undefined }> {
async generate(): Promise<{ dmmf: DMMF.Document | undefined; newPrismaClientDtsPath: string | undefined }> {
let dmmf: DMMF.Document | undefined;

const prismaImport = getPrismaClientImportSpec(this.outDir, this.options);
Expand Down Expand Up @@ -128,7 +128,12 @@ ${
await this.saveSourceFile(enhanceTs);
}

return { dmmf };
return {
dmmf,
newPrismaClientDtsPath: prismaTypesFixed
? path.resolve(this.outDir, LOGICAL_CLIENT_GENERATION_PATH, 'index-fixed.d.ts')
: undefined,
};
}

private getZodImport() {
Expand Down
4 changes: 2 additions & 2 deletions packages/schema/src/plugins/enhancer/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const run: PluginFunction = async (model, options, _dmmf, globalOptions) => {

await generateModelMeta(model, options, project, outDir);
await generatePolicy(model, options, project, outDir);
const { dmmf } = await new EnhancerGenerator(model, options, project, outDir).generate();
const { dmmf, newPrismaClientDtsPath } = await new EnhancerGenerator(model, options, project, outDir).generate();

let prismaClientPath: string | undefined;
if (dmmf) {
Expand All @@ -44,7 +44,7 @@ const run: PluginFunction = async (model, options, _dmmf, globalOptions) => {
}
}

return { dmmf, warnings: [], prismaClientPath };
return { dmmf, warnings: [], prismaClientPath, prismaClientDtsPath: newPrismaClientDtsPath };
};

export default run;
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ export class PolicyGenerator {

async generate(project: Project, model: Model, output: string) {
const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true });
sf.addStatements('/* eslint-disable */');

this.writeImports(model, output, sf);

Expand Down
Loading

0 comments on commit 7b81985

Please sign in to comment.