Skip to content

Commit

Permalink
Additional Database tests for all kinds of input data (#8057)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Sep 19, 2023
1 parent 0dce891 commit 5f157cd
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 67 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054))
- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054), [#8057](https://github.com/pyg-team/pytorch_geometric/pull/8057))
- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038))
- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033))
Expand Down
151 changes: 119 additions & 32 deletions test/data/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,53 @@
import pytest
import torch

from torch_geometric.data.database import RocksDatabase, SQLiteDatabase
from torch_geometric.data import Data
from torch_geometric.data.database import (
RocksDatabase,
SQLiteDatabase,
TensorInfo,
)
from torch_geometric.profile import benchmark
from torch_geometric.testing import withPackage
from torch_geometric.testing import has_package, withPackage

AVAILABLE_DATABASES = []
if has_package('sqlite3'):
AVAILABLE_DATABASES.append(SQLiteDatabase)
if has_package('rocksdict'):
AVAILABLE_DATABASES.append(RocksDatabase)

@withPackage('sqlite3')

@pytest.mark.parametrize('Database', AVAILABLE_DATABASES)
@pytest.mark.parametrize('batch_size', [None, 1])
def test_sqlite_database(tmp_path, batch_size):
path = osp.join(tmp_path, 'sqlite.db')
db = SQLiteDatabase(path, name='test_table')
assert str(db) == 'SQLiteDatabase(0)'
assert len(db) == 0
def test_databases_single_tensor(tmp_path, Database, batch_size):
kwargs = dict(path=osp.join(tmp_path, 'storage.db'))
if Database == SQLiteDatabase:
kwargs['name'] = 'test_table'

db = Database(**kwargs)
assert db.schema == {0: object}

try:
assert len(db) == 0
assert str(db) == f'{Database.__name__}(0)'
except NotImplementedError:
assert str(db) == f'{Database.__name__}()'

data = torch.randn(5)
db.insert(0, data)
assert len(db) == 1
try:
assert len(db) == 1
except NotImplementedError:
pass
assert torch.equal(db.get(0), data)

indices = torch.tensor([1, 2])
data_list = torch.randn(2, 5)
db.multi_insert(indices, data_list, batch_size=batch_size)
assert len(db) == 3

try:
assert len(db) == 3
except NotImplementedError:
pass
out_list = db.multi_get(indices, batch_size=batch_size)
assert isinstance(out_list, list)
assert len(out_list) == 2
Expand All @@ -35,35 +59,98 @@ def test_sqlite_database(tmp_path, batch_size):
db.close()


@withPackage('rocksdict')
@pytest.mark.parametrize('batch_size', [None, 1])
def test_rocks_database(tmp_path, batch_size):
path = osp.join(tmp_path, 'rocks.db')
db = RocksDatabase(path)
assert str(db) == 'RocksDatabase()'
with pytest.raises(NotImplementedError):
len(db)

data = torch.randn(5)
db.insert(0, data)
assert torch.equal(db.get(0), data)
@pytest.mark.parametrize('Database', AVAILABLE_DATABASES)
def test_databases_schema(tmp_path, Database):
kwargs = dict(name='test_table') if Database == SQLiteDatabase else {}

path = osp.join(tmp_path, 'tuple_storage.db')
schema = (int, float, str, dict(dtype=torch.float, size=(2, -1)), object)
db = Database(path, schema=schema, **kwargs)
assert db.schema == {
0: int,
1: float,
2: str,
3: TensorInfo(dtype=torch.float, size=(2, -1)),
4: object,
}

data1 = (1, 0.1, 'a', torch.randn(2, 8), Data(x=torch.randn(1, 8)))
data2 = (2, 0.2, 'b', torch.randn(2, 16), Data(x=torch.randn(2, 8)))
data3 = (3, 0.3, 'c', torch.randn(2, 32), Data(x=torch.randn(3, 8)))
db.insert(0, data1)
db.multi_insert([1, 2], [data2, data3])

out1 = db.get(0)
out2, out3 = db.multi_get([1, 2])

