Skip to content

Commit

Permalink
Add Formula class (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto authored Nov 5, 2022
1 parent c74fb3d commit e05daf1
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 2 deletions.
5 changes: 3 additions & 2 deletions bambi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@

from pymc import math

from .backend import PyMCModel
from .data import clear_data_home, load_data
from .families import Family, Likelihood, Link
from .formula import Formula
from .models import Model
from .priors import Prior
from .backend import PyMCModel
from .version import __version__


__all__ = ["Model", "Prior", "Family", "Likelihood", "Link", "PyMCModel"]
__all__ = ["Model", "Prior", "Family", "Likelihood", "Link", "PyMCModel", "Formula"]

_log = logging.getLogger("bambi")

Expand Down
79 changes: 79 additions & 0 deletions bambi/formula.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Sequence

from formulae import model_description
from formulae.terms.variable import Variable


class Formula:
"""Model formula
Allows to describe a model with multiple formulas. The first formula describes the response
variable and its predictors. The following formulas describe predictors for other parameters
of the likelihood function, allowing distributional models.
Parameters
----------
formula : str
A model description written using the formula syntax from the ``formulae`` library.
*additionals : str
Additional formulas that describe the
"""

def __init__(self, formula: str, *additionals: str):
self.additional_formulas_lhs = []
self.formula = formula
self.additional_formulas = self.check_additionals(additionals)

def check_additionals(self, additionals: Sequence[str]):
"""Check if the additional formulas match the expected format
Parameters
----------
additionals : Sequence[str]
Model formulas that describe model parameters rather than a response variable
Returns
-------
additionals : Sequence[str]
If all formulas match the required format, it return them.
"""
for additional in additionals:
self.check_additional(additional)
return additionals

def check_additional(self, additional: str):
"""Check if an additional formula match the expected format
Parameters
----------
additional : str
A model formula that describes a model parameter.
Raises
------
ValueError
If the formula does not contain a response term
ValueError
If the response term is not a plain name
"""
response = model_description(additional).response

# There's a response in the formula
if response is None:
raise ValueError("Additional formulas must contain a response name.")

# The response is a name, not a function call for example
if not isinstance(response.term.components[0], Variable):
raise ValueError("The response must be a name.")

self.additional_formulas_lhs.append(response.term.name)

def __str__(self):
formulas = [self.formula] + list(self.additional_formulas)
middle = ", ".join(formulas)
return f"Formula({middle})"

def __repr__(self):
formulas = [self.formula] + list(self.additional_formulas)
middle = ", ".join([f"'{formula}'" for formula in formulas])
return f"Formula({middle})"
44 changes: 44 additions & 0 deletions bambi/tests/test_formula.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

import bambi as bmb


def test_regular_formula():
f1 = bmb.Formula("y ~ x1 + x2")
assert f1.formula == "y ~ x1 + x2"
assert f1.additional_formulas == tuple()
assert f1.additional_formulas_lhs == list()


def test_additional_empty_response():
with pytest.raises(ValueError, match="Additional formulas must contain a response name"):
bmb.Formula("y ~ x1", "x1")


def test_additional_call_response():
with pytest.raises(ValueError, match="The response must be a name"):
bmb.Formula("y ~ x1", "log(sigma) ~ x1")


def test_access_additional_names():
f1 = bmb.Formula("y ~ x")
f2 = bmb.Formula("y ~ x1", "sigma ~ 1", "gamma ~ x")

assert f1.additional_formulas_lhs == []
assert f2.additional_formulas_lhs == ["sigma", "gamma"]


def test_formula_str():
f1 = bmb.Formula("y ~ x")
f2 = bmb.Formula("y ~ x", "sigma ~ 1", "gamma ~ x")

assert str(f1) == "Formula(y ~ x)"
assert str(f2) == "Formula(y ~ x, sigma ~ 1, gamma ~ x)"


def test_formula_repr():
f1 = bmb.Formula("y ~ x")
f2 = bmb.Formula("y ~ x", "sigma ~ 1", "gamma ~ x")

assert repr(f1) == "Formula('y ~ x')"
assert repr(f2) == "Formula('y ~ x', 'sigma ~ 1', 'gamma ~ x')"

0 comments on commit e05daf1

Please sign in to comment.