Skip to content

Commit

Permalink
fix(converse_transformation.py): handle cross region model name when …
Browse files Browse the repository at this point in the history
…getting openai param support

Fixes #6291
  • Loading branch information
krrishdholakia committed Oct 18, 2024
1 parent 5e381ca commit e80f693
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 9 deletions.
14 changes: 9 additions & 5 deletions litellm/llms/bedrock/chat/converse_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,19 @@ def get_supported_openai_params(self, model: str) -> List[str]:
"response_format",
]

## Filter out 'cross-region' from model name
base_model = self._get_base_model(model)

if (
model.startswith("anthropic")
or model.startswith("mistral")
or model.startswith("cohere")
or model.startswith("meta.llama3-1")
base_model.startswith("anthropic")
or base_model.startswith("mistral")
or base_model.startswith("cohere")
or base_model.startswith("meta.llama3-1")
or base_model.startswith("meta.llama3-2")
):
supported_params.append("tools")

if model.startswith("anthropic") or model.startswith("mistral"):
if base_model.startswith("anthropic") or base_model.startswith("mistral"):
# only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
supported_params.append("tool_choice")

Expand Down
36 changes: 32 additions & 4 deletions tests/llm_translation/test_optional_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,41 @@ def test_bedrock_optional_params_embeddings():
],
)
def test_bedrock_optional_params_completions(model):
litellm.drop_params = True
tools = [
{
"type": "function",
"function": {
"name": "structure_output",
"description": "Send structured output back to the user",
"strict": True,
"parameters": {
"type": "object",
"properties": {
"reasoning": {"type": "string"},
"sentiment": {"type": "string"},
},
"required": ["reasoning", "sentiment"],
"additionalProperties": False,
},
"additionalProperties": False,
},
}
]
optional_params = get_optional_params(
model=model, max_tokens=10, temperature=0.1, custom_llm_provider="bedrock"
model=model,
max_tokens=10,
temperature=0.1,
tools=tools,
custom_llm_provider="bedrock",
)
print(f"optional_params: {optional_params}")
assert len(optional_params) == 3
assert optional_params == {"maxTokens": 10, "stream": False, "temperature": 0.1}
assert len(optional_params) == 4
assert optional_params == {
"maxTokens": 10,
"stream": False,
"temperature": 0.1,
"tools": tools,
}


@pytest.mark.parametrize(
Expand Down

0 comments on commit e80f693

Please sign in to comment.