diff --git a/pymc/logprob/binary.py b/pymc/logprob/binary.py index a344d806734..6003b269c13 100644 --- a/pymc/logprob/binary.py +++ b/pymc/logprob/binary.py @@ -20,6 +20,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.scalar.basic import GE, GT, LE, LT, Invert +from pytensor.tensor import TensorVariable from pytensor.tensor.math import ge, gt, invert, le, lt from pymc.logprob.abstract import ( @@ -41,7 +42,7 @@ class MeasurableComparison(MeasurableElemwise): @node_rewriter(tracks=[gt, lt, ge, le]) def find_measurable_comparisons( fgraph: FunctionGraph, node: Node -) -> Optional[List[MeasurableComparison]]: +) -> Optional[List[TensorVariable]]: rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover