diff --git a/modelscope_agent/tools/base.py b/modelscope_agent/tools/base.py index 048464af..ed6d5b69 100644 --- a/modelscope_agent/tools/base.py +++ b/modelscope_agent/tools/base.py @@ -14,7 +14,8 @@ MODELSCOPE_AGENT_TOKEN_HEADER_NAME) from modelscope_agent.tools.utils.openapi_utils import (execute_api_call, get_parameter_value, - openapi_schema_convert) + openapi_schema_convert, + structure_json) from modelscope_agent.utils.base64_utils import decode_base64_to_files from modelscope_agent.utils.logger import agent_logger as logger from modelscope_agent.utils.utils import has_chinese_chars @@ -594,8 +595,18 @@ def _verify_args(self, params: str, api_info) -> Union[str, dict]: for param in api_info['parameters']: if 'required' in param and param['required']: - if param['name'] not in params_json: - raise ValueError(f'param `{param["name"]}` is required') + + current = {} + current_test = copy.deepcopy(params_json) + parts = param['name'].split('.') + for i, part in enumerate(parts): + if part not in current: + current[part] = {} + current = current[part] + if part not in current_test: + raise ValueError( + f'param `{".".join(parts[:i])}` is required') + current_test = current_test[part] return params_json def _parse_credentials(self, credentials: dict, headers=None): @@ -690,7 +701,7 @@ def call(self, params: str, **kwargs): elif parameter['in'] == 'header': header[parameter['name']] = value else: - data[parameter['name']] = value + data[parameter['name'].split('.')[0]] = value for name, value in path_params.items(): url = url.replace(f'{{{name}}}', f'{value}') diff --git a/modelscope_agent/tools/utils/openapi_utils.py b/modelscope_agent/tools/utils/openapi_utils.py index 2cce9c8e..843a7af1 100644 --- a/modelscope_agent/tools/utils/openapi_utils.py +++ b/modelscope_agent/tools/utils/openapi_utils.py @@ -1,28 +1,30 @@ +import copy import os import jsonref import requests -def execute_api_call(url: str, method: str, headers: dict, params: dict, - data: dict, cookies: dict): +def structure_json(flat_json): + structured = {} - def structure_json(flat_json): - structured = {} + for key, value in flat_json.items(): + parts = key.split('.') + current = structured - for key, value in flat_json.items(): - parts = key.split('.') - current = structured + for i, part in enumerate(parts): + if i == len(parts) - 1: + current[part] = value + else: + if part not in current: + current[part] = {} + current = current[part] - for i, part in enumerate(parts): - if i == len(parts) - 1: - current[part] = value - else: - if part not in current: - current[part] = {} - current = current[part] + return structured - return structured + +def execute_api_call(url: str, method: str, headers: dict, params: dict, + data: dict, cookies: dict): if data: data = structure_json(data) @@ -331,9 +333,22 @@ def openapi_schema_convert(schema: dict, auth: dict = {}): return config_data -def get_parameter_value(parameter: dict, parameters: dict): - if parameter['name'] in parameters: - return parameters[parameter['name']] +def get_parameter_value(parameter: dict, generated_params: dict): + if parameter['name'] in generated_params: + return generated_params[parameter['name']] + elif '.' in parameter['name']: + current = {} + current_test = copy.deepcopy(generated_params) + parts = parameter['name'].split('.') + for i, part in enumerate(parts): + if part not in current: + current[part] = {} + current = current[part] + if part not in current_test: + raise ValueError(f'param `{".".join(parts[:i])}` is required') + current_test = current_test[part] + + return generated_params[parts[0]] elif parameter.get('required', False): raise ValueError(f"Missing required parameter {parameter['name']}") else: