-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Generate problem support for generating unfinished paragraphs #2299
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
from typing import Dict | ||
from typing import Dict, List | ||
|
||
from langchain_core.messages import BaseMessage, get_buffer_string | ||
|
||
from common.config.tokenizer_manage_config import TokenizerManage | ||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI | ||
|
||
|
@@ -18,3 +21,15 @@ def new_instance(model_type, model_name, model_credential: Dict[str, object], ** | |
stream_usage=True, | ||
**optional_params, | ||
) | ||
|
||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: | ||
if self.usage_metadata is None or self.usage_metadata == {}: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) | ||
return self.usage_metadata.get('input_tokens', 0) | ||
|
||
def get_num_tokens(self, text: str) -> int: | ||
if self.usage_metadata is None or self.usage_metadata == {}: | ||
tokenizer = TokenizerManage.get_tokenizer() | ||
return len(tokenizer.encode(text)) | ||
return self.get_last_generation_info().get('output_tokens', 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The provided code is mostly clean and follows typical Python practices for a class that extends a base model with additional functionality related to message processing and token counting. However, there are a few improvements and optimizations that can be made:
Here’s an enhanced version of the code with these considerations: from typing import Dict, List
import langchain_core.messages as lc_messages
from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI
class CustomBaseChatOpenAI(BaseChatOpenAI):
def __init__(self, model_type, model_name, model_credential: Dict[str, object], **optional_params):
super().__init__(
model_type=model_type,
model_name=model_name,
max_seq_len=None,
max_total_tokens=self.max_total_tokens,
max_context_window=3072,
streaming=True,
**optional_params,
)
def count_tokens_in_messages(self, messages: List[lc_messages.BaseMessage]) -> int:
"""Count the total number of tokens in a list of chat messages."""
if self.usage_metadata is None or self.usage_metadata == {}:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(lc_messages.get_buffer_string([m]))) for m in messages])
return self.usage_metadata.get('input_tokens', 0)
def get_num_tokens_for_text(self, text: str) -> int:
"""Get the number of tokens required to encode a given input text."""
if self.usage_metadata is None or self.usage_metadata == {}:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
return self.get_last_generation_info().get('output_tokens', 0)
# Example usage in another module
# instance = CustomBaseChatOpenAI(...) Key Improvements:
These changes make the code easier to understand and maintain, improving its overall quality. Additionally, they follow best practices recommended by PEP 8 and adhere to naming conventions commonly used in modern software development. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -48,6 +48,16 @@ | |
type="textarea" | ||
/> | ||
</el-form-item> | ||
<el-form-item :label="$t('views.problem.relateParagraph.selectParagraph')" prop="state"> | ||
<el-radio-group v-model="state" class="radio-block"> | ||
<el-radio value="error" size="large" class="mb-16">{{ | ||
$t('views.document.form.selectVectorization.error') | ||
}}</el-radio> | ||
<el-radio value="all" size="large">{{ | ||
$t('views.document.form.selectVectorization.all') | ||
}}</el-radio> | ||
</el-radio-group> | ||
</el-form-item> | ||
</el-form> | ||
</div> | ||
<template #footer> | ||
|
@@ -87,7 +97,11 @@ const dialogVisible = ref<boolean>(false) | |
const modelOptions = ref<any>(null) | ||
const idList = ref<string[]>([]) | ||
const apiType = ref('') // 文档document或段落paragraph | ||
|
||
const state = ref<'all' | 'error'>('error') | ||
const stateMap = { | ||
all: ['0', '1', '2', '3', '4', '5', 'n'], | ||
error: ['0', '1', '3', '4', '5', 'n'] | ||
} | ||
const FormRef = ref() | ||
const userId = user.userInfo?.id as string | ||
const form = ref(prompt.get(userId)) | ||
|
@@ -131,14 +145,22 @@ const submitHandle = async (formEl: FormInstance) => { | |
// 保存提示词 | ||
prompt.save(user.userInfo?.id as string, form.value) | ||
if (apiType.value === 'paragraph') { | ||
const data = { ...form.value, paragraph_id_list: idList.value } | ||
const data = { | ||
...form.value, | ||
paragraph_id_list: idList.value, | ||
state_list: stateMap[state.value] | ||
} | ||
paragraphApi.batchGenerateRelated(id, documentId, data, loading).then(() => { | ||
MsgSuccess(t('views.document.generateQuestion.successMessage')) | ||
emit('refresh') | ||
dialogVisible.value = false | ||
}) | ||
} else if (apiType.value === 'document') { | ||
const data = { ...form.value, document_id_list: idList.value } | ||
const data = { | ||
...form.value, | ||
document_id_list: idList.value, | ||
state_list: stateMap[state.value] | ||
} | ||
documentApi.batchGenerateRelated(id, data, loading).then(() => { | ||
MsgSuccess(t('views.document.generateQuestion.successMessage')) | ||
emit('refresh') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here are some suggestions to optimize and ensure correctness of the provided code:
Here’s the revised version based on these suggestions: import { FormInstance } from '@vxe-table/components/form';
import { MessageBox } from 'element-plus';
import { useUserStore } from '@/stores/user-store';
import { Prompt } from '@/types/prompt';
interface DocumentForm {
// other form fields...
}
const dialogVisible = ref<boolean>(false);
const modelOptions = ref<any>(null);
const idList = ref<string[]>([]);
const apiType = ref(''); // 文档document或段落paragraph
enum State {
All,
Error
}
let stateMap: Record<State, string[]> = {
[State.All]: ['0', '1', '2', '3', '4', '5', 'n'],
[State.Error]: ['0', '1', '3', '4', '5', 'n']
};
const FormRef = ref<FormInstance>();
const userId = user.userInfo?.id as string;
const form = ref<Prompt>(prompt.get(userId));
// Update the type of data accordingly for batchGenerateRelated
async function submitHandle(formEl: FormInstance) {
try {
if (!formEl) throw new Error('Form element is not found');
await formEl.validate();
// Save prompt
prompt.save(user.userInfo?.id as string, form.value);
const data: DocumentForm & { state_list }: DocumentForm & { state_list?: string[] } = {
...form.value,
document_id_list: idList.value,
state_list: stateMap[apiType.value === 'document'] ? stateMap['All'] : stateMap['Error']
};
if (apiType.value === 'paragraph') {
await paragraphApi.batchGenerateRelated(id, documentId, data, loading).then(() => {
MsgSuccess(t('views.document.generateQuestion.successMessage'));
emit('refresh');
dialogVisible.value = false;
});
} else if (apiType.value === 'document') {
await documentApi.batchGenerateRelated(id, data, loading).then(() => {
MsgSuccess(t('views.document.generateQuestion.successMessage'));
emit('refresh');
});
}
} catch (err) {
console.error(err);
MessageBox.alert(`${err.message}`, t('tips'), {
confirmButtonText: t('common.ok'),
cancelButtonText: t('common.cancel'),
}).catch((action) => {});
}
} Key Changes:
These changes should make the code cleaner, more maintainable, and potentially handle edge cases better. |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your code looks generally well-structured and clean. However, there are a few recommendations for improvement:
String Literal Quotes: Consistently use either single (
'
) or double quotes ("
) for string literals to avoid any potential issues.Error Handling in
generate_related
Method: The method should handle cases wheremodel_id
is missing gracefully.Comments and Readability: Add comments explaining what each section of the methods does, especially complex logic like conditional blocks and function calls.
Import Statements: Ensure all import statements are included at the top of the file, making it easier to understand dependencies.
Here's an improved version of your code with these suggestions:
Key Improvements:
generate_related
MethodLet me know if you need further adjustments!