Skip to content

Commit

Permalink
matthias suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
jay-bhambhani committed Sep 15, 2023
1 parent 1ed9dee commit f713360
Showing 1 changed file with 15 additions and 23 deletions.
38 changes: 15 additions & 23 deletions torch_geometric/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
import io
import itertools
import json
from typing import Iterable, Generator, Optional, Any

import sqlite3
Expand All @@ -12,21 +13,17 @@
from torch_geometric.typing import OptTensor


@dataclass
class GraphLabel:
id: str
type: str

def __init__(self, id: str, *args, **kwargs):
self.id = id
for key, value in kwargs.items():
setattr(self, key, value)

@dataclass
class GraphRow:
id: str
type: str
x: Optional[bytes]
edge_index: Optional[bytes]
edge_attr: Optional[bytes]
y: Optional[bytes]
pos: Optional[bytes]
data: Optional[dict[str, Optional[bytes]]]


def chunk(seq: Iterable, chunk_size: int) -> Generator[list, Any, None]:
Expand Down Expand Up @@ -83,21 +80,20 @@ def __init__(self, credentials, table='pyg_database') -> None:
super().__init__(credentials)

def _initialize(self):
create = """CREATE TABLE ? (id TEXT, type TEXT, x BLOB, edge_index BLOB, edge_attr BLOB, y BLOB, pos BLOB, meta TEXT)"""
create = """CREATE TABLE ? (id TEXT, data, TEXT)"""
self.cursor.execute(create, self.table)

def insert(self, labels: Iterable[GraphLabel], values: Iterable[Data], batch_size=10000) -> list[GraphRow]:
for chunk_data in chunk(zip(labels, values), batch_size):
serialized = [self.serialize_data(label, value) for label, value in chunk_data]
query = f"""
INSERT INTO {self.table} (id, type, x, edge_index, edge_attr, y, pos, meta)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?)"""
self.cursor.executemany(query, [row.astuple() for row in serialized])
INSERT INTO {self.table} (id, data)
VALUES (?, ?)"""
self.cursor.executemany(query, [(row['id'], json.dumps(row['data'])) for row in serialized])

def get(self, label: GraphLabel):
query = f"""SELECT * FROM {self.table} where id = ? and type = ?"""
self.cursor.execute(query, (label.id, label.type))
query = f"""SELECT * FROM {self.table} where id = ?"""
self.cursor.execute(query, (label.id))
return self.cursor.fetchone()

def multi_get(self, labels: Iterable[GraphLabel], batch_size=999):
Expand All @@ -106,14 +102,10 @@ def multi_get(self, labels: Iterable[GraphLabel], batch_size=999):
self.cursor.execute(query, (label.id for label in chunk_data))

def serialize_data(self, label: GraphLabel, data: Data) -> GraphRow:
row_dict = {k: self._serialize_tensor(v) if isinstance(v, OptTensor) else v for k, v in vars(data).items()}
return GraphRow(
id=label.id,
type=data.type,
x=self._serialize_tensor(data.x),
edge_index=self._serialize_tensor(data.edge_index),
edge_attr=self._serialize_tensor(data.edge_attr),
y=self._serialize_tensor(data.y),
pos=self._serialize_tensor(data.pos)
**vars(label),
**row_dict
)

@staticmethod
Expand Down

0 comments on commit f713360

Please sign in to comment.