From e805643aa7e53f80875520286a3a37cfe633ddb0 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian <39779176+Dekermanjian@users.noreply.github.com> Date: Sun, 15 Sep 2024 23:29:59 -0600 Subject: [PATCH] =?UTF-8?q?implemented=20fix=20for=20escaping=20underscore?= =?UTF-8?q?s=20in=20latex=20repr=20and=20added=20a=20un=E2=80=A6=20(#7501)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * implemented fix for escaping underscores in latex repr and added a unit test * updated unit test staticmethod to include underscore in var name * add underscore escape fix to distribution repr as well as model repr, fixed testing to expect underscores in LaTeX representation to be escaped * added cleaner method using re to escape underscores, added cleaner test to assert underscores are escaped --- pymc/printing.py | 11 +++++++++++ tests/test_printing.py | 34 +++++++++++++++++++++++++++------- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/pymc/printing.py b/pymc/printing.py index 56445ab9ea8..ef417f37993 100644 --- a/pymc/printing.py +++ b/pymc/printing.py @@ -13,6 +13,8 @@ # limitations under the License. +import re + from functools import partial from pytensor.compile import SharedVariable @@ -58,6 +60,7 @@ def str_for_dist( if "latex" in formatting: if print_name is not None: print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}" + print_name = _format_underscore(print_name) op_name = ( dist.owner.op._print_name[1] @@ -114,6 +117,7 @@ def str_for_model(model: Model, formatting: str = "plain", include_params: bool if not var_reprs: return "" if "latex" in formatting: + var_reprs = [_format_underscore(x) for x in var_reprs] var_reprs = [ var_repr.replace(r"\sim", r"&\sim &").strip("$") for var_repr in var_reprs @@ -295,3 +299,10 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle): except (ModuleNotFoundError, AttributeError): # no ipython shell pass + + +def _format_underscore(variable: str) -> str: + """ + Escapes all unescaped underscores in the variable name for LaTeX representation. + """ + return re.sub(r"(?})$", ( - r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w}," + r"$\text{nested\_mix} \sim \operatorname{MarginalMixture}(\text{w}," r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Poisson}(5))," r"~\operatorname{Censored}(\operatorname{Bernoulli}(0.5),~-1,~1))$" ), - r"$\text{Y_obs} \sim \operatorname{Normal}(\text{mu},~\text{sigma})$", + r"$\text{Y\_obs} \sim \operatorname{Normal}(\text{mu},~\text{sigma})$", r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$", r"$\text{pred} \sim \operatorname{Deterministic}(f(\text{}))", ], @@ -189,11 +190,11 @@ def setup_class(self): r"$\text{mu} \sim \operatorname{Deterministic}$", r"$\text{beta} \sim \operatorname{Normal}$", r"$\text{Z} \sim \operatorname{MultivariateNormal}$", - r"$\text{nb_with_p_n} \sim \operatorname{NegativeBinomial}$", + r"$\text{nb\_with\_p\_n} \sim \operatorname{NegativeBinomial}$", r"$\text{zip} \sim \operatorname{MarginalMixture}$", r"$\text{w} \sim \operatorname{Dirichlet}$", - r"$\text{nested_mix} \sim \operatorname{MarginalMixture}$", - r"$\text{Y_obs} \sim \operatorname{Normal}$", + r"$\text{nested\_mix} \sim \operatorname{MarginalMixture}$", + r"$\text{Y\_obs} \sim \operatorname{Normal}$", r"$\text{pot} \sim \operatorname{Potential}$", r"$\text{pred} \sim \operatorname{Deterministic}", ], @@ -256,7 +257,7 @@ def test_model_latex_repr_three_levels_model(): "$$", "\\begin{array}{rcl}", "\\text{mu} &\\sim & \\operatorname{Normal}(0,~5)\\\\\\text{sigma} &\\sim & " - "\\operatorname{HalfCauchy}(0,~2.5)\\\\\\text{censored_normal} &\\sim & " + "\\operatorname{HalfCauchy}(0,~2.5)\\\\\\text{censored\\_normal} &\\sim & " "\\operatorname{Censored}(\\operatorname{Normal}(\\text{mu},~\\text{sigma}),~-2,~2)", "\\end{array}", "$$", @@ -316,3 +317,22 @@ def random(rng, mu, size): str_repr = model.str_repr(include_params=False) assert str_repr == "\n".join(["x ~ CustomDistNormal", "y ~ CustomRandomNormal"]) + + +class TestLatexRepr: + @staticmethod + def simple_model() -> Model: + with Model() as simple_model: + error = HalfNormal("error", 0.5) + alpha_a = Normal("alpha_a", 0, 1) + Normal("y", alpha_a, error) + return simple_model + + def test_latex_escaped_underscore(self): + """ + Ensures that all underscores in model variable names are properly escaped for LaTeX representation + """ + model = self.simple_model() + model_str = model.str_repr(formatting="latex") + assert "\\_" in model_str + assert "_" not in model_str.replace("\\_", "")