1- import OpenAI from "openai" ;
2- import type { ClientOptions } from "openai" ;
3- import { zodToJsonSchema } from "zod-to-json-schema" ;
41import { LogLine } from "../../types/log" ;
5- import { AvailableModel } from "../../types/model" ;
2+ import { AvailableModel , ClientOptions } from "../../types/model" ;
63import { LLMCache } from "../cache/LLMCache" ;
7- import {
8- ChatMessage ,
9- CreateChatCompletionOptions ,
10- LLMClient ,
11- LLMResponse ,
12- } from "./LLMClient" ;
13- import { CreateChatCompletionResponseError } from "@/types/stagehandErrors" ;
4+ import { AISdkClient } from "./aisdk" ;
5+ import { LLMClient , CreateChatCompletionOptions , LLMResponse } from "./LLMClient" ;
6+ import { createCerebras } from "@ai-sdk/cerebras" ;
7+ import { LanguageModel } from "ai" ;
148
159export class CerebrasClient extends LLMClient {
1610 public type = "cerebras" as const ;
17- private client : OpenAI ;
18- private cache : LLMCache | undefined ;
19- private enableCaching : boolean ;
20- public clientOptions : ClientOptions ;
2111 public hasVision = false ;
12+ private aisdkClient : AISdkClient ;
2213
2314 constructor ( {
2415 enableCaching = false ,
@@ -36,308 +27,46 @@ export class CerebrasClient extends LLMClient {
3627 } ) {
3728 super ( modelName , userProvidedInstructions ) ;
3829
39- // Create OpenAI client with the base URL set to Cerebras API
40- this . client = new OpenAI ( {
41- baseURL : "https://api.cerebras.ai/v1" ,
42- apiKey : clientOptions ?. apiKey || process . env . CEREBRAS_API_KEY ,
43- ...clientOptions ,
44- } ) ;
45-
46- this . cache = cache ;
47- this . enableCaching = enableCaching ;
48- this . modelName = modelName ;
49- this . clientOptions = clientOptions ;
50- }
51-
52- async createChatCompletion < T = LLMResponse > ( {
53- options,
54- retries,
55- logger,
56- } : CreateChatCompletionOptions ) : Promise < T > {
57- const optionsWithoutImage = { ...options } ;
58- delete optionsWithoutImage . image ;
30+ // Transform model name to remove cerebras- prefix
31+ const cerebrasModelName = modelName . startsWith ( "cerebras-" )
32+ ? modelName . split ( "cerebras-" ) [ 1 ]
33+ : modelName ;
5934
60- logger ( {
61- category : "cerebras" ,
62- message : "creating chat completion" ,
63- level : 2 ,
64- auxiliary : {
65- options : {
66- value : JSON . stringify ( optionsWithoutImage ) ,
67- type : "object" ,
68- } ,
69- } ,
35+ // Create Cerebras provider with API key
36+ const cerebrasProvider = createCerebras ( {
37+ apiKey : ( clientOptions ?. apiKey as string ) || process . env . CEREBRAS_API_KEY ,
7038 } ) ;
7139
72- // Try to get cached response
73- const cacheOptions = {
74- model : this . modelName . split ( "cerebras-" ) [ 1 ] ,
75- messages : options . messages ,
76- temperature : options . temperature ,
77- response_model : options . response_model ,
78- tools : options . tools ,
79- retries : retries ,
80- } ;
81-
82- if ( this . enableCaching ) {
83- const cachedResponse = await this . cache . get < T > (
84- cacheOptions ,
85- options . requestId ,
86- ) ;
87- if ( cachedResponse ) {
88- logger ( {
89- category : "llm_cache" ,
90- message : "LLM cache hit - returning cached response" ,
91- level : 1 ,
92- auxiliary : {
93- cachedResponse : {
94- value : JSON . stringify ( cachedResponse ) ,
95- type : "object" ,
96- } ,
97- requestId : {
98- value : options . requestId ,
99- type : "string" ,
100- } ,
101- cacheOptions : {
102- value : JSON . stringify ( cacheOptions ) ,
103- type : "object" ,
104- } ,
105- } ,
106- } ) ;
107- return cachedResponse as T ;
108- }
109- }
110-
111- // Format messages for Cerebras API (using OpenAI format)
112- const formattedMessages = options . messages . map ( ( msg : ChatMessage ) => {
113- const baseMessage = {
114- content :
115- typeof msg . content === "string"
116- ? msg . content
117- : Array . isArray ( msg . content ) &&
118- msg . content . length > 0 &&
119- "text" in msg . content [ 0 ]
120- ? msg . content [ 0 ] . text
121- : "" ,
122- } ;
123-
124- // Cerebras only supports system, user, and assistant roles
125- if ( msg . role === "system" ) {
126- return { ...baseMessage , role : "system" as const } ;
127- } else if ( msg . role === "assistant" ) {
128- return { ...baseMessage , role : "assistant" as const } ;
129- } else {
130- // Default to user for any other role
131- return { ...baseMessage , role : "user" as const } ;
132- }
133- } ) ;
134-
135- // Format tools if provided
136- let tools = options . tools ?. map ( ( tool ) => ( {
137- type : "function" as const ,
138- function : {
139- name : tool . name ,
140- description : tool . description ,
141- parameters : {
142- type : "object" ,
143- properties : tool . parameters . properties ,
144- required : tool . parameters . required ,
145- } ,
146- } ,
147- } ) ) ;
148-
149- // Add response model as a tool if provided
150- if ( options . response_model ) {
151- const jsonSchema = zodToJsonSchema ( options . response_model . schema ) as {
152- properties ?: Record < string , unknown > ;
153- required ?: string [ ] ;
154- } ;
155- const schemaProperties = jsonSchema . properties || { } ;
156- const schemaRequired = jsonSchema . required || [ ] ;
157-
158- const responseTool = {
159- type : "function" as const ,
160- function : {
161- name : "print_extracted_data" ,
162- description :
163- "Prints the extracted data based on the provided schema." ,
164- parameters : {
165- type : "object" ,
166- properties : schemaProperties ,
167- required : schemaRequired ,
168- } ,
169- } ,
170- } ;
171-
172- tools = tools ? [ ...tools , responseTool ] : [ responseTool ] ;
173- }
174-
175- try {
176- // Use OpenAI client with Cerebras API
177- const apiResponse = await this . client . chat . completions . create ( {
178- model : this . modelName . split ( "cerebras-" ) [ 1 ] ,
179- messages : [
180- ...formattedMessages ,
181- // Add explicit instruction to return JSON if we have a response model
182- ...( options . response_model
183- ? [
184- {
185- role : "system" as const ,
186- content : `IMPORTANT: Your response must be valid JSON that matches this schema: ${ JSON . stringify (
187- options . response_model . schema ,
188- ) } `,
189- } ,
190- ]
191- : [ ] ) ,
192- ] ,
193- temperature : options . temperature || 0.7 ,
194- max_tokens : options . maxTokens ,
195- tools : tools ,
196- tool_choice : options . tool_choice || "auto" ,
197- } ) ;
198-
199- // Format the response to match the expected LLMResponse format
200- const response : LLMResponse = {
201- id : apiResponse . id ,
202- object : "chat.completion" ,
203- created : Date . now ( ) ,
204- model : this . modelName . split ( "cerebras-" ) [ 1 ] ,
205- choices : [
206- {
207- index : 0 ,
208- message : {
209- role : "assistant" ,
210- content : apiResponse . choices [ 0 ] ?. message ?. content || null ,
211- tool_calls : apiResponse . choices [ 0 ] ?. message ?. tool_calls || [ ] ,
212- } ,
213- finish_reason : apiResponse . choices [ 0 ] ?. finish_reason || "stop" ,
214- } ,
215- ] ,
216- usage : {
217- prompt_tokens : apiResponse . usage ?. prompt_tokens || 0 ,
218- completion_tokens : apiResponse . usage ?. completion_tokens || 0 ,
219- total_tokens : apiResponse . usage ?. total_tokens || 0 ,
220- } ,
221- } ;
222-
223- logger ( {
224- category : "cerebras" ,
225- message : "response" ,
226- level : 2 ,
227- auxiliary : {
228- response : {
229- value : JSON . stringify ( response ) ,
230- type : "object" ,
231- } ,
232- requestId : {
233- value : options . requestId ,
234- type : "string" ,
235- } ,
236- } ,
237- } ) ;
238-
239- // If we have no response model, just return the entire LLMResponse
240- if ( ! options . response_model ) {
241- if ( this . enableCaching ) {
242- await this . cache . set ( cacheOptions , response , options . requestId ) ;
243- }
244- return response as T ;
245- }
246-
247- // If we have a response model, parse JSON from tool calls or content
248- const toolCall = response . choices [ 0 ] ?. message ?. tool_calls ?. [ 0 ] ;
249- if ( toolCall ?. function ?. arguments ) {
250- try {
251- const result = JSON . parse ( toolCall . function . arguments ) ;
252- const finalResponse = {
253- data : result ,
254- usage : response . usage ,
255- } ;
256- if ( this . enableCaching ) {
257- await this . cache . set (
258- cacheOptions ,
259- finalResponse ,
260- options . requestId ,
261- ) ;
262- }
263- return finalResponse as T ;
264- } catch ( e ) {
265- logger ( {
266- category : "cerebras" ,
267- message : "failed to parse tool call arguments as JSON, retrying" ,
268- level : 0 ,
269- auxiliary : {
270- error : {
271- value : e . message ,
272- type : "string" ,
273- } ,
274- } ,
275- } ) ;
40+ // Get the specific model from the provider
41+ const cerebrasModel = cerebrasProvider ( cerebrasModelName ) ;
42+
43+ this . aisdkClient = new AISdkClient ( {
44+ model : cerebrasModel as unknown as LanguageModel ,
45+ logger : ( message : LogLine ) => {
46+ // Transform log messages to use cerebras category
47+ const transformedMessage = {
48+ ...message ,
49+ category :
50+ message . category === "aisdk" ? "cerebras" : message . category ,
51+ } ;
52+ // Call the original logger if it exists
53+ if (
54+ typeof ( this as unknown as { logger ?: ( message : LogLine ) => void } )
55+ . logger === "function"
56+ ) {
57+ ( this as unknown as { logger : ( message : LogLine ) => void } ) . logger (
58+ transformedMessage ,
59+ ) ;
27660 }
277- }
278-
279- // If we have content but no tool calls, try to parse the content as JSON
280- const content = response . choices [ 0 ] ?. message ?. content ;
281- if ( content ) {
282- try {
283- const jsonMatch = content . match ( / \{ [ \s \S ] * \} / ) ;
284- if ( jsonMatch ) {
285- const result = JSON . parse ( jsonMatch [ 0 ] ) ;
286- const finalResponse = {
287- data : result ,
288- usage : response . usage ,
289- } ;
290- if ( this . enableCaching ) {
291- await this . cache . set (
292- cacheOptions ,
293- finalResponse ,
294- options . requestId ,
295- ) ;
296- }
297- return finalResponse as T ;
298- }
299- } catch ( e ) {
300- logger ( {
301- category : "cerebras" ,
302- message : "failed to parse content as JSON" ,
303- level : 0 ,
304- auxiliary : {
305- error : {
306- value : e . message ,
307- type : "string" ,
308- } ,
309- } ,
310- } ) ;
311- }
312- }
313-
314- // If we still haven't found valid JSON and have retries left, try again
315- if ( ! retries || retries < 5 ) {
316- return this . createChatCompletion ( {
317- options,
318- logger,
319- retries : ( retries ?? 0 ) + 1 ,
320- } ) ;
321- }
61+ } ,
62+ enableCaching,
63+ cache,
64+ } ) ;
65+ }
32266
323- throw new CreateChatCompletionResponseError ( "Invalid response schema" ) ;
324- } catch ( error ) {
325- logger ( {
326- category : "cerebras" ,
327- message : "error creating chat completion" ,
328- level : 0 ,
329- auxiliary : {
330- error : {
331- value : error . message ,
332- type : "string" ,
333- } ,
334- requestId : {
335- value : options . requestId ,
336- type : "string" ,
337- } ,
338- } ,
339- } ) ;
340- throw error ;
341- }
67+ async createChatCompletion < T = LLMResponse > (
68+ options : CreateChatCompletionOptions ,
69+ ) : Promise < T > {
70+ return this . aisdkClient . createChatCompletion < T > ( options ) ;
34271 }
34372}
0 commit comments