Skip to content

Commit

Permalink
Loc-scale variant of Tuncated Normal distribution (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qazalbash authored Jan 22, 2024
1 parent 08dfcc3 commit 9dd7988
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 40 deletions.
72 changes: 34 additions & 38 deletions jaxampler/_src/rvs/truncnormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,79 +29,75 @@
class TruncNormal(ContinuousRV):
def __init__(
self,
mu: Numeric | Any,
sigma: Numeric | Any,
low: Numeric | Any,
high: Numeric | Any,
loc: Numeric | Any = 0.0,
scale: Numeric | Any = 1.0,
low: Numeric | Any = -1.0,
high: Numeric | Any = 1.0,
name: Optional[str] = None,
) -> None:
shape, self._mu, self._sigma, self._low, self._high = jx_cast(mu, sigma, low, high)
shape, self._loc, self._scale, self._low, self._high = jx_cast(loc, scale, low, high)
self.check_params()
self._alpha = (self._low - self._mu) / self._sigma
self._beta = (self._high - self._mu) / self._sigma
self._alpha = (self._low - self._loc) / self._scale
self._beta = (self._high - self._loc) / self._scale
super().__init__(name=name, shape=shape)

def check_params(self) -> None:
assert jnp.all(self._low < self._high), "low must be smaller than high"
assert jnp.all(self._sigma > 0), "sigma must be positive"
assert jnp.all(self._scale > 0), "sigma must be positive"

@partial(jit, static_argnums=(0,))
def logpdf_x(self, x: Numeric) -> Numeric:
return jax_truncnorm.logpdf(
x,
self._alpha,
self._beta,
loc=self._mu,
scale=self._sigma,
x=x,
a=self._alpha,
b=self._beta,
loc=self._loc,
scale=self._scale,
)

@partial(jit, static_argnums=(0,))
def pdf_x(self, x: Numeric) -> Numeric:
return jax_truncnorm.pdf(
x,
self._alpha,
self._beta,
loc=self._mu,
scale=self._sigma,
x=x,
a=self._alpha,
b=self._beta,
loc=self._loc,
scale=self._scale,
)

@partial(jit, static_argnums=(0,))
def logcdf_x(self, x: Numeric) -> Numeric:
return jax_truncnorm.logcdf(
x,
self._alpha,
self._beta,
loc=self._mu,
scale=self._sigma,
x=x,
a=self._alpha,
b=self._beta,
loc=self._loc,
scale=self._scale,
)

@partial(jit, static_argnums=(0,))
def cdf_x(self, x: Numeric) -> Numeric:
return jax_truncnorm.cdf(
x,
self._alpha,
self._beta,
loc=self._mu,
scale=self._sigma,
x=x,
a=self._alpha,
b=self._beta,
loc=self._loc,
scale=self._scale,
)

def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
if key is None:
key = self.get_key()
new_shape = shape + self._shape
return (
jax.random.truncated_normal(
key,
self._alpha,
self._beta,
shape=new_shape,
)
* self._sigma
+ self._mu
return self._loc + self._scale * jax.random.truncated_normal(
key,
self._alpha,
self._beta,
shape=new_shape,
)

def __repr__(self) -> str:
string = f"TruncNorm(mu={self._mu}, sigma={self._sigma}, low={self._low}, high={self._high}"
string = f"TruncNorm(mu={self._loc}, sigma={self._scale}, low={self._low}, high={self._high}"
if self._name is not None:
string += f", name={self._name}"
string += ")"
Expand Down
4 changes: 2 additions & 2 deletions jaxampler/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
import numpy as np
from jax import lax, numpy as jnp
from jax._src import core
from jaxtyping import Array, Integer
from jaxtyping import Integer


def jx_cast(
*args: Any,
) -> tuple[tuple[int, ...], Unpack[tuple[Array, ...]]]:
) -> tuple[tuple[int, ...], Unpack[tuple[Any, ...]]]:
"""Cast provided arguments to `jnp.array` and checks if they can be
broadcast.
Expand Down

0 comments on commit 9dd7988

Please sign in to comment.