120120from pymc .logprob .utils import CheckParameterValue , check_potential_measurability
121121
122122
123- class RVTransform (abc .ABC ):
123+ class Transform (abc .ABC ):
124124 ndim_supp = None
125125
126126 @abc .abstractmethod
@@ -174,10 +174,10 @@ class MeasurableTransform(MeasurableElemwise):
174174
175175 # Cannot use `transform` as name because it would clash with the property added by
176176 # the `TransformValuesRewrite`
177- transform_elemwise : RVTransform
177+ transform_elemwise : Transform
178178 measurable_input_idx : int
179179
180- def __init__ (self , * args , transform : RVTransform , measurable_input_idx : int , ** kwargs ):
180+ def __init__ (self , * args , transform : Transform , measurable_input_idx : int , ** kwargs ):
181181 self .transform_elemwise = transform
182182 self .measurable_input_idx = measurable_input_idx
183183 super ().__init__ (* args , ** kwargs )
@@ -444,7 +444,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
444444 scalar_op = node .op .scalar_op
445445 measurable_input_idx = 0
446446 transform_inputs : Tuple [TensorVariable , ...] = (measurable_input ,)
447- transform : RVTransform
447+ transform : Transform
448448
449449 transform_dict = {
450450 Exp : ExpTransform (),
@@ -559,7 +559,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
559559)
560560
561561
562- class SinhTransform (RVTransform ):
562+ class SinhTransform (Transform ):
563563 name = "sinh"
564564 ndim_supp = 0
565565
@@ -570,7 +570,7 @@ def backward(self, value, *inputs):
570570 return pt .arcsinh (value )
571571
572572
573- class CoshTransform (RVTransform ):
573+ class CoshTransform (Transform ):
574574 name = "cosh"
575575 ndim_supp = 0
576576
@@ -589,7 +589,7 @@ def log_jac_det(self, value, *inputs):
589589 )
590590
591591
592- class TanhTransform (RVTransform ):
592+ class TanhTransform (Transform ):
593593 name = "tanh"
594594 ndim_supp = 0
595595
@@ -600,7 +600,7 @@ def backward(self, value, *inputs):
600600 return pt .arctanh (value )
601601
602602
603- class ArcsinhTransform (RVTransform ):
603+ class ArcsinhTransform (Transform ):
604604 name = "arcsinh"
605605 ndim_supp = 0
606606
@@ -611,7 +611,7 @@ def backward(self, value, *inputs):
611611 return pt .sinh (value )
612612
613613
614- class ArccoshTransform (RVTransform ):
614+ class ArccoshTransform (Transform ):
615615 name = "arccosh"
616616 ndim_supp = 0
617617
@@ -622,7 +622,7 @@ def backward(self, value, *inputs):
622622 return pt .cosh (value )
623623
624624
625- class ArctanhTransform (RVTransform ):
625+ class ArctanhTransform (Transform ):
626626 name = "arctanh"
627627 ndim_supp = 0
628628
@@ -633,7 +633,7 @@ def backward(self, value, *inputs):
633633 return pt .tanh (value )
634634
635635
636- class ErfTransform (RVTransform ):
636+ class ErfTransform (Transform ):
637637 name = "erf"
638638 ndim_supp = 0
639639
@@ -644,7 +644,7 @@ def backward(self, value, *inputs):
644644 return pt .erfinv (value )
645645
646646
647- class ErfcTransform (RVTransform ):
647+ class ErfcTransform (Transform ):
648648 name = "erfc"
649649 ndim_supp = 0
650650
@@ -655,7 +655,7 @@ def backward(self, value, *inputs):
655655 return pt .erfcinv (value )
656656
657657
658- class ErfcxTransform (RVTransform ):
658+ class ErfcxTransform (Transform ):
659659 name = "erfcx"
660660 ndim_supp = 0
661661
@@ -681,7 +681,7 @@ def calc_delta_x(value, prior_result):
681681 return result [- 1 ]
682682
683683
684- class LocTransform (RVTransform ):
684+ class LocTransform (Transform ):
685685 name = "loc"
686686
687687 def __init__ (self , transform_args_fn ):
@@ -699,7 +699,7 @@ def log_jac_det(self, value, *inputs):
699699 return pt .zeros_like (value )
700700
701701
702- class ScaleTransform (RVTransform ):
702+ class ScaleTransform (Transform ):
703703 name = "scale"
704704
705705 def __init__ (self , transform_args_fn ):
@@ -718,7 +718,7 @@ def log_jac_det(self, value, *inputs):
718718 return - pt .log (pt .abs (pt .broadcast_to (scale , value .shape )))
719719
720720
721- class LogTransform (RVTransform ):
721+ class LogTransform (Transform ):
722722 name = "log"
723723
724724 def forward (self , value , * inputs ):
@@ -731,7 +731,7 @@ def log_jac_det(self, value, *inputs):
731731 return value
732732
733733
734- class ExpTransform (RVTransform ):
734+ class ExpTransform (Transform ):
735735 name = "exp"
736736
737737 def forward (self , value , * inputs ):
@@ -744,7 +744,7 @@ def log_jac_det(self, value, *inputs):
744744 return - pt .log (value )
745745
746746
747- class AbsTransform (RVTransform ):
747+ class AbsTransform (Transform ):
748748 name = "abs"
749749
750750 def forward (self , value , * inputs ):
@@ -758,7 +758,7 @@ def log_jac_det(self, value, *inputs):
758758 return pt .switch (value >= 0 , 0 , np .nan )
759759
760760
761- class PowerTransform (RVTransform ):
761+ class PowerTransform (Transform ):
762762 name = "power"
763763
764764 def __init__ (self , power = None ):
@@ -801,7 +801,7 @@ def log_jac_det(self, value, *inputs):
801801 return res
802802
803803
804- class IntervalTransform (RVTransform ):
804+ class IntervalTransform (Transform ):
805805 name = "interval"
806806
807807 def __init__ (self , args_fn : Callable [..., Tuple [Optional [Variable ], Optional [Variable ]]]):
@@ -909,7 +909,7 @@ def log_jac_det(self, value, *inputs):
909909 return pt .zeros_like (value )
910910
911911
912- class LogOddsTransform (RVTransform ):
912+ class LogOddsTransform (Transform ):
913913 name = "logodds"
914914
915915 def backward (self , value , * inputs ):
@@ -923,7 +923,7 @@ def log_jac_det(self, value, *inputs):
923923 return pt .log (sigmoid_value ) + pt .log1p (- sigmoid_value )
924924
925925
926- class SimplexTransform (RVTransform ):
926+ class SimplexTransform (Transform ):
927927 name = "simplex"
928928
929929 def forward (self , value , * inputs ):
@@ -950,7 +950,7 @@ def log_jac_det(self, value, *inputs):
950950 return pt .sum (res , - 1 )
951951
952952
953- class CircularTransform (RVTransform ):
953+ class CircularTransform (Transform ):
954954 name = "circular"
955955
956956 def backward (self , value , * inputs ):
@@ -963,7 +963,7 @@ def log_jac_det(self, value, *inputs):
963963 return pt .zeros (value .shape )
964964
965965
966- class ChainedTransform (RVTransform ):
966+ class ChainedTransform (Transform ):
967967 name = "chain"
968968
969969 def __init__ (self , transform_list , base_op ):
0 commit comments