From fc008a191f6e323398278ea0bf985e85822582e4 Mon Sep 17 00:00:00 2001 From: Christina Holland Date: Wed, 11 Sep 2024 12:37:27 -0700 Subject: [PATCH] Add ability to set modelParams on getGenerativeModelFromCachedContent() (#254) --- .changeset/tame-lizards-kiss.md | 5 + common/api-review/generative-ai.api.md | 2 +- ...eai.getgenerativemodelfromcachedcontent.md | 3 +- .../main/generative-ai.googlegenerativeai.md | 2 +- src/gen-ai.test.ts | 91 ++++++++++++++++++- src/gen-ai.ts | 34 +++++++ 6 files changed, 132 insertions(+), 5 deletions(-) create mode 100644 .changeset/tame-lizards-kiss.md diff --git a/.changeset/tame-lizards-kiss.md b/.changeset/tame-lizards-kiss.md new file mode 100644 index 00000000..27c821f1 --- /dev/null +++ b/.changeset/tame-lizards-kiss.md @@ -0,0 +1,5 @@ +--- +"@google/generative-ai": minor +--- + +Add ability to set modelParams (generationConfig, safetySettings) on getGenerativeModelFromCachedContent(). diff --git a/common/api-review/generative-ai.api.md b/common/api-review/generative-ai.api.md index 448d096d..7bc7d3f4 100644 --- a/common/api-review/generative-ai.api.md +++ b/common/api-review/generative-ai.api.md @@ -481,7 +481,7 @@ export class GoogleGenerativeAI { // (undocumented) apiKey: string; getGenerativeModel(modelParams: ModelParams, requestOptions?: RequestOptions): GenerativeModel; - getGenerativeModelFromCachedContent(cachedContent: CachedContent, requestOptions?: RequestOptions): GenerativeModel; + getGenerativeModelFromCachedContent(cachedContent: CachedContent, modelParams?: Partial, requestOptions?: RequestOptions): GenerativeModel; } // @public diff --git a/docs/reference/main/generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md b/docs/reference/main/generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md index 1f24640f..354cd072 100644 --- a/docs/reference/main/generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md +++ b/docs/reference/main/generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md @@ -9,7 +9,7 @@ Creates a [GenerativeModel](./generative-ai.generativemodel.md) instance from pr **Signature:** ```typescript -getGenerativeModelFromCachedContent(cachedContent: CachedContent, requestOptions?: RequestOptions): GenerativeModel; +getGenerativeModelFromCachedContent(cachedContent: CachedContent, modelParams?: Partial, requestOptions?: RequestOptions): GenerativeModel; ``` ## Parameters @@ -17,6 +17,7 @@ getGenerativeModelFromCachedContent(cachedContent: CachedContent, requestOptions | Parameter | Type | Description | | --- | --- | --- | | cachedContent | [CachedContent](./generative-ai.cachedcontent.md) | | +| modelParams | Partial<[ModelParams](./generative-ai.modelparams.md)> | _(Optional)_ | | requestOptions | [RequestOptions](./generative-ai.requestoptions.md) | _(Optional)_ | **Returns:** diff --git a/docs/reference/main/generative-ai.googlegenerativeai.md b/docs/reference/main/generative-ai.googlegenerativeai.md index 0f3aa381..38bad256 100644 --- a/docs/reference/main/generative-ai.googlegenerativeai.md +++ b/docs/reference/main/generative-ai.googlegenerativeai.md @@ -29,5 +29,5 @@ export declare class GoogleGenerativeAI | Method | Modifiers | Description | | --- | --- | --- | | [getGenerativeModel(modelParams, requestOptions)](./generative-ai.googlegenerativeai.getgenerativemodel.md) | | Gets a [GenerativeModel](./generative-ai.generativemodel.md) instance for the provided model name. | -| [getGenerativeModelFromCachedContent(cachedContent, requestOptions)](./generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md) | | Creates a [GenerativeModel](./generative-ai.generativemodel.md) instance from provided content cache. | +| [getGenerativeModelFromCachedContent(cachedContent, modelParams, requestOptions)](./generative-ai.googlegenerativeai.getgenerativemodelfromcachedcontent.md) | | Creates a [GenerativeModel](./generative-ai.generativemodel.md) instance from provided content cache. | diff --git a/src/gen-ai.test.ts b/src/gen-ai.test.ts index 9e8972e9..fd17aa4e 100644 --- a/src/gen-ai.test.ts +++ b/src/gen-ai.test.ts @@ -18,17 +18,104 @@ import { ModelParams } from "../types"; import { GenerativeModel, GoogleGenerativeAI } from "./gen-ai"; import { expect } from "chai"; +const fakeContents = [{ role: "user", parts: [{ text: "hello" }] }]; + +const fakeCachedContent = { + model: "my-model", + name: "mycachename", + contents: fakeContents, +}; + describe("GoogleGenerativeAI", () => { - it("genGenerativeInstance throws if no model is provided", () => { + it("getGenerativeModel throws if no model is provided", () => { const genAI = new GoogleGenerativeAI("apikey"); expect(() => genAI.getGenerativeModel({} as ModelParams)).to.throw( "Must provide a model name", ); }); - it("genGenerativeInstance gets a GenerativeModel", () => { + it("getGenerativeModel gets a GenerativeModel", () => { const genAI = new GoogleGenerativeAI("apikey"); const genModel = genAI.getGenerativeModel({ model: "my-model" }); expect(genModel).to.be.an.instanceOf(GenerativeModel); expect(genModel.model).to.equal("models/my-model"); }); + it("getGenerativeModelFromCachedContent gets a GenerativeModel", () => { + const genAI = new GoogleGenerativeAI("apikey"); + const genModel = + genAI.getGenerativeModelFromCachedContent(fakeCachedContent); + expect(genModel).to.be.an.instanceOf(GenerativeModel); + expect(genModel.model).to.equal("models/my-model"); + expect(genModel.cachedContent).to.eql(fakeCachedContent); + }); + it("getGenerativeModelFromCachedContent gets a GenerativeModel merged with modelParams", () => { + const genAI = new GoogleGenerativeAI("apikey"); + const genModel = genAI.getGenerativeModelFromCachedContent( + fakeCachedContent, + { generationConfig: { temperature: 0 } }, + ); + expect(genModel).to.be.an.instanceOf(GenerativeModel); + expect(genModel.model).to.equal("models/my-model"); + expect(genModel.generationConfig.temperature).to.equal(0); + expect(genModel.cachedContent).to.eql(fakeCachedContent); + }); + it("getGenerativeModelFromCachedContent gets a GenerativeModel merged with modelParams with overlapping keys", () => { + const genAI = new GoogleGenerativeAI("apikey"); + const genModel = genAI.getGenerativeModelFromCachedContent( + fakeCachedContent, + { model: "my-model", generationConfig: { temperature: 0 } }, + ); + expect(genModel).to.be.an.instanceOf(GenerativeModel); + expect(genModel.model).to.equal("models/my-model"); + expect(genModel.generationConfig.temperature).to.equal(0); + expect(genModel.cachedContent).to.eql(fakeCachedContent); + }); + it("getGenerativeModelFromCachedContent throws if no name", () => { + const genAI = new GoogleGenerativeAI("apikey"); + expect(() => + genAI.getGenerativeModelFromCachedContent({ + model: "my-model", + contents: fakeContents, + }), + ).to.throw("Cached content must contain a `name` field."); + }); + it("getGenerativeModelFromCachedContent throws if no model", () => { + const genAI = new GoogleGenerativeAI("apikey"); + expect(() => + genAI.getGenerativeModelFromCachedContent({ + name: "cachename", + contents: fakeContents, + }), + ).to.throw("Cached content must contain a `model` field."); + }); + it("getGenerativeModelFromCachedContent throws if mismatched model", () => { + const genAI = new GoogleGenerativeAI("apikey"); + expect(() => + genAI.getGenerativeModelFromCachedContent( + { + name: "cachename", + model: "my-model", + contents: fakeContents, + }, + { model: "your-model" }, + ), + ).to.throw( + `Different value for "model" specified in modelParams (your-model) and cachedContent (my-model)`, + ); + }); + it("getGenerativeModelFromCachedContent throws if mismatched systemInstruction", () => { + const genAI = new GoogleGenerativeAI("apikey"); + expect(() => + genAI.getGenerativeModelFromCachedContent( + { + name: "cachename", + model: "my-model", + contents: fakeContents, + systemInstruction: "hi", + }, + { model: "models/my-model", systemInstruction: "yo" }, + ), + ).to.throw( + `Different value for "systemInstruction" specified in modelParams (yo) and cachedContent (hi)`, + ); + }); }); diff --git a/src/gen-ai.ts b/src/gen-ai.ts index 2cc34703..f65f489d 100644 --- a/src/gen-ai.ts +++ b/src/gen-ai.ts @@ -53,6 +53,7 @@ export class GoogleGenerativeAI { */ getGenerativeModelFromCachedContent( cachedContent: CachedContent, + modelParams?: Partial, requestOptions?: RequestOptions, ): GenerativeModel { if (!cachedContent.name) { @@ -65,7 +66,40 @@ export class GoogleGenerativeAI { "Cached content must contain a `model` field.", ); } + + /** + * Not checking tools and toolConfig for now as it would require a deep + * equality comparison and isn't likely to be a common case. + */ + const disallowedDuplicates: Array = + ["model", "systemInstruction"]; + + for (const key of disallowedDuplicates) { + if ( + modelParams?.[key] && + cachedContent[key] && + modelParams?.[key] !== cachedContent[key] + ) { + if (key === "model") { + const modelParamsComp = modelParams.model.startsWith("models/") + ? modelParams.model.replace("models/", "") + : modelParams.model; + const cachedContentComp = cachedContent.model.startsWith("models/") + ? cachedContent.model.replace("models/", "") + : cachedContent.model; + if (modelParamsComp === cachedContentComp) { + continue; + } + } + throw new GoogleGenerativeAIRequestInputError( + `Different value for "${key}" specified in modelParams` + + ` (${modelParams[key]}) and cachedContent (${cachedContent[key]})`, + ); + } + } + const modelParamsFromCache: ModelParams = { + ...modelParams, model: cachedContent.model, tools: cachedContent.tools, toolConfig: cachedContent.toolConfig,