Skip to content

Commit

Permalink
Fix handling of None database (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
genzgd authored Jan 17, 2023
1 parent a8020d8 commit 097415b
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 11 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# ClickHouse Connect ChangeLog

## 0.5.2, 2023-01-17

### Bug fix
* Fix issue where client database is set to None (this normally only happens when deleting the initial database)

## 0.5.1, 2023-01-16

### Bug fix
Expand Down
2 changes: 1 addition & 1 deletion clickhouse_connect/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.5.1
0.5.2
9 changes: 6 additions & 3 deletions clickhouse_connect/driver/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ def query_arrow(self,
:return: PyArrow.Table
"""
settings = dict_copy(settings)
settings['database'] = self.database
if self.database:
settings['database'] = self.database
if arrow_str_setting in self.server_settings and arrow_str_setting not in settings:
settings[arrow_str_setting] = '1' if use_strings else '0'
return to_arrow(self.raw_query(query, parameters, settings, 'Arrow'))
Expand Down Expand Up @@ -411,7 +412,8 @@ def insert_arrow(self, table: str, arrow_table, database: str = None, settings:
:param settings: Optional dictionary of ClickHouse settings (key/string values)
:return: No return, throws an exception if the insert fails
"""
full_table = table if '.' in table else f'{database or self.database}.{table}'
database = database or self.database
full_table = table if '.' in table or not database else f'{database}.{table}'
column_names, insert_block = arrow_buffer(arrow_table)
self.raw_insert(full_table, column_names, insert_block, settings, 'Arrow')

Expand Down Expand Up @@ -439,7 +441,8 @@ def create_insert_context(self,
:param data: Initial dataset for insert
:return Reusable insert context
"""
full_table = table if '.' in table else f'{database or self.database}.{table}'
database = database or self.database
full_table = table if '.' in table or not database else f'{database}.{table}'
column_defs = []
if column_types is None:
describe_result = self.query(f'DESCRIBE TABLE {full_table}')
Expand Down
10 changes: 7 additions & 3 deletions clickhouse_connect/driver/httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def _prep_query(self, context: QueryContext):

def _query_with_context(self, context: QueryContext) -> QueryResult:
headers = {'Content-Type': 'text/plain; charset=utf-8'}
params = {'database': self.database}
params = {}
if self.database:
params['database'] = self.database
params.update(context.bind_params)
params.update(self._validate_settings(context.settings))
if columns_only_re.search(context.uncommented_query):
Expand Down Expand Up @@ -248,7 +250,9 @@ def raw_insert(self, table: str,
if compression:
headers['Content-Encoding'] = compression
cols = f" ({', '.join([quote_identifier(x) for x in column_names])})" if column_names is not None else ''
params = {'query': f'INSERT INTO {table}{cols} FORMAT {write_format}', 'database': self.database}
params = {'query': f'INSERT INTO {table}{cols} FORMAT {write_format}'}
if self.database:
params['database'] = self.database
params.update(self._validate_settings(settings or {}))
response = self._raw_request(insert_block, params, headers, error_handler=status_handler)
logger.debug('Insert response code: %d, content: %s', response.status, response.data)
Expand Down Expand Up @@ -277,7 +281,7 @@ def command(self,
payload = cmd
elif cmd:
params['query'] = cmd
if use_database:
if use_database and self.database:
params['database'] = self.database
params.update(self._validate_settings(settings or {}))
method = 'POST' if payload else 'GET'
Expand Down
32 changes: 28 additions & 4 deletions tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,35 @@ def test_command(test_client: Client):
assert version.startswith('2')


def test_none_database(test_client: Client):
old_db = test_client.database
test_db = test_client.command('select database()')
assert test_db == old_db
try:
test_client.database = None
query_result = test_client.query('SELECT * FROM system.tables')
with query_result:
pass
test_db = test_client.command('select database()')
assert test_db == 'default'
test_client.database = old_db
test_db = test_client.command('select database()')
assert test_db == old_db
finally:
test_client.database = old_db


def test_insert(test_client: Client, test_table_engine: str):
test_client.command('DROP TABLE IF EXISTS test_system_insert')
test_client.command(f'CREATE TABLE test_system_insert AS system.tables Engine {test_table_engine} ORDER BY name')
tables_result = test_client.query('SELECT * from system.tables')
test_client.insert(table='test_system_insert', column_names='*', data=tables_result.result_set)
old_db = test_client.database
test_client.database = None
try:
test_client.command('DROP TABLE IF EXISTS default.test_system_insert')
test_client.command(f'CREATE TABLE default.test_system_insert AS system.tables Engine {test_table_engine} ORDER BY name')
tables_result = test_client.query('SELECT * from system.tables')
test_client.insert(table='test_system_insert', column_names='*', data=tables_result.result_set)
test_client.command('DROP TABLE IF EXISTS default.test_system_insert')
finally:
test_client.database = old_db


def test_raw_insert(test_client: Client, test_table_engine: str):
Expand Down

0 comments on commit 097415b

Please sign in to comment.