Skip to content

Commit

Permalink
New tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davfsa committed Oct 3, 2020
1 parent 5429e11 commit 2979078
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 40 deletions.
3 changes: 2 additions & 1 deletion hikari/impl/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,9 @@ async def _request(
except self._RetryRequest:
continue

@staticmethod
@typing.final
def _stringify_http_message(self, headers: data_binding.Headers, body: typing.Any) -> str:
def _stringify_http_message(headers: data_binding.Headers, body: typing.Any) -> str:
string = "\n".join(
f" {name}: {value}" if name != _AUTHORIZATION_HEADER else f" {name}: **REDACTED TOKEN**"
for name, value in headers.items()
Expand Down
168 changes: 140 additions & 28 deletions tests/hikari/impl/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,20 @@ class StubRestClient:
return StubRestClient()

@pytest.fixture()
def rest_provider(self, rest_client):
return rest._RESTProvider(lambda: mock.Mock(), None, lambda: mock.Mock(), lambda: rest_client)
def cache(self):
return mock.Mock()

@pytest.fixture()
def executor(self):
return mock.Mock()

@pytest.fixture()
def entity_factory(self):
return mock.Mock()

@pytest.fixture()
def rest_provider(self, rest_client, cache, executor, entity_factory):
return rest._RESTProvider(lambda: entity_factory, executor, lambda: cache, lambda: rest_client)

def test_rest_property(self, rest_provider, rest_client):
assert rest_provider.rest == rest_client
Expand All @@ -129,6 +141,18 @@ def test_http_settings_property(self, rest_provider, rest_client):
def test_proxy_settings_property(self, rest_provider, rest_client):
assert rest_provider.proxy_settings == rest_client.proxy_settings

def test_entity_factory_property(self, rest_provider, entity_factory):
assert rest_provider.entity_factory == entity_factory

def test_cache_property(self, rest_provider, cache):
assert rest_provider.cache == cache

def test_executor_property(self, rest_provider, executor):
assert rest_provider.executor == executor

def test_me_property(self, rest_provider, cache):
assert rest_provider.me == cache.get_me()


###########
# RESTApp #
Expand All @@ -148,6 +172,24 @@ def rest_app():


class TestRESTApp:
def test__init__when_connector_factory_is_None(self):
http_settings = object()

with mock.patch.object(rest, "BasicLazyCachedTCPConnectorFactory") as factory:
rest_app = rest.RESTApp(
connector_factory=None,
connector_owner=False,
executor=None,
http_settings=http_settings,
proxy_settings=None,
url=None,
)

factory.assert_called_once_with(http_settings)

assert rest_app._connector_factory is factory()
assert rest_app._connector_owner is True

def test_executor_property(self, rest_app):
mock_executor = object()
rest_app._executor = mock_executor
Expand All @@ -166,10 +208,9 @@ def test_proxy_settings(self, rest_app):
def test_acquire(self, rest_app):
mock_event_loop = object()
rest_app._event_loop = mock_event_loop
mock_entity_factory = object()

stack = contextlib.ExitStack()
stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl", return_value=mock_entity_factory))
_entity_factory = stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl"))
mock_client = stack.enter_context(mock.patch.object(rest, rest.RESTClientImpl.__qualname__))
stack.enter_context(mock.patch.object(asyncio, "get_running_loop", return_value=mock_event_loop))

Expand All @@ -179,7 +220,7 @@ def test_acquire(self, rest_app):
mock_client.assert_called_once_with(
connector_factory=rest_app._connector_factory,
connector_owner=rest_app._connector_owner,
entity_factory=mock_entity_factory,
entity_factory=_entity_factory(),
executor=rest_app._executor,
http_settings=rest_app._http_settings,
proxy_settings=rest_app._proxy_settings,
Expand All @@ -188,6 +229,26 @@ def test_acquire(self, rest_app):
rest_url=rest_app._url,
)

def test_acquire_when_even_loop_not_set(self, rest_app):
mock_event_loop = object()

stack = contextlib.ExitStack()
_entity_factory = stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl"))
stack.enter_context(mock.patch.object(rest, rest.RESTClientImpl.__qualname__))
stack.enter_context(mock.patch.object(asyncio, "get_running_loop", return_value=mock_event_loop))

with stack:
rest_app.acquire(token="token", token_type="Type")

assert rest_app._event_loop is mock_event_loop

# This is just to test the lambdas so it counts towards coverage
assert _entity_factory.call_count == 1
factory = _entity_factory.call_args_list[0][0][0]
factory.entity_factory
factory.cache
factory.rest

def test_acquire_when__event_loop_and_loop_do_not_equal(self, rest_app):
rest_app._event_loop = object()
with mock.patch.object(asyncio, "get_running_loop"):
Expand Down Expand Up @@ -479,6 +540,19 @@ def test__generate_allowed_mentions(self, rest_client, function_input, expected_
def test__transform_emoji_to_url_format(self, rest_client, emoji, expected_return):
assert rest_client._transform_emoji_to_url_format(emoji) == expected_return

def test__stringify_http_message_when_body_is_None(self, rest_client):
headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"}
expected_return = " HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**"
assert rest_client._stringify_http_message(headers, None) == expected_return

@pytest.mark.parametrize(("body", "expected"), [(bytes("hello :)", "ascii"), "hello :)"), (123, "123")])
def test__stringify_http_message_when_body_is_not_None(self, rest_client, body, expected):
headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"}
expected_return = (
f" HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**\n\n {expected}"
)
assert rest_client._stringify_http_message(headers, body) == expected_return

#######################
# Non-async endpoints #
#######################
Expand Down Expand Up @@ -743,6 +817,14 @@ class ExitException(Exception):

return ExitException

async def test___aenter__and__aexit__(self, rest_client):
with mock.patch.object(rest_client, "close") as close:
async with rest_client as client:
assert client is rest_client
close.assert_not_called()

close.assert_awaited_once_with()

@hikari_test_helpers.timeout()
async def test__request_when_buckets_not_started(self, rest_client, exit_exception):
route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
Expand All @@ -763,19 +845,25 @@ async def test__request_when_buckets_started(self, rest_client, exit_exception):

rest_client.buckets.start.assert_not_called()

# FIXME: Move logger thingy to seperate test

@hikari_test_helpers.timeout()
async def test__request_when__token_is_None(self, rest_client, exit_exception):
route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception))
rest_client.buckets.is_started = True
rest_client._token = None
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=100))) as logger:
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route)

