Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added amazon titan embedding capability to bedrock client #115

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions py/genai_client/text_generation/bedrock_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import boto3
import json
import logging
import requests

from .abstract_text_generation_client import AbstractTextGenerationClient
from ..tokenizers.huggingface_tokenizer import HuggingfaceTokenizer
Expand All @@ -9,6 +10,7 @@
MAX_INPUT_TOKENS,
FULL_PROMPT,
AskModelEngineResponse,
EmbeddingsModelEngineResponse,
)

# from langchain_community.llms import Bedrock
Expand Down Expand Up @@ -400,3 +402,35 @@ def ask_call(
raise Exception(f"Error while making request to Bedrock: {e}")

return final_response

def embeddings_call(self, strings_to_embed:list[str]) -> EmbeddingsModelEngineResponse:
embeddings_list = []
embeddings = []

for text in strings_to_embed:
json_obj = {"inputText": text}
request = json.dumps(json_obj)

try:
client = self._get_client()

response = client.invoke_model(
modelId=self.modelId, body=request
)
response_body = json.loads(response['body'].read())
embedding_array = response_body.get("embedding")

if embedding_array:
embeddings_list = [float(value) for value in embedding_array]
embeddings.append(embeddings_list)

model_engine_response = EmbeddingsModelEngineResponse(
response=embeddings,
prompt_tokens=response_body.get("inputTextTokenCount"),
response_tokens=0
)

except requests.RequestException as e:
print(f"An error occurred in bedrock embeddings_call: {e}")

return model_engine_response
29 changes: 28 additions & 1 deletion src/prerna/engine/impl/model/BedrockEngine.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package prerna.engine.impl.model;


import java.util.List;
import java.util.Map;

import org.apache.logging.log4j.LogManager;
Expand All @@ -9,6 +9,7 @@
import prerna.ds.py.PyUtils;
import prerna.engine.api.ModelTypeEnum;
import prerna.engine.impl.model.responses.AskModelEngineResponse;
import prerna.engine.impl.model.responses.EmbeddingsModelEngineResponse;
import prerna.om.Insight;

public class BedrockEngine extends AbstractPythonModelEngine {
Expand Down Expand Up @@ -46,5 +47,31 @@ public AskModelEngineResponse summarize(String filePath, Insight insight, Map<St
AskModelEngineResponse response = AskModelEngineResponse.fromObject(output);
return response;
}

/**
*
* @param filePath
* @param insight
* @return
*/
public EmbeddingsModelEngineResponse embeddingsCall(List<String> stringsToEmbed, Insight insight) {
checkSocketStatus();

StringBuilder callMaker = new StringBuilder(this.varName + ".embeddings_call(");
if (stringsToEmbed != null && !stringsToEmbed.isEmpty()) {
callMaker.append("strings_to_embed")
.append("=")
.append(PyUtils.determineStringType(stringsToEmbed));
} else {
throw new IllegalArgumentException("Nothing given to embed");
}

callMaker.append(")");
classLogger.debug("Running >>>" + callMaker.toString());

Object output = pyt.runScript(callMaker.toString(), insight);
EmbeddingsModelEngineResponse response = EmbeddingsModelEngineResponse.fromObject(output);
return response;
}

}