Skip to content

✨ Support for Featherless.ai as inference provider. #1310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 24, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
@@ -48,6 +48,7 @@ You can send inference requests to third-party providers with the inference clie

Currently, we support the following providers:
- [Fal.ai](https://fal.ai)
- [Featherless AI](https://featherless.ai)
- [Fireworks AI](https://fireworks.ai)
- [Hyperbolic](https://hyperbolic.xyz)
- [Nebius](https://studio.nebius.ai)
@@ -78,6 +79,7 @@ When authenticated with a third-party provider key, the request is made directly

Only a subset of models are supported when requesting third-party providers. You can check the list of supported models per pipeline tasks here:
- [Fal.ai supported models](https://huggingface.co/api/partners/fal-ai/models)
- [Featherless AI supported models](https://huggingface.co/api/partners/featherless-ai/models)
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
- [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models)
- [Nebius supported models](https://huggingface.co/api/partners/nebius/models)
6 changes: 5 additions & 1 deletion packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
@@ -2,10 +2,10 @@ import * as BlackForestLabs from "../providers/black-forest-labs";
import * as Cerebras from "../providers/cerebras";
import * as Cohere from "../providers/cohere";
import * as FalAI from "../providers/fal-ai";
import * as FeatherlessAI from "../providers/featherless-ai";
import * as Fireworks from "../providers/fireworks-ai";
import * as Groq from "../providers/groq";
import * as HFInference from "../providers/hf-inference";

import * as Hyperbolic from "../providers/hyperbolic";
import * as Nebius from "../providers/nebius";
import * as Novita from "../providers/novita";
@@ -64,6 +64,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
"text-to-video": new FalAI.FalAITextToVideoTask(),
"automatic-speech-recognition": new FalAI.FalAIAutomaticSpeechRecognitionTask(),
},
"featherless-ai": {
conversational: new FeatherlessAI.FeatherlessAIConversationalTask(),
"text-generation": new FeatherlessAI.FeatherlessAITextGenerationTask(),
},
"hf-inference": {
"text-to-image": new HFInference.HFInferenceTextToImageTask(),
conversational: new HFInference.HFInferenceConversationalTask(),
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
cerebras: {},
cohere: {},
"fal-ai": {},
"featherless-ai": {},
"fireworks-ai": {},
groq: {},
"hf-inference": {},
52 changes: 52 additions & 0 deletions packages/inference/src/providers/featherless-ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import type { ChatCompletionOutput, TextGenerationInput, TextGenerationOutput, TextGenerationOutputFinishReason } from "@huggingface/tasks";
import { InferenceOutputError } from "../lib/InferenceOutputError";
import type { BodyParams } from "../types";
import { BaseConversationalTask, BaseTextGenerationTask } from "./providerHelper";

interface FeatherlessAITextCompletionOutput extends Omit<ChatCompletionOutput, "choices"> {
choices: Array<{
text: string;
finish_reason: TextGenerationOutputFinishReason;
seed: number;
logprobs: unknown;
index: number;
}>;
}

const FEATHERLESS_API_BASE_URL = "https://api.featherless.ai";

export class FeatherlessAIConversationalTask extends BaseConversationalTask {
constructor() {
super("featherless-ai", FEATHERLESS_API_BASE_URL);
}
}

export class FeatherlessAITextGenerationTask extends BaseTextGenerationTask {
constructor() {
super("featherless-ai", FEATHERLESS_API_BASE_URL);
}

override preparePayload(params: BodyParams<TextGenerationInput>): Record<string, unknown> {
return {
...params.args,
...params.args.parameters,
model: params.model,
prompt: params.args.inputs,
};
}

override async getResponse(response: FeatherlessAITextCompletionOutput): Promise<TextGenerationOutput> {
if (
typeof response === "object" &&
"choices" in response &&
Array.isArray(response?.choices) &&
typeof response?.model === "string"
) {
const completion = response.choices[0];
return {
generated_text: completion.text,
};
}
throw new InferenceOutputError("Expected Featherless AI text generation response format");
}
}
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@ export const INFERENCE_PROVIDERS = [
"cerebras",
"cohere",
"fal-ai",
"featherless-ai",
"fireworks-ai",
"groq",
"hf-inference",
73 changes: 73 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
@@ -1045,6 +1045,79 @@ describe.skip("InferenceClient", () => {
TIMEOUT
);

describe.concurrent(
"Featherless",
() => {
HARDCODED_MODEL_INFERENCE_MAPPING["featherless-ai"] = {
"meta-llama/Llama-3.1-8B": {
providerId: "meta-llama/Meta-Llama-3.1-8B",
hfModelId: "meta-llama/Llama-3.1-8B",
task: "text-generation",
status: "live",
},
"meta-llama/Llama-3.1-8B-Instruct": {
providerId: "meta-llama/Meta-Llama-3.1-8B-Instruct",
hfModelId: "meta-llama/Llama-3.1-8B-Instruct",
task: "text-generation",
status: "live",
},
};

it("chatCompletion", async () => {
const res = await chatCompletion({
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
model: "meta-llama/Llama-3.1-8B-Instruct",
provider: "featherless-ai",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
temperature: 0.1,
});

expect(res).toBeDefined();
expect(res.choices).toBeDefined();
expect(res.choices?.length).toBeGreaterThan(0);

if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toBeDefined();
expect(typeof completion).toBe("string");
expect(completion).toContain("two");
}
});

it("chatCompletion stream", async () => {
const stream = chatCompletionStream({
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
model: "meta-llama/Llama-3.1-8B-Instruct",
provider: "featherless-ai",
messages: [{ role: "user", content: "Complete the equation 1 + 1 = , just the answer" }],
}) as AsyncGenerator<ChatCompletionStreamOutput>;
let out = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
out += chunk.choices[0].delta.content;
}
}
expect(out).toContain("2");
});

it("textGeneration", async () => {
const res = await textGeneration({
accessToken: env.HF_FEATHERLESS_KEY ?? "dummy",
model: "meta-llama/Llama-3.1-8B",
provider: "featherless-ai",
inputs: "Paris is a city of ",
parameters: {
temperature: 0,
top_p: 0.01,
max_tokens: 10,
},
});
expect(res).toMatchObject({ generated_text: "2.2 million people, and it is the" });
});
},
TIMEOUT
);

describe.concurrent(
"Replicate",
() => {