Skip to content

Commit

Permalink
Optionally drop empty dimensions from stats.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Aug 21, 2019
1 parent fda0105 commit cc260d3
Show file tree
Hide file tree
Showing 2 changed files with 273 additions and 11 deletions.
237 changes: 237 additions & 0 deletions python/tests/test_tree_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5043,6 +5043,243 @@ def f(x):
site_true_Y)


class TestOutputDimensions(StatsTestCase):
"""
Tests for the dimension stripping behaviour of the stats functions.
"""
def get_example_ts(self):
ts = msprime.simulate(10, mutation_rate=1, random_seed=1)
self.assertGreater(ts.num_sites, 1)
return ts

def test_afs_default_windows(self):
ts = self.get_example_ts()
n = ts.num_samples
A = ts.samples()[:5]
B = ts.samples()[5:]
for mode in ["site", "branch"]:
x = ts.allele_frequency_spectrum(mode=mode)
# x is a 1D numpy array with n + 1 values
self.assertEqual(x.shape, (n + 1,))
self.assertArrayEqual(
x, ts.allele_frequency_spectrum([ts.samples()], mode=mode))
x = ts.allele_frequency_spectrum([A, B], mode=mode)
self.assertEqual(x.shape, (len(A) + 1, len(B) + 1))

def test_afs_windows(self):
ts = self.get_example_ts()
L = ts.sequence_length

windows = [0, L / 4, L / 2, L]
A = ts.samples()[:5]
B = ts.samples()[5:]
for mode in ["site", "branch"]:
x = ts.allele_frequency_spectrum([A, B], windows=windows, mode=mode)
self.assertEqual(x.shape, (3, len(A) + 1, len(B) + 1))

x = ts.allele_frequency_spectrum([A], windows=windows, mode=mode)
self.assertEqual(x.shape, (3, len(A) + 1))

x = ts.allele_frequency_spectrum(windows=windows, mode=mode)
# Default returns this for all samples
self.assertEqual(x.shape, (3, ts.num_samples + 1))
y = ts.allele_frequency_spectrum([ts.samples()], windows=windows, mode=mode)
self.assertArrayEqual(x, y)

def test_one_way_stat_default_windows(self):
ts = self.get_example_ts()
# Use diversity as the example one-way stat.
for mode in ["site", "branch"]:
x = ts.diversity(mode=mode)
# x is a zero-d numpy value
self.assertEqual(np.shape(x), tuple())
self.assertEqual(x, float(x))
self.assertEqual(x, ts.diversity(ts.samples(), mode=mode))
self.assertArrayEqual([x], ts.diversity([ts.samples()], mode=mode))

mode = "node"
x = ts.diversity(mode=mode)
# x is a 1D numpy array with N values
self.assertEqual(x.shape, (ts.num_nodes,))
self.assertArrayEqual(x, ts.diversity(ts.samples(), mode=mode))
y = ts.diversity([ts.samples()], mode=mode)
# We're adding on the *last* dimension, so must reshape
self.assertArrayEqual(x.reshape(ts.num_nodes, 1), y)

def test_one_way_stat_windows(self):
ts = self.get_example_ts()
L = ts.sequence_length
N = ts.num_nodes

windows = [0, L / 4, L / 2, L]
A = ts.samples()[:5]
B = ts.samples()[5:]
for mode in ["site", "branch"]:
x = ts.diversity([A, B], windows=windows, mode=mode)
# Three windows, 2 sets.
self.assertEqual(x.shape, (3, 2))

x = ts.diversity([A], windows=windows, mode=mode)
# Three windows, 1 sets.
self.assertEqual(x.shape, (3, 1))

x = ts.diversity(A, windows=windows, mode=mode)
# Dropping the outer list removes the last dimension
self.assertEqual(x.shape, (3, ))

x = ts.diversity(windows=windows, mode=mode)
# Default returns this for all samples
self.assertEqual(x.shape, (3, ))
y = ts.diversity(ts.samples(), windows=windows, mode=mode)
self.assertArrayEqual(x, y)

mode = "node"
x = ts.diversity([A, B], windows=windows, mode=mode)
# Three windows, N nodes and 2 sets.
self.assertEqual(x.shape, (3, N, 2))

x = ts.diversity([A], windows=windows, mode=mode)
# Three windows, N nodes and 1 set.
self.assertEqual(x.shape, (3, N, 1))

x = ts.diversity(A, windows=windows, mode=mode)
# Drop the outer list, so we lose the last dimension
self.assertEqual(x.shape, (3, N))

