Skip to content

Commit

Permalink
feat(prompt): prompt.stream(), options.endpoint (#57)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: `prompt()` is now throwing an error if the API responds
with an error status code
  • Loading branch information
gr2m authored Sep 5, 2024
1 parent f8eec8f commit 9533a1f
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 58 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,22 @@ await prompt({
});
```
#### `prompt.stream(message, options)`
Works the same way as `prompt()`, but resolves with a `stream` key instead of a `message` key.
```js
import { prompt } from "@copilot-extensions/preview-sdk";

const { requestId, stream } = prompt.stream("What is the capital of France?", {
token: process.env.TOKEN,
});

for await (const chunk of stream) {
console.log(new TextDecoder().decode(chunk));
}
```
### `getFunctionCalls()`
Convenience metthod if a result from a `prompt()` call includes function calls.
Expand Down
18 changes: 16 additions & 2 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ export interface OpenAICompatibilityPayload {
export interface CopilotMessage {
role: string;
content: string;
copilot_references: MessageCopilotReference[];
copilot_references?: MessageCopilotReference[];
copilot_confirmations?: MessageCopilotConfirmation[];
tool_calls?: {
function: {
Expand Down Expand Up @@ -300,8 +300,9 @@ export interface PromptFunction {
}

export type PromptOptions = {
model?: ModelName;
token: string;
endpoint?: string;
model?: ModelName;
tools?: PromptFunction[];
messages?: InteropMessage[];
request?: {
Expand All @@ -314,12 +315,25 @@ export type PromptResult = {
message: CopilotMessage;
};

export type PromptStreamResult = {
requestId: string;
stream: ReadableStream<Uint8Array>;
};

// https://stackoverflow.com/a/69328045
type WithRequired<T, K extends keyof T> = T & { [P in K]-?: T[P] };

interface PromptInterface {
(userPrompt: string, options: PromptOptions): Promise<PromptResult>;
(options: WithRequired<PromptOptions, "messages">): Promise<PromptResult>;
stream: PromptStreamInterface;
}

interface PromptStreamInterface {
(userPrompt: string, options: PromptOptions): Promise<PromptStreamResult>;
(
options: WithRequired<PromptOptions, "messages">,
): Promise<PromptStreamResult>;
}

interface GetFunctionCallsInterface {
Expand Down
21 changes: 18 additions & 3 deletions index.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ export function getUserConfirmationTest(payload: CopilotRequestPayload) {

export async function promptTest() {
const result = await prompt("What is the capital of France?", {
model: "gpt-4",
token: "secret",
});

Expand All @@ -311,7 +310,6 @@ export async function promptTest() {

// with custom fetch
await prompt("What is the capital of France?", {
model: "gpt-4",
token: "secret",
request: {
fetch: () => {},
Expand All @@ -327,7 +325,6 @@ export async function promptTest() {

export async function promptWithToolsTest() {
await prompt("What is the capital of France?", {
model: "gpt-4",
token: "secret",
tools: [
{
Expand Down Expand Up @@ -366,6 +363,24 @@ export async function promptWithoutMessageButMessages() {
});
}

export async function otherPromptOptionsTest() {
const result = await prompt("What is the capital of France?", {
token: "secret",
model: "gpt-4",
endpoint: "https://api.githubcopilot.com",
});
}

export async function promptStreamTest() {
const result = await prompt.stream("What is the capital of France?", {
model: "gpt-4",
token: "secret",
});

expectType<string>(result.requestId);
expectType<ReadableStream<Uint8Array>>(result.stream);
}

export async function getFunctionCallsTest(
promptResponsePayload: PromptResult,
) {
Expand Down
114 changes: 81 additions & 33 deletions lib/prompt.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
// @ts-check

/** @type {import('..').PromptInterface} */
export async function prompt(userPrompt, promptOptions) {
const options = typeof userPrompt === "string" ? promptOptions : userPrompt;

const promptFetch = options.request?.fetch || fetch;
const modelName = options.model || "gpt-4";
function parsePromptArguments(userPrompt, promptOptions) {
const { request: requestOptions, ...options } =
typeof userPrompt === "string" ? promptOptions : userPrompt;

const promptFetch = requestOptions?.fetch || fetch;
const model = options.model || "gpt-4";
const endpoint =
options.endpoint || "https://api.githubcopilot.com/chat/completions";

const systemMessage = options.tools
? "You are a helpful assistant. Use the supplied tools to assist the user."
: "You are a helpful assistant.";
const toolsChoice = options.tools ? "auto" : undefined;

const messages = [
{
Expand All @@ -29,44 +34,87 @@ export async function prompt(userPrompt, promptOptions) {
});
}

const response = await promptFetch(
"https://api.githubcopilot.com/chat/completions",
{
method: "POST",
headers: {
accept: "application/json",
"content-type": "application/json; charset=UTF-8",
"user-agent": "copilot-extensions/preview-sdk.js",
authorization: `Bearer ${options.token}`,
},
body: JSON.stringify({
messages: messages,
model: modelName,
toolChoice: options.tools ? "auto" : undefined,
tools: options.tools,
}),
}
);
return [promptFetch, { ...options, messages, model, endpoint, toolsChoice }];
}

if (response.ok) {
const data = await response.json();
async function sendPromptRequest(promptFetch, options) {
const { endpoint, token, ...payload } = options;
const method = "POST";
const headers = {
accept: "application/json",
"content-type": "application/json; charset=UTF-8",
"user-agent": "copilot-extensions/preview-sdk.js",
authorization: `Bearer ${token}`,
};

return {
requestId: response.headers.get("x-request-id"),
message: data.choices[0].message,
};
const response = await promptFetch(endpoint, {
method,
headers,
body: JSON.stringify(payload),
});

if (response.ok) {
return response;
}

const body = await response.text();
console.log({ body });

throw Object.assign(
new Error(
`[@copilot-extensions/preview-sdk] An error occured with the chat completions API`,
),
{
name: "PromptError",
request: {
method: "POST",
url: endpoint,
headers: {
...headers,
authorization: `Bearer [REDACTED]`,
},
body: payload,
},
response: {
status: response.status,
headers: [...response.headers],
body: body,
},
},
);
}
export async function prompt(userPrompt, promptOptions) {
const [promptFetch, options] = parsePromptArguments(
userPrompt,
promptOptions,
);
const response = await sendPromptRequest(promptFetch, options);
const requestId = response.headers.get("x-request-id");

const data = await response.json();

return {
requestId: requestId,
message: {
role: "Sssistant",
content: `Sorry, an error occured with the chat completions API. (Status: ${response.status}, request ID: ${requestId})`,
},
requestId,
message: data.choices[0].message,
};
}

prompt.stream = async function promptStream(userPrompt, promptOptions) {
const [promptFetch, options] = parsePromptArguments(
userPrompt,
promptOptions,
);
const response = await sendPromptRequest(promptFetch, {
...options,
stream: true,
});

return {
requestId: response.headers.get("x-request-id"),
stream: response.body,
};
};

/** @type {import('..').GetFunctionCallsInterface} */
export function getFunctionCalls(payload) {
const functionCalls = payload.message.tool_calls;
Expand Down
Loading

0 comments on commit 9533a1f

Please sign in to comment.