diff --git a/dev-requirements.txt b/dev-requirements.txt index 2402a03595..69ac9fbf1e 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -7,7 +7,7 @@ mock==4.0.3 # Py.test stuff. pytest==7.0.0 -pytest-asyncio==0.17.2 +pytest-asyncio==0.18.0 pytest-cov==3.0.0 pytest-randomly==3.11.0 diff --git a/pyproject.toml b/pyproject.toml index c7b24523a2..c421d77e1d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,7 +78,7 @@ norecursedirs = [ "public", "ci", ] -filterwarnings = ["ignore:.*\"@coroutine\" decorator is deprecated.*:DeprecationWarning"] +filterwarnings = ["ignore:.*assertions not in test modules or plugins will be ignored because assert statements are not executed by the underlying Python interpreter.*:pytest.PytestConfigWarning"] [tool.towncrier] package = "hikari" diff --git a/tests/hikari/events/test_channel_events.py b/tests/hikari/events/test_channel_events.py index 2004e7ab81..9eb7b4e8aa 100644 --- a/tests/hikari/events/test_channel_events.py +++ b/tests/hikari/events/test_channel_events.py @@ -175,7 +175,6 @@ async def test_fetch_invite(self, event): event.app.rest.fetch_invite.assert_awaited_once_with("Jx4cNGG") -@pytest.mark.asyncio() class TestInviteCreateEvent: @pytest.fixture() def event(self): @@ -184,14 +183,17 @@ def event(self): def test_app_property(self, event): assert event.app is event.invite.app + @pytest.mark.asyncio() async def test_channel_id_property(self, event): event.invite.channel_id = 123 assert event.channel_id == 123 + @pytest.mark.asyncio() async def test_guild_id_property(self, event): event.invite.guild_id = 123 assert event.guild_id == 123 + @pytest.mark.asyncio() async def test_code_property(self, event): event.invite.code = "Jx4cNGG" assert event.code == "Jx4cNGG" diff --git a/tests/hikari/events/test_typing_events.py b/tests/hikari/events/test_typing_events.py index 4abb933649..8bae1b534b 100644 --- a/tests/hikari/events/test_typing_events.py +++ b/tests/hikari/events/test_typing_events.py @@ -28,7 +28,6 @@ from tests.hikari import hikari_test_helpers -@pytest.mark.asyncio() class TestTypingEvent: @pytest.fixture() def event(self): @@ -42,7 +41,7 @@ def event(self): return cls() - async def test_get_user_when_no_cache(self, event): + def test_get_user_when_no_cache(self, event): event = hikari_test_helpers.mock_class_namespace(typing_events.TypingEvent, app=None)() assert event.get_user() is None @@ -50,14 +49,13 @@ async def test_get_user_when_no_cache(self, event): def test_get_user(self, event): assert event.get_user() is event.app.cache.get_user.return_value - async def test_trigger_typing(self, event): + def test_trigger_typing(self, event): event.app.rest.trigger_typing = mock.Mock() result = event.trigger_typing() event.app.rest.trigger_typing.assert_called_once_with(123) assert result is event.app.rest.trigger_typing.return_value -@pytest.mark.asyncio() class TestGuildTypingEvent: @pytest.fixture() def event(self): @@ -87,6 +85,7 @@ def test_get_channel(self, event, guild_channel_impl): assert result is event.app.cache.get_guild_channel.return_value event.app.cache.get_guild_channel.assert_called_once_with(123) + @pytest.mark.asyncio() async def test_get_guild_when_no_cache(self): event = hikari_test_helpers.mock_class_namespace(typing_events.GuildTypingEvent, app=None, init_=False)() @@ -111,6 +110,7 @@ def test_user_id(self, event): assert event.user_id == event.member.id assert event.user_id == 456 + @pytest.mark.asyncio() @pytest.mark.parametrize("guild_channel_impl", [channels.GuildNewsChannel, channels.GuildTextChannel]) async def test_fetch_channel(self, event, guild_channel_impl): event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock.Mock(spec_set=guild_channel_impl)) @@ -118,16 +118,19 @@ async def test_fetch_channel(self, event, guild_channel_impl): event.app.rest.fetch_channel.assert_awaited_once_with(123) + @pytest.mark.asyncio() async def test_fetch_guild(self, event): await event.fetch_guild() event.app.rest.fetch_guild.assert_awaited_once_with(789) + @pytest.mark.asyncio() async def test_fetch_guild_preview(self, event): await event.fetch_guild_preview() event.app.rest.fetch_guild_preview.assert_awaited_once_with(789) + @pytest.mark.asyncio() async def test_fetch_member(self, event): await event.fetch_member() diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 8d1fc61f8a..a07c1d35b7 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -107,7 +107,6 @@ def test_executor_property(self, rest_provider, executor): ############################# -@pytest.mark.asyncio() class TestClientCredentialsStrategy: @pytest.fixture() def mock_token(self): @@ -134,6 +133,7 @@ def test_token_type_property(self): assert token.token_type is applications.TokenType.BEARER + @pytest.mark.asyncio() async def test_acquire_on_new_instance(self, mock_token): mock_rest = mock.Mock(authorize_client_credentials_token=mock.AsyncMock(return_value=mock_token)) @@ -145,6 +145,7 @@ async def test_acquire_on_new_instance(self, mock_token): client=54123123, client_secret="123123123", scopes=("applications.commands.update", "identify") ) + @pytest.mark.asyncio() async def test_acquire_handles_out_of_date_token(self, mock_token): mock_old_token = mock.Mock( applications.PartialOAuth2Token, @@ -167,6 +168,7 @@ async def test_acquire_handles_out_of_date_token(self, mock_token): assert new_token != token assert new_token == "Bearer okokok.fofofo.ddd" # noqa S105: Possible Hardcoded password + @pytest.mark.asyncio() async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, mock_token): lock = asyncio.Lock() mock_rest = mock.Mock(authorize_client_credentials_token=mock.AsyncMock(side_effect=[mock_token])) @@ -190,6 +192,7 @@ async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, moc "Bearer okokok.fofofo.ddd", ] + @pytest.mark.asyncio() async def test_acquire_after_invalidation(self, mock_token): mock_old_token = mock.Mock( applications.PartialOAuth2Token, @@ -212,6 +215,7 @@ async def test_acquire_after_invalidation(self, mock_token): assert new_token != token assert new_token == "Bearer okokok.fofofo.ddd" # noqa S105: Possible Hardcoded password + @pytest.mark.asyncio() async def test_acquire_uses_newly_cached_token_after_acquiring_lock(self): class MockLock: def __init__(self, strategy): @@ -236,6 +240,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): mock_rest.authorize_client_credentials_token.assert_not_called() + @pytest.mark.asyncio() async def test_acquire_caches_client_http_response_error(self): mock_rest = mock.AsyncMock() error = errors.ClientHTTPResponseError( @@ -1062,166 +1067,478 @@ def test_build_action_row(self, rest_client): action_row_builder.assert_called_once_with() + def test__build_message_payload_with_undefined_args(self, rest_client): + with mock.patch.object( + mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1} + ) as generate_allowed_mentions: + body, form = rest_client._build_message_payload() -@pytest.mark.asyncio() -class TestRESTClientImplAsync: - @pytest.fixture() - def exit_exception(self): - class ExitException(Exception): - ... + assert body == {"allowed_mentions": {"allowed_mentions": 1}} + assert form is None - return ExitException + generate_allowed_mentions.assert_called_once_with( + undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED + ) - async def test___aenter__and__aexit__(self, rest_client): - rest_client.close = mock.AsyncMock() - rest_client.start = mock.Mock() + @pytest.mark.parametrize("args", [("embeds", "components"), ("embed", "component")]) + def test__build_message_payload_with_None_args(self, rest_client, args): + kwargs = {} + for arg in args: + kwargs[arg] = None - async with rest_client as client: - assert client is rest_client - rest_client.start.assert_called_once() - rest_client.close.assert_not_called() + with mock.patch.object( + mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1} + ) as generate_allowed_mentions: + body, form = rest_client._build_message_payload(**kwargs) - rest_client.close.assert_awaited_once_with() + assert body == {"embeds": [], "components": [], "allowed_mentions": {"allowed_mentions": 1}} + assert form is None - @hikari_test_helpers.timeout() - async def test__request_builds_form_when_passed(self, rest_client, exit_exception, live_attributes): - route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - rest_client._token = None - mock_form = mock.AsyncMock() - mock_stack = mock.AsyncMock() - mock_stack.__aenter__ = mock_stack + generate_allowed_mentions.assert_called_once_with( + undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED + ) - with mock.patch.object(contextlib, "AsyncExitStack", return_value=mock_stack) as exit_stack: - with pytest.raises(exit_exception): - await rest_client._request(route, form_builder=mock_form) + def test__build_message_payload_with_edit_and_all_mentions_undefined(self, rest_client): + with mock.patch.object(mentions, "generate_allowed_mentions") as generate_allowed_mentions: + body, form = rest_client._build_message_payload(edit=True) - _, kwargs = mock_session.request.call_args_list[0] - mock_form.build.assert_awaited_once_with(exit_stack.return_value) - assert kwargs["data"] is mock_form.build.return_value - assert live_attributes.still_alive.call_count == 3 + assert body == {} + assert form is None - @hikari_test_helpers.timeout() - async def test__request_url_encodes_reason_header(self, rest_client, exit_exception, live_attributes): - route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session + generate_allowed_mentions.assert_not_called() - with pytest.raises(exit_exception): - await rest_client._request(route, reason="光のenergyが 大地に降りそそぐ") + def test__build_message_payload_embed_content_syntactic_sugar(self, rest_client): + embed = mock.Mock(embeds.Embed) - _, kwargs = mock_session.request.call_args_list[0] - assert kwargs["headers"][rest._X_AUDIT_LOG_REASON_HEADER] == ( - "%E5%85%89%E3%81%AEenergy%E3%81%8C%E3%80%80%E5%A4%" - "A7%E5%9C%B0%E3%81%AB%E9%99%8D%E3%82%8A%E3%81%9D%E3%81%9D%E3%81%90" + stack = contextlib.ExitStack() + generate_allowed_mentions = stack.enter_context( + mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) ) + rest_client._entity_factory.serialize_embed.return_value = ({"embed": 1}, []) - @hikari_test_helpers.timeout() - async def test__request_with_strategy_token(self, rest_client, exit_exception, live_attributes): - route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - rest_client._token = mock.Mock(rest_api.TokenStrategy, acquire=mock.AsyncMock(return_value="Bearer ok.ok.ok")) + with stack: + body, form = rest_client._build_message_payload(content=embed) - with pytest.raises(exit_exception): - await rest_client._request(route) + # Returned + assert body == {"embeds": [{"embed": 1}], "allowed_mentions": {"allowed_mentions": 1}} + assert form is None - _, kwargs = mock_session.request.call_args_list[0] - assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" - assert live_attributes.still_alive.call_count == 3 + # Embeds + rest_client._entity_factory.serialize_embed.assert_called_once_with(embed) - @hikari_test_helpers.timeout() - async def test__request_retries_strategy_once(self, rest_client, exit_exception, live_attributes): - class StubResponse: - status = http.HTTPStatus.UNAUTHORIZED - content_type = rest._APPLICATION_JSON - reason = "cause why not" - headers = {"HEADER": "value", "HEADER": "value"} + # Generate allowed mentions + generate_allowed_mentions.assert_called_once_with( + undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED + ) - async def read(self): - return '{"something": null}' + def test__build_message_payload_attachment_content_syntactic_sugar(self, rest_client): + attachment = mock.Mock(files.Resource) + resource_attachment = object() - route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock( - request=hikari_test_helpers.CopyingAsyncMock(side_effect=[StubResponse(), exit_exception]) + stack = contextlib.ExitStack() + ensure_resource = stack.enter_context( + mock.patch.object(files, "ensure_resource", return_value=resource_attachment) ) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - rest_client._token = mock.Mock( - rest_api.TokenStrategy, acquire=mock.AsyncMock(side_effect=["Bearer ok.ok.ok", "Bearer ok2.ok2.ok2"]) + generate_allowed_mentions = stack.enter_context( + mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) ) + url_encoded_form = stack.enter_context(mock.patch.object(data_binding, "URLEncodedFormBuilder")) - with pytest.raises(exit_exception): - await rest_client._request(route) + with stack: + body, form = rest_client._build_message_payload(content=attachment) - _, kwargs = mock_session.request.call_args_list[0] - assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" - _, kwargs = mock_session.request.call_args_list[1] - assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok2.ok2.ok2" - assert live_attributes.still_alive.call_count == 6 + # Returned + assert body == {"allowed_mentions": {"allowed_mentions": 1}} + assert form is url_encoded_form.return_value - @hikari_test_helpers.timeout() - async def test__request_raises_after_re_auth_attempt(self, rest_client, exit_exception, live_attributes): - class StubResponse: - status = http.HTTPStatus.UNAUTHORIZED - content_type = rest._APPLICATION_JSON - reason = "cause why not" - headers = {"HEADER": "value", "HEADER": "value"} - real_url = "okokokok" + # Attachments + ensure_resource.assert_called_once_with(attachment) - async def read(self): - return '{"something": null}' + # Generate allowed mentions + generate_allowed_mentions.assert_called_once_with( + undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED + ) - async def json(self): - return {"something": None} + # Form builder + url_encoded_form.assert_called_once_with(executor=rest_client._executor) + url_encoded_form.return_value.add_resource.assert_called_once_with("file0", resource_attachment) - route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock( - request=hikari_test_helpers.CopyingAsyncMock(side_effect=[StubResponse(), StubResponse(), StubResponse()]) + def test__build_message_payload_with_singular_args(self, rest_client): + attachment = object() + resource_attachment = object() + component = mock.Mock(build=mock.Mock(return_value={"component": 1})) + embed = object() + embed_attachment = object() + mentions_everyone = object() + mentions_reply = object() + user_mentions = object() + role_mentions = object() + + stack = contextlib.ExitStack() + ensure_resource = stack.enter_context( + mock.patch.object(files, "ensure_resource", return_value=resource_attachment) ) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - rest_client._token = mock.Mock( - rest_api.TokenStrategy, acquire=mock.AsyncMock(side_effect=["Bearer ok.ok.ok", "Bearer ok2.ok2.ok2"]) + generate_allowed_mentions = stack.enter_context( + mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) ) + url_encoded_form = stack.enter_context(mock.patch.object(data_binding, "URLEncodedFormBuilder")) + rest_client._entity_factory.serialize_embed.return_value = ({"embed": 1}, [embed_attachment]) - with pytest.raises(errors.UnauthorizedError): - await rest_client._request(route) - - _, kwargs = mock_session.request.call_args_list[0] - assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" - _, kwargs = mock_session.request.call_args_list[1] - assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok2.ok2.ok2" - assert live_attributes.still_alive.call_count == 6 + with stack: + body, form = rest_client._build_message_payload( + content=987654321, + attachment=attachment, + component=component, + embed=embed, + replace_attachments=True, + flags=120, + tts=True, + mentions_everyone=mentions_everyone, + mentions_reply=mentions_reply, + user_mentions=user_mentions, + role_mentions=role_mentions, + ) - @hikari_test_helpers.timeout() - async def test__request_when__token_is_None(self, rest_client, exit_exception, live_attributes): - route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - rest_client._token = None + # Returned + assert body == { + "content": "987654321", + "tts": True, + "flags": 120, + "embeds": [{"embed": 1}], + "components": [{"component": 1}], + "attachments": None, + "allowed_mentions": {"allowed_mentions": 1}, + } + assert form is url_encoded_form.return_value - with pytest.raises(exit_exception): - await rest_client._request(route) + # Attachments + ensure_resource.assert_called_once_with(attachment) - _, kwargs = mock_session.request.call_args_list[0] - assert rest._AUTHORIZATION_HEADER not in kwargs["headers"] + # Embeds + rest_client._entity_factory.serialize_embed.assert_called_once_with(embed) - @hikari_test_helpers.timeout() - async def test__request_when__token_is_not_None(self, rest_client, exit_exception, live_attributes): - route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - rest_client._token = "token" + # Components + component.build.assert_called_once_with() - with pytest.raises(exit_exception): - await rest_client._request(route) + # Generate allowed mentions + generate_allowed_mentions.assert_called_once_with( + mentions_everyone, mentions_reply, user_mentions, role_mentions + ) + + # Form builder + url_encoded_form.assert_called_once_with(executor=rest_client._executor) + assert url_encoded_form.return_value.add_resource.call_count == 2 + url_encoded_form.return_value.add_resource.assert_has_calls( + [mock.call("file0", resource_attachment), mock.call("file1", embed_attachment)] + ) + + def test__build_message_payload_with_plural_args(self, rest_client): + attachment1 = object() + attachment2 = object() + resource_attachment1 = object() + resource_attachment2 = object() + component1 = mock.Mock(build=mock.Mock(return_value={"component": 1})) + component2 = mock.Mock(build=mock.Mock(return_value={"component": 2})) + embed1 = object() + embed2 = object() + embed_attachment1 = object() + embed_attachment2 = object() + embed_attachment3 = object() + embed_attachment4 = object() + mentions_everyone = object() + mentions_reply = object() + user_mentions = object() + role_mentions = object() + + stack = contextlib.ExitStack() + ensure_resource = stack.enter_context( + mock.patch.object(files, "ensure_resource", side_effect=[resource_attachment1, resource_attachment2]) + ) + generate_allowed_mentions = stack.enter_context( + mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) + ) + url_encoded_form = stack.enter_context(mock.patch.object(data_binding, "URLEncodedFormBuilder")) + rest_client._entity_factory.serialize_embed.side_effect = [ + ({"embed": 1}, [embed_attachment1, embed_attachment2]), + ({"embed": 2}, [embed_attachment3, embed_attachment4]), + ] + + with stack: + body, form = rest_client._build_message_payload( + content=987654321, + attachments=[attachment1, attachment2], + components=[component1, component2], + embeds=[embed1, embed2], + replace_attachments=True, + flags=120, + tts=True, + mentions_everyone=mentions_everyone, + mentions_reply=mentions_reply, + user_mentions=user_mentions, + role_mentions=role_mentions, + ) + + # Returned + assert body == { + "content": "987654321", + "tts": True, + "flags": 120, + "embeds": [{"embed": 1}, {"embed": 2}], + "components": [{"component": 1}, {"component": 2}], + "attachments": None, + "allowed_mentions": {"allowed_mentions": 1}, + } + assert form is url_encoded_form.return_value + + # Attachments + assert ensure_resource.call_count == 2 + ensure_resource.assert_has_calls([mock.call(attachment1), mock.call(attachment2)]) + + # Embeds + assert rest_client._entity_factory.serialize_embed.call_count == 2 + rest_client._entity_factory.serialize_embed.assert_has_calls([mock.call(embed1), mock.call(embed2)]) + + # Components + component1.build.assert_called_once_with() + component2.build.assert_called_once_with() + + # Generate allowed mentions + generate_allowed_mentions.assert_called_once_with( + mentions_everyone, mentions_reply, user_mentions, role_mentions + ) + + # Form builder + url_encoded_form.assert_called_once_with(executor=rest_client._executor) + assert url_encoded_form.return_value.add_resource.call_count == 6 + url_encoded_form.return_value.add_resource.assert_has_calls( + [ + mock.call("file0", resource_attachment1), + mock.call("file1", resource_attachment2), + mock.call("file2", embed_attachment1), + mock.call("file3", embed_attachment2), + mock.call("file4", embed_attachment3), + mock.call("file5", embed_attachment4), + ] + ) + + @pytest.mark.parametrize( + ("singular_arg", "plural_arg"), + [("attachment", "attachments"), ("component", "components"), ("embed", "embeds")], + ) + def test__build_message_payload_when_both_single_and_plural_args_passed( + self, rest_client, singular_arg, plural_arg + ): + with pytest.raises( + ValueError, match=rf"You may only specify one of '{singular_arg}' or '{plural_arg}', not both" + ): + rest_client._build_message_payload(**{singular_arg: object(), plural_arg: object()}) + + @pytest.mark.parametrize( + ("singular_arg", "plural_arg"), + [("attachment", "attachments"), ("component", "components"), ("embed", "embeds")], + ) + def test__build_message_payload_when_non_collection_passed_to_plural(self, rest_client, singular_arg, plural_arg): + expected_error_message = ( + f"You passed a non-collection to '{plural_arg}', but this expects a collection. Maybe you meant to use " + f"'{singular_arg}' (singular) instead?" + ) + + with pytest.raises(TypeError, match=re.escape(expected_error_message)): + rest_client._build_message_payload(**{plural_arg: object()}) + + def test_interaction_deferred_builder(self, rest_client): + result = rest_client.interaction_deferred_builder(5) + + assert result.type == 5 + assert isinstance(result, special_endpoints.InteractionDeferredBuilder) + + def test_interaction_autocomplete_builder(self, rest_client): + result = rest_client.interaction_autocomplete_builder( + [ + commands.CommandChoice(name="name", value="value"), + commands.CommandChoice(name="a", value="b"), + ] + ) + + assert result.type == 8 + assert isinstance(result, special_endpoints.InteractionAutocompleteBuilder) + assert len(result.choices) == 2 + + raw = result.build(mock.Mock()) + assert raw["data"] == {"choices": [{"name": "name", "value": "value"}, {"name": "a", "value": "b"}]} + + def test_interaction_autocomplete_builder_with_set_choices(self, rest_client): + result = rest_client.interaction_autocomplete_builder([commands.CommandChoice(name="name", value="value")]) + + result.set_choices([commands.CommandChoice(name="a", value="b")]) + assert result.choices == [commands.CommandChoice(name="a", value="b")] + + def test_interaction_message_builder(self, rest_client): + result = rest_client.interaction_message_builder(4) + + assert result.type == 4 + assert isinstance(result, special_endpoints.InteractionMessageBuilder) + + +@pytest.mark.asyncio() +class TestRESTClientImplAsync: + @pytest.fixture() + def exit_exception(self): + class ExitException(Exception): + ... + + return ExitException + + async def test___aenter__and__aexit__(self, rest_client): + rest_client.close = mock.AsyncMock() + rest_client.start = mock.Mock() + + async with rest_client as client: + assert client is rest_client + rest_client.start.assert_called_once() + rest_client.close.assert_not_called() + + rest_client.close.assert_awaited_once_with() + + @hikari_test_helpers.timeout() + async def test__request_builds_form_when_passed(self, rest_client, exit_exception, live_attributes): + route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) + mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) + live_attributes.buckets.is_started = True + live_attributes.client_session = mock_session + rest_client._token = None + mock_form = mock.AsyncMock() + mock_stack = mock.AsyncMock() + mock_stack.__aenter__ = mock_stack + + with mock.patch.object(contextlib, "AsyncExitStack", return_value=mock_stack) as exit_stack: + with pytest.raises(exit_exception): + await rest_client._request(route, form_builder=mock_form) + + _, kwargs = mock_session.request.call_args_list[0] + mock_form.build.assert_awaited_once_with(exit_stack.return_value) + assert kwargs["data"] is mock_form.build.return_value + assert live_attributes.still_alive.call_count == 3 + + @hikari_test_helpers.timeout() + async def test__request_url_encodes_reason_header(self, rest_client, exit_exception, live_attributes): + route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) + mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) + live_attributes.buckets.is_started = True + live_attributes.client_session = mock_session + + with pytest.raises(exit_exception): + await rest_client._request(route, reason="光のenergyが 大地に降りそそぐ") + + _, kwargs = mock_session.request.call_args_list[0] + assert kwargs["headers"][rest._X_AUDIT_LOG_REASON_HEADER] == ( + "%E5%85%89%E3%81%AEenergy%E3%81%8C%E3%80%80%E5%A4%" + "A7%E5%9C%B0%E3%81%AB%E9%99%8D%E3%82%8A%E3%81%9D%E3%81%9D%E3%81%90" + ) + + @hikari_test_helpers.timeout() + async def test__request_with_strategy_token(self, rest_client, exit_exception, live_attributes): + route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) + mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) + live_attributes.buckets.is_started = True + live_attributes.client_session = mock_session + rest_client._token = mock.Mock(rest_api.TokenStrategy, acquire=mock.AsyncMock(return_value="Bearer ok.ok.ok")) + + with pytest.raises(exit_exception): + await rest_client._request(route) + + _, kwargs = mock_session.request.call_args_list[0] + assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" + assert live_attributes.still_alive.call_count == 3 + + @hikari_test_helpers.timeout() + async def test__request_retries_strategy_once(self, rest_client, exit_exception, live_attributes): + class StubResponse: + status = http.HTTPStatus.UNAUTHORIZED + content_type = rest._APPLICATION_JSON + reason = "cause why not" + headers = {"HEADER": "value", "HEADER": "value"} + + async def read(self): + return '{"something": null}' + + route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) + mock_session = mock.AsyncMock( + request=hikari_test_helpers.CopyingAsyncMock(side_effect=[StubResponse(), exit_exception]) + ) + live_attributes.buckets.is_started = True + live_attributes.client_session = mock_session + rest_client._token = mock.Mock( + rest_api.TokenStrategy, acquire=mock.AsyncMock(side_effect=["Bearer ok.ok.ok", "Bearer ok2.ok2.ok2"]) + ) + + with pytest.raises(exit_exception): + await rest_client._request(route) + + _, kwargs = mock_session.request.call_args_list[0] + assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" + _, kwargs = mock_session.request.call_args_list[1] + assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok2.ok2.ok2" + assert live_attributes.still_alive.call_count == 6 + + @hikari_test_helpers.timeout() + async def test__request_raises_after_re_auth_attempt(self, rest_client, exit_exception, live_attributes): + class StubResponse: + status = http.HTTPStatus.UNAUTHORIZED + content_type = rest._APPLICATION_JSON + reason = "cause why not" + headers = {"HEADER": "value", "HEADER": "value"} + real_url = "okokokok" + + async def read(self): + return '{"something": null}' + + async def json(self): + return {"something": None} + + route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) + mock_session = mock.AsyncMock( + request=hikari_test_helpers.CopyingAsyncMock(side_effect=[StubResponse(), StubResponse(), StubResponse()]) + ) + live_attributes.buckets.is_started = True + live_attributes.client_session = mock_session + rest_client._token = mock.Mock( + rest_api.TokenStrategy, acquire=mock.AsyncMock(side_effect=["Bearer ok.ok.ok", "Bearer ok2.ok2.ok2"]) + ) + + with pytest.raises(errors.UnauthorizedError): + await rest_client._request(route) + + _, kwargs = mock_session.request.call_args_list[0] + assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" + _, kwargs = mock_session.request.call_args_list[1] + assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok2.ok2.ok2" + assert live_attributes.still_alive.call_count == 6 + + @hikari_test_helpers.timeout() + async def test__request_when__token_is_None(self, rest_client, exit_exception, live_attributes): + route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) + mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) + live_attributes.buckets.is_started = True + live_attributes.client_session = mock_session + rest_client._token = None + + with pytest.raises(exit_exception): + await rest_client._request(route) + + _, kwargs = mock_session.request.call_args_list[0] + assert rest._AUTHORIZATION_HEADER not in kwargs["headers"] + + @hikari_test_helpers.timeout() + async def test__request_when__token_is_not_None(self, rest_client, exit_exception, live_attributes): + route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) + mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) + live_attributes.buckets.is_started = True + live_attributes.client_session = mock_session + rest_client._token = "token" + + with pytest.raises(exit_exception): + await rest_client._request(route) _, kwargs = mock_session.request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "token" @@ -1786,399 +2103,120 @@ async def test_edit_permission_overwrites(self, rest_client): name="", is_mentionable=True, permissions=0, - position=0, - bot_id=None, - integration_id=None, - is_premium_subscriber_role=False, - ), - channels.PermissionOverwriteType.ROLE, - ), - ( - channels.PermissionOverwrite(type=channels.PermissionOverwriteType.MEMBER, id=456), - channels.PermissionOverwriteType.MEMBER, - ), - ], - ) - async def test_edit_permission_overwrites_when_target_undefined(self, rest_client, target, expected_type): - expected_route = routes.PUT_CHANNEL_PERMISSIONS.compile(channel=123, overwrite=456) - rest_client._request = mock.AsyncMock() - expected_json = {"type": expected_type} - - await rest_client.edit_permission_overwrites(StubModel(123), target) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason=undefined.UNDEFINED) - - async def test_edit_permission_overwrites_when_cant_determine_target_type(self, rest_client): - with pytest.raises(TypeError): - await rest_client.edit_permission_overwrites(StubModel(123), StubModel(123)) - - async def test_delete_permission_overwrite(self, rest_client): - expected_route = routes.DELETE_CHANNEL_PERMISSIONS.compile(channel=123, overwrite=456) - rest_client._request = mock.AsyncMock() - - await rest_client.delete_permission_overwrite(StubModel(123), StubModel(456)) - rest_client._request.assert_awaited_once_with(expected_route) - - async def test_fetch_channel_invites(self, rest_client): - invite1 = StubModel(456) - invite2 = StubModel(789) - expected_route = routes.GET_CHANNEL_INVITES.compile(channel=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_invite_with_metadata = mock.Mock(side_effect=[invite1, invite2]) - - assert await rest_client.fetch_channel_invites(StubModel(123)) == [invite1, invite2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_invite_with_metadata.call_count == 2 - rest_client._entity_factory.deserialize_invite_with_metadata.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) - - async def test_create_invite(self, rest_client): - expected_route = routes.POST_CHANNEL_INVITES.compile(channel=123) - rest_client._request = mock.AsyncMock(return_value={"ID": "NOOOOOOOOPOOOOOOOI!"}) - expected_json = { - "max_age": 60, - "max_uses": 4, - "temporary": True, - "unique": True, - "target_type": invites.TargetType.STREAM, - "target_user_id": "456", - "target_application_id": "789", - } - - result = await rest_client.create_invite( - StubModel(123), - max_age=datetime.timedelta(minutes=1), - max_uses=4, - temporary=True, - unique=True, - target_type=invites.TargetType.STREAM, - target_user=StubModel(456), - target_application=StubModel(789), - reason="cause why not :)", - ) - - assert result is rest_client._entity_factory.deserialize_invite_with_metadata.return_value - rest_client._entity_factory.deserialize_invite_with_metadata.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="cause why not :)") - - async def test_fetch_pins(self, rest_client): - message1 = StubModel(456) - message2 = StubModel(789) - expected_route = routes.GET_CHANNEL_PINS.compile(channel=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_message = mock.Mock(side_effect=[message1, message2]) - - assert await rest_client.fetch_pins(StubModel(123)) == [message1, message2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_message.call_count == 2 - rest_client._entity_factory.deserialize_message.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) - - async def test_pin_message(self, rest_client): - expected_route = routes.PUT_CHANNEL_PINS.compile(channel=123, message=456) - rest_client._request = mock.AsyncMock() - - await rest_client.pin_message(StubModel(123), StubModel(456)) - rest_client._request.assert_awaited_once_with(expected_route) - - async def test_unpin_message(self, rest_client): - expected_route = routes.DELETE_CHANNEL_PIN.compile(channel=123, message=456) - rest_client._request = mock.AsyncMock() - - await rest_client.unpin_message(StubModel(123), StubModel(456)) - rest_client._request.assert_awaited_once_with(expected_route) - - async def test_fetch_message(self, rest_client): - message_obj = mock.Mock() - expected_route = routes.GET_CHANNEL_MESSAGE.compile(channel=123, message=456) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_message = mock.Mock(return_value=message_obj) - - assert await rest_client.fetch_message(StubModel(123), StubModel(456)) is message_obj - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"id": "456"}) - - def test__build_message_payload_with_undefined_args(self, rest_client): - with mock.patch.object( - mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1} - ) as generate_allowed_mentions: - body, form = rest_client._build_message_payload() - - assert body == {"allowed_mentions": {"allowed_mentions": 1}} - assert form is None - - generate_allowed_mentions.assert_called_once_with( - undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED - ) - - @pytest.mark.parametrize("args", [("embeds", "components"), ("embed", "component")]) - def test__build_message_payload_with_None_args(self, rest_client, args): - kwargs = {} - for arg in args: - kwargs[arg] = None - - with mock.patch.object( - mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1} - ) as generate_allowed_mentions: - body, form = rest_client._build_message_payload(**kwargs) - - assert body == {"embeds": [], "components": [], "allowed_mentions": {"allowed_mentions": 1}} - assert form is None - - generate_allowed_mentions.assert_called_once_with( - undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED - ) - - def test__build_message_payload_with_edit_and_all_mentions_undefined(self, rest_client): - with mock.patch.object(mentions, "generate_allowed_mentions") as generate_allowed_mentions: - body, form = rest_client._build_message_payload(edit=True) - - assert body == {} - assert form is None - - generate_allowed_mentions.assert_not_called() - - def test__build_message_payload_embed_content_syntactic_sugar(self, rest_client): - embed = mock.Mock(embeds.Embed) - - stack = contextlib.ExitStack() - generate_allowed_mentions = stack.enter_context( - mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) - ) - rest_client._entity_factory.serialize_embed.return_value = ({"embed": 1}, []) - - with stack: - body, form = rest_client._build_message_payload(content=embed) - - # Returned - assert body == {"embeds": [{"embed": 1}], "allowed_mentions": {"allowed_mentions": 1}} - assert form is None - - # Embeds - rest_client._entity_factory.serialize_embed.assert_called_once_with(embed) - - # Generate allowed mentions - generate_allowed_mentions.assert_called_once_with( - undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED - ) - - def test__build_message_payload_attachment_content_syntactic_sugar(self, rest_client): - attachment = mock.Mock(files.Resource) - resource_attachment = object() - - stack = contextlib.ExitStack() - ensure_resource = stack.enter_context( - mock.patch.object(files, "ensure_resource", return_value=resource_attachment) - ) - generate_allowed_mentions = stack.enter_context( - mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) - ) - url_encoded_form = stack.enter_context(mock.patch.object(data_binding, "URLEncodedFormBuilder")) - - with stack: - body, form = rest_client._build_message_payload(content=attachment) - - # Returned - assert body == {"allowed_mentions": {"allowed_mentions": 1}} - assert form is url_encoded_form.return_value - - # Attachments - ensure_resource.assert_called_once_with(attachment) - - # Generate allowed mentions - generate_allowed_mentions.assert_called_once_with( - undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED - ) - - # Form builder - url_encoded_form.assert_called_once_with(executor=rest_client._executor) - url_encoded_form.return_value.add_resource.assert_called_once_with("file0", resource_attachment) - - def test__build_message_payload_with_singular_args(self, rest_client): - attachment = object() - resource_attachment = object() - component = mock.Mock(build=mock.Mock(return_value={"component": 1})) - embed = object() - embed_attachment = object() - mentions_everyone = object() - mentions_reply = object() - user_mentions = object() - role_mentions = object() - - stack = contextlib.ExitStack() - ensure_resource = stack.enter_context( - mock.patch.object(files, "ensure_resource", return_value=resource_attachment) - ) - generate_allowed_mentions = stack.enter_context( - mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) - ) - url_encoded_form = stack.enter_context(mock.patch.object(data_binding, "URLEncodedFormBuilder")) - rest_client._entity_factory.serialize_embed.return_value = ({"embed": 1}, [embed_attachment]) - - with stack: - body, form = rest_client._build_message_payload( - content=987654321, - attachment=attachment, - component=component, - embed=embed, - replace_attachments=True, - flags=120, - tts=True, - mentions_everyone=mentions_everyone, - mentions_reply=mentions_reply, - user_mentions=user_mentions, - role_mentions=role_mentions, - ) + position=0, + bot_id=None, + integration_id=None, + is_premium_subscriber_role=False, + ), + channels.PermissionOverwriteType.ROLE, + ), + ( + channels.PermissionOverwrite(type=channels.PermissionOverwriteType.MEMBER, id=456), + channels.PermissionOverwriteType.MEMBER, + ), + ], + ) + async def test_edit_permission_overwrites_when_target_undefined(self, rest_client, target, expected_type): + expected_route = routes.PUT_CHANNEL_PERMISSIONS.compile(channel=123, overwrite=456) + rest_client._request = mock.AsyncMock() + expected_json = {"type": expected_type} - # Returned - assert body == { - "content": "987654321", - "tts": True, - "flags": 120, - "embeds": [{"embed": 1}], - "components": [{"component": 1}], - "attachments": None, - "allowed_mentions": {"allowed_mentions": 1}, - } - assert form is url_encoded_form.return_value + await rest_client.edit_permission_overwrites(StubModel(123), target) + rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason=undefined.UNDEFINED) - # Attachments - ensure_resource.assert_called_once_with(attachment) + async def test_edit_permission_overwrites_when_cant_determine_target_type(self, rest_client): + with pytest.raises(TypeError): + await rest_client.edit_permission_overwrites(StubModel(123), StubModel(123)) - # Embeds - rest_client._entity_factory.serialize_embed.assert_called_once_with(embed) + async def test_delete_permission_overwrite(self, rest_client): + expected_route = routes.DELETE_CHANNEL_PERMISSIONS.compile(channel=123, overwrite=456) + rest_client._request = mock.AsyncMock() - # Components - component.build.assert_called_once_with() + await rest_client.delete_permission_overwrite(StubModel(123), StubModel(456)) + rest_client._request.assert_awaited_once_with(expected_route) - # Generate allowed mentions - generate_allowed_mentions.assert_called_once_with( - mentions_everyone, mentions_reply, user_mentions, role_mentions - ) + async def test_fetch_channel_invites(self, rest_client): + invite1 = StubModel(456) + invite2 = StubModel(789) + expected_route = routes.GET_CHANNEL_INVITES.compile(channel=123) + rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) + rest_client._entity_factory.deserialize_invite_with_metadata = mock.Mock(side_effect=[invite1, invite2]) - # Form builder - url_encoded_form.assert_called_once_with(executor=rest_client._executor) - assert url_encoded_form.return_value.add_resource.call_count == 2 - url_encoded_form.return_value.add_resource.assert_has_calls( - [mock.call("file0", resource_attachment), mock.call("file1", embed_attachment)] + assert await rest_client.fetch_channel_invites(StubModel(123)) == [invite1, invite2] + rest_client._request.assert_awaited_once_with(expected_route) + assert rest_client._entity_factory.deserialize_invite_with_metadata.call_count == 2 + rest_client._entity_factory.deserialize_invite_with_metadata.assert_has_calls( + [mock.call({"id": "456"}), mock.call({"id": "789"})] ) - def test__build_message_payload_with_plural_args(self, rest_client): - attachment1 = object() - attachment2 = object() - resource_attachment1 = object() - resource_attachment2 = object() - component1 = mock.Mock(build=mock.Mock(return_value={"component": 1})) - component2 = mock.Mock(build=mock.Mock(return_value={"component": 2})) - embed1 = object() - embed2 = object() - embed_attachment1 = object() - embed_attachment2 = object() - embed_attachment3 = object() - embed_attachment4 = object() - mentions_everyone = object() - mentions_reply = object() - user_mentions = object() - role_mentions = object() + async def test_create_invite(self, rest_client): + expected_route = routes.POST_CHANNEL_INVITES.compile(channel=123) + rest_client._request = mock.AsyncMock(return_value={"ID": "NOOOOOOOOPOOOOOOOI!"}) + expected_json = { + "max_age": 60, + "max_uses": 4, + "temporary": True, + "unique": True, + "target_type": invites.TargetType.STREAM, + "target_user_id": "456", + "target_application_id": "789", + } - stack = contextlib.ExitStack() - ensure_resource = stack.enter_context( - mock.patch.object(files, "ensure_resource", side_effect=[resource_attachment1, resource_attachment2]) - ) - generate_allowed_mentions = stack.enter_context( - mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) + result = await rest_client.create_invite( + StubModel(123), + max_age=datetime.timedelta(minutes=1), + max_uses=4, + temporary=True, + unique=True, + target_type=invites.TargetType.STREAM, + target_user=StubModel(456), + target_application=StubModel(789), + reason="cause why not :)", ) - url_encoded_form = stack.enter_context(mock.patch.object(data_binding, "URLEncodedFormBuilder")) - rest_client._entity_factory.serialize_embed.side_effect = [ - ({"embed": 1}, [embed_attachment1, embed_attachment2]), - ({"embed": 2}, [embed_attachment3, embed_attachment4]), - ] - - with stack: - body, form = rest_client._build_message_payload( - content=987654321, - attachments=[attachment1, attachment2], - components=[component1, component2], - embeds=[embed1, embed2], - replace_attachments=True, - flags=120, - tts=True, - mentions_everyone=mentions_everyone, - mentions_reply=mentions_reply, - user_mentions=user_mentions, - role_mentions=role_mentions, - ) - # Returned - assert body == { - "content": "987654321", - "tts": True, - "flags": 120, - "embeds": [{"embed": 1}, {"embed": 2}], - "components": [{"component": 1}, {"component": 2}], - "attachments": None, - "allowed_mentions": {"allowed_mentions": 1}, - } - assert form is url_encoded_form.return_value + assert result is rest_client._entity_factory.deserialize_invite_with_metadata.return_value + rest_client._entity_factory.deserialize_invite_with_metadata.assert_called_once_with( + rest_client._request.return_value + ) + rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="cause why not :)") - # Attachments - assert ensure_resource.call_count == 2 - ensure_resource.assert_has_calls([mock.call(attachment1), mock.call(attachment2)]) + async def test_fetch_pins(self, rest_client): + message1 = StubModel(456) + message2 = StubModel(789) + expected_route = routes.GET_CHANNEL_PINS.compile(channel=123) + rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) + rest_client._entity_factory.deserialize_message = mock.Mock(side_effect=[message1, message2]) - # Embeds - assert rest_client._entity_factory.serialize_embed.call_count == 2 - rest_client._entity_factory.serialize_embed.assert_has_calls([mock.call(embed1), mock.call(embed2)]) + assert await rest_client.fetch_pins(StubModel(123)) == [message1, message2] + rest_client._request.assert_awaited_once_with(expected_route) + assert rest_client._entity_factory.deserialize_message.call_count == 2 + rest_client._entity_factory.deserialize_message.assert_has_calls( + [mock.call({"id": "456"}), mock.call({"id": "789"})] + ) - # Components - component1.build.assert_called_once_with() - component2.build.assert_called_once_with() + async def test_pin_message(self, rest_client): + expected_route = routes.PUT_CHANNEL_PINS.compile(channel=123, message=456) + rest_client._request = mock.AsyncMock() - # Generate allowed mentions - generate_allowed_mentions.assert_called_once_with( - mentions_everyone, mentions_reply, user_mentions, role_mentions - ) + await rest_client.pin_message(StubModel(123), StubModel(456)) + rest_client._request.assert_awaited_once_with(expected_route) - # Form builder - url_encoded_form.assert_called_once_with(executor=rest_client._executor) - assert url_encoded_form.return_value.add_resource.call_count == 6 - url_encoded_form.return_value.add_resource.assert_has_calls( - [ - mock.call("file0", resource_attachment1), - mock.call("file1", resource_attachment2), - mock.call("file2", embed_attachment1), - mock.call("file3", embed_attachment2), - mock.call("file4", embed_attachment3), - mock.call("file5", embed_attachment4), - ] - ) + async def test_unpin_message(self, rest_client): + expected_route = routes.DELETE_CHANNEL_PIN.compile(channel=123, message=456) + rest_client._request = mock.AsyncMock() - @pytest.mark.parametrize( - ("singular_arg", "plural_arg"), - [("attachment", "attachments"), ("component", "components"), ("embed", "embeds")], - ) - def test__build_message_payload_when_both_single_and_plural_args_passed( - self, rest_client, singular_arg, plural_arg - ): - with pytest.raises( - ValueError, match=rf"You may only specify one of '{singular_arg}' or '{plural_arg}', not both" - ): - rest_client._build_message_payload(**{singular_arg: object(), plural_arg: object()}) + await rest_client.unpin_message(StubModel(123), StubModel(456)) + rest_client._request.assert_awaited_once_with(expected_route) - @pytest.mark.parametrize( - ("singular_arg", "plural_arg"), - [("attachment", "attachments"), ("component", "components"), ("embed", "embeds")], - ) - def test__build_message_payload_when_non_collection_passed_to_plural(self, rest_client, singular_arg, plural_arg): - expected_error_message = ( - f"You passed a non-collection to '{plural_arg}', but this expects a collection. Maybe you meant to use " - f"'{singular_arg}' (singular) instead?" - ) + async def test_fetch_message(self, rest_client): + message_obj = mock.Mock() + expected_route = routes.GET_CHANNEL_MESSAGE.compile(channel=123, message=456) + rest_client._request = mock.AsyncMock(return_value={"id": "456"}) + rest_client._entity_factory.deserialize_message = mock.Mock(return_value=message_obj) - with pytest.raises(TypeError, match=re.escape(expected_error_message)): - rest_client._build_message_payload(**{plural_arg: object()}) + assert await rest_client.fetch_message(StubModel(123), StubModel(456)) is message_obj + rest_client._request.assert_awaited_once_with(expected_route) + rest_client._entity_factory.deserialize_message.assert_called_once_with({"id": "456"}) async def test_create_message_when_form(self, rest_client): attachment_obj = object() @@ -4811,39 +4849,6 @@ async def test_set_application_command_permissions(self, rest_client): route, json={"permissions": [rest_client._entity_factory.serialize_command_permission.return_value]} ) - def test_interaction_deferred_builder(self, rest_client): - result = rest_client.interaction_deferred_builder(5) - - assert result.type == 5 - assert isinstance(result, special_endpoints.InteractionDeferredBuilder) - - def test_interaction_autocomplete_builder(self, rest_client): - result = rest_client.interaction_autocomplete_builder( - [ - commands.CommandChoice(name="name", value="value"), - commands.CommandChoice(name="a", value="b"), - ] - ) - - assert result.type == 8 - assert isinstance(result, special_endpoints.InteractionAutocompleteBuilder) - assert len(result.choices) == 2 - - raw = result.build(mock.Mock()) - assert raw["data"] == {"choices": [{"name": "name", "value": "value"}, {"name": "a", "value": "b"}]} - - def test_interaction_autocomplete_builder_with_set_choices(self, rest_client): - result = rest_client.interaction_autocomplete_builder([commands.CommandChoice(name="name", value="value")]) - - result.set_choices([commands.CommandChoice(name="a", value="b")]) - assert result.choices == [commands.CommandChoice(name="a", value="b")] - - def test_interaction_message_builder(self, rest_client): - result = rest_client.interaction_message_builder(4) - - assert result.type == 4 - assert isinstance(result, special_endpoints.InteractionMessageBuilder) - async def test_fetch_interaction_response(self, rest_client): expected_route = routes.GET_INTERACTION_RESPONSE.compile(webhook=1235432, token="go homo or go gnomo") rest_client._request = mock.AsyncMock(return_value={"id": "94949494949"}) diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index 63d7642436..b465707c97 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -60,7 +60,19 @@ def proxy_settings(): return mock.Mock(spec_set=config.ProxySettings) -@pytest.mark.asyncio() +class StubResponse: + def __init__( + self, + *, + type=None, + data=None, + extra=None, + ): + self.type = type + self.data = data + self.extra = extra + + class TestGatewayTransport: @pytest.fixture() def transport_impl(self): @@ -76,6 +88,7 @@ def test__init__calls_super(self): init.assert_called_once_with("arg1", "arg2", some_kwarg="kwarg1") + @pytest.mark.asyncio() async def test_send_close_when_not_closed_nor_closing_logs(self, transport_impl): transport_impl.sent_close = False @@ -86,6 +99,7 @@ async def test_send_close_when_not_closed_nor_closing_logs(self, transport_impl) wait_for.assert_awaited_once_with(close.return_value, timeout=5) close.assert_called_once_with(code=1234, message=b"some message") + @pytest.mark.asyncio() async def test_send_close_when_TimeoutError(self, transport_impl): transport_impl.sent_close = False @@ -94,6 +108,7 @@ async def test_send_close_when_TimeoutError(self, transport_impl): close.assert_called_once_with(code=1234, message=b"some message") + @pytest.mark.asyncio() @pytest.mark.parametrize("trace", [True, False]) async def test_receive_json(self, transport_impl, trace): transport_impl._receive_and_check = mock.AsyncMock(return_value="{'json_response': null}") @@ -105,6 +120,7 @@ async def test_receive_json(self, transport_impl, trace): transport_impl._receive_and_check.assert_awaited_once_with(69) mock_loads.assert_called_once_with("{'json_response': null}") + @pytest.mark.asyncio() @pytest.mark.parametrize("trace", [True, False]) async def test_send_json(self, transport_impl, trace): transport_impl.send_str = mock.AsyncMock() @@ -116,18 +132,7 @@ async def test_send_json(self, transport_impl, trace): transport_impl.send_str.assert_awaited_once_with("{'json_send': null}", 420) mock_dumps.assert_called_once_with({"json_send": None}) - class StubResponse: - def __init__( - self, - *, - type=None, - data=None, - extra=None, - ): - self.type = type - self.data = data - self.extra = extra - + @pytest.mark.asyncio() @pytest.mark.parametrize( "code", [ @@ -140,7 +145,7 @@ def __init__( ], ) async def test__receive_and_check_when_message_type_is_CLOSE_and_should_reconnect(self, code, transport_impl): - stub_response = self.StubResponse(type=aiohttp.WSMsgType.CLOSE, extra="some error extra", data=code) + stub_response = StubResponse(type=aiohttp.WSMsgType.CLOSE, extra="some error extra", data=code) transport_impl.receive = mock.AsyncMock(return_value=stub_response) with pytest.raises(errors.GatewayServerClosedConnectionError) as exinfo: @@ -152,12 +157,13 @@ async def test__receive_and_check_when_message_type_is_CLOSE_and_should_reconnec assert exception.can_reconnect is True transport_impl.receive.assert_awaited_once_with(10) + @pytest.mark.asyncio() @pytest.mark.parametrize( "code", [*range(4010, 4020), 5000], ) async def test__receive_and_check_when_message_type_is_CLOSE_and_should_not_reconnect(self, code, transport_impl): - stub_response = self.StubResponse(type=aiohttp.WSMsgType.CLOSE, extra="dont reconnect", data=code) + stub_response = StubResponse(type=aiohttp.WSMsgType.CLOSE, extra="dont reconnect", data=code) transport_impl.receive = mock.AsyncMock(return_value=stub_response) with pytest.raises(errors.GatewayServerClosedConnectionError) as exinfo: @@ -169,8 +175,9 @@ async def test__receive_and_check_when_message_type_is_CLOSE_and_should_not_reco assert exception.can_reconnect is False transport_impl.receive.assert_awaited_once_with(10) + @pytest.mark.asyncio() async def test__receive_and_check_when_message_type_is_CLOSING(self, transport_impl): - stub_response = self.StubResponse(type=aiohttp.WSMsgType.CLOSING) + stub_response = StubResponse(type=aiohttp.WSMsgType.CLOSING) transport_impl.receive = mock.AsyncMock(return_value=stub_response) with pytest.raises(errors.GatewayError, match="Socket has closed"): @@ -178,8 +185,9 @@ async def test__receive_and_check_when_message_type_is_CLOSING(self, transport_i transport_impl.receive.assert_awaited_once_with(10) + @pytest.mark.asyncio() async def test__receive_and_check_when_message_type_is_CLOSED(self, transport_impl): - stub_response = self.StubResponse(type=aiohttp.WSMsgType.CLOSED) + stub_response = StubResponse(type=aiohttp.WSMsgType.CLOSED) transport_impl.receive = mock.AsyncMock(return_value=stub_response) with pytest.raises(errors.GatewayError, match="Socket has closed"): @@ -187,10 +195,11 @@ async def test__receive_and_check_when_message_type_is_CLOSED(self, transport_im transport_impl.receive.assert_awaited_once_with(10) + @pytest.mark.asyncio() async def test__receive_and_check_when_message_type_is_BINARY(self, transport_impl): - response1 = self.StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some") - response2 = self.StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data") - response3 = self.StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff") + response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some") + response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data") + response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff") transport_impl.receive = mock.AsyncMock(side_effect=[response1, response2, response3]) transport_impl.zlib = mock.Mock(decompress=mock.Mock(return_value=b"utf-8 encoded bytes")) @@ -199,9 +208,10 @@ async def test__receive_and_check_when_message_type_is_BINARY(self, transport_im transport_impl.receive.assert_awaited_with(10) transport_impl.zlib.decompress.assert_called_once_with(bytearray(b"somedata\x00\x00\xff\xff")) + @pytest.mark.asyncio() async def test__receive_and_check_when_buff_but_next_is_not_BINARY(self, transport_impl): - response1 = self.StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some") - response2 = self.StubResponse(type=aiohttp.WSMsgType.TEXT) + response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some") + response2 = StubResponse(type=aiohttp.WSMsgType.TEXT) transport_impl.receive = mock.AsyncMock(side_effect=[response1, response2]) with pytest.raises(errors.GatewayError, match="Unexpected message type received TEXT, expected BINARY"): @@ -209,17 +219,19 @@ async def test__receive_and_check_when_buff_but_next_is_not_BINARY(self, transpo transport_impl.receive.assert_awaited_with(10) + @pytest.mark.asyncio() async def test__receive_and_check_when_message_type_is_TEXT(self, transport_impl): transport_impl.receive = mock.AsyncMock( - return_value=self.StubResponse(type=aiohttp.WSMsgType.TEXT, data="some text") + return_value=StubResponse(type=aiohttp.WSMsgType.TEXT, data="some text") ) assert await transport_impl._receive_and_check(10) == "some text" transport_impl.receive.assert_awaited_once_with(10) + @pytest.mark.asyncio() async def test__receive_and_check_when_message_type_is_unknown(self, transport_impl): - transport_impl.receive = mock.AsyncMock(return_value=self.StubResponse(type=aiohttp.WSMsgType.ERROR)) + transport_impl.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.ERROR)) transport_impl.exception = mock.Mock(return_value=Exception) with pytest.raises(errors.GatewayError, match="Unexpected websocket exception from gateway"): @@ -227,6 +239,7 @@ async def test__receive_and_check_when_message_type_is_unknown(self, transport_i transport_impl.receive.assert_awaited_once_with(10) + @pytest.mark.asyncio() async def test_connect_yields_websocket(self, http_settings, proxy_settings): class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._GatewayTransport): closed = True @@ -293,6 +306,7 @@ def __init__(self): mock_websocket.assert_used_once() sleep.assert_awaited_once_with(0.25) + @pytest.mark.asyncio() async def test_connect_when_gateway_error_after_connecting(self, http_settings, proxy_settings): class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._GatewayTransport): closed = False @@ -333,6 +347,7 @@ def __init__(self): mock_client_session.assert_used_once() mock_websocket.assert_used_once() + @pytest.mark.asyncio() async def test_connect_when_unexpected_error_after_connecting(self, http_settings, proxy_settings): class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._GatewayTransport): closed = False @@ -373,6 +388,7 @@ def __init__(self): mock_client_session.assert_used_once() mock_websocket.assert_used_once() + @pytest.mark.asyncio() async def test_connect_when_no_error_and_not_closing(self, http_settings, proxy_settings): class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._GatewayTransport): closed = False @@ -413,6 +429,7 @@ def __init__(self): mock_client_session.assert_used_once() mock_websocket.assert_used_once() + @pytest.mark.asyncio() async def test_connect_when_no_error_and_closing(self, http_settings, proxy_settings): class MockWS(hikari_test_helpers.AsyncContextManagerMock, shard._GatewayTransport): closed = False @@ -450,6 +467,7 @@ def __init__(self): mock_client_session.assert_used_once() mock_websocket.assert_used_once() + @pytest.mark.asyncio() async def test_connect_when_error_connecting(self, http_settings, proxy_settings): mock_client_session = hikari_test_helpers.AsyncContextManagerMock() mock_client_session.ws_connect = mock.MagicMock(side_effect=aiohttp.ClientConnectionError("some error")) @@ -478,6 +496,7 @@ async def test_connect_when_error_connecting(self, http_settings, proxy_settings sleep.assert_awaited_once_with(0.25) mock_client_session.assert_used_once() + @pytest.mark.asyncio() async def test_connect_when_handshake_error_with_unknown_reason(self, http_settings, proxy_settings): mock_client_session = hikari_test_helpers.AsyncContextManagerMock() mock_client_session.ws_connect = mock.MagicMock( @@ -515,6 +534,7 @@ async def test_connect_when_handshake_error_with_unknown_reason(self, http_setti sleep.assert_awaited_once_with(0.25) mock_client_session.assert_used_once() + @pytest.mark.asyncio() async def test_connect_when_handshake_error_with_known_reason(self, http_settings, proxy_settings): mock_client_session = hikari_test_helpers.AsyncContextManagerMock() mock_client_session.ws_connect = mock.MagicMock( @@ -553,30 +573,27 @@ async def test_connect_when_handshake_error_with_known_reason(self, http_setting mock_client_session.assert_used_once() -@pytest.mark.asyncio() -class TestGatewayShardImpl: - @pytest.fixture() - def client_session(self): - stub = client_session_stub.ClientSessionStub() - with mock.patch.object(aiohttp, "ClientSession", new=stub): - yield stub +@pytest.fixture() +def client_session(): + stub = client_session_stub.ClientSessionStub() + with mock.patch.object(aiohttp, "ClientSession", new=stub): + yield stub - @pytest.fixture(scope="module") - def unslotted_client_type(self): - return hikari_test_helpers.mock_class_namespace(shard.GatewayShardImpl, slots_=False) - @pytest.fixture() - def client(self, http_settings, proxy_settings, unslotted_client_type): - return unslotted_client_type( - event_manager=mock.Mock(), - event_factory=mock.Mock(), - url="wss://gateway.discord.gg", - intents=intents.Intents.ALL, - token="lol", - http_settings=http_settings, - proxy_settings=proxy_settings, - ) +@pytest.fixture() +def client(http_settings, proxy_settings): + return hikari_test_helpers.mock_class_namespace(shard.GatewayShardImpl, slots_=False)( + event_manager=mock.Mock(), + event_factory=mock.Mock(), + url="wss://gateway.discord.gg", + intents=intents.Intents.ALL, + token="lol", + http_settings=http_settings, + proxy_settings=proxy_settings, + ) + +class TestGatewayShardImpl: @pytest.mark.parametrize( ("compression", "expect"), [ @@ -610,7 +627,7 @@ def test_using_etf_is_unsupported(self, http_settings, proxy_settings): url="wss://erlpack-is-broken-lol.discord.meh", intents=intents.Intents.ALL, data_format="etf", - compression=True, + compression="testing", ) def test_heartbeat_latency_property(self, client): @@ -622,9 +639,9 @@ def test_id_property(self, client): assert client.id == 101 def test_intents_property(self, client): - intents = object() - client._intents = intents - assert client.intents is intents + mock_intents = object() + client._intents = mock_intents + assert client.intents is mock_intents def test_is_alive_property(self, client): client._run_task = None @@ -653,6 +670,159 @@ def test_shard__check_if_alive_when_alive(self, client): with mock.patch.object(shard.GatewayShardImpl, "is_alive", new=True): client._check_if_alive() + def test__get_ws_when_active(self, client): + mock_ws = client._ws = object() + + assert client._get_ws() is mock_ws + + def test__get_ws_when_inactive(self, client): + client._ws = None + + with pytest.raises(errors.ComponentStateConflictError): + client._get_ws() + + def test_dispatch_when_READY(self, client): + client._seq = 0 + client._session_id = 0 + client._user_id = 0 + client._logger = mock.Mock() + client._handshake_completed = mock.Mock() + client._event_manager = mock.Mock() + + pl = { + "session_id": 101, + "user": {"id": 123, "username": "hikari", "discriminator": "5863"}, + "guilds": [ + {"id": "123"}, + {"id": "456"}, + {"id": "789"}, + ], + "v": 8, + } + + client._dispatch( + "READY", + 10, + pl, + ) + + assert client._seq == 10 + assert client._session_id == 101 + assert client._user_id == 123 + client._logger.info.assert_called_once_with( + "shard is ready: %s guilds, %s (%s), session %r on v%s gateway", + 3, + "hikari#5863", + 123, + 101, + 8, + ) + client._handshake_completed.set.assert_called_once_with() + client._event_manager.consume_raw_event.assert_called_once_with( + "READY", + client, + pl, + ) + + def test__dipatch_when_RESUME(self, client): + client._seq = 0 + client._session_id = 123 + client._logger = mock.Mock() + client._handshake_completed = mock.Mock() + client._event_manager = mock.Mock() + + client._dispatch("RESUME", 10, {}) + + assert client._seq == 10 + client._logger.info.assert_called_once_with("shard has resumed [session:%s, seq:%s]", 123, 10) + client._handshake_completed.set.assert_called_once_with() + client._event_manager.consume_raw_event.assert_called_once_with("RESUME", client, {}) + + def test__dipatch(self, client): + client._logger = mock.Mock() + client._handshake_completed = mock.Mock() + client._event_manager = mock.Mock() + + client._dispatch("EVENT NAME", 10, {"payload": None}) + + client._logger.info.assert_not_called() + client._logger.debug.assert_not_called() + client._handshake_completed.set.assert_not_called() + client._event_manager.consume_raw_event.assert_called_once_with("EVENT NAME", client, {"payload": None}) + + def test__serialize_activity_when_activity_is_None(self, client): + assert client._serialize_activity(None) is None + + def test__serialize_activity_when_activity_is_not_None(self, client): + activity = mock.Mock(type="0", url="https://some.url") + activity.name = "some name" # This has to be set separate because if not, its set as the mock's name + assert client._serialize_activity(activity) == {"name": "some name", "type": 0, "url": "https://some.url"} + + @pytest.mark.parametrize("idle_since", [datetime.datetime.now(), None]) + @pytest.mark.parametrize("afk", [True, False]) + @pytest.mark.parametrize( + "status", + [presences.Status.DO_NOT_DISTURB, presences.Status.IDLE, presences.Status.ONLINE, presences.Status.OFFLINE], + ) + @pytest.mark.parametrize("activity", [presences.Activity(name="foo"), None]) + def test__serialize_and_store_presence_payload_when_all_args_undefined( + self, client, idle_since, afk, status, activity + ): + client._activity = activity + client._idle_since = idle_since + client._is_afk = afk + client._status = status + + actual_result = client._serialize_and_store_presence_payload() + + if activity is not undefined.UNDEFINED and activity is not None: + expected_activity = { + "name": activity.name, + "type": activity.type, + "url": activity.url, + } + else: + expected_activity = None + + if status == presences.Status.OFFLINE: + expected_status = "invisible" + else: + expected_status = status.value + + expected_result = { + "game": expected_activity, + "since": int(idle_since.timestamp() * 1_000) if idle_since is not None else None, + "afk": afk if afk is not undefined.UNDEFINED else False, + "status": expected_status, + } + + assert expected_result == actual_result + + @pytest.mark.parametrize("idle_since", [datetime.datetime.now(), None]) + @pytest.mark.parametrize("afk", [True, False]) + @pytest.mark.parametrize( + "status", + [presences.Status.DO_NOT_DISTURB, presences.Status.IDLE, presences.Status.ONLINE, presences.Status.OFFLINE], + ) + @pytest.mark.parametrize("activity", [presences.Activity(name="foo"), None]) + def test__serialize_and_store_presence_payload_sets_state(self, client, idle_since, afk, status, activity): + client._serialize_and_store_presence_payload(idle_since=idle_since, afk=afk, status=status, activity=activity) + + assert client._activity == activity + assert client._idle_since == idle_since + assert client._is_afk == afk + assert client._status == status + + def test__serialize_datetime_when_datetime_is_None(self, client): + assert client._serialize_datetime(None) is None + + def test__serialize_datetime_when_datetime_is_not_None(self, client): + dt = datetime.datetime(2004, 11, 22, tzinfo=datetime.timezone.utc) + assert client._serialize_datetime(dt) == 1101081600000 + + +@pytest.mark.asyncio() +class TestGatewayShardImplAsync: async def test_close_when_closing_event_set(self, client): client._closing_event = mock.Mock(is_set=mock.Mock(return_value=True)) client._closed_event = mock.Mock(wait=mock.AsyncMock()) @@ -710,17 +880,6 @@ async def test_when__user_id_is_not_None(self, client): client._user_id = 123 assert await client.get_user_id() == 123 - def test__get_ws_when_active(self, client): - mock_ws = client._ws = object() - - assert client._get_ws() is mock_ws - - def test__get_ws_when_inactive(self, client): - client._ws = None - - with pytest.raises(errors.ComponentStateConflictError): - client._get_ws() - async def test_join(self, client): client._closed_event = mock.Mock(wait=mock.AsyncMock()) @@ -904,75 +1063,6 @@ async def test_update_voice_state_without_optionals(self, client): client._send_json.assert_awaited_once_with({"op": 4, "d": payload}) - def test_dispatch_when_READY(self, client): - client._seq = 0 - client._session_id = 0 - client._user_id = 0 - client._logger = mock.Mock() - client._handshake_completed = mock.Mock() - client._event_manager = mock.Mock() - - pl = { - "session_id": 101, - "user": {"id": 123, "username": "hikari", "discriminator": "5863"}, - "guilds": [ - {"id": "123"}, - {"id": "456"}, - {"id": "789"}, - ], - "v": 8, - } - - client._dispatch( - "READY", - 10, - pl, - ) - - assert client._seq == 10 - assert client._session_id == 101 - assert client._user_id == 123 - client._logger.info.assert_called_once_with( - "shard is ready: %s guilds, %s (%s), session %r on v%s gateway", - 3, - "hikari#5863", - 123, - 101, - 8, - ) - client._handshake_completed.set.assert_called_once_with() - client._event_manager.consume_raw_event.assert_called_once_with( - "READY", - client, - pl, - ) - - def test__dipatch_when_RESUME(self, client): - client._seq = 0 - client._session_id = 123 - client._logger = mock.Mock() - client._handshake_completed = mock.Mock() - client._event_manager = mock.Mock() - - client._dispatch("RESUME", 10, {}) - - assert client._seq == 10 - client._logger.info.assert_called_once_with("shard has resumed [session:%s, seq:%s]", 123, 10) - client._handshake_completed.set.assert_called_once_with() - client._event_manager.consume_raw_event.assert_called_once_with("RESUME", client, {}) - - def test__dipatch(self, client): - client._logger = mock.Mock() - client._handshake_completed = mock.Mock() - client._event_manager = mock.Mock() - - client._dispatch("EVENT NAME", 10, {"payload": None}) - - client._logger.info.assert_not_called() - client._logger.debug.assert_not_called() - client._handshake_completed.set.assert_not_called() - client._event_manager.consume_raw_event.assert_called_once_with("EVENT NAME", client, {"payload": None}) - async def test__dispatch_for_unknown_event(self, client): client._logger = mock.Mock() client._handshake_completed = mock.Mock() @@ -1079,73 +1169,3 @@ async def test__send_heartbeat(self, client): client._send_json.assert_awaited_once_with({"op": 1, "d": 10}) assert client._last_heartbeat_sent == 200 - - def test__serialize_activity_when_activity_is_None(self, client): - assert client._serialize_activity(None) is None - - def test__serialize_activity_when_activity_is_not_None(self, client): - activity = mock.Mock(type="0", url="https://some.url") - activity.name = "some name" # This has to be set separate because if not, its set as the mock's name - assert client._serialize_activity(activity) == {"name": "some name", "type": 0, "url": "https://some.url"} - - @pytest.mark.parametrize("idle_since", [datetime.datetime.now(), None]) - @pytest.mark.parametrize("afk", [True, False]) - @pytest.mark.parametrize( - "status", - [presences.Status.DO_NOT_DISTURB, presences.Status.IDLE, presences.Status.ONLINE, presences.Status.OFFLINE], - ) - @pytest.mark.parametrize("activity", [presences.Activity(name="foo"), None]) - def test__serialize_and_store_presence_payload_when_all_args_undefined( - self, client, idle_since, afk, status, activity - ): - client._activity = activity - client._idle_since = idle_since - client._is_afk = afk - client._status = status - - actual_result = client._serialize_and_store_presence_payload() - - if activity is not undefined.UNDEFINED and activity is not None: - expected_activity = { - "name": activity.name, - "type": activity.type, - "url": activity.url, - } - else: - expected_activity = None - - if status == presences.Status.OFFLINE: - expected_status = "invisible" - else: - expected_status = status.value - - expected_result = { - "game": expected_activity, - "since": int(idle_since.timestamp() * 1_000) if idle_since is not None else None, - "afk": afk if afk is not undefined.UNDEFINED else False, - "status": expected_status, - } - - assert expected_result == actual_result - - @pytest.mark.parametrize("idle_since", [datetime.datetime.now(), None]) - @pytest.mark.parametrize("afk", [True, False]) - @pytest.mark.parametrize( - "status", - [presences.Status.DO_NOT_DISTURB, presences.Status.IDLE, presences.Status.ONLINE, presences.Status.OFFLINE], - ) - @pytest.mark.parametrize("activity", [presences.Activity(name="foo"), None]) - def test__serialize_and_store_presence_payload_sets_state(self, client, idle_since, afk, status, activity): - client._serialize_and_store_presence_payload(idle_since=idle_since, afk=afk, status=status, activity=activity) - - assert client._activity == activity - assert client._idle_since == idle_since - assert client._is_afk == afk - assert client._status == status - - def test__serialize_datetime_when_datetime_is_None(self, client): - assert client._serialize_datetime(None) is None - - def test__serialize_datetime_when_datetime_is_not_None(self, client): - dt = datetime.datetime(2004, 11, 22, tzinfo=datetime.timezone.utc) - assert client._serialize_datetime(dt) == 1101081600000 diff --git a/tests/hikari/test_channels.py b/tests/hikari/test_channels.py index d30f45e10a..01681ebf3a 100644 --- a/tests/hikari/test_channels.py +++ b/tests/hikari/test_channels.py @@ -192,7 +192,6 @@ def test_make_icon_url_when_hash_is_None(self, model): assert model.make_icon_url() is None -@pytest.mark.asyncio() class TestTextChannel: @pytest.fixture() def model(self, mock_app): @@ -203,6 +202,7 @@ def model(self, mock_app): type=channels.ChannelType.GUILD_TEXT, ) + @pytest.mark.asyncio() async def test_fetch_history(self, model): model.app.rest.fetch_messages = mock.AsyncMock() @@ -219,6 +219,7 @@ async def test_fetch_history(self, model): around=datetime.datetime(2020, 4, 1, 0, 30, 0), ) + @pytest.mark.asyncio() async def test_fetch_message(self, model): model.app.rest.fetch_message = mock.AsyncMock() @@ -226,6 +227,7 @@ async def test_fetch_message(self, model): model.app.rest.fetch_message.assert_awaited_once_with(12345679, 133742069) + @pytest.mark.asyncio() async def test_fetch_pins(self, model): model.app.rest.fetch_pins = mock.AsyncMock() @@ -233,6 +235,7 @@ async def test_fetch_pins(self, model): model.app.rest.fetch_pins.assert_awaited_once_with(12345679) + @pytest.mark.asyncio() async def test_pin_message(self, model): model.app.rest.pin_message = mock.AsyncMock() @@ -240,6 +243,7 @@ async def test_pin_message(self, model): model.app.rest.pin_message.assert_awaited_once_with(12345679, 77790) + @pytest.mark.asyncio() async def test_unpin_message(self, model): model.app.rest.unpin_message = mock.AsyncMock() @@ -247,6 +251,7 @@ async def test_unpin_message(self, model): model.app.rest.unpin_message.assert_awaited_once_with(12345679, 77790) + @pytest.mark.asyncio() async def test_delete_messages(self, model): model.app.rest.delete_messages = mock.AsyncMock() @@ -254,6 +259,7 @@ async def test_delete_messages(self, model): model.app.rest.delete_messages.assert_awaited_once_with(12345679, [77790, 88890, 1800], 1337) + @pytest.mark.asyncio() async def test_send(self, model): model.app.rest.create_message = mock.AsyncMock() mock_attachment = object()