_, kwargs = mock_session.request.call_args_list[0]
assert rest._AUTHORIZATION_HEADER not in kwargs["headers"]
_, kwargs = mock_session.request.call_args_list[0]
assert rest._AUTHORIZATION_HEADER not in kwargs["headers"]

assert logger.log.call_count == 0

@hikari_test_helpers.timeout()
async def test__request_when__token_is_not_None(self, rest_client, exit_exception):
Expand All @@ -784,26 +872,34 @@ async def test__request_when__token_is_not_None(self, rest_client, exit_exceptio
rest_client.buckets.is_started = True
rest_client._token = "token"
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=100))) as logger:
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route)

_, kwargs = mock_session.request.call_args_list[0]
assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "token"

assert logger.log.call_count == 0

@hikari_test_helpers.timeout()
async def test__request_when_no_auth_passed(self, rest_client, exit_exception):
route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123)
mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception))
rest_client.buckets.is_started = True
rest_client._token = "token"
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route, no_auth=True)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=0))) as logger:
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route, no_auth=True)

_, kwargs = mock_session.request.call_args_list[0]
assert rest._AUTHORIZATION_HEADER not in kwargs["headers"]
_, kwargs = mock_session.request.call_args_list[0]
assert rest._AUTHORIZATION_HEADER not in kwargs["headers"]

assert logger.log.call_count == 1

@hikari_test_helpers.timeout()
async def test__request_when_response_is_NO_CONTENT(self, rest_client):
Expand All @@ -817,8 +913,12 @@ class StubResponse:
rest_client._debug = False
rest_client._parse_ratelimits = mock.AsyncMock()
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
assert (await rest_client._request(route)) is None
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=100))) as logger:
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
assert (await rest_client._request(route)) is None

assert logger.log.call_count == 0

@hikari_test_helpers.timeout()
async def test__request_when_response_is_APPLICATION_JSON(self, rest_client):
Expand All @@ -837,8 +937,12 @@ async def read(self):
rest_client._debug = True
rest_client._parse_ratelimits = mock.AsyncMock()
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
assert (await rest_client._request(route)) == {"something": None}
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=0))) as logger:
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
assert (await rest_client._request(route)) == {"something": None}

