Skip to content

Commit f8c3a74

Browse files
committed
Update CerebrasClient to use ai-sdk instead of OpenAI
1 parent cbff109 commit f8c3a74

File tree

5 files changed

+61
-317
lines changed

5 files changed

+61
-317
lines changed

lib/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,4 +1046,5 @@ export * from "../types/stagehandApiErrors";
10461046
export * from "../types/stagehandErrors";
10471047
export * from "./llm/LLMClient";
10481048
export * from "./llm/aisdk";
1049+
export { CerebrasClient } from "./llm/CerebrasClient";
10491050
export { connectToMCPServer };

lib/llm/CerebrasClient.ts

Lines changed: 42 additions & 313 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,15 @@
1-
import OpenAI from "openai";
2-
import type { ClientOptions } from "openai";
3-
import { zodToJsonSchema } from "zod-to-json-schema";
41
import { LogLine } from "../../types/log";
5-
import { AvailableModel } from "../../types/model";
2+
import { AvailableModel, ClientOptions } from "../../types/model";
63
import { LLMCache } from "../cache/LLMCache";
7-
import {
8-
ChatMessage,
9-
CreateChatCompletionOptions,
10-
LLMClient,
11-
LLMResponse,
12-
} from "./LLMClient";
13-
import { CreateChatCompletionResponseError } from "@/types/stagehandErrors";
4+
import { AISdkClient } from "./aisdk";
5+
import { LLMClient, CreateChatCompletionOptions, LLMResponse } from "./LLMClient";
6+
import { createCerebras } from "@ai-sdk/cerebras";
7+
import { LanguageModel } from "ai";
148

