Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC: custom models #325

Merged
merged 3 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,110 @@ $ xinference list --all
- If you want to use Apple Metal GPU for acceleration, please choose the q4_0 and q4_1 quantization methods.
- `llama-2-chat` 70B ggmlv3 model only supports q4_0 quantization currently.

## Custom models \[Experimental\]
Custom models are currently an experimental feature and are expected to be officially released in version v0.2.0.

Define a custom model based on the following template:
```python
custom_model = {
"version": 1,
# model name. must start with a letter or a
# digit, and can only contain letters, digits,
# underscores, or dashes.
"model_name": "nsql-2B",
# supported languages
"model_lang": [
"en"
],
# model abilities. could be "embed", "generate"
# and "chat".
"model_ability": [
"generate"
],
# model specifications.
"model_specs": [
{
# model format.
"model_format": "pytorch",
"model_size_in_billions": 2,
# quantizations.
"quantizations": [
"4-bit",
"8-bit",
"none"
],
# hugging face model ID.
"model_id": "NumbersStation/nsql-2B"
}
],
# prompt style, required by chat models.
# for more details, see: xinference/model/llm/tests/test_utils.py
"prompt_style": None
}
```

Register the custom model:
```python
import json

from xinference.client import Client

# replace with real xinference endpoint
endpoint = "http://localhost:9997"
client = Client(endpoint)
client.register_model(model_type="LLM", model=json.dumps(custom_model), persist=False)
```

Load the custom model:
```python
uid = client.launch_model(model_name='nsql-2B')
```

Run the custom model:
```python
text = """CREATE TABLE work_orders (
ID NUMBER,
CREATED_AT TEXT,
COST FLOAT,
INVOICE_AMOUNT FLOAT,
IS_DUE BOOLEAN,
IS_OPEN BOOLEAN,
IS_OVERDUE BOOLEAN,
COUNTRY_NAME TEXT,
)

-- Using valid SQLite, answer the following questions for the tables provided above.

-- how many work orders are open?

SELECT"""

model = client.get_model(model_uid=uid)
model.generate(prompt=text)
```

Result:
```json
{
"id":"aeb5c87a-352e-11ee-89ad-9af9f16816c5",
"object":"text_completion",
"created":1691418511,
"model":"3b912fc4-352e-11ee-8e66-9af9f16816c5",
"choices":[
{
"text":" COUNT(*) FROM work_orders WHERE IS_OPEN = '1';",
"index":0,
"logprobs":"None",
"finish_reason":"stop"
}
],
"usage":{
"prompt_tokens":117,
"completion_tokens":17,
"total_tokens":134
}
}
```

## Pytorch Model Best Practices

Expand Down
104 changes: 104 additions & 0 deletions README_zh_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,110 @@ $ xinference list --all
- 如果想使用 Apple metal GPU 加速,请选择 q4_0 或者 q4_1 这两种量化方式。
- `llama-2-chat` 70B ggmlv3 模型目前仅支持 q4_0 量化方式。

## 自定义模型\[Experimental\]
自定义模型目前是实验性功能,预计将会在 v0.2.0 版本正式与大家见面。

添加自定义模型前,请根据模版填写模型配置:
```python
custom_model = {
"version": 1,
# model name. must start with a letter or a
# digit, and can only contain letters, digits,
# underscores, or dashes.
"model_name": "nsql-2B",
# supported languages
"model_lang": [
"en"
],
# model abilities. could be "embed", "generate"
# and "chat".
"model_ability": [
"generate"
],
# model specifications.
"model_specs": [
{
# model format.
"model_format": "pytorch",
"model_size_in_billions": 2,
# quantizations.
"quantizations": [
"4-bit",
"8-bit",
"none"
],
# hugging face model ID.
"model_id": "NumbersStation/nsql-2B"
}
],
# prompt style, required by chat models.
# for more details, see: xinference/model/llm/tests/test_utils.py
"prompt_style": None
}
```

注册自定义模型:
```python
import json

from xinference.client import Client

# replace with real xinference endpoint
endpoint = "http://localhost:9997"
client = Client(endpoint)
client.register_model(model_type="LLM", model=json.dumps(custom_model), persist=False)
```

加载模型:
```python
uid = client.launch_model(model_name='nsql-2B')
```

推理:
```python
text = """CREATE TABLE work_orders (
ID NUMBER,
CREATED_AT TEXT,
COST FLOAT,
INVOICE_AMOUNT FLOAT,
IS_DUE BOOLEAN,
IS_OPEN BOOLEAN,
IS_OVERDUE BOOLEAN,
COUNTRY_NAME TEXT,
)

-- Using valid SQLite, answer the following questions for the tables provided above.

-- how many work orders are open?

SELECT"""

model = client.get_model(model_uid=uid)
model.generate(prompt=text)
```

结果:
```json
{
"id":"aeb5c87a-352e-11ee-89ad-9af9f16816c5",
"object":"text_completion",
"created":1691418511,
"model":"3b912fc4-352e-11ee-8e66-9af9f16816c5",
"choices":[
{
"text":" COUNT(*) FROM work_orders WHERE IS_OPEN = '1';",
"index":0,
"logprobs":"None",
"finish_reason":"stop"
}
],
"usage":{
"prompt_tokens":117,
"completion_tokens":17,
"total_tokens":134
}
}
```

## Pytorch 模型最佳实践

Expand Down