Skip to content

Commit

Permalink
mongo engine优化 (#2018)
Browse files Browse the repository at this point in the history
* 优化mongo的method取值逻辑

* 支持显示DML影响行数

* 修复执行报错仍显示正常结束的bug

* 补充单元测试
  • Loading branch information
nick2wang authored Jan 14, 2023
1 parent eba3e33 commit c0f41a4
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 9 deletions.
108 changes: 99 additions & 9 deletions sql/engines/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from pymongo.errors import OperationFailure
from dateutil.parser import parse
from bson.objectid import ObjectId
from datetime import datetime

from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
Expand Down Expand Up @@ -429,7 +428,7 @@ def execute(self, db_name=None, sql=""):
line += 1
logger.debug("执行结果:" + r)
# 如果执行中有错误
rz = r.replace(" ", "").replace('"', "").lower()
rz = r.replace(" ", "").replace('"', "")
tr = 1
if (
r.lower().find("syntaxerror") >= 0
Expand All @@ -438,7 +437,7 @@ def execute(self, db_name=None, sql=""):
or rz.find("ReferenceError") >= 0
or rz.find("getErrorWithCode") >= 0
or rz.find("failedtoconnect") >= 0
or rz.find("Error: field") >= 0
or rz.find("Error:") >= 0
):
tr = 0
if (rz.find("errmsg") >= 0 or tr == 0) and (
Expand All @@ -454,14 +453,36 @@ def execute(self, db_name=None, sql=""):
sql=exec_sql,
)
else:
try:
r = json.loads(r)
except Exception as e:
logger.info(str(e))
finally:
methodStr = exec_sql.split(").")[-1].split("(")[0].strip()
if "." in methodStr:
methodStr = methodStr.split(".")[-1]
if methodStr == "insert":
actual_affected_rows = r.get("nInserted", 0)
elif methodStr in ("insertOne", "insertMany"):
actual_affected_rows = r.count("ObjectId")
elif methodStr == "update":
actual_affected_rows = r.get("nModified", 0)
elif methodStr in ("updateOne", "updateMany"):
actual_affected_rows = r.get("modifiedCount", 0)
elif methodStr in ("deleteOne", "deleteMany"):
actual_affected_rows = r.get("deletedCount", 0)
elif methodStr == "remove":
actual_affected_rows = r.get("nRemoved", 0)
else:
actual_affected_rows = 0
# 把结果转换为ReviewSet
result = ReviewResult(
id=line,
errlevel=0,
stagestatus="执行结束",
errormessage=r,
errormessage=str(r),
execute_time=round(end - start, 6),
actual_affected_rows=0, # todo============这个值需要优化
affected_rows=actual_affected_rows,
sql=exec_sql,
)
execute_result.rows += [result]
Expand Down Expand Up @@ -571,9 +592,9 @@ def execute_check(self, db_name=None, sql=""):
check_result.rows += [result]
continue
else:
methodStr = (
sql_str.split(".")[-1].split("(")[0].strip()
) # 最后一个.和括号(之间的字符串作为方法
methodStr = sql_str.split(").")[-1].split("(")[0].strip()
if "." in methodStr:
methodStr = methodStr.split(".")[-1]
if methodStr in is_exist_premise_method and not is_in:
check_result.error = "文档不存在"
result = ReviewResult(
Expand Down Expand Up @@ -651,6 +672,75 @@ def execute_check(self, db_name=None, sql=""):
sql=check_sql,
execute_time=0,
)
if methodStr == "insertOne":
count = 1
elif methodStr in ("insert", "insertMany"):
insert_str = re.search(
rf"{methodStr}\((.*)\)", sql_str, re.S
).group(1)
first_char = insert_str.replace(" ", "").replace(
"\n", ""
)[0]
if first_char == "{":
count = 1
elif first_char == "[":
insert_values = re.search(
r"\[(.*?)\]", insert_str, re.S
).group(0)
de = JsonDecoder()
insert_values = de.decode(insert_values)
count = len(insert_values)
else:
count = 0
elif methodStr in (
"update",
"updateOne",
"updateMany",
"deleteOne",
"deleteMany",
"remove",
):
if sql_str.find("find(") > 0:
count_sql = sql_str.replace(methodStr, "count")
else:
count_sql = (
sql_str.replace(methodStr, "find") + ".count()"
)
query_dict = self.parse_query_sentence(count_sql)
count_sql = f"""db.getCollection("{query_dict["collection"]}").find({query_dict["condition"]}).count()"""
query_result = self.query(db_name, count_sql)
count = json.loads(query_result.rows[0][0]).get(
"count", 0
)
if (
methodStr == "update"
and "multi:true"
not in sql_str.replace(" ", "")
.replace('"', "")
.replace("'", "")
.replace("\n", "")
) or methodStr in ("deleteOne", "updateOne"):
count = 1 if count > 0 else 0
if methodStr in (
"insertOne",
"insert",
"insertMany",
"update",
"updateOne",
"updateMany",
"deleteOne",
"deleteMany",
"remove",
):
result = ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage="检测通过",
affected_rows=count,
sql=check_sql,
execute_time=0,
)
else:
result = ReviewResult(
id=line,
Expand Down Expand Up @@ -1061,7 +1151,7 @@ def parse_tuple(self, cursor, db_name, tb_name, projection=None):
dd = re.findall(re_date, str(value))
for d in dd:
t = int(d.split(":")[1].strip()[:-1])
e = datetime.fromtimestamp(t / 1000)
e = datetime.datetime.fromtimestamp(t / 1000)
value = str(value).replace(
d, e.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
)
Expand Down
50 changes: 50 additions & 0 deletions sql/engines/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,6 +1925,28 @@ def test_execute_check(self, mock_get_all_tables, mock_get_table_conut):
check_result.rows[0].__dict__["errormessage"], row.__dict__["errormessage"]
)

