Skip to content

Commit

Permalink
Use sgkit.distarray for sample_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Sep 19, 2024
1 parent 4eea44c commit c552a84
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cubed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ jobs:
- name: Test with pytest
run: |
pytest -v sgkit/tests/test_aggregation.py -k 'test_count_call_alleles or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed
pytest -v sgkit/tests/test_aggregation.py -k 'test_count_call_alleles or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed
19 changes: 10 additions & 9 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,22 +803,23 @@ def sample_stats(
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
if mixed_ploidy:
raise ValueError("Mixed-ploidy dataset")
G = da.asarray(ds[call_genotype].data)
GT = da.asarray(ds[call_genotype].transpose("samples", "variants", "ploidy").data)
H = xr.DataArray(
da.map_blocks(
count_hom,
G.transpose(1, 0, 2),
lambda *args: count_hom(*args)[:, np.newaxis, :],
GT,
np.zeros(3, np.uint64),
drop_axis=(1, 2),
new_axis=1,
drop_axis=2,
new_axis=2,
dtype=np.int64,
chunks=(G.chunks[1], 3),
chunks=(GT.chunks[0], 1, 3),
),
dims=["samples", "categories"],
dims=["samples", "variants", "categories"],
)
n_variant, _, _ = G.shape
H = H.sum(axis=1)
_, n_variant, _ = GT.shape
n_called = H.sum(axis=-1)
call_rate = n_called / n_variant
call_rate = n_called.astype(float) / float(n_variant)
n_hom_ref = H[:, 0]
n_hom_alt = H[:, 1]
n_het = H[:, 2]
Expand Down
11 changes: 11 additions & 0 deletions sgkit/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,17 @@ def test_sample_stats__raise_on_mixed_ploidy():
sample_stats(ds)


@pytest.mark.parametrize("chunks", [(-1, -1, -1), (100, -1, -1), (100, 10, -1)])
def test_sample_stats__chunks(chunks):
ds = simulate_genotype_call_dataset(
n_variant=1000, n_sample=30, missing_pct=0.01, seed=0
)
expect = sample_stats(ds, merge=False).compute()
ds["call_genotype"] = ds["call_genotype"].chunk(chunks)
actual = sample_stats(ds, merge=False).compute()
assert actual.equals(expect)


def test_infer_call_ploidy():
ds = get_dataset(
[
Expand Down

0 comments on commit c552a84

Please sign in to comment.