Skip to content

Commit

Permalink
Revise the Docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
henrykironde committed Aug 9, 2024
1 parent 3f777a1 commit adc89ac
Show file tree
Hide file tree
Showing 14 changed files with 286 additions and 296 deletions.
1 change: 1 addition & 0 deletions docs/dataset_structure.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# Dataset structure
The organization of this dataset was inspired by the WILDS benchmark

16 changes: 8 additions & 8 deletions milliontrees/common/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def get_train_loader(loader,
distinct_groups=True,
n_groups_per_batch=None,
**loader_kwargs):
"""
Constructs and returns the data loader for training.
"""Constructs and returns the data loader for training.
Args:
- loader (str): Loader type. 'standard' for standard loaders and 'group' for group loaders,
which first samples groups and then samples a fixed number of examples belonging
Expand Down Expand Up @@ -87,8 +87,8 @@ def get_train_loader(loader,


def get_eval_loader(loader, dataset, batch_size, grouper=None, **loader_kwargs):
"""
Constructs and returns the data loader for evaluation.
"""Constructs and returns the data loader for evaluation.
Args:
- loader (str): Loader type. 'standard' for standard loaders.
- dataset (milliontreesDataset or milliontreesSubset): Data
Expand All @@ -108,10 +108,10 @@ def get_eval_loader(loader, dataset, batch_size, grouper=None, **loader_kwargs):


class GroupSampler:
"""
Constructs batches by first sampling groups,
then sampling data from those groups.
It drops the last batch if it's incomplete.
"""Constructs batches by first sampling groups, then sampling data from
those groups.
It drops the last batch if it's incomplete.
"""

def __init__(self, group_ids, batch_size, n_groups_per_batch,
Expand Down
12 changes: 5 additions & 7 deletions milliontrees/common/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,18 @@


class Grouper:
"""
Groupers group data points together based on their metadata.
They are used for training and evaluation,
e.g., to measure the accuracies of different groups of data.
"""Groupers group data points together based on their metadata.
They are used for training and evaluation, e.g., to measure the
accuracies of different groups of data.
"""

def __init__(self):
raise NotImplementedError

@property
def n_groups(self):
"""
The number of groups defined by this Grouper.
"""
"""The number of groups defined by this Grouper."""
return self._n_groups

def metadata_to_group(self, metadata, return_counts=False):
Expand Down
20 changes: 9 additions & 11 deletions milliontrees/common/metrics/all_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ def binary_logits_to_score(logits):


def multiclass_logits_to_pred(logits):
"""
Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions
by taking an argmax at the last dimension
"""
"""Takes multi-class logits of size (batch_size, ..., n_classes) and
returns predictions by taking an argmax at the last dimension."""
assert logits.dim() > 1
return logits.argmax(-1)

Expand Down Expand Up @@ -326,7 +324,8 @@ def __init__(self, name=None):


class PrecisionAtRecall(Metric):
"""Given a specific model threshold, determine the precision score achieved"""
"""Given a specific model threshold, determine the precision score
achieved."""

def __init__(self, threshold, score_fn=None, name=None):
self.score_fn = score_fn
Expand All @@ -346,8 +345,9 @@ def worst(self, metrics):


class DummyMetric(Metric):
"""
For testing purposes. This Metric always returns -1.
"""For testing purposes.
This Metric always returns -1.
"""

def __init__(self, prediction_fn=None, name=None):
Expand All @@ -370,10 +370,8 @@ def worst(self, metrics):


class DetectionAccuracy(ElementwiseMetric):
"""
Given a specific Intersection over union threshold,
determine the accuracy achieved for a one-class detector
"""
"""Given a specific Intersection over union threshold, determine the
accuracy achieved for a one-class detector."""

def __init__(self, iou_threshold=0.5, score_threshold=0.5, name=None):
self.iou_threshold = iou_threshold
Expand Down
25 changes: 15 additions & 10 deletions milliontrees/common/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ def __init__(self, loss_fn, name=None):
super().__init__(name=name)

def _compute(self, y_pred, y_true):
"""
Helper for computing element-wise metric, implemented for each metric
"""Helper for computing element-wise metric, implemented for each
metric.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Expand All @@ -23,8 +24,9 @@ def _compute(self, y_pred, y_true):
return self.loss_fn(y_pred, y_true)

def worst(self, metrics):
"""
Given a list/numpy array/Tensor of metrics, computes the worst-case metric
"""Given a list/numpy array/Tensor of metrics, computes the worst-case
metric.
Args:
- metrics (Tensor, numpy array, or list): Metrics
Output:
Expand All @@ -42,8 +44,9 @@ def __init__(self, loss_fn, name=None):
super().__init__(name=name)

def _compute_element_wise(self, y_pred, y_true):
"""
Helper for computing element-wise metric, implemented for each metric
"""Helper for computing element-wise metric, implemented for each
metric.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Expand All @@ -53,8 +56,9 @@ def _compute_element_wise(self, y_pred, y_true):
return self.loss_fn(y_pred, y_true)

def worst(self, metrics):
"""
Given a list/numpy array/Tensor of metrics, computes the worst-case metric
"""Given a list/numpy array/Tensor of metrics, computes the worst-case
metric.
Args:
- metrics (Tensor, numpy array, or list): Metrics
Output:
Expand All @@ -81,8 +85,9 @@ def _compute_flattened(self, flattened_y_pred, flattened_y_true):
return flattened_loss

def worst(self, metrics):
"""
Given a list/numpy array/Tensor of metrics, computes the worst-case metric
"""Given a list/numpy array/Tensor of metrics, computes the worst-case
metric.
Args:
- metrics (Tensor, numpy array, or list): Metrics
Output:
Expand Down
84 changes: 39 additions & 45 deletions milliontrees/common/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@


class Metric:
"""
Parent class for metrics.
"""
"""Parent class for metrics."""

def __init__(self, name):
self._name = name

def _compute(self, y_pred, y_true):
"""
Helper function for computing the metric.
Subclasses should implement this.
"""Helper function for computing the metric. Subclasses should
implement this.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Expand All @@ -24,8 +22,9 @@ def _compute(self, y_pred, y_true):
return NotImplementedError

def worst(self, metrics):
"""
Given a list/numpy array/Tensor of metrics, computes the worst-case metric
"""Given a list/numpy array/Tensor of metrics, computes the worst-case
metric.
Args:
- metrics (Tensor, numpy array, or list): Metrics
Output:
Expand All @@ -35,46 +34,42 @@ def worst(self, metrics):

@property
def name(self):
"""
Metric name.
Used to name the key in the results dictionaries returned by the metric.
"""Metric name.
Used to name the key in the results dictionaries returned by the
metric.
"""
return self._name

@property
def agg_metric_field(self):
"""
The name of the key in the results dictionary returned by Metric.compute().
This should correspond to the aggregate metric computed on all of y_pred and y_true,
in contrast to a group-wise evaluation.
"""The name of the key in the results dictionary returned by
Metric.compute().
This should correspond to the aggregate metric computed on all
of y_pred and y_true, in contrast to a group-wise evaluation.
"""
return f'{self.name}_all'

def group_metric_field(self, group_idx):
"""
The name of the keys corresponding to individual group evaluations
in the results dictionary returned by Metric.compute_group_wise().
"""
"""The name of the keys corresponding to individual group evaluations
in the results dictionary returned by Metric.compute_group_wise()."""
return f'{self.name}_group:{group_idx}'

@property
def worst_group_metric_field(self):
"""
The name of the keys corresponding to the worst-group metric
in the results dictionary returned by Metric.compute_group_wise().
"""
"""The name of the keys corresponding to the worst-group metric in the
results dictionary returned by Metric.compute_group_wise()."""
return f'{self.name}_wg'

def group_count_field(self, group_idx):
"""
The name of the keys corresponding to each group's count
in the results dictionary returned by Metric.compute_group_wise().
"""
"""The name of the keys corresponding to each group's count in the
results dictionary returned by Metric.compute_group_wise()."""
return f'count_group:{group_idx}'

def compute(self, y_pred, y_true, return_dict=True):
"""
Computes metric. This is a wrapper around _compute.
"""Computes metric. This is a wrapper around _compute.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Expand All @@ -98,8 +93,8 @@ def compute(self, y_pred, y_true, return_dict=True):
return agg_metric

def compute_group_wise(self, y_pred, y_true, g, n_groups, return_dict=True):
"""
Computes metrics for each group. This is a wrapper around _compute.
"""Computes metrics for each group. This is a wrapper around _compute.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Expand Down Expand Up @@ -146,13 +141,12 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups):


class ElementwiseMetric(Metric):
"""
Averages.
"""
"""Averages."""

def _compute_element_wise(self, y_pred, y_true):
"""
Helper for computing element-wise metric, implemented for each metric
"""Helper for computing element-wise metric, implemented for each
metric.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Expand All @@ -162,8 +156,9 @@ def _compute_element_wise(self, y_pred, y_true):
raise NotImplementedError

def worst(self, metrics):
"""
Given a list/numpy array/Tensor of metrics, computes the worst-case metric
"""Given a list/numpy array/Tensor of metrics, computes the worst-case
metric.
Args:
- metrics (Tensor, numpy array, or list): Metrics
Output:
Expand All @@ -172,8 +167,8 @@ def worst(self, metrics):
raise NotImplementedError

def _compute(self, y_pred, y_true):
"""
Helper function for computing the metric.
"""Helper function for computing the metric.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Expand All @@ -193,14 +188,13 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups):

@property
def agg_metric_field(self):
"""
The name of the key in the results dictionary returned by Metric.compute().
"""
"""The name of the key in the results dictionary returned by
Metric.compute()."""
return f'{self.name}_avg'

def compute_element_wise(self, y_pred, y_true, return_dict=True):
"""
Computes element-wise metric
"""Computes element-wise metric.
Args:
- y_pred (Tensor): Predicted targets or model output
- y_true (Tensor): True targets
Expand Down
17 changes: 10 additions & 7 deletions milliontrees/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def split_into_groups(g):


def get_counts(g, n_groups):
"""
This differs from split_into_groups in how it handles missing groups.
get_counts always returns a count Tensor of length n_groups,
whereas split_into_groups returns a unique_counts Tensor
whose length is the number of unique groups present in g.
"""This differs from split_into_groups in how it handles missing groups.
get_counts always returns a count Tensor of length n_groups, whereas
split_into_groups returns a unique_counts Tensor whose length is the number
of unique groups present in g.
Args:
- g (Tensor): Vector of groups
Returns:
Expand Down Expand Up @@ -138,8 +138,11 @@ def shuffle_arr(arr, seed=None):


def threshold_at_recall(y_pred, y_true, global_recall=60):
""" Calculate the model threshold to use to achieve a desired global_recall level. Assumes that
y_true is a vector of the true binary labels."""
"""Calculate the model threshold to use to achieve a desired global_recall
level.
Assumes that y_true is a vector of the true binary labels.
"""
return np.percentile(y_pred[y_true == 1], 100 - global_recall)


Expand Down
Loading

0 comments on commit adc89ac

Please sign in to comment.