for out, data in zip([out1, out2, out3], [data1, data2, data3]):
assert out[0] == data[0]
assert out[1] == data[1]
assert out[2] == data[2]
assert torch.equal(out[3], data[3])
assert isinstance(out[4], Data) and len(out[4]) == 1
assert torch.equal(out[4].x, data[4].x)

indices = torch.tensor([1, 2])
data_list = torch.randn(2, 5)
db.multi_insert(indices, data_list, batch_size=batch_size)
db.close()

out_list = db.multi_get(indices, batch_size=batch_size)
assert isinstance(out_list, list)
assert len(out_list) == 2
assert torch.equal(out_list[0], data_list[0])
assert torch.equal(out_list[1], data_list[1])
path = osp.join(tmp_path, 'dict_storage.db')
schema = {
'int': int,
'float': float,
'str': str,
'tensor': dict(dtype=torch.float, size=(2, -1)),
'data': object
}
db = Database(path, schema=schema, **kwargs)
assert db.schema == {
'int': int,
'float': float,
'str': str,
'tensor': TensorInfo(dtype=torch.float, size=(2, -1)),
'data': object,
}

data1 = {
'int': 1,
'float': 0.1,
'str': 'a',
'tensor': torch.randn(2, 8),
'data': Data(x=torch.randn(1, 8)),
}
data2 = {
'int': 2,
'float': 0.2,
'str': 'b',
'tensor': torch.randn(2, 16),
'data': Data(x=torch.randn(2, 8)),
}
data3 = {
'int': 3,
'float': 0.3,
'str': 'c',
'tensor': torch.randn(2, 32),
'data': Data(x=torch.randn(3, 8)),
}
db.insert(0, data1)
db.multi_insert([1, 2], [data2, data3])

out1 = db.get(0)
out2, out3 = db.multi_get([1, 2])

for out, data in zip([out1, out2, out3], [data1, data2, data3]):
assert out['int'] == data['int']
assert out['float'] == data['float']
assert out['str'] == data['str']
assert torch.equal(out['tensor'], data['tensor'])
assert isinstance(out['data'], Data) and len(out['data']) == 1
assert torch.equal(out['data'].x, data['data'].x)

db.close()


@withPackage('sqlite3')
def test_database_syntactic_sugar(tmp_path):
path = osp.join(tmp_path, 'sqlite.db')
path = osp.join(tmp_path, 'storage.db')
db = SQLiteDatabase(path, name='test_table')

data = torch.randn(5, 16)
Expand Down
42 changes: 28 additions & 14 deletions torch_geometric/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def __init__(self, path: str, name: str, schema: Schema = object):

self.connect()