x = ts.diversity(windows=windows, mode=mode)
# The default sample sets also drops the last dimension
self.assertEqual(x.shape, (3, N))

self.assertEqual(ts.num_trees, 1)
# In this example, we know that the trees are all the same so check this
# for sanity.
self.assertArrayEqual(x[0], x[1])
self.assertArrayEqual(x[0], x[2])

def test_two_way_stat_default_windows(self):
ts = self.get_example_ts()
# Use divergence as the example one-way stat.
A = ts.samples()[:5]
B = ts.samples()[5:]
for mode in ["site", "branch"]:
x = ts.divergence([A, B], mode=mode)
# x is a zero-d numpy value
self.assertEqual(np.shape(x), tuple())
self.assertEqual(x, float(x))
# If indexes is a 1D array, we also drop the outer dimension
self.assertEqual(x, ts.divergence([A, B, A], indexes=[0, 1], mode=mode))
# But, if it's a 2D array we keep the outer dimension
self.assertEqual([x], ts.divergence([A, B], indexes=[[0, 1]], mode=mode))

mode = "node"
x = ts.divergence([A, B], mode=mode)
# x is a 1D numpy array with N values
self.assertEqual(x.shape, (ts.num_nodes,))
self.assertArrayEqual(x, ts.divergence([A, B], indexes=[0, 1], mode=mode))
y = ts.divergence([A, B], indexes=[[0, 1]], mode=mode)
# We're adding on the *last* dimension, so must reshape
self.assertArrayEqual(x.reshape(ts.num_nodes, 1), y)

def test_two_way_stat_windows(self):
ts = self.get_example_ts()
L = ts.sequence_length
N = ts.num_nodes

windows = [0, L / 4, L / 2, L]
A = ts.samples()[:5]
B = ts.samples()[5:]
for mode in ["site", "branch"]:
x = ts.divergence(
[A, B, A], indexes=[[0, 1], [0, 2]], windows=windows, mode=mode)
# Three windows, 2 pairs
self.assertEqual(x.shape, (3, 2))

x = ts.divergence([A, B], indexes=[[0, 1]], windows=windows, mode=mode)
# Three windows, 1 pair
self.assertEqual(x.shape, (3, 1))

x = ts.divergence([A, B], indexes=[0, 1], windows=windows, mode=mode)
# Dropping the outer list removes the last dimension
self.assertEqual(x.shape, (3, ))

y = ts.divergence([A, B], windows=windows, mode=mode)
self.assertEqual(y.shape, (3, ))
self.assertArrayEqual(x, y)

mode = "node"
x = ts.divergence([A, B], indexes=[[0, 1], [0, 1]], windows=windows, mode=mode)
# Three windows, N nodes and 2 pairs
self.assertEqual(x.shape, (3, N, 2))

x = ts.divergence([A, B], indexes=[[0, 1]], windows=windows, mode=mode)
# Three windows, N nodes and 1 pairs
self.assertEqual(x.shape, (3, N, 1))

x = ts.divergence([A, B], indexes=[0, 1], windows=windows, mode=mode)
# Drop the outer list, so we lose the last dimension
self.assertEqual(x.shape, (3, N))

x = ts.divergence([A, B], windows=windows, mode=mode)
# The default sample sets also drops the last dimension
self.assertEqual(x.shape, (3, N))

self.assertEqual(ts.num_trees, 1)
# In this example, we know that the trees are all the same so check this
# for sanity.
self.assertArrayEqual(x[0], x[1])
self.assertArrayEqual(x[0], x[2])

def test_three_way_stat_windows(self):
ts = self.get_example_ts()
L = ts.sequence_length
N = ts.num_nodes

windows = [0, L / 4, L / 2, L]
A = ts.samples()[:2]
B = ts.samples()[2: 4]
C = ts.samples()[4:]
for mode in ["site", "branch"]:
x = ts.Y3(
[A, B, C], indexes=[[0, 1, 2], [0, 2, 1]], windows=windows, mode=mode)
# Three windows, 2 triple
self.assertEqual(x.shape, (3, 2))

x = ts.Y3([A, B, C], indexes=[[0, 1, 2]], windows=windows, mode=mode)
# Three windows, 1 triple
self.assertEqual(x.shape, (3, 1))

x = ts.Y3([A, B, C], indexes=[0, 1, 2], windows=windows, mode=mode)
# Dropping the outer list removes the last dimension
self.assertEqual(x.shape, (3, ))

y = ts.Y3([A, B, C], windows=windows, mode=mode)
self.assertEqual(y.shape, (3, ))
self.assertArrayEqual(x, y)

