Skip to content

Commit

Permalink
refactor: prompt system (run-llama#1154)
Browse files Browse the repository at this point in the history
  • Loading branch information
himself65 authored Sep 6, 2024
1 parent 11b3856 commit 0148354
Show file tree
Hide file tree
Showing 48 changed files with 1,309 additions and 823 deletions.
10 changes: 10 additions & 0 deletions .changeset/tall-kangaroos-sleep.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
"@llamaindex/core": patch
"llamaindex": patch
"@llamaindex/core-tests": patch
"llamaindex-loader-example": patch
---

refactor: prompt system

Add `PromptTemplate` module with strong type check.
13 changes: 7 additions & 6 deletions examples/prompts/promptMixin.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import {
Document,
PromptTemplate,
ResponseSynthesizer,
TreeSummarize,
TreeSummarizePrompt,
VectorStoreIndex,
} from "llamaindex";

const treeSummarizePrompt: TreeSummarizePrompt = ({ context, query }) => {
return `Context information from multiple sources is below.
const treeSummarizePrompt: TreeSummarizePrompt = new PromptTemplate({
template: `Context information from multiple sources is below.
---------------------
${context}
{context}
---------------------
Given the information from multiple sources and not prior knowledge.
Answer the query in the style of a Shakespeare play"
Query: ${query}
Answer:`;
};
Query: {query}
Answer:`,
});

