Skip to content

Commit

Permalink
further clean-up of misc
Browse files Browse the repository at this point in the history
  • Loading branch information
bimac committed Nov 1, 2023
1 parent 8e3bf3b commit 72f6c14
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 44 deletions.
65 changes: 32 additions & 33 deletions iblrig/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,43 +133,42 @@ def texp(factor: float = 0.35, min_: float = 0.2, max_: float = 0.5) -> float:
return texp(factor=factor, min_=min_, max_=max_)


def get_biased_probs(n: int, idx: int = -1, idx_probability: float = 0.5) -> list:
def get_biased_probs(n: int, idx: int = -1, p_idx: float = 0.5) -> list[float]:
"""
Calculate the biased probability for all elements of an array so that
the <idx> value has <prob> probability of being drawn in respect to the
remaining values.
https://github.com/int-brain-lab/iblrig/issues/74
For prob == 0.5
p = [2 / (2 * len(contrast_set) - 1) for x in contrast_set]
p[-1] *= 1 / 2
For arbitrary probs
p = [1/(n-1 + 0.5)] * (n - 1)
e.g. get_biased_probs(3, idx=-1, prob=0.5)
>>> [0.4, 0.4, 0.2]
:param n: The length of the array, i.e. the num of probas to generate
:type n: int
:param idx: The index of the value that has the biased probability,
defaults to -1
:type idx: int, optional
:param idx_probability: The probability of the idxth value relative top the rest,
defaults to 0.5
:type idx_probability: float, optional
:return: List of biased probabilities
:rtype: list
Calculate biased probabilities for all elements of an array such that the
`i`th value has probability `p_i` for being drawn relative to the remaining
values.
See: https://github.com/int-brain-lab/iblrig/issues/74
Parameters
----------
n : int
The length of the array, i.e., the number of probabilities to generate.
idx : int, optional
The index of the value that has the biased probability. Defaults to -1.
p_idx : float, optional
The probability of the `idx`-th value relative to the rest. Defaults to 0.5.
Returns
-------
List[float]
List of biased probabilities.
Raises
------
ValueError
If `idx` is outside the valid range [-1, n), or if `p_idx` is 0.
"""
if idx < -1 or idx >= n:
raise ValueError("Invalid index. Index should be in the range [-1, n).")
# z = n - 1 + idx_probability
# p = [1 / z] * (n + 1)
# p[idx] *= idx_probability
# return p
n_1 = n - 1
z = n_1 + idx_probability
p = [1 / z] * (n_1 + 1)
p[idx] *= idx_probability
if n == 1:
return [1.0]
if p_idx == 0:
raise ValueError("Probability must be larger than 0.")
z = n - 1 + p_idx
p = [1 / z] * n
p[idx] *= p_idx
return p


Expand Down Expand Up @@ -205,7 +204,7 @@ def draw_contrast(contrast_set: Iterable[float],
If an unsupported `probability_type` is provided.
"""
if probability_type in ["skew_zero", "biased"]:
p = get_biased_probs(len(contrast_set), idx=idx, idx_probability=idx_probability)
p = get_biased_probs(n=len(contrast_set), idx=idx, p_idx=idx_probability)
return np.random.choice(contrast_set, p=p)
elif probability_type == "uniform":
return np.random.choice(contrast_set)
Expand Down
28 changes: 17 additions & 11 deletions iblrig/test/test_misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from typing import Iterable

import numpy as np
from scipy import stats
Expand All @@ -9,20 +10,25 @@
class TestMisc(unittest.TestCase):
def test_draw_contrast(self):

contrast_set = np.linspace(0, 1, 11)
n = 500
n_draws = 400
n_contrasts = 10
contrast_set = np.linspace(0, 1, n_contrasts)

drawn_contrasts = [misc.draw_contrast(contrast_set, "uniform") for i in range(n)]
frequencies = np.unique(drawn_contrasts, return_counts=True)[1]
assert stats.chisquare(frequencies).pvalue > 0.05
def assert_distribution(values: int, f_exp: float | None = None) -> None:
f_obs = np.unique(values, return_counts=True)[1]
assert stats.chisquare(f_obs, f_exp).pvalue > 0.05

for p_idx in np.linspace(0.1, 0.9, 7):
drawn_contrasts = [misc.draw_contrast(contrast_set, "biased", 0, p_idx) for i in range(n)]
expected = np.ones(contrast_set.size)
# uniform distribution
contrasts = [misc.draw_contrast(contrast_set, "uniform") for i in range(n_draws)]
assert_distribution(contrasts)

# biased distribution
for p_idx in [0.25, 0.5, 0.75, 1.25]:
contrasts = [misc.draw_contrast(contrast_set, "biased", 0, p_idx) for i in range(n_draws)]
expected = np.ones(n_contrasts)
expected[0] = p_idx
expected = expected / expected.sum() * n
frequencies = np.unique(drawn_contrasts, return_counts=True)[1]
assert stats.chisquare(frequencies, expected).pvalue > 0.05
expected = expected / expected.sum() * n_draws
assert_distribution(contrasts, expected)

self.assertRaises(ValueError, misc.draw_contrast, [], "incorrect_type") # assert exception for incorrect type
self.assertRaises(ValueError, misc.draw_contrast, [0, 1], "biased", 2) # assert exception for out-of-range index

0 comments on commit 72f6c14

Please sign in to comment.