From e64495e800a3b1c88b5313b86563d8ff340044fb Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 9 May 2024 16:25:59 -0400 Subject: [PATCH 1/5] Added setGlobalDefaultModel that allows setting a global default model and config --- js/ai/src/generate.ts | 40 ++++++++++++++++++++++++++---- js/ai/src/index.ts | 1 + js/plugins/dotprompt/src/prompt.ts | 8 ------ js/samples/rag/src/pdf_rag.ts | 7 +++++- 4 files changed, 42 insertions(+), 14 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 5031cfb643..380b1ebb70 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -431,7 +431,7 @@ export interface GenerateOptions< CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, > { /** A model name (e.g. `vertexai/gemini-1.0-pro`) or reference. */ - model: ModelArgument; + model?: ModelArgument; /** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */ prompt: string | Part | Part[]; /** Retrieved documents to be used as context for this generation. */ @@ -479,9 +479,39 @@ const isValidCandidate = ( }); }; -async function resolveModel( - model: ModelAction | ModelReference | string -): Promise { +const DEFAULT_MODEL_GLOBAL_KEY = 'genkit_ai__defaultModel'; + +interface GlobalModelRef { + model: ModelArgument; + config?: any; +} + +export function setGlobalDefaultModel< + CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, +>(model: ModelArgument, config?: z.infer) { + global[DEFAULT_MODEL_GLOBAL_KEY] = { + model, + config, + } as GlobalModelRef; +} + +async function resolveModel(options: GenerateOptions): Promise { + let model = options.model; + if (!model) { + const globalModel = global[DEFAULT_MODEL_GLOBAL_KEY] as GlobalModelRef; + if (globalModel) { + model = globalModel.model; + if ( + (!options.config || Object.keys(options.config).length === 0) && + globalModel.config + ) { + // use configured global config + options.config = globalModel.config; + } + } else { + throw new Error('Unable to resolve model.'); + } + } if (typeof model === 'string') { return (await lookupAction(`/model/${model}`)) as ModelAction; } else if (model.hasOwnProperty('info')) { @@ -537,7 +567,7 @@ export async function generate< ): Promise>> { const resolvedOptions: GenerateOptions = await Promise.resolve(options); - const model = await resolveModel(resolvedOptions.model); + const model = await resolveModel(resolvedOptions); if (!model) { throw new Error(`Model ${JSON.stringify(resolvedOptions.model)} not found`); } diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts index 06df9efec3..fb300918ac 100644 --- a/js/ai/src/index.ts +++ b/js/ai/src/index.ts @@ -30,6 +30,7 @@ export { Message, generate, generateStream, + setGlobalDefaultModel, toGenerateRequest, } from './generate.js'; export { PromptAction, definePrompt, renderPrompt } from './prompt.js'; diff --git a/js/plugins/dotprompt/src/prompt.ts b/js/plugins/dotprompt/src/prompt.ts index 12ff7600f5..b0fc9c4a06 100644 --- a/js/plugins/dotprompt/src/prompt.ts +++ b/js/plugins/dotprompt/src/prompt.ts @@ -161,14 +161,6 @@ export class Dotprompt implements PromptMetadata { private _generateOptions( options: PromptGenerateOptions ): GenerateOptions { - if (!options.model && !this.model) { - throw new GenkitError({ - source: 'Dotprompt', - message: 'Must supply `model` in prompt metadata or generate options.', - status: 'INVALID_ARGUMENT', - }); - } - const messages = this.renderMessages(options.input); return { model: options.model || this.model!, diff --git a/js/samples/rag/src/pdf_rag.ts b/js/samples/rag/src/pdf_rag.ts index 1638939651..f50c0fb47b 100644 --- a/js/samples/rag/src/pdf_rag.ts +++ b/js/samples/rag/src/pdf_rag.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { generate } from '@genkit-ai/ai'; +import { generate, setGlobalDefaultModel } from '@genkit-ai/ai'; import { Document, index, retrieve } from '@genkit-ai/ai/retriever'; import { devLocalIndexerRef, @@ -33,6 +33,11 @@ export const pdfChatRetriever = devLocalRetrieverRef('pdfQA'); export const pdfChatIndexer = devLocalIndexerRef('pdfQA'); +setGlobalDefaultModel(geminiPro, { + temperature: 0.6, + stopSequences: ['sorry'], +}); + // Define a simple RAG flow, we will evaluate this flow export const pdfQA = defineFlow( { From 7eb695af5a99c133619cefbe5576e49aab6c38ff Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 9 May 2024 21:05:15 -0400 Subject: [PATCH 2/5] moved to config --- js/ai/src/generate.ts | 29 ++++++++--------------------- js/ai/src/index.ts | 1 - js/core/src/config.ts | 4 ++++ js/samples/rag/src/index.ts | 6 ++++++ js/samples/rag/src/pdf_rag.ts | 7 +------ js/samples/rag/src/prompt.ts | 2 -- 6 files changed, 19 insertions(+), 30 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 380b1ebb70..3f425d9f6f 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -16,6 +16,7 @@ import { Action, + config as genkitConfig, GenkitError, runWithStreamingCallback, StreamingCallback, @@ -479,34 +480,20 @@ const isValidCandidate = ( }); }; -const DEFAULT_MODEL_GLOBAL_KEY = 'genkit_ai__defaultModel'; - -interface GlobalModelRef { - model: ModelArgument; - config?: any; -} - -export function setGlobalDefaultModel< - CustomOptions extends z.ZodTypeAny = z.ZodTypeAny, ->(model: ModelArgument, config?: z.infer) { - global[DEFAULT_MODEL_GLOBAL_KEY] = { - model, - config, - } as GlobalModelRef; -} - async function resolveModel(options: GenerateOptions): Promise { let model = options.model; if (!model) { - const globalModel = global[DEFAULT_MODEL_GLOBAL_KEY] as GlobalModelRef; - if (globalModel) { - model = globalModel.model; + if (genkitConfig?.options.defaultModel) { + model = + typeof genkitConfig?.options.defaultModel.name === 'string' + ? genkitConfig?.options.defaultModel.name + : genkitConfig?.options.defaultModel.name.name; if ( (!options.config || Object.keys(options.config).length === 0) && - globalModel.config + genkitConfig?.options.defaultModel.config ) { // use configured global config - options.config = globalModel.config; + options.config = genkitConfig?.options.defaultModel.config; } } else { throw new Error('Unable to resolve model.'); diff --git a/js/ai/src/index.ts b/js/ai/src/index.ts index fb300918ac..06df9efec3 100644 --- a/js/ai/src/index.ts +++ b/js/ai/src/index.ts @@ -30,7 +30,6 @@ export { Message, generate, generateStream, - setGlobalDefaultModel, toGenerateRequest, } from './generate.js'; export { PromptAction, definePrompt, renderPrompt } from './prompt.js'; diff --git a/js/core/src/config.ts b/js/core/src/config.ts index 1381af8bff..65391d3749 100644 --- a/js/core/src/config.ts +++ b/js/core/src/config.ts @@ -42,6 +42,10 @@ export interface ConfigOptions { logLevel?: 'error' | 'warn' | 'info' | 'debug'; promptDir?: string; telemetry?: TelemetryOptions; + defaultModel?: { + name: string | { name: string }; + config?: Record; + }; } class Config { diff --git a/js/samples/rag/src/index.ts b/js/samples/rag/src/index.ts index ffc5b98823..890756b22f 100644 --- a/js/samples/rag/src/index.ts +++ b/js/samples/rag/src/index.ts @@ -77,6 +77,12 @@ export default configureGenkit({ }, ]), ], + defaultModel: { + name: geminiPro, + config: { + temperature: 0.6, + }, + }, flowStateStore: 'firebase', traceStore: 'firebase', enableTracingAndMetrics: true, diff --git a/js/samples/rag/src/pdf_rag.ts b/js/samples/rag/src/pdf_rag.ts index f50c0fb47b..1638939651 100644 --- a/js/samples/rag/src/pdf_rag.ts +++ b/js/samples/rag/src/pdf_rag.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { generate, setGlobalDefaultModel } from '@genkit-ai/ai'; +import { generate } from '@genkit-ai/ai'; import { Document, index, retrieve } from '@genkit-ai/ai/retriever'; import { devLocalIndexerRef, @@ -33,11 +33,6 @@ export const pdfChatRetriever = devLocalRetrieverRef('pdfQA'); export const pdfChatIndexer = devLocalIndexerRef('pdfQA'); -setGlobalDefaultModel(geminiPro, { - temperature: 0.6, - stopSequences: ['sorry'], -}); - // Define a simple RAG flow, we will evaluate this flow export const pdfQA = defineFlow( { diff --git a/js/samples/rag/src/prompt.ts b/js/samples/rag/src/prompt.ts index fec34b3ab6..c77c281303 100644 --- a/js/samples/rag/src/prompt.ts +++ b/js/samples/rag/src/prompt.ts @@ -15,7 +15,6 @@ */ import { defineDotprompt } from '@genkit-ai/dotprompt'; -import { geminiPro } from '@genkit-ai/vertexai'; import * as z from 'zod'; // Define a prompt that includes the retrieved context documents @@ -23,7 +22,6 @@ import * as z from 'zod'; export const augmentedPrompt = defineDotprompt( { name: 'augmentedPrompt', - model: geminiPro, input: z.object({ context: z.array(z.string()), question: z.string(), From 50a70e9d0edbe462fcb540b0e1d3a389cb254551 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 9 May 2024 21:09:37 -0400 Subject: [PATCH 3/5] unnecessary ? --- js/ai/src/generate.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 3f425d9f6f..8f04bf5e57 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -485,9 +485,9 @@ async function resolveModel(options: GenerateOptions): Promise { if (!model) { if (genkitConfig?.options.defaultModel) { model = - typeof genkitConfig?.options.defaultModel.name === 'string' - ? genkitConfig?.options.defaultModel.name - : genkitConfig?.options.defaultModel.name.name; + typeof genkitConfig.options.defaultModel.name === 'string' + ? genkitConfig.options.defaultModel.name + : genkitConfig.options.defaultModel.name.name; if ( (!options.config || Object.keys(options.config).length === 0) && genkitConfig?.options.defaultModel.config From 9fbc7eb0dae6cd598d587ca16bfecccb8e336f7f Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 9 May 2024 21:09:59 -0400 Subject: [PATCH 4/5] more unnecessary ? --- js/ai/src/generate.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 8f04bf5e57..03af256956 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -490,10 +490,10 @@ async function resolveModel(options: GenerateOptions): Promise { : genkitConfig.options.defaultModel.name.name; if ( (!options.config || Object.keys(options.config).length === 0) && - genkitConfig?.options.defaultModel.config + genkitConfig.options.defaultModel.config ) { // use configured global config - options.config = genkitConfig?.options.defaultModel.config; + options.config = genkitConfig.options.defaultModel.config; } } else { throw new Error('Unable to resolve model.'); From cb0e542ea5386e6a95319ee7acf7ab112d4f55ad Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 9 May 2024 21:10:07 -0400 Subject: [PATCH 5/5] just in case ? --- js/ai/src/generate.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 03af256956..53523106a6 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -483,7 +483,7 @@ const isValidCandidate = ( async function resolveModel(options: GenerateOptions): Promise { let model = options.model; if (!model) { - if (genkitConfig?.options.defaultModel) { + if (genkitConfig?.options?.defaultModel) { model = typeof genkitConfig.options.defaultModel.name === 'string' ? genkitConfig.options.defaultModel.name