Skip to content

Commit 6b2edc6

Browse files
authored
Revert "fix: bracket stripping in gemini responses" (#8332)
1 parent 25f3f93 commit 6b2edc6

File tree

3 files changed

+144
-89
lines changed

3 files changed

+144
-89
lines changed

core/llm/llms/Gemini.ts

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { streamSse } from "@continuedev/fetch";
1+
import { streamResponse } from "@continuedev/fetch";
22
import { v4 as uuidv4 } from "uuid";
33
import {
44
AssistantChatMessage,
@@ -312,57 +312,84 @@ class Gemini extends BaseLLM {
312312
}
313313

314314
public async *processGeminiResponse(
315-
response: Response,
315+
stream: AsyncIterable<string>,
316316
): AsyncGenerator<ChatMessage> {
317-
for await (const chunk of streamSse(response)) {
318-
let data: GeminiChatResponse;
319-
try {
320-
data = JSON.parse(chunk) as GeminiChatResponse;
321-
} catch (e) {
322-
continue;
317+
let buffer = "";
318+
for await (const chunk of stream) {
319+
buffer += chunk;
320+
if (buffer.startsWith("[")) {
321+
buffer = buffer.slice(1);
323322
}
324-
325-
if ("error" in data) {
326-
throw new Error(data.error.message);
323+
if (buffer.endsWith("]")) {
324+
buffer = buffer.slice(0, -1);
325+
}
326+
if (buffer.startsWith(",")) {
327+
buffer = buffer.slice(1);
327328
}
328329

329-
const contentParts = data?.candidates?.[0]?.content?.parts;
330-
if (contentParts) {
331-
const textParts: MessagePart[] = [];
332-
const toolCalls: ToolCallDelta[] = [];
333-
334-
for (const part of contentParts) {
335-
if ("text" in part) {
336-
textParts.push({ type: "text", text: part.text });
337-
} else if ("functionCall" in part) {
338-
toolCalls.push({
339-
type: "function",
340-
id: part.functionCall.id ?? uuidv4(),
341-
function: {
342-
name: part.functionCall.name,
343-
arguments:
344-
typeof part.functionCall.args === "string"
345-
? part.functionCall.args
346-
: JSON.stringify(part.functionCall.args),
347-
},
348-
});
349-
} else {
350-
console.warn("Unsupported gemini part type received", part);
351-
}
330+
const parts = buffer.split("\n,");
331+
332+
let foundIncomplete = false;
333+
for (let i = 0; i < parts.length; i++) {
334+
const part = parts[i];
335+
let data: GeminiChatResponse;
336+
try {
337+
data = JSON.parse(part) as GeminiChatResponse;
338+
} catch (e) {
339+
foundIncomplete = true;
340+
continue; // yo!
352341
}
353342

354-
const assistantMessage: AssistantChatMessage = {
355-
role: "assistant",
356-
content: textParts.length ? textParts : "",
357-
};
358-
if (toolCalls.length > 0) {
359-
assistantMessage.toolCalls = toolCalls;
343+
if ("error" in data) {
344+
throw new Error(data.error.message);
360345
}
361-
if (textParts.length || toolCalls.length) {
362-
yield assistantMessage;
346+
347+
// In case of max tokens reached, gemini will sometimes return content with no parts, even though that doesn't match the API spec
348+
const contentParts = data?.candidates?.[0]?.content?.parts;
349+
if (contentParts) {
350+
const textParts: MessagePart[] = [];
351+
const toolCalls: ToolCallDelta[] = [];
352+
353+
for (const part of contentParts) {
354+
if ("text" in part) {
355+
textParts.push({ type: "text", text: part.text });
356+
} else if ("functionCall" in part) {
357+
toolCalls.push({
358+
type: "function",
359+
id: part.functionCall.id ?? uuidv4(),
360+
function: {
361+
name: part.functionCall.name,
362+
arguments:
363+
typeof part.functionCall.args === "string"
364+
? part.functionCall.args
365+
: JSON.stringify(part.functionCall.args),
366+
},
367+
});
368+
} else {
369+
// Note: function responses shouldn't be streamed, images not supported
370+
console.warn("Unsupported gemini part type received", part);
371+
}
372+
}
373+
374+
const assistantMessage: AssistantChatMessage = {
375+
role: "assistant",
376+
content: textParts.length ? textParts : "",
377+
};
378+
if (toolCalls.length > 0) {
379+
assistantMessage.toolCalls = toolCalls;
380+
}
381+
if (textParts.length || toolCalls.length) {
382+
yield assistantMessage;
383+
}
384+
} else {
385+
// Handle the case where the expected data structure is not found
386+
console.warn("Unexpected response format:", data);
363387
}
388+
}
389+
if (foundIncomplete) {
390+
buffer = parts[parts.length - 1];
364391
} else {
365-
console.warn("Unexpected response format:", data);
392+
buffer = "";
366393
}
367394
}
368395
}
@@ -387,9 +414,10 @@ class Gemini extends BaseLLM {
387414
body: JSON.stringify(body),
388415
signal,
389416
});
390-
391-
for await (const chunk of this.processGeminiResponse(response)) {
392-
yield chunk;
417+
for await (const message of this.processGeminiResponse(
418+
streamResponse(response),
419+
)) {
420+
yield message;
393421
}
394422
}
395423
private async *streamChatBison(

core/llm/llms/VertexAI.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { AuthClient, GoogleAuth, JWT, auth } from "google-auth-library";
22

3-
import { streamSse } from "@continuedev/fetch";
3+
import { streamResponse, streamSse } from "@continuedev/fetch";
44
import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js";
55
import { renderChatMessage, stripImages } from "../../util/messageContent.js";
66
import { BaseLLM } from "../index.js";
@@ -287,7 +287,7 @@ class VertexAI extends BaseLLM {
287287
body: JSON.stringify(body),
288288
signal,
289289
});
290-
yield* this.geminiInstance.processGeminiResponse(response);
290+
yield* this.geminiInstance.processGeminiResponse(streamResponse(response));
291291
}
292292

293293
private async *streamChatBison(

packages/openai-adapters/src/apis/Gemini.ts

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { streamSse } from "@continuedev/fetch";
1+
import { streamResponse } from "@continuedev/fetch";
22
import { OpenAI } from "openai/index";
33
import {
44
ChatCompletion,
@@ -284,58 +284,85 @@ export class GeminiApi implements BaseLlmApi {
284284
}
285285

286286
async *handleStreamResponse(response: any, model: string) {
287+
let buffer = "";
287288
let usage: UsageInfo | undefined = undefined;
288-
for await (const chunk of streamSse(response as any)) {
289-
let data;
290-
try {
291-
data = JSON.parse(chunk);
292-
} catch (e) {
293-
continue;
289+
for await (const chunk of streamResponse(response as any)) {
290+
buffer += chunk;
291+
if (buffer.startsWith("[")) {
292+
buffer = buffer.slice(1);
294293
}
295-
if (data.error) {
296-
throw new Error(data.error.message);
294+
if (buffer.endsWith("]")) {
295+
buffer = buffer.slice(0, -1);
297296
}
298-
299-
if (data.usageMetadata) {
300-
usage = {
301-
prompt_tokens: data.usageMetadata.promptTokenCount || 0,
302-
completion_tokens: data.usageMetadata.candidatesTokenCount || 0,
303-
total_tokens: data.usageMetadata.totalTokenCount || 0,
304-
};
297+
if (buffer.startsWith(",")) {
298+
buffer = buffer.slice(1);
305299
}
306300

307-
const contentParts = data?.candidates?.[0]?.content?.parts;
308-
if (contentParts) {
309-
for (const part of contentParts) {
310-
if ("text" in part) {
311-
yield chatChunk({
312-
content: part.text,
313-
model,
314-
});
315-
} else if ("functionCall" in part) {
316-
yield chatChunkFromDelta({
317-
model,
318-
delta: {
319-
tool_calls: [
320-
{
321-
index: 0,
322-
id: part.functionCall.id ?? uuidv4(),
323-
type: "function",
324-
function: {
325-
name: part.functionCall.name,
326-
arguments: JSON.stringify(part.functionCall.args),
301+
const parts = buffer.split("\n,");
302+
303+
let foundIncomplete = false;
304+
for (let i = 0; i < parts.length; i++) {
305+
const part = parts[i];
306+
let data;
307+
try {
308+
data = JSON.parse(part);
309+
} catch (e) {
310+
foundIncomplete = true;
311+
continue; // yo!
312+
}
313+
if (data.error) {
314+
throw new Error(data.error.message);
315+
}
316+
317+
// Check for usage metadata
318+
if (data.usageMetadata) {
319+
usage = {
320+
prompt_tokens: data.usageMetadata.promptTokenCount || 0,
321+
completion_tokens: data.usageMetadata.candidatesTokenCount || 0,
322+
total_tokens: data.usageMetadata.totalTokenCount || 0,
323+
};
324+
}
325+
326+
// In case of max tokens reached, gemini will sometimes return content with no parts, even though that doesn't match the API spec
327+
const contentParts = data?.candidates?.[0]?.content?.parts;
328+
if (contentParts) {
329+
for (const part of contentParts) {
330+
if ("text" in part) {
331+
yield chatChunk({
332+
content: part.text,
333+
model,
334+
});
335+
} else if ("functionCall" in part) {
336+
yield chatChunkFromDelta({
337+
model,
338+
delta: {
339+
tool_calls: [
340+
{
341+
index: 0,
342+
id: part.functionCall.id ?? uuidv4(),
343+
type: "function",
344+
function: {
345+
name: part.functionCall.name,
346+
arguments: JSON.stringify(part.functionCall.args),
347+
},
327348
},
328-
},
329-
],
330-
},
331-
});
349+
],
350+
},
351+
});
352+
}
332353
}
354+
} else {
355+
console.warn("Unexpected response format:", data);
333356
}
357+
}
358+
if (foundIncomplete) {
359+
buffer = parts[parts.length - 1];
334360
} else {
335-
console.warn("Unexpected response format:", data);
361+
buffer = "";
336362
}
337363
}
338364

365+
// Emit usage at the end if we have it
339366
if (usage) {
340367
yield usageChatChunk({
341368
model,

0 commit comments

Comments
 (0)