Skip to content

Commit 0f410b0

Browse files
zzhangpurdueZhicheng Zhang
andauthored
[do not merge]Refactor/unittest (#264)
* update unittest * update requirements * update requirements * update requirements * update requirements * update requirements * update requirements * update requirements * update requirements * update requirements * update requirements * update requirements * update requirements * update ci * update ci * update ci * update ci * update ci * update ci * update ci * update ci * update ci * update ci * merge gradio4 * update ci * update ci * update ci * update ci * pass unit test * merge and pass unit test --------- Co-authored-by: Zhicheng Zhang <zhangzhicheng.zzc@alibaba-inc.com>
1 parent 42f94cd commit 0f410b0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+424
-550
lines changed

.dev_scripts/dockerci.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#!/bin/bash
2+
3+
4+
# install dependencies for ci
5+
pip install torch
6+
export CODE_INTERPRETER_WORK_DIR=${GITHUB_WORKSPACE}
7+
echo "${CODE_INTERPRETER_WORK_DIR}"
8+
9+
# cp file
10+
cp tests/samples/luoli15.jpg "${CODE_INTERPRETER_WORK_DIR}/luoli15.jpg"
11+
ls "${CODE_INTERPRETER_WORK_DIR}"
12+
13+
# run ci
14+
pytest

.github/workflows/citest.yaml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ jobs:
4040
unittest:
4141
# The type of runner that the job will run on
4242
runs-on: ubuntu-latest
43-
timeout-minutes: 240
43+
timeout-minutes: 20
44+
environment: testci
4445
steps:
4546
- uses: actions/checkout@v3
4647

@@ -51,14 +52,21 @@ jobs:
5152
path: ~/.cache/pip
5253
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
5354

54-
- name: Set up Python 3.8
55+
- name: Set up Python 3.10.13
5556
uses: actions/setup-python@v3
5657
with:
57-
python-version: 3.8
58+
python-version: "3.10.13"
5859

5960
- name: Install dependencies
6061
if: steps.cache.outputs.cache-hit != 'true'
6162
run: pip install -r requirements.txt
6263

6364
- name: Run tests
64-
run: pytest
65+
env:
66+
AMAP_TOKEN: ${{ secrets.AMAP_TOKEN }}
67+
BING_SEARCH_V7_SUBSCRIPTION_KEY: ${{ secrets.BING_SEARCH_V7_SUBSCRIPTION_KEY }}
68+
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
69+
MODELSCOPE_API_TOKEN: ${{ secrets.MODELSCOPE_API_TOKEN }}
70+
71+
shell: bash
72+
run: bash .dev_scripts/dockerci.sh

modelscope_agent/agent.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,20 @@ def _register_tool(self, tool: Union[str, Dict]):
8686
"""
8787
Instantiate the global tool for the agent
8888
89+
Args:
90+
tool: the tool should be either in a string format with name as value
91+
and in a dict format, example
92+
(1) When str: amap_weather
93+
(2) When dict: {'amap_weather': {'token': 'xxx'}}
94+
95+
Returns:
96+
8997
"""
9098
tool_name = tool
9199
tool_cfg = {}
92100
if isinstance(tool, Dict):
93-
tool_name = tool['name']
94-
tool_cfg = tool
101+
tool_name = next(iter(tool))
102+
tool_cfg = tool[tool_name]
95103
if tool_name not in TOOL_REGISTRY:
96104
raise NotImplementedError
97105
if tool not in self.function_list:

modelscope_agent/llm/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
from .base import LLM_REGISTRY, BaseChatModel
44
from .custom import CustomLLM
5-
from .dashscope import DashScopeLLM
6-
from .dashscope_qwen import QwenChatAtDS
7-
from .modelscope import ModelScopeLLM
8-
from .modelscope_chatglm import ModelScopeChatGLM
5+
from .dashscope import DashScopeLLM, QwenChatAtDS
6+
from .modelscope import ModelScopeChatGLM, ModelScopeLLM
97
from .openai import OpenAi
108

119

modelscope_agent/llm/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def chat(self,
4343
assert isinstance(prompt, str)
4444
messages = [{'role': 'user', 'content': prompt}]
4545
else:
46-
assert prompt is None, 'Do not pass agents and messages at the same time.'
46+
assert prompt is None, 'Do not pass prompt and messages at the same time.'
47+
48+
assert len(messages) > 0, 'messages list must not be empty'
4749

4850
if stream:
4951
return self._chat_stream(messages, stop=stop, **kwargs)

modelscope_agent/llm/dashscope.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,64 @@ def _chat_no_stream(self,
9696
response.message,
9797
)
9898
return err
99+
100+
101+
@register_llm('dashscope_qwen')
102+
class QwenChatAtDS(DashScopeLLM):
103+
"""
104+
qwen_model from dashscope
105+
"""
106+
107+
def chat_with_raw_prompt(self,
108+
prompt: str,
109+
stop: Optional[List[str]] = None,
110+
**kwargs) -> str:
111+
if prompt == '':
112+
return ''
113+
stop = stop or []
114+
top_p = kwargs.get('top_p', 0.8)
115+
116+
response = dashscope.Generation.call(
117+
self.model,
118+
prompt=prompt, # noqa
119+
stop_words=[{
120+
'stop_str': word,
121+
'mode': 'exclude'
122+
} for word in stop],
123+
top_p=top_p,
124+
result_format='message',
125+
stream=False,
126+
use_raw_prompt=True,
127+
)
128+
if response.status_code == HTTPStatus.OK:
129+
# with open('debug.json', 'w', encoding='utf-8') as writer:
130+
# writer.write(json.dumps(response, ensure_ascii=False))
131+
return response.output.choices[0].message.content
132+
else:
133+
err = 'Error code: %s, error message: %s' % (
134+
response.code,
135+
response.message,
136+
)
137+
return err
138+
139+
def build_raw_prompt(self, messages):
140+
im_start = '<|im_start|>'
141+
im_end = '<|im_end|>'
142+
if messages[0]['role'] == 'system':
143+
sys = messages[0]['content']
144+
prompt = f'{im_start}system\n{sys}{im_end}'
145+
else:
146+
prompt = f'{im_start}system\nYou are a helpful assistant.{im_end}'
147+
148+
for message in messages:
149+
if message['role'] == 'user':
150+
query = message['content'].lstrip('\n').rstrip()
151+
prompt += f'\n{im_start}user\n{query}{im_end}'
152+
elif message['role'] == 'assistant':
153+
response = message['content'].lstrip('\n').rstrip()
154+
prompt += f'\n{im_start}assistant\n{response}{im_end}'
155+
156+
# add one empty reply for the last round of assistant
157+
assert prompt.endswith(f'\n{im_start}assistant\n{im_end}')
158+
prompt = prompt[:-len(f'{im_end}')]
159+
return prompt

modelscope_agent/llm/dashscope_qwen.py

Lines changed: 0 additions & 67 deletions
This file was deleted.

modelscope_agent/llm/modelscope.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@
22
import sys
33
from typing import Dict, Iterator, List, Optional
44

5-
import torch
6-
from swift import Swift
7-
from transformers import AutoModelForCausalLM, AutoTokenizer
8-
9-
from modelscope import GenerationConfig, snapshot_download
105
from .base import BaseChatModel, register_llm
116

127

@@ -18,6 +13,15 @@ class ModelScopeLLM(BaseChatModel):
1813

1914
def __init__(self, model: str, model_server: str, **kwargs):
2015
super().__init__(model, model_server)
16+
try:
17+
import torch
18+
except ImportError:
19+
raise ImportError(
20+
'Please install torch first: `pip install torch` or refer https://pytorch.org/ '
21+
)
22+
23+
from modelscope import (AutoModelForCausalLM, AutoTokenizer,
24+
GenerationConfig, snapshot_download)
2125

2226
# Download model based on model version
2327
self.model_version = kwargs.get('model_version', None)
@@ -59,6 +63,11 @@ def __init__(self, model: str, model_server: str, **kwargs):
5963
self.model_dir, trust_remote_code=True)
6064

6165
if self.use_lora:
66+
try:
67+
from swift import Swift
68+
except ImportError:
69+
raise ImportError(
70+
'Please install swift first: `pip install ms-swift`')
6271
self.load_from_lora()
6372

6473
if self.use_raw_generation_config:
@@ -112,3 +121,30 @@ def _inference(self, prompt: str) -> str:
112121
response = self.tokenizer.decode(result)
113122

114123
return response
124+
125+
126+
@register_llm('modelscope_chatglm')
127+
class ModelScopeChatGLM(ModelScopeLLM):
128+
129+
def _inference(self, prompt: str) -> str:
130+
device = self.model.device
131+
input_ids = self.tokenizer(
132+
prompt, return_tensors='pt').input_ids.to(device)
133+
input_len = input_ids.shape[1]
134+
135+
eos_token_id = [
136+
self.tokenizer.eos_token_id,
137+
self.tokenizer.get_command('<|user|>'),
138+
self.tokenizer.get_command('<|observation|>')
139+
]
140+
result = self.model.generate(
141+
input_ids=input_ids,
142+
generation_config=self.generation_cfg,
143+
eos_token_id=eos_token_id)
144+
145+
result = result[0].tolist()[input_len:]
146+
response = self.tokenizer.decode(result)
147+
# 遇到生成'<', '|', 'user', '|', '>'的case
148+
response = response.split('<|user|>')[0].split('<|observation|>')[0]
149+
150+
return response

modelscope_agent/llm/modelscope_chatglm.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

modelscope_agent/tools/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ class BaseTool(ABC):
2323
parameters: List[Dict]
2424

2525
def __init__(self, cfg: Optional[Dict] = {}):
26-
self.cfg = cfg or {}
26+
"""
27+
:param schema: Format of tools, default to oai format, in case there is a need for other formats
28+
"""
29+
self.cfg = cfg.get(self.name, {})
2730

2831
self.schema = self.cfg.get('schema', 'oai')
2932
self.function = self._build_function()

0 commit comments

Comments
 (0)