@patch("sql.engines.mongo.MongoEngine.get_all_tables")
def test_execute_check_include_dot(self, mock_get_all_tables):
sql = """db.job.insert({
fileName: "现金明细20230103075728.xls",
contentType: ".xls",
createdTime: ISODate("2023-01-03T12:05:27.402Z"),
reportDate: ISODate("2023-01-03T12:05:27.402Z"),
updatedTime: ISODate("2023-01-03T12:09:30.88Z")
});;"""
mock_get_all_tables.return_value.rows = "job"
check_result = self.engine.execute_check("some_db", sql)
self.assertEqual(
check_result.rows[0].__dict__["stagestatus"], "Audit completed"
)

@patch("sql.engines.mongo.MongoEngine.get_all_tables")
def test_execute_check_on_dml(self, mock_get_all_tables):
sql = """db.job.insert([{"orderCode":1001},{"orderCode":1002}]);"""
mock_get_all_tables.return_value.rows = "job"
check_result = self.engine.execute_check("some_db", sql)
self.assertEqual(check_result.rows[0].__dict__["affected_rows"], 2)

@patch("sql.engines.mongo.MongoEngine.exec_cmd")
@patch("sql.engines.mongo.MongoEngine.get_master")
def test_execute(self, mock_get_master, mock_exec_cmd):
Expand All @@ -1940,6 +1962,34 @@ def test_execute(self, mock_get_master, mock_exec_cmd):
mock_get_master.assert_called_once()
self.assertEqual(check_result.rows[0].__dict__["errlevel"], 0)

@patch("sql.engines.mongo.MongoEngine.exec_cmd")
@patch("sql.engines.mongo.MongoEngine.get_master")
def test_execute_on_dml(self, mock_get_master, mock_exec_cmd):
sql = """db.job.insertMany([{"title":"test1"},{"title":test2"},{"title":test3"}]);"""
mock_exec_cmd.return_value = """{
"acknowledged" : true,
"insertedIds" : [
ObjectId("63b77b53afab4917dfd48a20"),
ObjectId("63b77b53afab4917dfd48a21"),
ObjectId("63b77b53afab4917dfd48a22")
]
}"""

check_result = self.engine.execute("some_db", sql)
mock_get_master.assert_called_once()
self.assertEqual(check_result.rows[0].__dict__["affected_rows"], 3)

@patch("sql.engines.mongo.MongoEngine.exec_cmd")
@patch("sql.engines.mongo.MongoEngine.get_master")
def test_execute_return_error(self, mock_get_master, mock_exec_cmd):
sql = """db.job.insertMany({"title":"test1"},{"title":test2"},{"title":test3"});"""
mock_exec_cmd.return_value = (
"""uncaught exception: TypeError: documents.map is not a function"""
)
check_result = self.engine.execute("some_db", sql)
mock_get_master.assert_called_once()
self.assertEqual(check_result.rows[0].__dict__["stagestatus"], "异常终止")

def test_fill_query_columns(self):
columns = ["_id", "title", "tags", "likes"]
cursor = [
Expand Down

0 comments on commit c0f41a4

Please sign in to comment.