159
export class CerebrasClient extends LLMClient {
1610
public type = "cerebras" as const;
17-
private client: OpenAI;
18-
private cache: LLMCache | undefined;
19-
private enableCaching: boolean;
20-
public clientOptions: ClientOptions;
2111
public hasVision = false;
12+
private aisdkClient: AISdkClient;
2213

2314
constructor({
2415
enableCaching = false,
@@ -36,308 +27,46 @@ export class CerebrasClient extends LLMClient {
3627
}) {
3728
super(modelName, userProvidedInstructions);
3829

39-
// Create OpenAI client with the base URL set to Cerebras API
40-
this.client = new OpenAI({
41-
baseURL: "https://api.cerebras.ai/v1",
42-
apiKey: clientOptions?.apiKey || process.env.CEREBRAS_API_KEY,
43-
...clientOptions,
44-
});
45-
46-
this.cache = cache;
47-
this.enableCaching = enableCaching;
48-
this.modelName = modelName;
49-
this.clientOptions = clientOptions;
50-
}
51-
52-
async createChatCompletion<T = LLMResponse>({
53-
options,
54-
retries,
55-
logger,
56-
}: CreateChatCompletionOptions): Promise<T> {
57-
const optionsWithoutImage = { ...options };
58-
delete optionsWithoutImage.image;
30+
// Transform model name to remove cerebras- prefix
31+
const cerebrasModelName = modelName.startsWith("cerebras-")
32+
? modelName.split("cerebras-")[1]
33+
: modelName;
5934

60-
logger({
61-
category: "cerebras",
62-
message: "creating chat completion",
63-
level: 2,
64-
auxiliary: {
65-
options: {
66-
value: JSON.stringify(optionsWithoutImage),
67-
type: "object",
68-
},
69-
},
35+
// Create Cerebras provider with API key
36+
const cerebrasProvider = createCerebras({
37+
apiKey: (clientOptions?.apiKey as string) || process.env.CEREBRAS_API_KEY,
7038
});
7139

72-
// Try to get cached response
73-
const cacheOptions = {
74-
model: this.modelName.split("cerebras-")[1],
75-
messages: options.messages,
76-
temperature: options.temperature,
77-
response_model: options.response_model,
78-
tools: options.tools,
79-
retries: retries,
80-
};
81-
82-
if (this.enableCaching) {
83-
const cachedResponse = await this.cache.get<T>(
84-
cacheOptions,
85-
options.requestId,
86-
);
87-
if (cachedResponse) {
88-
logger({
89-
category: "llm_cache",
90-
message: "LLM cache hit - returning cached response",
91-
level: 1,
92-
auxiliary: {
93-
cachedResponse: {
94-
value: JSON.stringify(cachedResponse),
95-
type: "object",
96-
},
97-
requestId: {
98-
value: options.requestId,
99-
type: "string",
100-
},
101-
cacheOptions: {
102-
value: JSON.stringify(cacheOptions),
103-
type: "object",
104-
},
105-
},
106-
});
107-
return cachedResponse as T;
108-
}
109-
}
110-
111-
// Format messages for Cerebras API (using OpenAI format)
112-
const formattedMessages = options.messages.map((msg: ChatMessage) => {
113-
const baseMessage = {
114-
content:
115-
typeof msg.content === "string"
116-
? msg.content
117-
: Array.isArray(msg.content) &&
118-
msg.content.length > 0 &&
119-
"text" in msg.content[0]
120-
? msg.content[0].text
121-
: "",
122-
};
123-
124-
// Cerebras only supports system, user, and assistant roles
125-
if (msg.role === "system") {
126-
return { ...baseMessage, role: "system" as const };
127-
} else if (msg.role === "assistant") {
128-
return { ...baseMessage, role: "assistant" as const };
129-
} else {
130-
// Default to user for any other role
131-
return { ...baseMessage, role: "user" as const };
132-
}
133-
});
134-
135-
// Format tools if provided
136-
let tools = options.tools?.map((tool) => ({
137-
type: "function" as const,
138-
function: {
139-
name: tool.name,
140-
description: tool.description,
141-
parameters: {
142-
type: "object",
143-
properties: tool.parameters.properties,
144-
required: tool.parameters.required,
145-
},
146-
},
147-
}));
148-
149-
// Add response model as a tool if provided
150-
if (options.response_model) {
151-
const jsonSchema = zodToJsonSchema(options.response_model.schema) as {
152-
properties?: Record<string, unknown>;
153-
required?: string[];
154-
};
155-
const schemaProperties = jsonSchema.properties || {};
156-
const schemaRequired = jsonSchema.required || [];
157-
158-
const responseTool = {
159-
type: "function" as const,
160-
function: {
161-
name: "print_extracted_data",
162-
description:
163-
"Prints the extracted data based on the provided schema.",
164-
parameters: {
165-
type: "object",
166-
properties: schemaProperties,
167-
required: schemaRequired,
168-
},
169-
},
170-
};
171-
172-
tools = tools ? [...tools, responseTool] : [responseTool];
173-
}
174-
175-
try {
176-
// Use OpenAI client with Cerebras API
177-
const apiResponse = await this.client.chat.completions.create({
178-
model: this.modelName.split("cerebras-")[1],
179-
messages: [
180-
...formattedMessages,
181-
// Add explicit instruction to return JSON if we have a response model
182-
...(options.response_model
183-
? [
184-
{
185-
role: "system" as const,
186-
content: `IMPORTANT: Your response must be valid JSON that matches this schema: ${JSON.stringify(
187-
options.response_model.schema,
188-
)}`,
189-
},
190-
]
191-
: []),
192-
],
193-
temperature: options.temperature || 0.7,
194-
max_tokens: options.maxTokens,
195-
tools: tools,
196-
tool_choice: options.tool_choice || "auto",
197-
});
198-
199-
// Format the response to match the expected LLMResponse format
200-
const response: LLMResponse = {
201-
id: apiResponse.id,
202-
object: "chat.completion",
203-
created: Date.now(),
204-
model: this.modelName.split("cerebras-")[1],
205-
choices: [
206-
{
207-
index: 0,
208-
message: {
209-
role: "assistant",
210-
content: apiResponse.choices[0]?.message?.content || null,
211-
tool_calls: apiResponse.choices[0]?.message?.tool_calls || [],
212-
},
213-
finish_reason: apiResponse.choices[0]?.finish_reason || "stop",
214-
},
215-
],
216-
usage: {
217-
prompt_tokens: apiResponse.usage?.prompt_tokens || 0,
218-
completion_tokens: apiResponse.usage?.completion_tokens || 0,
219-
total_tokens: apiResponse.usage?.total_tokens || 0,
220-
},
221-
};
222-
223-
logger({
224-
category: "cerebras",
225-
message: "response",
226-
level: 2,
227-
auxiliary: {
228-
response: {
229-
value: JSON.stringify(response),
230-
type: "object",
231-
},
232-
requestId: {
233-
value: options.requestId,
234-
type: "string",
235-
},
236-
},
237-
});
238-
239-
// If we have no response model, just return the entire LLMResponse
240-
if (!options.response_model) {
241-
if (this.enableCaching) {
242-
await this.cache.set(cacheOptions, response, options.requestId);
243-
}
244-
return response as T;
245-
}
246-
247-
// If we have a response model, parse JSON from tool calls or content
248-
const toolCall = response.choices[0]?.message?.tool_calls?.[0];
249-
if (toolCall?.function?.arguments) {
250-
try {
251-
const result = JSON.parse(toolCall.function.arguments);
252-
const finalResponse = {
253-
data: result,
254-
usage: response.usage,
255-
};
256-
if (this.enableCaching) {
257-
await this.cache.set(
258-
cacheOptions,
259-
finalResponse,
260-
options.requestId,
261-
);
262-
}
263-
return finalResponse as T;
264-
} catch (e) {
265-
logger({
266-
category: "cerebras",
267-
message: "failed to parse tool call arguments as JSON, retrying",
268-
level: 0,
269-
auxiliary: {
270-
error: {
271-
value: e.message,
272-
type: "string",
273-
},
274-
},
275-
});
40+
// Get the specific model from the provider
41+
const cerebrasModel = cerebrasProvider(cerebrasModelName);
42+
43+
this.aisdkClient = new AISdkClient({
44+
model: cerebrasModel as unknown as LanguageModel,
45+
logger: (message: LogLine) => {
46+
// Transform log messages to use cerebras category
47+
const transformedMessage = {
48+
...message,
49+
category:
50+
message.category === "aisdk" ? "cerebras" : message.category,
51+
};
52+
// Call the original logger if it exists
53+
if (
54+
typeof (this as unknown as { logger?: (message: LogLine) => void })
55+
.logger === "function"
56+
) {
57+
(this as unknown as { logger: (message: LogLine) => void }).logger(
58+
transformedMessage,
59+
);
27660
}
277-
}
278-
279-
// If we have content but no tool calls, try to parse the content as JSON
280-
const content = response.choices[0]?.message?.content;
281-
if (content) {
282-
try {
283-
const jsonMatch = content.match(/\{[\s\S]*\}/);
284-
if (jsonMatch) {
285-
const result = JSON.parse(jsonMatch[0]);
286-
const finalResponse = {
287-
data: result,
288-
usage: response.usage,
289-
};
290-
if (this.enableCaching) {
291-
await this.cache.set(
292-
cacheOptions,
293-
finalResponse,
294-
options.requestId,
295-
);
296-
}
297-
return finalResponse as T;
298-
}
299-
} catch (e) {
300-
logger({
301-
category: "cerebras",
302-
message: "failed to parse content as JSON",
303-
level: 0,
304-
auxiliary: {
305-
error: {
306-
value: e.message,
307-
type: "string",
308-
},
309-
},
310-
});
311-
}
312-
}
313-
314-
// If we still haven't found valid JSON and have retries left, try again
315-
if (!retries || retries < 5) {
316-
return this.createChatCompletion({
317-
options,
318-
logger,
319-
retries: (retries ?? 0) + 1,
320-
});
321-
}
61+
},
62+
enableCaching,
63+
cache,
64+
});
65+
}
32266

323-
throw new CreateChatCompletionResponseError("Invalid response schema");
324-
} catch (error) {
325-
logger({
326-
category: "cerebras",
327-
message: "error creating chat completion",
328-
level: 0,
329-
auxiliary: {
330-
error: {
331-
value: error.message,
332-
type: "string",
333-
},
334-
requestId: {
335-
value: options.requestId,
336-
type: "string",
337-
},
338-
},
339-
});
340-
throw error;
341-
}
67+
async createChatCompletion<T = LLMResponse>(
68+
options: CreateChatCompletionOptions,
69+
): Promise<T> {
70+
return this.aisdkClient.createChatCompletion<T>(options);
34271
}
34372
}

0 commit comments

Comments
 (0)