Skip to content

Commit

Permalink
Support Command R on Bedrock (#12)
Browse files Browse the repository at this point in the history
* feat: Support for Cohere Command R on Bedrock (fixes #11)
  • Loading branch information
ebarroca authored May 20, 2024
1 parent 6620438 commit 533b7e5
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 61 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/node.js.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
# This workflow will do a clean installation of node dependencies, cache/restore them, build the source code and run tests across different versions of node
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-nodejs

name: Build+Test

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
workflow_dispatch:

permissions:
id-token: write
Expand Down
73 changes: 19 additions & 54 deletions drivers/src/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { transformAsyncIterator } from "@llumiverse/core/async";
import { ClaudeMessagesPrompt, formatClaudePrompt } from "@llumiverse/core/formatters";
import { AwsCredentialIdentity, Provider } from "@smithy/types";
import mnemonist from "mnemonist";
import { AI21RequestPayload, AmazonRequestPayload, ClaudeRequestPayload, CohereCommandRPayload, CohereRequestPayload, LLama2RequestPayload, MistralPayload } from "./payloads.js";
import { forceUploadFile } from "./s3.js";

const { LRUCache } = mnemonist;
Expand Down Expand Up @@ -100,13 +101,16 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
// LLAMA2
return [result.generation, result.stop_reason]; // comes in coirrect format (stop, length)
} else if (result.generations) {
// COHERE
// Cohere
return [result.generations[0].text, cohereFinishReason(result.generations[0].finish_reason)];
} else if (result.chat_history) {
//Cohere Command R
return [result.text, cohereFinishReason(result.finish_reason)];
} else if (result.completions) {
//A21
return [result.completions[0].data?.text, a21FinishReason(result.completions[0].finishReason?.reason)];
} else if (result.content) {
// anthropic claude
// Claude
//if last prompt.messages is {, add { to the response
const p = prompt as ClaudeMessagesPrompt;
const lastMessage = (p as ClaudeMessagesPrompt).messages[p.messages.length - 1];
Expand Down Expand Up @@ -174,7 +178,7 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
async requestCompletionStream(prompt: string, options: ExecutionOptions): Promise<AsyncIterable<string>> {
const payload = this.preparePayload(prompt, options);
const executor = this.getExecutor();
console.log("Requesting completion stream", JSON.stringify(payload));
console.log("Requesting completion with Streaming for model " + options.model, JSON.stringify(payload));
return executor.invokeModelWithResponseStream({
modelId: options.model,
contentType: "application/json",
Expand All @@ -188,10 +192,13 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP

return transformAsyncIterator(res.body, (stream: ResponseStream) => {
const segment = JSON.parse(decoder.decode(stream.chunk?.bytes));
//console.log("Debug Segment for model " + options.model, JSON.stringify(segment));
if (segment.delta) { // who is this?
return segment.delta.text || '';
} else if (segment.completion) { // who is this?
return segment.completion;
} else if (segment.text) { //cohere
return segment.text;
} else if (segment.completions) {
return segment.completions[0].data?.text;
} else if (segment.generation) {
Expand Down Expand Up @@ -247,12 +254,19 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP
temperature: options.temperature,
maxTokens: options.max_tokens,
} as AI21RequestPayload;
} else if (contains(options.model, "cohere")) {
} else if (contains(options.model, "command-r-plus")) {
return {
message: prompt as string,
max_tokens: options.max_tokens,
temperature: options.temperature,
} as CohereCommandRPayload;

}
else if (contains(options.model, "cohere")) {
return {
prompt: prompt,
temperature: options.temperature,
max_tokens: options.max_tokens,
p: 0.9,
} as CohereRequestPayload;
} else if (contains(options.model, "amazon")) {
return {
Expand Down Expand Up @@ -459,55 +473,6 @@ export class BedrockDriver extends AbstractDriver<BedrockDriverOptions, BedrockP



interface LLama2RequestPayload {
prompt: string;
temperature: number;
top_p?: number;
max_gen_len: number;
}

interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
anthropic_version: "bedrock-2023-05-31",
max_tokens: number,
prompt: string;
temperature?: number;
top_p?: number,
top_k?: number,
stop_sequences?: [string];
}

interface AI21RequestPayload {
prompt: string;
temperature: number;
maxTokens: number;
}

interface CohereRequestPayload {
prompt: string;
temperature: number;
max_tokens?: number;
p?: number;
}

interface AmazonRequestPayload {
inputText: string,
textGenerationConfig: {
temperature: number,
topP: number,
maxTokenCount: number,
stopSequences: [string];
};
}

interface MistralPayload {
prompt: string,
temperature: number,
max_tokens: number,
top_p?: number,
top_k?: number,
}


function jobInfo(job: GetModelCustomizationJobCommandOutput, jobId: string): TrainingJob {
const jobStatus = job.status;
let status = TrainingJobStatus.running;
Expand Down
67 changes: 67 additions & 0 deletions drivers/src/bedrock/payloads.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { ClaudeMessagesPrompt } from "@llumiverse/core/formatters";

export interface LLama2RequestPayload {
prompt: string;
temperature: number;
top_p?: number;
max_gen_len: number;
}
export interface ClaudeRequestPayload extends ClaudeMessagesPrompt {
anthropic_version: "bedrock-2023-05-31";
max_tokens: number;
prompt: string;
temperature?: number;
top_p?: number;
top_k?: number;
stop_sequences?: [string];
}
export interface AI21RequestPayload {
prompt: string;
temperature: number;
maxTokens: number;
}
export interface CohereRequestPayload {
prompt: string;
temperature: number;
max_tokens?: number;
p?: number;
}
export interface AmazonRequestPayload {
inputText: string;
textGenerationConfig: {
temperature: number;
topP: number;
maxTokenCount: number;
stopSequences: [string];
};
}
export interface MistralPayload {
prompt: string;
temperature: number;
max_tokens: number;
top_p?: number;
top_k?: number;
}

export interface CohereCommandRPayload {

message: string,
chat_history?: {
role: 'USER' | 'CHATBOT',
message: string }[],
documents?: { title: string, snippet: string }[],
search_queries_only?: boolean,
preamble?: string,
max_tokens: number,
temperature?: number,
p?: number,
k?: number,
prompt_truncation?: string,
frequency_penalty?: number,
presence_penalty?: number,
seed?: number,
return_prompt?: boolean,
stop_sequences?: string[],
raw_prompting?: boolean

}
5 changes: 3 additions & 2 deletions drivers/test/all-models.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ if (process.env.BEDROCK_REGION) {
"cohere.command-text-v14",
"ai21.j2-mid-v1",
"mistral.mixtral-8x7b-instruct-v0:1",
"cohere.command-r-plus-v1:0"
]
}
)
Expand Down Expand Up @@ -155,14 +156,14 @@ describe.concurrent.each(drivers)("Driver $name", ({ name, driver, models }) =>

test.each(models)(`${name}: execute prompt with streaming on %s`, async (model) => {
const r = await driver.stream(testPrompt_color, { model, temperature: 0.8, max_tokens: 1024 })
//console.log(JSON.stringify(r));
//console.log("Result for " + model, JSON.stringify(r));
await assertStreamingCompletionOk(r);
}, TIMEOUT);

test.each(models)(`${name}: execute prompt with schema on %s`, async (model) => {
console.log("Executing with schema", testPrompt_color)
const r = await driver.execute(testPrompt_color, { model, temperature: 0.8, max_tokens: 1024, resultSchema: testSchema_color });
//console.log(JSON.stringify(r));
//console.log("Result for " + model, JSON.stringify(r));
assertCompletionOk(r);
}, TIMEOUT);

Expand Down

0 comments on commit 533b7e5

Please sign in to comment.