Skip to content

Commit

Permalink
refactor: merge the two extract methods
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac committed Jul 18, 2024
1 parent 0448561 commit 0ea0f3d
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 51 deletions.
75 changes: 27 additions & 48 deletions js/ai/src/extract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,25 @@
import JSON5 from 'json5';
import { Allow, parse } from 'partial-json';

const parsePartialJson = (jsonString: string) =>
JSON.stringify(parse(jsonString, Allow.ALL));
export function parsePartialJson<T = unknown>(jsonString: string): T {
return JSON5.parse<T>(JSON.stringify(parse(jsonString, Allow.ALL)));
}

/**
* Extracts JSON from string with lenient parsing rules to improve likelihood of successful extraction.
*/
export function extractJson<T = unknown>(text: string): T | null {
export function extractJson<T = unknown>(
text: string,
throwOnBadJson?: true
): T;
export function extractJson<T = unknown>(
text: string,
throwOnBadJson?: false
): T | null;
export function extractJson<T = unknown>(
text: string,
throwOnBadJson?: boolean
): T | null {
let openingChar: '{' | '[' | undefined;
let closingChar: '}' | ']' | undefined;
let startPos: number | undefined;
Expand Down Expand Up @@ -52,54 +64,21 @@ export function extractJson<T = unknown>(text: string): T | null {
}

if (startPos !== undefined && nestingCount > 0) {
// If an incomplete JSON structure is detected
try {
return JSON5.parse(text.substring(startPos) + (closingChar || '')) as T;
} catch (e) {
throw new Error(`Invalid JSON extracted from model output: ${text}`);
}
}
throw new Error(`No JSON object or array found in model output: ${text}`);
}

export function extractAndUntruncateJson<T = unknown>(text: string): T | null {
let openingChar: '{' | '[' | undefined;
let closingChar: '}' | ']' | undefined;
let startPos: number | undefined;
let nestingCount = 0;

for (let i = 0; i < text.length; i++) {
const char = text[i].replace(/\u00A0/g, ' ');

if (!openingChar && (char === '{' || char === '[')) {
openingChar = char;
closingChar = char === '{' ? '}' : ']';
startPos = i;
nestingCount++;
} else if (char === openingChar) {
// Increment nesting for matching opening character
nestingCount++;
} else if (char === closingChar) {
// Decrement nesting for matching closing character
nestingCount--;
if (!nestingCount) {
// Reached end of target element
try {
return JSON5.parse(
parsePartialJson(text.substring(startPos || 0, i + 1))
) as T;
} catch {
throw new Error(text.substring(startPos || 0, i + 1));
}
// Parse the incomplete JSON structure using partial-json for lenient parsing
// Note: partial-json automatically handles adding the closing character
return parsePartialJson<T>(text.substring(startPos));
} catch {
// If parsing fails, throw an error
if (throwOnBadJson) {
throw new Error(`Invalid JSON extracted from model output: ${text}`);
}
return null; // Return null if no JSON structure is found }
}
}

if (startPos !== undefined && nestingCount > 0) {
try {
return JSON5.parse(parsePartialJson(text.substring(startPos))) as T;
} catch {
throw new Error(text.substring(startPos));
}
if (throwOnBadJson) {
throw new Error(`Invalid JSON extracted from model output: ${text}`);
}
return null as T;
return null; // Return null if no JSON structure is found
}
6 changes: 3 additions & 3 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import {
} from '@genkit-ai/core/schema';
import { z } from 'zod';
import { DocumentData } from './document.js';
import { extractAndUntruncateJson, extractJson } from './extract.js';
import { extractJson } from './extract.js';
import {
CandidateData,
GenerateRequest,
Expand Down Expand Up @@ -74,7 +74,7 @@ export class Message<T = unknown> implements MessageData {
*
* @returns The structured output contained in the message.
*/
output(): T | null {
output(): T {
return this.data() || extractJson<T>(this.text());
}

Expand Down Expand Up @@ -417,7 +417,7 @@ export class GenerateResponseChunk<T = unknown>
const accumulatedText = this.accumulatedChunks
.map((chunk) => chunk.content.map((part) => part.text || '').join(''))
.join('');
return extractAndUntruncateJson(accumulatedText);
return extractJson<T>(accumulatedText, false);
}

toJSON(): GenerateResponseChunkData {
Expand Down
10 changes: 10 additions & 0 deletions js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,16 @@ describe('GenerateResponseChunk', () => {
accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: {"baz": "qux`],
correctJson: { foo: { bar: { baz: 'qux' } } },
},
{
should: 'handle array nested in object',
accumulatedChunksTexts: [`{"foo": ["bar`],
correctJson: { foo: ['bar'] },
},
{
should: 'handle array nested in object with multiple chunks',
accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: ["baz`],
correctJson: { foo: { bar: ['baz'] } },
},
];

for (const test of testCases) {
Expand Down

0 comments on commit 0ea0f3d

Please sign in to comment.