assert logger.log.call_count == 2

@hikari_test_helpers.timeout()
async def test__request_when_response_is_not_JSON(self, rest_client):
Expand All @@ -854,9 +958,13 @@ class StubResponse:
rest_client._debug = False
rest_client._parse_ratelimits = mock.AsyncMock()
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(errors.HTTPError):
await rest_client._request(route)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=100))) as logger:
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(errors.HTTPError):
await rest_client._request(route)

assert logger.log.call_count == 0

@hikari_test_helpers.timeout()
async def test__request_when_response_is_not_between_200_and_300(self, rest_client, exit_exception):
Expand All @@ -872,9 +980,13 @@ class StubResponse:
rest_client._parse_ratelimits = mock.AsyncMock()
rest_client._handle_error_response = mock.AsyncMock(side_effect=exit_exception)
rest_client._acquire_client_session = mock.Mock(return_value=mock_session)
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route)
rest_client._stringify_http_message = mock.Mock()
with mock.patch.object(rest, "_LOGGER", new=mock.Mock(getEffectiveLevel=mock.Mock(return_value=100))) as logger:
with mock.patch.object(asyncio, "gather", new=mock.AsyncMock()):
with pytest.raises(exit_exception):
await rest_client._request(route)

assert logger.log.call_count == 0

@hikari_test_helpers.timeout()
async def test__request_when_response__RetryRequest_gets_handled(self, rest_client, exit_exception):
Expand Down
15 changes: 6 additions & 9 deletions tests/hikari/internal/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,22 +210,19 @@ def test___setitem___removes_old_entry_instead_of_replacing(self):
mock_map["ok"] = "foo"
assert list(mock_map.items())[2] == ("ok", "foo")

# TODO: fix this so that it is not flaky.
# https://travis-ci.org/github/nekokatt/hikari/jobs/724494888#L797
@pytest.mark.skip("flaky test, might fail on Windows runners.")
@pytest.mark.asyncio
async def test___setitem___garbage_collection(self):
mock_map = collections.TimedCacheMap(
expiry=datetime.timedelta(seconds=hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME * 3)
)
mock_map.update({"OK": "no", "blam": "booga"})
mock_map["OK"] = "no"
await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME * 2)
assert mock_map == {"OK": "no", "blam": "booga"}
mock_map.update({"ayanami": "rei", "owo": "awoo"})
assert mock_map == {"OK": "no", "blam": "booga", "ayanami": "rei", "owo": "awoo"}
assert mock_map == {"OK": "no"}
mock_map["ayanami"] = "rei"
assert mock_map == {"OK": "no", "ayanami": "rei"}
await asyncio.sleep(hikari_test_helpers.REASONABLE_QUICK_RESPONSE_TIME * 2)
mock_map.update({"nyaa": "qt"})
assert mock_map == {"ayanami": "rei", "owo": "awoo", "nyaa": "qt"}
mock_map["nyaa"] = "qt"
assert mock_map == {"ayanami": "rei", "nyaa": "qt"}


class TestLimitedCapacityCacheMap:
Expand Down
10 changes: 8 additions & 2 deletions tests/hikari/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def test_format_icon_when_hash_is_None(self, model):
assert model.format_icon() is None


@pytest.mark.asyncio
class TestTextChannel:
@pytest.fixture()
def model(self, mock_app):
Expand All @@ -196,7 +197,6 @@ def model(self, mock_app):
type=channels.ChannelType.GUILD_TEXT,
)

@pytest.mark.asyncio
async def test_history(self, model):
model.app.rest.fetch_messages = mock.AsyncMock()

Expand All @@ -213,7 +213,6 @@ async def test_history(self, model):
around=datetime.datetime(2020, 4, 1, 0, 30, 0),
)

@pytest.mark.asyncio
async def test_send(self, model):
model.app.rest.create_message = mock.AsyncMock()
mock_attachment = object()
Expand Down Expand Up @@ -245,6 +244,13 @@ async def test_send(self, model):
role_mentions=[789, 567],
)

def test_trigger_typing(self, model):
model.app.rest.trigger_typing = mock.Mock()

model.trigger_typing()

model.app.rest.trigger_typing.assert_called_once_with(12345679)


class TestGuildChannel:
@pytest.fixture()
Expand Down

0 comments on commit 2979078

Please sign in to comment.