Skip to content

fix: 模型添加长字符的加密解密方式 #310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from common.constants.authentication_type import AuthenticationType
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
from common.util.field_message import ErrMessage
from common.util.rsa_util import decrypt
from common.util.rsa_util import rsa_long_decrypt
from common.util.split_model import flat_map
from dataset.models import Paragraph, Document
from setting.models import Model, Status
Expand Down Expand Up @@ -225,7 +225,7 @@ def re_open_chat(chat_id: str):
# 对话模型
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
rsa_long_decrypt(model.credential)),
streaming=True)
# 数据集id列表
dataset_id_list = [str(row.dataset_id) for row in
Expand Down
8 changes: 5 additions & 3 deletions apps/application/serializers/chat_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import decrypt
from common.util.rsa_util import rsa_long_decrypt
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model
Expand Down Expand Up @@ -195,7 +195,8 @@ def open(self):
if model is not None:
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
rsa_long_decrypt(
model.credential)),
streaming=True)

chat_id = str(uuid.uuid1())
Expand Down Expand Up @@ -252,7 +253,8 @@ def open(self):
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
rsa_long_decrypt(
model.credential)),
streaming=True)
else:
model = None
Expand Down
52 changes: 52 additions & 0 deletions apps/common/util/rsa_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,55 @@ def decrypt(msg, pri_key: str | None = None):
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
return decrypt_data.decode("utf-8")


def rsa_long_encrypt(message, public_key: str | None = None, length=200):
"""
超长文本加密

:param message: 需要加密的字符串
:param public_key 公钥
:param length: 1024bit的证书用100, 2048bit的证书用 200
:return: 加密后的数据
"""

# 读取公钥
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
passphrase=secret_code))
# 处理:Plaintext is too long. 分段加密
if len(message) <= length:
# 对编码的数据进行加密,并通过base64进行编码
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
else:
rsa_text = []
# 对编码后的数据进行切片,原因:加密长度不能过长
for i in range(0, len(message), length):
cont = message[i:i + length]
# 对切片后的数据进行加密,并新增到text后面
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
# 加密完进行拼接
cipher_text = b''.join(rsa_text)
# base64进行编码
result = base64.b64encode(cipher_text)
return result.decode()


def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
"""
超长文本解密,默认不加密
:param message: 需要解密的数据
:param pri_key: 秘钥
:param length : 1024bit的证书用128,2048bit证书用256位
:return: 解密后的数据
"""

if pri_key is None:
pri_key = get_key_pair().get('value')
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
base64_de = base64.b64decode(message)
res = []
for i in range(0, len(base64_de), length):
res.append(cipher.decrypt(base64_de[i:i + length], 0))
return b"".join(res).decode()
10 changes: 5 additions & 5 deletions apps/setting/serializers/provider_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from application.models import Application
from common.exception.app_exception import AppApiException
from common.util.field_message import ErrMessage
from common.util.rsa_util import encrypt, decrypt
from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt
from setting.models.model_management import Model, Status
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
Expand Down Expand Up @@ -118,7 +118,7 @@ def is_valid(self, model=None, raise_exception=False):

model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
model_name)
source_model_credential = json.loads(decrypt(model.credential))
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
if credential is not None:
for k in source_encryption_model_credential.keys():
Expand Down Expand Up @@ -170,7 +170,7 @@ def insert(self, user_id, with_valid=False):
model_name = self.data.get('model_name')
model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=encrypt(model_credential_str),
credential=rsa_long_encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name)
model.save()
if status == Status.DOWNLOAD:
Expand All @@ -180,7 +180,7 @@ def insert(self, user_id, with_valid=False):

@staticmethod
def model_to_dict(model: Model):
credential = json.loads(decrypt(model.credential))
credential = json.loads(rsa_long_decrypt(model.credential))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
Expand Down Expand Up @@ -252,7 +252,7 @@ def edit(self, instance: Dict, user_id: str, with_valid=True):
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential':
model_credential_str = json.dumps(credential)
model.__setattr__(update_key, encrypt(model_credential_str))
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
else:
model.__setattr__(update_key, instance.get(update_key))
model.save()
Expand Down