Skip to content

Commit

Permalink
Probabilty inference for arc transformations
Browse files Browse the repository at this point in the history
Co-authored-by: Luke LB <ll17354@bristol.ac.uk>
  • Loading branch information
LukeLB and Luke LB authored Jun 14, 2023
1 parent 154f5b0 commit c792e88
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 2 deletions.
64 changes: 62 additions & 2 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from pytensor.scalar import (
Abs,
Add,
ArcCosh,
ArcSinh,
ArcTanh,
Cosh,
Erf,
Erfc,
Expand All @@ -71,6 +74,9 @@
from pytensor.tensor.math import (
abs,
add,
arccosh,
arcsinh,
arctanh,
cosh,
erf,
erfc,
Expand Down Expand Up @@ -369,7 +375,23 @@ def apply(self, fgraph: FunctionGraph):
class MeasurableTransform(MeasurableElemwise):
"""A placeholder used to specify a log-likelihood for a transformed measurable variable"""

valid_scalar_types = (Exp, Log, Add, Mul, Pow, Abs, Sinh, Cosh, Tanh, Erf, Erfc, Erfcx)
valid_scalar_types = (
Exp,
Log,
Add,
Mul,
Pow,
Abs,
Sinh,
Cosh,
Tanh,
ArcSinh,
ArcCosh,
ArcTanh,
Erf,
Erfc,
Erfcx,
)

# Cannot use `transform` as name because it would clash with the property added by
# the `TransformValuesRewrite`
Expand Down Expand Up @@ -501,7 +523,9 @@ def measurable_sub_to_neg(fgraph, node):
return [pt.add(minuend, pt.neg(subtrahend))]


@node_rewriter([exp, log, add, mul, pow, abs, sinh, cosh, tanh, erf, erfc, erfcx])
@node_rewriter(
[exp, log, add, mul, pow, abs, sinh, cosh, tanh, arcsinh, arccosh, arctanh, erf, erfc, erfcx]
)
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
"""Find measurable transformations from Elemwise operators."""

Expand Down Expand Up @@ -544,6 +568,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
Sinh: SinhTransform(),
Cosh: CoshTransform(),
Tanh: TanhTransform(),
ArcSinh: ArcsinhTransform(),
ArcCosh: ArccoshTransform(),
ArcTanh: ArctanhTransform(),
Erf: ErfTransform(),
Erfc: ErfcTransform(),
Erfcx: ErfcxTransform(),
Expand Down Expand Up @@ -660,6 +687,39 @@ def backward(self, value, *inputs):
return pt.arctanh(value)


class ArcsinhTransform(RVTransform):
name = "arcsinh"
ndim_supp = 0

def forward(self, value, *inputs):
return pt.arcsinh(value)

def backward(self, value, *inputs):
return pt.sinh(value)


class ArccoshTransform(RVTransform):
name = "arccosh"
ndim_supp = 0

def forward(self, value, *inputs):
return pt.arccosh(value)

def backward(self, value, *inputs):
return pt.cosh(value)


class ArctanhTransform(RVTransform):
name = "arctanh"
ndim_supp = 0

def forward(self, value, *inputs):
return pt.arctanh(value)

def backward(self, value, *inputs):
return pt.tanh(value)


class ErfTransform(RVTransform):
name = "erf"
ndim_supp = 0
Expand Down
9 changes: 9 additions & 0 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
from pymc.logprob.abstract import MeasurableVariable, _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.transforms import (
ArccoshTransform,
ArcsinhTransform,
ArctanhTransform,
ChainedTransform,
CoshTransform,
ErfcTransform,
Expand Down Expand Up @@ -1028,6 +1031,9 @@ def test_multivariate_transform(shift, scale):
(pt.sinh, SinhTransform()),
(pt.cosh, CoshTransform()),
(pt.tanh, TanhTransform()),
(pt.arcsinh, ArcsinhTransform()),
(pt.arccosh, ArccoshTransform()),
(pt.arctanh, ArctanhTransform()),
],
)
def test_erf_logp(pt_transform, transform):
Expand Down Expand Up @@ -1060,6 +1066,9 @@ def test_erf_logp(pt_transform, transform):
SinhTransform(),
CoshTransform(),
TanhTransform(),
ArcsinhTransform(),
ArccoshTransform(),
ArctanhTransform(),
],
)
def test_check_jac_det(transform):
Expand Down

0 comments on commit c792e88

Please sign in to comment.