Skip to content

Commit

Permalink
feat: allow for sync references to prompt files (#696)
Browse files Browse the repository at this point in the history
* feat: allow for sync references to prompt files

* docs: updated dotprompt docs with sync promptRef
  • Loading branch information
cabljac authored Jul 29, 2024
1 parent 04577c9 commit 72e0629
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 14 deletions.
18 changes: 9 additions & 9 deletions docs/dotprompt.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,19 @@ You are the world's most welcoming AI assistant and are currently working at {{l
Greet a guest{{#if name}} named {{name}}{{/if}}{{#if style}} in the style of {{style}}{{/if}}.
```

To use this prompt, install the `dotprompt` plugin, and import the `prompt` function from
To use this prompt, install the `dotprompt` plugin, and import the `promptRef` function from
the `@genkit-ai/dotprompt` library:

```ts
import { dotprompt, prompt } from '@genkit-ai/dotprompt';
import { dotprompt, promptRef } from '@genkit-ai/dotprompt';

configureGenkit({ plugins: [dotprompt()] });
```

Then, load the prompt using `prompt('file_name')`:
Then, load the prompt using `promptRef('file_name')`:

```ts
const greetingPrompt = await prompt('greeting');
const greetingPrompt = promptRef('greeting');

const result = await greetingPrompt.generate({
input: {
Expand Down Expand Up @@ -176,9 +176,9 @@ registered Zod schema. You can then utilize the schema to strongly type the
output of a Dotprompt:

```ts
import { prompt } from "@genkit-ai/dotprompt";
import { promptRef } from "@genkit-ai/dotprompt";
const myPrompt = await prompt("myPrompt");
const myPrompt = promptRef("myPrompt");
const result = await myPrompt.generate<typeof MySchema>({...});
Expand Down Expand Up @@ -229,7 +229,7 @@ When generating a prompt with structured output, use the `output()` helper to
retrieve and validate it:

```ts
const createMenuPrompt = await prompt('create_menu');
const createMenuPrompt = promptRef('create_menu');
const menu = await createMenuPrompt.generate({
input: {
Expand Down Expand Up @@ -332,7 +332,7 @@ The URL can be `https://` or base64-encoded `data:` URIs for "inline" image
usage. In code, this would be:

```ts
const describeImagePrompt = await prompt('describe_image');
const describeImagePrompt = promptRef('describe_image');
const result = await describeImagePrompt.generate({
input: {
Expand Down Expand Up @@ -425,7 +425,7 @@ Pro would perform better, you might create two files:
To use a prompt variant, specify the `variant` option when loading:

```ts
const myPrompt = await prompt('my_prompt', { variant: 'gemini15pro' });
const myPrompt = promptRef('my_prompt', { variant: 'gemini15pro' });
```

The name of the variant is included in the metadata of generation traces, so you
Expand Down
9 changes: 8 additions & 1 deletion js/plugins/dotprompt/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {

import { readFileSync } from 'fs';
import { basename } from 'path';
import { defineDotprompt, Dotprompt } from './prompt.js';
import { defineDotprompt, Dotprompt, DotpromptRef } from './prompt.js';
import { loadPromptFolder, lookupPrompt } from './registry.js';

export { defineHelper, definePartial } from './template.js';
Expand Down Expand Up @@ -57,6 +57,13 @@ export async function prompt<Variables = unknown>(
return (await lookupPrompt(name, options?.variant)) as Dotprompt<Variables>;
}

export function promptRef<Variables = unknown>(
name: string,
options?: { variant?: string; dir?: string }
): DotpromptRef<Variables> {
return new DotpromptRef(name, options);
}

export function loadPromptFile(path: string): Dotprompt {
return Dotprompt.parse(
basename(path).split('.')[0],
Expand Down
47 changes: 45 additions & 2 deletions js/plugins/dotprompt/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import {
toFrontmatter,
toMetadata,
} from './metadata.js';
import { registryDefinitionKey } from './registry.js';
import { lookupPrompt, registryDefinitionKey } from './registry.js';
import { compile } from './template.js';

export type PromptData = PromptFrontmatter & { template: string };
Expand Down Expand Up @@ -175,7 +175,7 @@ export class Dotprompt<Variables = unknown> implements PromptMetadata {
});
return {
model: options.model || this.model!,
config: { ...this.config, ...options.config } || {},
config: { ...this.config, ...options.config },
history: messages.slice(0, messages.length - 1),
prompt: messages[messages.length - 1].content,
context: options.context,
Expand Down Expand Up @@ -210,6 +210,49 @@ export class Dotprompt<Variables = unknown> implements PromptMetadata {
}
}

export class DotpromptRef<Variables = unknown> {
name: string;
variant?: string;
dir?: string;
private _prompt?: Dotprompt<Variables>;

constructor(
name: string,
options?: {
variant?: string;
dir?: string;
}
) {
this.name = name;
this.variant = options?.variant;
this.dir = options?.dir;
}

async loadPrompt(): Promise<Dotprompt<Variables>> {
if (this._prompt) return this._prompt;
this._prompt = (await lookupPrompt(
this.name,
this.variant,
this.dir
)) as Dotprompt<Variables>;
return this._prompt;
}

async generate<O extends z.ZodTypeAny = z.ZodTypeAny>(
opt: PromptGenerateOptions<Variables>
): Promise<GenerateResponse<z.infer<O>>> {
const prompt = await this.loadPrompt();
return prompt.generate<O>(opt);
}

async render<O extends z.ZodTypeAny = z.ZodTypeAny>(
opt: PromptGenerateOptions<Variables>
): Promise<GenerateOptions<z.ZodTypeAny, O>> {
const prompt = await this.loadPrompt();
return prompt.render<O>(opt);
}
}

export function defineDotprompt<V extends z.ZodTypeAny = z.ZodTypeAny>(
options: PromptMetadata<V>,
template: string
Expand Down
127 changes: 126 additions & 1 deletion js/plugins/dotprompt/tests/prompt_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import { toJsonSchema, ValidationError } from '@genkit-ai/core/schema';
import z from 'zod';
import { registerPluginProvider } from '../../../core/src/registry.js';
import { defineJsonSchema, defineSchema } from '../../../core/src/schema.js';
import { defineDotprompt, Dotprompt, prompt } from '../src/index.js';
import { defineDotprompt, Dotprompt, prompt, promptRef } from '../src/index.js';
import { PromptMetadata } from '../src/metadata.js';

function registerDotprompt() {
Expand Down Expand Up @@ -251,3 +251,128 @@ output:
});
});
});

describe('DotpromptRef', () => {
it('Should load a prompt correctly', async () => {
registerDotprompt();
defineDotprompt(
{
name: 'promptName',
model: 'echo',
},
`This is a prompt.`
);

const ref = promptRef('promptName');

const p = await ref.loadPrompt();

const isDotprompt = p instanceof Dotprompt;

assert.equal(isDotprompt, true);
assert.equal(p.template, 'This is a prompt.');
});

it('Should generate output correctly using DotpromptRef', async () => {
registerDotprompt();
defineDotprompt(
{
name: 'generatePrompt',
model: 'echo',
},
`Hello {{name}}, this is a test prompt.`
);

const ref = promptRef('generatePrompt');
const response = await ref.generate({ input: { name: 'Alice' } });

assert.equal(response.text(), 'Hello Alice, this is a test prompt.');
});

it('Should render correctly using DotpromptRef', async () => {
registerDotprompt();
defineDotprompt(
{
name: 'renderPrompt',
model: 'echo',
},
`Hi {{name}}, welcome to the system.`
);

const ref = promptRef('renderPrompt');
const rendered = await ref.render({ input: { name: 'Bob' } });

assert.deepStrictEqual(rendered.prompt, [
{ text: 'Hi Bob, welcome to the system.' },
]);
});

it('Should handle invalid schema input in DotpromptRef', async () => {
registerDotprompt();
defineDotprompt(
{
name: 'invalidSchemaPromptRef',
model: 'echo',
input: {
jsonSchema: {
properties: { foo: { type: 'boolean' } },
required: ['foo'],
},
},
},
`This is the prompt with foo={{foo}}.`
);

const ref = promptRef('invalidSchemaPromptRef');

await assert.rejects(async () => {
await ref.generate({ input: { foo: 'not_a_boolean' } });
}, ValidationError);
});

it('Should support streamingCallback in DotpromptRef', async () => {
registerDotprompt();
defineDotprompt(
{
name: 'streamingCallbackPrompt',
model: 'echo',
},
`Hello {{name}}, streaming test.`
);

const ref = promptRef('streamingCallbackPrompt');

const streamingCallback = (chunk) => console.log(chunk);
const options = {
input: { name: 'Charlie' },
streamingCallback,
returnToolRequests: true,
};

const rendered = await ref.render(options);

assert.strictEqual(rendered.streamingCallback, streamingCallback);
assert.strictEqual(rendered.returnToolRequests, true);
});

it('Should cache loaded prompt in DotpromptRef', async () => {
registerDotprompt();
defineDotprompt(
{
name: 'cacheTestPrompt',
model: 'echo',
},
`This is a prompt for cache test.`
);

const ref = promptRef('cacheTestPrompt');
const firstLoad = await ref.loadPrompt();
const secondLoad = await ref.loadPrompt();

assert.strictEqual(
firstLoad,
secondLoad,
'Loaded prompts should be identical (cached).'
);
});
});
2 changes: 1 addition & 1 deletion js/plugins/dotprompt/tests/template_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ describe('compile', () => {
],
},
{
should: 'insert a blank ouptut section when helper provided',
should: 'insert a blank output section when helper provided',
input: {},
template: `before{{section "output"}}after`,
options: {
Expand Down

0 comments on commit 72e0629

Please sign in to comment.