From 724f6d7be292a6c8b28983284f10b7cdc8a718cc Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 19 Nov 2025 11:35:55 +0000 Subject: [PATCH] feat(huggingface): add HuggingFace provider --- deploy/example.config.ts | 8 +++++++- deploy/example.env.local | 3 +++ gateway/src/providers/huggingface.ts | 11 +++++++++++ gateway/src/providers/index.ts | 3 +++ gateway/src/types.ts | 3 ++- 5 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 gateway/src/providers/huggingface.ts diff --git a/deploy/example.config.ts b/deploy/example.config.ts index fdf857a..202539e 100644 --- a/deploy/example.config.ts +++ b/deploy/example.config.ts @@ -2,7 +2,7 @@ import { env } from 'cloudflare:workers' import type { Config } from '@deploy/types' // can be whatever you want, just used to make linking apiKeys to providers typesafe. -type ProviderKeys = 'a' | 'b' | 'c' | 'd' | 'e' +type ProviderKeys = 'a' | 'b' | 'c' | 'd' | 'e' | 'huggingface' // projects, users and keys must have numeric keys, using constants here to make it easier to understand // of course, keys must be unique within a type (e.g. project ids must be unique) but users and projects can have the same id @@ -67,6 +67,12 @@ export const config: Config = { injectCost: true, credentials: env.AWS_BEARER_TOKEN_BEDROCK, }, + huggingface: { + providerId: 'huggingface', + baseUrl: 'https://api-inference.huggingface.co', + injectCost: true, + credentials: env.HF_TOKEN, + }, }, // routing groups for load balancing and fallback routingGroups: { diff --git a/deploy/example.env.local b/deploy/example.env.local index 06f16fb..e862077 100644 --- a/deploy/example.env.local +++ b/deploy/example.env.local @@ -20,5 +20,8 @@ AWS_BEARER_TOKEN_BEDROCK=... # python -c "import json;print(json.dumps(json.loads(open(input('Service account JSON file path: ')).read())))" GOOGLE_SERVICE_ACCOUNT_KEY=full service google service account key... +# same for Hugging Face, generate a token (you would use env.HF_TOKEN in config.ts) +HF_TOKEN=... + # password for viewing /status/ STATUS_AUTH_API_KEY="change-me!" diff --git a/gateway/src/providers/huggingface.ts b/gateway/src/providers/huggingface.ts new file mode 100644 index 0000000..130ad8d --- /dev/null +++ b/gateway/src/providers/huggingface.ts @@ -0,0 +1,11 @@ +import type { ModelAPI } from '../api' +import { ChatCompletionAPI } from '../api/chat' +import { DefaultProviderProxy } from './default' + +export class HuggingFaceProvider extends DefaultProviderProxy { + defaultBaseUrl = 'https://api-inference.huggingface.co' + + protected modelAPI(): ModelAPI { + return new ChatCompletionAPI('huggingface') + } +} diff --git a/gateway/src/providers/index.ts b/gateway/src/providers/index.ts index a18a5db..40e66dd 100644 --- a/gateway/src/providers/index.ts +++ b/gateway/src/providers/index.ts @@ -22,6 +22,7 @@ import { BedrockProvider } from './bedrock' import { DefaultProviderProxy, type ProviderOptions } from './default' import { GoogleVertexProvider } from './google' import { GroqProvider } from './groq' +import { HuggingFaceProvider } from './huggingface' import { OpenAIProvider } from './openai' import { TestProvider } from './test' @@ -39,6 +40,8 @@ export function getProvider(providerId: ProviderID): ProviderSig { return AnthropicProvider case 'bedrock': return BedrockProvider + case 'huggingface': + return HuggingFaceProvider case 'test': return TestProvider default: diff --git a/gateway/src/types.ts b/gateway/src/types.ts index 5676308..beda3f0 100644 --- a/gateway/src/types.ts +++ b/gateway/src/types.ts @@ -38,7 +38,7 @@ export interface ApiKeyInfo { otelSettings?: OtelSettings } -export type ProviderID = 'groq' | 'openai' | 'google-vertex' | 'anthropic' | 'test' | 'bedrock' +export type ProviderID = 'groq' | 'openai' | 'google-vertex' | 'anthropic' | 'test' | 'bedrock' | 'huggingface' // TODO | 'azure' | 'fireworks' | 'mistral' | 'cohere' const providerIds: Record = { @@ -48,6 +48,7 @@ const providerIds: Record = { anthropic: true, test: true, bedrock: true, + huggingface: true, } export const providerIdsArray = Object.keys(providerIds) as ProviderID[]