Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use threadpools for filling the cache and listing schemas (#2127) #2157

Merged
merged 2 commits into from
Feb 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
- Fix an issue where dbt rendered source test args, fix issue where dbt ran an extra compile pass over the wrapped SQL. ([#2114](https://github.com/fishtown-analytics/dbt/issues/2114), [#2150](https://github.com/fishtown-analytics/dbt/pull/2150))
- Set more upper bounds for jinja2,requests, and idna dependencies, upgrade snowflake-connector-python ([#2147](https://github.com/fishtown-analytics/dbt/issues/2147), [#2151](https://github.com/fishtown-analytics/dbt/pull/2151))

### Under the hood
- Parallelize filling the cache and listing schemas in each database during startup ([#2127](https://github.com/fishtown-analytics/dbt/issues/2127), [#2157](https://github.com/fishtown-analytics/dbt/pull/2157))

Contributors:
- [@bubbomb](https://github.com/bubbomb) ([#2080](https://github.com/fishtown-analytics/dbt/pull/2080))
- [@sonac](https://github.com/sonac) ([#2078](https://github.com/fishtown-analytics/dbt/pull/2078))
Expand Down
37 changes: 23 additions & 14 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import abc
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import Future # noqa - we use this for typing only
from concurrent.futures import as_completed, Future
from contextlib import contextmanager
from datetime import datetime
from typing import (
Expand All @@ -27,7 +26,7 @@
from dbt.exceptions import warn_or_error
from dbt.node_types import NodeType
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.utils import filter_null_values
from dbt.utils import filter_null_values, executor

from dbt.adapters.base.connections import BaseConnectionManager, Connection
from dbt.adapters.base.meta import AdapterMeta, available
Expand Down Expand Up @@ -358,23 +357,35 @@ def _get_cache_schemas(
# databases
return info_schema_name_map

def _list_relations_get_connection(
self, db: BaseRelation, schema: str
) -> List[BaseRelation]:
with self.connection_named(f'list_{db.database}_{schema}'):
return self.list_relations_without_caching(db, schema)

def _relations_cache_for_schemas(self, manifest: Manifest) -> None:
"""Populate the relations cache for the given schemas. Returns an
iterable of the schemas populated, as strings.
"""
if not dbt.flags.USE_CACHE:
return

info_schema_name_map = self._get_cache_schemas(manifest,
exec_only=True)
for db, schema in info_schema_name_map.search():
for relation in self.list_relations_without_caching(db, schema):
self.cache.add(relation)
schema_map = self._get_cache_schemas(manifest, exec_only=True)
with executor(self.config) as tpe:
futures: List[Future[List[BaseRelation]]] = [
tpe.submit(self._list_relations_get_connection, db, schema)
for db, schema in schema_map.search()
]
for future in as_completed(futures):
# if we can't read the relations we need to just raise anyway,
# so just call future.result() and let that raise on failure
for relation in future.result():
self.cache.add(relation)

# it's possible that there were no relations in some schemas. We want
# to insert the schemas we query into the cache's `.schemas` attribute
# so we can check it later
self.cache.update_schemas(info_schema_name_map.schemas_searched())
self.cache.update_schemas(schema_map.schemas_searched())

def set_relations_cache(
self, manifest: Manifest, clear: bool = False
Expand Down Expand Up @@ -1047,13 +1058,11 @@ def _get_one_catalog(
def get_catalog(
self, manifest: Manifest
) -> Tuple[agate.Table, List[Exception]]:
# snowflake is super slow. split it out into the specified threads
num_threads = self.config.threads
schema_map = self._get_cache_schemas(manifest)

with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [
executor.submit(self._get_one_catalog, info, schemas, manifest)
with executor(self.config) as tpe:
futures: List[Future[agate.Table]] = [
tpe.submit(self._get_one_catalog, info, schemas, manifest)
for info, schemas in schema_map.items() if len(schemas) > 0
]
catalogs, exceptions = catch_as_completed(futures)
Expand Down
46 changes: 38 additions & 8 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import time
from concurrent.futures import as_completed
from datetime import datetime
from multiprocessing.dummy import Pool as ThreadPool
from typing import Optional, Dict, List, Set, Tuple
from typing import Optional, Dict, List, Set, Tuple, Iterable

from dbt.task.base import ConfiguredTask
from dbt.adapters.factory import get_adapter
Expand Down Expand Up @@ -374,7 +375,9 @@ def interpret_results(self, results):
failures = [r for r in results if r.error or r.fail]
return len(failures) == 0

def get_model_schemas(self, selected_uids):
def get_model_schemas(
self, selected_uids: Iterable[str]
) -> Set[Tuple[str, str]]:
if self.manifest is None:
raise InternalException('manifest was None in get_model_schemas')

Expand All @@ -387,19 +390,46 @@ def get_model_schemas(self, selected_uids):

return schemas

def create_schemas(self, adapter, selected_uids):
def create_schemas(self, adapter, selected_uids: Iterable[str]):
required_schemas = self.get_model_schemas(selected_uids)
required_databases = set(db for db, _ in required_schemas)

existing_schemas_lowered: Set[Tuple[str, str]] = set()
for db in required_databases:
existing_schemas_lowered.update(
(db.lower(), s.lower()) for s in adapter.list_schemas(db))

for db, schema in required_schemas:
if (db.lower(), schema.lower()) not in existing_schemas_lowered:
def list_schemas(db: str) -> List[Tuple[str, str]]:
with adapter.connection_named(f'list_{db}'):
return [
(db.lower(), s.lower())
for s in adapter.list_schemas(db)
]

def create_schema(db: str, schema: str) -> None:
with adapter.connection_named(f'create_{db}_{schema}'):
adapter.create_schema(db, schema)

list_futures = []
create_futures = []

with dbt.utils.executor(self.config) as tpe:
list_futures = [
tpe.submit(list_schemas, db) for db in required_databases
]

for ls_future in as_completed(list_futures):
existing_schemas_lowered.update(ls_future.result())

for db, schema in required_schemas:
db_schema = (db.lower(), schema.lower())
if db_schema not in existing_schemas_lowered:
existing_schemas_lowered.add(db_schema)
create_futures.append(
tpe.submit(create_schema, db, schema)
)

for create_future in as_completed(create_futures):
# trigger/re-raise any excceptions while creating schemas
create_future.result()

def get_result(self, results, elapsed_time, generated_at):
return ExecutionResult(
results=results,
Expand Down
47 changes: 47 additions & 0 deletions core/dbt/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import concurrent.futures
import copy
import datetime
import decimal
Expand All @@ -8,6 +9,7 @@
import json
import os
from enum import Enum
from typing_extensions import Protocol
from typing import (
Tuple, Type, Any, Optional, TypeVar, Dict, Union, Callable
)
Expand Down Expand Up @@ -489,3 +491,48 @@ def format_bytes(num_bytes):
num_bytes /= 1024.0

return "> 1024 TB"


# a little concurrent.futures.Executor for single-threaded mode
class SingleThreadedExecutor(concurrent.futures.Executor):
def submit(*args, **kwargs):
# this basic pattern comes from concurrent.futures.Executor itself,
# but without handling the `fn=` form.
if len(args) >= 2:
self, fn, *args = args
elif not args:
raise TypeError(
"descriptor 'submit' of 'SingleThreadedExecutor' object needs "
"an argument"
)
else:
raise TypeError(
'submit expected at least 1 positional argument, '
'got %d' % (len(args) - 1)
)
fut = concurrent.futures.Future()
try:
result = fn(*args, **kwargs)
except Exception as exc:
fut.set_exception(exc)
else:
fut.set_result(result)
return fut


class ThreadedArgs(Protocol):
single_threaded: bool


class HasThreadingConfig(Protocol):
args: ThreadedArgs
threads: Optional[int]


def executor(config: HasThreadingConfig) -> concurrent.futures.Executor:
if config.args.single_threaded:
return SingleThreadedExecutor()
else:
return concurrent.futures.ThreadPoolExecutor(
max_workers=config.threads
)
1 change: 0 additions & 1 deletion test/integration/001_simple_copy_test/test_simple_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def test__postgres__simple_copy_with_materialized_views(self):
select * from {schema}.unrelated_materialized_view
)
'''.format(schema=self.unique_schema()))

results = self.run_dbt(["seed"])
self.assertEqual(len(results), 1)
results = self.run_dbt()
Expand Down
1 change: 1 addition & 0 deletions test/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def normalize(path):

class Obj:
which = 'blah'
single_threaded = False


def mock_connection(name):
Expand Down