async function main() {
const documents = new Document({
Expand Down
14 changes: 8 additions & 6 deletions examples/readers/src/csv.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
CompactAndRefine,
OpenAI,
PromptTemplate,
ResponseSynthesizer,
Settings,
VectorStoreIndex,
Expand All @@ -18,14 +19,15 @@ async function main() {
// Split text and create embeddings. Store them in a VectorStoreIndex
const index = await VectorStoreIndex.fromDocuments(documents);

const csvPrompt = ({ context = "", query = "" }) => {
return `The following CSV file is loaded from ${path}
const csvPrompt = new PromptTemplate({
templateVars: ["query", "context"],
template: `The following CSV file is loaded from ${path}
\`\`\`csv
${context}
{context}
\`\`\`
Given the CSV file, generate me Typescript code to answer the question: ${query}. You can use built in NodeJS functions but avoid using third party libraries.
`;
};
Given the CSV file, generate me Typescript code to answer the question: {query}. You can use built in NodeJS functions but avoid using third party libraries.
`,
});

const responseSynthesizer = new ResponseSynthesizer({
responseBuilder: new CompactAndRefine(undefined, csvPrompt),
Expand Down
5 changes: 4 additions & 1 deletion examples/readers/src/llamaparse-json.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {
ImageNode,
LlamaParseReader,
OpenAI,
PromptTemplate,
VectorStoreIndex,
} from "llamaindex";
import { createMessageContent } from "llamaindex/synthesizers/utils";
Expand Down Expand Up @@ -50,7 +51,9 @@ async function getImageTextDocs(

for (const imageDict of imageDicts) {
const imageDoc = new ImageNode({ image: imageDict.path });
const prompt = () => `Describe the image as alt text`;
const prompt = new PromptTemplate({
template: `Describe the image as alt text`,
});
const message = await createMessageContent(prompt, [imageDoc]);

const response = await llm.complete({
Expand Down
17 changes: 16 additions & 1 deletion packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@
"types": "./dist/utils/index.d.ts",
"default": "./dist/utils/index.js"
}
},
"./prompts": {
"require": {
"types": "./dist/prompts/index.d.cts",
"default": "./dist/prompts/index.cjs"
},
"import": {
"types": "./dist/prompts/index.d.ts",
"default": "./dist/prompts/index.js"
},
"default": {
"types": "./dist/prompts/index.d.ts",
"default": "./dist/prompts/index.js"
}
}
},
"files": [
Expand All @@ -132,7 +146,8 @@
"devDependencies": {
"ajv": "^8.17.1",
"bunchee": "5.3.2",
"natural": "^8.0.1"
"natural": "^8.0.1",
"python-format-js": "^1.4.3"
},
"dependencies": {
"@llamaindex/env": "workspace:*",
Expand Down
225 changes: 225 additions & 0 deletions packages/core/src/prompts/base.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import format from "python-format-js";
import type { ChatMessage } from "../llms";
import type { BaseOutputParser, Metadata } from "../schema";
import { objectEntries } from "../utils";
import { PromptType } from "./prompt-type";

type MappingFn<TemplatesVar extends string[] = string[]> = (
options: Record<TemplatesVar[number], string>,
) => string;

export type BasePromptTemplateOptions<
TemplatesVar extends readonly string[],
Vars extends readonly string[],
> = {
metadata?: Metadata;
templateVars?:
| TemplatesVar
// loose type for better type inference
| readonly string[];
options?: Partial<Record<TemplatesVar[number] | (string & {}), string>>;
outputParser?: BaseOutputParser;
templateVarMappings?: Partial<
Record<Vars[number] | (string & {}), TemplatesVar[number] | (string & {})>
>;
functionMappings?: Partial<
Record<TemplatesVar[number] | (string & {}), MappingFn>
>;
};

export abstract class BasePromptTemplate<
const TemplatesVar extends readonly string[] = string[],
const Vars extends readonly string[] = string[],
> {
metadata: Metadata = {};
templateVars: Set<string> = new Set();
options: Partial<Record<TemplatesVar[number] | (string & {}), string>> = {};
outputParser?: BaseOutputParser;
templateVarMappings: Partial<
Record<Vars[number] | (string & {}), TemplatesVar[number] | (string & {})>
> = {};
functionMappings: Partial<
Record<TemplatesVar[number] | (string & {}), MappingFn>
> = {};

protected constructor(
options: BasePromptTemplateOptions<TemplatesVar, Vars>,
) {
const {
metadata,
templateVars,
outputParser,
templateVarMappings,
functionMappings,
} = options;
if (metadata) {
this.metadata = metadata;
}
if (templateVars) {
this.templateVars = new Set(templateVars);
}
if (options.options) {
this.options = options.options;
}
this.outputParser = outputParser;
if (templateVarMappings) {
this.templateVarMappings = templateVarMappings;
}
if (functionMappings) {
this.functionMappings = functionMappings;
}
}

protected mapTemplateVars(
options: Record<TemplatesVar[number] | (string & {}), string>,
) {
const templateVarMappings = this.templateVarMappings;
return Object.fromEntries(
objectEntries(options).map(([k, v]) => [templateVarMappings[k] || k, v]),
);
}

protected mapFunctionVars(
options: Record<TemplatesVar[number] | (string & {}), string>,
) {
const functionMappings = this.functionMappings;
const newOptions = {} as Record<TemplatesVar[number], string>;
for (const [k, v] of objectEntries(functionMappings)) {
newOptions[k] = v!(options);
}

for (const [k, v] of objectEntries(options)) {
if (!(k in newOptions)) {
newOptions[k] = v;
}
}

return newOptions;
}

protected mapAllVars(
options: Record<TemplatesVar[number] | (string & {}), string>,
): Record<string, string> {
const newOptions = this.mapFunctionVars(options);
return this.mapTemplateVars(newOptions);
}

abstract partialFormat(
options: Partial<Record<TemplatesVar[number] | (string & {}), string>>,
): BasePromptTemplate<TemplatesVar, Vars>;

abstract format(
options?: Partial<Record<TemplatesVar[number] | (string & {}), string>>,
): string;

abstract formatMessages(
options?: Partial<Record<TemplatesVar[number] | (string & {}), string>>,
): ChatMessage[];

abstract get template(): string;
}

type Permutation<T, K = T> = [T] extends [never]
? []
: K extends K
? [K, ...Permutation<Exclude<T, K>>]
: never;

type Join<T extends any[], U extends string> = T extends [infer F, ...infer R]
? R["length"] extends 0
? `${F & string}`
: `${F & string}${U}${Join<R, U>}`
: never;

type WrapStringWithBracket<T extends string> = `{${T}}`;

export type StringTemplate<Var extends readonly string[]> =
Var["length"] extends 0
? string
: Var["length"] extends number
? number extends Var["length"]
? string
: `${string}${Join<Permutation<WrapStringWithBracket<Var[number]>>, `${string}`>}${string}`
: never;

export type PromptTemplateOptions<
TemplatesVar extends readonly string[],
Vars extends readonly string[],
Template extends StringTemplate<TemplatesVar>,
> = BasePromptTemplateOptions<TemplatesVar, Vars> & {
template: Template;
promptType?: PromptType;
};

export class PromptTemplate<
const TemplatesVar extends readonly string[] = string[],
const Vars extends readonly string[] = string[],
const Template extends
StringTemplate<TemplatesVar> = StringTemplate<TemplatesVar>,
> extends BasePromptTemplate<TemplatesVar, Vars> {
#template: Template;
promptType: PromptType;

constructor(options: PromptTemplateOptions<TemplatesVar, Vars, Template>) {
const { template, promptType, ...rest } = options;
super(rest);
this.#template = template;
this.promptType = promptType ?? PromptType.custom;
}

partialFormat(
options: Partial<Record<TemplatesVar[number] | (string & {}), string>>,
): PromptTemplate<TemplatesVar, Vars, Template> {
const prompt = new PromptTemplate({
template: this.template,
templateVars: [...this.templateVars],
options: this.options,
outputParser: this.outputParser,
templateVarMappings: this.templateVarMappings,
functionMappings: this.functionMappings,
metadata: this.metadata,
promptType: this.promptType,
});

prompt.options = {
...prompt.options,
...options,
};

return prompt;
}

format(
options?: Partial<Record<TemplatesVar[number] | (string & {}), string>>,
): string {
const allOptions = {
...this.options,
...options,
} as Record<TemplatesVar[number], string>;

const mappedAllOptions = this.mapAllVars(allOptions);

const prompt = format(this.template, mappedAllOptions);

if (this.outputParser) {
return this.outputParser.format(prompt);
}
return prompt;
}

formatMessages(
options?: Partial<Record<TemplatesVar[number] | (string & {}), string>>,
): ChatMessage[] {
const prompt = this.format(options);
return [
{
role: "user",
content: prompt,
},
];
}

get template(): Template {
return this.#template;
}
}
33 changes: 33 additions & 0 deletions packages/core/src/prompts/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
export { BasePromptTemplate, PromptTemplate } from "./base";
export type {
BasePromptTemplateOptions,
PromptTemplateOptions,
StringTemplate,
} from "./base";
export { PromptMixin, type ModuleRecord, type PromptsRecord } from "./mixin";
export {
anthropicSummaryPrompt,
anthropicTextQaPrompt,
defaultChoiceSelectPrompt,
defaultCondenseQuestionPrompt,
defaultContextSystemPrompt,
defaultKeywordExtractPrompt,
defaultQueryKeywordExtractPrompt,
defaultRefinePrompt,
defaultSubQuestionPrompt,
defaultSummaryPrompt,
defaultTextQAPrompt,
defaultTreeSummarizePrompt,
} from "./prompt";
export type {
ChoiceSelectPrompt,
CondenseQuestionPrompt,
ContextSystemPrompt,
KeywordExtractPrompt,
QueryKeywordExtractPrompt,
RefinePrompt,
SubQuestionPrompt,
SummaryPrompt,
TextQAPrompt,
TreeSummarizePrompt,
} from "./prompt";
Loading

0 comments on commit 0148354

Please sign in to comment.