# Create the table (if it does not exist) by mapping the Python schema
# to the corresponding SQL schema:
sql_schema = ',\n'.join([
f' {col_name} {self._to_sql_type(type_info)} NOT NULL' for
col_name, type_info in zip(self._col_names, self.schema.values())
Expand Down Expand Up @@ -215,7 +217,7 @@ def insert(self, index: int, data: Any):
query = (f'INSERT INTO {self.name} '
f'(id, {self._joined_col_names}) '
f'VALUES (?, {self._dummies})')
self.cursor.execute(query, (index, self._serialize(data)))
self.cursor.execute(query, (index, *self._serialize(data)))

def _multi_insert(
self,
Expand All @@ -225,12 +227,13 @@ def _multi_insert(
if isinstance(indices, Tensor):
indices = indices.tolist()

data_list = [self._serialize(data) for data in data_list]
data_list = [(index, *self._serialize(data))
for index, data in zip(indices, data_list)]

query = (f'INSERT INTO {self.name} '
f'(id, {self._joined_col_names}) '
f'VALUES (?, {self._dummies})')
self.cursor.executemany(query, zip(indices, data_list))
self.cursor.executemany(query, data_list)

def get(self, index: int) -> Any:
query = (f'SELECT {self._joined_col_names} FROM {self.name} '
Expand All @@ -249,7 +252,7 @@ def multi_get(
elif isinstance(indices, Tensor):
indices = indices.tolist()

# We first create a temporary ID table to then perform an INNER JOIN.
# We create a temporary ID table to then perform an INNER JOIN.
# This avoids having a long IN clause and guarantees sorted outputs:
join_table_name = f'{self.name}__join__{uuid4().hex}'
query = (f'CREATE TABLE {join_table_name} (\n'
Expand Down Expand Up @@ -314,24 +317,33 @@ def _to_sql_type(self, type_info: Any) -> str:
else:
return 'BLOB'

def _serialize(self, row: Any) -> Union[Any, List[Any]]:
out_list: List[Any] = []
def _serialize(self, row: Any) -> List[Any]:
# Serializes the given input data according to `schema`:
# * {int, float, str}: Use as they are.
# * torch.Tensor: Convert into the raw byte string
# * object: Dump via pickle
# If we find a `torch.Tensor` that is not registered as such in
# `schema`, we modify the schema in-place for improved efficiency.
out: List[Any] = []
for key, col in self._to_dict(row).items():
if isinstance(self.schema[key], TensorInfo):
out = row.numpy().tobytes()
out.append(col.numpy().tobytes())
elif isinstance(col, Tensor):
self.schema[key] = TensorInfo(dtype=col.dtype)
out = row.numpy().tobytes()
out.append(col.numpy().tobytes())
elif self.schema[key] in {int, float, str}:
out = col
out.append(col)
else:
out = pickle.dumps(col)
out.append(pickle.dumps(col))

out_list.append(out)

return out_list if len(out_list) > 1 else out_list[0]
return out

def _deserialize(self, row: Tuple[Any]) -> Any:
# Deserializes the DB data according to `schema`:
# * {int, float, str}: Use as they are.
# * torch.Tensor: Load raw byte string with `dtype` and `size`
# information from `schema`
# * object: Load via pickle
out_dict = {}
for i, (key, col_schema) in enumerate(self.schema.items()):
if isinstance(col_schema, TensorInfo):
Expand All @@ -342,12 +354,14 @@ def _deserialize(self, row: Tuple[Any]) -> Any:
else:
out_dict[key] = pickle.loads(row[i])

# In case `0` exists as integer in the schema, this means that the
# schema was passed as either a single entry or a tuple:
if 0 in self.schema:
if len(self.schema) == 1:
return out_dict[0]
else:
return tuple(out_dict.values())
else:
else: # Otherwise, return the dictionary as it is:
return out_dict


Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
onlyOnline,
onlyGraphviz,
onlyNeighborSampler,
has_package,
withPackage,
withCUDA,
disableExtensions,
Expand All @@ -29,6 +30,7 @@
'onlyOnline',
'onlyGraphviz',
'onlyNeighborSampler',
'has_package',
'withPackage',
'withCUDA',
'disableExtensions',
Expand Down
42 changes: 22 additions & 20 deletions torch_geometric/testing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,29 +125,31 @@ def onlyNeighborSampler(func: Callable):
)(func)


def has_package(package: str) -> bool:
r"""Returns :obj:`True` in case :obj:`package` is installed."""
if '|' in package:
return any(has_package(p) for p in package.split('|'))

req = Requirement(package)
if find_spec(req.name) is None:
return False
module = import_module(req.name)
if not hasattr(module, '__version__'):
return True

version = module.__version__
# `req.specifier` does not support `.dev` suffixes, e.g., for
# `pyg_lib==0.1.0.dev*`, so we manually drop them:
if '.dev' in version:
version = '.'.join(version.split('.dev')[:-1])

return version in req.specifier


def withPackage(*args) -> Callable:
r"""A decorator to skip tests if certain packages are not installed.
Also supports version specification."""
def is_installed(package: str) -> bool:
if '|' in package:
return any(is_installed(p) for p in package.split('|'))

req = Requirement(package)
if find_spec(req.name) is None:
return False
module = import_module(req.name)
if not hasattr(module, '__version__'):
return True

version = module.__version__
# `req.specifier` does not support `.dev` suffixes, e.g., for
# `pyg_lib==0.1.0.dev*`, so we manually drop them:
if '.dev' in version:
version = '.'.join(version.split('.dev')[:-1])

return version in req.specifier

na_packages = set(package for package in args if not is_installed(package))
na_packages = set(package for package in args if not has_package(package))

def decorator(func: Callable) -> Callable:
import pytest
Expand Down

0 comments on commit 5f157cd

Please sign in to comment.