Skip to content

Commit

Permalink
Merge branch 'master' into 663-move-cnv-frequencies
Browse files Browse the repository at this point in the history
  • Loading branch information
jonbrenas authored Dec 2, 2024
2 parents 8c7b2c6 + 1ba4714 commit b174fd6
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 15 deletions.
12 changes: 11 additions & 1 deletion malariagen_data/anoph/frq_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Parameter definitions for functions computing and plotting allele frequencies."""

from typing import Literal
from typing import Literal, List, Optional, Tuple, Union

import xarray as xr
from typing_extensions import Annotated, TypeAlias
Expand Down Expand Up @@ -70,3 +70,13 @@
bool,
"Include columns with allele counts and number of non-missing allele calls (nobs).",
]

taxa: TypeAlias = Annotated[
Optional[Union[str, List[str], Tuple[str, ...]]],
"The taxon or taxa to restrict the dataset to.",
]

areas: TypeAlias = Annotated[
Optional[Union[str, List[str], Tuple[str, ...]]],
"The area or areas to restrict the dataset to.",
]
14 changes: 14 additions & 0 deletions malariagen_data/anoph/snp_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,8 @@ def plot_frequencies_time_series(
legend_sizing: plotly_params.legend_sizing = "constant",
show: plotly_params.show = True,
renderer: plotly_params.renderer = None,
taxa: frq_params.taxa = None,
areas: frq_params.areas = None,
**kwargs,
) -> plotly_params.figure:
# Handle title.
Expand All @@ -947,6 +949,18 @@ def plot_frequencies_time_series(
df_cohorts = ds[cohort_vars].to_dataframe()
df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore

# If specified, restrict the dataframe by taxa.
if isinstance(taxa, str):
df_cohorts = df_cohorts[df_cohorts["taxon"] == taxa]
elif isinstance(taxa, (list, tuple)):
df_cohorts = df_cohorts[df_cohorts["taxon"].isin(taxa)]

# If specified, restrict the dataframe by areas.
if isinstance(areas, str):
df_cohorts = df_cohorts[df_cohorts["area"] == areas]
elif isinstance(areas, (list, tuple)):
df_cohorts = df_cohorts[df_cohorts["area"].isin(areas)]

# Extract variant labels.
variant_labels = ds["variant_label"].values

Expand Down
66 changes: 52 additions & 14 deletions notebooks/plot_frequencies_space_time.ipynb
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "47f669f3",
"metadata": {},
"outputs": [],
"source": [
"import malariagen_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f820bc66-2fb2-4ca2-9b54-824e50d61a0a",
"metadata": {},
"outputs": [],
"source": [
"import malariagen_data\n",
"\n",
"ag3 = malariagen_data.Ag3(\n",
" \"simplecache::gs://vo_agam_release_master_us_central1\",\n",
" simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
Expand All @@ -23,8 +31,6 @@
"metadata": {},
"outputs": [],
"source": [
"import malariagen_data\n",
"\n",
"af1 = malariagen_data.Af1(\n",
" \"simplecache::gs://vo_afun_release_master_us_central1\",\n",
" simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
Expand Down Expand Up @@ -69,6 +75,26 @@
"ag3.plot_frequencies_time_series(ds, height=500, width=1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "790c99e8",
"metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_time_series(ds, taxa=\"gambiae\", height=500, width=1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1bfc7298",
"metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_time_series(ds, taxa=(\"gambiae\", \"arabiensis\"), height=500, width=1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -252,6 +278,26 @@
"ag3.plot_frequencies_time_series(ds, height=900, width=900)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e16ab3fe",
"metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_time_series(ds, areas=\"BF-09\", height=400, width=900)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "26af27a1",
"metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_time_series(ds, areas=(\"BF-09\", \"TZ-25\"), height=400, width=900)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -336,19 +382,11 @@
"source": [
"af1.plot_frequencies_interactive_map(ds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a512b459",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "mgen_data_py3.11",
"language": "python",
"name": "python3"
},
Expand All @@ -362,7 +400,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.5"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
94 changes: 94 additions & 0 deletions tests/anoph/test_snp_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,6 +1528,100 @@ def test_plot_frequencies_time_series(
assert isinstance(fig, go.Figure)


@parametrize_with_cases("fixture,api", cases=".")
def test_plot_frequencies_time_series_with_taxa(
fixture,
api: AnophelesSnpFrequencyAnalysis,
):
# Pick test parameters at random.
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
transcript = random_transcript(api=api).name
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
period_by = random.choice(["year", "quarter", "month"])

# Pick a random taxon and taxa from valid taxa.
sample_sets_taxa = (
api.sample_metadata(sample_sets=sample_sets)["taxon"].dropna().unique().tolist()
)
taxon = random.choice(sample_sets_taxa)
taxa = random.sample(sample_sets_taxa, random.randint(1, len(sample_sets_taxa)))

# Compute SNP frequencies.
ds = api.snp_allele_frequencies_advanced(
transcript=transcript,
area_by=area_by,
period_by=period_by,
sample_sets=sample_sets,
min_cohort_size=1, # Don't exclude any samples.
site_mask=site_mask,
)

# Trim things down a bit for speed.
ds = ds.isel(variants=slice(0, 100))

# Plot with taxon.
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon)

# Test taxon plot.
assert isinstance(fig, go.Figure)

# Plot with taxa.
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxa)

# Test taxa plot.
assert isinstance(fig, go.Figure)


@parametrize_with_cases("fixture,api", cases=".")
def test_plot_frequencies_time_series_with_areas(
fixture,
api: AnophelesSnpFrequencyAnalysis,
):
# Pick test parameters at random.
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
transcript = random_transcript(api=api).name
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
period_by = random.choice(["year", "quarter", "month"])

# Compute SNP frequencies.
ds = api.snp_allele_frequencies_advanced(
transcript=transcript,
area_by=area_by,
period_by=period_by,
sample_sets=sample_sets,
min_cohort_size=1, # Don't exclude any samples.
site_mask=site_mask,
)

# Trim things down a bit for speed.
ds = ds.isel(variants=slice(0, 100))

# Extract cohorts into a DataFrame.
cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
df_cohorts = ds[cohort_vars].to_dataframe()

# Pick a random area and areas from valid areas.
cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist()
area = random.choice(cohorts_areas)
areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas)))

# Plot with area.
fig = api.plot_frequencies_time_series(ds, show=False, areas=area)

# Test areas plot.
assert isinstance(fig, go.Figure)

# Plot with areas.
fig = api.plot_frequencies_time_series(ds, show=False, areas=areas)

# Test area plot.
assert isinstance(fig, go.Figure)


@parametrize_with_cases("fixture,api", cases=".")
def test_plot_frequencies_interactive_map(
fixture,
Expand Down

0 comments on commit b174fd6

Please sign in to comment.