|
1 | | -from modelscope.pipelines import pipeline |
2 | | -from .tool import Tool |
| 1 | +from modelscope_agent.tools.base import BaseTool, register_tool |
| 2 | +import json |
| 3 | +import requests |
| 4 | +import os |
3 | 5 |
|
4 | | - |
5 | | -class ModelscopePipelineTool(Tool): |
6 | | - |
7 | | - default_model: str = '' |
8 | | - task: str = '' |
9 | | - model_revision = None |
| 6 | +@register_tool('pipeline') |
| 7 | +class ModelscopePipelineTool(BaseTool): |
| 8 | + API_URL = "" |
| 9 | + API_KEY = "" |
10 | 10 |
|
11 | 11 | def __init__(self, cfg): |
12 | | - |
| 12 | + """ |
| 13 | + 初始化一个ModelscopePipelineTool类 |
| 14 | + Initialize a ModelscopePipelineTool class. |
| 15 | + 参数: |
| 16 | + cfg (Dict[str, object]): 配置字典,包含了初始化对象所需要的参数 |
| 17 | + """ |
13 | 18 | super().__init__(cfg) |
14 | | - self.model = self.cfg.get('model', None) or self.default_model |
15 | | - self.model_revision = self.cfg.get('model_revision', |
16 | | - None) or self.model_revision |
17 | | - |
18 | | - self.pipeline_params = self.cfg.get('pipeline_params', {}) |
19 | | - self.pipeline = None |
20 | | - self.is_initialized = False |
21 | | - |
22 | | - def setup(self): |
23 | | - |
24 | | - # only initialize when this tool is really called to save memory |
25 | | - if not self.is_initialized: |
26 | | - self.pipeline = pipeline( |
27 | | - task=self.task, |
28 | | - model=self.model, |
29 | | - model_revision=self.model_revision, |
30 | | - **self.pipeline_params) |
31 | | - self.is_initialized = True |
32 | | - |
33 | | - def _local_call(self, *args, **kwargs): |
34 | | - |
35 | | - self.setup() |
| 19 | + self.API_URL = self.cfg.get(self.name, {}).get('url',None) or self.API_URL |
| 20 | + self.API_KEY = os.getenv('MODELSCOPE_API_KEY', None) or self.API_KEY |
| 21 | + |
| 22 | + |
| 23 | + def call(self, params: str, **kwargs) -> str: |
| 24 | + params = self._verify_args(params) |
| 25 | + data = json.dumps(params) |
| 26 | + headers = {"Authorization": f"Bearer {self.API_KEY}"} |
| 27 | + response = requests.request("POST", self.API_URL, headers=headers,data=data) |
| 28 | + result = json.loads(response.content.decode("utf-8")) |
| 29 | + return result |
36 | 30 |
|
37 | | - parsed_args, parsed_kwargs = self._local_parse_input(*args, **kwargs) |
38 | | - origin_result = self.pipeline(*parsed_args, **parsed_kwargs) |
39 | | - final_result = self._parse_output(origin_result, remote=False) |
40 | | - return final_result |
0 commit comments