Skip to content

Commit

Permalink
Merge pull request #1685 from BerriAI/litellm_bedrock_emb_input
Browse files Browse the repository at this point in the history
[Fix] Graceful rejection of token input for AWS Embeddings API
  • Loading branch information
ishaan-jaff authored Jan 30, 2024
2 parents c53ad87 + f941c57 commit 2686ec0
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
18 changes: 15 additions & 3 deletions litellm/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,11 @@ def _embedding_func_single(
encoding=None,
logging_obj=None,
):
if isinstance(input, str) is False:
raise BedrockError(
message="Bedrock Embedding API input must be type str | List[str]",
status_code=400,
)
# logic for parsing in - calling - parsing out model embedding calls
## FORMAT EMBEDDING INPUT ##
provider = model.split(".")[0]
Expand Down Expand Up @@ -795,7 +800,8 @@ def embedding(
aws_role_name=aws_role_name,
aws_session_name=aws_session_name,
)
if type(input) == str:
if isinstance(input, str):
## Embedding Call
embeddings = [
_embedding_func_single(
model,
Expand All @@ -805,8 +811,8 @@ def embedding(
logging_obj=logging_obj,
)
]
else:
## Embedding Call
elif isinstance(input, list):
## Embedding Call - assuming this is a List[str]
embeddings = [
_embedding_func_single(
model,
Expand All @@ -817,6 +823,12 @@ def embedding(
)
for i in input
] # [TODO]: make these parallel calls
else:
# enters this branch if input = int, ex. input=2
raise BedrockError(
message="Bedrock Embedding API input must be type str | List[str]",
status_code=400,
)

## Populate OpenAI compliant dictionary
embedding_response = []
Expand Down
19 changes: 19 additions & 0 deletions litellm/tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,25 @@ def test_bedrock_embedding_cohere():
# test_bedrock_embedding_cohere()


def test_demo_tokens_as_input_to_embeddings_fails_for_titan():
litellm.set_verbose = True

with pytest.raises(
litellm.BadRequestError,
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
):
litellm.embedding(model="amazon.titan-embed-text-v1", input=[[1]])

with pytest.raises(
litellm.BadRequestError,
match="BedrockException - Bedrock Embedding API input must be type str | List[str]",
):
litellm.embedding(
model="amazon.titan-embed-text-v1",
input=[1],
)


# comment out hf tests - since hf endpoints are unstable
def test_hf_embedding():
try:
Expand Down

0 comments on commit 2686ec0

Please sign in to comment.