Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize documentation style with darglint #70

Merged
merged 2 commits into from
Feb 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 17 additions & 28 deletions chemicalx/data/batchgenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@ def __init__(
):
"""Initialize a batch generator.

Args:
batch_size: Number of drug pairs per batch.
context_features: Indicator whether the batch should include biological context features.
drug_features: Indicator whether the batch should include drug features.
drug_molecules: Indicator whether the batch should include drug molecules
context_feature_set: A context feature set for feature generation.
drug_feature_set: A drug feature set for feature generation.
labeled_triples: A labeled triples object used to generate batches.
:param batch_size: Number of drug pairs per batch.
:param context_features: Indicator whether the batch should include biological context features.
:param drug_features: Indicator whether the batch should include drug features.
:param drug_molecules: Indicator whether the batch should include drug molecules
:param context_feature_set: A context feature set for feature generation.
:param drug_feature_set: A drug feature set for feature generation.
:param labeled_triples: A labeled triples object used to generate batches.
"""
self.batch_size = batch_size
self.context_features = context_features
Expand All @@ -53,10 +52,8 @@ def __init__(
def _get_context_features(self, context_identifiers: Iterable[str]) -> Optional[torch.FloatTensor]:
"""Get the context features as a matrix.

Args:
context_identifiers (pd.Series): The context identifiers of interest.
Returns:
context_features (torch.FloatTensor): The matrix of biological context features.
:param context_identifiers: The context identifiers of interest.
:returns: The matrix of biological context features.
"""
if not self.context_features or self.context_feature_set is None:
return None
Expand All @@ -65,10 +62,8 @@ def _get_context_features(self, context_identifiers: Iterable[str]) -> Optional[
def _get_drug_features(self, drug_identifiers: Iterable[str]) -> Optional[torch.FloatTensor]:
"""Get the global drug features as a matrix.

Args:
drug_identifiers: The drug identifiers of interest.
Returns:
drug_features: The matrix of drug features.
:param drug_identifiers: The drug identifiers of interest.
:returns: The matrix of drug features.
"""
if not self.drug_features or self.drug_feature_set is None:
return None
Expand All @@ -77,10 +72,8 @@ def _get_drug_features(self, drug_identifiers: Iterable[str]) -> Optional[torch.
def _get_drug_molecules(self, drug_identifiers: Iterable[str]) -> Optional[PackedGraph]:
"""Get the molecular structure of drugs.

Args:
drug_identifiers: The drug identifiers of interest.
Returns:
molecules: The molecules diagonally batched together for message passing.
:param drug_identifiers: The drug identifiers of interest.
:returns: The molecules diagonally batched together for message passing.
"""
if not self.drug_molecules or self.drug_feature_set is None:
return None
Expand All @@ -90,21 +83,17 @@ def _get_drug_molecules(self, drug_identifiers: Iterable[str]) -> Optional[Packe
def _transform_labels(cls, labels: Sequence[float]) -> torch.FloatTensor:
"""Transform the labels from a chunk of the labeled triples frame.

Args:
labels: The drug pair binary labels.
Returns:
labels : The label target vector as a column vector.
:param labels: The drug pair binary labels.
:returns: The label target vector as a column vector.
"""
return torch.FloatTensor(np.array(labels).reshape(-1, 1))

def generate_batch(self, batch_frame: pd.DataFrame) -> DrugPairBatch:
"""
Generate a batch of drug features, molecules, context features and labels for a set of pairs.

Args:
batch_frame (pd.DataFrame): The labeled pairs of interest.
Returns:
batch (DrugPairBatch): A batch of tensors for the pairs.
:param batch_frame: The labeled pairs of interest.
:Returns: A batch of tensors for the pairs.
"""
drug_features_left = self._get_drug_features(batch_frame["drug_1"])
drug_molecules_left = self._get_drug_molecules(batch_frame["drug_1"])
Expand Down
6 changes: 2 additions & 4 deletions chemicalx/data/contextfeatureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def from_dict(cls, data: Mapping[str, Sequence[float]]) -> "ContextFeatureSet":
def get_feature_matrix(self, contexts: Iterable[str]) -> torch.FloatTensor:
"""Get the feature matrix for a list of contexts.

Args:
contexts: A list of context identifiers.
Return:
features: A matrix of context features.
:param contexts: A list of context identifiers.
:returns: A matrix of context features.
"""
return torch.cat([self.data[context] for context in contexts])
94 changes: 27 additions & 67 deletions chemicalx/data/datasetloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,14 @@ def get_generator(
) -> BatchGenerator:
"""Initialize a batch generator.

Args:
batch_size: Number of drug pairs per batch.
context_features: Indicator whether the batch should include biological context features.
drug_features: Indicator whether the batch should include drug features.
drug_molecules: Indicator whether the batch should include drug molecules
labels: Indicator whether the batch should include drug pair labels.
labeled_triples: A labeled triples object used to generate batches. If none is given, will use
all triples from the dataset.
:param batch_size: Number of drug pairs per batch.
:param context_features: Indicator whether the batch should include biological context features.
:param drug_features: Indicator whether the batch should include drug features.
:param drug_molecules: Indicator whether the batch should include drug molecules
:param labeled_triples:
A labeled triples object used to generate batches. If none is given, will use
all triples from the dataset.
:returns: A batch generator
"""
return BatchGenerator(
batch_size=batch_size,
Expand All @@ -96,12 +96,7 @@ def get_generator(

@abstractmethod
def get_context_features(self) -> ContextFeatureSet:
"""
Get the context feature set.

Returns:
: The ContextFeatureSet of the dataset of interest.
"""
"""Get the context feature set."""

@property
def num_contexts(self) -> int:
Expand All @@ -115,12 +110,7 @@ def context_channels(self) -> int:

@abstractmethod
def get_drug_features(self):
"""
Get the drug feature set.

Returns:
: The DrugFeatureSet of the dataset of interest.
"""
"""Get the drug feature set."""

@property
def num_drugs(self) -> int:
Expand All @@ -134,12 +124,7 @@ def drug_channels(self) -> int:

@abstractmethod
def get_labeled_triples(self) -> LabeledTriples:
"""
Get the labeled triples file from the storage.

Returns:
: The labeled triples in the dataset.
"""
"""Get the labeled triples file from the storage."""

@property
def num_labeled_triples(self) -> int:
Expand Down Expand Up @@ -168,73 +153,53 @@ class RemoteDatasetLoader(DatasetLoader):
def __init__(self, dataset_name: str):
"""Instantiate the dataset loader.

Args:
dataset_name (str): The name of the dataset.
:param dataset_name: The name of the dataset.
"""
self.base_url = "https://raw.githubusercontent.com/AstraZeneca/chemicalx/main/dataset"
self.dataset_name = dataset_name
assert dataset_name in ["drugcombdb", "drugcomb", "twosides", "drugbankddi"]

def generate_path(self, file_name: str) -> str:
"""
Generate a complete url for a dataset file.
"""Generate a complete url for a dataset file.

Args:
file_name (str): Name of the data file.
Returns:
data_path (str): The complete url to the dataset.
:param file_name: Name of the data file.
:returns: The complete url to the dataset.
"""
data_path = "/".join([self.base_url, self.dataset_name, file_name])
return data_path

def load_raw_json_data(self, path: str) -> Dict:
"""
Load a raw JSON dataset at the given path.
"""Load a raw JSON dataset at the given path.

Args:
path (str): The path to the JSON file.
Returns:
raw_data (dict): A dictionary with the data.
:param path: The path to the JSON file.
:returns: A dictionary with the data.
"""
with urllib.request.urlopen(path) as url:
raw_data = json.loads(url.read().decode())
return raw_data

def load_raw_csv_data(self, path: str) -> pd.DataFrame:
"""
Load a CSV dataset at the given path.
"""Load a CSV dataset at the given path.

Args:
path (str): The path to the triples CSV file.
Returns:
raw_data (pd.DataFrame): A pandas DataFrame with the data.
:param path: The path to the triples CSV file.
:returns: A pandas DataFrame with the data.
"""
data_bytes = urllib.request.urlopen(path).read()
types = {"drug_1": str, "drug_2": str, "context": str, "label": float}
raw_data = pd.read_csv(io.BytesIO(data_bytes), encoding="utf8", sep=",", dtype=types)
return raw_data

@lru_cache(maxsize=1)
def get_context_features(self):
"""
Get the context feature set.

Returns:
: The ContextFeatureSet of the dataset of interest.
"""
def get_context_features(self) -> ContextFeatureSet:
"""Get the context feature set."""
path = self.generate_path("context_set.json")
raw_data = self.load_raw_json_data(path)
raw_data = {k: torch.FloatTensor(np.array(v).reshape(1, -1)) for k, v in raw_data.items()}
return ContextFeatureSet(raw_data)

@lru_cache(maxsize=1)
def get_drug_features(self):
"""
Get the drug feature set.

Returns:
: The DrugFeatureSet of the dataset of interest.
"""
def get_drug_features(self) -> DrugFeatureSet:
"""Get the drug feature set."""
path = self.generate_path("drug_set.json")
raw_data = self.load_raw_json_data(path)
raw_data = {
Expand All @@ -244,13 +209,8 @@ def get_drug_features(self):
return DrugFeatureSet.from_dict(raw_data)

@lru_cache(maxsize=1)
def get_labeled_triples(self):
"""
Get the labeled triples file from the storage.

Returns:
: The labeled triples in the dataset.
"""
def get_labeled_triples(self) -> LabeledTriples:
"""Get the labeled triples file from the storage."""
path = self.generate_path("labeled_triples.csv")
df = self.load_raw_csv_data(path)
return LabeledTriples(df)
Expand Down
12 changes: 4 additions & 8 deletions chemicalx/data/drugfeatureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,15 @@ def from_dict(cls, data: Dict[str, Dict]) -> "DrugFeatureSet":
def get_feature_matrix(self, drugs: Iterable[str]) -> torch.FloatTensor:
"""Get the drug feature matrix for a list of drugs.

Args:
drugs: A list of drug identifiers.
Return:
: A matrix of drug features.
:param drugs: A list of drug identifiers.
:returns: A matrix of drug features.
"""
return torch.cat([self.data[drug]["features"] for drug in drugs])

def get_molecules(self, drugs: Iterable[str]) -> PackedGraph:
"""Get the molecular structures.

Args:
drugs: A list of drug identifiers.
Return:
: The molecules batched together for message passing.
:param drugs: A list of drug identifiers.
:returns: The molecules batched together for message passing.
"""
return Graph.pack([self.data[drug]["molecule"] for drug in drugs])
Loading