Skip to content

Commit

Permalink
Added defaultModel to config that allows setting a global default mod…
Browse files Browse the repository at this point in the history
…el and config (#98)
  • Loading branch information
pavelgj authored May 10, 2024
1 parent f55e294 commit 8d871f0
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 15 deletions.
27 changes: 22 additions & 5 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import {
Action,
config as genkitConfig,
GenkitError,
runWithStreamingCallback,
StreamingCallback,
Expand Down Expand Up @@ -431,7 +432,7 @@ export interface GenerateOptions<
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
> {
/** A model name (e.g. `vertexai/gemini-1.0-pro`) or reference. */
model: ModelArgument<CustomOptions>;
model?: ModelArgument<CustomOptions>;
/** The prompt for which to generate a response. Can be a string for a simple text prompt or one or more parts for multi-modal prompts. */
prompt: string | Part | Part[];
/** Retrieved documents to be used as context for this generation. */
Expand Down Expand Up @@ -479,9 +480,25 @@ const isValidCandidate = (
});
};

async function resolveModel(
model: ModelAction<any> | ModelReference<any> | string
): Promise<ModelAction> {
async function resolveModel(options: GenerateOptions): Promise<ModelAction> {
let model = options.model;
if (!model) {
if (genkitConfig?.options?.defaultModel) {
model =
typeof genkitConfig.options.defaultModel.name === 'string'
? genkitConfig.options.defaultModel.name
: genkitConfig.options.defaultModel.name.name;
if (
(!options.config || Object.keys(options.config).length === 0) &&
genkitConfig.options.defaultModel.config
) {
// use configured global config
options.config = genkitConfig.options.defaultModel.config;
}
} else {
throw new Error('Unable to resolve model.');
}
}
if (typeof model === 'string') {
return (await lookupAction(`/model/${model}`)) as ModelAction;
} else if (model.hasOwnProperty('info')) {
Expand Down Expand Up @@ -537,7 +554,7 @@ export async function generate<
): Promise<GenerateResponse<z.infer<O>>> {
const resolvedOptions: GenerateOptions<O, CustomOptions> =
await Promise.resolve(options);
const model = await resolveModel(resolvedOptions.model);
const model = await resolveModel(resolvedOptions);
if (!model) {
throw new Error(`Model ${JSON.stringify(resolvedOptions.model)} not found`);
}
Expand Down
4 changes: 4 additions & 0 deletions js/core/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ export interface ConfigOptions {
logLevel?: 'error' | 'warn' | 'info' | 'debug';
promptDir?: string;
telemetry?: TelemetryOptions;
defaultModel?: {
name: string | { name: string };
config?: Record<string, any>;
};
}

class Config {
Expand Down
8 changes: 0 additions & 8 deletions js/plugins/dotprompt/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,6 @@ export class Dotprompt<Variables = unknown> implements PromptMetadata {
private _generateOptions(
options: PromptGenerateOptions<Variables>
): GenerateOptions {
if (!options.model && !this.model) {
throw new GenkitError({
source: 'Dotprompt',
message: 'Must supply `model` in prompt metadata or generate options.',
status: 'INVALID_ARGUMENT',
});
}

const messages = this.renderMessages(options.input);
return {
model: options.model || this.model!,
Expand Down
6 changes: 6 additions & 0 deletions js/samples/rag/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ export default configureGenkit({
},
]),
],
defaultModel: {
name: geminiPro,
config: {
temperature: 0.6,
},
},
flowStateStore: 'firebase',
traceStore: 'firebase',
enableTracingAndMetrics: true,
Expand Down
2 changes: 0 additions & 2 deletions js/samples/rag/src/prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@
*/

import { defineDotprompt } from '@genkit-ai/dotprompt';
import { geminiPro } from '@genkit-ai/vertexai';
import * as z from 'zod';

// Define a prompt that includes the retrieved context documents

export const augmentedPrompt = defineDotprompt(
{
name: 'augmentedPrompt',
model: geminiPro,
input: z.object({
context: z.array(z.string()),
question: z.string(),
Expand Down

0 comments on commit 8d871f0

Please sign in to comment.