Skip to content

Commit

Permalink
strong types for prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
yisding committed Aug 29, 2023
1 parent d1aa3b7 commit 259fe63
Show file tree
Hide file tree
Showing 12 changed files with 92 additions and 76 deletions.
5 changes: 5 additions & 0 deletions .changeset/fair-pets-leave.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

Strong types for prompts.
5 changes: 1 addition & 4 deletions apps/simple/csv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
PapaCSVReader,
ResponseSynthesizer,
serviceContextFromDefaults,
SimplePrompt,
VectorStoreIndex,
} from "llamaindex";

Expand All @@ -23,9 +22,7 @@ async function main() {
serviceContext,
});

const csvPrompt: SimplePrompt = (input) => {
const { context = "", query = "" } = input;

const csvPrompt = ({ context = "", query = "" }) => {
return `The following CSV file is loaded from ${path}
\`\`\`csv
${context}
Expand Down
9 changes: 1 addition & 8 deletions apps/simple/openai.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
import { OpenAI } from "llamaindex";

(async () => {
const llm = new OpenAI({
model: "gpt-3.5-turbo",
temperature: 0.1,
additionalChatOptions: { frequency_penalty: 0.1 },
additionalSessionOptions: {
defaultHeaders: { "X-Test-Header-Please-Ignore": "true" },
},
});
const llm = new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 });

// complete api
const response1 = await llm.complete("How are you?");
Expand Down
2 changes: 1 addition & 1 deletion apps/simple/vectorIndexCustomize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async function main() {
const document = new Document({ text: essay, id_: "essay" });

const serviceContext = serviceContextFromDefaults({
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 }),
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }),
});

const index = await VectorStoreIndex.fromDocuments([document], {
Expand Down
5 changes: 1 addition & 4 deletions examples/csv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
PapaCSVReader,
ResponseSynthesizer,
serviceContextFromDefaults,
SimplePrompt,
VectorStoreIndex,
} from "llamaindex";

Expand All @@ -23,9 +22,7 @@ async function main() {
serviceContext,
});

const csvPrompt: SimplePrompt = (input) => {
const { context = "", query = "" } = input;

const csvPrompt = ({ context = "", query = "" }) => {
return `The following CSV file is loaded from ${path}
\`\`\`csv
${context}
Expand Down
6 changes: 4 additions & 2 deletions examples/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ import { OpenAI } from "llamaindex";

(async () => {
const llm = new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 });

// complete api
const response1 = await llm.complete("How are you?");
console.log(response1.message.content);

// chat api
const response2 = await llm.chat([{ content: "Tell me a joke!", role: "user" }]);
const response2 = await llm.chat([
{ content: "Tell me a joke!", role: "user" },
]);
console.log(response2.message.content);
})();
2 changes: 1 addition & 1 deletion examples/vectorIndexCustomize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async function main() {
const document = new Document({ text: essay, id_: "essay" });

const serviceContext = serviceContextFromDefaults({
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.0 }),
llm: new OpenAI({ model: "gpt-3.5-turbo", temperature: 0.1 }),
});

const index = await VectorStoreIndex.fromDocuments([document], {
Expand Down
31 changes: 18 additions & 13 deletions packages/core/src/ChatEngine.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import { ChatMessage, OpenAI, ChatResponse, LLM } from "./llm/LLM";
import { v4 as uuidv4 } from "uuid";
import { TextNode } from "./Node";
import {
SimplePrompt,
contextSystemPrompt,
CondenseQuestionPrompt,
ContextSystemPrompt,
defaultCondenseQuestionPrompt,
defaultContextSystemPrompt,
messagesToHistoryStr,
} from "./Prompt";
import { BaseQueryEngine } from "./QueryEngine";
import { Response } from "./Response";
import { BaseRetriever } from "./Retriever";
import { ServiceContext, serviceContextFromDefaults } from "./ServiceContext";
import { v4 as uuidv4 } from "uuid";
import { Event } from "./callbacks/CallbackManager";
import { ChatMessage, LLM, OpenAI } from "./llm/LLM";

/**
* A ChatEngine is used to handle back and forth chats between the application and the LLM.
Expand Down Expand Up @@ -70,13 +71,13 @@ export class CondenseQuestionChatEngine implements ChatEngine {
queryEngine: BaseQueryEngine;
chatHistory: ChatMessage[];
serviceContext: ServiceContext;
condenseMessagePrompt: SimplePrompt;
condenseMessagePrompt: CondenseQuestionPrompt;

constructor(init: {
queryEngine: BaseQueryEngine;
chatHistory: ChatMessage[];
serviceContext?: ServiceContext;
condenseMessagePrompt?: SimplePrompt;
condenseMessagePrompt?: CondenseQuestionPrompt;
}) {
this.queryEngine = init.queryEngine;
this.chatHistory = init?.chatHistory ?? [];
Expand All @@ -92,14 +93,14 @@ export class CondenseQuestionChatEngine implements ChatEngine {
return this.serviceContext.llm.complete(
defaultCondenseQuestionPrompt({
question: question,
chat_history: chatHistoryStr,
})
chatHistory: chatHistoryStr,
}),
);
}

async chat(
message: string,
chatHistory?: ChatMessage[] | undefined
chatHistory?: ChatMessage[] | undefined,
): Promise<Response> {
chatHistory = chatHistory ?? this.chatHistory;

Expand Down Expand Up @@ -129,16 +130,20 @@ export class ContextChatEngine implements ChatEngine {
retriever: BaseRetriever;
chatModel: OpenAI;
chatHistory: ChatMessage[];
contextSystemPrompt: ContextSystemPrompt;

constructor(init: {
retriever: BaseRetriever;
chatModel?: OpenAI;
chatHistory?: ChatMessage[];
contextSystemPrompt?: ContextSystemPrompt;
}) {
this.retriever = init.retriever;
this.chatModel =
init.chatModel ?? new OpenAI({ model: "gpt-3.5-turbo-16k" });
this.chatHistory = init?.chatHistory ?? [];
this.contextSystemPrompt =
init?.contextSystemPrompt ?? defaultContextSystemPrompt;
}

async chat(message: string, chatHistory?: ChatMessage[] | undefined) {
Expand All @@ -151,11 +156,11 @@ export class ContextChatEngine implements ChatEngine {
};
const sourceNodesWithScore = await this.retriever.retrieve(
message,
parentEvent
parentEvent,
);

const systemMessage: ChatMessage = {
content: contextSystemPrompt({
content: this.contextSystemPrompt({
context: sourceNodesWithScore
.map((r) => (r.node as TextNode).text)
.join("\n\n"),
Expand All @@ -167,15 +172,15 @@ export class ContextChatEngine implements ChatEngine {

const response = await this.chatModel.chat(
[systemMessage, ...chatHistory],
parentEvent
parentEvent,
);
chatHistory.push(response.message);

this.chatHistory = chatHistory;

return new Response(
response.message.content,
sourceNodesWithScore.map((r) => r.node)
sourceNodesWithScore.map((r) => r.node),
);
}

Expand Down
53 changes: 30 additions & 23 deletions packages/core/src/Prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ DEFAULT_TEXT_QA_PROMPT_TMPL = (
)
*/

export const defaultTextQaPrompt: SimplePrompt = (input) => {
const { context = "", query = "" } = input;

export const defaultTextQaPrompt = ({ context = "", query = "" }) => {
return `Context information is below.
---------------------
${context}
Expand All @@ -34,6 +32,8 @@ Query: ${query}
Answer:`;
};

export type TextQaPrompt = typeof defaultTextQaPrompt;

/*
DEFAULT_SUMMARY_PROMPT_TMPL = (
"Write a summary of the following. Try to use only the "
Expand All @@ -48,9 +48,7 @@ DEFAULT_SUMMARY_PROMPT_TMPL = (
)
*/

export const defaultSummaryPrompt: SimplePrompt = (input) => {
const { context = "" } = input;

export const defaultSummaryPrompt = ({ context = "" }) => {
return `Write a summary of the following. Try to use only the information provided. Try to include as many key details as possible.
Expand All @@ -61,6 +59,8 @@ SUMMARY:"""
`;
};

export type SummaryPrompt = typeof defaultSummaryPrompt;

/*
DEFAULT_REFINE_PROMPT_TMPL = (
"The original query is as follows: {query_str}\n"
Expand All @@ -77,9 +77,11 @@ DEFAULT_REFINE_PROMPT_TMPL = (
)
*/

export const defaultRefinePrompt: SimplePrompt = (input) => {
const { query = "", existingAnswer = "", context = "" } = input;

export const defaultRefinePrompt = ({
query = "",
existingAnswer = "",
context = "",
}) => {
return `The original query is as follows: ${query}
We have provided an existing answer: ${existingAnswer}
We have the opportunity to refine the existing answer (only if needed) with some more context below.
Expand All @@ -90,6 +92,8 @@ Given the new context, refine the original answer to better answer the query. If
Refined Answer:`;
};

export type RefinePrompt = typeof defaultRefinePrompt;

/*
DEFAULT_TREE_SUMMARIZE_TMPL = (
"Context information from multiple sources is below.\n"
Expand All @@ -103,9 +107,7 @@ DEFAULT_TREE_SUMMARIZE_TMPL = (
)
*/

export const defaultTreeSummarizePrompt: SimplePrompt = (input) => {
const { context = "", query = "" } = input;

export const defaultTreeSummarizePrompt = ({ context = "", query = "" }) => {
return `Context information from multiple sources is below.
---------------------
${context}
Expand All @@ -115,9 +117,9 @@ Query: ${query}
Answer:`;
};

export const defaultChoiceSelectPrompt: SimplePrompt = (input) => {
const { context = "", query = "" } = input;
export type TreeSummarizePrompt = typeof defaultTreeSummarizePrompt;

export const defaultChoiceSelectPrompt = ({ context = "", query = "" }) => {
return `A list of documents is shown below. Each document has a number next to it along
with a summary of the document. A question is also provided.
Respond with the numbers of the documents
Expand Down Expand Up @@ -149,6 +151,8 @@ Question: ${query}
Answer:`;
};

export type ChoiceSelectPrompt = typeof defaultChoiceSelectPrompt;

/*
PREFIX = """\
Given a user question, and a list of tools, output a list of relevant sub-questions \
Expand Down Expand Up @@ -266,9 +270,7 @@ const exampleOutput: SubQuestion[] = [
},
];

export const defaultSubQuestionPrompt: SimplePrompt = (input) => {
const { toolsStr, queryStr } = input;

export const defaultSubQuestionPrompt = ({ toolsStr = "", queryStr = "" }) => {
return `Given a user question, and a list of tools, output a list of relevant sub-questions that when composed can help answer the full user question:
# Example 1
Expand Down Expand Up @@ -298,6 +300,8 @@ ${queryStr}
`;
};

export type SubQuestionPrompt = typeof defaultSubQuestionPrompt;

// DEFAULT_TEMPLATE = """\
// Given a conversation (between Human and Assistant) and a follow up message from Human, \
// rewrite the message to be a standalone question that captures all relevant context \
Expand All @@ -312,9 +316,10 @@ ${queryStr}
// <Standalone question>
// """

export const defaultCondenseQuestionPrompt: SimplePrompt = (input) => {
const { chatHistory, question } = input;

export const defaultCondenseQuestionPrompt = ({
chatHistory = "",
question = "",
}) => {
return `Given a conversation (between Human and Assistant) and a follow up message from Human, rewrite the message to be a standalone question that captures all relevant context from the conversation.
<Chat History>
Expand All @@ -327,6 +332,8 @@ ${question}
`;
};

export type CondenseQuestionPrompt = typeof defaultCondenseQuestionPrompt;

export function messagesToHistoryStr(messages: ChatMessage[]) {
return messages.reduce((acc, message) => {
acc += acc ? "\n" : "";
Expand All @@ -339,11 +346,11 @@ export function messagesToHistoryStr(messages: ChatMessage[]) {
}, "");
}

export const contextSystemPrompt: SimplePrompt = (input) => {
const { context } = input;

export const defaultContextSystemPrompt = ({ context = "" }) => {
return `Context information is below.
---------------------
${context}
---------------------`;
};

export type ContextSystemPrompt = typeof defaultContextSystemPrompt;
6 changes: 3 additions & 3 deletions packages/core/src/QuestionGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import {
SubQuestionOutputParser,
} from "./OutputParser";
import {
SimplePrompt,
SubQuestionPrompt,
buildToolsText,
defaultSubQuestionPrompt,
} from "./Prompt";
Expand All @@ -28,7 +28,7 @@ export interface BaseQuestionGenerator {
*/
export class LLMQuestionGenerator implements BaseQuestionGenerator {
llm: LLM;
prompt: SimplePrompt;
prompt: SubQuestionPrompt;
outputParser: BaseOutputParser<StructuredOutput<SubQuestion[]>>;

constructor(init?: Partial<LLMQuestionGenerator>) {
Expand All @@ -45,7 +45,7 @@ export class LLMQuestionGenerator implements BaseQuestionGenerator {
this.prompt({
toolsStr,
queryStr,
})
}),
)
).message.content;

Expand Down
Loading

0 comments on commit 259fe63

Please sign in to comment.