Skip to content

Commit

Permalink
fix python style
Browse files Browse the repository at this point in the history
  • Loading branch information
warren830 authored and crazywoola committed Dec 11, 2024
1 parent d8edd32 commit 7ad6d89
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 84 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

def get_bedrock_client(service_name, credentials=None):
client_config = Config(region_name=credentials["aws_region"])
aws_access_key_id = credentials["aws_access_key_id"],
aws_access_key_id = (credentials["aws_access_key_id"],)
aws_secret_access_key = credentials["aws_secret_access_key"]
if aws_access_key_id and aws_secret_access_key:
# 使用 AKSK 方式
client = boto3.client(
service_name=service_name,
config=client_config,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key
aws_secret_access_key=aws_secret_access_key,
)
else:
# 使用 IAM 角色方式
Expand Down
116 changes: 58 additions & 58 deletions api/core/model_runtime/model_providers/bedrock/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,16 @@ def _find_model_info(model_id):
return None

def _code_block_mode_wrapper(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
callbacks: Optional[list[Callback]] = None,
) -> Union[LLMResult, Generator]:
"""
Code block mode wrapper for invoking large language model
Expand All @@ -118,15 +118,15 @@ def _code_block_mode_wrapper(
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)

def _invoke(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
Expand All @@ -153,15 +153,15 @@ def _invoke(
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)

def _generate_with_converse(
self,
model_info: dict,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
tools: Optional[list[PromptMessageTool]] = None,
self,
model_info: dict,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
tools: Optional[list[PromptMessageTool]] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model with converse API
Expand Down Expand Up @@ -220,7 +220,7 @@ def _generate_with_converse(
raise InvokeError(str(ex))

def _handle_converse_response(
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm chat response
Expand Down Expand Up @@ -284,11 +284,11 @@ def _extract_tool_use(self, content: dict) -> tuple[str, dict]:
return text, tool_use

def _handle_converse_stream_response(
self,
model: str,
credentials: dict,
response: dict,
prompt_messages: list[PromptMessage],
self,
model: str,
credentials: dict,
response: dict,
prompt_messages: list[PromptMessage],
) -> Generator:
"""
Handle llm chat stream response
Expand Down Expand Up @@ -370,7 +370,7 @@ def _handle_converse_stream_response(
raise InvokeError(str(ex))

def _convert_converse_api_model_parameters(
self, model_parameters: dict, stop: Optional[list[str]] = None
self, model_parameters: dict, stop: Optional[list[str]] = None
) -> tuple[dict, dict]:
inference_config = {}
additional_model_fields = {}
Expand Down Expand Up @@ -496,11 +496,11 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
return message_dict

def get_num_tokens(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage] | str,
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for given prompt messages
Expand Down Expand Up @@ -563,7 +563,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
raise CredentialsValidateFailedError(str(ex))

def _convert_one_message_to_text(
self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None
self, message: PromptMessage, model_prefix: str, model_name: Optional[str] = None
) -> str:
"""
Convert a single message to a string.
Expand Down Expand Up @@ -594,7 +594,7 @@ def _convert_one_message_to_text(
return message_text

def _convert_messages_to_prompt(
self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None
self, messages: list[PromptMessage], model_prefix: str, model_name: Optional[str] = None
) -> str:
"""
Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models
Expand All @@ -616,12 +616,12 @@ def _convert_messages_to_prompt(
return text.rstrip()

def _create_payload(
self,
model: str,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
self,
model: str,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
):
"""
Create payload for bedrock api call depending on model provider
Expand Down Expand Up @@ -654,14 +654,14 @@ def _create_payload(
return payload

def _generate(
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
self,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
user: Optional[str] = None,
) -> Union[LLMResult, Generator]:
"""
Invoke large language model
Expand Down Expand Up @@ -716,7 +716,7 @@ def _generate(
return self._handle_generate_response(model, credentials, response, prompt_messages)

def _handle_generate_response(
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
Handle llm response
Expand Down Expand Up @@ -767,7 +767,7 @@ def _handle_generate_response(
return result

def _handle_generate_stream_response(
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
self, model: str, credentials: dict, response: dict, prompt_messages: list[PromptMessage]
) -> Generator:
"""
Handle llm stream response
Expand Down
16 changes: 8 additions & 8 deletions api/core/model_runtime/model_providers/bedrock/rerank/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ class BedrockRerankModel(RerankModel):
"""

def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@

class BedrockTextEmbeddingModel(TextEmbeddingModel):
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
Expand Down Expand Up @@ -132,12 +132,12 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]
}

def _create_payload(
self,
model_prefix: str,
texts: list[str],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
self,
model_prefix: str,
texts: list[str],
model_parameters: dict,
stop: Optional[list[str]] = None,
stream: bool = True,
):
"""
Create payload for bedrock api call depending on model provider
Expand Down Expand Up @@ -202,10 +202,10 @@ def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[I
return InvokeError(error_msg)

def _invoke_bedrock_embedding(
self,
model: str,
bedrock_runtime,
body: dict,
self,
model: str,
bedrock_runtime,
body: dict,
):
accept = "application/json"
content_type = "application/json"
Expand Down

0 comments on commit 7ad6d89

Please sign in to comment.