Skip to content

Commit

Permalink
Track specific Elemwise Ops in logprob rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 5, 2022
1 parent 7c32d3c commit ddc6b65
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 17 deletions.
15 changes: 3 additions & 12 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
from pytensor.scalar.basic import clip as scalar_clip
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
from pytensor.tensor.var import TensorConstant

from pymc.logprob.abstract import (
Expand All @@ -67,7 +67,7 @@ class MeasurableClip(MeasurableElemwise):
measurable_clip = MeasurableClip(scalar_clip)


@node_rewriter(tracks=[Elemwise])
@node_rewriter(tracks=[clip])
def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableClip]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

Expand All @@ -78,9 +78,6 @@ def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[List[Me
if isinstance(node.op, MeasurableClip):
return None # pragma: no cover

if not (isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, Clip)):
return None

clipped_var = node.outputs[0]
base_var, lower_bound, upper_bound = node.inputs

Expand Down Expand Up @@ -179,7 +176,7 @@ class MeasurableRound(MeasurableElemwise):
valid_scalar_types = (RoundHalfToEven, Floor, Ceil)


@node_rewriter(tracks=[Elemwise])
@node_rewriter(tracks=[ceil, floor, round_half_to_even])
def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[List[MeasurableRound]]:

rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
Expand All @@ -189,12 +186,6 @@ def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[Lis
if isinstance(node.op, MeasurableRound):
return None # pragma: no cover

if not (
isinstance(node.op, Elemwise)
and isinstance(node.op.scalar_op, MeasurableRound.valid_scalar_types)
):
return None

(rounded_var,) = node.outputs
(base_var,) = node.inputs

Expand Down
8 changes: 3 additions & 5 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.scalar import Add, Exp, Log, Mul
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import add, exp, log, mul
from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
Expand Down Expand Up @@ -256,12 +256,9 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
return input_logprob + jacobian


@node_rewriter([Elemwise])
@node_rewriter([exp, log, add, mul])
def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]:
"""Find measurable transformations from Elemwise operators."""
scalar_op = node.op.scalar_op
if not isinstance(scalar_op, MeasurableTransform.valid_scalar_types):
return None

# Node was already converted
if isinstance(node.op, MeasurableVariable):
Expand Down Expand Up @@ -311,6 +308,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
# This seems to be the only thing preventing nested rewrites from being erased
measurable_input = assign_custom_measurable_outputs(measurable_input.owner)

scalar_op = node.op.scalar_op
measurable_input_idx = 0
transform_inputs: Tuple[TensorVariable, ...] = (measurable_input,)
transform: RVTransform
Expand Down

0 comments on commit ddc6b65

Please sign in to comment.