diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 136f8e7a..1c67ab03 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ default_language_version: python: python3.7 repos: -- repo: git://github.com/pre-commit/pre-commit-hooks +- repo: https://github.com/pre-commit/pre-commit-hooks rev: c8bad492e1b1d65d9126dba3fe3bd49a5a52b9d6 # v2.1.0 hooks: - id: check-merge-conflict @@ -11,15 +11,15 @@ repos: exclude: ^docs/.*$ - id: trailing-whitespace exclude: README.md -- repo: git://github.com/PyCQA/flake8 +- repo: https://github.com/PyCQA/flake8 rev: 88caf5ac484f5c09aedc02167c59c66ff0af0068 # 3.7.7 hooks: - id: flake8 -- repo: git://github.com/asottile/seed-isort-config +- repo: https://github.com/asottile/seed-isort-config rev: v1.7.0 hooks: - id: seed-isort-config -- repo: git://github.com/pre-commit/mirrors-isort +- repo: https://github.com/pre-commit/mirrors-isort rev: v4.3.4 hooks: - id: isort diff --git a/graphene_sqlalchemy/tests/test_types.py b/graphene_sqlalchemy/tests/test_types.py index 32f01509..1f15fa1a 100644 --- a/graphene_sqlalchemy/tests/test_types.py +++ b/graphene_sqlalchemy/tests/test_types.py @@ -1,11 +1,14 @@ from unittest import mock import pytest +import sqlalchemy.exc +import sqlalchemy.orm.exc from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull, ObjectType, Schema, String) from graphene.relay import Connection +from .. import utils from ..converter import convert_sqlalchemy_composite from ..fields import (SQLAlchemyConnectionField, UnsortedSQLAlchemyConnectionField, createConnectionField, @@ -492,3 +495,65 @@ class Meta: def test_deprecated_createConnectionField(): with pytest.warns(DeprecationWarning): createConnectionField(None) + + +@mock.patch(utils.__name__ + '.class_mapper') +def test_unique_errors_propagate(class_mapper_mock): + # Define unique error to detect + class UniqueError(Exception): + pass + + # Mock class_mapper effect + class_mapper_mock.side_effect = UniqueError + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + class ArticleOne(SQLAlchemyObjectType): + class Meta(object): + model = Article + except UniqueError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, UniqueError) + + +@mock.patch(utils.__name__ + '.class_mapper') +def test_argument_errors_propagate(class_mapper_mock): + # Mock class_mapper effect + class_mapper_mock.side_effect = sqlalchemy.exc.ArgumentError + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + class ArticleTwo(SQLAlchemyObjectType): + class Meta(object): + model = Article + except sqlalchemy.exc.ArgumentError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, sqlalchemy.exc.ArgumentError) + + +@mock.patch(utils.__name__ + '.class_mapper') +def test_unmapped_errors_reformat(class_mapper_mock): + # Mock class_mapper effect + class_mapper_mock.side_effect = sqlalchemy.orm.exc.UnmappedClassError(object) + + # Make sure that errors are propagated from class_mapper when instantiating new classes + error = None + try: + class ArticleThree(SQLAlchemyObjectType): + class Meta(object): + model = Article + except ValueError as e: + error = e + + # Check that an error occured, and that it was the unique error we gave + assert error is not None + assert isinstance(error, ValueError) + assert "You need to pass a valid SQLAlchemy Model" in str(error) diff --git a/graphene_sqlalchemy/types.py b/graphene_sqlalchemy/types.py index 72f06c06..ac69b697 100644 --- a/graphene_sqlalchemy/types.py +++ b/graphene_sqlalchemy/types.py @@ -207,9 +207,11 @@ def __init_subclass_with_meta__( _meta=None, **options ): - assert is_mapped_class(model), ( - "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".' - ).format(cls.__name__, model) + # Make sure model is a valid SQLAlchemy model + if not is_mapped_class(model): + raise ValueError( + "You need to pass a valid SQLAlchemy Model in " '{}.Meta, received "{}".'.format(cls.__name__, model) + ) if not registry: registry = get_global_registry() diff --git a/graphene_sqlalchemy/utils.py b/graphene_sqlalchemy/utils.py index b30c0eb4..340ad47e 100644 --- a/graphene_sqlalchemy/utils.py +++ b/graphene_sqlalchemy/utils.py @@ -27,7 +27,13 @@ def get_query(model, context): def is_mapped_class(cls): try: class_mapper(cls) - except (ArgumentError, UnmappedClassError): + except ArgumentError as error: + # Only handle ArgumentErrors for non-class objects + if "Class object expected" in str(error): + return False + raise + except UnmappedClassError: + # Unmapped classes return false return False else: return True