diff --git a/marshmallow_jsonschema/base.py b/marshmallow_jsonschema/base.py index 91dbbb1..e27a206 100644 --- a/marshmallow_jsonschema/base.py +++ b/marshmallow_jsonschema/base.py @@ -8,7 +8,8 @@ from .compat import text_type, binary_type, basestring, dot_data_backwards_compatable from marshmallow.decorators import post_dump -from .validation import handle_length, handle_one_of, handle_range +from .validation import (handle_length, handle_one_of, handle_range, + handle_regexp) __all__ = ( @@ -77,6 +78,7 @@ validate.Length: handle_length, validate.OneOf: handle_one_of, validate.Range: handle_range, + validate.Regexp: handle_regexp } diff --git a/marshmallow_jsonschema/validation.py b/marshmallow_jsonschema/validation.py index 0f3fc1c..ce20573 100644 --- a/marshmallow_jsonschema/validation.py +++ b/marshmallow_jsonschema/validation.py @@ -100,3 +100,29 @@ def handle_range(schema, field, validator, parent_schema): schema['maximum'] = validator.max return schema + + +def handle_regexp(schema, field, validator, parent_schema): + """Adds validation logic for ``marshmallow.validate.Regexp``, setting the + values appropriately for ``fields.String`` and its subclasses. + + Args: + schema (dict): The original JSON schema we generated. This is what we + want to post-process. + field (fields.Field): The field that generated the original schema and + who this post-processor belongs to. + validator (marshmallow.validate.Regexp): The validator attached to the + passed in field. + parent_schema (marshmallow.Schema): The Schema instance that the field + belongs to. + + Returns: + dict: A, possibly, new JSON Schema that has been post processed and + altered. + """ + if not isinstance(field, fields.String): + return schema + + if validator.regex and getattr(validator.regex, 'pattern', None): + schema['pattern'] = validator.regex.pattern + return schema diff --git a/tests/__init__.py b/tests/__init__.py index 5450ebc..2ec4615 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -46,6 +46,7 @@ class UserSchema(Schema): validate=validate.Length(min=1, max=3)) github = fields.Nested(GithubProfile) const = fields.String(validate=validate.Length(equal=50)) + hex_number = fields.String(validate=validate.Regexp('^[a-fA-F0-9]+$')) class BaseTest(unittest.TestCase): diff --git a/tests/test_dump.py b/tests/test_dump.py index a7a54b3..38d2ac1 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -486,3 +486,11 @@ class TestSchema(Schema): json_schema = JSONSchema() dumped = dot_data_backwards_compatable(json_schema.dump(schema)) assert 'required' not in dumped['definitions']['TestSchema'] + + +def test_regexp_validator(): + schema = UserSchema() + json_schema = JSONSchema() + dumped = dot_data_backwards_compatable(json_schema.dump(schema)) + _validate_schema(dumped) + assert dumped['properties']['hex_number']['pattern'] == '^[a-fA-F0-9]+$'