Skip to content

Commit

Permalink
Merge pull request #175 from hhyo/oracle
Browse files Browse the repository at this point in the history
新增oracle查询
  • Loading branch information
hhyo authored May 9, 2019
2 parents 2324a47 + 35b53c7 commit 2755119
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 7 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ pyecharts_snapshot==0.1.7
aliyun-python-sdk-core==2.3.5
aliyun-python-sdk-core-v3==2.5.3
aliyun-python-sdk-rds==2.1.1
cx-Oracle==7.1.3
3 changes: 3 additions & 0 deletions sql/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,6 @@ def get_engine(instance=None):
elif instance.db_type == 'pgsql':
from .pgsql import PgSQLEngine
return PgSQLEngine(instance=instance)
elif instance.db_type == 'oracle':
from .oracle import OracleEngine
return OracleEngine(instance=instance)
15 changes: 14 additions & 1 deletion sql/engines/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from . import EngineBase
import pyodbc
from .models import ResultSet
from .models import ResultSet, ReviewSet, ReviewResult
from sql.utils.data_masking import brute_mask

logger = logging.getLogger('default')
Expand Down Expand Up @@ -154,6 +154,19 @@ def query_masking(self, db_name=None, sql='', resultset=None):
filtered_result = resultset
return filtered_result

def execute_check(self, db_name=None, sql=''):
"""上线单执行前的检查, 返回Review set"""
check_result = ReviewSet(full_sql=sql)
result = ReviewResult(id=1,
errlevel=0,
stagestatus='Audit completed',
errormessage='None',
sql=sql,
affected_rows=0,
execute_time=0, )
check_result.rows += [result]
return check_result

def close(self):
if self.conn:
self.conn.close()
Expand Down
2 changes: 1 addition & 1 deletion sql/engines/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def query_check(self, db_name=None, sql=''):
sql = sqlparse.split(sql)[0]
result['filtered_sql'] = sql.strip()
except IndexError:
result['has_star'] = True
result['bad_query'] = True
result['msg'] = '没有有效的SQL语句'
if re.match(r"^select|^show|^explain", sql, re.I) is None:
result['bad_query'] = True
Expand Down
200 changes: 200 additions & 0 deletions sql/engines/oracle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# -*- coding: UTF-8 -*-
# https://stackoverflow.com/questions/7942520/relationship-between-catalog-schema-user-and-database-instance
import logging
import traceback
import re
import sqlparse

from . import EngineBase
import cx_Oracle
from .models import ResultSet, ReviewSet, ReviewResult
from sql.utils.data_masking import brute_mask

logger = logging.getLogger('default')


class OracleEngine(EngineBase):

def __init__(self, instance=None):
super(OracleEngine, self).__init__(instance=instance)
self.service_name = instance.service_name
self.sid = instance.sid

def get_connection(self, db_name=None):
if self.conn:
return self.conn
if self.sid:
dsn = cx_Oracle.makedsn(self.host, self.port, self.sid)
self.conn = cx_Oracle.connect(self.user, self.password, dsn=dsn)
elif self.service_name:
dsn = cx_Oracle.makedsn(self.host, self.port, service_name=self.service_name)
self.conn = cx_Oracle.connect(self.user, self.password, dsn=dsn)
else:
self.conn = None
return self.conn

def get_all_databases(self):
"""获取数据库列表, 返回resultSet 供上层调用, 底层实际上是获取oracle的schema列表"""
return self._get_all_schemas()

def _get_all_databases(self):
"""获取数据库列表, 返回一个ResultSet"""
sql = "select name from v$database"
result = self.query(sql=sql)
db_list = [row[0] for row in result.rows]
result.rows = db_list
return result

def _get_all_instances(self):
"""获取实例列表, 返回一个ResultSet"""
sql = "select instance_name from v$instance"
result = self.query(sql=sql)
instance_list = [row[0] for row in result.rows]
result.rows = instance_list
return result

def _get_all_schemas(self):
"""
获取模式列表
:return:
"""
result = self.query(sql="select username from sys.dba_users")
schema_list = [row[0] for row in result.rows if row[0] not in ['ORACLE_OCM', 'DIP', 'DBSNMP', 'APPQOSSYS',
'MGMT_VIEW', 'SYS', 'SYSTEM', 'OUTLN']]
result.rows = schema_list
return result

def get_all_tables(self, schema_name):
"""获取table 列表, 返回一个ResultSet"""
sql = f"""select
TABLE_NAME
from dba_tab_privs
where grantee in ('{schema_name}')
union
select
OBJECT_NAME
from dba_objects
WHERE OWNER IN ('{schema_name}') and object_type in ('TABLE')
"""
result = self.query(sql=sql)
tb_list = [row[0] for row in result.rows if row[0] not in ['test']]
result.rows = tb_list
return result

