CellTypeGraph is a new graph benchmark for node classification.
The benchmark is distilled from of 84 Arabidopsis ovules segmentations, and the task is to classify each cell with its specific cell type. We represent each specimen as a graph, where each cell is a node and any two adjacent cells are connected with an edge. This python-package comes with a Pytorch DataLoader, and pre-computed node and edge features. But the latter can be fully customized and modified. The source data for CellTypeGraph Benchmark can be also manually download from zenodo.org.
In the package we also include evaluation code and examples.
To see our most recent results check out the leadboard page in the repository wiki.
- Linux
- Anaconda / miniconda
- python >= 3.8
- tqdm
- h5py
- requests
- pyyaml
- numba
- pytorch
- torchmetrics
- pytorch-geometric
- class_resolver
- for cuda 11.3
conda create -n ctg -c rusty1s -c pytorch -c conda-forge -c lcerrone ctg-benchmark cudatoolkit=11.3
- for cuda 10.2
conda create -n ctg -c rusty1s -c pytorch -c conda-forge -c lcerrone ctg-benchmark cudatoolkit=10.2
- for cpu only
conda create -n ctg -c rusty1s -c pytorch -c conda-forge -c lcerrone ctg-benchmark cpuonly
- A simple GCN training example can be found in examples.
- create CellTypeGraph cross validation loader
from ctg_benchmark.loaders import get_cross_validation_loaders
loaders_dict = get_cross_validation_loaders(root='./ctg_data/')
where loaders_dict
is a dictionary that contains 5 tuple of training and validation data-loaders.
for split, loader_dict in loaders_dict.items():
train_loader = loader_dict['train']
val_loader = loader_dict['val']
- Alternatively for quicker experimentation's one can create a simples train/val/test split as:
from ctg_benchmark.loaders import get_split_loaders
loader = get_split_loaders(root='./ctg_data/',)
print(loader['train'], loader['val'], loader['test'])
- Simple evaluation: For evaluation
conveniently wraps several metrics as implemented intorchmetrics
. Single class results can be aggregate by usingctg_benchmark.evaluation.aggregate_class
from ctg_benchmark.evaluation import NodeClassificationMetrics, aggregate_class
eval_metrics = NodeClassificationMetrics(num_classes=9)
predictions = torch.randint(9, (1000,))
target = torch.randint(9, (1000,))
results = eval_metrics.compute_metrics(predictions, target)
class_average_accuracy, _ = aggregate_class(results['accuracy_class'], index=7)
print(f"global accuracy: {results['accuracy_micro']: .3f}")
print(f"class average accuracy: {class_average_accuracy: .3f}")
- Change default features, features processing or add new features: We did our best to make our CellTypeGraph benchmark flexible and easy to extend, since we compute several incommensurable features, we needed to a way to select, and process every feature independently, for more details see.
- Load points samples, for more details see.
- Manual download, for more details see.
- To get details on the benchmark, run the following script
- To reproduce the GCN results checkout the following script
- To reproduce all baseline results, plots and additional experiments, checkout the plantcelltype repository.
@inproceedings{cerrone2022celltypegraph, title={CellTypeGraph: A New Geometric Computer Vision Benchmark}, author={Cerrone, Lorenzo and Vijayan, Athul and Mody, Tejasvinee and Schneitz, Kay and Hamprecht, Fred A}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={20897--20907}, year={2022} }