mode = "node"
x = ts.Y3([A, B, C], indexes=[[0, 1, 2], [0, 2, 1]], windows=windows, mode=mode)
# Three windows, N nodes and 2 triples
self.assertEqual(x.shape, (3, N, 2))

x = ts.Y3([A, B, C], indexes=[[0, 1, 2]], windows=windows, mode=mode)
# Three windows, N nodes and 1 triples
self.assertEqual(x.shape, (3, N, 1))

x = ts.Y3([A, B, C], indexes=[0, 1, 2], windows=windows, mode=mode)
# Drop the outer list, so we lose the last dimension
self.assertEqual(x.shape, (3, N))

x = ts.Y3([A, B, C], windows=windows, mode=mode)
# The default sample sets also drops the last dimension
self.assertEqual(x.shape, (3, N))

self.assertEqual(ts.num_trees, 1)
# In this example, we know that the trees are all the same so check this
# for sanity.
self.assertArrayEqual(x[0], x[1])
self.assertArrayEqual(x[0], x[2])


############################################
# Old code where stats are defined within type
# specific calculattors. These definititions have been
Expand Down
47 changes: 36 additions & 11 deletions python/tskit/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -3446,16 +3446,31 @@ def __run_windowed_stat(self, windows, method, *args, **kwargs):
stat = stat[0]
return stat

def __one_way_sample_set_stat(self, ll_method, sample_sets, windows=None,
mode=None, span_normalise=True, polarised=False):
def __one_way_sample_set_stat(
self, ll_method, sample_sets, windows=None, mode=None, span_normalise=True,
polarised=False):
if sample_sets is None:
sample_sets = self.samples()
# First try to convert to a 1D numpy array. If it is, then we strip off
# the corresponding dimension from the output.
drop_dimension = False
sample_sets = np.array(sample_sets)
if len(sample_sets.shape) == 1:
sample_sets = [sample_sets]
drop_dimension = True

sample_set_sizes = np.array(
[len(sample_set) for sample_set in sample_sets], dtype=np.uint32)
if np.any(sample_set_sizes == 0):
raise ValueError("Sample sets must contain at least one element")

flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32)
return self.__run_windowed_stat(
stat = self.__run_windowed_stat(
windows, ll_method, sample_set_sizes, flattened,
mode=mode, span_normalise=span_normalise, polarised=polarised)
if drop_dimension:
stat = stat.reshape(stat.shape[:-1])
return stat

def __k_way_sample_set_stat(
self, ll_method, k, sample_sets, indexes=None, windows=None,
Expand All @@ -3467,23 +3482,31 @@ def __k_way_sample_set_stat(
flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32)
if indexes is None:
if len(sample_sets) != k:
raise ValueError("Must specify indexes "
"if there are not exactly {} sample sets.".format(k))
else:
indexes = [np.arange(k, dtype=np.int32)]
raise ValueError(
"Must specify indexes if there are not exactly {} sample "
"sets.".format(k))
indexes = np.arange(k, dtype=np.int32)
drop_dimension = False
indexes = util.safe_np_int_cast(indexes, np.int32)
if len(indexes.shape) == 1:
indexes = indexes.reshape((1, indexes.shape[0]))
drop_dimension = True
if len(indexes.shape) != 2 or indexes.shape[1] != k:
raise ValueError("Indexes must be convertable to a 2D numpy array"
"with {} columns".format(k))
return self.__run_windowed_stat(
raise ValueError(
"Indexes must be convertable to a 2D numpy array with {} "
"columns".format(k))
stat = self.__run_windowed_stat(
windows, ll_method, sample_set_sizes, flattened, indexes,
mode=mode, span_normalise=span_normalise)
if drop_dimension:
stat = stat.reshape(stat.shape[:-1])
return stat

############################################
# Statistics definitions
############################################

def diversity(self, sample_sets, windows=None, mode="site",
def diversity(self, sample_sets=None, windows=None, mode="site",
span_normalise=True):
"""
Computes mean genetic diversity (also knowns as "Tajima's pi") in each of the
Expand Down Expand Up @@ -3894,6 +3917,8 @@ def allele_frequency_spectrum(
:return: A (k + 1) dimensional numpy array, where k is the number of sample
sets specified.
"""
# TODO should we allow a single sample_set to be specified here as a 1D array?
# This won't change the output dimensions like the other stats.
if sample_sets is None:
sample_sets = [self.samples()]
return self.__one_way_sample_set_stat(
Expand Down

0 comments on commit cc260d3

Please sign in to comment.