From a585389f2226c985b72f1ab8b736a1ea26970e37 Mon Sep 17 00:00:00 2001 From: Leon Ruggiero Date: Tue, 5 Nov 2024 15:51:39 +0000 Subject: [PATCH] Improves completion api with new input parameters for LLMs (#23) * #17 Extend LLM input options * Add new options to test script, enable togetherai testing and small changes for clarity. * Remove cohere command (non-r) from tests. Change default test_options to be more conservative to prevent long repetition and have better compatibility for testing. --- drivers/src/bedrock/index.ts | 35 ++++++++++++++++++--- drivers/src/bedrock/payloads.ts | 45 +++++++++++++++++++++------ drivers/src/groq/index.ts | 12 ++++++- drivers/src/mistral/index.ts | 7 ++++- drivers/src/mistral/types.ts | 3 +- drivers/src/openai/index.ts | 12 ++++++- drivers/src/togetherai/index.ts | 23 ++++++++++++-- drivers/src/vertexai/models/gemini.ts | 5 +++ drivers/test/all-models.test.ts | 40 +++++++++++++++--------- 9 files changed, 146 insertions(+), 36 deletions(-) diff --git a/drivers/src/bedrock/index.ts b/drivers/src/bedrock/index.ts index bead67d..82d7fe6 100644 --- a/drivers/src/bedrock/index.ts +++ b/drivers/src/bedrock/index.ts @@ -6,7 +6,7 @@ import { transformAsyncIterator } from "@llumiverse/core/async"; import { ClaudeMessagesPrompt, formatClaudePrompt } from "@llumiverse/core/formatters"; import { AwsCredentialIdentity, Provider } from "@smithy/types"; import mnemonist from "mnemonist"; -import { AI21RequestPayload, AmazonRequestPayload, ClaudeRequestPayload, CohereCommandRPayload, CohereRequestPayload, LLama2RequestPayload, MistralPayload } from "./payloads.js"; +import { AI21JurassicRequestPayload, AmazonRequestPayload, ClaudeRequestPayload, CohereCommandRPayload, CohereRequestPayload, LLama3RequestPayload, MistralPayload } from "./payloads.js"; import { forceUploadFile } from "./s3.js"; const { LRUCache } = mnemonist; @@ -347,7 +347,8 @@ export class BedrockDriver extends AbstractDriver { @@ -365,35 +366,55 @@ export class BedrockDriver extends AbstractDriver ({ diff --git a/drivers/src/mistral/index.ts b/drivers/src/mistral/index.ts index 640394b..e4d6fa6 100644 --- a/drivers/src/mistral/index.ts +++ b/drivers/src/mistral/index.ts @@ -94,8 +94,11 @@ export class MistralAIDriver extends AbstractDriver> { + + const stop_seq = typeof options.stop_sequence == 'string' ? + [options.stop_sequence] : options.stop_sequence ?? []; + const res = await this.fetchClient.post('/v1/completions', { payload: { model: options.model, @@ -37,9 +41,15 @@ export class TogetherAIDriver extends AbstractDriver", - "[/INST]" + "[/INST]", + ...stop_seq, ], } }) as TextCompletion; @@ -59,6 +69,9 @@ export class TogetherAIDriver extends AbstractDriver> { + const stop_seq = typeof options.stop_sequence == 'string' ? + [options.stop_sequence] : options.stop_sequence ?? []; + const stream = await this.fetchClient.post('/v1/completions', { payload: { model: options.model, @@ -66,10 +79,16 @@ export class TogetherAIDriver extends AbstractDriver", - "[/INST]" + "[/INST]", + ...stop_seq, ], }, reader: 'sse' diff --git a/drivers/src/vertexai/models/gemini.ts b/drivers/src/vertexai/models/gemini.ts index c9f5b1f..09255d6 100644 --- a/drivers/src/vertexai/models/gemini.ts +++ b/drivers/src/vertexai/models/gemini.ts @@ -39,6 +39,11 @@ function getGenerativeModel(driver: VertexAIDriver, options: ExecutionOptions, m candidateCount: modelParams?.generationConfig?.candidateCount ?? 1, temperature: options.temperature, maxOutputTokens: options.max_tokens, + topP: options.top_p, + topK: options.top_k, + frequencyPenalty: options.frequency_penalty, + stopSequences: typeof options.stop_sequence === 'string' ? + [options.stop_sequence] : options.stop_sequence }, }); diff --git a/drivers/test/all-models.test.ts b/drivers/test/all-models.test.ts index 820bdcb..0cd6691 100644 --- a/drivers/test/all-models.test.ts +++ b/drivers/test/all-models.test.ts @@ -1,4 +1,4 @@ -import { AIModel, AbstractDriver } from '@llumiverse/core'; +import { AIModel, AbstractDriver, ExecutionOptions } from '@llumiverse/core'; import 'dotenv/config'; import { GoogleAuth } from 'google-auth-library'; import { describe, expect, test } from "vitest"; @@ -60,8 +60,8 @@ if (process.env.TOGETHER_API_KEY) { apiKey: process.env.TOGETHER_API_KEY as string }), models: [ - //"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - //"mistralai/Mixtral-8x7B-Instruct-v0.1" too slow in tests for now + "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + //"mistralai/Mixtral-8x7B-Instruct-v0.1" //too slow in tests for now ] } ) @@ -115,7 +115,7 @@ if (process.env.BEDROCK_REGION) { models: [ "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-v2:1", - "cohere.command-text-v14", + //"cohere.command-text-v14", EOL "mistral.mixtral-8x7b-instruct-v0:1", "cohere.command-r-plus-v1:0", "meta.llama3-1-70b-instruct-v1:0" @@ -165,11 +165,23 @@ if (process.env.WATSONX_API_KEY) { describe.concurrent.each(drivers)("Driver $name", ({ name, driver, models }) => { - let fetchedModels: AIModel[] + let fetchedModels: AIModel[]; + + let test_options: ExecutionOptions = { + model: "", + max_tokens: 128, + temperature: 0.3, + top_k: 40, + top_p: 0.7, //Some models do not support top_p = 1.0, set to 0.99 or lower. + top_logprobs: 5, //Currently not supported, option will be ignored + presence_penalty: 0.1, //Cohere Command R does not support using presence & frequency penalty at the same time + frequency_penalty: 0.0, + }; test(`${name}: list models`, { timeout: TIMEOUT, retry: 1 }, async () => { const r = await driver.listModels(); fetchedModels = r; + console.log(r) expect(r.length).toBeGreaterThan(0); }); @@ -191,27 +203,27 @@ describe.concurrent.each(drivers)("Driver $name", ({ name, driver, models }) => }); test.each(models)(`${name}: execute prompt on %s`, { timeout: TIMEOUT, retry: 3 }, async (model) => { - const r = await driver.execute(testPrompt_color, { model, temperature: 0.5, max_tokens: 1024 }); - console.debug("Result for " + model, JSON.stringify(r)); + const r = await driver.execute(testPrompt_color, {...test_options, model: model} as ExecutionOptions); + console.log("Result for execute " + model, JSON.stringify(r)); assertCompletionOk(r, model, driver); }); test.each(models)(`${name}: execute prompt with streaming on %s`, { timeout: TIMEOUT, retry: 3 }, async (model) => { - const r = await driver.stream(testPrompt_color, { model, temperature: 0.5, max_tokens: 1024 }) + const r = await driver.stream(testPrompt_color, {...test_options, model: model} as ExecutionOptions); const out = await assertStreamingCompletionOk(r); - console.log("Result for " + model, JSON.stringify(out)); + console.log("Result for streaming " + model, JSON.stringify(out)); }); test.each(models)(`${name}: execute prompt with schema on %s`, { timeout: TIMEOUT, retry: 3 }, async (model) => { - const r = await driver.execute(testPrompt_color, { model, temperature: 0.5, max_tokens: 1024, result_schema: testSchema_color }); - console.log("Result for " + model, JSON.stringify(r.result)); + const r = await driver.execute(testPrompt_color, {...test_options, model: model, result_schema: testSchema_color } as ExecutionOptions); + console.log("Result for execute with schema " + model, JSON.stringify(r.result)); assertCompletionOk(r, model, driver); }); test.each(models)(`${name}: execute prompt with streaming and schema on %s`, { timeout: TIMEOUT, retry: 3 }, async (model) => { - const r = await driver.stream(testPrompt_color, { model, temperature: 0.5, max_tokens: 1024, result_schema: testSchema_color }) + const r = await driver.stream(testPrompt_color, {...test_options, model: model, result_schema: testSchema_color } as ExecutionOptions); const out = await assertStreamingCompletionOk(r, true); - console.log("Result for prompt with streaming and schema" + model, JSON.stringify(out)); + console.log("Result for streaming with schema " + model, JSON.stringify(out)); }); @@ -235,6 +247,4 @@ describe.concurrent.each(drivers)("Driver $name", ({ name, driver, models }) => console.log("Result", r) assertCompletionOk(r); }); - - });