Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use download thread to speed up result retrieval #280

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 43 additions & 14 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@
import copy
import functools
import os
import queue
import random
import re
import threading
import urllib.parse
import warnings
from concurrent.futures import ThreadPoolExecutor
from datetime import date, datetime, time, timedelta, timezone, tzinfo
from decimal import Decimal
from time import sleep
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union

import pytz
import requests
Expand Down Expand Up @@ -684,6 +686,27 @@ def _verify_extra_credential(self, header):
raise ValueError(f"only ASCII characters are allowed in extra credential '{key}'")


class ResultDownloader():
def __init__(self):
self.queue: queue.Queue = queue.Queue()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the docs, the default maxsize is zero, which means unbounded. When I was testing similar changes in the Go driver, it turned out that the queue size didn't matter, so even a queue of size 1 is fine. Did you test different queue sizes, and if not, can you check if maxsize greater than 1 makes a notable difference?

self.executor: Optional[ThreadPoolExecutor] = None

def submit(self, fetch_func: Callable[[], List[Any]]):
assert self.executor is not None
self.executor.submit(self.download_task, fetch_func)

def download_task(self, fetch_func):
self.queue.put(fetch_func())

def __enter__(self):
self.executor = ThreadPoolExecutor(max_workers=1)
return self

def __exit__(self, exc_type, exc_value, exc_traceback):
self.executor.shutdown()
self.executor = None


class TrinoResult(object):
"""
Represent the result of a Trino query as an iterator on rows.
Expand Down Expand Up @@ -711,16 +734,21 @@ def rownumber(self) -> int:
return self._rownumber

def __iter__(self):
# A query only transitions to a FINISHED state when the results are fully consumed:
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
while not self._query.finished or self._rows is not None:
next_rows = self._query.fetch() if not self._query.finished else None
for row in self._rows:
self._rownumber += 1
logger.debug("row %s", row)
yield row
with ResultDownloader() as result_downloader:
# A query only transitions to a FINISHED state when the results are fully consumed:
# The reception of the data is acknowledged by calling the next_uri before exposing the data through dbapi.
result_downloader.submit(self._query.fetch)
while not self._query.finished or self._rows is not None:
next_rows = result_downloader.queue.get() if not self._query.finished else None
if not self._query.finished:
result_downloader.submit(self._query.fetch)

self._rows = next_rows
for row in self._rows:
self._rownumber += 1
logger.debug("row %s", row)
yield row

self._rows = next_rows


class TrinoQuery(object):
Expand Down Expand Up @@ -753,7 +781,7 @@ def columns(self):
while not self._columns and not self.finished and not self.cancelled:
# Columns are not returned immediately after query is submitted.
# Continue fetching data until columns information is available and push fetched rows into buffer.
self._result.rows += self.fetch()
self._result.rows += self.map_rows(self.fetch())
return self._columns

@property
Expand Down Expand Up @@ -802,7 +830,7 @@ def execute(self, additional_http_headers=None) -> TrinoResult:

# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
self._result.rows += self.map_rows(self.fetch())
return self._result

def _update_state(self, status):
Expand All @@ -822,11 +850,12 @@ def fetch(self) -> List[List[Any]]:
logger.debug(status)
if status.next_uri is None:
self._finished = True
return status.rows

def map_rows(self, rows: List[List[Any]]) -> List[List[Any]]:
if not self._row_mapper:
return []

return self._row_mapper.map(status.rows)
return self._row_mapper.map(rows)

def cancel(self) -> None:
"""Cancel the current query"""
Expand Down