Skip to content

Commit

Permalink
Tried to move random_transcript to conftest. WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonbrenas committed Dec 11, 2024
1 parent 5a57346 commit 06be149
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 253 deletions.
8 changes: 8 additions & 0 deletions tests/anoph/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,14 @@ def contigs(self) -> Tuple[str, ...]:
def random_contig(self):
return choice(self.contigs)

def random_transcript_id(self):
df_transcripts = self.genome_features.query("type == 'mRNA'")
transcript_ids = [
t.split(";")[0].split("=")[1] for t in df_transcripts.loc[:, "attributes"]
]
transcript_id = choice(transcript_ids)
return transcript_id

def random_region_str(self, region_size=None):
contig = self.random_contig()
contig_size = self.contig_sizes[contig]
Expand Down
12 changes: 5 additions & 7 deletions tests/anoph/test_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from malariagen_data import ag3 as _ag3
from malariagen_data.anoph.snp_frq import AnophelesSnpFrequencyAnalysis

from .test_snp_frq import random_transcript


@pytest.fixture
def ag3_sim_api(ag3_sim_fixture):
Expand Down Expand Up @@ -83,7 +81,7 @@ def test_plot_frequencies_heatmap(
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
min_cohort_size = random.randint(0, 2)
transcript = random_transcript(api=api).name
transcript = fixture.random_transcript_id()
cohorts = random.choice(
["admin1_year", "admin1_month", "admin2_year", "admin2_month"]
)
Expand Down Expand Up @@ -128,7 +126,7 @@ def test_plot_frequencies_time_series(
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
min_cohort_size = random.randint(0, 2)
transcript = random_transcript(api=api).name
transcript = fixture.random_transcript_id()
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
period_by = random.choice(["year", "quarter", "month"])

Expand Down Expand Up @@ -179,7 +177,7 @@ def test_plot_frequencies_time_series_with_taxa(
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
transcript = random_transcript(api=api).name
transcript = fixture.random_transcript_id()
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
period_by = random.choice(["year", "quarter", "month"])

Expand Down Expand Up @@ -225,7 +223,7 @@ def test_plot_frequencies_time_series_with_areas(
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
transcript = random_transcript(api=api).name
transcript = fixture.random_transcript_id()
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
period_by = random.choice(["year", "quarter", "month"])

Expand Down Expand Up @@ -276,7 +274,7 @@ def test_plot_frequencies_interactive_map(
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
min_cohort_size = random.randint(0, 2)
transcript = random_transcript(api=api).name
transcript = fixture.random_transcript_id()
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
period_by = random.choice(["year", "quarter", "month"])

Expand Down
246 changes: 0 additions & 246 deletions tests/anoph/test_snp_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from pytest_cases import parametrize_with_cases
import xarray as xr
from numpy.testing import assert_allclose, assert_array_equal
import plotly.graph_objects as go # type: ignore

from malariagen_data import af1 as _af1
from malariagen_data import ag3 as _ag3
Expand Down Expand Up @@ -1429,248 +1428,3 @@ def test_allele_frequencies_advanced_with_dup_samples(
api=api,
sample_sets=sample_sets,
)


@parametrize_with_cases("fixture,api", cases=".")
def test_plot_frequencies_heatmap(
fixture,
api: AnophelesSnpFrequencyAnalysis,
):
# Pick test parameters at random.
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
min_cohort_size = random.randint(0, 2)
transcript = random_transcript(api=api).name
cohorts = random.choice(
["admin1_year", "admin1_month", "admin2_year", "admin2_month"]
)

# Set up call params.
params = dict(
transcript=transcript,
cohorts=cohorts,
min_cohort_size=min_cohort_size,
site_mask=site_mask,
sample_sets=sample_sets,
)

# Test SNP allele frequencies.
df_snp = api.snp_allele_frequencies(**params)
fig = api.plot_frequencies_heatmap(df_snp, show=False, max_len=None)
assert isinstance(fig, go.Figure)

# Test amino acid change allele frequencies.
df_aa = api.aa_allele_frequencies(**params)
fig = api.plot_frequencies_heatmap(df_aa, show=False, max_len=None)
assert isinstance(fig, go.Figure)

# Test max_len behaviour.
with pytest.raises(ValueError):
api.plot_frequencies_heatmap(df_snp, show=False, max_len=len(df_snp) - 1)

# Test index parameter - if None, should use dataframe index.
fig = api.plot_frequencies_heatmap(df_snp, show=False, index=None, max_len=None)
# Not unique.
with pytest.raises(ValueError):
api.plot_frequencies_heatmap(df_snp, show=False, index="contig", max_len=None)


@parametrize_with_cases("fixture,api", cases=".")
def test_plot_frequencies_time_series(
fixture,
api: AnophelesSnpFrequencyAnalysis,
):
# Pick test parameters at random.
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
min_cohort_size = random.randint(0, 2)
transcript = random_transcript(api=api).name
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
period_by = random.choice(["year", "quarter", "month"])

# Compute SNP frequencies.
ds = api.snp_allele_frequencies_advanced(
transcript=transcript,
area_by=area_by,
period_by=period_by,
sample_sets=sample_sets,
min_cohort_size=min_cohort_size,
site_mask=site_mask,
)

# Trim things down a bit for speed.
ds = ds.isel(variants=slice(0, 100))

# Plot.
fig = api.plot_frequencies_time_series(ds, show=False)

# Test.
assert isinstance(fig, go.Figure)

# Compute amino acid change frequencies.
ds = api.aa_allele_frequencies_advanced(
transcript=transcript,
area_by=area_by,
period_by=period_by,
sample_sets=sample_sets,
min_cohort_size=min_cohort_size,
)

# Trim things down a bit for speed.
ds = ds.isel(variants=slice(0, 100))

# Plot.
fig = api.plot_frequencies_time_series(ds, show=False)

# Test.
assert isinstance(fig, go.Figure)


@parametrize_with_cases("fixture,api", cases=".")
def test_plot_frequencies_time_series_with_taxa(
fixture,
api: AnophelesSnpFrequencyAnalysis,
):
# Pick test parameters at random.
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
transcript = random_transcript(api=api).name
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
period_by = random.choice(["year", "quarter", "month"])

# Pick a random taxon and taxa from valid taxa.
sample_sets_taxa = (
api.sample_metadata(sample_sets=sample_sets)["taxon"].dropna().unique().tolist()
)
taxon = random.choice(sample_sets_taxa)
taxa = random.sample(sample_sets_taxa, random.randint(1, len(sample_sets_taxa)))

# Compute SNP frequencies.
ds = api.snp_allele_frequencies_advanced(
transcript=transcript,
area_by=area_by,
period_by=period_by,
sample_sets=sample_sets,
min_cohort_size=1, # Don't exclude any samples.
site_mask=site_mask,
)

# Trim things down a bit for speed.
ds = ds.isel(variants=slice(0, 100))

# Plot with taxon.
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxon)

# Test taxon plot.
assert isinstance(fig, go.Figure)

# Plot with taxa.
fig = api.plot_frequencies_time_series(ds, show=False, taxa=taxa)

# Test taxa plot.
assert isinstance(fig, go.Figure)


@parametrize_with_cases("fixture,api", cases=".")
def test_plot_frequencies_time_series_with_areas(
fixture,
api: AnophelesSnpFrequencyAnalysis,
):
# Pick test parameters at random.
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
transcript = random_transcript(api=api).name
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
period_by = random.choice(["year", "quarter", "month"])

# Compute SNP frequencies.
ds = api.snp_allele_frequencies_advanced(
transcript=transcript,
area_by=area_by,
period_by=period_by,
sample_sets=sample_sets,
min_cohort_size=1, # Don't exclude any samples.
site_mask=site_mask,
)

# Trim things down a bit for speed.
ds = ds.isel(variants=slice(0, 100))

# Extract cohorts into a DataFrame.
cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
df_cohorts = ds[cohort_vars].to_dataframe()

# Pick a random area and areas from valid areas.
cohorts_areas = df_cohorts["cohort_area"].dropna().unique().tolist()
area = random.choice(cohorts_areas)
areas = random.sample(cohorts_areas, random.randint(1, len(cohorts_areas)))

# Plot with area.
fig = api.plot_frequencies_time_series(ds, show=False, areas=area)

# Test areas plot.
assert isinstance(fig, go.Figure)

# Plot with areas.
fig = api.plot_frequencies_time_series(ds, show=False, areas=areas)

# Test area plot.
assert isinstance(fig, go.Figure)


@parametrize_with_cases("fixture,api", cases=".")
def test_plot_frequencies_interactive_map(
fixture,
api: AnophelesSnpFrequencyAnalysis,
):
import ipywidgets # type: ignore

# Pick test parameters at random.
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice(api.site_mask_ids + (None,))
min_cohort_size = random.randint(0, 2)
transcript = random_transcript(api=api).name
area_by = random.choice(["country", "admin1_iso", "admin2_name"])
period_by = random.choice(["year", "quarter", "month"])

# Compute SNP frequencies.
ds = api.snp_allele_frequencies_advanced(
transcript=transcript,
area_by=area_by,
period_by=period_by,
sample_sets=sample_sets,
min_cohort_size=min_cohort_size,
site_mask=site_mask,
)

# Trim things down a bit for speed.
ds = ds.isel(variants=slice(0, 100))

# Plot.
fig = api.plot_frequencies_interactive_map(ds)

# Test.
assert isinstance(fig, ipywidgets.Widget)

# Compute amino acid change frequencies.
ds = api.aa_allele_frequencies_advanced(
transcript=transcript,
area_by=area_by,
period_by=period_by,
sample_sets=sample_sets,
min_cohort_size=min_cohort_size,
)

# Trim things down a bit for speed.
ds = ds.isel(variants=slice(0, 100))

# Plot.
fig = api.plot_frequencies_interactive_map(ds)

# Test.
assert isinstance(fig, ipywidgets.Widget)

0 comments on commit 06be149

Please sign in to comment.