-
-
Notifications
You must be signed in to change notification settings - Fork 127
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c74fb3d
commit e05daf1
Showing
3 changed files
with
126 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from typing import Sequence | ||
|
||
from formulae import model_description | ||
from formulae.terms.variable import Variable | ||
|
||
|
||
class Formula: | ||
"""Model formula | ||
Allows to describe a model with multiple formulas. The first formula describes the response | ||
variable and its predictors. The following formulas describe predictors for other parameters | ||
of the likelihood function, allowing distributional models. | ||
Parameters | ||
---------- | ||
formula : str | ||
A model description written using the formula syntax from the ``formulae`` library. | ||
*additionals : str | ||
Additional formulas that describe the | ||
""" | ||
|
||
def __init__(self, formula: str, *additionals: str): | ||
self.additional_formulas_lhs = [] | ||
self.formula = formula | ||
self.additional_formulas = self.check_additionals(additionals) | ||
|
||
def check_additionals(self, additionals: Sequence[str]): | ||
"""Check if the additional formulas match the expected format | ||
Parameters | ||
---------- | ||
additionals : Sequence[str] | ||
Model formulas that describe model parameters rather than a response variable | ||
Returns | ||
------- | ||
additionals : Sequence[str] | ||
If all formulas match the required format, it return them. | ||
""" | ||
for additional in additionals: | ||
self.check_additional(additional) | ||
return additionals | ||
|
||
def check_additional(self, additional: str): | ||
"""Check if an additional formula match the expected format | ||
Parameters | ||
---------- | ||
additional : str | ||
A model formula that describes a model parameter. | ||
Raises | ||
------ | ||
ValueError | ||
If the formula does not contain a response term | ||
ValueError | ||
If the response term is not a plain name | ||
""" | ||
response = model_description(additional).response | ||
|
||
# There's a response in the formula | ||
if response is None: | ||
raise ValueError("Additional formulas must contain a response name.") | ||
|
||
# The response is a name, not a function call for example | ||
if not isinstance(response.term.components[0], Variable): | ||
raise ValueError("The response must be a name.") | ||
|
||
self.additional_formulas_lhs.append(response.term.name) | ||
|
||
def __str__(self): | ||
formulas = [self.formula] + list(self.additional_formulas) | ||
middle = ", ".join(formulas) | ||
return f"Formula({middle})" | ||
|
||
def __repr__(self): | ||
formulas = [self.formula] + list(self.additional_formulas) | ||
middle = ", ".join([f"'{formula}'" for formula in formulas]) | ||
return f"Formula({middle})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import pytest | ||
|
||
import bambi as bmb | ||
|
||
|
||
def test_regular_formula(): | ||
f1 = bmb.Formula("y ~ x1 + x2") | ||
assert f1.formula == "y ~ x1 + x2" | ||
assert f1.additional_formulas == tuple() | ||
assert f1.additional_formulas_lhs == list() | ||
|
||
|
||
def test_additional_empty_response(): | ||
with pytest.raises(ValueError, match="Additional formulas must contain a response name"): | ||
bmb.Formula("y ~ x1", "x1") | ||
|
||
|
||
def test_additional_call_response(): | ||
with pytest.raises(ValueError, match="The response must be a name"): | ||
bmb.Formula("y ~ x1", "log(sigma) ~ x1") | ||
|
||
|
||
def test_access_additional_names(): | ||
f1 = bmb.Formula("y ~ x") | ||
f2 = bmb.Formula("y ~ x1", "sigma ~ 1", "gamma ~ x") | ||
|
||
assert f1.additional_formulas_lhs == [] | ||
assert f2.additional_formulas_lhs == ["sigma", "gamma"] | ||
|
||
|
||
def test_formula_str(): | ||
f1 = bmb.Formula("y ~ x") | ||
f2 = bmb.Formula("y ~ x", "sigma ~ 1", "gamma ~ x") | ||
|
||
assert str(f1) == "Formula(y ~ x)" | ||
assert str(f2) == "Formula(y ~ x, sigma ~ 1, gamma ~ x)" | ||
|
||
|
||
def test_formula_repr(): | ||
f1 = bmb.Formula("y ~ x") | ||
f2 = bmb.Formula("y ~ x", "sigma ~ 1", "gamma ~ x") | ||
|
||
assert repr(f1) == "Formula('y ~ x')" | ||
assert repr(f2) == "Formula('y ~ x', 'sigma ~ 1', 'gamma ~ x')" |