Skip to content

Commit

Permalink
fix: ai import bug (#710)
Browse files Browse the repository at this point in the history
  • Loading branch information
mizy authored Dec 13, 2023
1 parent f477936 commit 1c360d8
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 39 deletions.
20 changes: 5 additions & 15 deletions app/pages/Setting/index.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useCallback, useEffect, useState } from 'react';
import { useCallback, useEffect } from 'react';
import { observer } from 'mobx-react-lite';
import { Button, Col, Form, Input, InputNumber, Row, Select, Switch, message } from 'antd';
import { useI18n } from '@vesoft-inc/i18n';
Expand All @@ -15,15 +15,13 @@ const Setting = observer(() => {
const { global, llm } = useStore();
const { appSetting, saveAppSetting } = global;
const [form] = useForm();
const [apiType, setApiType] = useState('openai');
useEffect(() => {
initForm();
}, []);

const initForm = async () => {
await llm.fetchConfig();
form.setFieldsValue(llm.config);
setApiType(llm.config.apiType);
};

const updateAppSetting = useCallback(async (param: Partial<any['beta']>) => {
Expand Down Expand Up @@ -109,13 +107,7 @@ const Setting = observer(() => {
<div className={styles.tips}>{intl.get('setting.llmImportDesc')}</div>
<Form form={form} layout="vertical" style={{ marginTop: 20 }}>
<Form.Item label="API type" name="apiType" required={true}>
<Select
onChange={(value) => {
setApiType(value);
}}
defaultValue="openai"
style={{ width: 120 }}
>
<Select defaultValue="openai" style={{ width: 120 }}>
<Select.Option value="openai">OpenAI</Select.Option>
<Select.Option value="qwen">Aliyun</Select.Option>
</Select>
Expand All @@ -126,11 +118,9 @@ const Setting = observer(() => {
<Form.Item label="key" name="key">
<Input type="password" />
</Form.Item>
{apiType === 'qwen' && (
<Form.Item label="model" name="model" required={true}>
<Input placeholder="qwen-max" />
</Form.Item>
)}
<Form.Item label="model" name="model">
<Input />
</Form.Item>
<Form.Item label={intl.get('setting.maxTextLength')} name="maxContextLength" required={true}>
<InputNumber min={0} />
</Form.Item>
Expand Down
41 changes: 21 additions & 20 deletions app/stores/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import * as ngqlDoc from '@app/utils/ngql';
import schema from './schema';
import rootStore from '.';

export const matchPrompt = `Use NebulaGraph match knowledge to help me answer question.
export const matchPrompt = `I want you to be a NebulaGraph database asistant.
There are below document.
----
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
Schema:
Expand All @@ -25,6 +27,7 @@ diff
---
> MATCH (p:person)-[:directed]->(m:movie) WHERE m.movie.name == 'The Godfather'
> RETURN p.person.name;
---
Question:{query_str}
`;
export const llmImportPrompt = `As a knowledge graph AI importer, your task is to extract useful data from the following text:
Expand All @@ -46,14 +49,12 @@ The name of the nodes should be an actual object and a noun.
Result:
`;

export const docFinderPrompt = `Assume your are doc finder,from the following graph database book categories:
export const docFinderPrompt = `Assuming you are a document navigator, within the following categories related to graph database books:
"{category_string}"
user current space is: {space_name}
find top two useful categories to solve the question:"{query_str}",
don't explain, if you can't find, return "Sorry".
just return the two combined categories, separated by ',' is:`;
please identify the most two relevant categories that could address the question: "{query_str}".,
Please just return the two categories as a comma-separated list without any other word`;

export const text2queryPrompt = `Assume you are a NebulaGraph AI chat asistant to help user write NGQL.
export const text2queryPrompt = `Assume you are a NebulaGraph database AI chat asistant to help user write NGQL with NebulaGraph.
You have access to the following information:
the user space schema is:
----
Expand Down Expand Up @@ -227,10 +228,8 @@ class LLM {
console.log(prompt);
await ws.runChat({
req: {
temperature: 0.5,
stream: true,
max_tokens: 20,

messages: [
...historyMessages,
{
Expand Down Expand Up @@ -274,18 +273,19 @@ class LLM {
let prompt = matchPrompt; // default use text2cypher
if (this.mode !== 'text2cypher') {
text = text.replaceAll('"', "'");
const docPrompt = docFinderPrompt
.replace('{category_string}', ngqlDoc.NGQLCategoryString)
.replace('{query_str}', text)
.replace('{space_name}', rootStore.console.currentSpace);
console.log(docPrompt);
const res = (await ws.runChat({
req: {
temperature: 0.5,
stream: false,
max_tokens: 20,
messages: [
{
role: 'user',
content: docFinderPrompt
.replace('{category_string}', ngqlDoc.NGQLCategoryString)
.replace('{query_str}', text)
.replace('{space_name}', rootStore.console.currentSpace),
content: docPrompt,
},
],
},
Expand All @@ -297,19 +297,19 @@ class LLM {
.replaceAll(/\s|"|\\/g, '')
.split(',');
console.log('select doc url:', paths);
if (ngqlDoc.ngqlMap[paths[0]]) {
let doc = ngqlDoc.ngqlMap[paths[0]].content;
if (paths[0] !== 'sorry') {
let doc = ngqlDoc.ngqlMap[paths[0]]?.content;
if (!doc) {
doc = '';
}
const doc2 = ngqlDoc.ngqlMap[paths[1]].content;
const doc2 = ngqlDoc.ngqlMap[paths[1]]?.content;
if (doc2) {
doc += doc2;
doc =
doc.slice(0, this.config.maxContextLength / 2) + `\n` + doc2.slice(0, this.config.maxContextLength / 2);
}
doc = doc.replaceAll(/\n\n+/g, '');
if (doc.length) {
console.log('docString:', doc);
prompt = text2queryPrompt.replace('{doc}', doc.slice(0, this.config.maxContextLength));
prompt = text2queryPrompt.replace('{doc}', doc);
}
}
}
Expand All @@ -327,6 +327,7 @@ class LLM {
schemaPrompt += `\nuser console ngql context: ${rootStore.console.currentGQL}`;
}
prompt = prompt.replace('{schema}', schemaPrompt);
console.log(prompt);
return prompt;
}

Expand Down
5 changes: 4 additions & 1 deletion app/utils/ngql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ export const ngqlDoc = (ngqlJson as { url: string; content: string; title: strin
if (urlTransformerMap[item.title]) {
item.title = urlTransformerMap[item.title];
}
item.title = item.title.replaceAll(' ', '');
item.title = item.title
.split(' ')
.map((word) => word[0].toUpperCase() + word.slice(1))
.join('');
item.content = item.content.replace(/nebula>/g, '');

return item;
Expand Down
14 changes: 11 additions & 3 deletions server/api/studio/pkg/llm/importjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,16 @@ func (i *ImportJob) MakeGQLFile(filePath string) ([]string, error) {
if valueStr != "" {
valueStr += ","
}
valueStr += fmt.Sprintf(`"%v"`, value)
if strings.Contains(strings.ToLower(field.DataType), "string") {
valueStr += fmt.Sprintf(`"%v"`, value)
} else {
valueStr += fmt.Sprintf(`%v`, value)
}
}

gql := fmt.Sprintf("INSERT VERTEX `%s` ({props}) VALUES \"%s\":({value});", typ, name)
gql = strings.ReplaceAll(gql, "{props}", propsStr)
gql = strings.ReplaceAll(gql, "{value}", propsStr)
gql = strings.ReplaceAll(gql, "{value}", valueStr)
gqls = append(gqls, gql)
}

Expand Down Expand Up @@ -508,7 +512,11 @@ func (i *ImportJob) MakeGQLFile(filePath string) ([]string, error) {
if propsValue != "" {
propsValue += ","
}
propsValue += fmt.Sprintf("\"%v\"", value)
if strings.Contains(strings.ToLower(field.DataType), "string") {
propsValue += fmt.Sprintf(`"%v"`, value)
} else {
propsValue += fmt.Sprintf(`%v`, value)
}
}
gql := fmt.Sprintf("INSERT EDGE `%s` (%s) VALUES \"%s\"->\"%s\":(%s);", dst.EdgeType, propsName, dst.Src, dst.Dst, propsValue)
gqls = append(gqls, gql)
Expand Down
5 changes: 5 additions & 0 deletions server/api/studio/pkg/llm/transformer/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ type OpenAI struct {
}

func (o *OpenAI) HandleRequest(req map[string]any, config *db.LLMConfig) (*http.Request, error) {
configs := make(map[string]any)
err := json.Unmarshal([]byte(config.Config), &configs)
if err == nil {
req["model"] = configs["model"]
}
// Convert the request parameters to a JSON string
reqJSON, err := json.Marshal(req)
if err != nil {
Expand Down

0 comments on commit 1c360d8

Please sign in to comment.