Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ai import bug #710

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 5 additions & 15 deletions app/pages/Setting/index.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { useCallback, useEffect, useState } from 'react';
import { useCallback, useEffect } from 'react';
import { observer } from 'mobx-react-lite';
import { Button, Col, Form, Input, InputNumber, Row, Select, Switch, message } from 'antd';
import { useI18n } from '@vesoft-inc/i18n';
Expand All @@ -15,15 +15,13 @@ const Setting = observer(() => {
const { global, llm } = useStore();
const { appSetting, saveAppSetting } = global;
const [form] = useForm();
const [apiType, setApiType] = useState('openai');
useEffect(() => {
initForm();
}, []);

const initForm = async () => {
await llm.fetchConfig();
form.setFieldsValue(llm.config);
setApiType(llm.config.apiType);
};

const updateAppSetting = useCallback(async (param: Partial<any['beta']>) => {
Expand Down Expand Up @@ -109,13 +107,7 @@ const Setting = observer(() => {
<div className={styles.tips}>{intl.get('setting.llmImportDesc')}</div>
<Form form={form} layout="vertical" style={{ marginTop: 20 }}>
<Form.Item label="API type" name="apiType" required={true}>
<Select
onChange={(value) => {
setApiType(value);
}}
defaultValue="openai"
style={{ width: 120 }}
>
<Select defaultValue="openai" style={{ width: 120 }}>
<Select.Option value="openai">OpenAI</Select.Option>
<Select.Option value="qwen">Aliyun</Select.Option>
</Select>
Expand All @@ -126,11 +118,9 @@ const Setting = observer(() => {
<Form.Item label="key" name="key">
<Input type="password" />
</Form.Item>
{apiType === 'qwen' && (
<Form.Item label="model" name="model" required={true}>
<Input placeholder="qwen-max" />
</Form.Item>
)}
<Form.Item label="model" name="model">
<Input />
</Form.Item>
<Form.Item label={intl.get('setting.maxTextLength')} name="maxContextLength" required={true}>
<InputNumber min={0} />
</Form.Item>
Expand Down
41 changes: 21 additions & 20 deletions app/stores/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import * as ngqlDoc from '@app/utils/ngql';
import schema from './schema';
import rootStore from '.';

export const matchPrompt = `Use NebulaGraph match knowledge to help me answer question.
export const matchPrompt = `I want you to be a NebulaGraph database asistant.
There are below document.
----
Use only the provided relationship types and properties in the schema.
Do not use any other relationship types or properties that are not provided.
Schema:
Expand All @@ -25,6 +27,7 @@ diff
---
> MATCH (p:person)-[:directed]->(m:movie) WHERE m.movie.name == 'The Godfather'
> RETURN p.person.name;
---
Question:{query_str}
`;
export const llmImportPrompt = `As a knowledge graph AI importer, your task is to extract useful data from the following text:
Expand All @@ -46,14 +49,12 @@ The name of the nodes should be an actual object and a noun.
Result:
`;

export const docFinderPrompt = `Assume your are doc finder,from the following graph database book categories:
export const docFinderPrompt = `Assuming you are a document navigator, within the following categories related to graph database books:
"{category_string}"
user current space is: {space_name}
find top two useful categories to solve the question:"{query_str}",
don't explain, if you can't find, return "Sorry".
just return the two combined categories, separated by ',' is:`;
please identify the most two relevant categories that could address the question: "{query_str}".,
Please just return the two categories as a comma-separated list without any other word`;

export const text2queryPrompt = `Assume you are a NebulaGraph AI chat asistant to help user write NGQL.
export const text2queryPrompt = `Assume you are a NebulaGraph database AI chat asistant to help user write NGQL with NebulaGraph.
You have access to the following information:
the user space schema is:
----
Expand Down Expand Up @@ -227,10 +228,8 @@ class LLM {
console.log(prompt);
await ws.runChat({
req: {
temperature: 0.5,
stream: true,
max_tokens: 20,

messages: [
...historyMessages,
{
Expand Down Expand Up @@ -274,18 +273,19 @@ class LLM {
let prompt = matchPrompt; // default use text2cypher
if (this.mode !== 'text2cypher') {
text = text.replaceAll('"', "'");
const docPrompt = docFinderPrompt
.replace('{category_string}', ngqlDoc.NGQLCategoryString)
.replace('{query_str}', text)
.replace('{space_name}', rootStore.console.currentSpace);
console.log(docPrompt);
const res = (await ws.runChat({
req: {
temperature: 0.5,
stream: false,
max_tokens: 20,
messages: [
{
role: 'user',
content: docFinderPrompt
.replace('{category_string}', ngqlDoc.NGQLCategoryString)
.replace('{query_str}', text)
.replace('{space_name}', rootStore.console.currentSpace),
content: docPrompt,
},
],
},
Expand All @@ -297,19 +297,19 @@ class LLM {
.replaceAll(/\s|"|\\/g, '')
.split(',');
console.log('select doc url:', paths);
if (ngqlDoc.ngqlMap[paths[0]]) {
let doc = ngqlDoc.ngqlMap[paths[0]].content;
if (paths[0] !== 'sorry') {
let doc = ngqlDoc.ngqlMap[paths[0]]?.content;
if (!doc) {
doc = '';
}
const doc2 = ngqlDoc.ngqlMap[paths[1]].content;
const doc2 = ngqlDoc.ngqlMap[paths[1]]?.content;
if (doc2) {
doc += doc2;
doc =
doc.slice(0, this.config.maxContextLength / 2) + `\n` + doc2.slice(0, this.config.maxContextLength / 2);
}
doc = doc.replaceAll(/\n\n+/g, '');
if (doc.length) {
console.log('docString:', doc);
prompt = text2queryPrompt.replace('{doc}', doc.slice(0, this.config.maxContextLength));
prompt = text2queryPrompt.replace('{doc}', doc);
}
}
}
Expand All @@ -327,6 +327,7 @@ class LLM {
schemaPrompt += `\nuser console ngql context: ${rootStore.console.currentGQL}`;
}
prompt = prompt.replace('{schema}', schemaPrompt);
console.log(prompt);
return prompt;
}

Expand Down
5 changes: 4 additions & 1 deletion app/utils/ngql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ export const ngqlDoc = (ngqlJson as { url: string; content: string; title: strin
if (urlTransformerMap[item.title]) {
item.title = urlTransformerMap[item.title];
}
item.title = item.title.replaceAll(' ', '');
item.title = item.title
.split(' ')
.map((word) => word[0].toUpperCase() + word.slice(1))
.join('');
item.content = item.content.replace(/nebula>/g, '');

return item;
Expand Down
14 changes: 11 additions & 3 deletions server/api/studio/pkg/llm/importjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,16 @@ func (i *ImportJob) MakeGQLFile(filePath string) ([]string, error) {
if valueStr != "" {
valueStr += ","
}
valueStr += fmt.Sprintf(`"%v"`, value)
if strings.Contains(strings.ToLower(field.DataType), "string") {
valueStr += fmt.Sprintf(`"%v"`, value)
} else {
valueStr += fmt.Sprintf(`%v`, value)
}
}

gql := fmt.Sprintf("INSERT VERTEX `%s` ({props}) VALUES \"%s\":({value});", typ, name)
gql = strings.ReplaceAll(gql, "{props}", propsStr)
gql = strings.ReplaceAll(gql, "{value}", propsStr)
gql = strings.ReplaceAll(gql, "{value}", valueStr)
gqls = append(gqls, gql)
}

Expand Down Expand Up @@ -508,7 +512,11 @@ func (i *ImportJob) MakeGQLFile(filePath string) ([]string, error) {
if propsValue != "" {
propsValue += ","
}
propsValue += fmt.Sprintf("\"%v\"", value)
if strings.Contains(strings.ToLower(field.DataType), "string") {
propsValue += fmt.Sprintf(`"%v"`, value)
} else {
propsValue += fmt.Sprintf(`%v`, value)
}
}
gql := fmt.Sprintf("INSERT EDGE `%s` (%s) VALUES \"%s\"->\"%s\":(%s);", dst.EdgeType, propsName, dst.Src, dst.Dst, propsValue)
gqls = append(gqls, gql)
Expand Down
5 changes: 5 additions & 0 deletions server/api/studio/pkg/llm/transformer/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ type OpenAI struct {
}

func (o *OpenAI) HandleRequest(req map[string]any, config *db.LLMConfig) (*http.Request, error) {
configs := make(map[string]any)
err := json.Unmarshal([]byte(config.Config), &configs)
if err == nil {
req["model"] = configs["model"]
}
// Convert the request parameters to a JSON string
reqJSON, err := json.Marshal(req)
if err != nil {
Expand Down