From 6ce0bb1f5f72c8553f48fc308026d96014cdb1e1 Mon Sep 17 00:00:00 2001 From: yassun7010 Date: Sat, 10 Feb 2024 09:41:25 +0900 Subject: [PATCH 1/2] feat: use header option. --- .../turu/snowflake/record/record_cursor.py | 10 +++++-- turu-snowflake/tests/turu/test_snowflake.py | 29 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/turu-snowflake/src/turu/snowflake/record/record_cursor.py b/turu-snowflake/src/turu/snowflake/record/record_cursor.py index eee8644..2602011 100644 --- a/turu-snowflake/src/turu/snowflake/record/record_cursor.py +++ b/turu-snowflake/src/turu/snowflake/record/record_cursor.py @@ -8,9 +8,15 @@ def fetch_pandas_all(self, **kwargs) -> "PandasDataFrame": if isinstance(self._recorder, turu.core.record.CsvRecorder): if limit := self._recorder._options.get("limit"): - df.head(limit).to_csv(self._recorder.file, index=False) + record_df = df.head(limit) else: - df.to_csv(self._recorder.file, index=False) + record_df = df + + record_df.to_csv( + self._recorder.file, + index=False, + header=self._recorder._options.get("header", True), + ) return df diff --git a/turu-snowflake/tests/turu/test_snowflake.py b/turu-snowflake/tests/turu/test_snowflake.py index 20bf14a..5fd4bc5 100644 --- a/turu-snowflake/tests/turu/test_snowflake.py +++ b/turu-snowflake/tests/turu/test_snowflake.py @@ -266,6 +266,35 @@ def test_record_pandas_dataframe(self, connection: turu.snowflake.Connection): ).lstrip() ) + @pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed") + def test_record_pandas_dataframe_without_header_option( + self, connection: turu.snowflake.Connection + ): + import pandas as pd # type: ignore[import] + from pandas.testing import assert_frame_equal # type: ignore[import] + + with tempfile.NamedTemporaryFile() as file: + with record_to_csv( + file.name, + connection.execute_map( + pd.DataFrame, "select 1 as ID union all select 2 AS ID" + ), + header=False, + ) as cursor: + expected = pd.DataFrame({"ID": [1, 2]}, dtype="int8") + + assert_frame_equal(cursor.fetch_pandas_all(), expected) + + assert ( + Path(file.name).read_text() + == dedent( + """ + 1 + 2 + """ + ).lstrip() + ) + @pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed") def test_record_pandas_dataframe_with_limit_option( self, connection: turu.snowflake.Connection From 6e27941f45a39c50c20a7d684c961e31cc56b02b Mon Sep 17 00:00:00 2001 From: yassun7010 Date: Sat, 10 Feb 2024 09:44:16 +0900 Subject: [PATCH 2/2] fix: async version. --- .../snowflake/record/async_record_cursor.py | 10 ++- .../tests/turu/test_snowflake_async.py | 70 +++++++++++++++++-- 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/turu-snowflake/src/turu/snowflake/record/async_record_cursor.py b/turu-snowflake/src/turu/snowflake/record/async_record_cursor.py index 8260b3f..641751b 100644 --- a/turu-snowflake/src/turu/snowflake/record/async_record_cursor.py +++ b/turu-snowflake/src/turu/snowflake/record/async_record_cursor.py @@ -8,9 +8,15 @@ async def fetch_pandas_all(self, **kwargs) -> "PandasDataFrame": if isinstance(self._recorder, turu.core.record.CsvRecorder): if limit := self._recorder._options.get("limit"): - df.head(limit).to_csv(self._recorder.file, index=False) + record_df = df.head(limit) else: - df.to_csv(self._recorder.file, index=False) + record_df = df + + record_df.to_csv( + self._recorder.file, + index=False, + header=self._recorder._options.get("header", True), + ) return df diff --git a/turu-snowflake/tests/turu/test_snowflake_async.py b/turu-snowflake/tests/turu/test_snowflake_async.py index 6c1592e..692e725 100644 --- a/turu-snowflake/tests/turu/test_snowflake_async.py +++ b/turu-snowflake/tests/turu/test_snowflake_async.py @@ -379,14 +379,72 @@ async def test_record_pandas_dataframe( "select 1 as ID union all select 2 AS ID", ), ) as cursor: - expected = pd.DataFrame( - {"ID": [1, 2]}, - dtype="int8", - ) + expected = pd.DataFrame({"ID": [1, 2]}, dtype="int8") + + assert_frame_equal(await cursor.fetch_pandas_all(), expected) + + assert ( + Path(file.name).read_text() + == dedent( + """ + ID + 1 + 2 + """ + ).lstrip() + ) + + @pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed") + @pytest.mark.asyncio + async def test_record_pandas_dataframe_without_header_option( + self, async_connection: turu.snowflake.AsyncConnection + ): + import pandas as pd # type: ignore[import] + from pandas.testing import assert_frame_equal # type: ignore[import] + + with tempfile.NamedTemporaryFile() as file: + async with record_to_csv( + file.name, + await async_connection.execute_map( + pd.DataFrame, + "select 1 as ID union all select 2 AS ID", + ), + header=False, + ) as cursor: + expected = pd.DataFrame({"ID": [1, 2]}, dtype="int8") + + assert_frame_equal(await cursor.fetch_pandas_all(), expected) + + assert ( + Path(file.name).read_text() + == dedent( + """ + 1 + 2 + """ + ).lstrip() + ) + + @pytest.mark.skipif(not USE_PANDAS, reason="pandas is not installed") + @pytest.mark.asyncio + async def test_record_pandas_dataframe_with_limit_option( + self, async_connection: turu.snowflake.AsyncConnection + ): + import pandas as pd # type: ignore[import] + from pandas.testing import assert_frame_equal # type: ignore[import] + + with tempfile.NamedTemporaryFile() as file: + async with record_to_csv( + file.name, + await async_connection.execute_map( + pd.DataFrame, + "select value::integer as ID from table(flatten(ARRAY_GENERATE_RANGE(1, 10)))", + ), + limit=2, + ) as cursor: + expected = pd.DataFrame({"ID": list(range(1, 10))}, dtype="object") assert_frame_equal(await cursor.fetch_pandas_all(), expected) - for row in expected.values: - print(row) assert ( Path(file.name).read_text()