diff --git a/CHANGELOG.md b/CHANGELOG.md index 8468444f7080..6c1cf336bcbd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `NaN` handling in `SQLDatabase` ([#8479](https://github.com/pyg-team/pytorch_geometric/pull/8479)) - Fixed `CaptumExplainer` in case no `index` is passed ([#8440](https://github.com/pyg-team/pytorch_geometric/pull/8440)) - Fixed `edge_index` construction in the `UPFD` dataset ([#8413](https://github.com/pyg-team/pytorch_geometric/pull/8413)) - Fixed TorchScript support in `AttentionalAggregation` and `DeepSetsAggregation` ([#8406](https://github.com/pyg-team/pytorch_geometric/pull/8406)) diff --git a/test/data/test_database.py b/test/data/test_database.py index 5d7d055fb19a..e2b8131ff716 100644 --- a/test/data/test_database.py +++ b/test/data/test_database.py @@ -1,3 +1,4 @@ +import math import os.path as osp import pytest @@ -17,7 +18,7 @@ @pytest.mark.parametrize('Database', AVAILABLE_DATABASES) @pytest.mark.parametrize('batch_size', [None, 1]) -def test_databases_single_tensor(tmp_path, Database, batch_size): +def test_database_single_tensor(tmp_path, Database, batch_size): kwargs = dict(path=osp.join(tmp_path, 'storage.db')) if Database == SQLiteDatabase: kwargs['name'] = 'test_table' @@ -56,7 +57,7 @@ def test_databases_single_tensor(tmp_path, Database, batch_size): @pytest.mark.parametrize('Database', AVAILABLE_DATABASES) -def test_databases_schema(tmp_path, Database): +def test_database_schema(tmp_path, Database): kwargs = dict(name='test_table') if Database == SQLiteDatabase else {} path = osp.join(tmp_path, 'tuple_storage.db') @@ -70,9 +71,9 @@ def test_databases_schema(tmp_path, Database): 4: object, } - data1 = (1, 0.1, 'a', torch.randn(2, 8), Data(x=torch.randn(1, 8))) - data2 = (2, 0.2, 'b', torch.randn(2, 16), Data(x=torch.randn(2, 8))) - data3 = (3, 0.3, 'c', torch.randn(2, 32), Data(x=torch.randn(3, 8))) + data1 = (1, 0.1, 'a', torch.randn(2, 8), Data(x=torch.randn(8))) + data2 = (2, float('inf'), 'b', torch.randn(2, 16), Data(x=torch.randn(8))) + data3 = (3, float('NaN'), 'c', torch.randn(2, 32), Data(x=torch.randn(8))) db.insert(0, data1) db.multi_insert([1, 2], [data2, data3]) @@ -81,7 +82,10 @@ def test_databases_schema(tmp_path, Database): for out, data in zip([out1, out2, out3], [data1, data2, data3]): assert out[0] == data[0] - assert out[1] == data[1] + if math.isnan(data[1]): + assert math.isnan(out[1]) + else: + assert out[1] == data[1] assert out[2] == data[2] assert torch.equal(out[3], data[3]) assert isinstance(out[4], Data) and len(out[4]) == 1 diff --git a/torch_geometric/data/database.py b/torch_geometric/data/database.py index 2d4912171e1e..ec5850945ec4 100644 --- a/torch_geometric/data/database.py +++ b/torch_geometric/data/database.py @@ -284,8 +284,8 @@ def __init__(self, path: str, name: str, schema: Schema = object): # Create the table (if it does not exist) by mapping the Python schema # to the corresponding SQL schema: sql_schema = ',\n'.join([ - f' {col_name} {self._to_sql_type(type_info)} NOT NULL' for - col_name, type_info in zip(self._col_names, self.schema.values()) + f' {col_name} {self._to_sql_type(type_info)}' for col_name, + type_info in zip(self._col_names, self.schema.values()) ]) query = (f'CREATE TABLE IF NOT EXISTS {self.name} (\n' f' id INTEGER PRIMARY KEY,\n' @@ -409,13 +409,13 @@ def _dummies(self) -> str: def _to_sql_type(self, type_info: Any) -> str: if type_info == int: - return 'INTEGER' + return 'INTEGER NOT NULL' if type_info == float: return 'FLOAT' if type_info == str: - return 'TEXT' + return 'TEXT NOT NULL' else: - return 'BLOB' + return 'BLOB NOT NULL' def _serialize(self, row: Any) -> List[Any]: # Serializes the given input data according to `schema`: @@ -455,7 +455,9 @@ def _deserialize(self, row: Tuple[Any]) -> Any: else: tensor = torch.empty(0, dtype=col_schema.dtype) out_dict[key] = tensor.view(*col_schema.size) - elif col_schema in {int, float, str}: + elif col_schema == float: + out_dict[key] = value if value is not None else float('NaN') + elif col_schema in {int, str}: out_dict[key] = value else: out_dict[key] = pickle.loads(value)