diff --git a/dirty_equals/__init__.py b/dirty_equals/__init__.py index 7f301c2..20deef4 100644 --- a/dirty_equals/__init__.py +++ b/dirty_equals/__init__.py @@ -18,7 +18,7 @@ IsPositiveFloat, IsPositiveInt, ) -from ._other import FunctionCheck, IsHash, IsIP, IsJson, IsUUID +from ._other import FunctionCheck, IsHash, IsIP, IsJson, IsUrl, IsUUID from ._sequence import Contains, HasLen, IsList, IsListOrTuple, IsTuple from ._strings import IsAnyStr, IsBytes, IsStr from .version import VERSION @@ -70,6 +70,7 @@ 'FunctionCheck', 'IsJson', 'IsUUID', + 'IsUrl', 'IsHash', 'IsIP', # strings diff --git a/dirty_equals/_other.py b/dirty_equals/_other.py index c306e31..ab5e5c7 100644 --- a/dirty_equals/_other.py +++ b/dirty_equals/_other.py @@ -1,7 +1,7 @@ import json import re from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_network -from typing import Any, Callable, Optional, TypeVar, Union, overload +from typing import Any, Callable, Optional, Set, TypeVar, Union, overload from uuid import UUID from ._base import DirtyEquals @@ -149,6 +149,122 @@ def equals(self, other: Any) -> bool: return self.func(other) +class IsUrl(DirtyEquals[str]): + """ + A class that checks if a value is a valid URL, optionally checking different URL types and attributes with + [Pydantic](https://pydantic-docs.helpmanual.io/usage/types/#urls). + """ + + allowed_attribute_checks: Set[str] = { + 'scheme', + 'host', + 'host_type', + 'user', + 'password', + 'tld', + 'port', + 'path', + 'query', + 'fragment', + } + + def __init__( + self, + any_url: bool = False, + any_http_url: bool = False, + http_url: bool = False, + file_url: bool = False, + postgres_dsn: bool = False, + ampqp_dsn: bool = False, + redis_dsn: bool = False, + **expected_attributes: Any, + ): + """ + Args: + any_url: any scheme allowed, TLD not required, host required + any_http_url: scheme http or https, TLD not required, host required + http_url: scheme http or https, TLD required, host required, max length 2083 + file_url: scheme file, host not required + postgres_dsn: user info required, TLD not required + ampqp_dsn: schema amqp or amqps, user info not required, TLD not required, host not required + redis_dsn: scheme redis or rediss, user info not required, tld not required, host not required + **expected_attributes: Expected values for url attributes + ```py title="IsUrl" + from dirty_equals import IsUrl + + assert 'https://example.com' == IsUrl + assert 'https://example.com' == IsUrl(tld='com') + assert 'https://example.com' == IsUrl(scheme='https') + assert 'https://example.com' != IsUrl(scheme='http') + assert 'postgres://user:pass@localhost:5432/app' == IsUrl(postgres_dsn=True) + assert 'postgres://user:pass@localhost:5432/app' != IsUrl(http_url=True) + ``` + """ + try: + from pydantic import ( + AmqpDsn, + AnyHttpUrl, + AnyUrl, + FileUrl, + HttpUrl, + PostgresDsn, + RedisDsn, + ValidationError, + parse_obj_as, + ) + + self.AmqpDsn = AmqpDsn + self.AnyHttpUrl = AnyHttpUrl + self.AnyUrl = AnyUrl + self.FileUrl = FileUrl + self.HttpUrl = HttpUrl + self.PostgresDsn = PostgresDsn + self.RedisDsn = RedisDsn + self.parse_obj_as = parse_obj_as + self.ValidationError = ValidationError + except ImportError as e: + raise ImportError('pydantic is not installed, run `pip install dirty-equals[pydantic]`') from e + url_type_mappings = { + self.AnyUrl: any_url, + self.AnyHttpUrl: any_http_url, + self.HttpUrl: http_url, + self.FileUrl: file_url, + self.PostgresDsn: postgres_dsn, + self.AmqpDsn: ampqp_dsn, + self.RedisDsn: redis_dsn, + } + url_types_sum = sum(url_type_mappings.values()) + if url_types_sum > 1: + raise ValueError('You can only check against one Pydantic url type at a time') + for item in expected_attributes: + if item not in self.allowed_attribute_checks: + raise TypeError( + 'IsURL only checks these attributes: scheme, host, host_type, user, password, tld, ' + 'port, path, query, fragment' + ) + self.attribute_checks = expected_attributes + if url_types_sum == 0: + url_type = AnyUrl + else: + url_type = max(url_type_mappings, key=url_type_mappings.get) # type: ignore[arg-type] + self.url_type = url_type + super().__init__(url_type) + + def equals(self, other: Any) -> bool: + + try: + parsed = self.parse_obj_as(self.url_type, other) + except self.ValidationError: + raise ValueError('Invalid URL') + if not self.attribute_checks: + return parsed == other + + for attribute, expected in self.attribute_checks.items(): + if getattr(parsed, attribute) != expected: + return False + return parsed == other + + HashTypes = Literal['md5', 'sha-1', 'sha-256'] diff --git a/docs/types/other.md b/docs/types/other.md index cde8a5f..79359f6 100644 --- a/docs/types/other.md +++ b/docs/types/other.md @@ -12,6 +12,8 @@ ::: dirty_equals.IsOneOf +::: dirty_equals.IsUrl + ::: dirty_equals.IsHash ::: dirty_equals.IsIP diff --git a/pyproject.toml b/pyproject.toml index 6f51aa5..a1a0628 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ 'typing-extensions>=4.0.1;python_version<"3.8"', 'pytz>=2021.3', ] +optional-dependencies = {pydantic = ['pydantic>=1.9.1'] } dynamic = ['version'] [project.urls] diff --git a/requirements/linting.in b/requirements/linting.in index adde7ac..df660ee 100644 --- a/requirements/linting.in +++ b/requirements/linting.in @@ -4,6 +4,7 @@ flake8-quotes isort[colors] mypy pre-commit +pydantic pycodestyle pyflakes types-pytz diff --git a/requirements/linting.txt b/requirements/linting.txt index 3976e70..0909a1d 100644 --- a/requirements/linting.txt +++ b/requirements/linting.txt @@ -48,6 +48,8 @@ pycodestyle==2.9.1 # via # -r requirements/linting.in # flake8 +pydantic==1.10.2 + # via -r requirements/linting.in pyflakes==2.5.0 # via # -r requirements/linting.in @@ -63,7 +65,9 @@ tomli==2.0.1 types-pytz==2022.2.1.0 # via -r requirements/linting.in typing-extensions==4.3.0 - # via mypy + # via + # mypy + # pydantic virtualenv==20.16.4 # via pre-commit diff --git a/requirements/pyproject.txt b/requirements/pyproject.txt index 0ae8f15..1769127 100644 --- a/requirements/pyproject.txt +++ b/requirements/pyproject.txt @@ -2,9 +2,11 @@ # This file is autogenerated by pip-compile with python 3.10 # To update, run: # -# pip-compile --output-file=requirements/pyproject.txt pyproject.toml +# pip-compile --extra=pydantic --output-file=requirements/pyproject.txt pyproject.toml # +pydantic==1.10.2 + # via dirty-equals (pyproject.toml) pytz==2022.2.1 # via dirty-equals (pyproject.toml) typing-extensions==4.3.0 - # via dirty-equals (pyproject.toml) + # via pydantic diff --git a/tests/test_other.py b/tests/test_other.py index 02e7389..6aca284 100644 --- a/tests/test_other.py +++ b/tests/test_other.py @@ -4,7 +4,7 @@ import pytest -from dirty_equals import FunctionCheck, IsHash, IsIP, IsJson, IsUUID +from dirty_equals import FunctionCheck, IsHash, IsIP, IsJson, IsUrl, IsUUID @pytest.mark.parametrize( @@ -237,3 +237,46 @@ def test_hashlib_hashes(hash_func, hash_type): def test_wrong_hash_type(): with pytest.raises(ValueError, match='Hash type must be one of the following values: md5, sha-1, sha-256'): assert '123' == IsHash('ntlm') + + +@pytest.mark.parametrize( + 'other,dirty', + [ + ('https://example.com', IsUrl), + ('https://example.com', IsUrl(scheme='https')), + ('postgres://user:pass@localhost:5432/app', IsUrl(postgres_dsn=True)), + ], +) +def test_is_url_true(other, dirty): + assert other == dirty + + +@pytest.mark.parametrize( + 'other,dirty', + [ + ('https://example.com', IsUrl(postgres_dsn=True)), + ('https://example.com', IsUrl(scheme='http')), + ('definitely not a url', IsUrl), + (42, IsUrl), + ('https://anotherexample.com', IsUrl(postgres_dsn=True)), + ], +) +def test_is_url_false(other, dirty): + assert other != dirty + + +def test_is_url_invalid_kwargs(): + with pytest.raises( + TypeError, + match='IsURL only checks these attributes: scheme, host, host_type, user, password, tld, port, path, query, ' + 'fragment', + ): + IsUrl(https=True) + + +def test_is_url_too_many_url_types(): + with pytest.raises( + ValueError, + match='You can only check against one Pydantic url type at a time', + ): + assert 'https://example.com' == IsUrl(any_url=True, http_url=True, postgres_dsn=True)