Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: move prompt to config & add addon prompt #714

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/config/locale/en-US.ts
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ export default {
filePath: 'File Path',
importGraphSpace: 'Import Graph Space',
exportNGQLFilePath: 'Export NGQL File Path',
prompt: 'Prompt',
attachPrompt: 'Attach Prompt',
next: 'Next',
url: 'URL',
previous: 'Previous',
Expand Down
2 changes: 1 addition & 1 deletion app/config/locale/zh-CN.ts
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ export default {
filePath: '文件路径',
importGraphSpace: '导入图空间',
exportNGQLFilePath: '导出 NGQL 文件路径',
prompt: '提示',
attachPrompt: '附加提示',
next: '下一步',
previous: '上一步',
start: '开始',
Expand Down
9 changes: 3 additions & 6 deletions app/pages/Import/AIImport/Create.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { Button, Form, Input, Modal, Radio, Select, message } from 'antd';
import { observer } from 'mobx-react-lite';
import Icon from '@app/components/Icon';
import { useEffect, useMemo, useState } from 'react';
import { llmImportPrompt } from '@app/stores/llm';
import { getByteLength } from '@app/utils/function';
import { post } from '@app/utils/http';
import styles from './index.module.less';
Expand All @@ -30,7 +29,7 @@ const Create = observer((props: { visible: boolean; onCancel: () => void }) => {
form.resetFields();
form.setFieldsValue({
type: 'file',
promptTemplate: llmImportPrompt,
userPrompt: '',
});
setTokens(null);
}, [props.visible]);
Expand Down Expand Up @@ -63,11 +62,9 @@ const Create = observer((props: { visible: boolean; onCancel: () => void }) => {

const onConfirm = async () => {
const values = form.getFieldsValue();
const schema = await llm.getSpaceSchema(space);
post('/api/llm/import/job')({
type,
...values,
spaceSchemaString: schema,
}).then((res) => {
if (res.code === 0) {
message.success(intl.get('common.success'));
Expand Down Expand Up @@ -152,8 +149,8 @@ const Create = observer((props: { visible: boolean; onCancel: () => void }) => {
<Form.Item required label={intl.get('llm.exportNGQLFilePath')}>
<Input disabled value={llm.config.gqlPath} />
</Form.Item>
<Form.Item required={true} label={intl.get('llm.prompt')} name="promptTemplate">
<Input.TextArea style={{ height: 200 }} />
<Form.Item label={intl.get('llm.attachPrompt')} name="userPrompt">
<Input.TextArea />
</Form.Item>
</Form>

Expand Down
2 changes: 1 addition & 1 deletion app/pages/Import/AIImport/index.module.less
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

.tokenNum {
position: absolute;
top: 200px;
top: 180px;
right: 20px;
display: flex;
align-items: center;
Expand Down
2 changes: 1 addition & 1 deletion app/pages/LLMBot/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function Chat() {
const newMessages = [
...messages,
{ role: 'user', content: currentInput },
{ role: 'assistant', content: '', status: 'pending' },
{ role: 'assistant', content: '', status: 'pending' }, // asistant can't be changed
];
llm.update({
currentInput: '',
Expand Down
81 changes: 28 additions & 53 deletions app/stores/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,13 @@ diff
---
Question:{query_str}
`;
export const llmImportPrompt = `As a knowledge graph AI importer, your task is to extract useful data from the following text:
----text
{text}
----

the knowledge graph has following schema and node name must be a real :
----graph schema
{spaceSchema}
----

Return the results directly, without explain and comment. The results should be in the following JSON format:
{
"nodes":[{ "name":string,"type":string,"props":object }],
"edges":[{ "src":string,"dst":string,"edgeType":string,"props":object }]
}
The name of the nodes should be an actual object and a noun.
Result:
`;

export const docFinderPrompt = `The task is to identify the two best words from "{category_string}"\n that answer the question "{query_str}" for NebulaGraph database.The output should be a a comma-separated list of these two words.Don't explain anything.`;
export const docFinderPrompt = `The task is to identify the top2 effectively categories from
\`\`\`categories
{category_string}
\`\`\`
that answer the question "{query_str}" with the user's history ask is:"{history_str}" for NebulaGraph database.
The output should be a comma-separated list like "category1,category2" and don't explain anything`;

export const text2queryPrompt = `Assuming you are an NebulaGraph database AI assistant, your role is to assist users in crafting NGQL queries with NebulaGraph. You have access to the following details:
the user space schema is:
Expand Down Expand Up @@ -137,45 +124,28 @@ class LLM {
}

async getSpaceSchema(space: string) {
let finalPrompt: any = {
currentUsedSpaceName: space,
};
const finalPrompt = `The user's current graph space is: ${space} \n`;
if (this.config.features.includes('spaceSchema')) {
await schema.switchSpace(space);
await schema.getTagList();
await schema.getEdgeList();
const tagList = schema.tagList;
const edgeList = schema.edgeList;
finalPrompt = {
...finalPrompt,
vidType: schema.spaceVidType,
nodeTypes: tagList.map((item) => {
return {
type: item.name,
props: item.fields.map((item) => {
return {
name: item.Field,
dataType: item.Type,
nullable: (item as any).Null === 'YES',
};
}),
};
}),
edgeTypes: edgeList.map((item) => {
return {
type: item.name,
props: item.fields.map((item) => {
return {
name: item.Field,
dataType: item.Type,
nullable: (item as any).Null === 'YES',
};
}),
};
}),
};
let nodeSchemaString = '';
const edgeSchemaString = '';
tagList.forEach((item) => {
nodeSchemaString += `NodeType ${item.name} (${item.fields
.map((field) => `${field.Field}:${field.Type}`)
.join(' ')})\n`;
});
edgeList.forEach((item) => {
nodeSchemaString += `EdgeType ${item.name} (${item.fields
.map((field) => `${field.Field}:${field.Type}`)
.join(' ')})\n`;
});
return finalPrompt + nodeSchemaString + edgeSchemaString;
}
return JSON.stringify(finalPrompt);
return finalPrompt;
}

async getAgentPrompt(query_str: string, historyMessages: any, callback: (res: any) => void) {
Expand Down Expand Up @@ -270,17 +240,22 @@ class LLM {
let prompt = this.mode === 'text2cypher' ? matchPrompt : text2queryPrompt;
if (this.mode !== 'text2cypher') {
text = text.replaceAll('"', "'");
const history = historyMessages
.filter((item) => item.role === 'user')
.map((item) => item.content)
.join(',');
const docPrompt = docFinderPrompt
.replace('{category_string}', ngqlDoc.NGQLCategoryString)
.replace('{query_str}', text)
.replace('{history_str}', history)
.replace('{space_name}', rootStore.console.currentSpace);
console.log(docPrompt);
const res = (await ws.runChat({
req: {
stream: false,
max_tokens: 20,
max_tokens: 40,
top_p: 0.8,
messages: [
...historyMessages,
{
role: 'user',
content: docPrompt,
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"@vesoft-inc/force-graph": "2.0.7",
"@vesoft-inc/i18n": "^1.0.1",
"@vesoft-inc/icons": "^1.2.0",
"@vesoft-inc/nebula-explain-graph": "^1.0.2-beta.2",
"@vesoft-inc/nebula-explain-graph": "^1.0.2-beta.6",
"@vesoft-inc/veditor": "^4.4.12",
"antd": "^5.8.4",
"axios": "^0.23.0",
Expand Down
31 changes: 16 additions & 15 deletions server/api/studio/cmd/ai_importer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ import (

type Config struct {
LLMJob struct {
Space string
File string
PromptTemplate string
Space string
File string
}
Auth struct {
Address string
Expand All @@ -36,8 +35,9 @@ type Config struct {
APIType db.APIType
ContextLengthLimit int
}
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
PromptTemplate string `json:",default="`
}

func main() {
Expand All @@ -55,10 +55,9 @@ func main() {
CacheNodes: make(map[string]llm.Node),
CacheEdges: make(map[string]map[string]llm.Edge),
LLMJob: &db.LLMJob{
JobID: fmt.Sprintf("%d", time.Now().UnixNano()),
Space: c.LLMJob.Space,
File: c.LLMJob.File,
PromptTemplate: c.LLMJob.PromptTemplate,
JobID: fmt.Sprintf("%d", time.Now().UnixNano()),
Space: c.LLMJob.Space,
File: c.LLMJob.File,
},
AuthData: &auth.AuthData{
Address: c.Auth.Address,
Expand All @@ -75,13 +74,15 @@ func main() {
}
studioConfig := config.Config{
LLM: struct {
GQLPath string `json:",default=./data/llm"`
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
GQLPath string `json:",default=./data/llm"`
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
PromptTemplate string `json:",default="`
}{
GQLPath: *outputPath,
GQLBatchSize: c.GQLBatchSize,
MaxBlockSize: c.MaxBlockSize,
GQLPath: *outputPath,
GQLBatchSize: c.GQLBatchSize,
MaxBlockSize: c.MaxBlockSize,
PromptTemplate: c.PromptTemplate,
},
}
studioConfig.InitConfig()
Expand Down
28 changes: 14 additions & 14 deletions server/api/studio/etc/ai-importer.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
LLMJob:
Space: "" #space name
File: "" #file path,support pdf,txt,json,csv and other text format
PromptTemplate: |
Auth:
Address: "127.0.0.1" # nebula graphd address
Port: 9669
Username: "root"
Password: "nebula"
LLMConfig:
URL: "" # openai api url
Key: "" # openai api key
APIType: "openai"
ContextLengthLimit: 1024
MaxBlockSize: 0 # max request block num
GQLBatchSize: 100 # max gql batch size
PromptTemplate: |
As a knowledge graph AI importer, your task is to extract useful data from the following text:
----text
{text}
Expand All @@ -18,16 +30,4 @@ LLMJob:
"edges":[{ "src":string,"dst":string,"edgeType":string,"props":object }]
}
The name of the nodes should be an actual object and a noun.
Result:
Auth:
Address: "127.0.0.1" # nebula graphd address
Port: 9669
Username: "root"
Password: "nebula"
LLMConfig:
URL: "" # openai api url
Key: "" # openai api key
APIType: "openai"
ContextLengthLimit: 1024
MaxBlockSize: 0 # max request block num
GQLBatchSize: 100 # max gql batch size
Result:
19 changes: 18 additions & 1 deletion server/api/studio/etc/studio-api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,21 @@ DB:
LLM:
GQLPath: "./data/llm"
GQLBatchSize: 100
MaxBlockSize: 0
MaxBlockSize: 0
PromptTemplate: |
As a knowledge graph AI importer, your task is to extract useful data from the following text:
```text
{text}
```
the knowledge graph has following schema and node name must be a real :
```graph-schema
{spaceSchema}
```
{userPrompt}
Return the results directly, without explain and comment. The results should be in the following JSON format:
{
"nodes":[{ "name":string,"type":string,"props":object }],
"edges":[{ "src":string,"dst":string,"edgeType":string,"props":object }]
}
The name of the nodes should be an actual object and a noun.
Result:
27 changes: 24 additions & 3 deletions server/api/studio/internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@ import (

var configIns *Config

var PromptTemplate = `As a knowledge graph AI importer, your task is to extract useful data from the following text:` +
"```text\n" +
`{text}` +
"\n```\n" +
`the knowledge graph has following schema and node name must be a real :` +
"```graph-schema\n" +
`{spaceSchema}` +
"\n```\n" +
`{userPrompt}
Return the results directly, without explain and comment. The results should be in the following JSON format:
{
"nodes":[{ "name":string,"type":string,"props":object }],
"edges":[{ "src":string,"dst":string,"edgeType":string,"props":object }]
}
The name of the nodes should be an actual object and a noun.
Result:`

func GetConfig() *Config {
return configIns
}
Expand Down Expand Up @@ -65,9 +82,10 @@ type Config struct {
}

LLM struct {
GQLPath string `json:",default=./data/llm"`
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
GQLPath string `json:",default=./data/llm"`
GQLBatchSize int `json:",default=100"`
MaxBlockSize int `json:",default=0"`
PromptTemplate string `json:",default="`
}
}

Expand Down Expand Up @@ -117,6 +135,9 @@ func (c *Config) Complete() {
if c.LLM.MaxBlockSize == 0 {
c.LLM.MaxBlockSize = 1024 * 1024 * 1024
}
if c.LLM.PromptTemplate == "" {
c.LLM.PromptTemplate = PromptTemplate
}
}

func (c *Config) InitConfig() error {
Expand Down
2 changes: 1 addition & 1 deletion server/api/studio/internal/model/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ import (
"time"

"github.com/pkg/errors"
"github.com/vesoft-inc/nebula-studio/server/api/studio/internal/config"
"github.com/zeromicro/go-zero/core/logx"
"go.uber.org/zap"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"

"github.com/vesoft-inc/nebula-studio/server/api/studio/internal/config"
dbutil "github.com/vesoft-inc/nebula-studio/server/api/studio/pkg/db"
)

Expand Down
Loading