Skip to content

Commit

Permalink
Add option to evaluate different n_knots (#49)
Browse files Browse the repository at this point in the history
* Add method for evaluating diffent choices of n_knots.

Method that evaluates different choices for n_knots by fitting GAMs
on a subset of the genes for all given options for n_knots.
The different GAMs are evaluated by comparing the AICs.
Additionally, there is the possibility to plot different statistics
of the AIC.

Adds method returning the AIC of a GAM

---------

Co-authored-by: Philipp Weiler <weiler.philipp@gmail.com>
  • Loading branch information
KemperNiklas and WeilerP authored Mar 15, 2023
1 parent 485ed09 commit 8f212f7
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ anndata
rpy2
scipy
conorm
seaborn
9 changes: 9 additions & 0 deletions tradeseq/gam/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, gam):
"""
self._gam = gam
self.covariance_matrix: np.ndarray = _get_covariance_matrix(gam)
self.aic = _get_aic(gam)[0]

def predict(
self,
Expand Down Expand Up @@ -80,6 +81,14 @@ def _get_covariance_matrix(gam) -> np.ndarray:
return covariance


def _get_aic(gam) -> int:
np_cv_rules = default_converter + numpy2ri.converter + pandas2ri.converter
with localconverter(np_cv_rules):
ro.globalenv["gam"] = gam
aic = ro.r("gam$aic")
return aic


def _assign_pseudotimes(pseudotimes: np.ndarray):
"""Assign pseudotimes in R.
Expand Down
110 changes: 110 additions & 0 deletions tradeseq/gam/_gam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from anndata import AnnData
from conorm import tmm_norm_factors
from scipy.sparse import issparse
Expand Down Expand Up @@ -183,6 +184,17 @@ def get_covariance(self, gene_id: int) -> np.ndarray:

return self._model[gene_id].covariance_matrix

def get_aic(self) -> List[float]:
"""Get Akaike information criterion (AIC) for each fitted GAM.
Returns
-------
List of AICs
"""
self.check_is_fitted()

return [model.aic for model in self._model]

def plot(
self,
gene_id: int,
Expand Down Expand Up @@ -557,6 +569,104 @@ def fit(self, family: str = "nb", n_knots: int = 6, n_jobs: int = 1):
n_jobs,
)

def evaluate_n_knots(
self,
n_knots_options: List[int],
family: str = "nb",
n_vars: int = 500,
n_jobs: int = 1,
plot: bool = True,
) -> pd.DataFrame:
"""Evaluate different choices for number of knots.
Parameters
----------
n_knots_options
List of different options for number of knots (usual choices for are between 3 and 10).
family
Family of probability distributions that is used for fitting the GAM. Defaults to the negative binomial
distributions. Can be any family available in mgcv.gam.
n_vars
Number of randomly sampled genes that are used for the evaluation.
n_jobs
Number of jobs that are used for fitting. If n_jobs > 2, the R library biocParallel is used for fitting the
GAMs in parallel.
plot
Boolean indicating whether plots evaluating the different choices for number of knots should be shown.
Returns
-------
Pandas DataFrame containing AIC of the sampled genes for the different choices for n_knots and the mean AIC,
the relative mean AIC and the number of knots that have the optimal AIC for this value of n_knots.
"""
if any(n_knots < 3 for n_knots in n_knots_options):
raise RuntimeError(
"Cannot fit with fewer than 3 knots, please increase the number of knots."
)

aic = []
var_ind_sample = np.random.randint(0, self._adata.n_vars, size=(n_vars,))
gam = GAM(
self._adata[:, var_ind_sample],
self._n_lineages,
self._time_key,
self._weights_key,
self._offset_key,
self._layer_key,
)

for n_knots in n_knots_options:
gam.fit(family, n_knots, n_jobs)
aic.append(gam.get_aic())

var_names = self._adata.var_names[var_ind_sample]
result = pd.DataFrame(aic, index=n_knots_options, columns=var_names)
result["Number of knots"] = n_knots_options
result["Mean AIC"] = result[var_names].mean(axis=1)
result["Mean Relative AIC"] = (
result[var_names] / result[var_names].iloc[0]
).mean(axis=1)
result["Number of Genes with optimal n_knots"] = (
result[var_names] == result[var_names].min(axis=0)
).sum(axis=1)

if plot:
fig, axs = plt.subplots(ncols=4)

sns.boxplot(
data=(result[var_names] - result[var_names].mean(axis=0)).T, ax=axs[0]
)
axs[0].set_xlabel("Number of knots")
axs[0].set_ylabel("Deviation from gene-wise average AIC")

sns.scatterplot(data=result, x="Number of knots", y="Mean AIC", ax=axs[1])
sns.lineplot(data=result, x="Number of knots", y="Mean AIC", ax=axs[1])

sns.scatterplot(
data=result, x="Number of knots", y="Mean Relative AIC", ax=axs[2]
)
sns.lineplot(
data=result, x="Number of knots", y="Mean Relative AIC", ax=axs[2]
)

sns.scatterplot(
data=result,
x="Number of knots",
y="Number of Genes with optimal n_knots",
ax=axs[3],
)
sns.lineplot(
data=result,
x="Number of knots",
y="Number of Genes with optimal n_knots",
ax=axs[3],
)

fig.tight_layout(pad=3.0)
plt.show()

return result


def _indices_to_indicator_matrix(indices: np.ndarray, n_indices: int):
"""Compute indicator matrice from indices.
Expand Down

0 comments on commit 8f212f7

Please sign in to comment.