Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: track feature missingness rates #335

Merged
merged 7 commits into from
Dec 13, 2023
47 changes: 42 additions & 5 deletions src/otg/dataset/l2g_feature_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,25 @@

@dataclass
class L2GFeatureMatrix(Dataset):
"""Dataset with features for Locus to Gene prediction."""
"""Dataset with features for Locus to Gene prediction.

Attributes:
features_list (list[str] | None): List of features to use. If None, all possible features are used.
"""

features_list: list[str] | None = None

def __post_init__(self: L2GFeatureMatrix) -> None:
"""Post-initialisation to set the features list. If not provided, all columns except the fixed ones are used."""
fixed_cols = ["studyLocusId", "geneId", "goldStandardSet"]
self.features_list = self.features_list or [
col for col in self._df.columns if col not in fixed_cols
]

@classmethod
def generate_features(
cls: Type[L2GFeatureMatrix],
features_list: list[str],
study_locus: StudyLocus,
study_index: StudyIndex,
variant_gene: V2G,
Expand All @@ -34,6 +48,7 @@ def generate_features(
"""Generate features from the OTG datasets.

Args:
features_list (list[str]): List of features to generate
study_locus (StudyLocus): Study locus dataset
study_index (StudyIndex): Study index dataset
variant_gene (V2G): Variant to gene dataset
Expand Down Expand Up @@ -65,6 +80,7 @@ def generate_features(
fm, ["studyLocusId", "geneId"], "featureName", "featureValue"
),
_schema=cls.get_schema(),
features_list=features_list,
)
raise ValueError("L2G Feature matrix is empty")

Expand All @@ -77,6 +93,26 @@ def get_schema(cls: type[L2GFeatureMatrix]) -> StructType:
"""
return parse_spark_schema("l2g_feature_matrix.json")

def calculate_feature_missingness_rate(
self: L2GFeatureMatrix,
) -> dict[str, float]:
"""Calculate the proportion of missing values in each feature.

Returns:
dict[str, float]: Dictionary of feature names and their missingness rate.

Raises:
ValueError: If no features are found.
"""
total_count = self._df.count()
if not self.features_list:
raise ValueError("No features found")

return {
feature: (self._df.filter(self._df[feature].isNull()).count() / total_count)
for feature in self.features_list
}

def fill_na(
self: L2GFeatureMatrix, value: float = 0.0, subset: list[str] | None = None
) -> L2GFeatureMatrix:
Expand All @@ -93,18 +129,19 @@ def fill_na(
return self

def select_features(
self: L2GFeatureMatrix, features_list: list[str]
self: L2GFeatureMatrix, features_list: list[str] | None
) -> L2GFeatureMatrix:
"""Select a subset of features from the feature matrix.

Args:
features_list (list[str]): List of features to select
features_list (list[str] | None): List of features to select

Returns:
L2GFeatureMatrix: L2G feature matrix dataset
"""
fixed_rows = ["studyLocusId", "geneId", "goldStandardSet"]
self.df = self._df.select(fixed_rows + features_list)
features_list = features_list or self.features_list
fixed_cols = ["studyLocusId", "geneId", "goldStandardSet"]
self.df = self._df.select(fixed_cols + features_list) # type: ignore
return self

def train_test_split(
Expand Down
8 changes: 6 additions & 2 deletions src/otg/dataset/l2g_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_schema(cls: type[L2GPrediction]) -> StructType:
def from_credible_set(
cls: Type[L2GPrediction],
model_path: str,
features_list: list[str],
study_locus: StudyLocus,
study_index: StudyIndex,
v2g: V2G,
Expand All @@ -53,6 +54,7 @@ def from_credible_set(

Args:
model_path (str): Path to the fitted model
features_list (list[str]): List of features to use for the model
study_locus (StudyLocus): Study locus dataset
study_index (StudyIndex): Study index dataset
v2g (V2G): Variant to gene dataset
Expand All @@ -61,6 +63,7 @@ def from_credible_set(
L2GPrediction: L2G dataset
"""
fm = L2GFeatureMatrix.generate_features(
features_list=features_list,
study_locus=study_locus,
study_index=study_index,
variant_gene=v2g,
Expand All @@ -71,8 +74,9 @@ def from_credible_set(
_df=(
LocusToGeneModel.load_from_disk(
model_path,
features_list=fm.df.drop("studyLocusId", "geneId").columns,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

much cleaner

).predict(fm)
features_list=features_list,
)
.predict(fm)
# the probability of the positive class is the second element inside the probability array
# - this is selected as the L2G probability
.select(
Expand Down
2 changes: 2 additions & 0 deletions src/otg/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __post_init__(self: LocusToGeneStep) -> None:
)

fm = L2GFeatureMatrix.generate_features(
features_list=self.features_list,
study_locus=credible_set,
study_index=studies,
variant_gene=v2g,
Expand Down Expand Up @@ -193,6 +194,7 @@ def __post_init__(self: LocusToGeneStep) -> None:
)
predictions = L2GPrediction.from_credible_set(
self.model_path,
self.features_list,
credible_set,
studies,
v2g,
Expand Down
34 changes: 8 additions & 26 deletions src/otg/method/l2g/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,19 @@ def log_to_wandb(
wandb_evaluator.evaluate(results)
## Track feature importance
wandb_run.log({"importances": self.get_feature_importance()})
## Track training set metadata
## Track training set
training_table = wandb.Table(dataframe=training_data.df.toPandas())
wandb_run.log({"trainingSet": training_table})
# Count number of positive and negative labels
gs_counts_dict = {
"goldStandard" + row["goldStandardSet"].capitalize(): row["count"]
for row in training_data.df.groupBy("goldStandardSet").count().collect()
}
wandb_run.log(gs_counts_dict)
training_table = wandb.Table(dataframe=training_data.df.toPandas())
wandb_run.log({"trainingSet": wandb.Table(dataframe=training_table)})
# Missingness rates
wandb_run.log(
{"missingnessRates": training_data.calculate_feature_missingness_rate()}
)

@classmethod
def load_from_disk(
Expand Down Expand Up @@ -218,30 +223,7 @@ def evaluate(
labelCol="label", predictionCol="prediction"
)

print("Evaluating model...") # noqa: T201
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice cleanup

print( # noqa: T201
"... Area under ROC curve:",
binary_evaluator.evaluate(
results, {binary_evaluator.metricName: "areaUnderROC"}
),
)
print( # noqa: T201
"... Area under Precision-Recall curve:",
binary_evaluator.evaluate(
results, {binary_evaluator.metricName: "areaUnderPR"}
),
)
print( # noqa: T201
"... Accuracy:",
multi_evaluator.evaluate(results, {multi_evaluator.metricName: "accuracy"}),
)
print( # noqa: T201
"... F1 score:",
multi_evaluator.evaluate(results, {multi_evaluator.metricName: "f1"}),
)

if wandb_run_name and training_data:
print("Logging to W&B...") # noqa: T201
run = wandb.init(
project=self.wandb_l2g_project_name,
config=hyperparameters,
Expand Down