Skip to content

Commit

Permalink
Merge pull request #1773 from getredash/patches
Browse files Browse the repository at this point in the history
Split refresh schemas into separate tasks and add a timeout.
  • Loading branch information
arikfr authored May 18, 2017
2 parents 3650617 + 3807510 commit 40a8187
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 34 deletions.
28 changes: 14 additions & 14 deletions redash/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,20 @@ def create_with_group(cls, *args, **kwargs):
db.session.add_all([data_source, data_source_group])
return data_source

@classmethod
def all(cls, org, group_ids=None):
data_sources = cls.query.filter(cls.org == org).order_by(cls.id.asc())

if group_ids:
data_sources = data_sources.join(DataSourceGroup).filter(
DataSourceGroup.group_id.in_(group_ids))

return data_sources

@classmethod
def get_by_id(cls, _id):
return cls.query.filter(cls.id == _id).one()

def get_schema(self, refresh=False):
key = "data_source:schema:{}".format(self.id)

Expand Down Expand Up @@ -536,24 +550,10 @@ def update_group_permission(self, group, view_only):
def query_runner(self):
return get_query_runner(self.type, self.options)

@classmethod
def get_by_id(cls, _id):
return cls.query.filter(cls.id == _id).one()

@classmethod
def get_by_name(cls, name):
return cls.query.filter(cls.name == name).one()

@classmethod
def all(cls, org, group_ids=None):
data_sources = cls.query.filter(cls.org == org).order_by(cls.id.asc())

if group_ids:
data_sources = data_sources.join(DataSourceGroup).filter(
DataSourceGroup.group_id.in_(group_ids))

return data_sources

#XXX examine call sites to see if a regular SQLA collection would work better
@property
def groups(self):
Expand Down
34 changes: 21 additions & 13 deletions redash/tasks/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import redis

from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.utils.log import get_task_logger
from redash import models, redis_connection, settings, statsd_client, utils
Expand Down Expand Up @@ -234,9 +235,7 @@ def enqueue_query(query, data_source, user_id, scheduled_query=None, metadata={}
queue_name = data_source.queue_name
scheduled_query_id = None

result = execute_query.apply_async(args=(
query, data_source.id, metadata, user_id,
scheduled_query_id),
result = execute_query.apply_async(args=(query, data_source.id, metadata, user_id, scheduled_query_id),
queue=queue_name)
job = QueryTask(async_result=result)
tracker = QueryTaskTracker.create(
Expand Down Expand Up @@ -342,14 +341,30 @@ def cleanup_query_results():
logger.info("Deleted %d unused query results.", deleted_count)


@celery.task(name="redash.tasks.refresh_schema", soft_time_limit=60)
def refresh_schema(data_source_id):
ds = models.DataSource.get_by_id(data_source_id)
logger.info(u"task=refresh_schema state=start ds_id=%s", ds.id)
start_time = time.time()
try:
ds.get_schema(refresh=True)
logger.info(u"task=refresh_schema state=finished ds_id=%s runtime=%.2f", ds.id, time.time() - start_time)
statsd_client.incr('refresh_schema.success')
except SoftTimeLimitExceeded:
logger.info(u"task=refresh_schema state=timeout ds_id=%s runtime=%.2f", ds.id, time.time() - start_time)
statsd_client.incr('refresh_schema.timeout')
except Exception:
logger.warning(u"Failed refreshing schema for the data source: %s", ds.name, exc_info=1)
statsd_client.incr('refresh_schema.error')
logger.info(u"task=refresh_schema state=failed ds_id=%s runtime=%.2f", ds.id, time.time() - start_time)


@celery.task(name="redash.tasks.refresh_schemas")
def refresh_schemas():
"""
Refreshes the data sources schemas.
"""

blacklist = [int(ds_id) for ds_id in redis_connection.smembers('data_sources:schema:blacklist') if ds_id]

global_start_time = time.time()

logger.info(u"task=refresh_schemas state=start")
Expand All @@ -360,14 +375,7 @@ def refresh_schemas():
elif ds.id in blacklist:
logger.info(u"task=refresh_schema state=skip ds_id=%s reason=blacklist", ds.id)
else:
logger.info(u"task=refresh_schema state=start ds_id=%s", ds.id)
start_time = time.time()
try:
ds.get_schema(refresh=True)
logger.info(u"task=refresh_schema state=finished ds_id=%s runtime=%.2f", ds.id, time.time() - start_time)
except Exception:
logger.exception(u"Failed refreshing schema for the data source: %s", ds.name)
logger.info(u"task=refresh_schema state=failed ds_id=%s runtime=%.2f", ds.id, time.time() - start_time)
refresh_schema.apply_async(args=(ds.id,), queue="schemas")

logger.info(u"task=refresh_schemas state=finish total_runtime=%.2f", time.time() - global_start_time)

Expand Down
16 changes: 9 additions & 7 deletions tests/tasks/test_refresh_schemas.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
import datetime
from mock import patch, call, ANY

from mock import ANY, call, patch
from tests import BaseTestCase

from redash.tasks import refresh_schemas


class TestRefreshSchemas(BaseTestCase):
def test_calls_refresh_of_all_data_sources(self):
self.factory.data_source # trigger creation
with patch('redash.models.DataSource.get_schema') as get_schema:
with patch('redash.tasks.queries.refresh_schema.apply_async') as refresh_job:
refresh_schemas()
get_schema.assert_called_with(refresh=True)
refresh_job.assert_called()

def test_skips_paused_data_sources(self):
self.factory.data_source.pause()

with patch('redash.models.DataSource.get_schema') as get_schema:
with patch('redash.tasks.queries.refresh_schema.apply_async') as refresh_job:
refresh_schemas()
get_schema.assert_not_called()
refresh_job.assert_not_called()

self.factory.data_source.resume()

with patch('redash.models.DataSource.get_schema') as get_schema:
with patch('redash.tasks.queries.refresh_schema.apply_async') as refresh_job:
refresh_schemas()
get_schema.assert_called_with(refresh=True)
refresh_job.assert_called()

0 comments on commit 40a8187

Please sign in to comment.