Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #7390: push down limit filtering to adapter #7545

Merged
merged 4 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230508-060926.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
mikealfare marked this conversation as resolved.
Show resolved Hide resolved
body: push down limit filtering to adapter
time: 2023-05-08T06:09:26.455524-07:00
custom:
Author: aranke
Issue: "7390"
5 changes: 3 additions & 2 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def connection_for(self, node: ResultNode) -> Iterator[None]:

@available.parse(lambda *a, **k: ("", empty_table()))
def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None
) -> Tuple[AdapterResponse, agate.Table]:
"""Execute the given SQL. This is a thin wrapper around
ConnectionManager.execute.
Expand All @@ -283,10 +283,11 @@ def execute(
:param bool auto_begin: If set, and dbt is not currently inside a
transaction, automatically begin one.
:param bool fetch: If set, fetch results.
:param Optional[int] limit: If set, only fetch n number of rows
:return: A tuple of the query status and results (empty if fetch=False).
:rtype: Tuple[AdapterResponse, agate.Table]
"""
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch)
return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch, limit=limit)

@available.parse(lambda *a, **k: [])
def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]:
Expand Down
11 changes: 7 additions & 4 deletions core/dbt/adapters/sql/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,16 @@ def process_results(
return [dict(zip(column_names, row)) for row in rows]

@classmethod
def get_result_from_cursor(cls, cursor: Any) -> agate.Table:
def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> agate.Table:
data: List[Any] = []
column_names: List[str] = []

if cursor.description is not None:
column_names = [col[0] for col in cursor.description]
rows = cursor.fetchall()
if limit:
rows = cursor.fetchmany(limit)
else:
rows = cursor.fetchall()
data = cls.process_results(column_names, rows)

return dbt.clients.agate_helper.table_from_data_flat(data, column_names)
Expand All @@ -138,13 +141,13 @@ def data_type_code_to_name(cls, type_code: Union[int, str]) -> str:
)

def execute(
self, sql: str, auto_begin: bool = False, fetch: bool = False
self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None
) -> Tuple[AdapterResponse, agate.Table]:
sql = self._add_query_comment(sql)
_, cursor = self.add_query(sql, auto_begin)
response = self.get_response(cursor)
if fetch:
table = self.get_result_from_cursor(cursor)
table = self.get_result_from_cursor(cursor, limit)
aranke marked this conversation as resolved.
Show resolved Hide resolved
else:
table = dbt.clients.agate_helper.empty_table()
return response, table
Expand Down
11 changes: 5 additions & 6 deletions core/dbt/task/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@ def __init__(self, config, adapter, node, node_index, num_nodes):

def execute(self, compiled_node, manifest):
start_time = time.time()

# Allow passing in -1 (or any negative number) to get all rows
limit = None if self.config.args.limit < 0 else self.config.args.limit

adapter_response, execute_result = self.adapter.execute(
compiled_node.compiled_code, fetch=True
compiled_node.compiled_code, fetch=True, limit=limit
)
end_time = time.time()

Expand Down Expand Up @@ -66,13 +70,8 @@ def task_end_messages(self, results):
)

for result in matched_results:
# Allow passing in -1 (or any negative number) to get all rows
table = result.agate_table

if self.args.limit >= 0:
table = table.limit(self.args.limit)
result.agate_table = table

# Hack to get Agate table output as string
output = io.StringIO()
if self.args.output == "json":
Expand Down
14 changes: 14 additions & 0 deletions tests/functional/show/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@ def test_second_ephemeral_model(self, project):
)
assert "col_hundo" in log_output

@pytest.mark.parametrize(
"args,expected",
[
([], 5), # default limit
(["--limit", 3], 3), # fetch 3 rows
(["--limit", -1], 7), # fetch all rows
],
)
def test_limit(self, project, args, expected):
run_dbt(["build"])
dbt_args = ["show", "--inline", models__second_ephemeral_model, *args]
results, log_output = run_dbt_and_capture(dbt_args)
assert len(results.results[0].agate_table) == expected

def test_seed(self, project):
(results, log_output) = run_dbt_and_capture(["show", "--select", "sample_seed"])
assert "Previewing node 'sample_seed'" in log_output
Expand Down