diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index 620bd07344..1bb77d1104 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -603,8 +603,9 @@ async def _request( except self._RetryRequest: continue + @staticmethod @typing.final - def _stringify_http_message(self, headers: data_binding.Headers, body: typing.Any) -> str: + def _stringify_http_message(headers: data_binding.Headers, body: typing.Any) -> str: string = "\n".join( f" {name}: {value}" if name != _AUTHORIZATION_HEADER else f" {name}: **REDACTED TOKEN**" for name, value in headers.items() diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 4d534f4989..f529d7331d 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -117,8 +117,20 @@ class StubRestClient: return StubRestClient() @pytest.fixture() - def rest_provider(self, rest_client): - return rest._RESTProvider(lambda: mock.Mock(), None, lambda: mock.Mock(), lambda: rest_client) + def cache(self): + return mock.Mock() + + @pytest.fixture() + def executor(self): + return mock.Mock() + + @pytest.fixture() + def entity_factory(self): + return mock.Mock() + + @pytest.fixture() + def rest_provider(self, rest_client, cache, executor, entity_factory): + return rest._RESTProvider(lambda: entity_factory, executor, lambda: cache, lambda: rest_client) def test_rest_property(self, rest_provider, rest_client): assert rest_provider.rest == rest_client @@ -129,6 +141,18 @@ def test_http_settings_property(self, rest_provider, rest_client): def test_proxy_settings_property(self, rest_provider, rest_client): assert rest_provider.proxy_settings == rest_client.proxy_settings + def test_entity_factory_property(self, rest_provider, entity_factory): + assert rest_provider.entity_factory == entity_factory + + def test_cache_property(self, rest_provider, cache): + assert rest_provider.cache == cache + + def test_executor_property(self, rest_provider, executor): + assert rest_provider.executor == executor + + def test_me_property(self, rest_provider, cache): + assert rest_provider.me == cache.get_me() + ########### # RESTApp # @@ -148,6 +172,24 @@ def rest_app(): class TestRESTApp: + def test__init__when_connector_factory_is_None(self): + http_settings = object() + + with mock.patch.object(rest, "BasicLazyCachedTCPConnectorFactory") as factory: + rest_app = rest.RESTApp( + connector_factory=None, + connector_owner=False, + executor=None, + http_settings=http_settings, + proxy_settings=None, + url=None, + ) + + factory.assert_called_once_with(http_settings) + + assert rest_app._connector_factory is factory() + assert rest_app._connector_owner is True + def test_executor_property(self, rest_app): mock_executor = object() rest_app._executor = mock_executor @@ -166,10 +208,9 @@ def test_proxy_settings(self, rest_app): def test_acquire(self, rest_app): mock_event_loop = object() rest_app._event_loop = mock_event_loop - mock_entity_factory = object() stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl", return_value=mock_entity_factory)) + _entity_factory = stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl")) mock_client = stack.enter_context(mock.patch.object(rest, rest.RESTClientImpl.__qualname__)) stack.enter_context(mock.patch.object(asyncio, "get_running_loop", return_value=mock_event_loop)) @@ -179,7 +220,7 @@ def test_acquire(self, rest_app): mock_client.assert_called_once_with( connector_factory=rest_app._connector_factory, connector_owner=rest_app._connector_owner, - entity_factory=mock_entity_factory, + entity_factory=_entity_factory(), executor=rest_app._executor, http_settings=rest_app._http_settings, proxy_settings=rest_app._proxy_settings, @@ -188,6 +229,26 @@ def test_acquire(self, rest_app): rest_url=rest_app._url, ) + def test_acquire_when_even_loop_not_set(self, rest_app): + mock_event_loop = object() + + stack = contextlib.ExitStack() + _entity_factory = stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl")) + stack.enter_context(mock.patch.object(rest, rest.RESTClientImpl.__qualname__)) + stack.enter_context(mock.patch.object(asyncio, "get_running_loop", return_value=mock_event_loop)) + + with stack: + rest_app.acquire(token="token", token_type="Type") + + assert rest_app._event_loop is mock_event_loop + + # This is just to test the lambdas so it counts towards coverage + assert _entity_factory.call_count == 1 + factory = _entity_factory.call_args_list[0][0][0] + factory.entity_factory + factory.cache + factory.rest + def test_acquire_when__event_loop_and_loop_do_not_equal(self, rest_app): rest_app._event_loop = object() with mock.patch.object(asyncio, "get_running_loop"): @@ -479,6 +540,19 @@ def test__generate_allowed_mentions(self, rest_client, function_input, expected_ def test__transform_emoji_to_url_format(self, rest_client, emoji, expected_return): assert rest_client._transform_emoji_to_url_format(emoji) == expected_return + def test__stringify_http_message_when_body_is_None(self, rest_client): + headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"} + expected_return = " HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**" + assert rest_client._stringify_http_message(headers, None) == expected_return + + @pytest.mark.parametrize(("body", "expected"), [(bytes("hello :)", "ascii"), "hello :)"), (123, "123")]) + def test__stringify_http_message_when_body_is_not_None(self, rest_client, body, expected): + headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"} + expected_return = ( + f" HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**\n\n {expected}" + ) + assert rest_client._stringify_http_message(headers, body) == expected_return + ####################### # Non-async endpoints # ####################### @@ -743,6 +817,14 @@ class ExitException(Exception): return ExitException + async def test___aenter__and__aexit__(self, rest_client): + with mock.patch.object(rest_client, "close") as close: + async with rest_client as client: + assert client is rest_client + close.assert_not_called() + + close.assert_awaited_once_with() + @hikari_test_helpers.timeout() async def test__request_when_buckets_not_started(self, rest_client, exit_exception): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) @@ -770,12 +852,16 @@ async def test__request_when__token_is_None(self, rest_client, exit_exception): rest_client.buckets.is_started = True rest_client._token = None rest_client._acquire_client_session = mock.Mock(return_value=mock_session) - with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): - with pytest.raises(exit_exception): - await rest_client._request(route) + rest_client._stringify_http_message = mock.Mock() + with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=100))) as logger: + with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): + with pytest.raises(exit_exception): + await rest_client._request(route) - _, kwargs = mock_session.request.call_args_list[0] - assert rest._AUTHORIZATION_HEADER not in kwargs["headers"] + _, kwargs = mock_session.request.call_args_list[0] + assert rest._AUTHORIZATION_HEADER not in kwargs["headers"] + + assert logger.log.call_count == 0 @hikari_test_helpers.timeout() async def test__request_when__token_is_not_None(self, rest_client, exit_exception): @@ -784,13 +870,17 @@ async def test__request_when__token_is_not_None(self, rest_client, exit_exceptio rest_client.buckets.is_started = True rest_client._token = "token" rest_client._acquire_client_session = mock.Mock(return_value=mock_session) - with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): - with pytest.raises(exit_exception): - await rest_client._request(route) + rest_client._stringify_http_message = mock.Mock() + with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=100))) as logger: + with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): + with pytest.raises(exit_exception): + await rest_client._request(route) _, kwargs = mock_session.request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "token" + assert logger.log.call_count == 0 + @hikari_test_helpers.timeout() async def test__request_when_no_auth_passed(self, rest_client, exit_exception): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) @@ -798,12 +888,16 @@ async def test__request_when_no_auth_passed(self, rest_client, exit_exception): rest_client.buckets.is_started = True rest_client._token = "token" rest_client._acquire_client_session = mock.Mock(return_value=mock_session) - with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): - with pytest.raises(exit_exception): - await rest_client._request(route, no_auth=True) + rest_client._stringify_http_message = mock.Mock() + with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=0))) as logger: + with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): + with pytest.raises(exit_exception): + await rest_client._request(route, no_auth=True) - _, kwargs = mock_session.request.call_args_list[0] - assert rest._AUTHORIZATION_HEADER not in kwargs["headers"] + _, kwargs = mock_session.request.call_args_list[0] + assert rest._AUTHORIZATION_HEADER not in kwargs["headers"] + + assert logger.log.call_count == 1 @hikari_test_helpers.timeout() async def test__request_when_response_is_NO_CONTENT(self, rest_client): @@ -817,8 +911,12 @@ class StubResponse: rest_client._debug = False rest_client._parse_ratelimits = mock.AsyncMock() rest_client._acquire_client_session = mock.Mock(return_value=mock_session) - with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): - assert (await rest_client._request(route)) is None + rest_client._stringify_http_message = mock.Mock() + with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=100))) as logger: + with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): + assert (await rest_client._request(route)) is None + + assert logger.log.call_count == 0 @hikari_test_helpers.timeout() async def test__request_when_response_is_APPLICATION_JSON(self, rest_client): @@ -837,8 +935,12 @@ async def read(self): rest_client._debug = True rest_client._parse_ratelimits = mock.AsyncMock() rest_client._acquire_client_session = mock.Mock(return_value=mock_session) - with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): - assert (await rest_client._request(route)) == {"something": None} + rest_client._stringify_http_message = mock.Mock() + with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=0))) as logger: + with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): + assert (await rest_client._request(route)) == {"something": None} + + assert logger.log.call_count == 2 @hikari_test_helpers.timeout() async def test__request_when_response_is_not_JSON(self, rest_client): @@ -854,9 +956,13 @@ class StubResponse: rest_client._debug = False rest_client._parse_ratelimits = mock.AsyncMock() rest_client._acquire_client_session = mock.Mock(return_value=mock_session) - with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): - with pytest.raises(errors.HTTPError): - await rest_client._request(route) + rest_client._stringify_http_message = mock.Mock() + with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=100))) as logger: + with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): + with pytest.raises(errors.HTTPError): + await rest_client._request(route) + + assert logger.log.call_count == 0 @hikari_test_helpers.timeout() async def test__request_when_response_is_not_between_200_and_300(self, rest_client, exit_exception): @@ -872,9 +978,13 @@ class StubResponse: rest_client._parse_ratelimits = mock.AsyncMock() rest_client._handle_error_response = mock.AsyncMock(side_effect=exit_exception) rest_client._acquire_client_session = mock.Mock(return_value=mock_session) - with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): - with pytest.raises(exit_exception): - await rest_client._request(route) + rest_client._stringify_http_message = mock.Mock() + with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=100))) as logger: + with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()): + with pytest.raises(exit_exception): + await rest_client._request(route) + + assert logger.log.call_count == 0 @hikari_test_helpers.timeout() async def test__request_when_response__RetryRequest_gets_handled(self, rest_client, exit_exception): diff --git a/tests/hikari/test_channels.py b/tests/hikari/test_channels.py index 64be17b27b..df597f25c6 100644 --- a/tests/hikari/test_channels.py +++ b/tests/hikari/test_channels.py @@ -186,6 +186,7 @@ def test_format_icon_when_hash_is_None(self, model): assert model.format_icon() is None +@pytest.mark.asyncio class TestTextChannel: @pytest.fixture() def model(self, mock_app): @@ -196,7 +197,6 @@ def model(self, mock_app): type=channels.ChannelType.GUILD_TEXT, ) - @pytest.mark.asyncio async def test_history(self, model): model.app.rest.fetch_messages = mock.AsyncMock() @@ -213,7 +213,6 @@ async def test_history(self, model): around=datetime.datetime(2020, 4, 1, 0, 30, 0), ) - @pytest.mark.asyncio async def test_send(self, model): model.app.rest.create_message = mock.AsyncMock() mock_attachment = object() @@ -245,6 +244,13 @@ async def test_send(self, model): role_mentions=[789, 567], ) + def test_trigger_typing(self, model): + model.app.rest.trigger_typing = mock.Mock() + + model.trigger_typing() + + model.app.rest.trigger_typing.assert_called_once_with(12345679) + class TestGuildChannel: @pytest.fixture() diff --git a/tests/hikari/utilities/test_collections.py b/tests/hikari/utilities/test_collections.py index a76c6e8cb6..ca7f0b8ada 100644 --- a/tests/hikari/utilities/test_collections.py +++ b/tests/hikari/utilities/test_collections.py @@ -210,22 +210,19 @@ def test___setitem___removes_old_entry_instead_of_replacing(self): mock_map["ok"] = "foo" assert list(mock_map.items())[2] == ("ok", "foo") - # TODO: fix this so that it is not flaky. - # https://travis-ci.org/github/nekokatt/hikari/jobs/724494888#L797 - @pytest.mark.skip("flaky test, might fail on Windows runners.") @pytest.mark.asyncio async def test___setitem___garbage_collection(self): mock_map = collections.TimedCacheMap( expiry=datetime.timedelta(seconds=hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME * 3) ) - mock_map.update({"OK": "no", "blam": "booga"}) + mock_map["OK"] = "no" await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME * 2) - assert mock_map == {"OK": "no", "blam": "booga"} - mock_map.update({"ayanami": "rei", "owo": "awoo"}) - assert mock_map == {"OK": "no", "blam": "booga", "ayanami": "rei", "owo": "awoo"} + assert mock_map == {"OK": "no"} + mock_map["ayanami"] = "rei" + assert mock_map == {"OK": "no", "ayanami": "rei"} await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME * 2) - mock_map.update({"nyaa": "qt"}) - assert mock_map == {"ayanami": "rei", "owo": "awoo", "nyaa": "qt"} + mock_map["nyaa"] = "qt" + assert mock_map == {"ayanami": "rei", "nyaa": "qt"} class TestLimitedCapacityCacheMap: