From bcdd9e59350f95806ce979645cc20ea8949e2539 Mon Sep 17 00:00:00 2001 From: Michael Moore Date: Thu, 12 Sep 2024 16:40:02 -0400 Subject: [PATCH 1/4] Added embeddings_call to bedrock client --- .../text_generation/bedrock_client.py | 34 +++++++++++++++++++ .../engine/impl/model/BedrockEngine.java | 27 +++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/py/genai_client/text_generation/bedrock_client.py b/py/genai_client/text_generation/bedrock_client.py index f27f8e2481e..507cfa0856c 100644 --- a/py/genai_client/text_generation/bedrock_client.py +++ b/py/genai_client/text_generation/bedrock_client.py @@ -1,6 +1,7 @@ import boto3 import json import logging +import requests from .abstract_text_generation_client import AbstractTextGenerationClient from ..tokenizers.huggingface_tokenizer import HuggingfaceTokenizer @@ -9,6 +10,7 @@ MAX_INPUT_TOKENS, FULL_PROMPT, AskModelEngineResponse, + EmbeddingsModelEngineResponse, ) # from langchain_community.llms import Bedrock @@ -400,3 +402,35 @@ def ask_call( raise Exception(f"Error while making request to Bedrock: {e}") return final_response + + def embeddings_call(self, strings_to_embed:list[str]) -> EmbeddingsModelEngineResponse: + embeddings_list = [] + embeddings = [] + + for text in strings_to_embed: + json_obj = {"inputText": text} + request = json.dumps(json_obj) + + try: + client = self._get_client() + + response = client.invoke_model( + modelId=self.modelId, body=request + ) + response_body = json.loads(response['body'].read()) + embedding_array = response_body.get("embedding") + + if embedding_array: + embeddings_list = [float(value) for value in embedding_array] + embeddings.append(embeddings_list) + + model_engine_response = EmbeddingsModelEngineResponse( + response=embeddings, + prompt_tokens=response_body.get("inputTextTokenCount"), + response_tokens=0 + ) + + except requests.RequestException as e: + print(f"An error occurred in bedrock embeddings_call: {e}") + + return model_engine_response diff --git a/src/prerna/engine/impl/model/BedrockEngine.java b/src/prerna/engine/impl/model/BedrockEngine.java index ab06207b7b8..2f7d818539b 100644 --- a/src/prerna/engine/impl/model/BedrockEngine.java +++ b/src/prerna/engine/impl/model/BedrockEngine.java @@ -9,6 +9,7 @@ import prerna.ds.py.PyUtils; import prerna.engine.api.ModelTypeEnum; import prerna.engine.impl.model.responses.AskModelEngineResponse; +import prerna.engine.impl.model.responses.EmbeddingsModelEngineResponse; import prerna.om.Insight; public class BedrockEngine extends AbstractPythonModelEngine { @@ -46,5 +47,31 @@ public AskModelEngineResponse summarize(String filePath, Insight insight, Map stringsToEmbed, Insight insight) { + checkSocketStatus(); + + StringBuilder callMaker = new StringBuilder(this.varName + ".embeddings_call("); + if (stringsToEmbed != null && !stringsToEmbed.isEmpty()) { + callMaker.append("strings_to_embed") + .append("=") + .append(PyUtils.determineStringType(stringsToEmbed)); + } else { + throw new IllegalArgumentException("Nothing given to embed"); + } + + callMaker.append(")"); + classLogger.debug("Running >>>" + callMaker.toString()); + + Object output = pyt.runScript(callMaker.toString(), insight); + AskModelEngineResponse response = AskModelEngineResponse.fromObject(output); + return response; + } } From 2577f9869a4aeb91859c90052133328aad2913ca Mon Sep 17 00:00:00 2001 From: Michael Moore Date: Thu, 12 Sep 2024 17:37:52 -0400 Subject: [PATCH 2/4] fixed missing import in BedrockEngine --- src/prerna/engine/impl/model/BedrockEngine.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prerna/engine/impl/model/BedrockEngine.java b/src/prerna/engine/impl/model/BedrockEngine.java index 2f7d818539b..7e94b42110f 100644 --- a/src/prerna/engine/impl/model/BedrockEngine.java +++ b/src/prerna/engine/impl/model/BedrockEngine.java @@ -1,6 +1,6 @@ package prerna.engine.impl.model; - +import java.util.List; import java.util.Map; import org.apache.logging.log4j.LogManager; From 37b36978d8b07dcc0a09899b571ee1cc096328d4 Mon Sep 17 00:00:00 2001 From: Michael Moore Date: Thu, 12 Sep 2024 17:41:32 -0400 Subject: [PATCH 3/4] Fixed return type for embeddingsCall --- src/prerna/engine/impl/model/BedrockEngine.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prerna/engine/impl/model/BedrockEngine.java b/src/prerna/engine/impl/model/BedrockEngine.java index 7e94b42110f..f8ddf1ce3c1 100644 --- a/src/prerna/engine/impl/model/BedrockEngine.java +++ b/src/prerna/engine/impl/model/BedrockEngine.java @@ -70,7 +70,7 @@ public EmbeddingsModelEngineResponse embeddingsCall(List stringsToEmbed, classLogger.debug("Running >>>" + callMaker.toString()); Object output = pyt.runScript(callMaker.toString(), insight); - AskModelEngineResponse response = AskModelEngineResponse.fromObject(output); + EmbeddingsModelEngineResponse response = AskModelEngineResponse.fromObject(output); return response; } From 42f82c9191d2d3025da246c43587f865b60357d8 Mon Sep 17 00:00:00 2001 From: Michael Moore Date: Thu, 12 Sep 2024 17:45:07 -0400 Subject: [PATCH 4/4] Fixed error in bedrock engine --- src/prerna/engine/impl/model/BedrockEngine.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/prerna/engine/impl/model/BedrockEngine.java b/src/prerna/engine/impl/model/BedrockEngine.java index f8ddf1ce3c1..04a07bf7693 100644 --- a/src/prerna/engine/impl/model/BedrockEngine.java +++ b/src/prerna/engine/impl/model/BedrockEngine.java @@ -70,7 +70,7 @@ public EmbeddingsModelEngineResponse embeddingsCall(List stringsToEmbed, classLogger.debug("Running >>>" + callMaker.toString()); Object output = pyt.runScript(callMaker.toString(), insight); - EmbeddingsModelEngineResponse response = AskModelEngineResponse.fromObject(output); + EmbeddingsModelEngineResponse response = EmbeddingsModelEngineResponse.fromObject(output); return response; }