Skip to content

Commit

Permalink
Rectify return type hints in logprob module rewrites
Browse files Browse the repository at this point in the history
Co-authored-by: Vivek Anand Singh <17vivekanandsingh@gmail.com>
  • Loading branch information
2 people authored and ricardoV94 committed Feb 6, 2024
1 parent ad500d0 commit 972f3c4
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 11 deletions.
5 changes: 3 additions & 2 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import GE, GT, LE, LT, Invert
from pytensor.tensor import TensorVariable
from pytensor.tensor.math import ge, gt, invert, le, lt

from pymc.logprob.abstract import (
Expand All @@ -41,7 +42,7 @@ class MeasurableComparison(MeasurableElemwise):
@node_rewriter(tracks=[gt, lt, ge, le])
def find_measurable_comparisons(
fgraph: FunctionGraph, node: Node
) -> Optional[list[MeasurableComparison]]:
) -> Optional[list[TensorVariable]]:
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover
Expand Down Expand Up @@ -133,7 +134,7 @@ class MeasurableBitwise(MeasurableElemwise):


@node_rewriter(tracks=[invert])
def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[MeasurableBitwise]]:
def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]:
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover
Expand Down
5 changes: 3 additions & 2 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
from pytensor.scalar.basic import clip as scalar_clip
from pytensor.tensor import TensorVariable
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
from pytensor.tensor.variable import TensorConstant

Expand All @@ -62,7 +63,7 @@ class MeasurableClip(MeasurableElemwise):


@node_rewriter(tracks=[clip])
def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[MeasurableClip]]:
def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
Expand Down Expand Up @@ -157,7 +158,7 @@ class MeasurableRound(MeasurableElemwise):


@node_rewriter(tracks=[ceil, floor, round_half_to_even])
def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[list[MeasurableRound]]:
def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]:
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover
Expand Down
5 changes: 3 additions & 2 deletions pymc/logprob/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import TensorVariable
from pytensor.tensor.shape import SpecifyShape

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
Expand All @@ -63,7 +64,7 @@ def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs):


@node_rewriter([SpecifyShape])
def find_measurable_specify_shapes(fgraph, node) -> Optional[list[MeasurableSpecifyShape]]:
def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable]]:
r"""Finds `SpecifyShapeOp`\s for which a `logprob` can be computed."""

if isinstance(node.op, MeasurableSpecifyShape):
Expand Down Expand Up @@ -116,7 +117,7 @@ def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):


@node_rewriter([CheckAndRaise])
def find_measurable_check_and_raise(fgraph, node) -> Optional[list[MeasurableCheckAndRaise]]:
def find_measurable_check_and_raise(fgraph, node) -> Optional[list[TensorVariable]]:
r"""Finds `AssertOp`\s for which a `logprob` can be computed."""

if isinstance(node.op, MeasurableCheckAndRaise):
Expand Down
3 changes: 2 additions & 1 deletion pymc/logprob/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import pytensor.tensor as pt

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import CumOp

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
Expand Down Expand Up @@ -77,7 +78,7 @@ def logprob_cumsum(op, values, base_rv, **kwargs):


@node_rewriter([CumOp])
def find_measurable_cumsums(fgraph, node) -> Optional[list[MeasurableCumsum]]:
def find_measurable_cumsums(fgraph, node) -> Optional[list[TensorVariable]]:
r"""Finds `Cumsums`\s for which a `logprob` can be computed."""

if not (isinstance(node.op, CumOp) and node.op.mode == "add"):
Expand Down
7 changes: 3 additions & 4 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from pytensor import tensor as pt
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import TensorVariable
from pytensor.tensor.basic import Alloc, Join, MakeVector
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -197,9 +198,7 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs):


@node_rewriter([MakeVector, Join])
def find_measurable_stacks(
fgraph, node
) -> Optional[list[Union[MeasurableMakeVector, MeasurableJoin]]]:
def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]:
r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed."""

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
Expand Down Expand Up @@ -273,7 +272,7 @@ def logprob_dimshuffle(op, values, base_var, **kwargs):


@node_rewriter([DimShuffle])
def find_measurable_dimshuffles(fgraph, node) -> Optional[list[MeasurableDimShuffle]]:
def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]:
r"""Finds `Dimshuffle`\s for which a `logprob` can be computed."""

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
Expand Down

0 comments on commit 972f3c4

Please sign in to comment.