Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在线查询 支持AI根据描述生成查询语句 #2726

Merged
merged 7 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions archery/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@
# https://docs.djangoproject.com/en/4.0/ref/settings/#csrf-trusted-origins
CSRF_TRUSTED_ORIGINS = env("CSRF_TRUSTED_ORIGINS")

# 用于请求 OpenAI
OPENAI_API_KEY = env("OPENAI_API_KEY")
OPENAI_BASE_URL = env("OPENAI_BASE_URL")
DEFAULT_CHAT_MODEL = env("DEFAULT_CHAT_MODEL")
QSummerY marked this conversation as resolved.
Show resolved Hide resolved

# 解决nginx部署跳转404
USE_X_FORWARDED_HOST = True

Expand Down
26 changes: 26 additions & 0 deletions common/utils/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from openai import OpenAI
from archery import settings
import logging


logger = logging.getLogger("default")
openai_client = OpenAI(base_url=settings.OPENAI_BASE_URL, api_key=settings.OPENAI_API_KEY)


def request_chat_completion(messages, model=settings.DEFAULT_CHAT_MODEL, **kwargs):
"""openai_client """
completion = openai_client.chat.completions.create(
model=model, messages=messages, **kwargs
)
return completion


def generate_sql_by_openai(db_type: str, table_schema: str, query_desc: str):
tips = f'你是一个熟悉 {db_type} 的工程师, 我会给你一些基本信息和要求, 你会生成一个查询语句给我使用, 不要返回任何注释和序号, 仅返回查询语句'
messages = [dict(role='user', content=f"{tips}: {table_schema}\n{query_desc}")]
logger.info(messages)
try:
res = request_chat_completion(messages)
return res.choices[0].message.content
except Exception as e:
raise ValueError(f"请求openai生成查询语句失败: {e}")
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,5 @@ mozilla-django-oidc==3.0.0
django-auth-dingding==0.0.3
django-cas-ng==4.3.0
cassandra-driver
httpx
OpenAI
39 changes: 39 additions & 0 deletions sql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from django.http import HttpResponse
from common.config import SysConfig
from common.utils.extend_json_encoder import ExtendJSONEncoder, ExtendJSONEncoderFTime
from common.utils.openai import generate_sql_by_openai
from common.utils.timer import FuncTimer
from sql.query_privileges import query_priv_check
from sql.utils.resource_group import user_instances
Expand Down Expand Up @@ -313,3 +314,41 @@ def kill_query_conn(instance_id, thread_id):
instance = Instance.objects.get(pk=instance_id)
query_engine = get_engine(instance)
query_engine.kill_connection(thread_id)


@permission_required("sql.menu_sqlquery", raise_exception=True)
def generate_sql(request):
"""
利用AI生成查询SQL, 传入数据基本结构和查询描述
:param request:
:return:
"""
db_type = request.POST.get("db_type")
query_desc = request.POST.get("query_desc")
if not db_type or not query_desc:
return HttpResponse(json.dumps({"status": 1, "msg": "db_type or query_desc不存在", "data": []}), content_type="application/json")

instance_name = request.POST.get("instance_name")
try:
instance = Instance.objects.get(instance_name=instance_name)
except Instance.DoesNotExist:
return HttpResponse(json.dumps({"status": 1, "msg": "实例不存在", "data": []}), content_type="application/json")
db_name = request.POST.get("db_name")
schema_name = request.POST.get("schema_name")
tb_name = request.POST.get("tb_name")

result = {"status": 0, "msg": "ok", "data": ""}
try:
query_engine = get_engine(instance=instance)
query_result = query_engine.describe_table(
db_name, tb_name, schema_name=schema_name
)
# 有些不存在表结构, 例如 redis
if len(query_result.rows) != 0:
result["data"] = generate_sql_by_openai(db_type, query_result.rows[0][-1], query_desc)
else:
result["data"] = generate_sql_by_openai(db_type, "", query_desc)
except Exception as msg:
result["status"] = 1
result["msg"] = str(msg)
return HttpResponse(json.dumps(result), content_type="application/json")
39 changes: 39 additions & 0 deletions sql/templates/sqlquery.html
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ <h4 class="modal-title text-danger">收藏语句</h4>
<option value={{ sql.id }}>{{ sql.alias }}</option>
{% endfor %}
</select>
<input class="form-control" id="generateDesc" placeholder="AI 查询描述"></input>
<input id="btn-generatesql" type="button" class="btn btn-info" value="生成SQL"/>
</div>
<div class="panel-body">
<form id="form-sqlquery" action="/sqlquery/" method="post" class="form-horizontal" role="form">
Expand Down Expand Up @@ -624,6 +626,12 @@ <h4 class="modal-title text-danger">收藏语句</h4>
return result;
}

//提交AI生成sql语句请求
$("#btn-generatesql").click(function () {
generatesql();
}
);

//先做表单验证,验证成功再成功提交查询请求
$("#btn-sqlquery").click(function () {
dosqlquery();
Expand Down Expand Up @@ -1023,6 +1031,37 @@ <h4 class="modal-title text-danger">收藏语句</h4>
});
}

function generatesql() {
var optgroup = $('#instance_name :selected').parent().attr('label');
const data = {
db_type: optgroup,
instance_name: $("#instance_name").val(),
db_name: $("#db_name").val(),
schema_name: $("#schema_name").val(),
tb_name: $("#table_name").val(),
query_desc: $("#generateDesc").val(),
}
//提交请求
$.ajax({
type: "post",
url: "/query/generate_sql/",
dataType: "json",
data: data,
complete: function () {
$('input[type=button]').removeClass('disabled');
$('input[type=button]').prop('disabled', false);
optgroup_control();
},
success: function (data) {
editor.setValue(data["data"]);
editor.clearSelection();
},
error: function (XMLHttpRequest, textStatus, errorThrown) {
alert(errorThrown);
}
});
}

function dosqlquery() {
if (sqlquery_validate()) {
$('input[type=button]').addClass('disabled');
Expand Down
1 change: 1 addition & 0 deletions sql/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
path("query/querylog/", query.querylog),
path("query/querylog_audit/", query.querylog_audit),
path("query/favorite/", query.favorite),
path("query/generate_sql/", query.generate_sql),
path("query/explain/", sql.sql_optimize.explain),
path("query/applylist/", sql.query_privileges.query_priv_apply_list),
path("query/userprivileges/", sql.query_privileges.user_query_priv),
Expand Down
Loading