def get_all_columns_by_tb(self, schema_name, tb_name):
"""获取所有字段, 返回一个ResultSet"""
result = self.describe_table(schema_name, tb_name)
column_list = [row[0] for row in result.rows]
result.rows = column_list
return result

def describe_table(self, schema_name, tb_name):
"""return ResultSet"""
# https://www.thepolyglotdeveloper.com/2015/01/find-tables-oracle-database-column-name/
sql = f"""SELECT
column_name,
data_type,
data_length,
nullable,
data_default
FROM all_tab_cols
WHERE table_name = '{tb_name}'
"""
result = self.query(sql=sql)
return result

def query_check(self, db_name=None, sql=''):
# 查询语句的检查、注释去除、切分
result = {'msg': '', 'bad_query': False, 'filtered_sql': sql, 'has_star': False}
keyword_warning = ''
star_patter = r"(^|,| )\*( |\(|$)"
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sql.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
result['filtered_sql'] = re.sub(r';$', '', sql.strip())
sql_lower = sql.lower()
except IndexError:
result['has_star'] = True
result['msg'] = '没有有效的SQL语句'
return result
if re.match(r"^select", sql_lower) is None:
result['bad_query'] = True
result['msg'] = '仅支持^select语法!'
return result
if re.search(star_patter, sql_lower) is not None:
keyword_warning += '禁止使用 * 关键词\n'
result['bad_query'] = True
result['has_star'] = True
if '+' in sql_lower:
keyword_warning += '禁止使用 + 关键词\n'
result['bad_query'] = True
if result.get('bad_query'):
result['msg'] = keyword_warning
return result

def filter_sql(self, sql='', limit_num=0):
sql_lower = sql.lower()
# 对查询sql增加limit限制
if re.match(r"^select", sql_lower):
if sql_lower.find(' rownum ') == -1:
if sql_lower.find(' where ') == -1:
return f"{sql.rstrip(';')} WHERE ROWNUM <= {limit_num}"
else:
return f"{sql.rstrip(';')} AND ROWNUM <= {limit_num}"
return sql.strip()

def query(self, db_name=None, sql='', limit_num=0, close_conn=True):
"""返回 ResultSet """
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection()
cursor = conn.cursor()
# if schema_name:
# cursor.execute(f"ALTER SESSION SET CURRENT_SCHEMA = {schema_name}")
cursor.execute(sql)
if int(limit_num) > 0:
rows = cursor.fetchmany(int(limit_num))
else:
rows = cursor.fetchall()
fields = cursor.description

result_set.column_list = [i[0] for i in fields] if fields else []
result_set.rows = [tuple(x) for x in rows]
result_set.affected_rows = len(result_set.rows)
except Exception as e:
logger.error(f"Oracle 语句执行报错,语句:{sql},错误信息{traceback.format_exc()}")
result_set.error = str(e)
finally:
if close_conn:
self.close()
return result_set

def query_masking(self, schema_name=None, sql='', resultset=None):
"""传入 sql语句, db名, 结果集,
返回一个脱敏后的结果集"""
# 仅对select语句脱敏
if re.match(r"^select", sql, re.I):
filtered_result = brute_mask(resultset)
filtered_result.is_masked = True
else:
filtered_result = resultset
return filtered_result

def execute_check(self, db_name=None, sql=''):
"""上线单执行前的检查, 返回Review set"""
check_result = ReviewSet(full_sql=sql)
result = ReviewResult(id=1,
errlevel=0,
stagestatus='Audit completed',
errormessage='None',
sql=sql,
affected_rows=0,
execute_time=0, )
check_result.rows += [result]
return check_result

def close(self):
if self.conn:
self.conn.close()
self.conn = None
38 changes: 37 additions & 1 deletion sql/engines/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sql.engines.mysql import MysqlEngine
from sql.engines.redis import RedisEngine
from sql.engines.pgsql import PgSQLEngine
from sql.engines.oracle import OracleEngine
from sql.engines.inception import InceptionEngine, _repair_json_str
from sql.models import Instance, SqlWorkflow, SqlWorkflowContent

Expand Down Expand Up @@ -240,7 +241,7 @@ def test_query_check_wrong_sql(self):
wrong_sql = '-- 测试'
check_result = new_engine.query_check(db_name='some_db', sql=wrong_sql)
self.assertDictEqual(check_result,
{'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': '-- 测试', 'has_star': True})
{'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': '-- 测试', 'has_star': False})

def test_query_check_update_sql(self):
new_engine = MysqlEngine(instance=self.ins1)
Expand Down Expand Up @@ -850,3 +851,38 @@ def test_query_not_limit(self, _conn, _cursor, _execute):
new_engine = GoInceptionEngine(instance=self.ins)
query_result = new_engine.query(db_name=0, sql='select 1', limit_num=0)
self.assertIsInstance(query_result, ResultSet)


