Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions packages/types/src/providers/cerebras.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export const cerebrasModels = {
contextWindow: 131072,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0,
outputPrice: 0,
description: "Highly intelligent general purpose model with up to 1,000 tokens/s",
Expand All @@ -20,6 +21,7 @@ export const cerebrasModels = {
contextWindow: 64000,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0,
outputPrice: 0,
description: "Intelligent model with ~1400 tokens/s",
Expand All @@ -29,6 +31,7 @@ export const cerebrasModels = {
contextWindow: 64000,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0,
outputPrice: 0,
description: "Powerful model with ~2600 tokens/s",
Expand All @@ -38,6 +41,7 @@ export const cerebrasModels = {
contextWindow: 64000,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0,
outputPrice: 0,
description: "SOTA coding performance with ~2500 tokens/s",
Expand All @@ -47,6 +51,7 @@ export const cerebrasModels = {
contextWindow: 64000,
supportsImages: false,
supportsPromptCache: false,
supportsNativeTools: true,
inputPrice: 0,
outputPrice: 0,
description:
Expand Down
4 changes: 3 additions & 1 deletion src/api/providers/base-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ export abstract class BaseProvider implements ApiHandler {

/**
* Converts an array of tools to be compatible with OpenAI's strict mode.
* Filters for function tools and applies schema conversion to their parameters.
* Filters for function tools, applies schema conversion to their parameters,
* and ensures all tools have consistent strict: true values.
*/
protected convertToolsForOpenAI(tools: any[] | undefined): any[] | undefined {
if (!tools) {
Expand All @@ -33,6 +34,7 @@ export abstract class BaseProvider implements ApiHandler {
...tool,
function: {
...tool.function,
strict: true,
parameters: this.convertToolSchemaForOpenAI(tool.function.parameters),
},
}
Expand Down
153 changes: 79 additions & 74 deletions src/api/providers/cerebras.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,68 +16,6 @@ import { t } from "../../i18n"
const CEREBRAS_BASE_URL = "https://api.cerebras.ai/v1"
const CEREBRAS_DEFAULT_TEMPERATURE = 0

/**
* Removes thinking tokens from text to prevent model confusion when processing conversation history.
* This is crucial because models can get confused by their own thinking tokens in input.
*/
function stripThinkingTokens(text: string): string {
// Remove <think>...</think> blocks entirely, including nested ones
return text.replace(/<think>[\s\S]*?<\/think>/g, "").trim()
}

/**
* Flattens OpenAI message content to simple strings that Cerebras can handle.
* Cerebras doesn't support complex content arrays like OpenAI does.
*/
function flattenMessageContent(content: any): string {
if (typeof content === "string") {
return content
}

if (Array.isArray(content)) {
return content
.map((part) => {
if (typeof part === "string") {
return part
}
if (part.type === "text") {
return part.text || ""
}
if (part.type === "image_url") {
return "[Image]" // Placeholder for images since Cerebras doesn't support images
}
return ""
})
.filter(Boolean)
.join("\n")
}

// Fallback for any other content types
return String(content || "")
}

/**
* Converts OpenAI messages to Cerebras-compatible format with simple string content.
* Also strips thinking tokens from assistant messages to prevent model confusion.
*/
function convertToCerebrasMessages(openaiMessages: any[]): Array<{ role: string; content: string }> {
return openaiMessages
.map((msg) => {
let content = flattenMessageContent(msg.content)

// Strip thinking tokens from assistant messages to prevent confusion
if (msg.role === "assistant") {
content = stripThinkingTokens(content)
}

return {
role: msg.role,
content,
}
})
.filter((msg) => msg.content.trim() !== "") // Remove empty messages
}

export class CerebrasHandler extends BaseProvider implements SingleCompletionHandler {
private apiKey: string
private providerModels: typeof cerebrasModels
Expand Down Expand Up @@ -106,26 +44,70 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
}
}

/**
* Override convertToolSchemaForOpenAI to remove unsupported schema fields for Cerebras.
* Cerebras doesn't support minItems/maxItems in array schemas with strict mode.
*/
protected override convertToolSchemaForOpenAI(schema: any): any {
const converted = super.convertToolSchemaForOpenAI(schema)
return this.stripUnsupportedSchemaFields(converted)
}

/**
* Recursively strips unsupported schema fields for Cerebras.
* Cerebras strict mode doesn't support minItems, maxItems on arrays.
*/
private stripUnsupportedSchemaFields(schema: any): any {
if (!schema || typeof schema !== "object") {
return schema
}

const result = { ...schema }

// Remove unsupported array constraints
if (result.type === "array" || (Array.isArray(result.type) && result.type.includes("array"))) {
delete result.minItems
delete result.maxItems
}

// Recursively process properties
if (result.properties) {
const newProps = { ...result.properties }
for (const key of Object.keys(newProps)) {
newProps[key] = this.stripUnsupportedSchemaFields(newProps[key])
}
result.properties = newProps
}

// Recursively process array items
if (result.items) {
result.items = this.stripUnsupportedSchemaFields(result.items)
}

return result
}

async *createMessage(
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const {
id: model,
info: { maxTokens: max_tokens },
} = this.getModel()
const { id: model, info: modelInfo } = this.getModel()
const max_tokens = modelInfo.maxTokens
const supportsNativeTools = modelInfo.supportsNativeTools ?? false
const temperature = this.options.modelTemperature ?? CEREBRAS_DEFAULT_TEMPERATURE

// Convert Anthropic messages to OpenAI format, then flatten for Cerebras
// This will automatically strip thinking tokens from assistant messages
// Check if we should use native tool calling
const useNativeTools =
supportsNativeTools && metadata?.tools && metadata.tools.length > 0 && metadata?.toolProtocol !== "xml"

// Convert Anthropic messages to OpenAI format (Cerebras is OpenAI-compatible)
const openaiMessages = convertToOpenAiMessages(messages)
const cerebrasMessages = convertToCerebrasMessages(openaiMessages)

// Prepare request body following Cerebras API specification exactly
const requestBody = {
const requestBody: Record<string, any> = {
model,
messages: [{ role: "system", content: systemPrompt }, ...cerebrasMessages],
messages: [{ role: "system", content: systemPrompt }, ...openaiMessages],
stream: true,
// Use max_completion_tokens (Cerebras-specific parameter)
...(max_tokens && max_tokens > 0 && max_tokens <= 32768 ? { max_completion_tokens: max_tokens } : {}),
Expand All @@ -135,6 +117,10 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
temperature: Math.max(0, Math.min(1.5, temperature)),
}
: {}),
// Native tool calling support
...(useNativeTools && { tools: this.convertToolsForOpenAI(metadata.tools) }),
...(useNativeTools && metadata.tool_choice && { tool_choice: metadata.tool_choice }),
...(useNativeTools && { parallel_tool_calls: metadata?.parallelToolCalls ?? false }),
}

try {
Expand Down Expand Up @@ -216,16 +202,31 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan

const parsed = JSON.parse(jsonStr)

const delta = parsed.choices?.[0]?.delta

// Handle text content - parse for thinking tokens
if (parsed.choices?.[0]?.delta?.content) {
const content = parsed.choices[0].delta.content
if (delta?.content) {
const content = delta.content

// Use XmlMatcher to parse <think>...</think> tags
for (const chunk of matcher.update(content)) {
yield chunk
}
}

// Handle tool calls in stream - emit partial chunks for NativeToolCallParser
if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
}

// Handle usage information if available
if (parsed.usage) {
inputTokens = parsed.usage.prompt_tokens || 0
Expand All @@ -248,7 +249,11 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan

// Provide token usage estimate if not available from API
if (inputTokens === 0 || outputTokens === 0) {
const inputText = systemPrompt + cerebrasMessages.map((m) => m.content).join("")
const inputText =
systemPrompt +
openaiMessages
.map((m: any) => (typeof m.content === "string" ? m.content : JSON.stringify(m.content)))
.join("")
inputTokens = inputTokens || Math.ceil(inputText.length / 4) // Rough estimate: 4 chars per token
outputTokens = outputTokens || Math.ceil((max_tokens || 1000) / 10) // Rough estimate
}
Expand Down
Loading