From 94998f3734c5719af372d8606afd6095d49dad8c Mon Sep 17 00:00:00 2001 From: Chuck Cadman Date: Thu, 2 Feb 2023 10:47:09 -0800 Subject: [PATCH] CLN: Put exit_stack inside _query_iterator. --- pandas/io/sql.py | 132 ++++++++++++++++++++++------------------------- 1 file changed, 62 insertions(+), 70 deletions(-) diff --git a/pandas/io/sql.py b/pandas/io/sql.py index b4624a1f4a447..d88decc8601f0 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -75,14 +75,6 @@ # -- Helper functions -def _cleanup_after_generator(generator, exit_stack: ExitStack): - """Does the cleanup after iterating through the generator.""" - try: - yield from generator - finally: - exit_stack.close() - - def _convert_params(sql, params): """Convert SQL and params args to DBAPI2.0 compliant format.""" args = [sql] @@ -1093,6 +1085,7 @@ def insert( def _query_iterator( self, result, + exit_stack: ExitStack, chunksize: str | None, columns, coerce_float: bool = True, @@ -1101,28 +1094,29 @@ def _query_iterator( ): """Return generator through chunked result set.""" has_read_data = False - while True: - data = result.fetchmany(chunksize) - if not data: - if not has_read_data: - yield DataFrame.from_records( - [], columns=columns, coerce_float=coerce_float - ) - break + with exit_stack: + while True: + data = result.fetchmany(chunksize) + if not data: + if not has_read_data: + yield DataFrame.from_records( + [], columns=columns, coerce_float=coerce_float + ) + break - has_read_data = True - self.frame = _convert_arrays_to_dataframe( - data, columns, coerce_float, use_nullable_dtypes - ) + has_read_data = True + self.frame = _convert_arrays_to_dataframe( + data, columns, coerce_float, use_nullable_dtypes + ) - self._harmonize_columns( - parse_dates=parse_dates, use_nullable_dtypes=use_nullable_dtypes - ) + self._harmonize_columns( + parse_dates=parse_dates, use_nullable_dtypes=use_nullable_dtypes + ) - if self.index is not None: - self.frame.set_index(self.index, inplace=True) + if self.index is not None: + self.frame.set_index(self.index, inplace=True) - yield self.frame + yield self.frame def read( self, @@ -1147,16 +1141,14 @@ def read( column_names = result.keys() if chunksize is not None: - return _cleanup_after_generator( - self._query_iterator( - result, - chunksize, - column_names, - coerce_float=coerce_float, - parse_dates=parse_dates, - use_nullable_dtypes=use_nullable_dtypes, - ), + return self._query_iterator( + result, exit_stack, + chunksize, + column_names, + coerce_float=coerce_float, + parse_dates=parse_dates, + use_nullable_dtypes=use_nullable_dtypes, ) else: data = result.fetchall() @@ -1693,6 +1685,7 @@ def read_table( @staticmethod def _query_iterator( result, + exit_stack: ExitStack, chunksize: int, columns, index_col=None, @@ -1703,31 +1696,32 @@ def _query_iterator( ): """Return generator through chunked result set""" has_read_data = False - while True: - data = result.fetchmany(chunksize) - if not data: - if not has_read_data: - yield _wrap_result( - [], - columns, - index_col=index_col, - coerce_float=coerce_float, - parse_dates=parse_dates, - dtype=dtype, - use_nullable_dtypes=use_nullable_dtypes, - ) - break + with exit_stack: + while True: + data = result.fetchmany(chunksize) + if not data: + if not has_read_data: + yield _wrap_result( + [], + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + use_nullable_dtypes=use_nullable_dtypes, + ) + break - has_read_data = True - yield _wrap_result( - data, - columns, - index_col=index_col, - coerce_float=coerce_float, - parse_dates=parse_dates, - dtype=dtype, - use_nullable_dtypes=use_nullable_dtypes, - ) + has_read_data = True + yield _wrap_result( + data, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + use_nullable_dtypes=use_nullable_dtypes, + ) def read_query( self, @@ -1793,18 +1787,16 @@ def read_query( if chunksize is not None: self.returns_generator = True - return _cleanup_after_generator( - self._query_iterator( - result, - chunksize, - columns, - index_col=index_col, - coerce_float=coerce_float, - parse_dates=parse_dates, - dtype=dtype, - use_nullable_dtypes=use_nullable_dtypes, - ), + return self._query_iterator( + result, self.exit_stack, + chunksize, + columns, + index_col=index_col, + coerce_float=coerce_float, + parse_dates=parse_dates, + dtype=dtype, + use_nullable_dtypes=use_nullable_dtypes, ) else: data = result.fetchall()