Skip to content

Commit

Permalink
Add mock data
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 25, 2022
1 parent 7acb34c commit d8309f4
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 4 deletions.
5 changes: 4 additions & 1 deletion chemicalx/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@
"DrugFeatureSet",
"DrugPairBatch",
"LabeledTriples",
# Datasets
# Abstract datasets
"dataset_resolver",
"DatasetLoader",
"RemoteDatasetLoader",
"LocalDatasetLoader",
# Datasets
"DrugbankDDI",
"TwoSides",
"DrugComb",
Expand Down
5 changes: 3 additions & 2 deletions chemicalx/data/datasetloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from abc import ABC, abstractmethod
from functools import lru_cache
from itertools import chain
from pathlib import Path
from textwrap import dedent
from typing import Dict, Mapping, Optional, Sequence, Tuple, cast

Expand Down Expand Up @@ -290,9 +291,9 @@ def __init__(self):
class LocalDatasetLoader(DatasetLoader, ABC):
"""A dataset loader that processes and caches data locally."""

def __init__(self):
def __init__(self, directory: Optional[Path] = None):
"""Instantiate the local dataset loader."""
self.directory = pystow.join("chemicalx", self.__class__.__name__.lower())
self.directory = directory or pystow.join("chemicalx", self.__class__.__name__.lower())
self.drugs_path = self.directory.joinpath(DRUG_FILE_NAME)
self.contexts_path = self.directory.joinpath("context.tsv")
self.labels_path = self.directory.joinpath(LABELS_FILE_NAME)
Expand Down
45 changes: 44 additions & 1 deletion tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,52 @@
"""Tests for datasets."""

import pathlib
import unittest
from typing import ClassVar

from chemicalx.data import DatasetLoader, DrugbankDDI, DrugComb, DrugCombDB, TwoSides
from chemicalx.data import (
DatasetLoader,
DrugbankDDI,
DrugComb,
DrugCombDB,
LocalDatasetLoader,
TwoSides,
)

HERE = pathlib.Path(__file__).parent.resolve()


class TestDatasetLoader(LocalDatasetLoader):
"""A mock dataset loader."""

def preprocess(self):
"""A mock preprocessing function."""


class TestMock(unittest.TestCase):
"""A test case for the mock dataset."""

loader: DatasetLoader

def setUp(self) -> None:
"""Set up the test case."""
self.loader = TestDatasetLoader(directory=HERE.joinpath("test_dataset"))

def test_get_context_features(self):
"""Test the number of context features."""
assert self.loader.num_contexts == 2
assert self.loader.context_channels == 5

def test_get_drug_features(self):
"""Test the number of drug features."""
assert self.loader.num_drugs == 2
assert self.loader.drug_channels == 4

def test_get_labeled_triples(self):
"""Test the shape of the labeled triples."""
assert self.loader.num_labeled_triples == 2
labeled_triples = self.loader.get_labeled_triples()
assert labeled_triples.data.shape == (2, 4)


class TestDrugComb(unittest.TestCase):
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_dataset/context.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
A2058 1.0 2.0 0.0 0.0 0.0
A2780 2.0 1.0 0.0 0.0 0.0
20 changes: 20 additions & 0 deletions tests/unit/test_dataset/drug_set.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"5-FU": {
"smiles": "O=c1[nH]cc(F)c(=O)[nH]1",
"features": [
0,
1,
1,
0
]
},
"ABT-888": {
"smiles": "CC1(c2nc3c(C(N)=O)cccc3[nH]2)CCCN1",
"features": [
0,
1,
0,
1
]
}
}
3 changes: 3 additions & 0 deletions tests/unit/test_dataset/labeled_triples.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
drug_1 drug_2 context label
5-FU ABT-888 A2058 7.6935301658
5-FU ABT-888 A2780 7.7780530600999995

0 comments on commit d8309f4

Please sign in to comment.