Skip to content

Commit

Permalink
feat: add hugging face llm (run-llama#796)
Browse files Browse the repository at this point in the history
  • Loading branch information
thucpn authored May 2, 2024
1 parent 8aeb8ae commit d10533e
Show file tree
Hide file tree
Showing 7 changed files with 9,221 additions and 11,915 deletions.
5 changes: 5 additions & 0 deletions .changeset/hot-tools-pay.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"llamaindex": patch
---

feat: add hugging face llm
22 changes: 22 additions & 0 deletions examples/huggingface/chat.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { HuggingFaceInferenceAPI } from "llamaindex";

(async () => {
if (!process.env.HUGGING_FACE_TOKEN) {
throw new Error("Please set the HUGGING_FACE_TOKEN environment variable.");
}
const hf = new HuggingFaceInferenceAPI({
accessToken: process.env.HUGGING_FACE_TOKEN,
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
});
const result = await hf.chat({
messages: [
{ content: "You want to talk in rhymes.", role: "system" },
{
content:
"How much wood would a woodchuck chuck if a woodchuck could chuck wood?",
role: "user",
},
],
});
console.log(result);
})();
File renamed without changes.
1 change: 1 addition & 0 deletions packages/core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"@datastax/astra-db-ts": "^1.0.1",
"@google/generative-ai": "^0.8.0",
"@grpc/grpc-js": "^1.10.6",
"@huggingface/inference": "^2.6.7",
"@llamaindex/cloud": "0.0.5",
"@llamaindex/env": "workspace:*",
"@mistralai/mistralai": "^0.1.3",
Expand Down
141 changes: 141 additions & 0 deletions packages/core/src/llm/huggingface.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import {
HfInference,
type Options as HfInferenceOptions,
} from "@huggingface/inference";
import { BaseLLM } from "./base.js";
import type {
ChatMessage,
ChatResponse,
ChatResponseChunk,
LLMChatParamsNonStreaming,
LLMChatParamsStreaming,
LLMMetadata,
ToolCallLLMMessageOptions,
} from "./types.js";
import { streamConverter, wrapLLMEvent } from "./utils.js";

const DEFAULT_PARAMS = {
temperature: 0.1,
topP: 1,
maxTokens: undefined,
contextWindow: 3900,
};
export type HFConfig = Partial<typeof DEFAULT_PARAMS> &
HfInferenceOptions & {
model: string;
accessToken: string;
endpoint?: string;
};

/**
Wrapper on the Hugging Face's Inference API.
API Docs: https://huggingface.co/docs/huggingface.js/inference/README
List of tasks with models: huggingface.co/api/tasks
Note that Conversational API is not yet supported by the Inference API.
They recommend using the text generation API instead.
See: https://github.com/huggingface/huggingface.js/issues/586#issuecomment-2024059308
*/
export class HuggingFaceInferenceAPI extends BaseLLM {
model: string;
temperature: number;
topP: number;
maxTokens?: number;
contextWindow: number;
hf: HfInference;

constructor(init: HFConfig) {
super();
const {
model,
temperature,
topP,
maxTokens,
contextWindow,
accessToken,
endpoint,
...hfInferenceOpts
} = init;
this.hf = new HfInference(accessToken, hfInferenceOpts);
this.model = model;
this.temperature = temperature ?? DEFAULT_PARAMS.temperature;
this.topP = topP ?? DEFAULT_PARAMS.topP;
this.maxTokens = maxTokens ?? DEFAULT_PARAMS.maxTokens;
this.contextWindow = contextWindow ?? DEFAULT_PARAMS.contextWindow;
if (endpoint) this.hf.endpoint(endpoint);
}

get metadata(): LLMMetadata {
return {
model: this.model,
temperature: this.temperature,
topP: this.topP,
maxTokens: this.maxTokens,
contextWindow: this.contextWindow,
tokenizer: undefined,
};
}

chat(
params: LLMChatParamsStreaming,
): Promise<AsyncIterable<ChatResponseChunk>>;
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
@wrapLLMEvent
async chat(
params: LLMChatParamsStreaming | LLMChatParamsNonStreaming,
): Promise<AsyncIterable<ChatResponseChunk> | ChatResponse<object>> {
if (params.stream) return this.streamChat(params);
return this.nonStreamChat(params);
}

private messagesToPrompt(messages: ChatMessage<ToolCallLLMMessageOptions>[]) {
let prompt = "";
for (const message of messages) {
if (message.role === "system") {
prompt += `<|system|>\n${message.content}</s>\n`;
} else if (message.role === "user") {
prompt += `<|user|>\n${message.content}</s>\n`;
} else if (message.role === "assistant") {
prompt += `<|assistant|>\n${message.content}</s>\n`;
}
}
// ensure we start with a system prompt, insert blank if needed
if (!prompt.startsWith("<|system|>\n")) {
prompt = "<|system|>\n</s>\n" + prompt;
}
// add final assistant prompt
prompt = prompt + "<|assistant|>\n";
return prompt;
}

protected async nonStreamChat(
params: LLMChatParamsNonStreaming,
): Promise<ChatResponse> {
const res = await this.hf.textGeneration({
model: this.model,
inputs: this.messagesToPrompt(params.messages),
parameters: this.metadata,
});
return {
raw: res,
message: {
content: res.generated_text,
role: "assistant",
},
};
}

protected async *streamChat(
params: LLMChatParamsStreaming,
): AsyncIterable<ChatResponseChunk> {
const stream = this.hf.textGenerationStream({
model: this.model,
inputs: this.messagesToPrompt(params.messages),
parameters: this.metadata,
});
yield* streamConverter(stream, (chunk) => ({
delta: chunk.token.text,
raw: chunk,
}));
}
}
1 change: 1 addition & 0 deletions packages/core/src/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ export {
export { FireworksLLM } from "./fireworks.js";
export { GEMINI_MODEL, Gemini } from "./gemini.js";
export { Groq } from "./groq.js";
export { HuggingFaceInferenceAPI } from "./huggingface.js";
export {
ALL_AVAILABLE_MISTRAL_MODELS,
MistralAI,
Expand Down
Loading

0 comments on commit d10533e

Please sign in to comment.