Skip to content

Commit

Permalink
fix bugs for nested openapi
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhangpurdue committed Nov 13, 2024
1 parent c4a6218 commit ad9bce4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 22 deletions.
19 changes: 15 additions & 4 deletions modelscope_agent/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
MODELSCOPE_AGENT_TOKEN_HEADER_NAME)
from modelscope_agent.tools.utils.openapi_utils import (execute_api_call,
get_parameter_value,
openapi_schema_convert)
openapi_schema_convert,
structure_json)
from modelscope_agent.utils.base64_utils import decode_base64_to_files
from modelscope_agent.utils.logger import agent_logger as logger
from modelscope_agent.utils.utils import has_chinese_chars
Expand Down Expand Up @@ -594,8 +595,18 @@ def _verify_args(self, params: str, api_info) -> Union[str, dict]:

for param in api_info['parameters']:
if 'required' in param and param['required']:
if param['name'] not in params_json:
raise ValueError(f'param `{param["name"]}` is required')

current = {}
current_test = copy.deepcopy(params_json)
parts = param['name'].split('.')
for i, part in enumerate(parts):
if part not in current:
current[part] = {}
current = current[part]
if part not in current_test:
raise ValueError(
f'param `{".".join(parts[:i])}` is required')
current_test = current_test[part]
return params_json

def _parse_credentials(self, credentials: dict, headers=None):
Expand Down Expand Up @@ -690,7 +701,7 @@ def call(self, params: str, **kwargs):
elif parameter['in'] == 'header':
header[parameter['name']] = value
else:
data[parameter['name']] = value
data[parameter['name'].split('.')[0]] = value

for name, value in path_params.items():
url = url.replace(f'{{{name}}}', f'{value}')
Expand Down
51 changes: 33 additions & 18 deletions modelscope_agent/tools/utils/openapi_utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
import copy
import os

import jsonref
import requests


def execute_api_call(url: str, method: str, headers: dict, params: dict,
data: dict, cookies: dict):
def structure_json(flat_json):
structured = {}

def structure_json(flat_json):
structured = {}
for key, value in flat_json.items():
parts = key.split('.')
current = structured

for key, value in flat_json.items():
parts = key.split('.')
current = structured
for i, part in enumerate(parts):
if i == len(parts) - 1:
current[part] = value
else:
if part not in current:
current[part] = {}
current = current[part]

for i, part in enumerate(parts):
if i == len(parts) - 1:
current[part] = value
else:
if part not in current:
current[part] = {}
current = current[part]
return structured

return structured

def execute_api_call(url: str, method: str, headers: dict, params: dict,
data: dict, cookies: dict):

if data:
data = structure_json(data)
Expand Down Expand Up @@ -331,9 +333,22 @@ def openapi_schema_convert(schema: dict, auth: dict = {}):
return config_data


def get_parameter_value(parameter: dict, parameters: dict):
if parameter['name'] in parameters:
return parameters[parameter['name']]
def get_parameter_value(parameter: dict, generated_params: dict):
if parameter['name'] in generated_params:
return generated_params[parameter['name']]
elif '.' in parameter['name']:
current = {}
current_test = copy.deepcopy(generated_params)
parts = parameter['name'].split('.')
for i, part in enumerate(parts):
if part not in current:
current[part] = {}
current = current[part]
if part not in current_test:
raise ValueError(f'param `{".".join(parts[:i])}` is required')
current_test = current_test[part]

return generated_params[parts[0]]
elif parameter.get('required', False):
raise ValueError(f"Missing required parameter {parameter['name']}")
else:
Expand Down

0 comments on commit ad9bce4

Please sign in to comment.