From 76ddebd309fd57df61d9c24ab6c9dcedb789518b Mon Sep 17 00:00:00 2001 From: Jacob Cable Date: Mon, 16 Sep 2024 16:29:06 +0100 Subject: [PATCH] Revert "feat(js/plugins/ollama): Ollama embeddings (#807)" This reverts commit fb1c2845226fc4b3eb4b6bd528303cd3085cabbc. --- docs/errors/no_new_actions_at_runtime.md | 22 -- genkit-tools/cli/src/commands/eval-flow.ts | 2 +- genkit-tools/common/src/eval/index.ts | 7 +- .../common/src/eval/localFileDatasetStore.ts | 222 ------------ .../common/src/eval/localFileEvalStore.ts | 24 +- genkit-tools/common/src/server/router.ts | 53 +-- genkit-tools/common/src/types/apis.ts | 22 +- genkit-tools/common/src/types/eval.ts | 74 +--- genkit-tools/common/src/types/model.ts | 2 +- .../common/tests/eval/exporter_test.ts | 2 +- .../tests/eval/localFileDatasetStore_test.ts | 290 ---------------- .../tests/eval/localFileEvalStore_test.ts | 15 +- js/ai/src/generate.ts | 210 +++++++++--- js/ai/src/generateAction.ts | 301 ---------------- js/ai/src/model.ts | 5 +- js/ai/tests/generate/generate_test.ts | 27 +- js/core/src/action.ts | 34 +- js/core/src/plugin.ts | 8 +- js/core/src/registry.ts | 323 +++++++----------- js/core/tests/registry_test.ts | 210 +----------- js/flow/src/flow.ts | 2 - js/flow/src/utils.ts | 3 +- js/plugins/dotprompt/src/template.ts | 14 +- js/plugins/dotprompt/tests/prompt_test.ts | 19 -- js/plugins/google-cloud/tests/logs_test.ts | 6 +- js/plugins/ollama/package.json | 7 +- js/plugins/ollama/src/embeddings.ts | 106 ------ js/plugins/ollama/src/index.ts | 13 - .../ollama/tests/embeddings_live_test.ts | 59 ---- js/plugins/ollama/tests/embeddings_test.ts | 135 -------- .../vertexai/src/openai_compatibility.ts | 2 +- js/pnpm-lock.yaml | 17 +- js/testapps/flow-simple-ai/src/index.ts | 39 +-- 33 files changed, 351 insertions(+), 1924 deletions(-) delete mode 100644 docs/errors/no_new_actions_at_runtime.md delete mode 100644 genkit-tools/common/src/eval/localFileDatasetStore.ts delete mode 100644 genkit-tools/common/tests/eval/localFileDatasetStore_test.ts delete mode 100644 js/ai/src/generateAction.ts delete mode 100644 js/plugins/ollama/src/embeddings.ts delete mode 100644 js/plugins/ollama/tests/embeddings_live_test.ts delete mode 100644 js/plugins/ollama/tests/embeddings_test.ts diff --git a/docs/errors/no_new_actions_at_runtime.md b/docs/errors/no_new_actions_at_runtime.md deleted file mode 100644 index cfed39eac..000000000 --- a/docs/errors/no_new_actions_at_runtime.md +++ /dev/null @@ -1,22 +0,0 @@ -# No new actions at runtime error - -Defining new actions at runtime is not allowed. - -✅ DO: - -```ts -const prompt = defineDotprompt({...}) - -const flow = defineFlow({...}, async (input) => { - await prompt.generate(...); -}) -``` - -❌ DON'T: - -```ts -const flow = defineFlow({...}, async (input) => { - const prompt = defineDotprompt({...}) - prompt.generate(...); -}) -``` diff --git a/genkit-tools/cli/src/commands/eval-flow.ts b/genkit-tools/cli/src/commands/eval-flow.ts index b56eeaac2..3245fca41 100644 --- a/genkit-tools/cli/src/commands/eval-flow.ts +++ b/genkit-tools/cli/src/commands/eval-flow.ts @@ -180,7 +180,7 @@ export const evalFlow = new Command('eval:flow') const evalRun = { key: { - actionRef: `/flow/${flowName}`, + actionId: flowName, evalRunId, createdAt: new Date().toISOString(), }, diff --git a/genkit-tools/common/src/eval/index.ts b/genkit-tools/common/src/eval/index.ts index 51b157ae3..28f2374b9 100644 --- a/genkit-tools/common/src/eval/index.ts +++ b/genkit-tools/common/src/eval/index.ts @@ -14,8 +14,7 @@ * limitations under the License. */ -import { DatasetStore, EvalStore } from '../types/eval'; -import { LocalFileDatasetStore } from './localFileDatasetStore'; +import { EvalStore } from '../types/eval'; import { LocalFileEvalStore } from './localFileEvalStore'; export { EvalFlowInput, EvalFlowInputSchema } from '../types/eval'; export * from './exporter'; @@ -25,7 +24,3 @@ export function getEvalStore(): EvalStore { // TODO: This should provide EvalStore, based on tools config. return LocalFileEvalStore.getEvalStore(); } - -export function getDatasetStore(): DatasetStore { - return LocalFileDatasetStore.getDatasetStore(); -} diff --git a/genkit-tools/common/src/eval/localFileDatasetStore.ts b/genkit-tools/common/src/eval/localFileDatasetStore.ts deleted file mode 100644 index c5a1948f2..000000000 --- a/genkit-tools/common/src/eval/localFileDatasetStore.ts +++ /dev/null @@ -1,222 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import crypto from 'crypto'; -import fs from 'fs'; -import { readFile, rm, writeFile } from 'fs/promises'; -import path from 'path'; -import { v4 as uuidv4 } from 'uuid'; -import { CreateDatasetRequest, UpdateDatasetRequest } from '../types/apis'; -import { - Dataset, - DatasetMetadata, - DatasetStore, - EvalFlowInputSchema, -} from '../types/eval'; -import { logger } from '../utils/logger'; - -/** - * A local, file-based DatasetStore implementation. - */ -export class LocalFileDatasetStore implements DatasetStore { - private readonly storeRoot; - private readonly indexFile; - private readonly INDEX_DELIMITER = '\n'; - private static cachedDatasetStore: LocalFileDatasetStore | null = null; - - private constructor(storeRoot: string) { - this.storeRoot = storeRoot; - this.indexFile = this.getIndexFilePath(); - fs.mkdirSync(this.storeRoot, { recursive: true }); - if (!fs.existsSync(this.indexFile)) { - fs.writeFileSync(path.resolve(this.indexFile), ''); - } - logger.info( - `Initialized local file dataset store at root: ${this.storeRoot}` - ); - } - - static getDatasetStore() { - if (!this.cachedDatasetStore) { - this.cachedDatasetStore = new LocalFileDatasetStore( - this.generateRootPath() - ); - } - return this.cachedDatasetStore; - } - - static reset() { - this.cachedDatasetStore = null; - } - - async createDataset(req: CreateDatasetRequest): Promise { - return this.createDatasetInternal(req.data, req.displayName); - } - - private async createDatasetInternal( - data: Dataset, - displayName?: string - ): Promise { - const datasetId = this.generateDatasetId(); - const filePath = path.resolve( - this.storeRoot, - this.generateFileName(datasetId) - ); - - if (fs.existsSync(filePath)) { - logger.error(`Dataset already exists at ` + filePath); - throw new Error( - `Create dataset failed: file already exists at {$filePath}` - ); - } - - logger.info(`Saving Dataset to ` + filePath); - await writeFile(filePath, JSON.stringify(data)); - - const now = new Date().toString(); - const metadata = { - datasetId, - size: Array.isArray(data) ? data.length : data.samples.length, - version: 1, - displayName: displayName, - createTime: now, - updateTime: now, - }; - - let metadataMap = await this.getMetadataMap(); - metadataMap[datasetId] = metadata; - - logger.debug( - `Saving DatasetMetadata for ID ${datasetId} to ` + - path.resolve(this.indexFile) - ); - - await writeFile(path.resolve(this.indexFile), JSON.stringify(metadataMap)); - return metadata; - } - - async updateDataset(req: UpdateDatasetRequest): Promise { - const datasetId = req.datasetId; - const filePath = path.resolve( - this.storeRoot, - this.generateFileName(datasetId) - ); - if (!fs.existsSync(filePath)) { - throw new Error(`Update dataset failed: dataset not found`); - } - - let metadataMap = await this.getMetadataMap(); - const prevMetadata = metadataMap[datasetId]; - if (!prevMetadata) { - throw new Error(`Update dataset failed: dataset metadata not found`); - } - - logger.info(`Updating Dataset at ` + filePath); - await writeFile(filePath, JSON.stringify(req.patch)); - - const now = new Date().toString(); - const newMetadata = { - datasetId: datasetId, - size: Array.isArray(req.patch) - ? req.patch.length - : req.patch.samples.length, - version: prevMetadata.version + 1, - displayName: req.displayName, - createTime: prevMetadata.createTime, - updateTime: now, - }; - - logger.debug( - `Updating DatasetMetadata for ID ${datasetId} at ` + - path.resolve(this.indexFile) - ); - // Replace the metadata object in the metadata map - metadataMap[datasetId] = newMetadata; - await writeFile(path.resolve(this.indexFile), JSON.stringify(metadataMap)); - - return newMetadata; - } - - async getDataset(datasetId: string): Promise { - const filePath = path.resolve( - this.storeRoot, - this.generateFileName(datasetId) - ); - if (!fs.existsSync(filePath)) { - throw new Error(`Dataset not found for dataset ID {$id}`); - } - return await readFile(filePath, 'utf8').then((data) => - EvalFlowInputSchema.parse(JSON.parse(data)) - ); - } - - async listDatasets(): Promise { - return this.getMetadataMap().then((metadataMap) => { - let metadatas = []; - - for (var key in metadataMap) { - metadatas.push(metadataMap[key]); - } - return metadatas; - }); - } - - async deleteDataset(datasetId: string): Promise { - const filePath = path.resolve( - this.storeRoot, - this.generateFileName(datasetId) - ); - await rm(filePath); - - let metadataMap = await this.getMetadataMap(); - delete metadataMap[datasetId]; - - logger.debug( - `Deleting DatasetMetadata for ID ${datasetId} in ` + - path.resolve(this.indexFile) - ); - await writeFile(path.resolve(this.indexFile), JSON.stringify(metadataMap)); - } - - private static generateRootPath(): string { - const rootHash = crypto - .createHash('md5') - .update(process.cwd() || 'unknown') - .digest('hex'); - return path.resolve(process.cwd(), `.genkit/${rootHash}/datasets`); - } - - private generateDatasetId(): string { - return uuidv4(); - } - - private generateFileName(datasetId: string): string { - return `${datasetId}.json`; - } - - private getIndexFilePath(): string { - return path.resolve(this.storeRoot, 'index.json'); - } - - private async getMetadataMap(): Promise { - if (!fs.existsSync(this.indexFile)) { - return Promise.resolve({} as any); - } - return await readFile(path.resolve(this.indexFile), 'utf8').then((data) => - JSON.parse(data) - ); - } -} diff --git a/genkit-tools/common/src/eval/localFileEvalStore.ts b/genkit-tools/common/src/eval/localFileEvalStore.ts index e9a1c44ea..9d1706631 100644 --- a/genkit-tools/common/src/eval/localFileEvalStore.ts +++ b/genkit-tools/common/src/eval/localFileEvalStore.ts @@ -61,7 +61,10 @@ export class LocalFileEvalStore implements EvalStore { } async save(evalRun: EvalRun): Promise { - const fileName = this.generateFileName(evalRun.key.evalRunId); + const fileName = this.generateFileName( + evalRun.key.evalRunId, + evalRun.key.actionId + ); logger.info( `Saving EvalRun ${evalRun.key.evalRunId} to ` + @@ -82,10 +85,13 @@ export class LocalFileEvalStore implements EvalStore { ); } - async load(evalRunId: string): Promise { + async load( + evalRunId: string, + actionId?: string + ): Promise { const filePath = path.resolve( this.storeRoot, - this.generateFileName(evalRunId) + this.generateFileName(evalRunId, actionId) ); if (!fs.existsSync(filePath)) { return undefined; @@ -111,8 +117,8 @@ export class LocalFileEvalStore implements EvalStore { logger.debug(`Found keys: ${JSON.stringify(keys)}`); - if (query?.filter?.actionRef) { - keys = keys.filter((key) => key.actionRef === query?.filter?.actionRef); + if (query?.filter?.actionId) { + keys = keys.filter((key) => key.actionId === query?.filter?.actionId); logger.debug(`Filtered keys: ${JSON.stringify(keys)}`); } @@ -121,8 +127,12 @@ export class LocalFileEvalStore implements EvalStore { }; } - private generateFileName(evalRunId: string): string { - return `${evalRunId}.json`; + private generateFileName(evalRunId: string, actionId?: string): string { + if (!actionId) { + return `${evalRunId}.json`; + } + + return `${actionId?.replace('/', '_')}-${evalRunId}.json`; } private getIndexFilePath(): string { diff --git a/genkit-tools/common/src/server/router.ts b/genkit-tools/common/src/server/router.ts index 853aa35d1..5f1859db8 100644 --- a/genkit-tools/common/src/server/router.ts +++ b/genkit-tools/common/src/server/router.ts @@ -14,8 +14,7 @@ * limitations under the License. */ import { initTRPC, TRPCError } from '@trpc/server'; -import { z } from 'zod'; -import { getDatasetStore, getEvalStore } from '../eval'; +import { getEvalStore } from '../eval'; import { Runner } from '../runner/runner'; import { GenkitToolsError } from '../runner/types'; import { Action } from '../types/action'; @@ -191,8 +190,9 @@ export const TOOLS_SERVER_ROUTER = (runner: Runner) => .output(evals.EvalRunSchema) .query(async ({ input }) => { const parts = input.name.split('/'); - const evalRunId = parts[1]; - const evalRun = await getEvalStore().load(evalRunId); + const evalRunId = parts[3]; + const actionId = parts[1] !== '-' ? parts[1] : undefined; + const evalRun = await getEvalStore().load(evalRunId, actionId); if (!evalRun) { throw new TRPCError({ code: 'NOT_FOUND', @@ -202,51 +202,6 @@ export const TOOLS_SERVER_ROUTER = (runner: Runner) => return evalRun; }), - /** Retrieves all eval datasets */ - listDatasets: loggedProcedure - .input(z.void()) - .output(z.array(evals.DatasetMetadataSchema)) - .query(async () => { - const response = await getDatasetStore().listDatasets(); - return response; - }), - - /** Retrieves an existing dataset */ - getDataset: loggedProcedure - .input(z.string()) - .output(evals.EvalFlowInputSchema) - .query(async ({ input }) => { - const response = await getDatasetStore().getDataset(input); - return response; - }), - - /** Creates a new dataset */ - createDataset: loggedProcedure - .input(apis.CreateDatasetRequestSchema) - .output(evals.DatasetMetadataSchema) - .query(async ({ input }) => { - const response = await getDatasetStore().createDataset(input); - return response; - }), - - /** Updates an exsting dataset */ - updateDataset: loggedProcedure - .input(apis.UpdateDatasetRequestSchema) - .output(evals.DatasetMetadataSchema) - .query(async ({ input }) => { - const response = await getDatasetStore().updateDataset(input); - return response; - }), - - /** Deletes an exsting dataset */ - deleteDataset: loggedProcedure - .input(z.string()) - .output(z.void()) - .query(async ({ input }) => { - const response = await getDatasetStore().deleteDataset(input); - return response; - }), - /** Send a screen view analytics event */ sendPageView: t.procedure .input(apis.PageViewSchema) diff --git a/genkit-tools/common/src/types/apis.ts b/genkit-tools/common/src/types/apis.ts index 24c40e341..6ed3e4961 100644 --- a/genkit-tools/common/src/types/apis.ts +++ b/genkit-tools/common/src/types/apis.ts @@ -15,7 +15,7 @@ */ import { z } from 'zod'; -import { EvalFlowInputSchema, EvalRunKeySchema } from './eval'; +import { EvalRunKeySchema } from './eval'; import { FlowStateSchema } from './flow'; import { GenerationCommonConfigSchema, @@ -113,7 +113,7 @@ export type PageView = z.infer; export const ListEvalKeysRequestSchema = z.object({ filter: z .object({ - actionRef: z.string().optional(), + actionId: z.string().optional(), }) .optional(), }); @@ -127,22 +127,8 @@ export const ListEvalKeysResponseSchema = z.object({ export type ListEvalKeysResponse = z.infer; export const GetEvalRunRequestSchema = z.object({ - // Eval run name in the form evalRuns/{evalRunId} + // Eval run name in the form actions/{action}/evalRun/{evalRun} + // where `action` can be blank e.g. actions/-/evalRun/{evalRun} name: z.string(), }); export type GetEvalRunRequest = z.infer; - -export const CreateDatasetRequestSchema = z.object({ - data: EvalFlowInputSchema, - displayName: z.string().optional(), -}); - -export type CreateDatasetRequest = z.infer; - -export const UpdateDatasetRequestSchema = z.object({ - /** Supports upsert */ - patch: EvalFlowInputSchema, - datasetId: z.string(), - displayName: z.string().optional(), -}); -export type UpdateDatasetRequest = z.infer; diff --git a/genkit-tools/common/src/types/eval.ts b/genkit-tools/common/src/types/eval.ts index 303a24b88..75556d048 100644 --- a/genkit-tools/common/src/types/eval.ts +++ b/genkit-tools/common/src/types/eval.ts @@ -15,12 +15,7 @@ */ import { z } from 'zod'; -import { - CreateDatasetRequest, - ListEvalKeysRequest, - ListEvalKeysResponse, - UpdateDatasetRequest, -} from './apis'; +import { ListEvalKeysRequest, ListEvalKeysResponse } from './apis'; /** * This file defines schema and types that are used by the Eval store. @@ -53,12 +48,6 @@ export const EvalFlowInputSchema = z.union([ ]); export type EvalFlowInput = z.infer; -/** - * Alias for EvalFlowInput to be used in the DatasetStore related APIs. - * We may deprecate EvalFlowInput in favor of this in the future. - */ -export type Dataset = z.infer; - /** * A record that is ready for evaluation. * @@ -99,8 +88,7 @@ export type EvalResult = z.infer; * A unique identifier for an Evaluation Run. */ export const EvalRunKeySchema = z.object({ - actionRef: z.string().optional(), - datasetId: z.string().optional(), + actionId: z.string().optional(), evalRunId: z.string(), createdAt: z.string(), }); @@ -125,7 +113,7 @@ export const EvalRunSchema = z.object({ export type EvalRun = z.infer; /** - * Eval store persistence interface. + * Eval dataset store persistence interface. */ export interface EvalStore { /** @@ -137,8 +125,9 @@ export interface EvalStore { /** * Load a single EvalRun from storage * @param evalRunId the ID of the EvalRun + * @param actionId (optional) the ID of the action used to generate output. */ - load(evalRunId: string): Promise; + load(evalRunId: string, actionId?: string): Promise; /** * List the keys of all EvalRuns from storage @@ -146,56 +135,3 @@ export interface EvalStore { */ list(query?: ListEvalKeysRequest): Promise; } - -/** - * Metadata for Dataset objects containing version, create and update time, etc. - */ -export const DatasetMetadataSchema = z.object({ - /** autogenerated */ - datasetId: z.string(), - size: z.number(), - /** 1 for v1, 2 for v2, etc */ - version: z.number(), - displayName: z.string().optional(), - createTime: z.string(), - updateTime: z.string(), -}); -export type DatasetMetadata = z.infer; - -/** - * Eval dataset store persistence interface. - */ -export interface DatasetStore { - /** - * Create new dataset with the given data - * @param req create requeest with the data - * @returns dataset metadata - */ - createDataset(req: CreateDatasetRequest): Promise; - - /** - * Update dataset - * @param req update requeest with new data - * @returns updated dataset metadata - */ - updateDataset(req: UpdateDatasetRequest): Promise; - - /** - * Get existing dataset - * @param datasetId the ID of the dataset - * @returns dataset ready for inference - */ - getDataset(datasetId: string): Promise; - - /** - * List all existing datasets - * @returns array of dataset metadata objects - */ - listDatasets(): Promise; - - /** - * Delete existing dataset - * @param datasetId the ID of the dataset - */ - deleteDataset(datasetId: string): Promise; -} diff --git a/genkit-tools/common/src/types/model.ts b/genkit-tools/common/src/types/model.ts index efc0f2db4..3bfc2f444 100644 --- a/genkit-tools/common/src/types/model.ts +++ b/genkit-tools/common/src/types/model.ts @@ -14,7 +14,7 @@ * limitations under the License. */ import { z } from 'zod'; -import { DocumentDataSchema } from './document'; +import { DocumentDataSchema } from './document.js'; // // IMPORTANT: Keep this file in sync with genkit/ai/src/model.ts! diff --git a/genkit-tools/common/tests/eval/exporter_test.ts b/genkit-tools/common/tests/eval/exporter_test.ts index fc81efbff..8955661f1 100644 --- a/genkit-tools/common/tests/eval/exporter_test.ts +++ b/genkit-tools/common/tests/eval/exporter_test.ts @@ -80,7 +80,7 @@ const EVAL_RESULTS: EvalResult[] = [ ]; const EVAL_RUN_KEY: EvalRunKey = { - actionRef: 'flow/myAwesomeFlow', + actionId: 'flow/myAwesomeFlow', evalRunId: 'abc1234', createdAt: new Date().toISOString(), }; diff --git a/genkit-tools/common/tests/eval/localFileDatasetStore_test.ts b/genkit-tools/common/tests/eval/localFileDatasetStore_test.ts deleted file mode 100644 index d3f1691fa..000000000 --- a/genkit-tools/common/tests/eval/localFileDatasetStore_test.ts +++ /dev/null @@ -1,290 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { - afterEach, - beforeEach, - describe, - expect, - it, - jest, -} from '@jest/globals'; -import fs from 'fs'; -import { LocalFileDatasetStore } from '../../src/eval/localFileDatasetStore'; -import { - CreateDatasetRequestSchema, - UpdateDatasetRequestSchema, -} from '../../src/types/apis'; -import { DatasetStore } from '../../src/types/eval'; - -const FAKE_TIME = new Date('2024-02-03T12:05:33.243Z'); - -const SAMPLE_DATASET_1_V1 = { - samples: [ - { - input: 'Cats are evil', - reference: 'Sorry no reference', - }, - { - input: 'Dogs are beautiful', - }, - ], -}; - -const SAMPLE_DATASET_1_V2 = { - samples: [ - { - input: 'Cats are evil', - reference: 'Sorry no reference', - }, - { - input: 'Dogs are angels', - }, - { - input: 'Dogs are also super cute', - }, - ], -}; - -const SAMPLE_DATASET_ID_1 = '12345678'; -const SAMPLE_DATASET_NAME_1 = 'dataset-1'; - -const SAMPLE_DATASET_METADATA_1_V1 = { - datasetId: SAMPLE_DATASET_ID_1, - size: 2, - version: 1, - displayName: SAMPLE_DATASET_NAME_1, - createTime: FAKE_TIME.toString(), - updateTime: FAKE_TIME.toString(), -}; -const SAMPLE_DATASET_METADATA_1_V2 = { - datasetId: SAMPLE_DATASET_ID_1, - size: 3, - version: 2, - displayName: SAMPLE_DATASET_NAME_1, - createTime: FAKE_TIME.toString(), - updateTime: FAKE_TIME.toString(), -}; - -const CREATE_DATASET_REQUEST = CreateDatasetRequestSchema.parse({ - data: SAMPLE_DATASET_1_V1, - displayName: SAMPLE_DATASET_NAME_1, -}); - -const UPDATE_DATASET_REQUEST = UpdateDatasetRequestSchema.parse({ - patch: SAMPLE_DATASET_1_V2, - datasetId: SAMPLE_DATASET_ID_1, - displayName: SAMPLE_DATASET_NAME_1, -}); - -const SAMPLE_DATASET_ID_2 = '22345678'; -const SAMPLE_DATASET_NAME_2 = 'dataset-2'; - -const SAMPLE_DATASET_METADATA_2 = { - datasetId: SAMPLE_DATASET_ID_2, - size: 5, - version: 1, - displayName: SAMPLE_DATASET_NAME_2, - createTime: FAKE_TIME.toString(), - updateTime: FAKE_TIME.toString(), -}; - -jest.mock('crypto', () => { - return { - createHash: jest.fn().mockReturnThis(), - update: jest.fn().mockReturnThis(), - digest: jest.fn(() => 'store-root'), - }; -}); - -jest.mock('uuid', () => ({ - v4: () => SAMPLE_DATASET_ID_1, -})); - -jest.useFakeTimers({ advanceTimers: true }); -jest.setSystemTime(FAKE_TIME); - -describe('localFileDatasetStore', () => { - let DatasetStore: DatasetStore; - - beforeEach(() => { - LocalFileDatasetStore.reset(); - DatasetStore = LocalFileDatasetStore.getDatasetStore() as DatasetStore; - }); - - afterEach(() => { - jest.restoreAllMocks(); - }); - - describe('createDataset', () => { - it('writes and updates index for new dataset', async () => { - fs.promises.writeFile = jest.fn(async () => Promise.resolve(undefined)); - fs.promises.appendFile = jest.fn(async () => Promise.resolve(undefined)); - // For index file reads - fs.promises.readFile = jest.fn(async () => - Promise.resolve(JSON.stringify({}) as any) - ); - - const datasetMetadata = await DatasetStore.createDataset( - CREATE_DATASET_REQUEST - ); - - expect(fs.promises.writeFile).toHaveBeenCalledTimes(2); - expect(fs.promises.writeFile).toHaveBeenNthCalledWith( - 1, - expect.stringContaining('datasets/12345678.json'), - JSON.stringify(CREATE_DATASET_REQUEST.data) - ); - const metadataMap = { - [SAMPLE_DATASET_ID_1]: SAMPLE_DATASET_METADATA_1_V1, - }; - expect(fs.promises.writeFile).toHaveBeenNthCalledWith( - 2, - expect.stringContaining('datasets/index.json'), - JSON.stringify(metadataMap) - ); - expect(datasetMetadata).toMatchObject(SAMPLE_DATASET_METADATA_1_V1); - }); - - it('fails request if dataset already exists', async () => { - fs.existsSync = jest.fn(() => true); - - expect(async () => { - await DatasetStore.createDataset(CREATE_DATASET_REQUEST); - }).rejects.toThrow(); - - expect(fs.promises.writeFile).toBeCalledTimes(0); - }); - }); - - describe('updateDataset', () => { - it('succeeds for existing dataset', async () => { - fs.existsSync = jest.fn(() => true); - let metadataMap = { - [SAMPLE_DATASET_ID_1]: SAMPLE_DATASET_METADATA_1_V1, - [SAMPLE_DATASET_ID_2]: SAMPLE_DATASET_METADATA_2, - }; - // For index file reads - fs.promises.readFile = jest.fn(async () => - Promise.resolve(JSON.stringify(metadataMap) as any) - ); - fs.promises.writeFile = jest.fn(async () => Promise.resolve(undefined)); - fs.promises.appendFile = jest.fn(async () => Promise.resolve(undefined)); - - const datasetMetadata = await DatasetStore.updateDataset( - UPDATE_DATASET_REQUEST - ); - - expect(fs.promises.writeFile).toHaveBeenCalledTimes(2); - expect(fs.promises.writeFile).toHaveBeenNthCalledWith( - 1, - expect.stringContaining('datasets/12345678.json'), - JSON.stringify(SAMPLE_DATASET_1_V2) - ); - const updatedMetadataMap = { - [SAMPLE_DATASET_ID_1]: SAMPLE_DATASET_METADATA_1_V2, - [SAMPLE_DATASET_ID_2]: SAMPLE_DATASET_METADATA_2, - }; - expect(fs.promises.writeFile).toHaveBeenNthCalledWith( - 2, - expect.stringContaining('datasets/index.json'), - JSON.stringify(updatedMetadataMap) - ); - expect(datasetMetadata).toMatchObject(SAMPLE_DATASET_METADATA_1_V2); - }); - - it('fails for non existing dataset', async () => { - fs.existsSync = jest.fn(() => false); - - expect(async () => { - await DatasetStore.updateDataset(UPDATE_DATASET_REQUEST); - }).rejects.toThrow(); - - expect(fs.promises.writeFile).toBeCalledTimes(0); - }); - }); - - describe('listDatasets', () => { - it('succeeds for zero datasets', async () => { - fs.existsSync = jest.fn(() => false); - - const metadatas = await DatasetStore.listDatasets(); - - expect(metadatas).toMatchObject([]); - }); - - it('succeeds for existing datasets', async () => { - fs.existsSync = jest.fn(() => true); - const metadataMap = { - [SAMPLE_DATASET_ID_1]: SAMPLE_DATASET_METADATA_1_V1, - [SAMPLE_DATASET_ID_2]: SAMPLE_DATASET_METADATA_2, - }; - fs.promises.readFile = jest.fn(async () => - Promise.resolve(JSON.stringify(metadataMap) as any) - ); - - const metadatas = await DatasetStore.listDatasets(); - - expect(metadatas).toMatchObject([ - SAMPLE_DATASET_METADATA_1_V1, - SAMPLE_DATASET_METADATA_2, - ]); - }); - }); - - describe('getDataset', () => { - it('succeeds for existing dataset', async () => { - fs.existsSync = jest.fn(() => true); - fs.promises.readFile = jest.fn(async () => - Promise.resolve(JSON.stringify(SAMPLE_DATASET_1_V1) as any) - ); - - const fetchedDataset = await DatasetStore.getDataset(SAMPLE_DATASET_ID_1); - - expect(fetchedDataset).toMatchObject(SAMPLE_DATASET_1_V1); - }); - - it('fails for non existing dataset', async () => { - // TODO: Implement this. - }); - }); - - describe('deleteDataset', () => { - it('deletes dataset and updates index', async () => { - fs.promises.rm = jest.fn(async () => Promise.resolve()); - let metadataMap = { - [SAMPLE_DATASET_ID_1]: SAMPLE_DATASET_METADATA_1_V1, - [SAMPLE_DATASET_ID_2]: SAMPLE_DATASET_METADATA_2, - }; - fs.promises.readFile = jest.fn(async () => - Promise.resolve(JSON.stringify(metadataMap) as any) - ); - - await DatasetStore.deleteDataset(SAMPLE_DATASET_ID_1); - - expect(fs.promises.rm).toHaveBeenCalledWith( - expect.stringContaining('datasets/12345678.json') - ); - let updatedMetadataMap = { - [SAMPLE_DATASET_ID_2]: SAMPLE_DATASET_METADATA_2, - }; - expect(fs.promises.writeFile).toHaveBeenCalledWith( - expect.stringContaining('datasets/index.json'), - JSON.stringify(updatedMetadataMap) - ); - }); - }); -}); diff --git a/genkit-tools/common/tests/eval/localFileEvalStore_test.ts b/genkit-tools/common/tests/eval/localFileEvalStore_test.ts index 72fb8758e..6885c84c3 100644 --- a/genkit-tools/common/tests/eval/localFileEvalStore_test.ts +++ b/genkit-tools/common/tests/eval/localFileEvalStore_test.ts @@ -86,7 +86,7 @@ const METRICS_METADATA = { const EVAL_RUN_WITH_ACTION = EvalRunSchema.parse({ key: { - actionRef: 'flow/tellMeAJoke', + actionId: 'flow/tellMeAJoke', evalRunId: 'abc1234', createdAt: new Date().toISOString(), }, @@ -125,7 +125,7 @@ describe('localFileEvalStore', () => { await evalStore.save(EVAL_RUN_WITH_ACTION); expect(fs.promises.writeFile).toHaveBeenCalledWith( - `/tmp/.genkit/store-root/evals/abc1234.json`, + `/tmp/.genkit/store-root/evals/flow_tellMeAJoke-abc1234.json`, JSON.stringify(EVAL_RUN_WITH_ACTION) ); expect(fs.promises.appendFile).toHaveBeenCalledWith( @@ -155,7 +155,8 @@ describe('localFileEvalStore', () => { Promise.resolve(JSON.stringify(EVAL_RUN_WITH_ACTION) as any) ); const fetchedEvalRun = await evalStore.load( - EVAL_RUN_WITH_ACTION.key.evalRunId + EVAL_RUN_WITH_ACTION.key.evalRunId, + EVAL_RUN_WITH_ACTION.key.actionId ); expect(fetchedEvalRun).toMatchObject(EVAL_RUN_WITH_ACTION); }); @@ -166,7 +167,8 @@ describe('localFileEvalStore', () => { Promise.resolve(JSON.stringify(EVAL_RUN_WITHOUT_ACTION) as any) ); const fetchedEvalRun = await evalStore.load( - EVAL_RUN_WITHOUT_ACTION.key.evalRunId + EVAL_RUN_WITHOUT_ACTION.key.evalRunId, + EVAL_RUN_WITHOUT_ACTION.key.actionId ); expect(fetchedEvalRun).toMatchObject(EVAL_RUN_WITHOUT_ACTION); }); @@ -175,7 +177,8 @@ describe('localFileEvalStore', () => { fs.existsSync = jest.fn(() => false); const fetchedEvalRun = await evalStore.load( - EVAL_RUN_WITH_ACTION.key.evalRunId + EVAL_RUN_WITH_ACTION.key.evalRunId, + EVAL_RUN_WITH_ACTION.key.actionId ); expect(fetchedEvalRun).toBeUndefined(); }); @@ -205,7 +208,7 @@ describe('localFileEvalStore', () => { ); const fetchedEvalKeys = await evalStore.list({ - filter: { actionRef: EVAL_RUN_WITH_ACTION.key.actionRef }, + filter: { actionId: EVAL_RUN_WITH_ACTION.key.actionId }, }); const expectedKeys = { evalRunKeys: [EVAL_RUN_WITH_ACTION.key] }; diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 8d412942e..ebda1c322 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -22,15 +22,14 @@ import { StreamingCallback, } from '@genkit-ai/core'; import { lookupAction } from '@genkit-ai/core/registry'; -import { toJsonSchema, validateSchema } from '@genkit-ai/core/schema'; +import { + parseSchema, + toJsonSchema, + validateSchema, +} from '@genkit-ai/core/schema'; import { z } from 'zod'; import { DocumentData } from './document.js'; import { extractJson } from './extract.js'; -import { - generateAction, - GenerateUtilParamSchema, - inferRoleFromParts, -} from './generateAction.js'; import { CandidateData, GenerateRequest, @@ -43,11 +42,16 @@ import { ModelArgument, ModelReference, Part, - ToolDefinition, + Role, ToolRequestPart, ToolResponsePart, } from './model.js'; -import { resolveTools, ToolArgument, toToolDefinition } from './tool.js'; +import { + resolveTools, + ToolAction, + ToolArgument, + toToolDefinition, +} from './tool.js'; /** * Message represents a single role's contribution to a generation. Each message @@ -421,6 +425,27 @@ export class GenerateResponseChunk } } +function getRoleFromPart(part: Part): Role { + if (part.toolRequest !== undefined) return 'model'; + if (part.toolResponse !== undefined) return 'tool'; + if (part.text !== undefined) return 'user'; + if (part.media !== undefined) return 'user'; + if (part.data !== undefined) return 'user'; + throw new Error('No recognized fields in content'); +} + +function inferRoleFromParts(parts: Part[]): Role { + const uniqueRoles = new Set(); + for (const part of parts) { + const role = getRoleFromPart(part); + uniqueRoles.add(role); + if (uniqueRoles.size > 1) { + throw new Error('Contents contain mixed roles'); + } + } + return Array.from(uniqueRoles)[0]; +} + export async function toGenerateRequest( options: GenerateOptions ): Promise { @@ -492,6 +517,29 @@ export interface GenerateOptions< streamingCallback?: StreamingCallback; } +const isValidCandidate = ( + candidate: CandidateData, + tools: Action[] +): boolean => { + // Check if tool calls are vlaid + const toolCalls = candidate.message.content.filter( + (part) => !!part.toolRequest + ); + + // make sure every tool called exists and has valid input + return toolCalls.every((toolCall) => { + const tool = tools?.find( + (tool) => tool.__action.name === toolCall.toolRequest?.name + ); + if (!tool) return false; + const { valid } = validateSchema(toolCall.toolRequest?.input, { + schema: tool.__action.inputSchema, + jsonSchema: tool.__action.inputJsonSchema, + }); + return valid; + }); +}; + async function resolveModel(options: GenerateOptions): Promise { let model = options.model; if (!model) { @@ -556,6 +604,7 @@ export class NoValidCandidatesError extends GenkitError { * @param options The options for this generation request. * @returns The generated response based on the provided parameters. */ + export async function generate< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, @@ -571,51 +620,120 @@ export async function generate< throw new Error(`Model ${JSON.stringify(resolvedOptions.model)} not found`); } - // convert tools to action refs (strings). - let tools: (string | ToolDefinition)[] | undefined; - if (resolvedOptions.tools) { - tools = resolvedOptions.tools.map((t) => { - if (typeof t === 'string') { - return `/tool/${t}`; - } else if ((t as Action).__action) { - return `/${(t as Action).__action.metadata?.type}/${(t as Action).__action.name}`; - } else if (t.name) { - return `/tool/${t.name}`; - } + let tools: ToolAction[] | undefined; + if (resolvedOptions.tools?.length) { + if (!model.__action.metadata?.model.supports?.tools) { throw new Error( - `Unable to determine type of of tool: ${JSON.stringify(t)}` + `Model ${JSON.stringify(resolvedOptions.model)} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` ); + } + tools = await resolveTools(resolvedOptions.tools); + } + + const request = await toGenerateRequest(resolvedOptions); + + const accumulatedChunks: GenerateResponseChunkData[] = []; + + const response = await runWithStreamingCallback( + resolvedOptions.streamingCallback + ? (chunk: GenerateResponseChunkData) => { + // Store accumulated chunk data + accumulatedChunks.push(chunk); + if (resolvedOptions.streamingCallback) { + resolvedOptions.streamingCallback!( + new GenerateResponseChunk(chunk, accumulatedChunks) + ); + } + } + : undefined, + async () => new GenerateResponse>(await model(request), request) + ); + + // throw NoValidCandidates if all candidates are blocked or + if ( + !response.candidates.some((c) => + ['stop', 'length'].includes(c.finishReason) + ) + ) { + throw new NoValidCandidatesError({ + message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`, + response, }); } - const params: z.infer = { - model: model.__action.name, - prompt: resolvedOptions.prompt, - context: resolvedOptions.context, - history: resolvedOptions.history, - tools, - candidates: resolvedOptions.candidates, - config: resolvedOptions.config, - output: resolvedOptions.output && { - format: resolvedOptions.output.format, - jsonSchema: resolvedOptions.output.schema - ? toJsonSchema({ - schema: resolvedOptions.output.schema, - jsonSchema: resolvedOptions.output.jsonSchema, - }) - : resolvedOptions.output.jsonSchema, - }, - returnToolRequests: resolvedOptions.returnToolRequests, - }; + if (resolvedOptions.output?.schema || resolvedOptions.output?.jsonSchema) { + // find a candidate with valid output schema + const candidateErrors = response.candidates.map((c) => { + // don't validate messages that have no text or data + if (c.text() === '' && c.data() === null) return null; - return await runWithStreamingCallback( - resolvedOptions.streamingCallback, - async () => - new GenerateResponse( - await generateAction(params), - await toGenerateRequest(resolvedOptions) - ) + try { + parseSchema(c.output(), { + jsonSchema: resolvedOptions.output?.jsonSchema, + schema: resolvedOptions.output?.schema, + }); + return null; + } catch (e) { + return e as Error; + } + }); + // if all candidates have a non-null error... + if (candidateErrors.every((c) => !!c)) { + throw new NoValidCandidatesError({ + message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`, + response, + detail: { + candidateErrors: candidateErrors, + }, + }); + } + } + + // Pick the first valid candidate. + let selected: Candidate> | undefined; + for (const candidate of response.candidates) { + if (isValidCandidate(candidate, tools || [])) { + selected = candidate; + break; + } + } + + if (!selected) { + throw new Error('No valid candidates found'); + } + + const toolCalls = selected.message.content.filter( + (part) => !!part.toolRequest + ); + if (resolvedOptions.returnToolRequests || toolCalls.length === 0) { + return response; + } + const toolResponses: ToolResponsePart[] = await Promise.all( + toolCalls.map(async (part) => { + if (!part.toolRequest) { + throw Error( + 'Tool request expected but not provided in tool request part' + ); + } + const tool = tools?.find( + (tool) => tool.__action.name === part.toolRequest?.name + ); + if (!tool) { + throw Error('Tool not found'); + } + return { + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output: await tool(part.toolRequest?.input), + }, + }; + }) ); + resolvedOptions.history = request.messages; + resolvedOptions.history.push(selected.message); + resolvedOptions.prompt = toolResponses; + return await generate(resolvedOptions); } export type GenerateStreamOptions< diff --git a/js/ai/src/generateAction.ts b/js/ai/src/generateAction.ts deleted file mode 100644 index 2938d1ab2..000000000 --- a/js/ai/src/generateAction.ts +++ /dev/null @@ -1,301 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { - Action, - defineAction, - getStreamingCallback, - runWithStreamingCallback, -} from '@genkit-ai/core'; -import { lookupAction } from '@genkit-ai/core/registry'; -import { - parseSchema, - toJsonSchema, - validateSchema, -} from '@genkit-ai/core/schema'; -import { z } from 'zod'; -import { DocumentDataSchema } from './document.js'; -import { - Candidate, - GenerateResponse, - GenerateResponseChunk, - NoValidCandidatesError, -} from './generate.js'; -import { - CandidateData, - GenerateRequest, - GenerateResponseChunkData, - GenerateResponseSchema, - MessageData, - MessageSchema, - ModelAction, - Part, - PartSchema, - Role, - ToolDefinitionSchema, - ToolResponsePart, -} from './model.js'; -import { ToolAction, toToolDefinition } from './tool.js'; - -export const GenerateUtilParamSchema = z.object({ - /** A model name (e.g. `vertexai/gemini-1.0-pro`). */ - model: z.string(), - /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ - prompt: z.union([z.string(), PartSchema, z.array(PartSchema)]), - /** Retrieved documents to be used as context for this generation. */ - context: z.array(DocumentDataSchema).optional(), - /** Conversation history for multi-turn prompting when supported by the underlying model. */ - history: z.array(MessageSchema).optional(), - /** List of registered tool names for this generation if supported by the underlying model. */ - tools: z.array(z.union([z.string(), ToolDefinitionSchema])).optional(), - /** Number of candidate messages to generate. */ - candidates: z.number().optional(), - /** Configuration for the generation request. */ - config: z.any().optional(), - /** Configuration for the desired output of the request. Defaults to the model's default output if unspecified. */ - output: z - .object({ - format: z - .union([z.literal('text'), z.literal('json'), z.literal('media')]) - .optional(), - jsonSchema: z.any().optional(), - }) - .optional(), - /** When true, return tool calls for manual processing instead of automatically resolving them. */ - returnToolRequests: z.boolean().optional(), -}); - -export const generateAction = defineAction( - { - actionType: 'util', - name: 'generate', - inputSchema: GenerateUtilParamSchema, - outputSchema: GenerateResponseSchema, - }, - async (input) => { - const model = (await lookupAction(`/model/${input.model}`)) as ModelAction; - if (!model) { - throw new Error(`Model ${input.model} not found`); - } - - let tools: ToolAction[] | undefined; - if (input.tools?.length) { - if (!model.__action.metadata?.model.supports?.tools) { - throw new Error( - `Model ${input.model} does not support tools, but some tools were supplied to generate(). Please call generate() without tools if you would like to use this model.` - ); - } - tools = await Promise.all( - input.tools.map(async (toolRef) => { - if (typeof toolRef === 'string') { - const tool = (await lookupAction(toolRef)) as ToolAction; - if (!tool) { - throw new Error(`Tool ${toolRef} not found`); - } - return tool; - } - throw ''; - }) - ); - } - - const request = await actionToGenerateRequest(input, tools); - - const accumulatedChunks: GenerateResponseChunkData[] = []; - - const streamingCallback = getStreamingCallback(); - const response = await runWithStreamingCallback( - streamingCallback - ? (chunk: GenerateResponseChunkData) => { - // Store accumulated chunk data - accumulatedChunks.push(chunk); - if (streamingCallback) { - streamingCallback!( - new GenerateResponseChunk(chunk, accumulatedChunks) - ); - } - } - : undefined, - async () => new GenerateResponse(await model(request)) - ); - - // throw NoValidCandidates if all candidates are blocked or - if ( - !response.candidates.some((c) => - ['stop', 'length'].includes(c.finishReason) - ) - ) { - throw new NoValidCandidatesError({ - message: `All candidates returned finishReason issues: ${JSON.stringify(response.candidates.map((c) => c.finishReason))}`, - response, - }); - } - - if (input.output?.jsonSchema && !response.toolRequests()?.length) { - // find a candidate with valid output schema - const candidateErrors = response.candidates.map((c) => { - // don't validate messages that have no text or data - if (c.text() === '' && c.data() === null) return null; - - try { - parseSchema(c.output(), { - jsonSchema: input.output?.jsonSchema, - }); - return null; - } catch (e) { - return e as Error; - } - }); - // if all candidates have a non-null error... - if (candidateErrors.every((c) => !!c)) { - throw new NoValidCandidatesError({ - message: `Generation resulted in no candidates matching provided output schema.${candidateErrors.map((e, i) => `\n\nCandidate[${i}] ${e!.toString()}`)}`, - response, - detail: { - candidateErrors: candidateErrors, - }, - }); - } - } - - // Pick the first valid candidate. - let selected: Candidate | undefined; - for (const candidate of response.candidates) { - if (isValidCandidate(candidate, tools || [])) { - selected = candidate; - break; - } - } - - if (!selected) { - throw new Error('No valid candidates found'); - } - - const toolCalls = selected.message.content.filter( - (part) => !!part.toolRequest - ); - if (input.returnToolRequests || toolCalls.length === 0) { - return response.toJSON(); - } - const toolResponses: ToolResponsePart[] = await Promise.all( - toolCalls.map(async (part) => { - if (!part.toolRequest) { - throw Error( - 'Tool request expected but not provided in tool request part' - ); - } - const tool = tools?.find( - (tool) => tool.__action.name === part.toolRequest?.name - ); - if (!tool) { - throw Error('Tool not found'); - } - return { - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output: await tool(part.toolRequest?.input), - }, - }; - }) - ); - const nextRequest = { - ...input, - history: [...request.messages, selected.message], - prompt: toolResponses, - }; - return await generateAction(nextRequest); - } -); - -async function actionToGenerateRequest( - options: z.infer, - resolvedTools?: ToolAction[] -): Promise { - const promptMessage: MessageData = { role: 'user', content: [] }; - if (typeof options.prompt === 'string') { - promptMessage.content.push({ text: options.prompt }); - } else if (Array.isArray(options.prompt)) { - promptMessage.role = inferRoleFromParts(options.prompt); - promptMessage.content.push(...(options.prompt as Part[])); - } else { - promptMessage.role = inferRoleFromParts([options.prompt]); - promptMessage.content.push(options.prompt); - } - const messages: MessageData[] = [...(options.history || []), promptMessage]; - - const out = { - messages, - candidates: options.candidates, - config: options.config, - context: options.context, - tools: resolvedTools?.map((tool) => toToolDefinition(tool)) || [], - output: { - format: - options.output?.format || - (options.output?.jsonSchema ? 'json' : 'text'), - schema: toJsonSchema({ - jsonSchema: options.output?.jsonSchema, - }), - }, - }; - if (!out.output.schema) delete out.output.schema; - return out; -} - -const isValidCandidate = ( - candidate: CandidateData, - tools: Action[] -): boolean => { - // Check if tool calls are vlaid - const toolCalls = candidate.message.content.filter( - (part) => !!part.toolRequest - ); - - // make sure every tool called exists and has valid input - return toolCalls.every((toolCall) => { - const tool = tools?.find( - (tool) => tool.__action.name === toolCall.toolRequest?.name - ); - if (!tool) return false; - const { valid } = validateSchema(toolCall.toolRequest?.input, { - schema: tool.__action.inputSchema, - jsonSchema: tool.__action.inputJsonSchema, - }); - return valid; - }); -}; - -export function inferRoleFromParts(parts: Part[]): Role { - const uniqueRoles = new Set(); - for (const part of parts) { - const role = getRoleFromPart(part); - uniqueRoles.add(role); - if (uniqueRoles.size > 1) { - throw new Error('Contents contain mixed roles'); - } - } - return Array.from(uniqueRoles)[0]; -} - -function getRoleFromPart(part: Part): Role { - if (part.toolRequest !== undefined) return 'model'; - if (part.toolResponse !== undefined) return 'tool'; - if (part.text !== undefined) return 'user'; - if (part.media !== undefined) return 'user'; - if (part.data !== undefined) return 'user'; - throw new Error('No recognized fields in content'); -} diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 161040d68..d98bf9c67 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -163,12 +163,11 @@ export const ToolDefinitionSchema = z.object({ description: z.string(), inputSchema: z .record(z.any()) - .describe('Valid JSON Schema representing the input of the tool.') - .nullish(), + .describe('Valid JSON Schema representing the input of the tool.'), outputSchema: z .record(z.any()) .describe('Valid JSON Schema describing the output of the tool.') - .nullish(), + .optional(), }); export type ToolDefinition = z.infer; diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index 67173d6da..2388c401b 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -17,7 +17,7 @@ import assert from 'node:assert'; import { describe, it } from 'node:test'; import { z } from 'zod'; -import { GenerateResponseChunk, generate } from '../../src/generate'; +import { GenerateResponseChunk } from '../../src/generate'; import { Candidate, GenerateOptions, @@ -25,7 +25,7 @@ import { Message, toGenerateRequest, } from '../../src/generate.js'; -import { GenerateResponseChunkData, defineModel } from '../../src/model'; +import { GenerateResponseChunkData } from '../../src/model'; import { CandidateData, GenerateRequest, @@ -581,26 +581,3 @@ describe('GenerateResponseChunk', () => { } }); }); - -const echo = defineModel( - { name: 'echo', supports: { tools: true } }, - async (input) => ({ - candidates: [ - { index: 0, message: input.messages[0], finishReason: 'stop' }, - ], - }) -); - -describe('generate', () => { - it('should preserve the request in the returned response, enabling toHistory()', async () => { - const response = await generate({ - model: echo, - prompt: 'Testing toHistory', - }); - - assert.deepEqual( - response.toHistory().map((m) => m.content[0].text), - ['Testing toHistory', 'Testing toHistory'] - ); - }); -}); diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 559a36420..301988966 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -17,12 +17,7 @@ import { JSONSchema7 } from 'json-schema'; import { AsyncLocalStorage } from 'node:async_hooks'; import * as z from 'zod'; -import { - ActionType, - initializeAllPlugins, - lookupPlugin, - registerAction, -} from './registry.js'; +import { ActionType, lookupPlugin, registerAction } from './registry.js'; import { parseSchema } from './schema.js'; import { SPAN_TYPE_ATTR, @@ -207,16 +202,9 @@ export function defineAction< }, fn: (input: z.infer) => Promise> ): Action { - if (isInRuntimeContext()) { - throw new Error( - 'Cannot define new actions at runtime.\n' + - 'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md' - ); - } - const act = action(config, async (i: I): Promise> => { + const act = action(config, (i: I): Promise> => { setCustomMetadataAttributes({ subtype: config.actionType }); - await initializeAllPlugins(); - return await runInActionRuntimeContext(() => fn(i)); + return fn(i); }); act.__action.actionType = config.actionType; registerAction(config.actionType, act); @@ -250,19 +238,3 @@ export function getStreamingCallback(): StreamingCallback | undefined { } return cb; } - -const runtimeCtxAls = new AsyncLocalStorage(); - -/** - * Checks whether the caller is currently in the runtime context of an action. - */ -export function isInRuntimeContext() { - return !!runtimeCtxAls.getStore(); -} - -/** - * Execute the provided function in the action runtime context. - */ -export function runInActionRuntimeContext(fn: () => R) { - return runtimeCtxAls.run('runtime', fn); -} diff --git a/js/core/src/plugin.ts b/js/core/src/plugin.ts index 97955ce00..a21a3a5f3 100644 --- a/js/core/src/plugin.ts +++ b/js/core/src/plugin.ts @@ -15,7 +15,7 @@ */ import { z } from 'zod'; -import { Action, isInRuntimeContext } from './action.js'; +import { Action } from './action.js'; import { FlowStateStore } from './flowTypes.js'; import { LoggerConfig, TelemetryConfig } from './telemetryTypes.js'; import { TraceStore } from './tracing.js'; @@ -60,12 +60,6 @@ export function genkitPlugin( pluginName: string, initFn: T ): Plugin> { - if (isInRuntimeContext()) { - throw new Error( - 'Cannot define new plugins at runtime.\n' + - 'See: https://github.com/firebase/genkit/blob/main/docs/errors/no_new_actions_at_runtime.md' - ); - } return (...args: Parameters) => ({ name: pluginName, initializer: async () => { diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index 75daeedb7..94647c92b 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -25,7 +25,45 @@ import { TraceStore } from './tracing/types.js'; export type AsyncProvider = () => Promise; -const REGISTRY_KEY = 'genkit__REGISTRY'; +const ACTIONS_BY_ID = 'genkit__ACTIONS_BY_ID'; +const TRACE_STORES_BY_ENV = 'genkit__TRACE_STORES_BY_ENV'; +const FLOW_STATE_STORES_BY_ENV = 'genkit__FLOW_STATE_STORES_BY_ENV'; +const PLUGINS_BY_NAME = 'genkit__PLUGINS_BY_NAME'; +const SCHEMAS_BY_NAME = 'genkit__SCHEMAS_BY_NAME'; + +function actionsById(): Record> { + if (global[ACTIONS_BY_ID] === undefined) { + global[ACTIONS_BY_ID] = {}; + } + return global[ACTIONS_BY_ID]; +} +function traceStoresByEnv(): Record> { + if (global[TRACE_STORES_BY_ENV] === undefined) { + global[TRACE_STORES_BY_ENV] = {}; + } + return global[TRACE_STORES_BY_ENV]; +} +function flowStateStoresByEnv(): Record> { + if (global[FLOW_STATE_STORES_BY_ENV] === undefined) { + global[FLOW_STATE_STORES_BY_ENV] = {}; + } + return global[FLOW_STATE_STORES_BY_ENV]; +} +function pluginsByName(): Record { + if (global[PLUGINS_BY_NAME] === undefined) { + global[PLUGINS_BY_NAME] = {}; + } + return global[PLUGINS_BY_NAME]; +} +function schemasByName(): Record< + string, + { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema } +> { + if (global[SCHEMAS_BY_NAME] === undefined) { + global[SCHEMAS_BY_NAME] = {}; + } + return global[SCHEMAS_BY_NAME]; +} /** * Type of a runnable action. @@ -39,18 +77,22 @@ export type ActionType = | 'flow' | 'model' | 'prompt' - | 'util' | 'tool'; /** * Looks up a registry key (action type and key) in the registry. */ -export function lookupAction< +export async function lookupAction< I extends z.ZodTypeAny, O extends z.ZodTypeAny, R extends Action, >(key: string): Promise { - return getRegistryInstance().lookupAction(key); + // If we don't see the key in the registry we try to initialize the plugin first. + const pluginName = parsePluginName(key); + if (!actionsById()[key] && pluginName) { + await initializePlugin(pluginName); + } + return actionsById()[key] as R; } function parsePluginName(registryKey: string) { @@ -68,23 +110,26 @@ export function registerAction( type: ActionType, action: Action ) { - return getRegistryInstance().registerAction(type, action); + logger.info(`Registering ${type}: ${action.__action.name}`); + const key = `/${type}/${action.__action.name}`; + if (actionsById().hasOwnProperty(key)) { + logger.warn( + `WARNING: ${key} already has an entry in the registry. Overwriting.` + ); + } + actionsById()[key] = action; } type ActionsRecord = Record>; -/** - * Initialize all plugins in the registry. - */ -export async function initializeAllPlugins() { - await getRegistryInstance().initializeAllPlugins(); -} - /** * Returns all actions in the registry. */ -export function listActions(): Promise { - return getRegistryInstance().listActions(); +export async function listActions(): Promise { + for (const pluginName of Object.keys(pluginsByName())) { + await initializePlugin(pluginName); + } + return Object.assign({}, actionsById()); } /** @@ -94,14 +139,27 @@ export function registerTraceStore( env: string, traceStoreProvider: AsyncProvider ) { - return getRegistryInstance().registerTraceStore(env, traceStoreProvider); + traceStoresByEnv()[env] = traceStoreProvider; } +const traceStoresByEnvCache: Record> = {}; + /** * Looks up the trace store for the given environment. */ -export function lookupTraceStore(env: string): Promise { - return getRegistryInstance().lookupTraceStore(env); +export async function lookupTraceStore( + env: string +): Promise { + if (!traceStoresByEnv()[env]) { + return undefined; + } + const cached = traceStoresByEnvCache[env]; + if (!cached) { + const newStore = traceStoresByEnv()[env](); + traceStoresByEnvCache[env] = newStore; + return newStore; + } + return cached; } /** @@ -111,48 +169,68 @@ export function registerFlowStateStore( env: string, flowStateStoreProvider: AsyncProvider ) { - return getRegistryInstance().registerFlowStateStore( - env, - flowStateStoreProvider - ); + flowStateStoresByEnv()[env] = flowStateStoreProvider; } +const flowStateStoresByEnvCache: Record> = {}; /** * Looks up the flow state store for the given environment. */ export async function lookupFlowStateStore( env: string ): Promise { - return getRegistryInstance().lookupFlowStateStore(env); + if (!flowStateStoresByEnv()[env]) { + return undefined; + } + const cached = flowStateStoresByEnvCache[env]; + if (!cached) { + const newStore = flowStateStoresByEnv()[env](); + flowStateStoresByEnvCache[env] = newStore; + return newStore; + } + return cached; } /** * Registers a flow state store for the given environment. */ export function registerPluginProvider(name: string, provider: PluginProvider) { - return getRegistryInstance().registerPluginProvider(name, provider); + let cached; + pluginsByName()[name] = { + name: provider.name, + initializer: () => { + if (cached) { + return cached; + } + cached = provider.initializer(); + return cached; + }, + }; } export function lookupPlugin(name: string) { - return getRegistryInstance().lookupFlowStateStore(name); + return pluginsByName()[name]; } /** - * Initialize plugin -- calls the plugin initialization function. + * */ export async function initializePlugin(name: string) { - return getRegistryInstance().initializePlugin(name); + if (pluginsByName()[name]) { + return await pluginsByName()[name].initializer(); + } + return undefined; } export function registerSchema( name: string, data: { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema } ) { - return getRegistryInstance().registerSchema(name, data); + schemasByName()[name] = data; } export function lookupSchema(name: string) { - return getRegistryInstance().lookupSchema(name); + return schemasByName()[name]; } /** @@ -163,187 +241,14 @@ if (process.env.GENKIT_ENV === 'dev') { } export function __hardResetRegistryForTesting() { - delete global[REGISTRY_KEY]; - global[REGISTRY_KEY] = new Registry(); -} - -export class Registry { - private actionsById: Record> = {}; - private traceStoresByEnv: Record> = {}; - private flowStateStoresByEnv: Record> = - {}; - private pluginsByName: Record = {}; - private schemasByName: Record< - string, - { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema } - > = {}; - - private traceStoresByEnvCache: Record> = {}; - private flowStateStoresByEnvCache: Record> = {}; - private allPluginsInitialized = false; - - constructor(public parent?: Registry) {} - - static withCurrent() { - return new Registry(getRegistryInstance()); - } - - static withParent(parent: Registry) { - return new Registry(parent); - } - - async lookupAction< - I extends z.ZodTypeAny, - O extends z.ZodTypeAny, - R extends Action, - >(key: string): Promise { - // If we don't see the key in the registry we try to initialize the plugin first. - const pluginName = parsePluginName(key); - if (!this.actionsById[key] && pluginName) { - await this.initializePlugin(pluginName); - } - return (this.actionsById[key] as R) || this.parent?.lookupAction(key); - } - - registerAction( - type: ActionType, - action: Action - ) { - logger.info(`Registering ${type}: ${action.__action.name}`); - const key = `/${type}/${action.__action.name}`; - if (this.actionsById.hasOwnProperty(key)) { - logger.warn( - `WARNING: ${key} already has an entry in the registry. Overwriting.` - ); - } - this.actionsById[key] = action; - } - - async listActions(): Promise { - await this.initializeAllPlugins(); - return { - ...(await this.parent?.listActions()), - ...this.actionsById, - }; - } - - async initializeAllPlugins() { - if (this.allPluginsInitialized) { - return; - } - for (const pluginName of Object.keys(this.pluginsByName)) { - await initializePlugin(pluginName); - } - this.allPluginsInitialized = true; - } - - registerTraceStore( - env: string, - traceStoreProvider: AsyncProvider - ) { - this.traceStoresByEnv[env] = traceStoreProvider; - } - - async lookupTraceStore(env: string): Promise { - return ( - (await this.lookupOverlaidTraceStore(env)) || - this.parent?.lookupTraceStore(env) - ); - } - - private async lookupOverlaidTraceStore( - env: string - ): Promise { - if (!this.traceStoresByEnv[env]) { - return undefined; - } - const cached = this.traceStoresByEnvCache[env]; - if (!cached) { - const newStore = this.traceStoresByEnv[env](); - this.traceStoresByEnvCache[env] = newStore; - return newStore; - } - return cached; - } - - registerFlowStateStore( - env: string, - flowStateStoreProvider: AsyncProvider - ) { - this.flowStateStoresByEnv[env] = flowStateStoreProvider; - } - - async lookupFlowStateStore(env: string): Promise { - return ( - (await this.lookupOverlaidFlowStateStore(env)) || - this.parent?.lookupFlowStateStore(env) - ); - } - - private async lookupOverlaidFlowStateStore( - env: string - ): Promise { - if (!this.flowStateStoresByEnv[env]) { - return undefined; - } - const cached = this.flowStateStoresByEnvCache[env]; - if (!cached) { - const newStore = this.flowStateStoresByEnv[env](); - this.flowStateStoresByEnvCache[env] = newStore; - return newStore; - } - return cached; - } - - registerPluginProvider(name: string, provider: PluginProvider) { - this.allPluginsInitialized = false; - let cached; - let isInitialized = false; - this.pluginsByName[name] = { - name: provider.name, - initializer: () => { - if (isInitialized) { - return cached; - } - cached = provider.initializer(); - isInitialized = true; - return cached; - }, - }; - } - - lookupPlugin(name: string) { - return this.pluginsByName[name] || this.parent?.lookupPlugin(name); - } - - async initializePlugin(name: string) { - if (this.pluginsByName[name]) { - return await this.pluginsByName[name].initializer(); - } - return undefined; - } - - registerSchema( - name: string, - data: { schema?: z.ZodTypeAny; jsonSchema?: JSONSchema } - ) { - this.schemasByName[name] = data; - } - - lookupSchema(name: string) { - return this.schemasByName[name] || this.parent?.lookupSchema(name); - } -} - -// global regustry instance -global[REGISTRY_KEY] = new Registry(); - -/** Returns the current registry instance. */ -export function getRegistryInstance(): Registry { - return global[REGISTRY_KEY]; + delete global[ACTIONS_BY_ID]; + delete global[TRACE_STORES_BY_ENV]; + delete global[FLOW_STATE_STORES_BY_ENV]; + delete global[PLUGINS_BY_NAME]; + deleteAll(flowStateStoresByEnvCache); + deleteAll(traceStoresByEnvCache); } -/** Sets global registry instance. */ -export function setRegistryInstance(reg: Registry) { - global[REGISTRY_KEY] = reg; +function deleteAll(map: Record) { + Object.keys(map).forEach((key) => delete map[key]); } diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index 23c165c6e..c969eba29 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -15,10 +15,9 @@ */ import assert from 'node:assert'; -import { afterEach, beforeEach, describe, it } from 'node:test'; +import { beforeEach, describe, it } from 'node:test'; import { action } from '../src/action.js'; import { - Registry, __hardResetRegistryForTesting, listActions, lookupAction, @@ -26,9 +25,8 @@ import { registerPluginProvider, } from '../src/registry.js'; -describe('global registry', () => { +describe('registry', () => { beforeEach(__hardResetRegistryForTesting); - afterEach(__hardResetRegistryForTesting); describe('listActions', () => { it('returns all registered actions', async () => { @@ -171,207 +169,3 @@ describe('global registry', () => { assert.strictEqual(await lookupAction('/model/foo/something'), undefined); }); }); - -describe('registry class', () => { - var registry: Registry; - beforeEach(() => { - registry = new Registry(); - }); - - describe('listActions', () => { - it('returns all registered actions', async () => { - const fooSomethingAction = action( - { name: 'foo_something' }, - async () => null - ); - registry.registerAction('model', fooSomethingAction); - const barSomethingAction = action( - { name: 'bar_something' }, - async () => null - ); - registry.registerAction('model', barSomethingAction); - - assert.deepEqual(await registry.listActions(), { - '/model/foo_something': fooSomethingAction, - '/model/bar_something': barSomethingAction, - }); - }); - - it('returns all registered actions by plugins', async () => { - registry.registerPluginProvider('foo', { - name: 'foo', - async initializer() { - registry.registerAction('model', fooSomethingAction); - return {}; - }, - }); - const fooSomethingAction = action( - { - name: { - pluginId: 'foo', - actionId: 'something', - }, - }, - async () => null - ); - registry.registerAction('custom', fooSomethingAction); - registry.registerPluginProvider('bar', { - name: 'bar', - async initializer() { - registry.registerAction('model', barSomethingAction); - return {}; - }, - }); - const barSomethingAction = action( - { - name: { - pluginId: 'bar', - actionId: 'something', - }, - }, - async () => null - ); - registry.registerAction('custom', barSomethingAction); - - assert.deepEqual(await registry.listActions(), { - '/custom/foo/something': fooSomethingAction, - '/custom/bar/something': barSomethingAction, - }); - }); - - it('returns all registered actions, including parent', async () => { - const child = Registry.withParent(registry); - - const fooSomethingAction = action( - { name: 'foo_something' }, - async () => null - ); - registry.registerAction('model', fooSomethingAction); - const barSomethingAction = action( - { name: 'bar_something' }, - async () => null - ); - child.registerAction('model', barSomethingAction); - - assert.deepEqual(await child.listActions(), { - '/model/foo_something': fooSomethingAction, - '/model/bar_something': barSomethingAction, - }); - assert.deepEqual(await registry.listActions(), { - '/model/foo_something': fooSomethingAction, - }); - }); - }); - - describe('lookupAction', () => { - it('initializes plugin for action first', async () => { - let fooInitialized = false; - registry.registerPluginProvider('foo', { - name: 'foo', - async initializer() { - fooInitialized = true; - return {}; - }, - }); - let barInitialized = false; - registry.registerPluginProvider('bar', { - name: 'bar', - async initializer() { - barInitialized = true; - return {}; - }, - }); - - await registry.lookupAction('/model/foo/something'); - - assert.strictEqual(fooInitialized, true); - assert.strictEqual(barInitialized, false); - - await registry.lookupAction('/model/bar/something'); - - assert.strictEqual(fooInitialized, true); - assert.strictEqual(barInitialized, true); - }); - - it('returns registered action', async () => { - const fooSomethingAction = action( - { name: 'foo_something' }, - async () => null - ); - registry.registerAction('model', fooSomethingAction); - const barSomethingAction = action( - { name: 'bar_something' }, - async () => null - ); - registry.registerAction('model', barSomethingAction); - - assert.strictEqual( - await registry.lookupAction('/model/foo_something'), - fooSomethingAction - ); - assert.strictEqual( - await registry.lookupAction('/model/bar_something'), - barSomethingAction - ); - }); - - it('returns action registered by plugin', async () => { - registry.registerPluginProvider('foo', { - name: 'foo', - async initializer() { - registry.registerAction('model', somethingAction); - return {}; - }, - }); - const somethingAction = action( - { - name: { - pluginId: 'foo', - actionId: 'something', - }, - }, - async () => null - ); - - assert.strictEqual( - await registry.lookupAction('/model/foo/something'), - somethingAction - ); - }); - - it('returns undefined for unknown action', async () => { - assert.strictEqual( - await registry.lookupAction('/model/foo/something'), - undefined - ); - }); - - it('should lookup parent registry when child missing action', async () => { - const childRegistry = new Registry(registry); - - const fooAction = action({ name: 'foo' }, async () => null); - registry.registerAction('model', fooAction); - - assert.strictEqual(await registry.lookupAction('/model/foo'), fooAction); - assert.strictEqual( - await childRegistry.lookupAction('/model/foo'), - fooAction - ); - }); - - it('registration on the child registry should not modify parent', async () => { - const childRegistry = Registry.withParent(registry); - - assert.strictEqual(childRegistry.parent, registry); - - const fooAction = action({ name: 'foo' }, async () => null); - childRegistry.registerAction('model', fooAction); - - assert.strictEqual(await registry.lookupAction('/model/foo'), undefined); - assert.strictEqual( - await childRegistry.lookupAction('/model/foo'), - fooAction - ); - }); - }); -}); diff --git a/js/flow/src/flow.ts b/js/flow/src/flow.ts index 186918de6..09b533d24 100644 --- a/js/flow/src/flow.ts +++ b/js/flow/src/flow.ts @@ -28,7 +28,6 @@ import { StreamingCallback, } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { initializeAllPlugins } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; import { newTrace, @@ -391,7 +390,6 @@ export class Flow< labels: Record | undefined ) { const startTimeMs = performance.now(); - await initializeAllPlugins(); await runWithActiveContext(ctx, async () => { let traceContext; if (ctx.state.traceContext) { diff --git a/js/flow/src/utils.ts b/js/flow/src/utils.ts index 411019517..774d81c4c 100644 --- a/js/flow/src/utils.ts +++ b/js/flow/src/utils.ts @@ -14,7 +14,6 @@ * limitations under the License. */ -import { runInActionRuntimeContext } from '@genkit-ai/core'; import { AsyncLocalStorage } from 'node:async_hooks'; import { v4 as uuidv4 } from 'uuid'; import z from 'zod'; @@ -46,7 +45,7 @@ export function runWithActiveContext( ctx: Context, fn: () => R ) { - return ctxAsyncLocalStorage.run(ctx, () => runInActionRuntimeContext(fn)); + return ctxAsyncLocalStorage.run(ctx, fn); } /** diff --git a/js/plugins/dotprompt/src/template.ts b/js/plugins/dotprompt/src/template.ts index 10e07cba4..020f9fad9 100644 --- a/js/plugins/dotprompt/src/template.ts +++ b/js/plugins/dotprompt/src/template.ts @@ -115,15 +115,11 @@ function toMessages( ) return messages; - if (messages.at(-1)?.role === 'user') { - return [ - ...messages.slice(0, -1), - ...options.history, - messages.at(-1), - ] as MessageData[]; - } - - return [...messages, ...options.history] as MessageData[]; + return [ + ...messages.slice(0, -1), + ...options.history, + messages.at(-1), + ] as MessageData[]; } const PART_REGEX = /(<<>>/g; diff --git a/js/plugins/dotprompt/tests/prompt_test.ts b/js/plugins/dotprompt/tests/prompt_test.ts index ac012a55d..c9a386141 100644 --- a/js/plugins/dotprompt/tests/prompt_test.ts +++ b/js/plugins/dotprompt/tests/prompt_test.ts @@ -105,25 +105,6 @@ describe('Prompt', () => { assert.strictEqual(rendered.streamingCallback, streamingCallback); assert.strictEqual(rendered.returnToolRequests, true); }); - - it('should support system prompt with history', async () => { - const prompt = testPrompt(`{{ role "system" }}Testing system {{name}}`); - - const rendered = await prompt.render({ - input: { name: 'Michael' }, - history: [ - { role: 'user', content: [{ text: 'history 1' }] }, - { role: 'model', content: [{ text: 'history 2' }] }, - { role: 'user', content: [{ text: 'history 3' }] }, - ], - }); - assert.deepStrictEqual(rendered.history, [ - { role: 'system', content: [{ text: 'Testing system Michael' }] }, - { role: 'user', content: [{ text: 'history 1' }] }, - { role: 'model', content: [{ text: 'history 2' }] }, - ]); - assert.deepStrictEqual(rendered.prompt, [{ text: 'history 3' }]); - }); }); describe('#generate', () => { diff --git a/js/plugins/google-cloud/tests/logs_test.ts b/js/plugins/google-cloud/tests/logs_test.ts index 002546878..596c3bb50 100644 --- a/js/plugins/google-cloud/tests/logs_test.ts +++ b/js/plugins/google-cloud/tests/logs_test.ts @@ -312,19 +312,19 @@ describe('GoogleCloudLogs', () => { const logMessages = await getLogs(1, 100, logLines); assert.equal( logMessages.includes( - '[info] Config[testFlow > sub1 > sub2 > generate > testModel, testModel]' + '[info] Config[testFlow > sub1 > sub2 > testModel, testModel]' ), true ); assert.equal( logMessages.includes( - '[info] Input[testFlow > sub1 > sub2 > generate > testModel, testModel]' + '[info] Input[testFlow > sub1 > sub2 > testModel, testModel]' ), true ); assert.equal( logMessages.includes( - '[info] Output[testFlow > sub1 > sub2 > generate > testModel, testModel]' + '[info] Output[testFlow > sub1 > sub2 > testModel, testModel]' ), true ); diff --git a/js/plugins/ollama/package.json b/js/plugins/ollama/package.json index 006229c68..1d5b88726 100644 --- a/js/plugins/ollama/package.json +++ b/js/plugins/ollama/package.json @@ -17,9 +17,7 @@ "compile": "tsup-node", "build:clean": "rm -rf ./lib", "build": "npm-run-all build:clean check compile", - "build:watch": "tsup-node --watch", - "test": "find tests -name '*_test.ts' ! -name '*_live_test.ts' -exec node --import tsx --test {} +", - "test:live": "node --import tsx --test tests/*_test.ts" + "build:watch": "tsup-node --watch" }, "repository": { "type": "git", @@ -28,9 +26,6 @@ }, "author": "genkit", "license": "Apache-2.0", - "dependencies": { - "zod": "^3.22.4" - }, "peerDependencies": { "@genkit-ai/ai": "workspace:*", "@genkit-ai/core": "workspace:*" diff --git a/js/plugins/ollama/src/embeddings.ts b/js/plugins/ollama/src/embeddings.ts deleted file mode 100644 index c1371e54f..000000000 --- a/js/plugins/ollama/src/embeddings.ts +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { defineEmbedder } from '@genkit-ai/ai/embedder'; -import { logger } from '@genkit-ai/core/logging'; -import z from 'zod'; -import { OllamaPluginParams } from './index.js'; - -// Define the schema for Ollama embedding configuration -export const OllamaEmbeddingConfigSchema = z.object({ - modelName: z.string(), - serverAddress: z.string(), -}); -export type OllamaEmbeddingConfig = z.infer; - -// Define the structure of the request and response for embedding -interface OllamaEmbeddingInstance { - content: string; -} - -interface OllamaEmbeddingPrediction { - embedding: number[]; -} - -interface DefineOllamaEmbeddingParams { - name: string; - modelName: string; - dimensions: number; - options: OllamaPluginParams; -} - -export function defineOllamaEmbedder({ - name, - modelName, - dimensions, - options, -}: DefineOllamaEmbeddingParams) { - return defineEmbedder( - { - name, - configSchema: OllamaEmbeddingConfigSchema, // Use the Zod schema directly here - info: { - // TODO: do we want users to be able to specify the label when they call this method directly? - label: 'Ollama Embedding - ' + modelName, - dimensions, - supports: { - // TODO: do any ollama models support other modalities? - input: ['text'], - }, - }, - }, - async (input, _config) => { - const serverAddress = options.serverAddress; - - const responses = await Promise.all( - input.map(async (i) => { - const requestPayload = { - model: modelName, - prompt: i.text(), - }; - let res: Response; - try { - console.log('MODEL NAME: ', modelName); - res = await fetch(`${serverAddress}/api/embeddings`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify(requestPayload), - }); - } catch (e) { - logger.error('Failed to fetch Ollama embedding'); - throw new Error(`Error fetching embedding from Ollama: ${e}`); - } - - if (!res.ok) { - logger.error('Failed to fetch Ollama embedding'); - throw new Error( - `Error fetching embedding from Ollama: ${res.statusText}` - ); - } - - const responseData = (await res.json()) as OllamaEmbeddingPrediction; - return responseData; - }) - ); - - return { - embeddings: responses, - }; - } - ); -} diff --git a/js/plugins/ollama/src/index.ts b/js/plugins/ollama/src/index.ts index 87296ee28..7a7bbb636 100644 --- a/js/plugins/ollama/src/index.ts +++ b/js/plugins/ollama/src/index.ts @@ -25,7 +25,6 @@ import { } from '@genkit-ai/ai/model'; import { genkitPlugin, Plugin } from '@genkit-ai/core'; import { logger } from '@genkit-ai/core/logging'; -import { defineOllamaEmbedder } from './embeddings'; type ApiType = 'chat' | 'generate'; @@ -38,11 +37,8 @@ type RequestHeaders = type ModelDefinition = { name: string; type?: ApiType }; -type EmbeddingModelDefinition = { name: string; dimensions: number }; - export interface OllamaPluginParams { models: ModelDefinition[]; - embeddingModels?: EmbeddingModelDefinition[]; /** * ollama server address. */ @@ -55,19 +51,10 @@ export const ollama: Plugin<[OllamaPluginParams]> = genkitPlugin( 'ollama', async (params: OllamaPluginParams) => { const serverAddress = params?.serverAddress; - return { models: params.models.map((model) => ollamaModel(model, serverAddress, params.requestHeaders) ), - embedders: params.embeddingModels?.map((model) => - defineOllamaEmbedder({ - name: `${ollama}/model.name`, - modelName: model.name, - dimensions: model.dimensions, - options: params, - }) - ), }; } ); diff --git a/js/plugins/ollama/tests/embeddings_live_test.ts b/js/plugins/ollama/tests/embeddings_live_test.ts deleted file mode 100644 index 8cbab30ef..000000000 --- a/js/plugins/ollama/tests/embeddings_live_test.ts +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { embed } from '@genkit-ai/ai/embedder'; -import assert from 'node:assert'; -import { describe, it } from 'node:test'; - -import { defineOllamaEmbedder } from '../src/embeddings.js'; // Adjust the import path as necessary -import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary - -// Utility function to parse command-line arguments -function parseArgs() { - const args = process.argv.slice(2); - const serverAddress = - args.find((arg) => arg.startsWith('--server-address='))?.split('=')[1] || - 'http://localhost:11434'; - const modelName = - args.find((arg) => arg.startsWith('--model-name='))?.split('=')[1] || - 'nomic-embed-text'; - return { serverAddress, modelName }; -} - -const { serverAddress, modelName } = parseArgs(); - -describe('defineOllamaEmbedder - Live Tests', () => { - const options: OllamaPluginParams = { - models: [{ name: modelName }], - serverAddress, - }; - - it('should successfully return embeddings', async () => { - const embedder = defineOllamaEmbedder({ - name: 'live-test-embedder', - modelName: 'nomic-embed-text', - dimensions: 768, - options, - }); - - const result = await embed({ - embedder, - content: 'Hello, world!', - }); - - assert.strictEqual(result.length, 768); - }); -}); diff --git a/js/plugins/ollama/tests/embeddings_test.ts b/js/plugins/ollama/tests/embeddings_test.ts deleted file mode 100644 index 10d2407e7..000000000 --- a/js/plugins/ollama/tests/embeddings_test.ts +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { embed } from '@genkit-ai/ai/embedder'; -import assert from 'node:assert'; -import { describe, it } from 'node:test'; -import { - OllamaEmbeddingConfigSchema, - defineOllamaEmbedder, -} from '../src/embeddings.js'; // Adjust the import path as necessary -import { OllamaPluginParams } from '../src/index.js'; // Adjust the import path as necessary - -// Mock fetch to simulate API responses -global.fetch = async (input: RequestInfo | URL, options?: RequestInit) => { - const url = typeof input === 'string' ? input : input.toString(); - - if (url.includes('/api/embedding')) { - if (options?.body && JSON.stringify(options.body).includes('fail')) { - return { - ok: false, - statusText: 'Internal Server Error', - json: async () => ({}), - } as Response; - } - return { - ok: true, - json: async () => ({ - embedding: [0.1, 0.2, 0.3], // Example embedding values - }), - } as Response; - } - - throw new Error('Unknown API endpoint'); -}; - -describe('defineOllamaEmbedder', () => { - const options: OllamaPluginParams = { - models: [{ name: 'test-model' }], - serverAddress: 'http://localhost:3000', - }; - - it('should successfully return embeddings', async () => { - const embedder = defineOllamaEmbedder({ - name: 'test-embedder', - modelName: 'test-model', - dimensions: 123, - options, - }); - - const result = await embed({ - embedder, - content: 'Hello, world!', - }); - assert.deepStrictEqual(result, [0.1, 0.2, 0.3]); - }); - - it('should handle API errors correctly', async () => { - const embedder = defineOllamaEmbedder({ - name: 'test-embedder', - modelName: 'test-model', - dimensions: 123, - options, - }); - - await assert.rejects( - async () => { - await embed({ - embedder, - content: 'fail', - }); - }, - (error) => { - // Check if error is an instance of Error - assert(error instanceof Error); - - assert.strictEqual( - error.message, - 'Error fetching embedding from Ollama: Internal Server Error' - ); - return true; - } - ); - }); - - it('should validate the embedding configuration schema', async () => { - const validConfig = { - modelName: 'test-model', - serverAddress: 'http://localhost:3000', - }; - - const invalidConfig = { - modelName: 123, // Invalid type - serverAddress: 'http://localhost:3000', - }; - - // Valid configuration should pass - assert.doesNotThrow(() => { - OllamaEmbeddingConfigSchema.parse(validConfig); - }); - - // Invalid configuration should throw - assert.throws(() => { - OllamaEmbeddingConfigSchema.parse(invalidConfig); - }); - }); - - it('should throw an error if the fetch response is not ok', async () => { - const embedder = defineOllamaEmbedder({ - name: 'test-embedder', - modelName: 'test-model', - dimensions: 123, - options, - }); - - await assert.rejects(async () => { - await embed({ - embedder, - content: 'fail', - }); - }, new Error('Error fetching embedding from Ollama: Internal Server Error')); - }); -}); diff --git a/js/plugins/vertexai/src/openai_compatibility.ts b/js/plugins/vertexai/src/openai_compatibility.ts index 6d89565b6..8e1d27ca6 100644 --- a/js/plugins/vertexai/src/openai_compatibility.ts +++ b/js/plugins/vertexai/src/openai_compatibility.ts @@ -75,7 +75,7 @@ function toOpenAiTool(tool: ToolDefinition): ChatCompletionTool { type: 'function', function: { name: tool.name, - parameters: tool.inputSchema || undefined, + parameters: tool.inputSchema, }, }; } diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index cfdb90660..1c97b84df 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -509,7 +509,7 @@ importers: version: link:../../flow '@langchain/community': specifier: ^0.0.53 - version: 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) + version: 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) '@langchain/core': specifier: ^0.1.61 version: 0.1.61 @@ -518,7 +518,7 @@ importers: version: 1.9.0 langchain: specifier: ^0.1.36 - version: 0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1) + version: 0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1) zod: specifier: ^3.22.4 version: 3.22.4 @@ -547,9 +547,6 @@ importers: '@genkit-ai/core': specifier: workspace:* version: link:../../core - zod: - specifier: ^3.22.4 - version: 3.22.4 devDependencies: '@types/node': specifier: ^20.11.16 @@ -1132,7 +1129,7 @@ importers: version: link:../../plugins/vertexai '@langchain/community': specifier: ^0.0.53 - version: 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) + version: 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) '@langchain/core': specifier: ^0.1.61 version: 0.1.61 @@ -1150,7 +1147,7 @@ importers: version: link:../../plugins/ollama langchain: specifier: ^0.1.36 - version: 0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1) + version: 0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1) pdf-parse: specifier: ^1.1.1 version: 1.1.1 @@ -5633,7 +5630,7 @@ snapshots: '@js-sdsl/ordered-map@4.4.2': {} - '@langchain/community@0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2)': + '@langchain/community@0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2)': dependencies: '@langchain/core': 0.1.61 '@langchain/openai': 0.0.28(encoding@0.1.13) @@ -7993,10 +7990,10 @@ snapshots: kuler@2.0.0: {} - langchain@0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1): + langchain@0.1.36(@google-cloud/storage@7.10.1(encoding@0.1.13))(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(fast-xml-parser@4.3.6)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(handlebars@4.7.8)(ignore@5.3.1)(jsonwebtoken@9.0.2)(pdf-parse@1.1.1): dependencies: '@anthropic-ai/sdk': 0.9.1(encoding@0.1.13) - '@langchain/community': 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.2.0(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) + '@langchain/community': 0.0.53(@pinecone-database/pinecone@2.2.0)(chromadb@1.8.1(encoding@0.1.13)(openai@4.53.0(encoding@0.1.13)))(encoding@0.1.13)(firebase-admin@12.3.1(encoding@0.1.13))(google-auth-library@8.9.0(encoding@0.1.13))(jsonwebtoken@9.0.2) '@langchain/core': 0.1.61 '@langchain/openai': 0.0.28(encoding@0.1.13) '@langchain/textsplitters': 0.0.0 diff --git a/js/testapps/flow-simple-ai/src/index.ts b/js/testapps/flow-simple-ai/src/index.ts index 0b64d41d4..934fda755 100644 --- a/js/testapps/flow-simple-ai/src/index.ts +++ b/js/testapps/flow-simple-ai/src/index.ts @@ -14,7 +14,8 @@ * limitations under the License. */ -import { defineTool, generate, generateStream, retrieve } from '@genkit-ai/ai'; +import { generate, generateStream, retrieve } from '@genkit-ai/ai'; +import { defineTool } from '@genkit-ai/ai/tool'; import { configureGenkit } from '@genkit-ai/core'; import { dotprompt, prompt } from '@genkit-ai/dotprompt'; import { defineFirestoreRetriever, firebase } from '@genkit-ai/firebase'; @@ -428,7 +429,6 @@ export const invalidOutput = defineFlow( } ); -import { MessageSchema } from '@genkit-ai/ai/model'; import { GoogleAIFileManager } from '@google/generative-ai/server'; const fileManager = new GoogleAIFileManager( process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY! @@ -465,38 +465,3 @@ export const fileApi = defineFlow( return result.text(); } ); - -export const testTools = [ - // test a tool with no input / output schema - defineTool( - { name: 'getColor', description: 'gets a random color' }, - async () => { - const colors = [ - 'red', - 'orange', - 'yellow', - 'blue', - 'green', - 'indigo', - 'violet', - ]; - return colors[Math.floor(Math.random() * colors.length)]; - } - ), -]; - -export const toolTester = defineFlow( - { - name: 'toolTester', - inputSchema: z.string(), - outputSchema: z.array(MessageSchema), - }, - async (query) => { - const result = await generate({ - model: gemini15Flash, - prompt: query, - tools: testTools, - }); - return result.toHistory(); - } -);