Skip to content
This repository was archived by the owner on Mar 14, 2024. It is now read-only.

Sql example #248

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ The outputs will be stored next to the inputs in the `data/FB15k` directory.

This simple utility is only suitable for small graphs that fit entirely in memory. To handle larger data one will have to implement their own custom preprocessor.

### Checking the data

It is advised that you check the edgelist files using PBG's checking script. It catches common errors that our developers have run into before that can be hard to debug. This command is run by invoking the following:

```bash
torchbiggraph_check \
torchbiggraph/examples/configs/fb15k_config_cpu.py
```

This command will throw run-time errors with informative messages if it encounters problems. You will see ''Found no errors in the input'' logged if there are no errors.

### Training

The `torchbiggraph_train` command is used to launch training. The training parameters are tucked away in a configuration file, whose path is given to the command. They can however be overridden from the command line with the `--param` flag. The sample config is used for both training and evaluation, so we will have to use the override to specify the edge set to use.
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ parquet = parquet

[options.entry_points]
console_scripts =
torchbiggraph_check = torchbiggraph.check:main
torchbiggraph_config = torchbiggraph.config:main
torchbiggraph_eval = torchbiggraph.eval:main
torchbiggraph_example_fb15k = torchbiggraph.examples.fb15k:main
Expand Down
110 changes: 110 additions & 0 deletions torchbiggraph/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import argparse
import logging
import sys
import torch

from torchbiggraph.train_cpu import (
IterationManager,
get_num_edge_chunks,
)
from torchbiggraph.graph_storages import EDGE_STORAGES, ENTITY_STORAGES
from torchbiggraph.config import ConfigFileLoader, ConfigSchema
from torchbiggraph.types import Bucket
from torchbiggraph.util import EmbeddingHolder
from torchbiggraph.util import (
set_logging_verbosity,
setup_logging,
)


logger = logging.getLogger("torchbiggraph")

class Checker:
def __init__(self, config):
entity_storage = ENTITY_STORAGES.make_instance(config.entity_path)
entity_counts = {}
for entity, econf in config.entities.items():
entity_counts[entity] = []
for part in range(econf.num_partitions):
entity_counts[entity].append(entity_storage.load_count(entity, part))
self.entity_counts = entity_counts
self.config = config
holder = self.holder = EmbeddingHolder(config)


def check_all_edges(self):
num_edge_chunks = get_num_edge_chunks(self.config)

iteration_manager = IterationManager(
1,
self.config.edge_paths,
num_edge_chunks,
iteration_idx=0,
)
edge_storage = EDGE_STORAGES.make_instance(iteration_manager.edge_path)

for _, _, edge_chunk_idx in iteration_manager:
for lhs in range(self.holder.nparts_lhs):
for rhs in range(self.holder.nparts_rhs):
cur_b = Bucket(lhs, rhs)
logging.info(f"Checking edge chunk: {edge_chunk_idx} for edges_{cur_b.lhs}_{cur_b.rhs}.h5")
edges = edge_storage.load_chunk_of_edges(
cur_b.lhs,
cur_b.rhs,
edge_chunk_idx,
iteration_manager.num_edge_chunks,
shared=True,
)
self.check_edge_chunk(cur_b, edges)

def check_edge_chunk(self, cur_b, edges):
rhs = edges.rhs.to_tensor()
lhs = edges.lhs.to_tensor()
rel_lhs_entity_counts = torch.tensor(
[self.entity_counts[r.lhs][cur_b.lhs] for r in self.config.relations]
)
#Check LHS
edge_lhs_entity_count = rel_lhs_entity_counts[edges.rel]

