Skip to content

Commit

Permalink
[sqllab] proper, quoted, select * on the server side (#1404)
Browse files Browse the repository at this point in the history
* [sqllab] proper, quoted, select * on the server side

* fixing tests
  • Loading branch information
mistercrunch authored Oct 21, 2016
1 parent 4f886d6 commit 63161b1
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 61 deletions.
2 changes: 0 additions & 2 deletions caravel/assets/javascripts/SqlLab/common.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,4 @@ export const STATE_BSSTYLE_MAP = {
success: 'success',
};

export const DATA_PREVIEW_ROW_COUNT = 100;

export const STATUS_OPTIONS = ['success', 'failed', 'running'];
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,12 @@ class SqlEditorLeftBar extends React.Component {

this.setState({ tableLoading: true });
$.get(url, (data) => {
this.props.actions.mergeTable({
this.props.actions.mergeTable(Object.assign(data, {
dbId: this.props.queryEditor.dbId,
queryEditorId: this.props.queryEditor.id,
name: data.name,
indexes: data.indexes,
schema: qe.schema,
columns: data.columns,
expanded: true,
});
}));
this.setState({ tableLoading: false });
})
.fail(() => {
Expand Down
34 changes: 3 additions & 31 deletions caravel/assets/javascripts/SqlLab/components/TableElement.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import * as Actions from '../actions';
import { ButtonGroup, Well } from 'react-bootstrap';
import shortid from 'shortid';

import { DATA_PREVIEW_ROW_COUNT } from '../common';
import CopyToClipboard from '../../components/CopyToClipboard';
import Link from './Link';
import ModalTrigger from '../../components/ModalTrigger';
Expand All @@ -23,41 +22,14 @@ const defaultProps = {
};

class TableElement extends React.Component {
setSelectStar() {
this.props.actions.queryEditorSetSql(this.props.queryEditor, this.selectStar());
}

selectStar(useStar = false, limit = 0) {
let cols = '';
this.props.table.columns.forEach((col, i) => {
cols += col.name;
if (i < this.props.table.columns.length - 1) {
cols += ', ';
}
});
let tableName = this.props.table.name;
if (this.props.table.schema) {
tableName = this.props.table.schema + '.' + tableName;
}
let sql;
if (useStar) {
sql = `SELECT * FROM ${tableName}`;
} else {
sql = `SELECT ${cols}\nFROM ${tableName}`;
}
if (limit > 0) {
sql += `\nLIMIT ${limit}`;
}
return sql;
}

popSelectStar() {
const qe = {
id: shortid.generate(),
title: this.props.table.name,
dbId: this.props.table.dbId,
autorun: true,
sql: this.selectStar(),
sql: this.props.table.selectStar,
};
this.props.actions.addQueryEditor(qe);
}
Expand All @@ -78,7 +50,7 @@ class TableElement extends React.Component {
dataPreviewModal() {
const query = {
dbId: this.props.queryEditor.dbId,
sql: this.selectStar(true, DATA_PREVIEW_ROW_COUNT),
sql: this.props.table.selectStar,
tableName: this.props.table.name,
sqlEditorId: null,
tab: '',
Expand Down Expand Up @@ -208,7 +180,7 @@ class TableElement extends React.Component {
copyNode={
<a className="fa fa-clipboard pull-left m-l-2" />
}
text={this.selectStar()}
text={table.selectStar}
shouldShowText={false}
tooltipText="Copy SELECT statement to clipboard"
/>
Expand Down
26 changes: 23 additions & 3 deletions caravel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,12 @@ def get_sqla_engine(self, schema=None):
url.database = schema
return create_engine(url, **params)

def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words

def get_quoter(self):
return self.get_sqla_engine().dialect.identifier_preparer.quote

def get_df(self, sql, schema):
eng = self.get_sqla_engine(schema=schema)
cur = eng.execute(sql, schema=schema)
Expand All @@ -701,12 +707,26 @@ def compile_sqla_query(self, qry, schema=None):
compiled = qry.compile(eng, compile_kwargs={"literal_binds": True})
return '{}'.format(compiled)

def select_star(self, table_name, schema=None, limit=1000):
def select_star(
self, table_name, schema=None, limit=100, show_cols=False,
indent=True):
"""Generates a ``select *`` statement in the proper dialect"""
qry = select('*').select_from(text(table_name))
for i in range(10):
print(schema)
quote = self.get_quoter()
fields = '*'
table = self.get_table(table_name, schema=schema)
if show_cols:
fields = [quote(c.name) for c in table.columns]
if schema:
table_name = schema + '.' + table_name
qry = select(fields).select_from(text(table_name))
if limit:
qry = qry.limit(limit)
return self.compile_sqla_query(qry)
sql = self.compile_sqla_query(qry)
if indent:
sql = sqlparse.format(sql, reindent=True)
return sql

def wrap_sql_limit(self, sql, limit=1000):
qry = (
Expand Down
5 changes: 4 additions & 1 deletion caravel/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,6 +1968,8 @@ def table(self, database_id, table_name, schema):
tbl = {
'name': table_name,
'columns': cols,
'selectStar': mydb.select_star(
table_name, schema=schema, show_cols=True, indent=True),
'indexes': indexes,
}
return Response(json.dumps(tbl), mimetype="application/json")
Expand All @@ -1988,6 +1990,7 @@ def extra_table_metadata(self, database_id, table_name, schema):
def select_star(self, database_id, table_name):
mydb = db.session.query(
models.Database).filter_by(id=database_id).first()
quote = mydb.get_quoter()
t = mydb.get_table(table_name)

# Prevent exposing column fields to users that cannot access DB.
Expand All @@ -1996,7 +1999,7 @@ def select_star(self, database_id, table_name):
return redirect("/tablemodelview/list/")

fields = ", ".join(
[c.name for c in t.columns] or "*")
[quote(c.name) for c in t.columns] or "*")
s = "SELECT\n{}\nFROM {}".format(fields, table_name)
return self.render_template(
"caravel/ajah.html",
Expand Down
36 changes: 17 additions & 19 deletions tests/celery_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def tearDownClass(cls):
shell=True
)

def run_sql(self, dbid, sql, client_id, cta='false', tmp_table='tmp',
def run_sql(self, db_id, sql, client_id, cta='false', tmp_table='tmp',
async='false'):
self.login()
resp = self.client.post(
'/caravel/sql_json/',
data=dict(
database_id=dbid,
database_id=db_id,
sql=sql,
async=async,
select_as_cta=cta,
Expand All @@ -144,12 +144,11 @@ def run_sql(self, dbid, sql, client_id, cta='false', tmp_table='tmp',

def test_add_limit_to_the_query(self):
session = db.session
db_to_query = session.query(models.Database).filter_by(
id=1).first()
eng = db_to_query.get_sqla_engine()
main_db = self.get_main_database(db.session)
eng = main_db.get_sqla_engine()

select_query = "SELECT * FROM outer_space;"
updated_select_query = db_to_query.wrap_sql_limit(select_query, 100)
updated_select_query = main_db.wrap_sql_limit(select_query, 100)
# Different DB engines have their own spacing while compiling
# the queries, that's why ' '.join(query.split()) is used.
# In addition some of the engines do not include OFFSET 0.
Expand All @@ -159,7 +158,7 @@ def test_add_limit_to_the_query(self):
)

select_query_no_semicolon = "SELECT * FROM outer_space"
updated_select_query_no_semicolon = db_to_query.wrap_sql_limit(
updated_select_query_no_semicolon = main_db.wrap_sql_limit(
select_query_no_semicolon, 100)
self.assertTrue(
"SELECT * FROM (SELECT * FROM outer_space) AS inner_qry "
Expand All @@ -170,29 +169,29 @@ def test_add_limit_to_the_query(self):
multi_line_query = (
"SELECT * FROM planets WHERE\n Luke_Father = 'Darth Vader';"
)
updated_multi_line_query = db_to_query.wrap_sql_limit(multi_line_query, 100)
updated_multi_line_query = main_db.wrap_sql_limit(multi_line_query, 100)
self.assertTrue(
"SELECT * FROM (SELECT * FROM planets WHERE "
"Luke_Father = 'Darth Vader';) AS inner_qry LIMIT 100" in
' '.join(updated_multi_line_query.split())
)

def test_run_sync_query(self):
main_db = db.session.query(models.Database).filter_by(
database_name="main").first()
main_db = self.get_main_database(db.session)
eng = main_db.get_sqla_engine()

db_id = main_db.id
# Case 1.
# Table doesn't exist.
sql_dont_exist = 'SELECT name FROM table_dont_exist'
result1 = self.run_sql(1, sql_dont_exist, "1", cta='true')
result1 = self.run_sql(db_id, sql_dont_exist, "1", cta='true')
self.assertTrue('error' in result1)

# Case 2.
# Table and DB exists, CTA call to the backend.
sql_where = "SELECT name FROM ab_permission WHERE name='can_sql'"
result2 = self.run_sql(
1, sql_where, "2", tmp_table='tmp_table_2', cta='true')
db_id, sql_where, "2", tmp_table='tmp_table_2', cta='true')
self.assertEqual(QueryStatus.SUCCESS, result2['query']['state'])
self.assertEqual([], result2['data'])
self.assertEqual([], result2['columns'])
Expand All @@ -207,7 +206,7 @@ def test_run_sync_query(self):
# Table and DB exists, CTA call to the backend, no data.
sql_empty_result = 'SELECT * FROM ab_user WHERE id=666'
result3 = self.run_sql(
1, sql_empty_result, "3", tmp_table='tmp_table_3', cta='true',)
db_id, sql_empty_result, "3", tmp_table='tmp_table_3', cta='true',)
self.assertEqual(QueryStatus.SUCCESS, result3['query']['state'])
self.assertEqual([], result3['data'])
self.assertEqual([], result3['columns'])
Expand All @@ -216,8 +215,7 @@ def test_run_sync_query(self):
self.assertEqual(QueryStatus.SUCCESS, query3.status)

def test_run_async_query(self):
main_db = db.session.query(models.Database).filter_by(
database_name="main").first()
main_db = self.get_main_database(db.session)
eng = main_db.get_sqla_engine()

# Schedule queries
Expand All @@ -226,7 +224,8 @@ def test_run_async_query(self):
# Table and DB exists, async CTA call to the backend.
sql_where = "SELECT name FROM ab_role WHERE name='Admin'"
result1 = self.run_sql(
1, sql_where, "4", async='true', tmp_table='tmp_async_1', cta='true')
main_db.id, sql_where, "4", async='true', tmp_table='tmp_async_1',
cta='true')
assert result1['query']['state'] in (
QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS)

Expand All @@ -238,7 +237,7 @@ def test_run_async_query(self):
self.assertEqual(QueryStatus.SUCCESS, query1.status)
self.assertEqual([{'name': 'Admin'}], df1.to_dict(orient='records'))
self.assertEqual(QueryStatus.SUCCESS, query1.status)
self.assertTrue("SELECT * \nFROM tmp_async_1" in query1.select_sql)
self.assertTrue("FROM tmp_async_1" in query1.select_sql)
self.assertTrue("LIMIT 666" in query1.select_sql)
self.assertEqual(
"CREATE TABLE tmp_async_1 AS \nSELECT name FROM ab_role "
Expand All @@ -252,8 +251,7 @@ def test_run_async_query(self):
self.assertEqual(True, query1.select_as_cta_used)

def test_get_columns_dict(self):
main_db = db.session.query(models.Database).filter_by(
database_name='main').first()
main_db = self.get_main_database(db.session)
df = main_db.get_df("SELECT * FROM multiformat_time_series", None)
cdf = dataframe.CaravelDataFrame(df)
if main_db.sqlalchemy_uri.startswith('sqlite'):
Expand Down

0 comments on commit 63161b1

Please sign in to comment.