From 1ed9dee96a7b9453962b2c44f9a432b14e88cd08 Mon Sep 17 00:00:00 2001 From: jay-bhambhani Date: Thu, 14 Sep 2023 02:09:50 +0000 Subject: [PATCH] fixup connection function naming --- torch_geometric/data/database.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/torch_geometric/data/database.py b/torch_geometric/data/database.py index b1b756918cdb..3fdc235b4f67 100644 --- a/torch_geometric/data/database.py +++ b/torch_geometric/data/database.py @@ -11,11 +11,13 @@ from torch_geometric.data.data import Data from torch_geometric.typing import OptTensor + @dataclass class GraphLabel: id: str type: str + @dataclass class GraphRow: id: str @@ -36,6 +38,7 @@ def chunk(seq: Iterable, chunk_size: int) -> Generator[list, Any, None]: return yield batch + def namedtuple_factory(cursor, row): """util function to create a namedtuple Row foe db results""" fields = [column[0] for column in cursor.description] @@ -46,14 +49,15 @@ def namedtuple_factory(cursor, row): class Database(abc.ABC): def __init__(self, credentials, *args, **kwargs): - self.cursor = self._get_cursor(credentials) + self.connection = self._get_connection(credentials) @abc.abstractmethod def _initialize(self): """initialize the database in some way if needed""" + raise NotImplementedError() @abc.abstractmethod - def insert(elf, labels: Iterable[GraphLabel], values: Iterable[Data], batch_size=10000) -> list[str]: + def insert(self, labels: Iterable[GraphLabel], values: Iterable[Data], batch_size=10000) -> list[str]: """insert data into a database""" raise NotImplementedError() @@ -82,7 +86,6 @@ def _initialize(self): create = """CREATE TABLE ? (id TEXT, type TEXT, x BLOB, edge_index BLOB, edge_attr BLOB, y BLOB, pos BLOB, meta TEXT)""" self.cursor.execute(create, self.table) - def insert(self, labels: Iterable[GraphLabel], values: Iterable[Data], batch_size=10000) -> list[GraphRow]: for chunk_data in chunk(zip(labels, values), batch_size): serialized = [self.serialize_data(label, value) for label, value in chunk_data] @@ -102,7 +105,6 @@ def multi_get(self, labels: Iterable[GraphLabel], batch_size=999): query = f"SELECT * FROM {self.table} WHERE id IN ({','.join('?' * len(chunk_data))})" self.cursor.execute(query, (label.id for label in chunk_data)) - def serialize_data(self, label: GraphLabel, data: Data) -> GraphRow: return GraphRow( id=label.id, @@ -120,13 +122,10 @@ def _serialize_tensor(t: OptTensor) -> bytes: buff = io.BytesIO() torch.save(t, buff) return buff.getvalue() - - def _get_cursor(self, connection): + + def _get_connection(self, credentials): """a method to get the db cursor to executor SQL""" - con = sqlite3.connect(connection) + con = sqlite3.connect(credentials) cursor = con.cursor() cursor.row_factory = namedtuple_factory return cursor - - -