diff --git a/caravel/utils.py b/caravel/utils.py index 26e05b77e94dd..b4296c92eeef9 100644 --- a/caravel/utils.py +++ b/caravel/utils.py @@ -319,6 +319,21 @@ def json_int_dttm_ser(obj): return obj +def error_msg_from_exception(e): + """Translate exception into error message + Database have different ways to handle exception. This function attempts + to make sense of the exception object and construct a human readable + sentence. + """ + msg = '' + if hasattr(e, 'message'): + if (type(e.message) is dict): + msg = e.message.get('message') + elif e.message: + msg = "{}".format(e.message) + return msg or '{}'.format(e) + + def markdown(s, markup_wrap=False): s = s or '' s = md(s, [ diff --git a/caravel/views.py b/caravel/views.py index a1dfb3fe19218..fe1984254d165 100755 --- a/caravel/views.py +++ b/caravel/views.py @@ -1214,6 +1214,7 @@ def select_star(self, database_id, table_name): @log_this def runsql(self): """Runs arbitrary sql and returns and html table""" + # TODO deprecate in favor on `sql_json` session = db.session() limit = 1000 data = json.loads(request.form.get('data')) @@ -1225,7 +1226,7 @@ def runsql(self): not self.can_access( 'all_datasource_access', 'all_datasource_access')): raise utils.CaravelSecurityException(_( - "This view requires the `all_datasource_access` permission")) + "SQL Lab requires the `all_datasource_access` permission")) content = "" if mydb: eng = mydb.get_sqla_engine() @@ -1253,6 +1254,59 @@ def runsql(self): session.commit() return content + @has_access + @expose("/sql_json/", methods=['POST', 'GET']) + @log_this + def sql_json(self): + """Runs arbitrary sql and returns and json""" + session = db.session() + limit = 1000 + sql = request.form.get('sql') + database_id = request.form.get('database_id') + mydb = session.query(models.Database).filter_by(id=database_id).first() + + if ( + not self.can_access( + 'all_datasource_access', 'all_datasource_access')): + raise utils.CaravelSecurityException(_( + "This view requires the `all_datasource_access` permission")) + + error_msg = "" + if not mydb: + error_msg = "The database selected doesn't seem to exist" + else: + eng = mydb.get_sqla_engine() + if limit: + sql = sql.strip().strip(';') + qry = ( + select('*') + .select_from(TextAsFrom(text(sql), ['*']).alias('inner_qry')) + .limit(limit) + ) + sql = str(qry.compile(eng, compile_kwargs={"literal_binds": True})) + try: + df = pd.read_sql_query(sql=sql, con=eng) + df = df.fillna(0) # TODO make sure NULL + except Exception as e: + logging.exception(e) + error_msg = utils.error_msg_from_exception(e) + + session.commit() + if error_msg: + return Response( + json.dumps({ + 'error': error_msg, + }), + status=500, + mimetype="application/json") + else: + data = { + 'columns': [c for c in df.columns], + 'data': df.to_dict(orient='records'), + } + return json.dumps(data, default=utils.json_int_dttm_ser, allow_nan=False) + + @has_access @expose("/refresh_datasources/") def refresh_datasources(self): diff --git a/tests/core_tests.py b/tests/core_tests.py index ef957ec46e2c4..b6f9154b38833 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -226,6 +226,26 @@ def test_gamma(self): resp = self.client.get('/dashboardmodelview/list/') assert "List Dashboard" in resp.data.decode('utf-8') + def run_sql(self, sql): + self.login(username='admin') + dbid = ( + db.session.query(models.Database) + .filter_by(database_name="main") + .first().id + ) + resp = self.client.post( + '/caravel/sql_json/', + data=dict(database_id=dbid, sql=sql), + ) + return json.loads(resp.data.decode('utf-8')) + + def test_sql_json(self): + data = self.run_sql("SELECT * FROM ab_user") + assert len(data['data']) > 0 + + data = self.run_sql("SELECT * FROM unexistant_table") + assert len(data['error']) > 0 + def test_public_user_dashboard_access(self): # Try access before adding appropriate permissions. self.revoke_public_access('birth_names')