Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bedrock mistral #268

Merged
merged 9 commits into from
Mar 29, 2024
159 changes: 159 additions & 0 deletions src/providers/bedrock/chatComplete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import {
BedrockLlamaStreamChunk,
BedrockTitanCompleteResponse,
BedrockTitanStreamChunk,
BedrockMistralCompleteResponse,
BedrocMistralStreamChunk,
} from './complete';
import { BedrockErrorResponse } from './embed';

Expand Down Expand Up @@ -268,6 +270,57 @@ export const BedrockLLamaChatCompleteConfig: ProviderConfig = {
},
};

export const BedrockMistralChatCompleteConfig: ProviderConfig = {
messages: {
param: 'prompt',
required: true,
transform: (params: Params) => {
let prompt: string = '';
if (!!params.messages) {
let messages: Message[] = params.messages;
messages.forEach((msg, index) => {
if (index === 0 && msg.role === 'system') {
prompt += `system: ${messages}\n`;
} else if (msg.role == 'user') {
prompt += `user: ${msg.content}\n`;
} else if (msg.role == 'assistant') {
prompt += `assistant: ${msg.content}\n`;
} else {
prompt += `${msg.role}: ${msg.content}\n`;
}
});
prompt += 'Assistant:';
}
return prompt;
},
},
max_tokens: {
param: 'max_tokens',
default: 20,
min: 1,
},
temperature: {
param: 'temperature',
default: 0.75,
min: 0,
max: 5,
},
top_p: {
param: 'top_p',
default: 0.75,
min: 0,
max: 1,
},
top_k: {
param: 'top_k',
default: 0,
max: 200,
},
stop: {
param: 'stop',
},
};

const transformTitanGenerationConfig = (params: Params) => {
const generationConfig: Record<string, any> = {};
if (params['temperature']) {
Expand Down Expand Up @@ -908,3 +961,109 @@ export const BedrockCohereChatCompleteStreamChunkTransform: (
],
})}\n\n`;
};

export const BedrockMistralChatCompleteResponseTransform: (
response: BedrockMistralCompleteResponse | BedrockErrorResponse,
responseStatus: number,
responseHeaders: Headers
) => ChatCompletionResponse | ErrorResponse = (
response,
responseStatus,
responseHeaders
) => {
if (responseStatus !== 200) {
const errorResposne = BedrockErrorResponseTransform(
response as BedrockErrorResponse
);
if (errorResposne) return errorResposne;
}

if ('outputs' in response) {
const prompt_tokens =
Number(responseHeaders.get('X-Amzn-Bedrock-Input-Token-Count')) || 0;
const completion_tokens =
Number(responseHeaders.get('X-Amzn-Bedrock-Output-Token-Count')) || 0;
return {
id: Date.now().toString(),
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
index: 0,
message: {
role: 'assistant',
content: response.outputs[0].text,
},
finish_reason: response.outputs[0].stop_reason,
},
],
usage: {
prompt_tokens: prompt_tokens,
completion_tokens: completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
};
}

return generateInvalidProviderResponseError(response, BEDROCK);
};

export const BedrockMistralChatCompleteStreamChunkTransform: (
response: string,
fallbackId: string
) => string | string[] = (responseChunk, fallbackId) => {
let chunk = responseChunk.trim();
chunk = chunk.replace(/^data: /, '');
chunk = chunk.trim();
const parsedChunk: BedrocMistralStreamChunk = JSON.parse(chunk);

// discard the last cohere chunk as it sends the whole response combined.
if (parsedChunk.outputs[0].stop_reason) {
return [
`data: ${JSON.stringify({
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
index: 0,
delta: {},
finish_reason: parsedChunk.outputs[0].stop_reason,
},
],
usage: {
prompt_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount,
completion_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount,
total_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount +
parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount,
},
})}\n\n`,
`data: [DONE]\n\n`,
];
}

return `data: ${JSON.stringify({
id: fallbackId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
index: 0,
delta: {
role: 'assistant',
content: parsedChunk.outputs[0].text,
},
finish_reason: null,
},
],
})}\n\n`;
};
152 changes: 152 additions & 0 deletions src/providers/bedrock/complete.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,37 @@ export const BedrockLLamaCompleteConfig: ProviderConfig = {
},
};

export const BedrockMistralCompleteConfig: ProviderConfig = {
prompt: {
param: 'prompt',
required: true,
},
max_tokens: {
param: 'max_tokens',
default: 20,
min: 1,
},
temperature: {
param: 'temperature',
default: 0.75,
min: 0,
max: 5,
},
top_p: {
aashsach marked this conversation as resolved.
Show resolved Hide resolved
param: 'top_p',
default: 0.75,
min: 0,
max: 1,
},
top_k: {
param: 'top_k',
default: 0,
},
stop: {
VisargD marked this conversation as resolved.
Show resolved Hide resolved
param: 'stop',
},
};

const transformTitanGenerationConfig = (params: Params) => {
const generationConfig: Record<string, any> = {};
if (params['temperature']) {
Expand Down Expand Up @@ -754,3 +785,124 @@ export const BedrockCohereCompleteStreamChunkTransform: (
],
})}\n\n`;
};

