Skip to content

Commit 28b7998

Browse files
committed
fix: fix tests
1 parent 8035fce commit 28b7998

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

tests/test_arrow_table_queries.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
from chdb import session
1111
from urllib.request import urlretrieve
1212

13-
if os.path.exists(".test_chdb_arrow_table"):
14-
shutil.rmtree(".test_chdb_arrow_table", ignore_errors=True)
15-
sess = session.Session(".test_chdb_arrow_table")
13+
# Clean up and create session in the test methods instead of globally
1614

1715
class TestChDBArrowTable(unittest.TestCase):
1816
@classmethod
@@ -33,11 +31,16 @@ def setUpClass(cls):
3331

3432
print(f"Loaded Arrow table: {cls.num_rows} rows, {cls.num_columns} columns, {cls.table_size} bytes")
3533

34+
if os.path.exists(".test_chdb_arrow_table"):
35+
shutil.rmtree(".test_chdb_arrow_table", ignore_errors=True)
36+
cls.sess = session.Session(".test_chdb_arrow_table")
37+
3638
@classmethod
3739
def tearDownClass(cls):
3840
# Clean up session directory
3941
if os.path.exists(".test_chdb_arrow_table"):
4042
shutil.rmtree(".test_chdb_arrow_table", ignore_errors=True)
43+
cls.sess.close()
4144

4245
def setUp(self):
4346
pass
@@ -54,23 +57,23 @@ def test_arrow_table_basic_info(self):
5457
def test_arrow_table_count(self):
5558
"""Test counting rows in Arrow table"""
5659
my_arrow_table = self.arrow_table
57-
result = sess.query("SELECT COUNT(*) as row_count FROM Python(my_arrow_table)", "CSV")
60+
result = self.sess.query("SELECT COUNT(*) as row_count FROM Python(my_arrow_table)", "CSV")
5861
lines = str(result).strip().split('\n')
5962
count = int(lines[0])
6063
self.assertEqual(count, self.num_rows, f"Count should match table rows: {self.num_rows}")
6164

6265
def test_arrow_table_schema(self):
6366
"""Test querying Arrow table schema information"""
6467
my_arrow_table = self.arrow_table
65-
result = sess.query("DESCRIBE Python(my_arrow_table)", "CSV")
68+
result = self.sess.query("DESCRIBE Python(my_arrow_table)", "CSV")
6669
# print(result)
6770
self.assertIn('WatchID', str(result))
6871
self.assertIn('URLHash', str(result))
6972

7073
def test_arrow_table_limit(self):
7174
"""Test LIMIT queries on Arrow table"""
7275
my_arrow_table = self.arrow_table
73-
result = sess.query("SELECT * FROM Python(my_arrow_table) LIMIT 5", "CSV")
76+
result = self.sess.query("SELECT * FROM Python(my_arrow_table) LIMIT 5", "CSV")
7477
lines = str(result).strip().split('\n')
7578
self.assertEqual(len(lines), 5, "Should have 5 data rows")
7679

@@ -82,7 +85,7 @@ def test_arrow_table_select_columns(self):
8285
first_col = schema.field(0).name
8386
second_col = schema.field(1).name if len(schema) > 1 else first_col
8487

85-
result = sess.query(f"SELECT {first_col}, {second_col} FROM Python(my_arrow_table) LIMIT 3", "CSV")
88+
result = self.sess.query(f"SELECT {first_col}, {second_col} FROM Python(my_arrow_table) LIMIT 3", "CSV")
8689
lines = str(result).strip().split('\n')
8790
self.assertEqual(len(lines), 3, "Should have 3 data rows")
8891

@@ -96,7 +99,7 @@ def test_arrow_table_where_clause(self):
9699
numeric_col = field.name
97100
break
98101

99-
result = sess.query(f"SELECT COUNT(*) FROM Python(my_arrow_table) WHERE {numeric_col} > 1", "CSV")
102+
result = self.sess.query(f"SELECT COUNT(*) FROM Python(my_arrow_table) WHERE {numeric_col} > 1", "CSV")
100103
lines = str(result).strip().split('\n')
101104
count = int(lines[0])
102105
self.assertEqual(count, 1000000)
@@ -111,7 +114,7 @@ def test_arrow_table_group_by(self):
111114
string_col = field.name
112115
break
113116

114-
result = sess.query(f"SELECT {string_col}, COUNT(*) as cnt FROM Python(my_arrow_table) GROUP BY {string_col} ORDER BY cnt DESC LIMIT 5", "CSV")
117+
result = self.sess.query(f"SELECT {string_col}, COUNT(*) as cnt FROM Python(my_arrow_table) GROUP BY {string_col} ORDER BY cnt DESC LIMIT 5", "CSV")
115118
lines = str(result).strip().split('\n')
116119
self.assertEqual(len(lines), 5)
117120

@@ -125,7 +128,7 @@ def test_arrow_table_aggregations(self):
125128
numeric_col = field.name
126129
break
127130

128-
result = sess.query(f"SELECT AVG({numeric_col}) as avg_val, MIN({numeric_col}) as min_val, MAX({numeric_col}) as max_val FROM Python(my_arrow_table)", "CSV")
131+
result = self.sess.query(f"SELECT AVG({numeric_col}) as avg_val, MIN({numeric_col}) as min_val, MAX({numeric_col}) as max_val FROM Python(my_arrow_table)", "CSV")
129132
lines = str(result).strip().split('\n')
130133
self.assertEqual(len(lines), 1)
131134

@@ -135,14 +138,14 @@ def test_arrow_table_order_by(self):
135138
# Use first column for ordering
136139
first_col = self.arrow_table.schema.field(0).name
137140

138-
result = sess.query(f"SELECT {first_col} FROM Python(my_arrow_table) ORDER BY {first_col} LIMIT 10", "CSV")
141+
result = self.sess.query(f"SELECT {first_col} FROM Python(my_arrow_table) ORDER BY {first_col} LIMIT 10", "CSV")
139142
lines = str(result).strip().split('\n')
140143
self.assertEqual(len(lines), 10)
141144

142145
def test_arrow_table_subquery(self):
143146
"""Test subqueries with Arrow table"""
144147
my_arrow_table = self.arrow_table
145-
result = sess.query("""
148+
result = self.sess.query("""
146149
SELECT COUNT(*) as total_count
147150
FROM (
148151
SELECT * FROM Python(my_arrow_table)
@@ -161,7 +164,7 @@ def test_arrow_table_multiple_tables(self):
161164
# Create a smaller subset table
162165
subset_table = my_arrow_table.slice(0, min(100, my_arrow_table.num_rows))
163166

164-
result = sess.query("""
167+
result = self.sess.query("""
165168
SELECT
166169
(SELECT COUNT(*) FROM Python(my_arrow_table)) as full_count,
167170
(SELECT COUNT(*) FROM Python(subset_table)) as subset_count

0 commit comments

Comments
 (0)