diff --git a/changelog.d/333.misc b/changelog.d/333.misc new file mode 100644 index 00000000..f4477225 --- /dev/null +++ b/changelog.d/333.misc @@ -0,0 +1 @@ +Improve static type checking. diff --git a/mypy.ini b/mypy.ini index 4fd26680..3f0a41c3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,7 @@ [mypy] plugins = mypy_zope:plugin check_untyped_defs = True +disallow_untyped_defs = True show_error_codes = True show_traceback = True mypy_path = stubs @@ -43,3 +44,36 @@ ignore_missing_imports = True [mypy-pywebpush] ignore_missing_imports = True + +[mypy-sygnal.helper.*] +disallow_untyped_defs = False + +[mypy-sygnal.notifications] +disallow_untyped_defs = False + +[mypy-sygnal.http] +disallow_untyped_defs = False + +[mypy-sygnal.sygnal] +disallow_untyped_defs = False + +[mypy-tests.asyncio_test_helpers] +disallow_untyped_defs = False + +[mypy-tests.test_http] +disallow_untyped_defs = False + +[mypy-tests.test_httpproxy_asyncio] +disallow_untyped_defs = False + +[mypy-tests.test_httpproxy_twisted] +disallow_untyped_defs = False + +[mypy-tests.test_pushgateway_api_v1] +disallow_untyped_defs = False + +[mypy-tests.testutils] +disallow_untyped_defs = False + +[mypy-tests.twisted_test_helpers] +disallow_untyped_defs = False diff --git a/sygnal/apnstruncate.py b/sygnal/apnstruncate.py index 511cbd38..9654a416 100644 --- a/sygnal/apnstruncate.py +++ b/sygnal/apnstruncate.py @@ -28,7 +28,7 @@ ] -def json_encode(payload) -> bytes: +def json_encode(payload: Dict[str, Any]) -> bytes: return json.dumps(payload, ensure_ascii=False).encode() @@ -115,7 +115,7 @@ def _choppables_for_aps(aps: Dict[str, Any]) -> List[Choppable]: def _choppable_get( aps: Dict[str, Any], choppable: Choppable, -): +) -> str: if choppable[0] == "alert": return aps["alert"] elif choppable[0] == "alert.body": diff --git a/sygnal/gcmpushkin.py b/sygnal/gcmpushkin.py index f725b51f..c13e7ca6 100644 --- a/sygnal/gcmpushkin.py +++ b/sygnal/gcmpushkin.py @@ -165,7 +165,7 @@ async def create( return cls(name, sygnal, config) async def _perform_http_request( - self, body: Dict, headers: Dict[AnyStr, List[AnyStr]] + self, body: Dict[str, Any], headers: Dict[AnyStr, List[AnyStr]] ) -> Tuple[IResponse, str]: """ Perform an HTTP request to the FCM server with the body and headers @@ -208,7 +208,7 @@ async def _request_dispatch( self, n: Notification, log: NotificationLoggerAdapter, - body: dict, + body: Dict[str, Any], headers: Dict[AnyStr, List[AnyStr]], pushkeys: List[str], span: Span, diff --git a/tests/test_apns.py b/tests/test_apns.py index b7142f36..6a6653e3 100644 --- a/tests/test_apns.py +++ b/tests/test_apns.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict from unittest.mock import MagicMock, patch from aioapns.common import NotificationResult, PushType @@ -56,7 +57,7 @@ class ApnsTestCase(testutils.TestCase): - def setUp(self): + def setUp(self) -> None: self.apns_mock_class = patch("sygnal.apnspushkin.APNs").start() self.apns_mock = MagicMock() self.apns_mock_class.return_value = self.apns_mock @@ -82,7 +83,7 @@ def get_test_pushkin(self, name: str) -> ApnsPushkin: assert isinstance(test_pushkin, ApnsPushkin) return test_pushkin - def config_setup(self, config): + def config_setup(self, config: Dict[str, Any]) -> None: super().config_setup(config) config["apps"][PUSHKIN_ID] = {"type": "apns", "certfile": TEST_CERTFILE_PATH} config["apps"][PUSHKIN_ID_WITH_PUSH_TYPE] = { @@ -91,7 +92,7 @@ def config_setup(self, config): "push_type": "alert", } - def test_payload_truncation(self): + def test_payload_truncation(self) -> None: """ Tests that APNS message bodies will be truncated to fit the limits of APNS. @@ -114,7 +115,7 @@ def test_payload_truncation(self): self.assertLessEqual(len(apnstruncate.json_encode(payload)), 240) - def test_payload_truncation_test_validity(self): + def test_payload_truncation_test_validity(self) -> None: """ This tests that L{test_payload_truncation_success} is a valid test by showing that not limiting the truncation size would result in a @@ -138,7 +139,7 @@ def test_payload_truncation_test_validity(self): self.assertGreater(len(apnstruncate.json_encode(payload)), 200) - def test_expected(self): + def test_expected(self) -> None: """ Tests the expected case: a good response from APNS means we pass on a good response to the homeserver. @@ -177,7 +178,7 @@ def test_expected(self): self.assertEqual({"rejected": []}, resp) - def test_expected_event_id_only_with_default_payload(self): + def test_expected_event_id_only_with_default_payload(self) -> None: """ Tests the expected fallback case: a good response from APNS means we pass on a good response to the homeserver. @@ -214,7 +215,7 @@ def test_expected_event_id_only_with_default_payload(self): self.assertEqual({"rejected": []}, resp) - def test_expected_badge_only_with_default_payload(self): + def test_expected_badge_only_with_default_payload(self) -> None: """ Tests the expected fallback case: a good response from APNS means we pass on a good response to the homeserver. @@ -243,7 +244,7 @@ def test_expected_badge_only_with_default_payload(self): self.assertEqual({"rejected": []}, resp) - def test_expected_full_with_default_payload(self): + def test_expected_full_with_default_payload(self) -> None: """ Tests the expected fallback case: a good response from APNS means we pass on a good response to the homeserver. @@ -285,7 +286,7 @@ def test_expected_full_with_default_payload(self): self.assertEqual({"rejected": []}, resp) - def test_misconfigured_payload_is_rejected(self): + def test_misconfigured_payload_is_rejected(self) -> None: """Test that a malformed default_payload causes pushkey to be rejected""" resp = self._request( @@ -294,7 +295,7 @@ def test_misconfigured_payload_is_rejected(self): self.assertEqual({"rejected": ["badpayload"]}, resp) - def test_rejection(self): + def test_rejection(self) -> None: """ Tests the rejection case: a rejection response from APNS leads to us passing on a rejection to the homeserver. @@ -312,7 +313,7 @@ def test_rejection(self): self.assertEqual(1, method.call_count) self.assertEqual({"rejected": ["spqr"]}, resp) - def test_no_retry_on_4xx(self): + def test_no_retry_on_4xx(self) -> None: """ Test that we don't retry when we get a 4xx error but do not mark as rejected. @@ -330,7 +331,7 @@ def test_no_retry_on_4xx(self): self.assertEqual(1, method.call_count) self.assertEqual(502, resp) - def test_retry_on_5xx(self): + def test_retry_on_5xx(self) -> None: """ Test that we DO retry when we get a 5xx error and do not mark as rejected. @@ -348,7 +349,7 @@ def test_retry_on_5xx(self): self.assertGreater(method.call_count, 1) self.assertEqual(502, resp) - def test_expected_with_push_type(self): + def test_expected_with_push_type(self) -> None: """ Tests the expected case: a good response from APNS means we pass on a good response to the homeserver. diff --git a/tests/test_apnstruncate.py b/tests/test_apnstruncate.py index 8343b200..bc0c0ba8 100644 --- a/tests/test_apnstruncate.py +++ b/tests/test_apnstruncate.py @@ -19,11 +19,12 @@ import string import unittest +from typing import Any, Dict from sygnal.apnstruncate import json_encode, truncate -def simplestring(length, offset=0): +def simplestring(length: int, offset: int = 0) -> str: """ Deterministically generates a string. Args: @@ -41,7 +42,7 @@ def simplestring(length, offset=0): ) -def sillystring(length, offset=0): +def sillystring(length: int, offset: int = 0) -> str: """ Deterministically generates a string Args: @@ -55,7 +56,7 @@ def sillystring(length, offset=0): return "".join([chars[(i + offset) % len(chars)] for i in range(length)]) -def payload_for_aps(aps): +def payload_for_aps(aps: Dict[str, Any]) -> Dict[str, Any]: """ Returns the APNS payload for an 'aps' dictionary. """ @@ -63,7 +64,7 @@ def payload_for_aps(aps): class TruncateTestCase(unittest.TestCase): - def test_dont_truncate(self): + def test_dont_truncate(self) -> None: """ Tests that truncation is not performed if unnecessary. """ @@ -72,7 +73,7 @@ def test_dont_truncate(self): aps = {"alert": txt} self.assertEqual(txt, truncate(payload_for_aps(aps), 256)["aps"]["alert"]) - def test_truncate_alert(self): + def test_truncate_alert(self) -> None: """ Tests that the 'alert' string field will be truncated when needed. """ @@ -83,7 +84,7 @@ def test_truncate_alert(self): txt[:5], truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"] ) - def test_truncate_alert_body(self): + def test_truncate_alert_body(self) -> None: """ Tests that the 'alert' 'body' field will be truncated when needed. """ @@ -95,7 +96,7 @@ def test_truncate_alert_body(self): truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["body"], ) - def test_truncate_loc_arg(self): + def test_truncate_loc_arg(self) -> None: """ Tests that the 'alert' 'loc-args' field will be truncated when needed. (Tests with one loc arg) @@ -108,7 +109,7 @@ def test_truncate_loc_arg(self): truncate(payload_for_aps(aps), overhead + 5)["aps"]["alert"]["loc-args"][0], ) - def test_truncate_loc_args(self): + def test_truncate_loc_args(self) -> None: """ Tests that the 'alert' 'loc-args' field will be truncated when needed. (Tests with two loc args) @@ -130,7 +131,7 @@ def test_truncate_loc_args(self): ], ) - def test_python_unicode_support(self): + def test_python_unicode_support(self) -> None: """ Tests Python's unicode support :- a one character unicode string should have a length of one, even if it's one @@ -146,7 +147,7 @@ def test_python_unicode_support(self): ) self.fail(msg) - def test_truncate_string_with_multibyte(self): + def test_truncate_string_with_multibyte(self) -> None: """ Tests that truncation works as expected on strings containing one multibyte character. @@ -160,7 +161,7 @@ def test_truncate_string_with_multibyte(self): txt[:17], truncate(payload_for_aps(aps), overhead + 20)["aps"]["alert"] ) - def test_truncate_multibyte(self): + def test_truncate_multibyte(self) -> None: """ Tests that truncation works as expected on strings containing only multibyte characters. diff --git a/tests/test_concurrency_limit.py b/tests/test_concurrency_limit.py index eefcf0e8..b4acf90f 100644 --- a/tests/test_concurrency_limit.py +++ b/tests/test_concurrency_limit.py @@ -13,11 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sygnal.notifications import ConcurrencyLimitedPushkin +from typing import TYPE_CHECKING, Any, Dict, List + +from sygnal.notifications import ConcurrencyLimitedPushkin, Device, Notification from sygnal.utils import twisted_sleep from tests.testutils import TestCase +if TYPE_CHECKING: + from sygnal.notifications import NotificationContext + DEVICE_GCM1_EXAMPLE = { "app_id": "com.example.gcm", "pushkey": "spqrg", @@ -36,7 +41,9 @@ class SlowConcurrencyLimitedDummyPushkin(ConcurrencyLimitedPushkin): - async def _dispatch_notification_unlimited(self, n, device, context): + async def dispatch_notification( + self, n: Notification, device: Device, context: "NotificationContext" + ) -> List[str]: """ We will deliver the notification to the mighty nobody and we will take one second to do it, because we are slow! @@ -46,7 +53,7 @@ async def _dispatch_notification_unlimited(self, n, device, context): class ConcurrencyLimitTestCase(TestCase): - def config_setup(self, config): + def config_setup(self, config: Dict[str, Any]) -> None: super().config_setup(config) config["apps"]["com.example.gcm"] = { "type": "tests.test_concurrency_limit.SlowConcurrencyLimitedDummyPushkin", @@ -57,7 +64,7 @@ def config_setup(self, config): "inflight_request_limit": 1, } - def test_passes_under_limit_one(self): + def test_passes_under_limit_one(self) -> None: """ Tests that a push notification succeeds if it is under the limit. """ @@ -65,7 +72,7 @@ def test_passes_under_limit_one(self): self.assertEqual(resp, {"rejected": []}) - def test_passes_under_limit_multiple_no_interfere(self): + def test_passes_under_limit_multiple_no_interfere(self) -> None: """ Tests that 2 push notifications succeed if they are to different pushkins (so do not hit a per-pushkin limit). @@ -76,7 +83,7 @@ def test_passes_under_limit_multiple_no_interfere(self): self.assertEqual(resp, {"rejected": []}) - def test_fails_when_limit_hit(self): + def test_fails_when_limit_hit(self) -> None: """ Tests that 1 of 2 push notifications fail if they are to the same pushkins (so do hit the per-pushkin limit of 1). diff --git a/tests/test_gcm.py b/tests/test_gcm.py index dcb685bf..3a4f61f9 100644 --- a/tests/test_gcm.py +++ b/tests/test_gcm.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Tuple from sygnal.gcmpushkin import GcmPushkin from tests import testutils from tests.testutils import DummyResponse +if TYPE_CHECKING: + from sygnal.sygnal import Sygnal + DEVICE_EXAMPLE = {"app_id": "com.example.gcm", "pushkey": "spqr", "pushkey_ts": 42} DEVICE_EXAMPLE2 = {"app_id": "com.example.gcm", "pushkey": "spqr2", "pushkey_ts": 42} DEVICE_EXAMPLE_WITH_DEFAULT_PAYLOAD = { @@ -57,22 +61,26 @@ class TestGcmPushkin(GcmPushkin): can be preloaded with virtual requests. """ - def __init__(self, name, sygnal, config): + def __init__(self, name: str, sygnal: "Sygnal", config: Dict[str, Any]): super().__init__(name, sygnal, config) - self.preloaded_response = None - self.preloaded_response_payload = None - self.last_request_body = None - self.last_request_headers = None + self.preloaded_response = DummyResponse(0) + self.preloaded_response_payload: Dict[str, Any] = {} + self.last_request_body: Dict[str, Any] = {} + self.last_request_headers: Dict[AnyStr, List[AnyStr]] = {} # type: ignore[valid-type] self.num_requests = 0 - def preload_with_response(self, code, response_payload): + def preload_with_response( + self, code: int, response_payload: Dict[str, Any] + ) -> None: """ Preloads a fake GCM response. """ self.preloaded_response = DummyResponse(code) self.preloaded_response_payload = response_payload - async def _perform_http_request(self, body, headers): + async def _perform_http_request( # type: ignore[override] + self, body: Dict[str, Any], headers: Dict[AnyStr, List[AnyStr]] + ) -> Tuple[DummyResponse, str]: self.last_request_body = body self.last_request_headers = headers self.num_requests += 1 @@ -80,7 +88,7 @@ async def _perform_http_request(self, body, headers): class GcmTestCase(testutils.TestCase): - def config_setup(self, config): + def config_setup(self, config: Dict[str, Any]) -> None: config["apps"]["com.example.gcm"] = { "type": "tests.test_gcm.TestGcmPushkin", "api_key": "kii", @@ -96,7 +104,7 @@ def get_test_pushkin(self, name: str) -> TestGcmPushkin: assert isinstance(pushkin, TestGcmPushkin) return pushkin - def test_expected(self): + def test_expected(self) -> None: """ Tests the expected case: a good response from GCM leads to a good response from Sygnal. @@ -111,7 +119,7 @@ def test_expected(self): self.assertEqual(resp, {"rejected": []}) self.assertEqual(gcm.num_requests, 1) - def test_expected_with_default_payload(self): + def test_expected_with_default_payload(self) -> None: """ Tests the expected case: a good response from GCM leads to a good response from Sygnal. @@ -128,7 +136,7 @@ def test_expected_with_default_payload(self): self.assertEqual(resp, {"rejected": []}) self.assertEqual(gcm.num_requests, 1) - def test_misformed_default_payload_rejected(self): + def test_misformed_default_payload_rejected(self) -> None: """ Tests that a non-dict default_payload is rejected. """ @@ -144,7 +152,7 @@ def test_misformed_default_payload_rejected(self): self.assertEqual(resp, {"rejected": ["badpayload"]}) self.assertEqual(gcm.num_requests, 0) - def test_rejected(self): + def test_rejected(self) -> None: """ Tests the rejected case: a pushkey rejected to GCM leads to Sygnal informing the homeserver of the rejection. @@ -159,7 +167,7 @@ def test_rejected(self): self.assertEqual(resp, {"rejected": ["spqr"]}) self.assertEqual(gcm.num_requests, 1) - def test_batching(self): + def test_batching(self) -> None: """ Tests that multiple GCM devices have their notification delivered to GCM together, instead of being delivered separately. @@ -184,7 +192,7 @@ def test_batching(self): self.assertEqual(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"]) self.assertEqual(gcm.num_requests, 1) - def test_batching_individual_failure(self): + def test_batching_individual_failure(self) -> None: """ Tests that multiple GCM devices have their notification delivered to GCM together, instead of being delivered separately, @@ -211,7 +219,7 @@ def test_batching_individual_failure(self): self.assertEqual(gcm.last_request_body["registration_ids"], ["spqr", "spqr2"]) self.assertEqual(gcm.num_requests, 1) - def test_fcm_options(self): + def test_fcm_options(self) -> None: """ Tests that the config option `fcm_options` allows setting a base layer of options to pass to FCM, for example ones that would be needed for iOS. diff --git a/tests/test_proxy_url_parsing.py b/tests/test_proxy_url_parsing.py index 21c70534..03e90b71 100644 --- a/tests/test_proxy_url_parsing.py +++ b/tests/test_proxy_url_parsing.py @@ -18,7 +18,7 @@ class ProxyUrlTestCase(unittest.TestCase): - def test_decompose_http_proxy_url(self): + def test_decompose_http_proxy_url(self) -> None: parts = decompose_http_proxy_url("http://example.org") self.assertEqual(parts, HttpProxyUrl("example.org", 80, None)) @@ -35,7 +35,7 @@ def test_decompose_http_proxy_url(self): parts, HttpProxyUrl("example.org", 8080, ("bob", "secretsquirrel")) ) - def test_decompose_username_only(self): + def test_decompose_username_only(self) -> None: """ We do not support usernames without passwords for now — this tests the current behaviour, though (it ignores the username). @@ -44,7 +44,7 @@ def test_decompose_username_only(self): parts = decompose_http_proxy_url("http://bob@example.org:8080") self.assertEqual(parts, HttpProxyUrl("example.org", 8080, None)) - def test_decompose_http_proxy_url_failure(self): + def test_decompose_http_proxy_url_failure(self) -> None: # test that non-HTTP schemes raise an exception self.assertRaises( RuntimeError, lambda: decompose_http_proxy_url("ftp://example.org")