class TestOracle(TestCase):
"""Oracle 测试"""
def setUp(self):
self.ins = Instance.objects.create(instance_name='some_ins', type='slave', db_type='oracle',
host='some_host', port=3306, user='ins_user', password='some_pass',
sid='some_id')
self.wf = SqlWorkflow.objects.create(
workflow_name='some_name',
group_id=1,
group_name='g1',
engineer_display='',
audit_auth_groups='some_group',
create_time=datetime.now() - timedelta(days=1),
status='workflow_finish',
is_backup=True,
instance=self.ins,
db_name='some_db',
syntax_type=1
)
SqlWorkflowContent.objects.create(workflow=self.wf)

def tearDown(self):
self.ins.delete()
SqlWorkflow.objects.all().delete()
SqlWorkflowContent.objects.all().delete()

@patch('cx_Oracle.makedsn')
@patch('cx_Oracle.connect')
def test_get_connection(self, _connect, _makedsn):
new_engine = OracleEngine(self.ins)
new_engine.get_connection()
_connect.assert_called_once()
_makedsn.assert_called_once()
9 changes: 6 additions & 3 deletions sql/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class Meta:
('mysql', 'MySQL'),
('mssql', 'MsSQL'),
('redis', 'Redis'),
('pgsql', 'PgSQL'),)
('pgsql', 'PgSQL'),
('oracle', 'Oracle'),)


class Instance(models.Model):
Expand All @@ -87,6 +88,8 @@ class Instance(models.Model):
port = models.IntegerField('端口', default=0)
user = models.CharField('用户名', max_length=100, default='', blank=True)
password = models.CharField('密码', max_length=300, default='', blank=True)
service_name = models.CharField('Oracle service name', max_length=50, null=True, blank=True)
sid = models.CharField('Oracle sid', max_length=50, null=True, blank=True)
create_time = models.DateTimeField('创建时间', auto_now_add=True)
update_time = models.DateTimeField('更新时间', auto_now=True)

Expand Down Expand Up @@ -324,8 +327,8 @@ class QueryPrivilegesApply(models.Model):
user_name = models.CharField('申请人', max_length=30)
user_display = models.CharField('申请人中文名', max_length=50, default='')
instance = models.ForeignKey(Instance, on_delete=models.CASCADE)
db_list = models.TextField('数据库') # 逗号分隔的数据库列表
table_list = models.TextField('表') # 逗号分隔的表列表
db_list = models.TextField('数据库', default='') # 逗号分隔的数据库列表
table_list = models.TextField('表', default='') # 逗号分隔的表列表
valid_date = models.DateField('有效时间')
limit_num = models.IntegerField('行数限制', default=100)
priv_type = models.IntegerField('权限类型', choices=((1, 'DATABASE'), (2, 'TABLE'),), default=0)
Expand Down
1 change: 1 addition & 0 deletions sql/templates/instance.html
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
<option value="mssql">MsSQL</option>
<option value="redis">Redis</option>
<option value="pgsql">PgSQL</option>
<option value="oracle">Oracle</option>
</select>
</div>
<div class="form-group">
Expand Down
6 changes: 6 additions & 0 deletions sql/templates/queryapplylist.html
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@ <h4 class="modal-title" id="myModalLabel">申请数据库查询权限</h4>
class="selectpicker show-tick form-control bs-select-hidden" data-live-search="true"
data-placeholder="请选择实例:" required>
<option value="is-empty" disabled="" selected="selected">请选择实例:</option>
// TODO 使用models中的choices 渲染
<optgroup id="optgroup-mysql" label="MySQL"></optgroup>
<optgroup id="optgroup-mssql" label="MsSQL"></optgroup>
<optgroup id="optgroup-redis" label="Redis"></optgroup>
<optgroup id="optgroup-pgsql" label="PgSQL"></optgroup>
<optgroup id="optgroup-oracle" label="Oracle"></optgroup>

</select>
</div>
<div class="form-group">
Expand Down Expand Up @@ -154,6 +157,7 @@ <h4 class="modal-title text-danger">工单日志</h4>
$("#optgroup-mssql").empty();
$("#optgroup-redis").empty();
$("#optgroup-pgsql").empty();
$("#optgroup-oracle").empty();
for (var i = 0; i < result.length; i++) {
var instance = "<option value=\"" + result[i]['instance_name'] + "\">" + result[i]['instance_name'] + "</option>";
if (result[i]['db_type'] === 'mysql') {
Expand All @@ -164,6 +168,8 @@ <h4 class="modal-title text-danger">工单日志</h4>
$("#optgroup-redis").append(instance);
} else if (result[i]['db_type'] === 'pgsql') {
$("#optgroup-pgsql").append(instance);
} else if (result[i]['db_type'] === 'oracle') {
$("#optgroup-oracle").append(instance);
}
}
$('#instance_name').selectpicker('render');
Expand Down
Loading

0 comments on commit 2755119

Please sign in to comment.