Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit return types for plugins that define models, etc #397

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions js/plugins/chroma/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import {
defineIndexer,
defineRetriever,
Document,
IndexerAction,
indexerRef,
RetrieverAction,
retrieverRef,
} from '@genkit-ai/ai/retriever';
import { genkitPlugin, PluginProvider } from '@genkit-ai/core';
Expand Down Expand Up @@ -120,7 +122,7 @@ export function chromaRetriever<
createCollectionIfMissing?: boolean;
embedder: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}) {
}): RetrieverAction<z.ZodOptional<typeof ChromaRetrieverOptionsSchema>> {
Copy link
Member Author

@MichaelDoyle MichaelDoyle Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed a little funky to me. Should we just make this ChromaRetrieverOptionsSchema?

It looks like we're possibly double-wrapping in an Optional? https://github.com/firebase/genkit/blob/main/js/ai/src/retriever.ts#L164-L168

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I first made these abstractions in js/ai, I wanted to make as few assumptions as possible. That meant that the options should not be ZodOptional since we do not know what kind of retriever they are building and what options are required vs optional.

IMO, the /ai abstractions should not add a .optional() on the schema and the providers can add them in the custom schema they define (ChromaRetrieverOptionsSchema in this case).

That being said, this code has through many changes and its evident that it no longer holds to that thought. The problem here does not seem to be double wrapping though -- there is no Optional on the ChromaRetrieverOptionsSchema object and it is added by RetrieverAction.

const { embedder, collectionName, embedderOptions } = params;
return defineRetriever(
{
Expand Down Expand Up @@ -191,7 +193,7 @@ export function chromaIndexer<
createCollectionIfMissing?: boolean;
embedder: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}) {
}): IndexerAction<typeof ChromaIndexerOptionsSchema> {
const { collectionName, embedder, embedderOptions } = {
...params,
};
Expand Down
6 changes: 4 additions & 2 deletions js/plugins/dev-local-vectorstore/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import {
defineRetriever,
Document,
DocumentData,
IndexerAction,
indexerRef,
RetrieverAction,
retrieverRef,
} from '@genkit-ai/ai/retriever';
import { genkitPlugin, PluginProvider } from '@genkit-ai/core';
Expand Down Expand Up @@ -173,7 +175,7 @@ export function configureDevLocalRetriever<
indexName: string;
embedder: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}) {
}): RetrieverAction<typeof CommonRetrieverOptionsSchema> {
const { embedder, embedderOptions } = params;
const vectorstore = defineRetriever(
{
Expand Down Expand Up @@ -209,7 +211,7 @@ export function configureDevLocalIndexer<
indexName: string;
embedder: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}) {
}): IndexerAction<z.ZodTypeAny> {
const { embedder, embedderOptions } = params;
const vectorstore = defineIndexer(
{ name: `devLocalVectorstore/${params.indexName}` },
Expand Down
8 changes: 6 additions & 2 deletions js/plugins/googleai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
* limitations under the License.
*/

import { defineEmbedder, embedderRef } from '@genkit-ai/ai/embedder';
import {
defineEmbedder,
EmbedderAction,
embedderRef,
} from '@genkit-ai/ai/embedder';
import { EmbedContentRequest, GoogleGenerativeAI } from '@google/generative-ai';
import { string, z } from 'zod';
import { PluginOptions } from './index.js';
Expand Down Expand Up @@ -60,7 +64,7 @@ export const SUPPORTED_MODELS = {
export function textEmbeddingGeckoEmbedder(
name: string,
options: PluginOptions
) {
): EmbedderAction<typeof TextEmbeddingGeckoConfigSchema> {
let apiKey =
options?.apiKey ||
process.env.GOOGLE_GENAI_API_KEY ||
Expand Down
13 changes: 7 additions & 6 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ export const geminiUltra = modelRef({

export const SUPPORTED_V1_MODELS: Record<
string,
ModelReference<z.ZodTypeAny>
ModelReference<typeof GeminiConfigSchema>
> = {
'gemini-pro': geminiPro,
'gemini-pro-vision': geminiProVision,
Expand All @@ -159,7 +159,7 @@ export const SUPPORTED_V1_MODELS: Record<

export const SUPPORTED_V15_MODELS: Record<
string,
ModelReference<z.ZodTypeAny>
ModelReference<typeof GeminiConfigSchema>
> = {
'gemini-1.5-pro-latest': gemini15Pro,
'gemini-1.5-flash-latest': gemini15Flash,
Expand All @@ -172,7 +172,7 @@ const SUPPORTED_MODELS = {

function toGeminiRole(
role: MessageData['role'],
model?: ModelReference<z.ZodTypeAny>
model?: ModelReference<typeof GeminiConfigSchema>
): string {
switch (role) {
case 'user':
Expand Down Expand Up @@ -331,7 +331,7 @@ function fromGeminiPart(part: GeminiPart): Part {

export function toGeminiMessage(
message: MessageData,
model?: ModelReference<z.ZodTypeAny>
model?: ModelReference<typeof GeminiConfigSchema>
): GeminiMessage {
return {
role: toGeminiRole(message.role, model),
Expand Down Expand Up @@ -387,7 +387,7 @@ export function googleAIModel(
apiKey?: string,
apiVersion?: string,
baseUrl?: string
): ModelAction {
): ModelAction<typeof GeminiConfigSchema> {
const modelName = `googleai/${name}`;

if (!apiKey) {
Expand All @@ -400,7 +400,8 @@ export function googleAIModel(
);
}

const model: ModelReference<z.ZodTypeAny> = SUPPORTED_MODELS[name];
const model: ModelReference<typeof GeminiConfigSchema> =
SUPPORTED_MODELS[name];
if (!model) throw new Error(`Unsupported model: ${name}`);

const middleware: ModelMiddleware[] = [];
Expand Down
3 changes: 2 additions & 1 deletion js/plugins/ollama/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import {
GenerationCommonConfigSchema,
getBasicUsageStats,
MessageData,
ModelAction,
} from '@genkit-ai/ai/model';
import { genkitPlugin, Plugin } from '@genkit-ai/core';
import { logger } from '@genkit-ai/core/logging';
Expand Down Expand Up @@ -63,7 +64,7 @@ function ollamaModel(
model: ModelDefinition,
serverAddress: string,
requestHeaders?: RequestHeaders
) {
): ModelAction<typeof GenerationCommonConfigSchema> {
return defineModel(
{
name: `ollama/${model.name}`,
Expand Down
6 changes: 4 additions & 2 deletions js/plugins/pinecone/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import {
defineIndexer,
defineRetriever,
Document,
IndexerAction,
indexerRef,
RetrieverAction,
retrieverRef,
} from '@genkit-ai/ai/retriever';
import { genkitPlugin, PluginProvider } from '@genkit-ai/core';
Expand Down Expand Up @@ -130,7 +132,7 @@ export function configurePineconeRetriever<
textKey?: string;
embedder: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}) {
}): RetrieverAction<typeof PineconeRetrieverOptionsSchema> {
const { indexId, embedder, embedderOptions } = {
...params,
};
Expand Down Expand Up @@ -185,7 +187,7 @@ export function configurePineconeIndexer<
textKey?: string;
embedder: EmbedderArgument<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}) {
}): IndexerAction<z.ZodOptional<typeof PineconeIndexerOptionsSchema>> {
Copy link
Member Author

@MichaelDoyle MichaelDoyle Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed a little funky to me. Should we just make this PineconeIndexerOptionsSchema?

It looks like we're possibly double-wrapping in an Optional? https://github.com/firebase/genkit/blob/main/js/ai/src/embedder.ts#L74-L78

const { indexId, embedder, embedderOptions } = {
...params,
};
Expand Down
3 changes: 2 additions & 1 deletion js/plugins/vertexai/src/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import {
GenerateResponseData,
GenerationCommonConfigSchema,
Part as GenkitPart,
ModelAction,
ModelReference,
defineModel,
getBasicUsageStats,
Expand Down Expand Up @@ -96,7 +97,7 @@ export function anthropicModel(
modelName: string,
projectId: string,
region: string
) {
): ModelAction<typeof GenerationCommonConfigSchema> {
const client = new AnthropicVertex({
region,
projectId,
Expand Down
8 changes: 6 additions & 2 deletions js/plugins/vertexai/src/embedder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import {
defineEmbedder,
EmbedderAction,
embedderRef,
EmbedderReference,
} from '@genkit-ai/ai/embedder';
Expand Down Expand Up @@ -119,7 +120,10 @@ export const textEmbeddingGeckoMultilingual001 = embedderRef({

export const textEmbeddingGecko = textEmbeddingGecko003;

export const SUPPORTED_EMBEDDER_MODELS: Record<string, EmbedderReference> = {
export const SUPPORTED_EMBEDDER_MODELS: Record<
string,
EmbedderReference<typeof TextEmbeddingGeckoConfigSchema>
> = {
'textembedding-gecko@003': textEmbeddingGecko003,
'textembedding-gecko@002': textEmbeddingGecko002,
'textembedding-gecko@001': textEmbeddingGecko001,
Expand Down Expand Up @@ -147,7 +151,7 @@ export function textEmbeddingGeckoEmbedder(
name: string,
client: GoogleAuth,
options: PluginOptions
) {
): EmbedderAction<typeof TextEmbeddingGeckoConfigSchema> {
const embedder = SUPPORTED_EMBEDDER_MODELS[name];
// TODO: Figure out how to allow different versions while still sharing a single implementation.
const predict = predictModel<EmbeddingInstance, EmbeddingPrediction>(
Expand Down
22 changes: 16 additions & 6 deletions js/plugins/vertexai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,13 +149,19 @@ export const gemini15Flash = modelRef({
configSchema: GeminiConfigSchema,
});

export const SUPPORTED_V1_MODELS = {
export const SUPPORTED_V1_MODELS: Record<
string,
ModelReference<typeof GeminiConfigSchema>
> = {
'gemini-1.0-pro': geminiPro,
'gemini-1.0-pro-vision': geminiProVision,
// 'gemini-ultra': geminiUltra,
};

export const SUPPORTED_V15_MODELS = {
export const SUPPORTED_V15_MODELS: Record<
string,
ModelReference<typeof GeminiConfigSchema>
> = {
'gemini-1.5-pro': gemini15Pro,
'gemini-1.5-flash': gemini15Flash,
'gemini-1.5-pro-preview': gemini15ProPreview,
Expand All @@ -169,7 +175,7 @@ export const SUPPORTED_GEMINI_MODELS = {

function toGeminiRole(
role: MessageData['role'],
model?: ModelReference<z.ZodTypeAny>
model?: ModelReference<typeof GeminiConfigSchema>
): string {
switch (role) {
case 'user':
Expand Down Expand Up @@ -271,7 +277,7 @@ export function toGeminiSystemInstruction(message: MessageData): Content {

export function toGeminiMessage(
message: MessageData,
model?: ModelReference<z.ZodTypeAny>
model?: ModelReference<typeof GeminiConfigSchema>
): Content {
return {
role: toGeminiRole(message.role, model),
Expand Down Expand Up @@ -441,10 +447,14 @@ const convertSchemaProperty = (property) => {
/**
*
*/
export function geminiModel(name: string, vertex: VertexAI): ModelAction {
export function geminiModel(
name: string,
vertex: VertexAI
): ModelAction<typeof GeminiConfigSchema> {
const modelName = `vertexai/${name}`;

const model: ModelReference<z.ZodTypeAny> = SUPPORTED_GEMINI_MODELS[name];
const model: ModelReference<typeof GeminiConfigSchema> =
SUPPORTED_GEMINI_MODELS[name];
if (!model) throw new Error(`Unsupported model: ${name}`);

const middlewares: ModelMiddleware[] = [];
Expand Down
7 changes: 5 additions & 2 deletions js/plugins/vertexai/src/imagen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
GenerateRequest,
GenerationCommonConfigSchema,
getBasicUsageStats,
ModelAction,
modelRef,
} from '@genkit-ai/ai/model';
import { GoogleAuth } from 'google-auth-library';
Expand All @@ -39,7 +40,6 @@ const ImagenConfigSchema = GenerationCommonConfigSchema.extend({
/** Any non-negative integer you provide to make output images deterministic. Providing the same seed number always results in the same output images. Accepted integer values: 1 - 2147483647. */
seed: z.number().optional(),
});
type ImagenConfig = z.infer<typeof ImagenConfigSchema>;

export const imagen2 = modelRef({
name: 'vertexai/imagen2',
Expand Down Expand Up @@ -109,7 +109,10 @@ interface ImagenInstance {
/**
*
*/
export function imagen2Model(client: GoogleAuth, options: PluginOptions) {
export function imagen2Model(
client: GoogleAuth,
options: PluginOptions
): ModelAction<typeof ImagenConfigSchema> {
const predict = predictModel<
ImagenInstance,
ImagenPrediction,
Expand Down
12 changes: 6 additions & 6 deletions js/plugins/vertexai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
* limitations under the License.
*/

import { ModelReference } from '@genkit-ai/ai/model';
import { genkitPlugin, Plugin } from '@genkit-ai/core';
import { ModelAction, ModelReference } from '@genkit-ai/ai/model';
import { Plugin, genkitPlugin } from '@genkit-ai/core';
import { VertexAI } from '@google-cloud/vertexai';
import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library';
import {
SUPPORTED_ANTHROPIC_MODELS,
anthropicModel,
claude3Haiku,
claude3Opus,
claude3Sonnet,
SUPPORTED_ANTHROPIC_MODELS,
} from './anthropic.js';
import {
SUPPORTED_EMBEDDER_MODELS,
Expand All @@ -42,18 +42,19 @@ import {
vertexEvaluators,
} from './evaluation.js';
import {
SUPPORTED_GEMINI_MODELS,
gemini15Flash,
gemini15FlashPreview,
gemini15Pro,
gemini15ProPreview,
geminiModel,
geminiPro,
geminiProVision,
SUPPORTED_GEMINI_MODELS,
} from './gemini.js';
import { imagen2, imagen2Model } from './imagen.js';

export {
VertexAIEvaluationMetricType as VertexAIEvaluationMetricType,
claude3Haiku,
claude3Opus,
claude3Sonnet,
Expand All @@ -71,7 +72,6 @@ export {
textEmbeddingGecko003,
textEmbeddingGeckoMultilingual001,
textMultilingualEmbedding002,
VertexAIEvaluationMetricType as VertexAIEvaluationMetricType,
};

export interface PluginOptions {
Expand Down Expand Up @@ -120,7 +120,7 @@ export const vertexAI: Plugin<[PluginOptions] | []> = genkitPlugin(
? options.evaluation.metrics
: [];

const models = [
const models: ModelAction<any>[] = [
imagen2Model(authClient, { projectId, location }),
...Object.keys(SUPPORTED_GEMINI_MODELS).map((name) =>
geminiModel(name, vertexClient)
Expand Down
Loading