Skip to content

Commit

Permalink
test: add tests for streaming json
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac committed Jun 26, 2024
1 parent e23f3b9 commit 9161a80
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 3 deletions.
40 changes: 40 additions & 0 deletions js/ai/src/extract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import JSON5 from 'json5';
import untruncateJson from 'untruncate-json';

/**
* Extracts JSON from string with lenient parsing rules to improve likelihood of successful extraction.
Expand Down Expand Up @@ -56,3 +57,42 @@ export function extractJson<T = unknown>(text: string): T | null {
}
throw new Error(`No JSON object or array found in model output: ${text}`);
}

export function extractAndUntruncateJson<T = unknown>(text: string): T | null {
let openingChar: '{' | '[' | undefined;
let closingChar: '}' | ']' | undefined;
let startPos: number | undefined;
let nestingCount = 0;

for (let i = 0; i < text.length; i++) {
const char = text[i].replace(/\u00A0/g, ' ');

if (!openingChar && (char === '{' || char === '[')) {
openingChar = char;
closingChar = char === '{' ? '}' : ']';
startPos = i;
nestingCount++;
} else if (char === openingChar) {
// Increment nesting for matching opening character
nestingCount++;
} else if (char === closingChar) {
// Decrement nesting for matching closing character
nestingCount--;
if (!nestingCount) {
// Reached end of target element
return JSON5.parse(
untruncateJson(text.substring(startPos || 0, i + 1))
) as T;
}
}
}

if (startPos !== undefined && nestingCount > 0) {
try {
return JSON5.parse(untruncateJson(text.substring(startPos))) as T;
} catch (e) {
return null;
}
}
return {} as T;
}
5 changes: 2 additions & 3 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,9 @@ import {
} from '@genkit-ai/core';
import { lookupAction } from '@genkit-ai/core/registry';
import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema';
import untruncateJson from 'untruncate-json';
import { z } from 'zod';
import { DocumentData } from './document.js';
import { extractJson } from './extract.js';
import { extractAndUntruncateJson, extractJson } from './extract.js';
import {
CandidateData,
GenerateRequest,
Expand Down Expand Up @@ -415,7 +414,7 @@ export class GenerateResponseChunk<T = unknown>
const accumulatedText = this.accumulatedChunks
.map((chunk) => chunk.content.map((part) => part.text || '').join(''))
.join('');
return extractJson(untruncateJson(accumulatedText));
return extractAndUntruncateJson(accumulatedText);
}

toJSON(): GenerateResponseChunkData {
Expand Down
60 changes: 60 additions & 0 deletions js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import assert from 'node:assert';
import { describe, it } from 'node:test';
import { z } from 'zod';
import { GenerateResponseChunk } from '../../src/generate';
import {
Candidate,
GenerateOptions,
GenerateResponse,
Message,
toGenerateRequest,
} from '../../src/generate.js';
import { GenerateResponseChunkData } from '../../src/model';
import {
CandidateData,
GenerateRequest,
Expand Down Expand Up @@ -506,3 +508,61 @@ describe('toGenerateRequest', () => {
});
}
});

describe('GenerateResponseChunk', () => {
describe('#output()', () => {
const testCases = [
{
should: 'parse ``` correctly',
accumulatedChunksTexts: ['```'],
correctJson: {},
},
{
should: 'parse valid json correctly',
accumulatedChunksTexts: [`{"foo":"bar"}`],
correctJson: { foo: 'bar' },
},
{
should: 'handle missing closing brace',
accumulatedChunksTexts: [`{"foo":"bar"`],
correctJson: { foo: 'bar' },
},
{
should: 'handle missing closing bracket in nested object',
accumulatedChunksTexts: [`{"foo": {"bar": "baz"`],
correctJson: { foo: { bar: 'baz' } },
},
{
should: 'handle multiple chunks',
accumulatedChunksTexts: [`{"foo": {"bar"`, `: "baz`],
correctJson: { foo: { bar: 'baz' } },
},
{
should: 'handle multiple chunks with nested objects',
accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: {"baz": "qux`],
correctJson: { foo: { bar: { baz: 'qux' } } },
},
];

for (const test of testCases) {
if (test.should) {
it(test.should, () => {
const accumulatedChunks: GenerateResponseChunkData[] =
test.accumulatedChunksTexts.map((text, index) => ({
index,
content: [{ text }],
}));

const chunkData = accumulatedChunks[accumulatedChunks.length - 1];

const responseChunk: GenerateResponseChunk =
new GenerateResponseChunk(chunkData, accumulatedChunks);

const output = responseChunk.output();

assert.deepStrictEqual(output, test.correctJson);
});
}
}
});
});

0 comments on commit 9161a80

Please sign in to comment.