Skip to content

Pr@main@show chat source #250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 25, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.1.13 on 2024-04-25 11:28

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('application', '0003_application_icon'),
]

operations = [
migrations.AddField(
model_name='applicationaccesstoken',
name='show_source',
field=models.BooleanField(default=False, verbose_name='是否显示知识来源'),
),
]
1 change: 1 addition & 0 deletions apps/application/models/api_key_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ApplicationAccessToken(AppModelMixin):
white_list = ArrayField(verbose_name="白名单列表",
base_field=models.CharField(max_length=128, blank=True)
, default=list)
show_source = models.BooleanField(default=False, verbose_name="是否显示知识来源")

class Meta:
db_table = "application_access_token"
Expand Down
17 changes: 13 additions & 4 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404
from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed
from common.field.common import UploadedImageField
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
Expand Down Expand Up @@ -170,7 +170,9 @@ class AccessTokenEditSerializer(serializers.Serializer):
white_list = serializers.ListSerializer(required=False, child=serializers.CharField(required=True,
error_messages=ErrMessage.char(
"白名单")),
error_messages=ErrMessage.list("白名单列表"))
error_messages=ErrMessage.list("白名单列表")),
show_source = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("是否显示知识来源"))

def edit(self, instance: Dict, with_valid=True):
if with_valid:
Expand All @@ -190,6 +192,8 @@ def edit(self, instance: Dict, with_valid=True):
application_access_token.white_active = instance.get("white_active")
if 'white_list' in instance and instance.get('white_list') is not None:
application_access_token.white_list = instance.get('white_list')
if 'show_source' in instance and instance.get('show_source') is not None:
application_access_token.show_source = instance.get('show_source')
application_access_token.save()
return self.one(with_valid=False)

Expand All @@ -210,7 +214,8 @@ def one(self, with_valid=True):
"is_active": application_access_token.is_active,
'access_num': application_access_token.access_num,
'white_active': application_access_token.white_active,
'white_list': application_access_token.white_list
'white_list': application_access_token.white_list,
'show_source': application_access_token.show_source
}

class Authentication(serializers.Serializer):
Expand Down Expand Up @@ -474,8 +479,12 @@ def profile(self, with_valid=True):
self.is_valid()
application_id = self.data.get("application_id")
application = QuerySet(Application).get(id=application_id)
application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application.id).first()
if application_access_token is None:
raise AppUnauthorizedFailed(500, "非法用户")
return ApplicationSerializer.Query.reset_application(
ApplicationSerializer.ApplicationModel(application).data)
{**ApplicationSerializer.ApplicationModel(application).data,
'show_source': application_access_token.show_source})

def edit(self, instance: Dict, with_valid=True):
if with_valid:
Expand Down
21 changes: 16 additions & 5 deletions apps/application/serializers/chat_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from rest_framework import serializers

from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
from application.models.api_key_model import ApplicationAccessToken
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
ModelSettingSerializer
from application.serializers.chat_message_serializers import ChatInfo
Expand Down Expand Up @@ -277,17 +278,27 @@ class Meta:
class ChatRecordSerializer(serializers.Serializer):
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))

application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=self.data.get('application_id')).first()
if application_access_token is None:
raise AppApiException(500, '不存在的应用认证信息')
if not application_access_token.show_source:
raise AppApiException(500, '未开启显示知识来源')

def get_chat_record(self):
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
chat_record.id == uuid.UUID(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
if chat_info is not None:
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
chat_record.id == uuid.UUID(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()

def one(self, with_valid=True):
Expand Down
2 changes: 2 additions & 0 deletions apps/application/swagger_api/application_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def get_request_body_api():
'white_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING), title="白名单列表",
description="白名单列表"),
'show_source': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否显示知识来源",
description="是否显示知识来源"),
}
)

Expand Down
3 changes: 2 additions & 1 deletion apps/application/views/chat_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ class Operate(APIView):
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
RoleConstants.APPLICATION_ACCESS_TOKEN],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
Expand Down
25 changes: 15 additions & 10 deletions ui/src/components/ai-chat/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
<img src="@/assets/icon_robot.svg" style="width: 75%" alt="" />
</AppAvatar>
</div>

<div class="content">
<div class="flex" v-if="!item.answer_text">
<el-card
Expand All @@ -72,7 +73,9 @@

<el-card v-else shadow="always" class="dialog-card">
<MdRenderer :source="item.answer_text"></MdRenderer>
<div v-if="(id && !props.appId && item.write_ed) || log">
<div
v-if="(id && item.write_ed) || (props.data?.show_source && item.write_ed) || log"
>
<el-divider> <el-text type="info">知识来源</el-text> </el-divider>
<div class="mb-8">
<el-space wrap>
Expand Down Expand Up @@ -271,7 +274,7 @@ function openParagraph(row: any, id?: string) {
}

function quickProblemHandle(val: string) {
if (!props.log && !loading.value) {
if (!props.log && !loading.value && props.data?.name && props.data?.model_id) {
// inputValue.value = val
// nextTick(() => {
// quickInputRef.value?.focus()
Expand Down Expand Up @@ -488,7 +491,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean) {
}
})
.then(() => {
return !props.appId && getSourceDetail(chat)
return (id || props.data?.show_source) && getSourceDetail(chat)
})
.finally(() => {
ChatManagement.close(chat.id)
Expand All @@ -505,14 +508,16 @@ function regenerationChart(item: chatType) {
}

function getSourceDetail(row: any) {
logApi.getRecordDetail(id, chartOpenId.value, row.record_id, loading).then((res) => {
const exclude_keys = ['answer_text', 'id']
Object.keys(res.data).forEach((key) => {
if (!exclude_keys.includes(key)) {
row[key] = res.data[key]
}
logApi
.getRecordDetail(id || props.appId, chartOpenId.value, row.record_id, loading)
.then((res) => {
const exclude_keys = ['answer_text', 'id']
Object.keys(res.data).forEach((key) => {
if (!exclude_keys.includes(key)) {
row[key] = res.data[key]
}
})
})
})
return true
}

Expand Down
7 changes: 7 additions & 0 deletions ui/src/views/application-overview/component/LimitDialog.vue
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
<template>
<el-dialog title="访问限制" v-model="dialogVisible">
<el-form label-position="top" ref="limitFormRef" :model="form">
<el-form-item label="显示知识来源" @click.prevent>
<el-switch size="small" v-model="form.show_source"></el-switch>
</el-form-item>
<el-form-item label="客户端提问限制">
<el-input-number
v-model="form.access_num"
Expand Down Expand Up @@ -51,6 +54,7 @@ const emit = defineEmits(['refresh'])

const limitFormRef = ref()
const form = ref<any>({
show_source: false,
access_num: 0,
white_active: true,
white_list: ''
Expand All @@ -62,6 +66,7 @@ const loading = ref(false)
watch(dialogVisible, (bool) => {
if (!bool) {
form.value = {
show_source: false,
access_num: 0,
white_active: true,
white_list: ''
Expand All @@ -70,6 +75,7 @@ watch(dialogVisible, (bool) => {
})

const open = (data: any) => {
form.value.show_source = data.show_source
form.value.access_num = data.access_num
form.value.white_active = data.white_active
form.value.white_list = data.white_list?.length ? data.white_list?.join('\n') : ''
Expand All @@ -81,6 +87,7 @@ const submit = async (formEl: FormInstance | undefined) => {
await formEl.validate((valid, fields) => {
if (valid) {
const obj = {
show_source: form.value.show_source,
white_list: form.value.white_list ? form.value.white_list.split('\n') : [],
white_active: form.value.white_active,
access_num: form.value.access_num
Expand Down