Skip to content

Commit

Permalink
fixup connection function naming
Browse files Browse the repository at this point in the history
  • Loading branch information
jay-bhambhani committed Sep 14, 2023
1 parent 164c0a1 commit 1ed9dee
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions torch_geometric/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from torch_geometric.data.data import Data
from torch_geometric.typing import OptTensor


@dataclass
class GraphLabel:
id: str
type: str


@dataclass
class GraphRow:
id: str
Expand All @@ -36,6 +38,7 @@ def chunk(seq: Iterable, chunk_size: int) -> Generator[list, Any, None]:
return
yield batch


def namedtuple_factory(cursor, row):
"""util function to create a namedtuple Row foe db results"""
fields = [column[0] for column in cursor.description]
Expand All @@ -46,14 +49,15 @@ def namedtuple_factory(cursor, row):
class Database(abc.ABC):

def __init__(self, credentials, *args, **kwargs):
self.cursor = self._get_cursor(credentials)
self.connection = self._get_connection(credentials)

@abc.abstractmethod
def _initialize(self):
"""initialize the database in some way if needed"""
raise NotImplementedError()

@abc.abstractmethod
def insert(elf, labels: Iterable[GraphLabel], values: Iterable[Data], batch_size=10000) -> list[str]:
def insert(self, labels: Iterable[GraphLabel], values: Iterable[Data], batch_size=10000) -> list[str]:
"""insert data into a database"""
raise NotImplementedError()

Expand Down Expand Up @@ -82,7 +86,6 @@ def _initialize(self):
create = """CREATE TABLE ? (id TEXT, type TEXT, x BLOB, edge_index BLOB, edge_attr BLOB, y BLOB, pos BLOB, meta TEXT)"""
self.cursor.execute(create, self.table)


def insert(self, labels: Iterable[GraphLabel], values: Iterable[Data], batch_size=10000) -> list[GraphRow]:
for chunk_data in chunk(zip(labels, values), batch_size):
serialized = [self.serialize_data(label, value) for label, value in chunk_data]
Expand All @@ -102,7 +105,6 @@ def multi_get(self, labels: Iterable[GraphLabel], batch_size=999):
query = f"SELECT * FROM {self.table} WHERE id IN ({','.join('?' * len(chunk_data))})"
self.cursor.execute(query, (label.id for label in chunk_data))


def serialize_data(self, label: GraphLabel, data: Data) -> GraphRow:
return GraphRow(
id=label.id,
Expand All @@ -120,13 +122,10 @@ def _serialize_tensor(t: OptTensor) -> bytes:
buff = io.BytesIO()
torch.save(t, buff)
return buff.getvalue()
def _get_cursor(self, connection):

def _get_connection(self, credentials):
"""a method to get the db cursor to executor SQL"""
con = sqlite3.connect(connection)
con = sqlite3.connect(credentials)
cursor = con.cursor()
cursor.row_factory = namedtuple_factory
return cursor



0 comments on commit 1ed9dee

Please sign in to comment.