Skip to content

Commit 6623082

Browse files
authored
feat(huggingface): add HuggingFace provider (#156)
1 parent fc955e2 commit 6623082

21 files changed

+3087
-148
lines changed

deploy/example.config.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { env } from 'cloudflare:workers'
22
import type { Config } from '@deploy/types'
33

44
// can be whatever you want, just used to make linking apiKeys to providers typesafe.
5-
type ProviderKeys = 'openai' | 'anthropic' | 'google-vertex' | 'bedrock' | 'groq' | 'azure'
5+
type ProviderKeys = 'openai' | 'anthropic' | 'google-vertex' | 'bedrock' | 'groq' | 'azure' | 'huggingface'
66

77
// projects, users and keys must have numeric keys, using constants here to make it easier to understand
88
// of course, keys must be unique within a type (e.g. project ids must be unique) but users and projects can have the same id
@@ -98,6 +98,12 @@ export const config: Config<ProviderKeys> = {
9898
injectCost: true,
9999
credentials: env.AWS_BEARER_TOKEN_BEDROCK,
100100
},
101+
huggingface: {
102+
providerId: 'huggingface',
103+
baseUrl: 'https://router.huggingface.co/v1',
104+
injectCost: true,
105+
credentials: env.HF_TOKEN,
106+
},
101107
},
102108
// individual apiKeys
103109
apiKeys: {

deploy/example.env.local

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,8 @@ AWS_BEARER_TOKEN_BEDROCK=...
2323
# python -c "import json;print(json.dumps(json.loads(open(input('Service account JSON file path: ')).read())))"
2424
GOOGLE_SERVICE_ACCOUNT_KEY=full service google service account key...
2525

26+
# same for Hugging Face, generate a token (you would use env.HF_TOKEN in config.ts)
27+
HF_TOKEN=...
28+
2629
# password for viewing /status/
2730
STATUS_AUTH_API_KEY="change-me!"

examples/ex_huggingface.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
3+
from openai import OpenAI
4+
5+
api_key = os.getenv('PYDANTIC_AI_GATEWAY_API_KEY')
6+
assert api_key is not None
7+
8+
client = OpenAI(api_key=api_key, base_url='http://localhost:8787/huggingface/v1')
9+
10+
completion = client.chat.completions.create(
11+
model='openai/gpt-oss-20b:hyperbolic',
12+
messages=[{'role': 'user', 'content': 'What is the capital of France?'}],
13+
)
14+
15+
print(completion.choices[0].message)

examples/pai_huggingface.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import os
2+
from datetime import date
3+
4+
import logfire
5+
from huggingface_hub import AsyncInferenceClient
6+
from pydantic import BaseModel, field_validator
7+
from pydantic_ai import Agent, __version__
8+
from pydantic_ai.models.huggingface import HuggingFaceModel
9+
from pydantic_ai.providers.huggingface import HuggingFaceProvider
10+
11+
logfire.configure(service_name='testing')
12+
logfire.instrument_pydantic_ai()
13+
logfire.instrument_aiohttp_client(capture_all=True)
14+
print('pydantic-ai version:', __version__)
15+
16+
17+
class Person(BaseModel, use_attribute_docstrings=True):
18+
name: str
19+
"""The name of the person."""
20+
dob: date
21+
"""The date of birth of the person. MUST BE A VALID ISO 8601 date."""
22+
city: str
23+
"""The city where the person lives."""
24+
25+
@field_validator('dob')
26+
def validate_dob(cls, v: date) -> date:
27+
if v >= date(1900, 1, 1):
28+
raise ValueError('The person must be born in the 19th century')
29+
return v
30+
31+
32+
api_key = os.getenv('PYDANTIC_AI_GATEWAY_API_KEY')
33+
# api_key = os.getenv('HF_TOKEN')
34+
assert api_key is not None
35+
base_url = 'http://localhost:8787/huggingface'
36+
# base_url = None
37+
38+
hf_client = AsyncInferenceClient(api_key=api_key, provider='novita', base_url=base_url)
39+
provider = HuggingFaceProvider(hf_client=hf_client)
40+
model = HuggingFaceModel('moonshotai/Kimi-K2-Thinking', provider=provider)
41+
42+
person_agent = Agent(
43+
model=model,
44+
output_type=Person,
45+
instructions='Extract information about the person',
46+
)
47+
result = person_agent.run_sync("Samuel lived in London and was born on Jan 28th '87")
48+
print(repr(result.output))

examples/pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ version = "0.1.0"
44
dependencies = [
55
"boto3>=1.40.28",
66
"devtools>=0.12.2",
7-
"logfire[httpx]>=4.3.3",
7+
"logfire[httpx,aiohttp]>=4.3.3",
88
"opentelemetry-instrumentation-botocore>=0.57b0",
9-
"pydantic-ai>=1.10.0",
9+
"pydantic-ai[huggingface]>=1.10.0",
10+
"huggingface-hub<1.0",
1011
"types-boto3[bedrock-runtime]",
1112
"mypy-boto3-bedrock-runtime",
1213
]

gateway/src/api/base.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { Provider as UsageProvider } from '@pydantic/genai-prices'
12
import { extractUsage, findProvider, type Usage } from '@pydantic/genai-prices'
23
import type { GenAIAttributes, GenAIAttributesExtractor } from '../otel/attributes'
34
import type { InputMessages, OutputMessages, TextPart } from '../otel/genai'
@@ -48,6 +49,8 @@ export interface SafeExtractor<RequestBody, ResponseBody, StreamChunk> {
4849
export abstract class BaseAPI<RequestBody, ResponseBody, StreamChunk = JsonData>
4950
implements GenAIAttributesExtractor<RequestBody, ResponseBody>, SafeExtractor<RequestBody, ResponseBody, StreamChunk>
5051
{
52+
private usageProvider: UsageProvider | undefined
53+
5154
/** @apiFlavor: the flavor of the API, used to determine the response model and usage */
5255
apiFlavor: string | undefined = undefined
5356

@@ -57,9 +60,10 @@ export abstract class BaseAPI<RequestBody, ResponseBody, StreamChunk = JsonData>
5760
extractedRequest: ExtractedRequest = {}
5861
extractedResponse: Partial<ExtractedResponse> = {}
5962

60-
constructor(providerId: ProviderID, requestModel?: string) {
63+
constructor(providerId: ProviderID, requestModel?: string, options?: { usageProvider?: UsageProvider }) {
6164
this.providerId = providerId
6265
this.requestModel = requestModel
66+
this.usageProvider = options?.usageProvider
6367
}
6468

6569
requestExtractors: ExtractorConfig<RequestBody, ExtractedRequest> = {}
@@ -86,7 +90,7 @@ export abstract class BaseAPI<RequestBody, ResponseBody, StreamChunk = JsonData>
8690
}
8791

8892
extractUsage(responseBody: ResponseBody | StreamChunk): Usage | undefined {
89-
const provider = findProvider({ providerId: this.providerId })
93+
const provider = this.usageProvider ?? findProvider({ providerId: this.providerId })
9094
// This should never happen because we know the provider ID is valid, but we will throw an error to be safe.
9195
if (!provider) throw new Error(`Provider not found for provider ID: ${this.providerId}`)
9296
const { usage } = extractUsage(provider, responseBody, this.apiFlavor)

gateway/src/providers/default.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,8 @@ export class DefaultProviderProxy {
334334
const url = this.url()
335335

336336
// Validate that it's possible to calculate the price for the request model.
337-
if (requestModel && this.providerProxy.disableKey) {
337+
// HuggingFace is an exception because we will only know the real provider in the response headers.
338+
if (requestModel && this.providerProxy.disableKey && this.providerId() !== 'huggingface') {
338339
const price = calcPrice({ input_tokens: 0, output_tokens: 0 }, requestModel, { provider: this.usageProvider() })
339340
if (!price) {
340341
return { modelNotFound: true, requestModel }
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import { findProvider, type Provider as UsageProvider } from '@pydantic/genai-prices'
2+
import type { ModelAPI } from '../api'
3+
import { ChatCompletionAPI } from '../api/chat'
4+
import { DefaultProviderProxy } from './default'
5+
6+
export class HuggingFaceProvider extends DefaultProviderProxy {
7+
// This provider refers to the provider that will be used to calculate the price.
8+
protected provider: string | null = null
9+
10+
protected modelAPI(): ModelAPI {
11+
return new ChatCompletionAPI('huggingface', undefined, { usageProvider: this.usageProvider() })
12+
}
13+
14+
apiFlavor(): string | undefined {
15+
return 'chat'
16+
}
17+
18+
// We need to do this magic, because the `provider` is only set in the response headers.
19+
protected usageProvider(): UsageProvider | undefined {
20+
return findProvider({ providerId: `${this.providerId()}-${this.provider ?? 'unknown'}` })
21+
}
22+
23+
protected responseHeaders(headers: Headers): Headers {
24+
const newHeaders = super.responseHeaders(headers)
25+
this.provider = headers.get('x-inference-provider')
26+
return newHeaders
27+
}
28+
}

gateway/src/providers/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import { BedrockProvider } from './bedrock'
2323
import { DefaultProviderProxy, type ProviderOptions } from './default'
2424
import { GoogleVertexProvider } from './google'
2525
import { GroqProvider } from './groq'
26+
import { HuggingFaceProvider } from './huggingface'
2627
import { OpenAIProvider } from './openai'
2728
import { TestProvider } from './test'
2829

@@ -42,6 +43,8 @@ export function getProvider(providerId: ProviderID): ProviderSig {
4243
return AnthropicProvider
4344
case 'bedrock':
4445
return BedrockProvider
46+
case 'huggingface':
47+
return HuggingFaceProvider
4548
case 'test':
4649
return TestProvider
4750
default:

gateway/src/types.ts

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,15 @@ export interface ApiKeyInfo<ProviderKey extends string = string> {
3838
otelSettings?: OtelSettings
3939
}
4040

41-
export type ProviderID = 'groq' | 'openai' | 'google-vertex' | 'anthropic' | 'test' | 'bedrock' | 'azure'
41+
export type ProviderID =
42+
| 'azure'
43+
| 'groq'
44+
| 'openai'
45+
| 'google-vertex'
46+
| 'anthropic'
47+
| 'test'
48+
| 'bedrock'
49+
| 'huggingface'
4250
// TODO | 'fireworks' | 'mistral' | 'cohere'
4351

4452
const providerIds: Record<ProviderID, boolean> = {
@@ -48,6 +56,7 @@ const providerIds: Record<ProviderID, boolean> = {
4856
anthropic: true,
4957
test: true,
5058
bedrock: true,
59+
huggingface: true,
5160
azure: true,
5261
}
5362

0 commit comments

Comments
 (0)