diff --git a/CHANGELOG.md b/CHANGELOG.md index 734e8416..cb1fe3d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added - [dev] Add instructions and script for running `postgres` and `postgis` tests. +- Add ability to pass `str` values to `foreign_key` for recipes from other modules [PR #120](https://github.com/model-bakers/model_bakery/pull/120) ### Changed - Fixed _model parameter annotations [PR #115](https://github.com/model-bakers/model_bakery/pull/115) diff --git a/model_bakery/recipe.py b/model_bakery/recipe.py index 998c3838..556c91ef 100644 --- a/model_bakery/recipe.py +++ b/model_bakery/recipe.py @@ -1,4 +1,3 @@ -import inspect import itertools from typing import Any, Dict, List, Type, Union, cast @@ -7,7 +6,10 @@ from . import baker from .exceptions import RecipeNotFound -from .utils import seq # NoQA: Enable seq to be imported from recipes +from .utils import ( # NoQA: Enable seq to be imported from recipes + get_calling_module, + seq, +) finder = baker.ModelFinder() @@ -68,18 +70,29 @@ def extend(self, **attrs) -> "Recipe": return type(self)(self._model, **attr_mapping) +def _load_recipe_from_calling_module(recipe: str) -> Recipe: + """Load `Recipe` from the string attribute given from the calling module. + + Args: + recipe (str): the name of the recipe attribute within the module from + which it should be loaded + + Returns: + (Recipe): recipe resolved from calling module + """ + recipe = getattr(get_calling_module(2), recipe) + if recipe: + return cast(Recipe, recipe) + else: + raise RecipeNotFound + + class RecipeForeignKey(object): - def __init__(self, recipe: Union[str, Recipe]) -> None: + """A `Recipe` to use for making ManyToOne related objects.""" + + def __init__(self, recipe: Recipe) -> None: if isinstance(recipe, Recipe): self.recipe = recipe - elif isinstance(recipe, str): - frame = inspect.stack()[2] - caller_module = inspect.getmodule(frame[0]) - recipe = getattr(caller_module, recipe) - if recipe: - self.recipe = cast(Recipe, recipe) - else: - raise RecipeNotFound else: raise TypeError("Not a recipe") @@ -89,8 +102,20 @@ def foreign_key(recipe: Union[Recipe, str]) -> RecipeForeignKey: Return the callable, so that the associated `_model` will not be created during the recipe definition. + + This resolves recipes supplied as strings from other module paths or from + the calling code's module. """ - return RecipeForeignKey(recipe) + if isinstance(recipe, str): + # Load `Recipe` from string before handing off to `RecipeForeignKey` + try: + # Try to load from another module + recipe = baker._recipe(recipe) + except (AttributeError, ImportError, ValueError): + # Probably not in another module, so load it from calling module + recipe = _load_recipe_from_calling_module(cast(str, recipe)) + + return RecipeForeignKey(cast(Recipe, recipe)) class related(object): # FIXME @@ -100,9 +125,7 @@ def __init__(self, *args) -> None: if isinstance(recipe, Recipe): self.related.append(recipe) elif isinstance(recipe, str): - frame = inspect.stack()[1] - caller_module = inspect.getmodule(frame[0]) - recipe = getattr(caller_module, recipe) + recipe = _load_recipe_from_calling_module(recipe) if recipe: self.related.append(recipe) else: diff --git a/model_bakery/utils.py b/model_bakery/utils.py index 04d0ea55..aa07963e 100644 --- a/model_bakery/utils.py +++ b/model_bakery/utils.py @@ -1,7 +1,9 @@ import datetime import importlib +import inspect import itertools import warnings +from types import ModuleType from typing import Any, Callable, Optional, Union from .timezone import tz_aware @@ -21,6 +23,22 @@ def import_from_str(import_string: Optional[Union[Callable, str]]) -> Any: return import_string +def get_calling_module(levels_back: int) -> Optional[ModuleType]: + """Get the module some number of stack frames back from the current one. + + Make sure to account for the number of frames between the "calling" code + and the one that calls this function. + + Args: + levels_back (int): Number of stack frames back from the current + + Returns: + (ModuleType): the module from which the code was called + """ + frame = inspect.stack()[levels_back + 1][0] + return inspect.getmodule(frame) + + def seq(value, increment_by=1, start=None, suffix=None): """Generate a sequence of values based on a running count. diff --git a/tests/test_recipes.py b/tests/test_recipes.py index 15e56ff3..55c61e66 100644 --- a/tests/test_recipes.py +++ b/tests/test_recipes.py @@ -341,6 +341,18 @@ def test_not_accept_other_type(self): exception = c.value assert str(exception) == "Not a recipe" + def test_load_from_other_module_recipe(self): + dog = Recipe(Dog, owner=foreign_key("tests.generic.person")).make() + assert dog.owner.name == "John Doe" + + def test_fail_load_invalid_recipe(self): + with pytest.raises(AttributeError): + foreign_key("tests.generic.nonexisting_recipe") + + def test_class_directly_with_string(self): + with pytest.raises(TypeError): + RecipeForeignKey("foo") + def test_do_not_create_related_model(self): """It should not create another object when passing the object as argument.""" person = baker.make_recipe("tests.generic.person") diff --git a/tests/test_utils.py b/tests/test_utils.py index 3efeb763..cecad4cd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,8 @@ +from inspect import getmodule + import pytest -from model_bakery.utils import import_from_str +from model_bakery.utils import get_calling_module, import_from_str from tests.generic.models import User @@ -13,3 +15,32 @@ def test_import_from_str(): assert import_from_str("tests.generic.models.User") == User assert import_from_str(User) == User + + +def test_get_calling_module(): + # Reference to this very module + this_module = getmodule(test_get_calling_module) + + # Once removed is the `pytest` module calling this function + pytest_module = get_calling_module(1) + assert pytest_module != this_module + assert "pytest" in pytest_module.__name__ + + # Test functions + def dummy_secondary_method(): + return get_calling_module(2), get_calling_module(3) + + def dummy_method(): + return (*dummy_secondary_method(), get_calling_module(1), get_calling_module(2)) + + # Unpack results from the function chain + sec_mod, sec_pytest_mod, dummy_mod, pytest_mod = dummy_method() + + assert sec_mod == this_module + assert "pytest" in sec_pytest_mod.__name__ + assert dummy_mod == this_module + assert "pytest" in pytest_mod.__name__ + + # Raise an `IndexError` when attempting to access too many frames removed + with pytest.raises(IndexError): + assert get_calling_module(100)