Skip to content

Commit

Permalink
feat: add rerank models to the project infiniflow#724 infiniflow#162 (i…
Browse files Browse the repository at this point in the history
…nfiniflow#966)

### What problem does this PR solve?

Vector similarity weight is displayed incorrectly infiniflow#965
feat: add rerank models to the project infiniflow#724 infiniflow#162
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
  • Loading branch information
cike8899 authored May 29, 2024
1 parent d4d8c89 commit 0233335
Show file tree
Hide file tree
Showing 15 changed files with 132 additions and 25 deletions.
57 changes: 57 additions & 0 deletions web/src/components/rerank.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { LlmModelType } from '@/constants/knowledge';
import { useTranslate } from '@/hooks/commonHooks';
import { useSelectLlmOptionsByModelType } from '@/hooks/llmHooks';
import { Form, Select, Slider } from 'antd';

type FieldType = {
rerank_id?: string;
top_k?: number;
};

export const RerankItem = () => {
const { t } = useTranslate('knowledgeDetails');
const allOptions = useSelectLlmOptionsByModelType();

return (
<Form.Item
label={t('rerankModel')}
name={'rerank_id'}
tooltip={t('rerankTip')}
>
<Select
options={allOptions[LlmModelType.Rerank]}
allowClear
placeholder={t('rerankPlaceholder')}
/>
</Form.Item>
);
};

const Rerank = () => {
const { t } = useTranslate('knowledgeDetails');

return (
<>
<RerankItem></RerankItem>
<Form.Item noStyle dependencies={['rerank_id']}>
{({ getFieldValue }) => {
const rerankId = getFieldValue('rerank_id');
return (
rerankId && (
<Form.Item<FieldType>
label={t('topK')}
name={'top_k'}
initialValue={1024}
tooltip={t('topKTip')}
>
<Slider max={2048} min={1} />
</Form.Item>
)
);
}}
</Form.Item>
</>
);
};

