diff --git a/.changes/unreleased/Fixes-20230508-060926.yaml b/.changes/unreleased/Fixes-20230508-060926.yaml new file mode 100644 index 00000000000..764e628d27b --- /dev/null +++ b/.changes/unreleased/Fixes-20230508-060926.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: push down limit filtering to adapter +time: 2023-05-08T06:09:26.455524-07:00 +custom: + Author: aranke + Issue: "7390" diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index a150f8296b8..7bb16c7ea4e 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -274,7 +274,7 @@ def connection_for(self, node: ResultNode) -> Iterator[None]: @available.parse(lambda *a, **k: ("", empty_table())) def execute( - self, sql: str, auto_begin: bool = False, fetch: bool = False + self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None ) -> Tuple[AdapterResponse, agate.Table]: """Execute the given SQL. This is a thin wrapper around ConnectionManager.execute. @@ -283,10 +283,11 @@ def execute( :param bool auto_begin: If set, and dbt is not currently inside a transaction, automatically begin one. :param bool fetch: If set, fetch results. + :param Optional[int] limit: If set, only fetch n number of rows :return: A tuple of the query status and results (empty if fetch=False). :rtype: Tuple[AdapterResponse, agate.Table] """ - return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch) + return self.connections.execute(sql=sql, auto_begin=auto_begin, fetch=fetch, limit=limit) @available.parse(lambda *a, **k: []) def get_column_schema_from_query(self, sql: str) -> List[BaseColumn]: diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index 88e4a30d0b6..464c07871a0 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -118,13 +118,16 @@ def process_results( return [dict(zip(column_names, row)) for row in rows] @classmethod - def get_result_from_cursor(cls, cursor: Any) -> agate.Table: + def get_result_from_cursor(cls, cursor: Any, limit: Optional[int]) -> agate.Table: data: List[Any] = [] column_names: List[str] = [] if cursor.description is not None: column_names = [col[0] for col in cursor.description] - rows = cursor.fetchall() + if limit: + rows = cursor.fetchmany(limit) + else: + rows = cursor.fetchall() data = cls.process_results(column_names, rows) return dbt.clients.agate_helper.table_from_data_flat(data, column_names) @@ -138,13 +141,13 @@ def data_type_code_to_name(cls, type_code: Union[int, str]) -> str: ) def execute( - self, sql: str, auto_begin: bool = False, fetch: bool = False + self, sql: str, auto_begin: bool = False, fetch: bool = False, limit: Optional[int] = None ) -> Tuple[AdapterResponse, agate.Table]: sql = self._add_query_comment(sql) _, cursor = self.add_query(sql, auto_begin) response = self.get_response(cursor) if fetch: - table = self.get_result_from_cursor(cursor) + table = self.get_result_from_cursor(cursor, limit) else: table = dbt.clients.agate_helper.empty_table() return response, table diff --git a/core/dbt/task/show.py b/core/dbt/task/show.py index 0ea7160bfa7..4718e2da352 100644 --- a/core/dbt/task/show.py +++ b/core/dbt/task/show.py @@ -19,8 +19,12 @@ def __init__(self, config, adapter, node, node_index, num_nodes): def execute(self, compiled_node, manifest): start_time = time.time() + + # Allow passing in -1 (or any negative number) to get all rows + limit = None if self.config.args.limit < 0 else self.config.args.limit + adapter_response, execute_result = self.adapter.execute( - compiled_node.compiled_code, fetch=True + compiled_node.compiled_code, fetch=True, limit=limit ) end_time = time.time() @@ -66,13 +70,8 @@ def task_end_messages(self, results): ) for result in matched_results: - # Allow passing in -1 (or any negative number) to get all rows table = result.agate_table - if self.args.limit >= 0: - table = table.limit(self.args.limit) - result.agate_table = table - # Hack to get Agate table output as string output = io.StringIO() if self.args.output == "json": diff --git a/tests/functional/show/test_show.py b/tests/functional/show/test_show.py index 4cda22935d0..5990106872f 100644 --- a/tests/functional/show/test_show.py +++ b/tests/functional/show/test_show.py @@ -87,6 +87,20 @@ def test_second_ephemeral_model(self, project): ) assert "col_hundo" in log_output + @pytest.mark.parametrize( + "args,expected", + [ + ([], 5), # default limit + (["--limit", 3], 3), # fetch 3 rows + (["--limit", -1], 7), # fetch all rows + ], + ) + def test_limit(self, project, args, expected): + run_dbt(["build"]) + dbt_args = ["show", "--inline", models__second_ephemeral_model, *args] + results, log_output = run_dbt_and_capture(dbt_args) + assert len(results.results[0].agate_table) == expected + def test_seed(self, project): (results, log_output) = run_dbt_and_capture(["show", "--select", "sample_seed"]) assert "Previewing node 'sample_seed'" in log_output