Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[52] Allow string values for foriegn_key from other modules #120

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

### Added
- [dev] Add instructions and script for running `postgres` and `postgis` tests.
- Add ability to pass `str` values to `foreign_key` for recipes from other modules [PR #120](https://github.com/model-bakers/model_bakery/pull/120)

### Changed
- Fixed _model parameter annotations [PR #115](https://github.com/model-bakers/model_bakery/pull/115)
Expand Down
53 changes: 38 additions & 15 deletions model_bakery/recipe.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
import itertools
from typing import Any, Dict, List, Type, Union, cast

Expand All @@ -7,7 +6,10 @@

from . import baker
from .exceptions import RecipeNotFound
from .utils import seq # NoQA: Enable seq to be imported from recipes
from .utils import ( # NoQA: Enable seq to be imported from recipes
get_calling_module,
seq,
)

finder = baker.ModelFinder()

Expand Down Expand Up @@ -68,18 +70,29 @@ def extend(self, **attrs) -> "Recipe":
return type(self)(self._model, **attr_mapping)


def _load_recipe_from_calling_module(recipe: str) -> Recipe:
"""Load `Recipe` from the string attribute given from the calling module.

Args:
recipe (str): the name of the recipe attribute within the module from
which it should be loaded

Returns:
(Recipe): recipe resolved from calling module
"""
recipe = getattr(get_calling_module(2), recipe)
if recipe:
return cast(Recipe, recipe)
else:
raise RecipeNotFound


class RecipeForeignKey(object):
def __init__(self, recipe: Union[str, Recipe]) -> None:
"""A `Recipe` to use for making ManyToOne related objects."""

def __init__(self, recipe: Recipe) -> None:
if isinstance(recipe, Recipe):
self.recipe = recipe
elif isinstance(recipe, str):
frame = inspect.stack()[2]
caller_module = inspect.getmodule(frame[0])
recipe = getattr(caller_module, recipe)
if recipe:
self.recipe = cast(Recipe, recipe)
else:
raise RecipeNotFound
else:
raise TypeError("Not a recipe")

Expand All @@ -89,8 +102,20 @@ def foreign_key(recipe: Union[Recipe, str]) -> RecipeForeignKey:

Return the callable, so that the associated `_model` will not be created
during the recipe definition.

This resolves recipes supplied as strings from other module paths or from
the calling code's module.
"""
return RecipeForeignKey(recipe)
if isinstance(recipe, str):
# Load `Recipe` from string before handing off to `RecipeForeignKey`
try:
# Try to load from another module
recipe = baker._recipe(recipe)
except (AttributeError, ImportError, ValueError):
# Probably not in another module, so load it from calling module
recipe = _load_recipe_from_calling_module(cast(str, recipe))

return RecipeForeignKey(cast(Recipe, recipe))


class related(object): # FIXME
Expand All @@ -100,9 +125,7 @@ def __init__(self, *args) -> None:
if isinstance(recipe, Recipe):
self.related.append(recipe)
elif isinstance(recipe, str):
frame = inspect.stack()[1]
caller_module = inspect.getmodule(frame[0])
recipe = getattr(caller_module, recipe)
recipe = _load_recipe_from_calling_module(recipe)
if recipe:
self.related.append(recipe)
else:
Expand Down
18 changes: 18 additions & 0 deletions model_bakery/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import datetime
import importlib
import inspect
import itertools
import warnings
from types import ModuleType
from typing import Any, Callable, Optional, Union

from .timezone import tz_aware
Expand All @@ -21,6 +23,22 @@ def import_from_str(import_string: Optional[Union[Callable, str]]) -> Any:
return import_string


def get_calling_module(levels_back: int) -> Optional[ModuleType]:
"""Get the module some number of stack frames back from the current one.

Make sure to account for the number of frames between the "calling" code
and the one that calls this function.

Args:
levels_back (int): Number of stack frames back from the current

Returns:
(ModuleType): the module from which the code was called
"""
frame = inspect.stack()[levels_back + 1][0]
return inspect.getmodule(frame)
Comment on lines +26 to +39
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a function in utils.py named import_from_str. It's used by baker.make or baker.prepare to work with string paths instead of model classes.

This PR looks great and it definitely makes bakery more stable in terms of its external API. But I wonder if we could re-use the previous import function instead of defining new strategies to import models based on strings.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, this functionality was already present in the import logic of RecipeForeignKey and of related. I simply extracted it out into a re-usable function.

I personally think the logic here (using stack frames to get the module from which the code was called) is a bit fragile, but it was already the way it was implemented so I left it alone (aside from pulling it into its own function to avoid duplication).

I am definitely open to alternatives, if you think there is a better way to access the "calling module" and its attributes. Just let me know what you think, and I will update the code. Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context, I did re-use the baker._recipe import function (which calls import_from_str) for the new logic that I added, however, I left the "import from calling module" logic as a fallback (to maintain existing use of the foreign_key function).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@berinhard Please let me know if there is a way to facilitate loading a recipe from the module that called foreign_key that can re-use any existing function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@timjklein36 sorry I didn't have the time to properly review this last week. But I'll do so today and will update the PR with comments.



def seq(value, increment_by=1, start=None, suffix=None):
"""Generate a sequence of values based on a running count.

Expand Down
12 changes: 12 additions & 0 deletions tests/test_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,18 @@ def test_not_accept_other_type(self):
exception = c.value
assert str(exception) == "Not a recipe"

def test_load_from_other_module_recipe(self):
dog = Recipe(Dog, owner=foreign_key("tests.generic.person")).make()
assert dog.owner.name == "John Doe"

def test_fail_load_invalid_recipe(self):
with pytest.raises(AttributeError):
foreign_key("tests.generic.nonexisting_recipe")

def test_class_directly_with_string(self):
with pytest.raises(TypeError):
RecipeForeignKey("foo")

def test_do_not_create_related_model(self):
"""It should not create another object when passing the object as argument."""
person = baker.make_recipe("tests.generic.person")
Expand Down
33 changes: 32 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from inspect import getmodule

import pytest

from model_bakery.utils import import_from_str
from model_bakery.utils import get_calling_module, import_from_str
from tests.generic.models import User


Expand All @@ -13,3 +15,32 @@ def test_import_from_str():

assert import_from_str("tests.generic.models.User") == User
assert import_from_str(User) == User


def test_get_calling_module():
# Reference to this very module
this_module = getmodule(test_get_calling_module)

# Once removed is the `pytest` module calling this function
pytest_module = get_calling_module(1)
assert pytest_module != this_module
assert "pytest" in pytest_module.__name__

# Test functions
def dummy_secondary_method():
return get_calling_module(2), get_calling_module(3)

def dummy_method():
return (*dummy_secondary_method(), get_calling_module(1), get_calling_module(2))

# Unpack results from the function chain
sec_mod, sec_pytest_mod, dummy_mod, pytest_mod = dummy_method()

assert sec_mod == this_module
assert "pytest" in sec_pytest_mod.__name__
assert dummy_mod == this_module
assert "pytest" in pytest_mod.__name__

# Raise an `IndexError` when attempting to access too many frames removed
with pytest.raises(IndexError):
assert get_calling_module(100)