export default Rerank;
2 changes: 1 addition & 1 deletion web/src/components/similarity-slider/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const SimilaritySlider = ({ isTooltipShown = false }: IProps) => {
<Form.Item<FieldType>
label={t('vectorSimilarityWeight')}
name={'vector_similarity_weight'}
initialValue={0.3}
initialValue={1 - 0.3}
tooltip={isTooltipShown && t('vectorSimilarityWeightTip')}
>
<Slider max={1} step={0.01} />
Expand Down
1 change: 1 addition & 0 deletions web/src/constants/knowledge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export enum LlmModelType {
Chat = 'chat',
Image2text = 'image2text',
Speech2text = 'speech2text',
Rerank = 'rerank',
}

export enum KnowledgeSearchParams {
Expand Down
1 change: 1 addition & 0 deletions web/src/hooks/llmHooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export const useSelectLlmOptionsByModelType = () => {
[LlmModelType.Speech2text]: groupOptionsByModelType(
LlmModelType.Speech2text,
),
[LlmModelType.Rerank]: groupOptionsByModelType(LlmModelType.Rerank),
};
};

Expand Down
2 changes: 2 additions & 0 deletions web/src/interfaces/database/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ export interface IDialog {
tenant_id: string;
update_date: string;
update_time: number;
vector_similarity_weight: number;
similarity_threshold: number;
}

export interface IConversation {
Expand Down
17 changes: 11 additions & 6 deletions web/src/locales/en.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,15 +93,12 @@ export default {
progressMsg: 'Progress Msg',
testingDescription:
'Final step! After success, leave the rest to Infiniflow AI.',
topK: 'Top K',
topKTip:
"For the computaion cost, not all the retrieved chunk will be computed vector cosine similarity with query. The bigger the 'Top K' is, the higher the recall rate is, the slower the retrieval speed is.",
similarityThreshold: 'Similarity threshold',
similarityThresholdTip:
"We use hybrid similarity score to evaluate distance between two lines of text. It's weighted keywords similarity and vector cosine similarity. If the similarity between query and chunk is less than this threshold, the chunk will be filtered out.",
vectorSimilarityWeight: 'Vector similarity weight',
vectorSimilarityWeight: 'Keywords similarity weight',
vectorSimilarityWeightTip:
"We use hybrid similarity score to evaluate distance between two lines of text. It's weighted keywords similarity and vector cosine similarity. The sum of both weights is 1.0.",
" We use hybrid similarity score to evaluate distance between two lines of text. It's weighted keywords similarity and vector cosine similarity or rerank score(0~1). The sum of both weights is 1.0.",
testText: 'Test text',
testTextPlaceholder: 'Please input your question!',
testingLabel: 'Testing',
Expand Down Expand Up @@ -143,6 +140,11 @@ export default {
chunk: 'Chunk',
bulk: 'Bulk',
cancel: 'Cancel',
rerankModel: 'Rerank Model',
rerankPlaceholder: 'Please select',
rerankTip: `If it's empty. It uses embeddings of query and chunks to compuste vector cosine similarity. Otherwise, it uses rerank score in place of vector cosine similarity.`,
topK: 'Top-K',
topKTip: `K chunks will be fed into rerank models.`,
},
knowledgeConfiguration: {
titleDescription:
Expand Down Expand Up @@ -465,6 +467,8 @@ The above is the content you need to summarize.`,
sequence2txtModel: 'Sequence2txt model',
sequence2txtModelTip:
'The default ASR model all the newly created knowledgebase will use. Use this model to translate voices to corresponding text.',
rerankModel: 'Rerank Model',
rerankModelTip: `The default rerank model is used to rerank chunks retrieved by users' questions.`,
workspace: 'Workspace',
upgrade: 'Upgrade',
addLlmTitle: 'Add LLM',
Expand All @@ -477,7 +481,8 @@ The above is the content you need to summarize.`,
baseUrlNameMessage: 'Please input your base url!',
vision: 'Does it support Vision?',
ollamaLink: 'How to integrate {{name}}',
volcModelNameMessage: 'Please input your model name! Format: {"ModelName":"EndpointID"}',
volcModelNameMessage:
'Please input your model name! Format: {"ModelName":"EndpointID"}',
addVolcEngineAK: 'VOLC ACCESS_KEY',
volcAKMessage: 'Please input your VOLC_ACCESS_KEY',
addVolcEngineSK: 'VOLC SECRET_KEY',
Expand Down
14 changes: 9 additions & 5 deletions web/src/locales/zh-traditional.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,12 @@ export default {
processDuration: '過程持續時間',
progressMsg: '進度消息',
testingDescription: '最後一步!成功後,剩下的就交給Infiniflow AI吧。',
topK: 'top k',
topKTip:
'對於計算成本,並非所有檢索到的塊都會計算與查詢的向量餘弦相似度。Top K越大,召回率越高,檢索速度越慢。',
similarityThreshold: '相似度閾值',
similarityThresholdTip:
'我們使用混合相似度得分來評估兩行文本之間的距離。它是加權關鍵詞相似度和向量餘弦相似度。如果查詢和塊之間的相似度小於此閾值,則該塊將被過濾掉。',
vectorSimilarityWeight: '向量相似度權重',
vectorSimilarityWeight: '關鍵字相似度權重',
vectorSimilarityWeightTip:
'我們使用混合相似度得分來評估兩行文本之間的距離。它是加權關鍵詞相似度和向量餘弦相似度。兩個權重之和為 1.0。',
'我們使用混合相似性評分來評估兩行文本之間的距離。它是加權關鍵字相似性和矢量餘弦相似性或rerank得分(0〜1)。兩個權重的總和為1.0。',
testText: '測試文本',
testTextPlaceholder: '請輸入您的問題!',
testingLabel: '測試',
Expand Down Expand Up @@ -139,6 +136,11 @@ export default {
chunk: '解析塊',
bulk: '批量',
cancel: '取消',
rerankModel: 'rerank模型',
rerankPlaceholder: '請選擇',
rerankTip: `如果是空的。它使用查詢和塊的嵌入來構成矢量餘弦相似性。否則,它使用rerank評分代替矢量餘弦相似性。`,
topK: 'Top-K',
topKTip: `K塊將被送入Rerank型號。`,
},
knowledgeConfiguration: {
titleDescription: '在這裡更新您的知識庫詳細信息,尤其是解析方法。',
Expand Down Expand Up @@ -429,6 +431,8 @@ export default {
sequence2txtModel: 'sequence2Txt模型',
sequence2txtModelTip:
'所有新創建的知識庫都將使用默認的 ASR 模型。使用此模型將語音翻譯為相應的文本。',
rerankModel: 'rerank模型',
rerankModelTip: `默認的重讀模型用於用戶問題檢索到重讀塊。`,
workspace: '工作空間',
upgrade: '升級',
addLlmTitle: '添加Llm',
Expand Down
14 changes: 9 additions & 5 deletions web/src/locales/zh.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,12 @@ export default {
processDuration: '过程持续时间',
progressMsg: '进度消息',
testingDescription: '最后一步! 成功后,剩下的就交给Infiniflow AI吧。',
topK: 'Top K',
topKTip:
'对于计算成本,并非所有检索到的块都会计算与查询的向量余弦相似度。 Top K越大,召回率越高,检索速度越慢。',
similarityThreshold: '相似度阈值',
similarityThresholdTip:
'我们使用混合相似度得分来评估两行文本之间的距离。 它是加权关键词相似度和向量余弦相似度。 如果查询和块之间的相似度小于此阈值,则该块将被过滤掉。',
vectorSimilarityWeight: '向量相似度权重',
vectorSimilarityWeight: '关键字相似度权重',
vectorSimilarityWeightTip:
'我们使用混合相似度得分来评估两行文本之间的距离。 它是加权关键词相似度和向量余弦相似度。 两个权重之和为 1.0。',
'我们使用混合相似性评分来评估两行文本之间的距离。它是加权关键字相似性和矢量余弦相似性或rerank得分(0〜1)。两个权重的总和为1.0。',
testText: '测试文本',
testTextPlaceholder: '请输入您的问题!',
testingLabel: '测试',
Expand Down Expand Up @@ -140,6 +137,11 @@ export default {
chunk: '解析块',
bulk: '批量',
cancel: '取消',
rerankModel: 'Rerank模型',
rerankPlaceholder: '请选择',
rerankTip: `如果是空的。它使用查询和块的嵌入来构成矢量余弦相似性。否则,它使用rerank评分代替矢量余弦相似性。`,
topK: 'Top-K',
topKTip: `K块将被送入Rerank型号。`,
},
knowledgeConfiguration: {
titleDescription: '在这里更新您的知识库详细信息,尤其是解析方法。',
Expand Down Expand Up @@ -446,6 +448,8 @@ export default {
sequence2txtModel: 'Sequence2txt模型',
sequence2txtModelTip:
'所有新创建的知识库都将使用默认的 ASR 模型。 使用此模型将语音翻译为相应的文本。',
rerankModel: 'Rerank模型',
rerankModelTip: `默认的重读模型用于用户问题检索到重读块。`,
workspace: '工作空间',
upgrade: '升级',
addLlmTitle: '添加 LLM',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ const KnowledgeTesting = () => {

const handleTesting = async () => {
const values = await form.validateFields();
testChunk(values);
testChunk({
...values,
vector_similarity_weight: 1 - values.vector_similarity_weight,
});
};

useEffect(() => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ import SimilaritySlider from '@/components/similarity-slider';
import { Button, Card, Divider, Flex, Form, Input } from 'antd';
import { FormInstance } from 'antd/lib';

import Rerank from '@/components/rerank';
import { useTranslate } from '@/hooks/commonHooks';
import { useFetchLlmList } from '@/hooks/llmHooks';
import { useOneNamespaceEffectsLoading } from '@/hooks/storeHooks';
import { useEffect } from 'react';
import styles from './index.less';

type FieldType = {
Expand All @@ -23,6 +26,11 @@ const TestingControl = ({ form, handleTesting }: IProps) => {
'testDocumentChunk',
]);
const { t } = useTranslate('knowledgeDetails');
const fetchLlmList = useFetchLlmList();

useEffect(() => {
fetchLlmList();
}, [fetchLlmList]);

const buttonDisabled =
!question || (typeof question === 'string' && question.trim() === '');
Expand All @@ -37,6 +45,7 @@ const TestingControl = ({ form, handleTesting }: IProps) => {
<section>
<Form name="testing" layout="vertical" form={form}>
<SimilaritySlider isTooltipShown></SimilaritySlider>
<Rerank></Rerank>
<Card size="small" title={t('testText')}>
<Form.Item<FieldType>
name={'question'}
Expand Down
11 changes: 11 additions & 0 deletions web/src/pages/chat/chat-configuration-modal/hooks.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { useFetchLlmList } from '@/hooks/llmHooks';
import {
useFetchTenantInfo,
useSelectTenantInfo,
Expand All @@ -16,3 +17,13 @@ export const useFetchModelId = (visible: boolean) => {

return tenantInfo?.llm_id ?? '';
};

export const useFetchLlmModelOnVisible = (visible: boolean) => {
const fetchLlmList = useFetchLlmList();

useEffect(() => {
if (visible) {
fetchLlmList();
}
}, [fetchLlmList, visible]);
};
6 changes: 5 additions & 1 deletion web/src/pages/chat/chat-configuration-modal/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { variableEnabledFieldMap } from '../constants';
import { IPromptConfigParameters } from '../interface';
import { excludeUnEnabledVariables } from '../utils';
import AssistantSetting from './assistant-setting';
import { useFetchModelId } from './hooks';
import { useFetchLlmModelOnVisible, useFetchModelId } from './hooks';
import ModelSetting from './model-setting';
import PromptEngine from './prompt-engine';

Expand Down Expand Up @@ -92,6 +92,7 @@ const ChatConfigurationModal = ({
const finalValues = {
dialog_id: initialDialog.id,
...nextValues,
vector_similarity_weight: 1 - nextValues.vector_similarity_weight,
prompt_config: {
...nextValues.prompt_config,
parameters: promptEngineRef.current,
Expand All @@ -115,6 +116,8 @@ const ChatConfigurationModal = ({
form.resetFields();
};

useFetchLlmModelOnVisible(visible);

const title = (
<Flex gap={16}>
<ChatConfigurationAtom></ChatConfigurationAtom>
Expand Down Expand Up @@ -142,6 +145,7 @@ const ChatConfigurationModal = ({
settledModelVariableMap[ModelVariableType.Precise],
icon: fileList,
llm_id: initialDialog.llm_id ?? modelId,
vector_similarity_weight: 1 - initialDialog.vector_similarity_weight,
});
}
}, [initialDialog, form, visible, modelId]);
Expand Down
8 changes: 3 additions & 5 deletions web/src/pages/chat/chat-configuration-modal/model-setting.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { useEffect } from 'react';
import { ISegmentedContentProps } from '../interface';

import { useTranslate } from '@/hooks/commonHooks';
import { useFetchLlmList, useSelectLlmOptions } from '@/hooks/llmHooks';
import { useSelectLlmOptionsByModelType } from '@/hooks/llmHooks';
import { Variable } from '@/interfaces/database/chat';
import { variableEnabledFieldMap } from '../constants';
import styles from './index.less';
Expand All @@ -30,7 +30,7 @@ const ModelSetting = ({
value: x,
}));

const modelOptions = useSelectLlmOptions();
const modelOptions = useSelectLlmOptionsByModelType();

const handleParametersChange = (value: ModelVariableType) => {
const variable = settledModelVariableMap[value];
Expand All @@ -56,8 +56,6 @@ const ModelSetting = ({
}
}, [form, initialLlmSetting, visible]);

useFetchLlmList(LlmModelType.Chat);

return (
<section
className={classNames({
Expand All @@ -70,7 +68,7 @@ const ModelSetting = ({
tooltip={t('modelTip')}
rules={[{ required: true, message: t('modelMessage') }]}
>
<Select options={modelOptions} showSearch />
<Select options={modelOptions[LlmModelType.Chat]} showSearch />
</Form.Item>
<Divider></Divider>
<Form.Item
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import {
} from '../interface';
import { EditableCell, EditableRow } from './editable-cell';

import Rerank from '@/components/rerank';
import { useTranslate } from '@/hooks/commonHooks';
import { useSelectPromptConfigParameters } from '../hooks';
import styles from './index.less';
Expand Down Expand Up @@ -172,7 +173,7 @@ const PromptEngine = (
>
<Slider max={30} />
</Form.Item>

<Rerank></Rerank>
<section className={classNames(styles.variableContainer)}>
<Row align={'middle'} justify="end">
<Col span={7} className={styles.variableAlign}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ const SystemModelSettingModal = ({
>
<Select options={allOptions[LlmModelType.Speech2text]} />
</Form.Item>
<Form.Item
label={t('rerankModel')}
name="rerank_id"
tooltip={t('rerankModelTip')}
>
<Select options={allOptions[LlmModelType.Rerank]} />
</Form.Item>
</Form>
</Modal>
);
Expand Down

0 comments on commit 0233335

Please sign in to comment.