export interface BedrocMistralStreamChunk {
outputs: {
text: string;
stop_reason: string | null;
}[];
'amazon-bedrock-invocationMetrics': {
inputTokenCount: number;
outputTokenCount: number;
invocationLatency: number;
firstByteLatency: number;
};
}

export const BedrockMistralCompleteStreamChunkTransform: (
response: string,
fallbackId: string
) => string | string[] = (responseChunk, fallbackId) => {
let chunk = responseChunk.trim();
chunk = chunk.trim();
const parsedChunk: BedrocMistralStreamChunk = JSON.parse(chunk);

if (parsedChunk.outputs[0].stop_reason) {
return [
`data: ${JSON.stringify({
id: fallbackId,
object: 'text_completion',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
text: parsedChunk.outputs[0].text,
index: 0,
logprobs: null,
finish_reason: parsedChunk.outputs[0].stop_reason,
},
],
usage: {
aashsach marked this conversation as resolved.
Show resolved Hide resolved
prompt_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount,
completion_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount,
total_tokens:
parsedChunk['amazon-bedrock-invocationMetrics'].inputTokenCount +
parsedChunk['amazon-bedrock-invocationMetrics'].outputTokenCount,
},
})}\n\n`,
`data: [DONE]\n\n`,
];
}

return `data: ${JSON.stringify({
id: fallbackId,
object: 'text_completion',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
text: parsedChunk.outputs[0].text,
index: 0,
logprobs: null,
finish_reason: null,
},
],
})}\n\n`;
};

export interface BedrockMistralCompleteResponse {
outputs: {
text: string;
stop_reason: string;
}[];
}

export const BedrockMistralCompleteResponseTransform: (
response: BedrockMistralCompleteResponse | BedrockErrorResponse,
responseStatus: number,
responseHeaders: Headers
) => CompletionResponse | ErrorResponse = (
response,
responseStatus,
responseHeaders
) => {
if (responseStatus !== 200) {
const errorResponse = BedrockErrorResponseTransform(
response as BedrockErrorResponse
);
if (errorResponse) return errorResponse;
}

if ('outputs' in response) {
const prompt_tokens =
Number(responseHeaders.get('X-Amzn-Bedrock-Input-Token-Count')) || 0;
const completion_tokens =
Number(responseHeaders.get('X-Amzn-Bedrock-Output-Token-Count')) || 0;
return {
id: Date.now().toString(),
object: 'text_completion',
created: Math.floor(Date.now() / 1000),
model: '',
provider: BEDROCK,
choices: [
{
text: response.outputs[0].text,
index: 0,
logprobs: null,
finish_reason: response.outputs[0].stop_reason,
},
],
usage: {
aashsach marked this conversation as resolved.
Show resolved Hide resolved
prompt_tokens: prompt_tokens,
completion_tokens: completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
};
}

return generateInvalidProviderResponseError(response, BEDROCK);
};
19 changes: 19 additions & 0 deletions src/providers/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import {
BedrockTitanChatCompleteResponseTransform,
BedrockTitanChatCompleteStreamChunkTransform,
BedrockTitanChatompleteConfig,
BedrockMistralChatCompleteConfig,
BedrockMistralChatCompleteResponseTransform,
BedrockMistralChatCompleteStreamChunkTransform,
} from './chatComplete';
import {
BedrockAI21CompleteConfig,
Expand All @@ -30,6 +33,9 @@ import {
BedrockLLamaCompleteConfig,
BedrockLlamaCompleteResponseTransform,
BedrockLlamaCompleteStreamChunkTransform,
BedrockMistralCompleteConfig,
BedrockMistralCompleteResponseTransform,
BedrockMistralCompleteStreamChunkTransform,
BedrockTitanCompleteConfig,
BedrockTitanCompleteResponseTransform,
BedrockTitanCompleteStreamChunkTransform,
Expand Down Expand Up @@ -91,6 +97,19 @@ const BedrockConfig: ProviderConfigs = {
chatComplete: BedrockLlamaChatCompleteResponseTransform,
},
};
case 'mistral':
return {
complete: BedrockMistralCompleteConfig,
chatComplete: BedrockMistralChatCompleteConfig,
api: BedrockAPIConfig,
responseTransforms: {
'stream-complete': BedrockMistralCompleteStreamChunkTransform,
complete: BedrockMistralCompleteResponseTransform,
'stream-chatComplete':
BedrockMistralChatCompleteStreamChunkTransform,
chatComplete: BedrockMistralChatCompleteResponseTransform,
},
};
case 'amazon':
return {
complete: BedrockTitanCompleteConfig,
Expand Down
Loading