Skip to content

Commit

Permalink
Azure open ai services setup (#3)
Browse files Browse the repository at this point in the history
* Setup support for Azure OpenAI Services

* Fixed OpenAI api changes since package upgrade

* Fixed issues after upgrading OpenAI package

* Update env example to include new env vars

* Maybe a fix?

---------

Co-authored-by: Keith Gibson <keithwgibson74@gmail.com>
  • Loading branch information
EdwardPrentice and pureit-dev authored Dec 14, 2023
1 parent 8e4a772 commit 440e4ad
Show file tree
Hide file tree
Showing 16 changed files with 1,490 additions and 626 deletions.
15 changes: 14 additions & 1 deletion .env.local.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Chatbot UI
DEFAULT_MODEL=gpt-3.5-turbo
OPENAI_API_KEY=YOUR_KEY
OPENAI_API_KEY= # Either OpenAI or Azure OpenAI key
NEXT_PUBLIC_DEFAULT_SYSTEM_PROMPT="You are ChatGPT, a large language model trained by OpenAI. Follow the user's instructions carefully. Respond using markdown."

# Specify url to a json file that list available plugins.
Expand Down Expand Up @@ -42,8 +42,21 @@ COGNITO_CLIENT_ID=
COGNITO_CLIENT_SECRET=
COGNITO_ISSUER=

# Azure Auth
AZURE_AD_CLIENT_ID=
AZURE_AD_TENANT_ID=
AZURE_AD_CLIENT_SECRET=

# Audit Log
AUDIT_LOG_ENABLED=false

# For Debugging
DEBUG_AGENT_LLM_LOGGING=true


# Azure OpenAI Service
OPENAI_API_HOST= # https://<deployment_name>.openai.azure.com
OPENAI_API_TYPE= # openai or azure
OPENAI_API_VERSION= # e.g 2023-07-01-preview
AZURE_DEPLOYMENT_ID_EMBEDDINGS= # Your embeddings deployment name
AZURE_DEPLOYMENT_ID= # Your deployment name
2 changes: 1 addition & 1 deletion .github/workflows/run-test-suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
timeout-minutes: 5
container:
image: node:16
image: node:20

steps:
- name: Checkout code
Expand Down
125 changes: 114 additions & 11 deletions agent/agent.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,120 @@
import { Plugin, PluginResult, ReactAgentResult } from '@/types/agent';

import { DebugCallbackHandler, stripQuotes } from './agentUtil';
import { DebugCallbackHandler as ImportedDebugCallbackHandler, stripQuotes } from './agentUtil';
import { TaskExecutionContext } from './plugins/executor';
import { listToolsBySpecifiedPlugins } from './plugins/list';
import prompts from './prompts/agent';

import chalk from 'chalk';
import { CallbackManager } from 'langchain/callbacks';
import { CallbackManager, NewTokenIndices } from 'langchain/callbacks';
import { PromptTemplate } from 'langchain/prompts';
import { ChatCompletionRequestMessage, Configuration, OpenAIApi } from 'openai';
import { OpenAI } from 'openai';
import { getOpenAIApi } from '@/utils/server/openai';
import { OpenAIError } from '@/utils/server';
import { saveLlmUsage } from '@/utils/server/llmUsage';
import { BaseCallbackHandler } from 'langchain/callbacks';
import { SerializedFields } from '@langchain/core/dist/load/map_keys';
import { HandleLLMNewTokenCallbackFields } from 'langchain/dist/callbacks/base';
import { Document } from 'langchain/dist/document';
import { Serialized, SerializedNotImplemented } from 'langchain/dist/load/serializable';
import { LLMResult, BaseMessage, ChainValues, AgentAction, AgentFinish } from 'langchain/dist/schema';

class DebugCallbackHandler implements BaseCallbackHandler {
lc_serializable: boolean = false;
get lc_namespace(): ["langchain_core", "callbacks", string] {
throw new Error('Method not implemented.');
}
get lc_secrets(): { [key: string]: string; } | undefined {
throw new Error('Method not implemented.');
}
get lc_attributes(): { [key: string]: string; } | undefined {
throw new Error('Method not implemented.');
}
get lc_aliases(): { [key: string]: string; } | undefined {
throw new Error('Method not implemented.');
}
get lc_id(): string[] {
throw new Error('Method not implemented.');
}
lc_kwargs!: SerializedFields;
name!: string;
ignoreLLM!: boolean;
ignoreChain!: boolean;
ignoreAgent!: boolean;
ignoreRetriever!: boolean;
awaitHandlers!: boolean;
copy(): BaseCallbackHandler {
throw new Error('Method not implemented.');
}
toJSON(): Serialized {
throw new Error('Method not implemented.');
}
toJSONNotImplemented(): SerializedNotImplemented {
throw new Error('Method not implemented.');
}
handleLLMNewToken?(token: string, idx: NewTokenIndices, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined, fields?: HandleLLMNewTokenCallbackFields | undefined) {
throw new Error('Method not implemented.');
}
handleLLMError?(err: any, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined) {
throw new Error('Method not implemented.');
}
handleLLMEnd?(output: LLMResult, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined) {
throw new Error('Method not implemented.');
}
handleChatModelStart?(llm: Serialized, messages: BaseMessage[][], runId: string, parentRunId?: string | undefined, extraParams?: Record<string, unknown> | undefined, tags?: string[] | undefined, metadata?: Record<string, unknown> | undefined, name?: string | undefined) {
throw new Error('Method not implemented.');
}
handleChainStart?(chain: Serialized, inputs: ChainValues, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined, metadata?: Record<string, unknown> | undefined, runType?: string | undefined, name?: string | undefined) {
throw new Error('Method not implemented.');
}
handleChainError?(err: any, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined, kwargs?: { inputs?: Record<string, unknown> | undefined; } | undefined) {
throw new Error('Method not implemented.');
}
handleChainEnd?(outputs: ChainValues, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined, kwargs?: { inputs?: Record<string, unknown> | undefined; } | undefined) {
throw new Error('Method not implemented.');
}
handleToolStart?(tool: Serialized, input: string, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined, metadata?: Record<string, unknown> | undefined, name?: string | undefined) {
throw new Error('Method not implemented.');
}
handleToolError?(err: any, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined) {
throw new Error('Method not implemented.');
}
handleToolEnd?(output: string, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined) {
throw new Error('Method not implemented.');
}
handleText?(text: string, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined): void | Promise<void> {
throw new Error('Method not implemented.');
}
handleAgentAction?(action: AgentAction, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined): void | Promise<void> {
throw new Error('Method not implemented.');
}
handleAgentEnd?(action: AgentFinish, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined): void | Promise<void> {
throw new Error('Method not implemented.');
}
handleRetrieverStart?(retriever: Serialized, query: string, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined, metadata?: Record<string, unknown> | undefined, name?: string | undefined) {
throw new Error('Method not implemented.');
}
handleRetrieverEnd?(documents: Document<Record<string, any>>[], runId: string, parentRunId?: string | undefined, tags?: string[] | undefined) {
throw new Error('Method not implemented.');
}
handleRetrieverError?(err: any, runId: string, parentRunId?: string | undefined, tags?: string[] | undefined) {
throw new Error('Method not implemented.');
}
handleLLMStart(
llm: Serialized,
prompts: string[],
runId: string,
parentRunId?: string,
extraParams?: Record<string, unknown>,
tags?: string[],
metadata?: Record<string, unknown>,
name?: string
): any {
// implementation goes here
}

// implement other required methods here
}

const setupCallbackManager = (verbose: boolean): void => {
const callbackManager = new CallbackManager();
Expand Down Expand Up @@ -84,7 +187,7 @@ const createMessages = async (
tools: Plugin[],
pluginResults: PluginResult[],
input: string
): Promise<ChatCompletionRequestMessage[]> => {
): Promise<OpenAI.Chat.ChatCompletionMessageParam[]> => {
const { sytemPrompt, userPrompt } = createPrompts();
const agentScratchpad = createAgentScratchpad(pluginResults);
const toolDescriptions = tools
Expand Down Expand Up @@ -113,7 +216,7 @@ const createMessages = async (
];
};

const logVerboseRequest = (messages: ChatCompletionRequestMessage[]): void => {
const logVerboseRequest = (messages: OpenAI.Chat.ChatCompletionMessageParam[]): void => {
console.log(chalk.greenBright('LLM Request:'));
for (const message of messages) {
console.log(chalk.blue(message.role + ': ') + message.content);
Expand Down Expand Up @@ -149,7 +252,7 @@ export const executeNotConversationalReactAgent = async (
}
let result;
try {
result = await openai.createChatCompletion({
result = await openai.chat.completions.create({
model: context.model.id,
messages,
temperature: 0.0,
Expand All @@ -163,15 +266,15 @@ export const executeNotConversationalReactAgent = async (
}

await saveLlmUsage(context.userId, context.model.id, "agent", {
prompt: result.data.usage!.prompt_tokens,
completion: result.data.usage!.completion_tokens,
total: result.data.usage!.total_tokens
prompt: result.usage!.prompt_tokens,
completion: result.usage!.completion_tokens,
total: result.usage!.total_tokens
})

const responseText = result.data.choices[0].message?.content;
const responseText = result.choices[0].message?.content;
const ellapsed = Date.now() - start;
if (verbose) {
logVerboseResponse(ellapsed, responseText);
logVerboseResponse(ellapsed, responseText ?? undefined);
}
const output = parseResult(tools, responseText!);
return output;
Expand Down
53 changes: 29 additions & 24 deletions agent/agentConvo.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import { OpenAIError } from '@/utils/server';
import { saveLlmUsage } from '@/utils/server/llmUsage';
import { getOpenAIApi } from '@/utils/server/openai';

import { Plugin, PluginResult, ReactAgentResult } from '@/types/agent';
import { Message } from '@/types/chat';

import {
DebugCallbackHandler,
createAgentHistory,
messagesToOpenAIMessages,
} from './agentUtil';
import { createAgentHistory, messagesToOpenAIMessages } from './agentUtil';
import { TaskExecutionContext } from './plugins/executor';
import { listToolsBySpecifiedPlugins } from './plugins/list';
import prompts from './prompts/agentConvo';

import chalk from 'chalk';
import { CallbackManager } from 'langchain/callbacks';
import { BaseCallbackHandler } from 'langchain/callbacks';
import { PromptTemplate } from 'langchain/prompts';
import { ChatCompletionRequestMessage, Configuration, OpenAIApi } from 'openai';
import { getOpenAIApi } from '@/utils/server/openai';
import { OpenAIError } from '@/utils/server';
import { saveLlmUsage } from '@/utils/server/llmUsage';
import OpenAI from 'openai';

class DebugCallbackHandler extends BaseCallbackHandler {
name!: string;
// implementation of DebugCallbackHandler class
}

const setupCallbackManager = (verbose: boolean): void => {
const callbackManager = new CallbackManager();
Expand Down Expand Up @@ -53,8 +56,8 @@ const createPrompts = (): {
const createToolResponse = async (
pluginResults: PluginResult[],
toolResponsePrompt: PromptTemplate,
): Promise<ChatCompletionRequestMessage[]> => {
let toolResponse: ChatCompletionRequestMessage[] = [];
): Promise<OpenAI.Chat.ChatCompletionMessageParam[]> => {
let toolResponse: OpenAI.Chat.ChatCompletionMessageParam[] = [];
if (pluginResults.length > 0) {
for (const actionResult of pluginResults) {
const toolResponseContent = await toolResponsePrompt.format({
Expand All @@ -77,7 +80,7 @@ const createFormattedPrompts = async (
userPrompt: PromptTemplate,
input: string,
toolDescriptions: string,
toolResponse: ChatCompletionRequestMessage[],
toolResponse: OpenAI.Chat.ChatCompletionMessageParam[],
): Promise<{ systemContent: string; userContent: string }> => {
const systemContent = await sytemPrompt.format({
locale: context.locale,
Expand All @@ -102,7 +105,7 @@ const createMessages = async (
history: Message[],
modelId: string,
input: string,
): Promise<ChatCompletionRequestMessage[]> => {
): Promise<OpenAI.Chat.ChatCompletionMessageParam[]> => {
const { sytemPrompt, formatPrompt, userPrompt, toolResponsePrompt } =
createPrompts();
const toolResponse = await createToolResponse(
Expand Down Expand Up @@ -143,7 +146,9 @@ const createMessages = async (
];
};

const logVerboseRequest = (messages: ChatCompletionRequestMessage[]): void => {
const logVerboseRequest = (
messages: OpenAI.Chat.ChatCompletionMessageParam[],
): void => {
console.log(chalk.greenBright('LLM Request:'));
for (const message of messages) {
console.log(chalk.blue(message.role + ': ') + message.content);
Expand Down Expand Up @@ -188,7 +193,7 @@ export const executeReactAgent = async (
}
let result;
try {
result = await openai.createChatCompletion({
result = await openai.chat.completions.create({
model: modelId,
messages,
temperature: 0.0,
Expand All @@ -197,20 +202,20 @@ export const executeReactAgent = async (
} catch (error: any) {
if (error.response) {
const { message, type, param, code } = error.response.data.error;
throw new OpenAIError(message, type, param, code)
} else throw error
throw new OpenAIError(message, type, param, code);
} else throw error;
}

await saveLlmUsage(context.userId, context.model.id, "agentConv", {
prompt: result.data.usage!.prompt_tokens,
completion: result.data.usage!.completion_tokens,
total: result.data.usage!.total_tokens
})
await saveLlmUsage(context.userId, context.model.id, 'agentConv', {
prompt: result.usage!.prompt_tokens,
completion: result.usage!.completion_tokens,
total: result.usage!.total_tokens,
});

const responseText = result.data.choices[0].message?.content;
const responseText = result.choices[0].message?.content;
const ellapsed = Date.now() - start;
if (verbose) {
logVerboseResponse(ellapsed, responseText);
logVerboseResponse(ellapsed, responseText || undefined);
}
const output = parseResult(tools, responseText!);
return output;
Expand Down
22 changes: 14 additions & 8 deletions agent/agentUtil.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import { Message } from '@/types/chat';

import { Tiktoken } from 'tiktoken';
import chalk from 'chalk';
import { ConsoleCallbackHandler } from 'langchain/callbacks';
import { ConsoleCallbackHandler, Run } from 'langchain/callbacks';
import { LLMResult } from 'langchain/dist/schema';
import { ChatCompletionRequestMessage } from 'openai';
import OpenAI from 'openai';
import { Serialized } from 'langchain/dist/load/serializable';

const strip = (str: string, c: string) => {
const m = str.match(new RegExp(`^${c}(.*)${c}$`));
Expand All @@ -24,23 +25,28 @@ export class DebugCallbackHandler extends ConsoleCallbackHandler {
alwaysVerbose: boolean = true;
llmStartTime: number = 0;
async handleLLMStart(
llm: {
name: string;
},
llm: Serialized,
prompts: string[],
runId: string,
): Promise<void> {
parentRunId?: string,
extraParams?: Map<any, any>,
tags?: string[],
metadata?: Map<any, any>,
name?: string,
): Promise<Run> {
this.llmStartTime = Date.now();
console.log(chalk.greenBright('handleLLMStart ============'));
console.log(prompts[0]);
console.log('');
return Promise.resolve({} as Run);
}
async handleLLMEnd(output: LLMResult, runId: string) {
async handleLLMEnd(output: LLMResult, runId: string): Promise<Run> {
const duration = Date.now() - this.llmStartTime;
console.log(chalk.greenBright('handleLLMEnd =============='));
console.log(`ellapsed: ${duration / 1000} sec.`);
console.log(output.generations[0][0].text);
console.log('');
return Promise.resolve({} as Run);
}
async handleText(text: string): Promise<void> {
console.log(chalk.greenBright('handleText =========='));
Expand Down Expand Up @@ -69,7 +75,7 @@ export const createAgentHistory = (

export const messagesToOpenAIMessages = (
messages: Message[],
): ChatCompletionRequestMessage[] => {
): OpenAI.Chat.ChatCompletionMessageParam[] => {
return messages.map((msg) => {
return {
role: msg.role,
Expand Down
4 changes: 2 additions & 2 deletions agent/plugins/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ export default {
const openai = getOpenAIApi(model.azureDeploymentId)
let answerRes;
try {
answerRes = await openai.createChatCompletion({
answerRes = await openai.chat.completions.create({
model: model.id,
messages: [
{
Expand All @@ -151,7 +151,7 @@ export default {
} else throw error
}

const { choices, usage } = answerRes.data;
const { choices, usage } = answerRes;
const answer = choices[0].message!.content!;
encoding.free();

Expand Down
Loading

0 comments on commit 440e4ad

Please sign in to comment.