From 972f3c452f0f9ed6048caf9a676bc7490bcb2dba Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 2 Feb 2024 01:43:48 +0530 Subject: [PATCH] Rectify return type hints in logprob module rewrites Co-authored-by: Vivek Anand Singh <17vivekanandsingh@gmail.com> --- pymc/logprob/binary.py | 5 +++-- pymc/logprob/censoring.py | 5 +++-- pymc/logprob/checks.py | 5 +++-- pymc/logprob/cumsum.py | 3 ++- pymc/logprob/tensor.py | 7 +++---- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index 8326774e5ca..f5d8cf848c3 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -20,6 +20,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.scalar.basic import GE, GT, LE, LT, Invert +from pytensor.tensor import TensorVariable from pytensor.tensor.math import ge, gt, invert, le, lt from pymc.logprob.abstract import ( @@ -41,7 +42,7 @@ class MeasurableComparison(MeasurableElemwise): @node_rewriter(tracks=[gt, lt, ge, le]) def find_measurable_comparisons( fgraph: FunctionGraph, node: Node -) -> Optional[list[MeasurableComparison]]: +) -> Optional[list[TensorVariable]]: rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -133,7 +134,7 @@ class MeasurableBitwise(MeasurableElemwise): @node_rewriter(tracks=[invert]) -def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[MeasurableBitwise]]: +def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover diff --git a/pymc/logprob/censoring.py b/pymc/logprob/censoring.py index 2a053fc06fd..b9221e08db8 100644 --- a/pymc/logprob/censoring.py +++ b/pymc/logprob/censoring.py @@ -44,6 +44,7 @@ from pytensor.graph.rewriting.basic import node_rewriter from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven from pytensor.scalar.basic import clip as scalar_clip +from pytensor.tensor import TensorVariable from pytensor.tensor.math import ceil, clip, floor, round_half_to_even from pytensor.tensor.variable import TensorConstant @@ -62,7 +63,7 @@ class MeasurableClip(MeasurableElemwise): @node_rewriter(tracks=[clip]) -def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[MeasurableClip]]: +def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: # TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub) rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) @@ -157,7 +158,7 @@ class MeasurableRound(MeasurableElemwise): @node_rewriter(tracks=[ceil, floor, round_half_to_even]) -def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[list[MeasurableRound]]: +def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]: rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover diff --git a/pymc/logprob/checks.py b/pymc/logprob/checks.py index 1c9d5559351..1cf202ec5e2 100644 --- a/pymc/logprob/checks.py +++ b/pymc/logprob/checks.py @@ -40,6 +40,7 @@ from pytensor.graph.rewriting.basic import node_rewriter from pytensor.raise_op import CheckAndRaise +from pytensor.tensor import TensorVariable from pytensor.tensor.shape import SpecifyShape from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper @@ -63,7 +64,7 @@ def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs): @node_rewriter([SpecifyShape]) -def find_measurable_specify_shapes(fgraph, node) -> Optional[list[MeasurableSpecifyShape]]: +def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable]]: r"""Finds `SpecifyShapeOp`\s for which a `logprob` can be computed.""" if isinstance(node.op, MeasurableSpecifyShape): @@ -116,7 +117,7 @@ def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs): @node_rewriter([CheckAndRaise]) -def find_measurable_check_and_raise(fgraph, node) -> Optional[list[MeasurableCheckAndRaise]]: +def find_measurable_check_and_raise(fgraph, node) -> Optional[list[TensorVariable]]: r"""Finds `AssertOp`\s for which a `logprob` can be computed.""" if isinstance(node.op, MeasurableCheckAndRaise): diff --git a/pymc/logprob/cumsum.py b/pymc/logprob/cumsum.py index 9154d3a0112..810f226c8ba 100644 --- a/pymc/logprob/cumsum.py +++ b/pymc/logprob/cumsum.py @@ -39,6 +39,7 @@ import pytensor.tensor as pt from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor import TensorVariable from pytensor.tensor.extra_ops import CumOp from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper @@ -77,7 +78,7 @@ def logprob_cumsum(op, values, base_rv, **kwargs): @node_rewriter([CumOp]) -def find_measurable_cumsums(fgraph, node) -> Optional[list[MeasurableCumsum]]: +def find_measurable_cumsums(fgraph, node) -> Optional[list[TensorVariable]]: r"""Finds `Cumsums`\s for which a `logprob` can be computed.""" if not (isinstance(node.op, CumOp) and node.op.mode == "add"): diff --git a/pymc/logprob/tensor.py b/pymc/logprob/tensor.py index e48945a4968..9cbf456b7bb 100644 --- a/pymc/logprob/tensor.py +++ b/pymc/logprob/tensor.py @@ -41,6 +41,7 @@ from pytensor import tensor as pt from pytensor.graph.op import compute_test_value from pytensor.graph.rewriting.basic import node_rewriter +from pytensor.tensor import TensorVariable from pytensor.tensor.basic import Alloc, Join, MakeVector from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.random.op import RandomVariable @@ -197,9 +198,7 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs): @node_rewriter([MakeVector, Join]) -def find_measurable_stacks( - fgraph, node -) -> Optional[list[Union[MeasurableMakeVector, MeasurableJoin]]]: +def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]: r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed.""" rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) @@ -273,7 +272,7 @@ def logprob_dimshuffle(op, values, base_var, **kwargs): @node_rewriter([DimShuffle]) -def find_measurable_dimshuffles(fgraph, node) -> Optional[list[MeasurableDimShuffle]]: +def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]: r"""Finds `Dimshuffle`\s for which a `logprob` can be computed.""" rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)