From 5f157cd263d7af959ba975ac26dea00f176ac74d Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 19 Sep 2023 08:46:07 +0200 Subject: [PATCH] Additional `Database` tests for all kinds of input data (#8057) --- CHANGELOG.md | 2 +- test/data/test_database.py | 151 ++++++++++++++++++++------ torch_geometric/data/database.py | 42 ++++--- torch_geometric/testing/__init__.py | 2 + torch_geometric/testing/decorators.py | 42 +++---- 5 files changed, 172 insertions(+), 67 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0017c0643ed7..98e0776c0714 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054)) +- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054), [#8057](https://github.com/pyg-team/pytorch_geometric/pull/8057)) - Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038)) - Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025)) - Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033)) diff --git a/test/data/test_database.py b/test/data/test_database.py index 63a603aca0b5..7d3b99a257ef 100644 --- a/test/data/test_database.py +++ b/test/data/test_database.py @@ -3,29 +3,53 @@ import pytest import torch -from torch_geometric.data.database import RocksDatabase, SQLiteDatabase +from torch_geometric.data import Data +from torch_geometric.data.database import ( + RocksDatabase, + SQLiteDatabase, + TensorInfo, +) from torch_geometric.profile import benchmark -from torch_geometric.testing import withPackage +from torch_geometric.testing import has_package, withPackage +AVAILABLE_DATABASES = [] +if has_package('sqlite3'): + AVAILABLE_DATABASES.append(SQLiteDatabase) +if has_package('rocksdict'): + AVAILABLE_DATABASES.append(RocksDatabase) -@withPackage('sqlite3') + +@pytest.mark.parametrize('Database', AVAILABLE_DATABASES) @pytest.mark.parametrize('batch_size', [None, 1]) -def test_sqlite_database(tmp_path, batch_size): - path = osp.join(tmp_path, 'sqlite.db') - db = SQLiteDatabase(path, name='test_table') - assert str(db) == 'SQLiteDatabase(0)' - assert len(db) == 0 +def test_databases_single_tensor(tmp_path, Database, batch_size): + kwargs = dict(path=osp.join(tmp_path, 'storage.db')) + if Database == SQLiteDatabase: + kwargs['name'] = 'test_table' + + db = Database(**kwargs) + assert db.schema == {0: object} + + try: + assert len(db) == 0 + assert str(db) == f'{Database.__name__}(0)' + except NotImplementedError: + assert str(db) == f'{Database.__name__}()' data = torch.randn(5) db.insert(0, data) - assert len(db) == 1 + try: + assert len(db) == 1 + except NotImplementedError: + pass assert torch.equal(db.get(0), data) indices = torch.tensor([1, 2]) data_list = torch.randn(2, 5) db.multi_insert(indices, data_list, batch_size=batch_size) - assert len(db) == 3 - + try: + assert len(db) == 3 + except NotImplementedError: + pass out_list = db.multi_get(indices, batch_size=batch_size) assert isinstance(out_list, list) assert len(out_list) == 2 @@ -35,35 +59,98 @@ def test_sqlite_database(tmp_path, batch_size): db.close() -@withPackage('rocksdict') -@pytest.mark.parametrize('batch_size', [None, 1]) -def test_rocks_database(tmp_path, batch_size): - path = osp.join(tmp_path, 'rocks.db') - db = RocksDatabase(path) - assert str(db) == 'RocksDatabase()' - with pytest.raises(NotImplementedError): - len(db) - - data = torch.randn(5) - db.insert(0, data) - assert torch.equal(db.get(0), data) +@pytest.mark.parametrize('Database', AVAILABLE_DATABASES) +def test_databases_schema(tmp_path, Database): + kwargs = dict(name='test_table') if Database == SQLiteDatabase else {} + + path = osp.join(tmp_path, 'tuple_storage.db') + schema = (int, float, str, dict(dtype=torch.float, size=(2, -1)), object) + db = Database(path, schema=schema, **kwargs) + assert db.schema == { + 0: int, + 1: float, + 2: str, + 3: TensorInfo(dtype=torch.float, size=(2, -1)), + 4: object, + } + + data1 = (1, 0.1, 'a', torch.randn(2, 8), Data(x=torch.randn(1, 8))) + data2 = (2, 0.2, 'b', torch.randn(2, 16), Data(x=torch.randn(2, 8))) + data3 = (3, 0.3, 'c', torch.randn(2, 32), Data(x=torch.randn(3, 8))) + db.insert(0, data1) + db.multi_insert([1, 2], [data2, data3]) + + out1 = db.get(0) + out2, out3 = db.multi_get([1, 2]) + + for out, data in zip([out1, out2, out3], [data1, data2, data3]): + assert out[0] == data[0] + assert out[1] == data[1] + assert out[2] == data[2] + assert torch.equal(out[3], data[3]) + assert isinstance(out[4], Data) and len(out[4]) == 1 + assert torch.equal(out[4].x, data[4].x) - indices = torch.tensor([1, 2]) - data_list = torch.randn(2, 5) - db.multi_insert(indices, data_list, batch_size=batch_size) + db.close() - out_list = db.multi_get(indices, batch_size=batch_size) - assert isinstance(out_list, list) - assert len(out_list) == 2 - assert torch.equal(out_list[0], data_list[0]) - assert torch.equal(out_list[1], data_list[1]) + path = osp.join(tmp_path, 'dict_storage.db') + schema = { + 'int': int, + 'float': float, + 'str': str, + 'tensor': dict(dtype=torch.float, size=(2, -1)), + 'data': object + } + db = Database(path, schema=schema, **kwargs) + assert db.schema == { + 'int': int, + 'float': float, + 'str': str, + 'tensor': TensorInfo(dtype=torch.float, size=(2, -1)), + 'data': object, + } + + data1 = { + 'int': 1, + 'float': 0.1, + 'str': 'a', + 'tensor': torch.randn(2, 8), + 'data': Data(x=torch.randn(1, 8)), + } + data2 = { + 'int': 2, + 'float': 0.2, + 'str': 'b', + 'tensor': torch.randn(2, 16), + 'data': Data(x=torch.randn(2, 8)), + } + data3 = { + 'int': 3, + 'float': 0.3, + 'str': 'c', + 'tensor': torch.randn(2, 32), + 'data': Data(x=torch.randn(3, 8)), + } + db.insert(0, data1) + db.multi_insert([1, 2], [data2, data3]) + + out1 = db.get(0) + out2, out3 = db.multi_get([1, 2]) + + for out, data in zip([out1, out2, out3], [data1, data2, data3]): + assert out['int'] == data['int'] + assert out['float'] == data['float'] + assert out['str'] == data['str'] + assert torch.equal(out['tensor'], data['tensor']) + assert isinstance(out['data'], Data) and len(out['data']) == 1 + assert torch.equal(out['data'].x, data['data'].x) db.close() @withPackage('sqlite3') def test_database_syntactic_sugar(tmp_path): - path = osp.join(tmp_path, 'sqlite.db') + path = osp.join(tmp_path, 'storage.db') db = SQLiteDatabase(path, name='test_table') data = torch.randn(5, 16) diff --git a/torch_geometric/data/database.py b/torch_geometric/data/database.py index 992354672e89..c180eac124d6 100644 --- a/torch_geometric/data/database.py +++ b/torch_geometric/data/database.py @@ -183,6 +183,8 @@ def __init__(self, path: str, name: str, schema: Schema = object): self.connect() + # Create the table (if it does not exist) by mapping the Python schema + # to the corresponding SQL schema: sql_schema = ',\n'.join([ f' {col_name} {self._to_sql_type(type_info)} NOT NULL' for col_name, type_info in zip(self._col_names, self.schema.values()) @@ -215,7 +217,7 @@ def insert(self, index: int, data: Any): query = (f'INSERT INTO {self.name} ' f'(id, {self._joined_col_names}) ' f'VALUES (?, {self._dummies})') - self.cursor.execute(query, (index, self._serialize(data))) + self.cursor.execute(query, (index, *self._serialize(data))) def _multi_insert( self, @@ -225,12 +227,13 @@ def _multi_insert( if isinstance(indices, Tensor): indices = indices.tolist() - data_list = [self._serialize(data) for data in data_list] + data_list = [(index, *self._serialize(data)) + for index, data in zip(indices, data_list)] query = (f'INSERT INTO {self.name} ' f'(id, {self._joined_col_names}) ' f'VALUES (?, {self._dummies})') - self.cursor.executemany(query, zip(indices, data_list)) + self.cursor.executemany(query, data_list) def get(self, index: int) -> Any: query = (f'SELECT {self._joined_col_names} FROM {self.name} ' @@ -249,7 +252,7 @@ def multi_get( elif isinstance(indices, Tensor): indices = indices.tolist() - # We first create a temporary ID table to then perform an INNER JOIN. + # We create a temporary ID table to then perform an INNER JOIN. # This avoids having a long IN clause and guarantees sorted outputs: join_table_name = f'{self.name}__join__{uuid4().hex}' query = (f'CREATE TABLE {join_table_name} (\n' @@ -314,24 +317,33 @@ def _to_sql_type(self, type_info: Any) -> str: else: return 'BLOB' - def _serialize(self, row: Any) -> Union[Any, List[Any]]: - out_list: List[Any] = [] + def _serialize(self, row: Any) -> List[Any]: + # Serializes the given input data according to `schema`: + # * {int, float, str}: Use as they are. + # * torch.Tensor: Convert into the raw byte string + # * object: Dump via pickle + # If we find a `torch.Tensor` that is not registered as such in + # `schema`, we modify the schema in-place for improved efficiency. + out: List[Any] = [] for key, col in self._to_dict(row).items(): if isinstance(self.schema[key], TensorInfo): - out = row.numpy().tobytes() + out.append(col.numpy().tobytes()) elif isinstance(col, Tensor): self.schema[key] = TensorInfo(dtype=col.dtype) - out = row.numpy().tobytes() + out.append(col.numpy().tobytes()) elif self.schema[key] in {int, float, str}: - out = col + out.append(col) else: - out = pickle.dumps(col) + out.append(pickle.dumps(col)) - out_list.append(out) - - return out_list if len(out_list) > 1 else out_list[0] + return out def _deserialize(self, row: Tuple[Any]) -> Any: + # Deserializes the DB data according to `schema`: + # * {int, float, str}: Use as they are. + # * torch.Tensor: Load raw byte string with `dtype` and `size` + # information from `schema` + # * object: Load via pickle out_dict = {} for i, (key, col_schema) in enumerate(self.schema.items()): if isinstance(col_schema, TensorInfo): @@ -342,12 +354,14 @@ def _deserialize(self, row: Tuple[Any]) -> Any: else: out_dict[key] = pickle.loads(row[i]) + # In case `0` exists as integer in the schema, this means that the + # schema was passed as either a single entry or a tuple: if 0 in self.schema: if len(self.schema) == 1: return out_dict[0] else: return tuple(out_dict.values()) - else: + else: # Otherwise, return the dictionary as it is: return out_dict diff --git a/torch_geometric/testing/__init__.py b/torch_geometric/testing/__init__.py index 0a3131046e50..f80b66093d76 100644 --- a/torch_geometric/testing/__init__.py +++ b/torch_geometric/testing/__init__.py @@ -9,6 +9,7 @@ onlyOnline, onlyGraphviz, onlyNeighborSampler, + has_package, withPackage, withCUDA, disableExtensions, @@ -29,6 +30,7 @@ 'onlyOnline', 'onlyGraphviz', 'onlyNeighborSampler', + 'has_package', 'withPackage', 'withCUDA', 'disableExtensions', diff --git a/torch_geometric/testing/decorators.py b/torch_geometric/testing/decorators.py index 29845e02cfed..6e517ae4b14a 100644 --- a/torch_geometric/testing/decorators.py +++ b/torch_geometric/testing/decorators.py @@ -125,29 +125,31 @@ def onlyNeighborSampler(func: Callable): )(func) +def has_package(package: str) -> bool: + r"""Returns :obj:`True` in case :obj:`package` is installed.""" + if '|' in package: + return any(has_package(p) for p in package.split('|')) + + req = Requirement(package) + if find_spec(req.name) is None: + return False + module = import_module(req.name) + if not hasattr(module, '__version__'): + return True + + version = module.__version__ + # `req.specifier` does not support `.dev` suffixes, e.g., for + # `pyg_lib==0.1.0.dev*`, so we manually drop them: + if '.dev' in version: + version = '.'.join(version.split('.dev')[:-1]) + + return version in req.specifier + + def withPackage(*args) -> Callable: r"""A decorator to skip tests if certain packages are not installed. Also supports version specification.""" - def is_installed(package: str) -> bool: - if '|' in package: - return any(is_installed(p) for p in package.split('|')) - - req = Requirement(package) - if find_spec(req.name) is None: - return False - module = import_module(req.name) - if not hasattr(module, '__version__'): - return True - - version = module.__version__ - # `req.specifier` does not support `.dev` suffixes, e.g., for - # `pyg_lib==0.1.0.dev*`, so we manually drop them: - if '.dev' in version: - version = '.'.join(version.split('.dev')[:-1]) - - return version in req.specifier - - na_packages = set(package for package in args if not is_installed(package)) + na_packages = set(package for package in args if not has_package(package)) def decorator(func: Callable) -> Callable: import pytest