diff --git a/sql/engines/clickhouse.py b/sql/engines/clickhouse.py index 54e5c14624..446238e013 100644 --- a/sql/engines/clickhouse.py +++ b/sql/engines/clickhouse.py @@ -199,11 +199,10 @@ def explain_check(self, check_result, db_name=None, line=0, statement=''): if self.server_version >= (21, 1, 2): explain_result = self.query(db_name=db_name, sql=f"explain ast {statement}") if explain_result.error: - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回未通过检查SQL', - errormessage=f'explain语法检查错误:{explain_result.error}', - sql=statement) + stagestatus='驳回未通过检查SQL', + errormessage=f'explain语法检查错误:{explain_result.error}', + sql=statement) return result def execute_check(self, db_name=None, sql=''): @@ -222,14 +221,12 @@ def execute_check(self, db_name=None, sql=''): statement = statement.rstrip(';') # 禁用语句 if re.match(r"^select|^show", statement.lower()): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回不支持语句', errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', sql=statement) # 高危语句 elif critical_ddl_regex and p.match(statement.strip().lower()): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回高危SQL', errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!', @@ -245,7 +242,6 @@ def execute_check(self, db_name=None, sql=''): table_exist = self.get_table_engine(table_name)['status'] if table_exist == 1: if not table_engine.endswith('MergeTree') and table_engine not in ('Merge', 'Distributed'): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回不支持SQL', errormessage='ALTER TABLE仅支持*MergeTree,Merge以及Distributed等引擎表!', @@ -254,7 +250,6 @@ def execute_check(self, db_name=None, sql=''): # delete与update语句,实际是alter语句的变种 if re.match(r"^alter\s+table\s+(.+?)\s+(delete|update)\s+", statement.lower()): if not table_engine.endswith('MergeTree'): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回不支持SQL', errormessage='DELETE与UPDATE仅支持*MergeTree引擎表!', @@ -264,7 +259,6 @@ def execute_check(self, db_name=None, sql=''): else: result = self.explain_check(check_result, db_name, line, statement) else: - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='表不存在', errormessage=f'表 {table_name} 不存在!', @@ -281,7 +275,6 @@ def execute_check(self, db_name=None, sql=''): table_exist = self.get_table_engine(table_name)['status'] if table_exist == 1: if table_engine in ('View', 'File,', 'URL', 'Buffer', 'Null'): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回不支持SQL', errormessage='TRUNCATE不支持View,File,URL,Buffer和Null表引擎!', @@ -289,7 +282,6 @@ def execute_check(self, db_name=None, sql=''): else: result = self.explain_check(check_result, db_name, line, statement) else: - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='表不存在', errormessage=f'表 {table_name} 不存在!', @@ -310,13 +302,11 @@ def execute_check(self, db_name=None, sql=''): affected_rows=0, execute_time=0, ) else: - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='表不存在', errormessage=f'表 {table_name} 不存在!', sql=statement) else: - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回不支持SQL', errormessage='INSERT语法不正确!', @@ -330,12 +320,13 @@ def execute_check(self, db_name=None, sql=''): if get_syntax_type(statement, parser=False, db_type='mysql') == 'DDL': check_result.syntax_type = 1 check_result.rows += [result] - - # 遇到禁用和高危语句直接返回 - if check_result.is_critical: - check_result.error_count += 1 - return check_result line += 1 + # 统计警告和错误数量 + for r in check_result.rows: + if r.errlevel == 1: + check_result.warning_count += 1 + if r.errlevel == 2: + check_result.error_count += 1 return check_result def execute_workflow(self, workflow): diff --git a/sql/engines/mongo.py b/sql/engines/mongo.py index e3edaeba86..1968179b03 100644 --- a/sql/engines/mongo.py +++ b/sql/engines/mongo.py @@ -430,7 +430,6 @@ def execute_check(self, db_name=None, sql=''): alert = "" if is_in: check_result.error = "文档已经存在" - check_result.error_count += 1 result = ReviewResult(id=line, errlevel=2, stagestatus='文档已经存在', errormessage='文档已经存在!', @@ -444,7 +443,6 @@ def execute_check(self, db_name=None, sql=''): methodStr = sql_str.split('(')[0].split('.')[-1].strip() # 最后一个.和括号(之间的字符串作为方法 if methodStr in is_exist_premise_method and not is_in: check_result.error = "文档不存在" - check_result.error_count += 1 result = ReviewResult(id=line, errlevel=2, stagestatus='文档不存在', errormessage=f'文档不存在,不能进行{methodStr}操作!', @@ -492,7 +490,6 @@ def execute_check(self, db_name=None, sql=''): sql=check_sql, execute_time=0) else: - check_result.error_count += 1 result = ReviewResult(id=line, errlevel=2, stagestatus='驳回不支持语句', errormessage='仅支持DML和DDL语句,如需查询请使用数据库查询功能!', @@ -500,7 +497,6 @@ def execute_check(self, db_name=None, sql=''): else: check_result.error = "语法错误" - check_result.error_count += 1 result = ReviewResult(id=line, errlevel=2, stagestatus='语法错误', errormessage='请检查语句的正确性或(){} },{是否正确匹配!', @@ -511,6 +507,12 @@ def execute_check(self, db_name=None, sql=''): check_result.column_list = ['Result'] # 审核结果的列名 check_result.checked = True check_result.warning = self.warning + # 统计警告和错误数量 + for r in check_result.rows: + if r.errlevel == 1: + check_result.warning_count += 1 + if r.errlevel == 2: + check_result.error_count += 1 return check_result def get_connection(self, db_name=None): diff --git a/sql/engines/mysql.py b/sql/engines/mysql.py index 5fc369a720..c82136acb2 100644 --- a/sql/engines/mysql.py +++ b/sql/engines/mysql.py @@ -297,8 +297,8 @@ def query_check(self, db_name=None, sql=''): result['bad_query'] = True result['msg'] = explain_result.error # 不应该查看mysql.user表 - if re.match('.*(\\s)+(mysql|`mysql`)(\\s)*\\.(\\s)*(user|`user`)((\\s)*|;).*',sql.lower().replace('\n','')) or\ - (db_name=="mysql" and re.match('.*(\\s)+(user|`user`)((\\s)*|;).*',sql.lower().replace('\n',''))): + if re.match('.*(\\s)+(mysql|`mysql`)(\\s)*\\.(\\s)*(user|`user`)((\\s)*|;).*', sql.lower().replace('\n', '')) or \ + (db_name == "mysql" and re.match('.*(\\s)+(user|`user`)((\\s)*|;).*', sql.lower().replace('\n', ''))): result['bad_query'] = True result['msg'] = '您无权查看该表' @@ -348,62 +348,36 @@ def execute_check(self, db_name=None, sql=''): """上线单执行前的检查, 返回Review set""" # 进行Inception检查,获取检测结果 try: - inc_check_result = self.inc_engine.execute_check(instance=self.instance, db_name=db_name, sql=sql) + check_result = self.inc_engine.execute_check(instance=self.instance, db_name=db_name, sql=sql) except Exception as e: logger.debug(f"{self.inc_engine.name}检测语句报错:错误信息{traceback.format_exc()}") raise RuntimeError(f"{self.inc_engine.name}检测语句报错,请注意检查系统配置中{self.inc_engine.name}配置,错误信息:\n{e}") # 判断Inception检测结果 - if inc_check_result.error: - logger.debug(f"{self.inc_engine.name}检测语句报错:错误信息{inc_check_result.error}") - raise RuntimeError(f"{self.inc_engine.name}检测语句报错,错误信息:\n{inc_check_result.error}") + if check_result.error: + logger.debug(f"{self.inc_engine.name}检测语句报错:错误信息{check_result.error}") + raise RuntimeError(f"{self.inc_engine.name}检测语句报错,错误信息:\n{check_result.error}") # 禁用/高危语句检查 - check_critical_result = ReviewSet(full_sql=sql) - line = 1 critical_ddl_regex = self.config.get('critical_ddl_regex', '') p = re.compile(critical_ddl_regex) - check_critical_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML - - for row in inc_check_result.rows: + for row in check_result.rows: statement = row.sql # 去除注释 statement = remove_comments(statement, db_type='mysql') # 禁用语句 if re.match(r"^select", statement.lower()): - check_critical_result.is_critical = True - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回不支持语句', - errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', - sql=statement) + check_result.error_count += 1 + row.stagestatus = '驳回不支持语句' + row.errlevel = 2 + row.errormessage = '仅支持DML和DDL语句,查询语句请使用SQL查询功能!' # 高危语句 elif critical_ddl_regex and p.match(statement.strip().lower()): - check_critical_result.is_critical = True - result = ReviewResult(id=line, errlevel=2, - stagestatus='驳回高危SQL', - errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!', - sql=statement) - # 正常语句 - else: - result = ReviewResult(id=line, errlevel=0, - stagestatus='Audit completed', - errormessage='None', - sql=statement, - affected_rows=0, - execute_time=0, ) - - # 没有找出DDL语句的才继续执行此判断 - if check_critical_result.syntax_type == 2: - if get_syntax_type(statement, parser=False, db_type='mysql') == 'DDL': - check_critical_result.syntax_type = 1 - check_critical_result.rows += [result] - - # 遇到禁用和高危语句直接返回 - if check_critical_result.is_critical: - check_critical_result.error_count += 1 - return check_critical_result - line += 1 - return inc_check_result + check_result.error_count += 1 + row.stagestatus = '驳回高危SQL' + row.errlevel = 2 + row.errormessage = '禁止提交匹配' + critical_ddl_regex + '条件的语句!' + return check_result def execute_workflow(self, workflow): """执行上线单,返回Review set""" diff --git a/sql/engines/oracle.py b/sql/engines/oracle.py index aab93ea52c..537ffbdd94 100644 --- a/sql/engines/oracle.py +++ b/sql/engines/oracle.py @@ -341,21 +341,21 @@ def object_name_check(self, db_name=None, object_name=''): schema_name = object_name.split('.')[0] object_name = object_name.split('.')[1] if '"' in schema_name: - schema_name = schema_name.replace( '"','' ) + schema_name = schema_name.replace('"', '') if '"' in object_name: - object_name = object_name.replace( '"','' ) + object_name = object_name.replace('"', '') else: object_name = object_name.upper() else: schema_name = schema_name.upper() if '"' in object_name: - object_name = object_name.replace( '"','' ) + object_name = object_name.replace('"', '') else: object_name = object_name.upper() else: schema_name = db_name if '"' in object_name: - object_name = object_name.replace( '"','' ) + object_name = object_name.replace('"', '') else: object_name = object_name.upper() sql = f""" SELECT object_name FROM all_objects WHERE OWNER = '{schema_name}' and OBJECT_NAME = '{object_name}' """ @@ -369,26 +369,26 @@ def object_name_check(self, db_name=None, object_name=''): def get_sql_first_object_name(sql=''): """获取sql文本中的object_name""" object_name = '' - if re.match(r"^create\s+table\s", sql, re.M|re.IGNORECASE): - object_name = re.match(r"^create\s+table\s(.+?)(\s|\()", sql, re.M|re.IGNORECASE).group(1) - elif re.match(r"^create\s+index\s", sql, re.M|re.IGNORECASE): - object_name = re.match(r"^create\s+index\s(.+?)\s", sql, re.M|re.IGNORECASE).group(1) - elif re.match(r"^create\s+unique\s+index\s", sql, re.M|re.IGNORECASE): - object_name = re.match(r"^create\s+unique\s+index\s(.+?)\s", sql, re.M|re.IGNORECASE).group(1) - elif re.match(r"^create\s+sequence\s", sql, re.M|re.IGNORECASE): - object_name = re.match(r"^create\s+sequence\s(.+?)(\s|$)", sql, re.M|re.IGNORECASE).group(1) - elif re.match(r"^alter\s+table\s", sql, re.M|re.IGNORECASE): - object_name = re.match(r"^alter\s+table\s(.+?)\s", sql, re.M|re.IGNORECASE).group(1) - elif re.match(r"^create\s+function\s", sql, re.M|re.IGNORECASE): - object_name = re.match(r"^create\s+function\s(.+?)(\s|\()", sql, re.M|re.IGNORECASE).group(1) - elif re.match(r"^create\s+view\s", sql, re.M|re.IGNORECASE): - object_name = re.match(r"^create\s+view\s(.+?)\s", sql, re.M|re.IGNORECASE).group(1) - elif re.match(r"^create\s+procedure\s", sql, re.M|re.IGNORECASE): - object_name = re.match(r"^create\s+procedure\s(.+?)\s", sql, re.M|re.IGNORECASE).group(1) - elif re.match(r"^create\s+package\s+body", sql, re.M|re.IGNORECASE): - object_name = re.match(r"^create\s+package\s+body\s(.+?)\s", sql, re.M|re.IGNORECASE).group(1) - elif re.match(r"^create\s+package\s", sql, re.M|re.IGNORECASE): - object_name = re.match(r"^create\s+package\s(.+?)\s", sql, re.M|re.IGNORECASE).group(1) + if re.match(r"^create\s+table\s", sql, re.M | re.IGNORECASE): + object_name = re.match(r"^create\s+table\s(.+?)(\s|\()", sql, re.M | re.IGNORECASE).group(1) + elif re.match(r"^create\s+index\s", sql, re.M | re.IGNORECASE): + object_name = re.match(r"^create\s+index\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + elif re.match(r"^create\s+unique\s+index\s", sql, re.M | re.IGNORECASE): + object_name = re.match(r"^create\s+unique\s+index\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + elif re.match(r"^create\s+sequence\s", sql, re.M | re.IGNORECASE): + object_name = re.match(r"^create\s+sequence\s(.+?)(\s|$)", sql, re.M | re.IGNORECASE).group(1) + elif re.match(r"^alter\s+table\s", sql, re.M | re.IGNORECASE): + object_name = re.match(r"^alter\s+table\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + elif re.match(r"^create\s+function\s", sql, re.M | re.IGNORECASE): + object_name = re.match(r"^create\s+function\s(.+?)(\s|\()", sql, re.M | re.IGNORECASE).group(1) + elif re.match(r"^create\s+view\s", sql, re.M | re.IGNORECASE): + object_name = re.match(r"^create\s+view\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + elif re.match(r"^create\s+procedure\s", sql, re.M | re.IGNORECASE): + object_name = re.match(r"^create\s+procedure\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + elif re.match(r"^create\s+package\s+body", sql, re.M | re.IGNORECASE): + object_name = re.match(r"^create\s+package\s+body\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) + elif re.match(r"^create\s+package\s", sql, re.M | re.IGNORECASE): + object_name = re.match(r"^create\s+package\s(.+?)\s", sql, re.M | re.IGNORECASE).group(1) else: return object_name.strip() return object_name.strip() @@ -435,7 +435,8 @@ def get_dml_table(sql='', object_name_list=None, db_name=''): else: return False elif re.match(r"^insert\s", sql): - table_name = re.match(r"^insert\s+((into)|(all\s+into)|(all\s+when\s(.+?)into))\s+(.+?)(\(|\s)", sql, re.M).group(6) + table_name = re.match(r"^insert\s+((into)|(all\s+into)|(all\s+when\s(.+?)into))\s+(.+?)(\(|\s)", sql, + re.M).group(6) if '.' not in table_name: table_name = f"{db_name}.{table_name}" if table_name in object_name_list: @@ -560,7 +561,6 @@ def query(self, db_name=None, sql='', limit_num=0, close_conn=True, **kwargs): self.close() return result_set - def query_masking(self, db_name=None, sql='', resultset=None): """简单字段脱敏规则, 仅对select有效""" if re.match(r"^select", sql, re.I): @@ -594,28 +594,24 @@ def execute_check(self, db_name=None, sql='', close_conn=True): sql_nolower = sqlitem.statement.rstrip(';') # 禁用语句 if re.match(r"^select|^with|^explain", sql_lower): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回不支持语句', errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', sql=sqlitem.statement) # 高危语句 elif critical_ddl_regex and p.match(sql_lower.strip()): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回高危SQL', errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!', sql=sqlitem.statement) # 驳回未带where数据修改语句,如确实需做全部删除或更新,显示的带上where 1=1 elif re.match(r"^update((?!where).)*$|^delete((?!where).)*$", sql_lower): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回未带where数据修改', errormessage='数据修改需带where条件!', sql=sqlitem.statement) # 驳回事务控制,会话控制SQL elif re.match(r"^set|^rollback|^exit", sql_lower): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='SQL中不能包含^set|^rollback|^exit', errormessage='SQL中不能包含^set|^rollback|^exit', @@ -651,7 +647,6 @@ def execute_check(self, db_name=None, sql='', close_conn=True): else: result_set = self.explain_check(db_name=db_name, sql=sqlitem.statement, close_conn=False) if result_set['msg']: - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='explain语法检查未通过!', errormessage=result_set['msg'], @@ -673,14 +668,13 @@ def execute_check(self, db_name=None, sql='', close_conn=True): if '"' not in object_name: object_name = object_name.upper() else: - schema_name = ( '"' + db_name + '"' ) + schema_name = ('"' + db_name + '"') if '"' not in object_name: object_name = object_name.upper() object_name = f"""{schema_name}.{object_name}""" if self.object_name_check(db_name=db_name, object_name=object_name) or object_name in object_name_list: - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus=f"""{object_name}对象已经存在!""", errormessage=f"""{object_name}对象已经存在!""", @@ -749,14 +743,13 @@ def execute_check(self, db_name=None, sql='', close_conn=True): if '"' not in object_name: object_name = object_name.upper() else: - schema_name = ( '"' + db_name + '"' ) + schema_name = ('"' + db_name + '"') if '"' not in object_name: object_name = object_name.upper() object_name = f"""{schema_name}.{object_name}""" if not self.object_name_check(db_name=db_name, object_name=object_name) and object_name not in object_name_list: - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus=f"""{object_name}对象不存在!""", errormessage=f"""{object_name}对象不存在!""", @@ -787,14 +780,13 @@ def execute_check(self, db_name=None, sql='', close_conn=True): if '"' not in object_name: object_name = object_name.upper() else: - schema_name = ( '"' + db_name + '"' ) + schema_name = ('"' + db_name + '"') if '"' not in object_name: object_name = object_name.upper() object_name = f"""{schema_name}.{object_name}""" if self.object_name_check(db_name=db_name, object_name=object_name) or object_name in object_name_list: - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus=f"""{object_name}对象已经存在!""", errormessage=f"""{object_name}对象已经存在!""", @@ -826,10 +818,6 @@ def execute_check(self, db_name=None, sql='', close_conn=True): if get_syntax_type(sql=sqlitem.statement, db_type='oracle') == 'DDL': check_result.syntax_type = 1 check_result.rows += [result] - # 遇到禁用和高危语句直接返回,提高效率 - if check_result.is_critical: - check_result.error_count += 1 - return check_result line += 1 except Exception as e: logger.warning(f"Oracle 语句执行报错,第{line}个SQL:{sqlitem.statement},错误信息{traceback.format_exc()}") @@ -837,6 +825,12 @@ def execute_check(self, db_name=None, sql='', close_conn=True): finally: if close_conn: self.close() + # 统计警告和错误数量 + for r in check_result.rows: + if r.errlevel == 1: + check_result.warning_count += 1 + if r.errlevel == 2: + check_result.error_count += 1 return check_result def execute_workflow(self, workflow, close_conn=True): @@ -867,16 +861,16 @@ def execute_workflow(self, workflow, close_conn=True): statement = sqlitem.statement if sqlitem.stmt_type == "SQL": statement = statement.rstrip(';') - #如果是DDL的工单,获取对象的原定义,并保存到sql_rollback.undo_sql - #需要授权 grant execute on dbms_metadata to xxxxx + # 如果是DDL的工单,获取对象的原定义,并保存到sql_rollback.undo_sql + # 需要授权 grant execute on dbms_metadata to xxxxx if workflow.syntax_type == 1: - object_name=self.get_sql_first_object_name(statement) - back_obj_sql=f"""select dbms_metadata.get_ddl(object_type,object_name,owner) + object_name = self.get_sql_first_object_name(statement) + back_obj_sql = f"""select dbms_metadata.get_ddl(object_type,object_name,owner) from all_objects where (object_name=upper( '{object_name}' ) or OBJECT_NAME = '{sqlitem.object_name}') and owner='{workflow.db_name}' """ cursor.execute(back_obj_sql) - metdata_back_flag=self.metdata_backup(workflow, cursor,statement) + metdata_back_flag = self.metdata_backup(workflow, cursor, statement) with FuncTimer() as t: if statement != '': @@ -1019,7 +1013,7 @@ def backup(self, workflow, cursor, begin_time, end_time): conn.close() return True - def metdata_backup(self, workflow, cursor ,redo_sql): + def metdata_backup(self, workflow, cursor, redo_sql): """ :param workflow: 工单对象,作为备份记录与工单的关联列 :param cursor: 执行SQL的当前会话游标,保存metadata diff --git a/sql/engines/pgsql.py b/sql/engines/pgsql.py index 1444116554..f6197b06ef 100644 --- a/sql/engines/pgsql.py +++ b/sql/engines/pgsql.py @@ -207,14 +207,12 @@ def execute_check(self, db_name=None, sql=''): statement = sqlparse.format(statement, strip_comments=True) # 禁用语句 if re.match(r"^select", statement.lower()): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回不支持语句', errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', sql=statement) # 高危语句 elif critical_ddl_regex and p.match(statement.strip().lower()): - check_result.is_critical = True result = ReviewResult(id=line, errlevel=2, stagestatus='驳回高危SQL', errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!', @@ -232,12 +230,13 @@ def execute_check(self, db_name=None, sql=''): if get_syntax_type(statement) == 'DDL': check_result.syntax_type = 1 check_result.rows += [result] - - # 遇到禁用和高危语句直接返回,提高效率 - if check_result.is_critical: - check_result.error_count += 1 - return check_result line += 1 + # 统计警告和错误数量 + for r in check_result.rows: + if r.errlevel == 1: + check_result.warning_count += 1 + if r.errlevel == 2: + check_result.error_count += 1 return check_result def execute_workflow(self, workflow, close_conn=True): diff --git a/sql/engines/tests.py b/sql/engines/tests.py index 4806c7d4d1..be3a86b0cd 100644 --- a/sql/engines/tests.py +++ b/sql/engines/tests.py @@ -400,7 +400,7 @@ def test_execute_check_select_sql(self, _inception_engine): errormessage='None', sql=sql, affected_rows=0, - execute_time=0, ) + execute_time='', ) row = ReviewResult(id=1, errlevel=2, stagestatus='驳回不支持语句', errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', @@ -423,7 +423,7 @@ def test_execute_check_critical_sql(self, _inception_engine): errormessage='None', sql=sql, affected_rows=0, - execute_time=0, ) + execute_time='', ) row = ReviewResult(id=1, errlevel=2, stagestatus='驳回高危SQL', errormessage='禁止提交匹配' + '^|update' + '条件的语句!', diff --git a/sql/utils/sql_review.py b/sql/utils/sql_review.py index 6aea3775fe..7dfa52fba8 100644 --- a/sql/utils/sql_review.py +++ b/sql/utils/sql_review.py @@ -118,8 +118,7 @@ def can_timingtask(user, workflow_id): def can_cancel(user, workflow_id): """ 判断用户当前是否是可终止, - 审核中的工单,审核人和提交人可终止 - 审核通过但未执行的工单,有执行权限的用户终止 + 审核中、审核通过的的工单,审核人和提交人可终止 :param user: :param workflow_id: :return: @@ -129,11 +128,9 @@ def can_cancel(user, workflow_id): # 审核中的工单,审核人和提交人可终止 if workflow_detail.status == 'workflow_manreviewing': from sql.utils.workflow_audit import Audit - if Audit.can_review(user, workflow_id, 2) or user.username == workflow_detail.engineer: - result = True - # 审核通过但未执行的工单,执行人可以打回 - if workflow_detail.status in ['workflow_review_pass', 'workflow_timingtask']: - result = True if can_execute(user, workflow_id) else False + return any([Audit.can_review(user, workflow_id, 2), user.username == workflow_detail.engineer]) + elif workflow_detail.status in ['workflow_review_pass', 'workflow_timingtask']: + return any([can_execute(user, workflow_id), user.username == workflow_detail.engineer]) return result diff --git a/sql/utils/tests.py b/sql/utils/tests.py index 6121c0144d..6ce7b97510 100644 --- a/sql/utils/tests.py +++ b/sql/utils/tests.py @@ -434,18 +434,18 @@ def test_can_cancel_true_for_execute_user(self, _can_execute): self.assertTrue(r) @patch('sql.utils.sql_review.can_execute') - def test_can_cancel_false(self, _can_execute): + def test_can_cancel_true_for_submit_user(self, _can_execute): """ - 测试是否能取消,审核通过但未执行的工单,无执行权限的用户无法终止 + 测试是否能取消,审核通过但未执行的工单,提交人可终止 :return: """ # 修改工单为workflow_review_pass,当前登录用户为提交人 self.wf1.status = 'workflow_review_pass' self.wf1.engineer = self.user.username self.wf1.save(update_fields=('status', 'engineer')) - _can_execute.return_value = False + _can_execute.return_value = True r = can_cancel(user=self.user, workflow_id=self.wfc1.workflow_id) - self.assertFalse(r) + self.assertTrue(r) def test_on_correct_time_period(self): """