10
10
from chdb import session
11
11
from urllib .request import urlretrieve
12
12
13
- if os .path .exists (".test_chdb_arrow_table" ):
14
- shutil .rmtree (".test_chdb_arrow_table" , ignore_errors = True )
15
- sess = session .Session (".test_chdb_arrow_table" )
13
+ # Clean up and create session in the test methods instead of globally
16
14
17
15
class TestChDBArrowTable (unittest .TestCase ):
18
16
@classmethod
@@ -33,11 +31,16 @@ def setUpClass(cls):
33
31
34
32
print (f"Loaded Arrow table: { cls .num_rows } rows, { cls .num_columns } columns, { cls .table_size } bytes" )
35
33
34
+ if os .path .exists (".test_chdb_arrow_table" ):
35
+ shutil .rmtree (".test_chdb_arrow_table" , ignore_errors = True )
36
+ cls .sess = session .Session (".test_chdb_arrow_table" )
37
+
36
38
@classmethod
37
39
def tearDownClass (cls ):
38
40
# Clean up session directory
39
41
if os .path .exists (".test_chdb_arrow_table" ):
40
42
shutil .rmtree (".test_chdb_arrow_table" , ignore_errors = True )
43
+ cls .sess .close ()
41
44
42
45
def setUp (self ):
43
46
pass
@@ -54,23 +57,23 @@ def test_arrow_table_basic_info(self):
54
57
def test_arrow_table_count (self ):
55
58
"""Test counting rows in Arrow table"""
56
59
my_arrow_table = self .arrow_table
57
- result = sess .query ("SELECT COUNT(*) as row_count FROM Python(my_arrow_table)" , "CSV" )
60
+ result = self . sess .query ("SELECT COUNT(*) as row_count FROM Python(my_arrow_table)" , "CSV" )
58
61
lines = str (result ).strip ().split ('\n ' )
59
62
count = int (lines [0 ])
60
63
self .assertEqual (count , self .num_rows , f"Count should match table rows: { self .num_rows } " )
61
64
62
65
def test_arrow_table_schema (self ):
63
66
"""Test querying Arrow table schema information"""
64
67
my_arrow_table = self .arrow_table
65
- result = sess .query ("DESCRIBE Python(my_arrow_table)" , "CSV" )
68
+ result = self . sess .query ("DESCRIBE Python(my_arrow_table)" , "CSV" )
66
69
# print(result)
67
70
self .assertIn ('WatchID' , str (result ))
68
71
self .assertIn ('URLHash' , str (result ))
69
72
70
73
def test_arrow_table_limit (self ):
71
74
"""Test LIMIT queries on Arrow table"""
72
75
my_arrow_table = self .arrow_table
73
- result = sess .query ("SELECT * FROM Python(my_arrow_table) LIMIT 5" , "CSV" )
76
+ result = self . sess .query ("SELECT * FROM Python(my_arrow_table) LIMIT 5" , "CSV" )
74
77
lines = str (result ).strip ().split ('\n ' )
75
78
self .assertEqual (len (lines ), 5 , "Should have 5 data rows" )
76
79
@@ -82,7 +85,7 @@ def test_arrow_table_select_columns(self):
82
85
first_col = schema .field (0 ).name
83
86
second_col = schema .field (1 ).name if len (schema ) > 1 else first_col
84
87
85
- result = sess .query (f"SELECT { first_col } , { second_col } FROM Python(my_arrow_table) LIMIT 3" , "CSV" )
88
+ result = self . sess .query (f"SELECT { first_col } , { second_col } FROM Python(my_arrow_table) LIMIT 3" , "CSV" )
86
89
lines = str (result ).strip ().split ('\n ' )
87
90
self .assertEqual (len (lines ), 3 , "Should have 3 data rows" )
88
91
@@ -96,7 +99,7 @@ def test_arrow_table_where_clause(self):
96
99
numeric_col = field .name
97
100
break
98
101
99
- result = sess .query (f"SELECT COUNT(*) FROM Python(my_arrow_table) WHERE { numeric_col } > 1" , "CSV" )
102
+ result = self . sess .query (f"SELECT COUNT(*) FROM Python(my_arrow_table) WHERE { numeric_col } > 1" , "CSV" )
100
103
lines = str (result ).strip ().split ('\n ' )
101
104
count = int (lines [0 ])
102
105
self .assertEqual (count , 1000000 )
@@ -111,7 +114,7 @@ def test_arrow_table_group_by(self):
111
114
string_col = field .name
112
115
break
113
116
114
- result = sess .query (f"SELECT { string_col } , COUNT(*) as cnt FROM Python(my_arrow_table) GROUP BY { string_col } ORDER BY cnt DESC LIMIT 5" , "CSV" )
117
+ result = self . sess .query (f"SELECT { string_col } , COUNT(*) as cnt FROM Python(my_arrow_table) GROUP BY { string_col } ORDER BY cnt DESC LIMIT 5" , "CSV" )
115
118
lines = str (result ).strip ().split ('\n ' )
116
119
self .assertEqual (len (lines ), 5 )
117
120
@@ -125,7 +128,7 @@ def test_arrow_table_aggregations(self):
125
128
numeric_col = field .name
126
129
break
127
130
128
- result = sess .query (f"SELECT AVG({ numeric_col } ) as avg_val, MIN({ numeric_col } ) as min_val, MAX({ numeric_col } ) as max_val FROM Python(my_arrow_table)" , "CSV" )
131
+ result = self . sess .query (f"SELECT AVG({ numeric_col } ) as avg_val, MIN({ numeric_col } ) as min_val, MAX({ numeric_col } ) as max_val FROM Python(my_arrow_table)" , "CSV" )
129
132
lines = str (result ).strip ().split ('\n ' )
130
133
self .assertEqual (len (lines ), 1 )
131
134
@@ -135,14 +138,14 @@ def test_arrow_table_order_by(self):
135
138
# Use first column for ordering
136
139
first_col = self .arrow_table .schema .field (0 ).name
137
140
138
- result = sess .query (f"SELECT { first_col } FROM Python(my_arrow_table) ORDER BY { first_col } LIMIT 10" , "CSV" )
141
+ result = self . sess .query (f"SELECT { first_col } FROM Python(my_arrow_table) ORDER BY { first_col } LIMIT 10" , "CSV" )
139
142
lines = str (result ).strip ().split ('\n ' )
140
143
self .assertEqual (len (lines ), 10 )
141
144
142
145
def test_arrow_table_subquery (self ):
143
146
"""Test subqueries with Arrow table"""
144
147
my_arrow_table = self .arrow_table
145
- result = sess .query ("""
148
+ result = self . sess .query ("""
146
149
SELECT COUNT(*) as total_count
147
150
FROM (
148
151
SELECT * FROM Python(my_arrow_table)
@@ -161,7 +164,7 @@ def test_arrow_table_multiple_tables(self):
161
164
# Create a smaller subset table
162
165
subset_table = my_arrow_table .slice (0 , min (100 , my_arrow_table .num_rows ))
163
166
164
- result = sess .query ("""
167
+ result = self . sess .query ("""
165
168
SELECT
166
169
(SELECT COUNT(*) FROM Python(my_arrow_table)) as full_count,
167
170
(SELECT COUNT(*) FROM Python(subset_table)) as subset_count
0 commit comments