Skip to content

Commit

Permalink
ES与OpenSearch数据源添加SQL语句支持-v0.6-beta (hhyo#2761)
Browse files Browse the repository at this point in the history
Co-authored-by: 王飞 <fei.wang@xgo.one>
  • Loading branch information
feiazifeiazi and 王飞 committed Oct 18, 2024
1 parent eaf8a62 commit 5c66142
Show file tree
Hide file tree
Showing 3 changed files with 296 additions and 87 deletions.
277 changes: 190 additions & 87 deletions sql/engines/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import traceback
from opensearchpy import OpenSearch
import simplejson as json
import sqlparse
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
from common.config import SysConfig
Expand All @@ -19,24 +20,28 @@
class QueryParamsSearch:
def __init__(
self,
index: str,
path: str,
params: str,
method: str,
size: int,
index: str = None,
path: str = None,
params: str = None,
method: str = None,
size: int = 100,
sql: str = None,
query_body: dict = None,
):
self.index = index
self.path = path
self.index = index if index is not None else ""
self.path = path if path is not None else ""
self.method = method if method is not None else ""
self.params = params
self.method = method
self.size = size
# 确保 query_body 不为 None
self.sql = sql if sql is not None else ""
self.query_body = query_body if query_body is not None else {}


class ElasticsearchEngineBase(EngineBase):
"""Elasticsearch、OpenSearch等Search父类实现"""
"""
Elasticsearch、OpenSearch等Search父类实现
如果2者方法差异不大,可以在父类用if else实现。如果差异大,建议在子类实现。
"""

def __init__(self, instance=None):
self.db_separator = "__" # 设置分隔符
Expand Down Expand Up @@ -178,15 +183,25 @@ def describe_table(self, db_name, tb_name, **kwargs):

def query_check(self, db_name=None, sql=""):
"""语句检查"""
result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
# 使用正则表达式去除开头的空白字符和换行符
tripped_sql = re.sub(r"^\s+", "", sql)
result["filtered_sql"] = tripped_sql
lower_sql = tripped_sql.lower()
result = {
"msg": "语句检查通过。",
"bad_query": False,
"filtered_sql": sql,
"has_star": False,
}
sql = sql.rstrip(";").strip()
result["filtered_sql"] = sql
# 检查是否以 'get' 或 'select' 开头
if lower_sql.startswith("get ") or lower_sql.startswith("select "):
result["msg"] = "语句检查通过。"
result["bad_query"] = False
if re.match(r"^get", sql, re.I):
pass
elif re.match(r"^select", sql, re.I):
try:
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
result["filtered_sql"] = sql.strip()
except IndexError:
result["bad_query"] = True
result["msg"] = "没有有效的SQL语句。"
else:
result["msg"] = (
"语句检查失败:语句必须以 'get' 或 'select' 开头。示例查询:GET /dmp__iv/_search、select * from dmp__iv limit 10;"
Expand All @@ -195,8 +210,40 @@ def query_check(self, db_name=None, sql=""):
return result

def filter_sql(self, sql="", limit_num=0):
"""过滤 SQL 语句"""
return sql.strip()
"""过滤 SQL 语句。
对查询sql增加limit限制,limit n 或 limit n,n 或 limit n offset n统一改写成limit n
此方法SQL部分的逻辑copy的mysql实现。
"""
#
sql = sql.rstrip(";").strip()
if re.match(r"^get", sql, re.I):
pass
elif re.match(r"^select", sql, re.I):
# LIMIT N
limit_n = re.compile(r"limit\s+(\d+)\s*$", re.I)
# LIMIT M OFFSET N
limit_offset = re.compile(r"limit\s+(\d+)\s+offset\s+(\d+)\s*$", re.I)
# LIMIT M,N
offset_comma_limit = re.compile(r"limit\s+(\d+)\s*,\s*(\d+)\s*$", re.I)
if limit_n.search(sql):
sql_limit = limit_n.search(sql).group(1)
limit_num = min(int(limit_num), int(sql_limit))
sql = limit_n.sub(f"limit {limit_num};", sql)
elif limit_offset.search(sql):
sql_limit = limit_offset.search(sql).group(1)
sql_offset = limit_offset.search(sql).group(2)
limit_num = min(int(limit_num), int(sql_limit))
sql = limit_offset.sub(f"limit {limit_num} offset {sql_offset};", sql)
elif offset_comma_limit.search(sql):
sql_offset = offset_comma_limit.search(sql).group(1)
sql_limit = offset_comma_limit.search(sql).group(2)
limit_num = min(int(limit_num), int(sql_limit))
sql = offset_comma_limit.sub(f"limit {sql_offset},{limit_num};", sql)
else:
sql = f"{sql} limit {limit_num};"
else:
sql = f"{sql};"
return sql

def query(
self,
Expand Down Expand Up @@ -236,6 +283,56 @@ def query(
result_set.column_list = []
result_set.rows = []
result_set.affected_rows = 0
elif query_params.sql and self.name == "Elasticsearch":
query_body = {"query": query_params.sql}
response = self.conn.sql.query(body=query_body)
# 提取列名和行数据
columns = response.get("columns", [])
rows = response.get("rows", [])
# 获取字段名作为列名
column_list = [col["name"] for col in columns]

# 处理查询结果,将列表和字典转换为 JSON 字符串。列名可能是重复的。
formatted_rows = []
for row in rows:
# 创建字典,将列名和对应的行值关联
formatted_row = []
for col_name, value in zip(column_list, row):
# 如果字段是列表或字典,将其转换为 JSON 字符串
if isinstance(value, (list, dict)):
formatted_row.append(json.dumps(value))
else:
formatted_row.append(value)
formatted_rows.append(formatted_row)
# 构建结果集
result_set.rows = formatted_rows
result_set.column_list = column_list
elif query_params.sql and self.name == "OpenSearch":
query_body = {"query": query_params.sql}
response = self.conn.transport.perform_request(
method="POST", url="/_opendistro/_sql", body=query_body
)
# 提取列名和行数据
columns = response.get("schema", [])
rows = response.get("datarows", [])
# 获取字段名作为列名
column_list = [col["name"] for col in columns]

# 处理查询结果,将列表和字典转换为 JSON 字符串。列名可能是重复的。
formatted_rows = []
for row in rows:
# 创建字典,将列名和对应的行值关联
formatted_row = []
for col_name, value in zip(column_list, row):
# 如果字段是列表或字典,将其转换为 JSON 字符串
if isinstance(value, (list, dict)):
formatted_row.append(json.dumps(value))
else:
formatted_row.append(value)
formatted_rows.append(formatted_row)
# 构建结果集
result_set.rows = formatted_rows
result_set.column_list = column_list
else:
# 执行搜索查询
response = self.conn.search(
Expand Down Expand Up @@ -297,76 +394,82 @@ def parse_es_select_query_to_query_params(
) -> QueryParamsSearch:
"""解析 search query 字符串为 QueryParamsSearch 对象"""

# 解析查询字符串
lines = search_query_str.splitlines()
method_line = lines[0].strip()
query_params = QueryParamsSearch()
sql = search_query_str.rstrip(";").strip()
if re.match(r"^get", sql, re.I):
# 解析查询字符串
lines = search_query_str.splitlines()
method_line = lines[0].strip()

query_body = "\n".join(lines[1:]).strip()
# 如果 query_body 为空,使用默认查询体
if not query_body:
query_body = json.dumps({"query": {"match_all": {}}})
query_body = "\n".join(lines[1:]).strip()
# 如果 query_body 为空,使用默认查询体
if not query_body:
query_body = json.dumps({"query": {"match_all": {}}})

# 确保 query_body 是有效的 JSON
try:
json_body = json.loads(query_body)
except json.JSONDecodeError as json_err:
raise ValueError(f"query_body:{query_body} 无法转为Json格式。{json_err},")

# 提取方法和路径
method, path_with_params = method_line.split(maxsplit=1)
# 确保路径以 '/' 开头
if not path_with_params.startswith("/"):
path_with_params = "/" + path_with_params

# 分离路径和查询参数
path, params_str = (
path_with_params.split("?", 1)
if "?" in path_with_params
else (path_with_params, "")
)
params = {}
if params_str:
for pair in params_str.split("&"):
if "=" in pair:
key, value = pair.split("=", 1)
else:
key = pair
value = ""
params[key] = value
index_pattern = ""
# 判断路径类型并提取索引模式
if path.startswith("/_cat/indices"):
# _cat API 路径
path_parts = path.split("/")
if len(path_parts) > 3:
index_pattern = path_parts[3]
if not index_pattern:
index_pattern = "*"
elif "/_search" in path:
# 默认情况,处理常规索引路径
# 提取索引名称
path_parts = path.split("/")
if len(path_parts) > 1:
index_pattern = path_parts[1]

if not index_pattern:
raise Exception("未找到索引名称。")

size = limit_num if limit_num > 0 else 100
# 检查 JSON 中是否已经有 size,如果没有就设置
if "size" not in json_body:
json_body["size"] = size

# 构建 QueryParams 对象
query_params = QueryParamsSearch(
index=index_pattern,
path=path_with_params,
params=params,
method=method,
size=size,
query_body=json_body,
)
# 确保 query_body 是有效的 JSON
try:
json_body = json.loads(query_body)
except json.JSONDecodeError as json_err:
raise ValueError(
f"query_body:{query_body} 无法转为Json格式。{json_err},"
)

# 提取方法和路径
method, path_with_params = method_line.split(maxsplit=1)
# 确保路径以 '/' 开头
if not path_with_params.startswith("/"):
path_with_params = "/" + path_with_params

# 分离路径和查询参数
path, params_str = (
path_with_params.split("?", 1)
if "?" in path_with_params
else (path_with_params, "")
)
params = {}
if params_str:
for pair in params_str.split("&"):
if "=" in pair:
key, value = pair.split("=", 1)
else:
key = pair
value = ""
params[key] = value
index_pattern = ""
# 判断路径类型并提取索引模式
if path.startswith("/_cat/indices"):
# _cat API 路径
path_parts = path.split("/")
if len(path_parts) > 3:
index_pattern = path_parts[3]
if not index_pattern:
index_pattern = "*"
elif "/_search" in path:
# 默认情况,处理常规索引路径
# 提取索引名称
path_parts = path.split("/")
if len(path_parts) > 1:
index_pattern = path_parts[1]

if not index_pattern:
raise Exception("未找到索引名称。")

size = limit_num if limit_num > 0 else 100
# 检查 JSON 中是否已经有 size,如果没有就设置
if "size" not in json_body:
json_body["size"] = size

# 构建 QueryParams 对象
query_params = QueryParamsSearch(
index=index_pattern,
path=path_with_params,
params=params,
method=method,
size=size,
query_body=json_body,
)
elif re.match(r"^select", sql, re.I):
query_params = QueryParamsSearch(sql=sql)
return query_params


Expand Down
Loading

0 comments on commit 5c66142

Please sign in to comment.