if any(lhs >= edge_lhs_entity_count):
_, worst_edge_idx = (lhs - edge_lhs_entity_count).max(0)
raise RuntimeError(f"edge {worst_edge_idx} has LHS entity of "
f"{lhs[worst_edge_idx]} but rel "
f"{edges.rel[worst_edge_idx]} only has "
f"{edge_lhs_entity_count[worst_edge_idx]} "
"entities "
f" with r.name: {self.config.relations[edges.rel[worst_edge_idx]].name}. "
"Preprocessing bug?")
#Check RHS
rel_rhs_entity_counts = torch.tensor(
[self.entity_counts[r.rhs][cur_b.rhs] for r in self.config.relations]
)
edge_rhs_entity_count = rel_rhs_entity_counts[edges.rel]
if any(rhs >= edge_rhs_entity_count):
_, worst_edge_idx = (rhs - edge_rhs_entity_count).max(0)
raise RuntimeError(f"edge {worst_edge_idx} has RHS entity of "
f"{rhs[worst_edge_idx]} but rel "
f"{edges.rel[worst_edge_idx]} only has "
f"{edge_rhs_entity_count[worst_edge_idx]} "
"entities "
f" with r.name: {self.config.relations[edges.rel[worst_edge_idx]].name}. "
"Preprocessing bug?")

if __name__ == '__main__':
parser = argparse.ArgumentParser("""Script to check for user errors in a PBG input config and data.

This script checks that each entity index is within range for the entity type specified by the config relation.
Preprocessing or config bugs can break this assumption, and may lead to errors or crashes during training.
""")
parser.add_argument("config", help="Path to config file")
parser.add_argument("-p", "--param", action="append", nargs="*")
opt = parser.parse_args()

loader = ConfigFileLoader()
config = loader.load_config(opt.config, opt.param)

set_logging_verbosity(config.verbose)
setup_logging(config.verbose)

Checker(config).check_all_edges()
logging.info("Found no errors in the input")
56 changes: 56 additions & 0 deletions torchbiggraph/examples/sql_end2end/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SQL End to End Example

This is intended as a simple end-to-end example of how to get your a SQL edgelist
table into the format that PyTorch BigGraph expects using SQL queries. It's
implemented in SQLite for portability, but similar techniques scale to billions
of edges using cloud databases such as BigQuery or SnowFlake. This pipeline
can be split into three different components:

1. Data preparation
2. Data verification/checking
3. Training

To run the pipeline, you'll first need to download the edges.csv file,
available HERE (TODO: INSERT LINK). This graph was constructed by
taking the [ogbl-citation2](https://github.com/snap-stanford/ogb) graph, and
adding edges for both paper-citations and years-published. While this graph
might not make a huge amount of sense, it's intended to largely fulfill a
pedagogical purpose.

In the data preparation stage, we first load the graph
into a SQLite database and then we transform and partition it. The transformation
can be understood as first partitioning the entities, then generating a mapping
between the graph-ids and ordinal ids per-type that PBG will expect, and finally
writing out all the files required to train, including the config file. By
keeping track of the vertex types, we're able to specifically verify our mappings
in a fully self consistent fashion.

Once the data has been prepared and generated, we're ready to embed the graph. We
do this by passing the generated config to `torchbiggraph_train` in the following
way:

```
torchbiggraph_train \
path/to/generated/config.py
```

The `data_prep.py` script will also compute the approximate amount of shared memory
that will be needed for training. If the training demands are more than the
available shared memory, you'll need to regenerate your data with more partitions
than what you currently have. If you're seeing either a bus error or a OOM kill
message in the kernel ring buffer but your machine has enough ram, you'll want to
verify that `/dev/shm` is large enough to accomodate your embedding table.

# Extensions

A few changes will need to be made to use this at scale in production environment.
First, this pipeline is brittle and simplistic. For production workloads it's
probably better to use a tool like DBT or dataflow to create independent tables
in parallel. It's also important to be careful with our indices to make our joins
performant.

When it comes time to map your buckets to hdf5 files it's almost certainly more
performant to dump them to chunked parquet/avro files and merge those files together
in parallel. Every company's compute infrastructure is going to be different
enough that this piece of code will have to be custom written. Fortunately, this
code can be written once and reused.
Loading