From 83fabc917075b33fa5bb470cf5128400a3ba7eb7 Mon Sep 17 00:00:00 2001 From: Kshitij Aranke <kshitij.aranke@dbtlabs.com> Date: Mon, 8 May 2023 06:08:37 -0700 Subject: [PATCH 1/3] fix #7390: push down limit filtering to adapter --- .../unreleased/Fixes-20230508-060926.yaml | 6 ++++++ core/dbt/adapters/base/impl.py | 5 +++-- core/dbt/adapters/sql/connections.py | 11 ++++++---- core/dbt/task/show.py | 11 +++++----- tests/functional/show/test_show.py | 21 +++++++++++++++++++ 5 files changed, 42 insertions(+), 12 deletions(-) create mode 100644 .changes/unreleased/Fixes-20230508-060926.yaml diff --git a/.changes/unreleased/Fixes-20230508-060926.yaml b/.changes/unreleased/Fixes-20230508-060926.yaml new file mode 100644 index 00000000000..764e628d27b --- /dev/null +++ b/.changes/unreleased/Fixes-20230508-060926.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: push down limit filtering to adapter +time: 2023-05-08T06:09:26.455524-07:00 +custom: + Author: aranke + Issue: "7390" diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index b2614555d01..cdad9f57fc3 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -274,7 +274,7 @@ def connection_for(self, node: ResultNode) -> Iterator[None]: @available.parse(lambda *a, **k: ("", empty_table())) def execute( - self, sql: str, auto_begin: bool = False, fetch: bool = False + self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None ) -> Tuple[AdapterResponse, agate.Table]: """Execute the given SQL. This is a thin wrapper around ConnectionManager.execute. @@ -283,10 +283,11 @@ def execute( :param bool auto_begin: If set, and dbt is not currently inside a transaction, automatically begin one. :param bool fetch: If set, fetch results. + :param Optional[int] limit: If set, only fetch n number of rows :return: A tuple of the query status and results (empty if fetch=False). :rtype: Tuple[AdapterResponse, agate.Table] """ - return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch) + return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch, limit=limit) @available.parse(lambda *a, **k: []) def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]: diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index 88e4a30d0b6..464c07871a0 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -118,13 +118,16 @@ def process_results( return [dict(zip(column_names, row)) for row in rows] @classmethod - def get_result_from_cursor(cls, cursor: Any) -> agate.Table: + def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> agate.Table: data: List[Any] = [] column_names: List[str] = [] if cursor.description is not None: column_names = [col[0] for col in cursor.description] - rows = cursor.fetchall() + if limit: + rows = cursor.fetchmany(limit) + else: + rows = cursor.fetchall() data = cls.process_results(column_names, rows) return dbt.clients.agate_helper.table_from_data_flat(data, column_names) @@ -138,13 +141,13 @@ def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: ) def execute( - self, sql: str, auto_begin: bool = False, fetch: bool = False + self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None ) -> Tuple[AdapterResponse, agate.Table]: sql = self._add_query_comment(sql) _, cursor = self.add_query(sql, auto_begin) response = self.get_response(cursor) if fetch: - table = self.get_result_from_cursor(cursor) + table = self.get_result_from_cursor(cursor, limit) else: table = dbt.clients.agate_helper.empty_table() return response, table diff --git a/core/dbt/task/show.py b/core/dbt/task/show.py index 58e0948bcfb..4a2518db792 100644 --- a/core/dbt/task/show.py +++ b/core/dbt/task/show.py @@ -17,8 +17,12 @@ def __init__(self, config, adapter, node, node_index, num_nodes): def execute(self, compiled_node, manifest): start_time = time.time() + + # Allow passing in -1 (or any negative number) to get all rows + limit = None if self.config.args.limit < 0 else self.config.args.limit + adapter_response, execute_result = self.adapter.execute( - compiled_node.compiled_code, fetch=True + compiled_node.compiled_code, fetch=True, limit=limit ) end_time = time.time() @@ -61,13 +65,8 @@ def task_end_messages(self, results): ) for result in matched_results: - # Allow passing in -1 (or any negative number) to get all rows table = result.agate_table - if self.args.limit >= 0: - table = table.limit(self.args.limit) - result.agate_table = table - # Hack to get Agate table output as string output = io.StringIO() if self.args.output == "json": diff --git a/tests/functional/show/test_show.py b/tests/functional/show/test_show.py index c5684197ec5..3b8890940af 100644 --- a/tests/functional/show/test_show.py +++ b/tests/functional/show/test_show.py @@ -85,3 +85,24 @@ def test_second_ephemeral_model(self, project): ["show", "--inline", models__second_ephemeral_model] ) assert "col_hundo" in log_output + + def test_limit_default(self, project): + run_dbt(["build"]) + (results, log_output) = run_dbt_and_capture( + ["show", "--inline", models__second_ephemeral_model] + ) + assert len(results.results[0].agate_table) == 5 + + def test_limit_positive_number(self, project): + run_dbt(["build"]) + (results, log_output) = run_dbt_and_capture( + ["show", "--inline", models__second_ephemeral_model, "--limit", 3] + ) + assert len(results.results[0].agate_table) == 3 + + def test_limit_negative_number(self, project): + run_dbt(["build"]) + (results, log_output) = run_dbt_and_capture( + ["show", "--inline", models__second_ephemeral_model, "--limit", -1] + ) + assert len(results.results[0].agate_table) == 7 From 01303b14844b21653060207c7029d5046cc24419 Mon Sep 17 00:00:00 2001 From: Kshitij Aranke <kshitij.aranke@dbtlabs.com> Date: Mon, 8 May 2023 14:55:39 -0700 Subject: [PATCH 2/3] parametrize test --- tests/functional/show/test_show.py | 31 ++++++++++++------------------ 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/tests/functional/show/test_show.py b/tests/functional/show/test_show.py index 3b8890940af..ad6ae0eb78d 100644 --- a/tests/functional/show/test_show.py +++ b/tests/functional/show/test_show.py @@ -86,23 +86,16 @@ def test_second_ephemeral_model(self, project): ) assert "col_hundo" in log_output - def test_limit_default(self, project): + @pytest.mark.parametrize( + "args,expected", + [ + ([], 5), # default limit + (["--limit", 3], 3), # fetch 3 rows + (["--limit", -1], 7), # fetch all rows + ], + ) + def test_limit(self, project, args, expected): run_dbt(["build"]) - (results, log_output) = run_dbt_and_capture( - ["show", "--inline", models__second_ephemeral_model] - ) - assert len(results.results[0].agate_table) == 5 - - def test_limit_positive_number(self, project): - run_dbt(["build"]) - (results, log_output) = run_dbt_and_capture( - ["show", "--inline", models__second_ephemeral_model, "--limit", 3] - ) - assert len(results.results[0].agate_table) == 3 - - def test_limit_negative_number(self, project): - run_dbt(["build"]) - (results, log_output) = run_dbt_and_capture( - ["show", "--inline", models__second_ephemeral_model, "--limit", -1] - ) - assert len(results.results[0].agate_table) == 7 + dbt_args = ["show", "--inline", models__second_ephemeral_model, *args] + results, log_output = run_dbt_and_capture(dbt_args) + assert len(results.results[0].agate_table) == expected From b178ce045fcbd3533b2941bc5df2493ea455bd05 Mon Sep 17 00:00:00 2001 From: Kshitij Aranke <kshitij.aranke@dbtlabs.com> Date: Mon, 8 May 2023 21:39:24 -0700 Subject: [PATCH 3/3] remove whitespace indent --- tests/functional/show/test_show.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/functional/show/test_show.py b/tests/functional/show/test_show.py index 639cc69b8df..5990106872f 100644 --- a/tests/functional/show/test_show.py +++ b/tests/functional/show/test_show.py @@ -100,7 +100,7 @@ def test_limit(self, project, args, expected): dbt_args = ["show", "--inline", models__second_ephemeral_model, *args] results, log_output = run_dbt_and_capture(dbt_args) assert len(results.results[0].agate_table) == expected - + def test_seed(self, project): (results, log_output) = run_dbt_and_capture(["show", "--select", "sample_seed"]) assert "Previewing node 'sample_seed'" in log_output