Skip to content

Commit

Permalink
Start presto cursor backoff
Browse files Browse the repository at this point in the history
  • Loading branch information
erik_ritter committed Jun 26, 2020
1 parent df71fac commit bf0025f
Showing 1 changed file with 30 additions and 8 deletions.
38 changes: 30 additions & 8 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING
from urllib import parse

import backoff
import pandas as pd
import simplejson as json
from sqlalchemy import Column, literal_column
Expand Down Expand Up @@ -98,6 +99,25 @@ def get_children(column: Dict[str, str]) -> List[Dict[str, str]]:
raise Exception(f"Unknown type {type_}!")


def decimal_fibo(max_value: Optional[float]) -> float:
"""
Generator for fibonaccial decay starting with 0.1 instead of 1.
:param max_value: The maximum value to yield. Once the value in the
true fibonacci sequence exceeds this, the value
of max_value will forever after be yielded.
:return: The next value in the sequence.
"""
a = 0.1
b = 0.1
while True:
if max_value is None or a < max_value:
yield a
a, b = b, a + b
else:
yield max_value


class PrestoEngineSpec(BaseEngineSpec):
engine = "presto"

Expand Down Expand Up @@ -726,30 +746,30 @@ def get_create_view(
return rows[0][0]

@classmethod
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
"""Updates progress information"""
@backoff.on_predicate(decimal_fibo, jitter=None, max_value=1.0)
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> bool:
query_id = query.id
logger.info("Query %i: Polling the cursor for progress", query_id)
polled = cursor.poll()
# poll returns dict -- JSON status information or ``None``
# if the query is done
# https://github.com/dropbox/PyHive/blob/
# b34bdbf51378b3979eaf5eca9e956f06ddc36ca0/pyhive/presto.py#L178
while polled:
if polled:
# Update the object and wait for the kill signal.
stats = polled.get("stats", {})

query = session.query(type(query)).filter_by(id=query_id).one()
if query.status in [QueryStatus.STOPPED, QueryStatus.TIMED_OUT]:
cursor.cancel()
break
return True

if stats:
state = stats.get("state")

# if already finished, then stop polling
if state == "FINISHED":
break
return True

completed_splits = float(stats.get("completedSplits"))
total_splits = float(stats.get("totalSplits"))
Expand All @@ -762,9 +782,11 @@ def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
if progress > query.progress:
query.progress = progress
session.commit()
time.sleep(1)
logger.info("Query %i: Polling the cursor for progress", query_id)
polled = cursor.poll()

# Query isn't done yet, so let's backoff and go for another loop
return False
else:
return True

@classmethod
def _extract_error_message(cls, ex: Exception) -> Optional[str]:
Expand Down

0 comments on commit bf0025f

Please sign in to comment.