diff --git a/changes/1230.breaking.md b/changes/1230.breaking.md new file mode 100644 index 0000000000..23314c10d8 --- /dev/null +++ b/changes/1230.breaking.md @@ -0,0 +1 @@ +`RESTApp` and `RESTBucketManager` now need to be started and stopped by using `.start` and `.close`. diff --git a/changes/1230.bugfix.md b/changes/1230.bugfix.md new file mode 100644 index 0000000000..ef0a17dde8 --- /dev/null +++ b/changes/1230.bugfix.md @@ -0,0 +1 @@ +Buckets across different authentications are not shared any more, which would lead to incorrect rate limiting. diff --git a/changes/1230.feature.md b/changes/1230.feature.md new file mode 100644 index 0000000000..bfcf9478f2 --- /dev/null +++ b/changes/1230.feature.md @@ -0,0 +1,5 @@ +`RESTClientImpl` improvements: + - You can now share client sessions and bucket managers across these objects or have them created for you. + - Speedup of request lifetime + - No-ratelimit routes no longer attempt to acquire rate limits + - Just for safety, a check is in place to treat the route as a rate limited route if a bucket is ever received for it and a error log is emitted. If you spot it around, please inform us! diff --git a/examples/oauth.py b/examples/oauth.py new file mode 100644 index 0000000000..014c209f59 --- /dev/null +++ b/examples/oauth.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Hikari Examples - A collection of examples for Hikari. +# +# To the extent possible under law, the author(s) have dedicated all copyright +# and related and neighboring rights to this software to the public domain worldwide. +# This software is distributed without any warranty. +# +# You should have received a copy of the CC0 Public Domain Dedication along with this software. +# If not, see . +"""An example OAuth server.""" +import logging +import os + +from aiohttp import web + +import hikari + +logging.basicConfig(level=logging.DEBUG) + +host = "localhost" +port = 8080 +CLIENT_ID = int(os.environ["CLIENT_ID"]) # ID as an int +CLIENT_SECRET = os.environ["CLIENT_SECRET"] # Secret as a str +BOT_TOKEN = os.environ["BOT_TOKEN"] # Token as a str +CHANNEL_ID = int(os.environ["CHANNEL_ID"]) # Channel to post in as an int +REDIRECT_URI = "http://localhost:8080" + +route_table = web.RouteTableDef() + + +@route_table.get("/") +async def oauth(request: web.Request) -> web.Response: + """Handle an OAuth request.""" + code = request.query.get("code") + if not code: + return web.json_response({"error": "'code' is not provided"}, status=400) + + discord_rest: hikari.RESTApp = request.app["discord_rest"] + + # Exchange code to acquire a Bearer one for the user + async with discord_rest.acquire(None) as r: + auth = await r.authorize_access_token(CLIENT_ID, CLIENT_SECRET, code, REDIRECT_URI) + + # Perform a request as the user to get their own user object + async with discord_rest.acquire(auth.access_token, hikari.TokenType.BEARER) as client: + user = await client.fetch_my_user() + # user is a hikari.OwnUser object where we can access attributes on it + + # Notify the success + async with discord_rest.acquire(BOT_TOKEN, hikari.TokenType.BOT) as client: + await client.create_message(CHANNEL_ID, f"{user} ({user.id}) just authorized!") + + return web.Response(text="Successfully authenticated!") + + +async def start_discord_rest(app: web.Application) -> None: + """Start the RESTApp.""" + discord_rest = hikari.RESTApp() + await discord_rest.start() + + app["discord_rest"] = discord_rest + + +async def stop_discord_rest(app: web.Application) -> None: + """Stop the RESTApp.""" + discord_rest: hikari.RESTApp = app["discord_rest"] + + await discord_rest.close() + + +if __name__ == "__main__": + server = web.Application() + server.add_routes(route_table) + + server.on_startup.append(start_discord_rest) + server.on_cleanup.append(stop_discord_rest) + + web.run_app(server, host=host, port=port) diff --git a/hikari/impl/bot.py b/hikari/impl/bot.py index 7066671136..68455d52dc 100644 --- a/hikari/impl/bot.py +++ b/hikari/impl/bot.py @@ -97,6 +97,21 @@ def _validate_activity(activity: undefined.UndefinedNoneOr[presences.Activity]) ) +async def _close_resource(name: str, awaitable: typing.Awaitable[typing.Any]) -> None: + future = asyncio.ensure_future(awaitable) + + try: + await future + except Exception as ex: + asyncio.get_running_loop().call_exception_handler( + { + "message": f"{name} raised an exception during shut down", + "future": future, + "exception": ex, + } + ) + + class GatewayBot(traits.GatewayBotAware): """Basic auto-sharding bot implementation. @@ -288,7 +303,7 @@ def __init__( intents: intents_.Intents = intents_.Intents.ALL_UNPRIVILEGED, auto_chunk_members: bool = True, logs: typing.Union[None, int, str, typing.Dict[str, typing.Any]] = "INFO", - max_rate_limit: float = 300, + max_rate_limit: float = 300.0, max_retries: int = 3, proxy_settings: typing.Optional[config_impl.ProxySettings] = None, rest_url: typing.Optional[str] = None, @@ -426,30 +441,14 @@ async def close(self) -> None: await self._event_manager.dispatch(self._event_factory.deserialize_stopping_event()) _LOGGER.log(ux.TRACE, "StoppingEvent dispatch completed, now beginning termination") - loop = asyncio.get_running_loop() - - async def handle(name: str, awaitable: typing.Awaitable[typing.Any]) -> None: - future = asyncio.ensure_future(awaitable) - - try: - await future - except Exception as ex: - loop.call_exception_handler( - { - "message": f"{name} raised an exception during shut down", - "future": future, - "exception": ex, - } - ) - - await handle("voice handler", self._voice.close()) + await _close_resource("voice handler", self._voice.close()) - shards = tuple((handle(f"shard {s.id}", s.close()) for s in self._shards.values() if s.is_alive)) + shards = tuple(_close_resource(f"shard {s.id}", s.close()) for s in self._shards.values() if s.is_alive) for coro in asyncio.as_completed(shards): await coro - await handle("rest", self._rest.close()) + await _close_resource("rest", self._rest.close()) # Clear out cache and shard map self._cache.clear() diff --git a/hikari/impl/buckets.py b/hikari/impl/buckets.py index e68cbdd002..573b3cb7b4 100644 --- a/hikari/impl/buckets.py +++ b/hikari/impl/buckets.py @@ -227,12 +227,24 @@ class RESTBucket(rate_limits.WindowedBurstRateLimiter): which allows dynamically changing the enforced rate limits at any time. """ - __slots__: typing.Sequence[str] = ("_compiled_route", "_max_rate_limit", "_lock") + __slots__: typing.Sequence[str] = ( + "_compiled_route", + "_max_rate_limit", + "_global_ratelimit", + "_lock", + ) - def __init__(self, name: str, compiled_route: routes.CompiledRoute, max_rate_limit: float) -> None: + def __init__( + self, + name: str, + compiled_route: routes.CompiledRoute, + global_ratelimit: rate_limits.ManualRateLimiter, + max_rate_limit: float, + ) -> None: super().__init__(name, 1, 1) self._compiled_route = compiled_route self._max_rate_limit = max_rate_limit + self._global_ratelimit = global_ratelimit self._lock = asyncio.Lock() async def __aenter__(self) -> None: @@ -290,6 +302,8 @@ async def acquire(self) -> None: await super().acquire() + await self._global_ratelimit.acquire() + def update_rate_limit(self, remaining: int, limit: int, reset_at: float) -> None: """Update the rate limit information. @@ -330,8 +344,12 @@ def resolve(self, real_bucket_hash: str) -> None: self.name: str = real_bucket_hash -def _create_unknown_hash(route: routes.CompiledRoute) -> str: - return UNKNOWN_HASH + routes.HASH_SEPARATOR + str(hash(route)) +def _create_authentication_hash(authentication: typing.Optional[str]) -> str: + return str(hash(authentication)) + + +def _create_unknown_hash(route: routes.CompiledRoute, authentication_hash: str) -> str: + return f"{UNKNOWN_HASH}{routes.HASH_SEPARATOR}{authentication_hash}{routes.HASH_SEPARATOR}{str(hash(route))}" class RESTBucketManager: @@ -349,52 +367,26 @@ class RESTBucketManager: """ __slots__: typing.Sequence[str] = ( - "routes_to_hashes", - "real_hashes_to_buckets", - "closed_event", - "gc_task", - "max_rate_limit", + "_routes_to_hashes", + "_real_hashes_to_buckets", + "_global_ratelimit", + "_closed_event", + "_gc_task", + "_max_rate_limit", ) - routes_to_hashes: typing.Final[typing.MutableMapping[routes.Route, str]] - """Maps routes to their `X-RateLimit-Bucket` header being used.""" - - real_hashes_to_buckets: typing.Final[typing.MutableMapping[str, RESTBucket]] - """Maps full bucket hashes to their corresponding rate limiters. - - The full bucket hash consists of `X-RateLimit-Bucket` appended with a hash of - major parameters used in that compiled route. - """ - - closed_event: typing.Final[asyncio.Event] - """An internal event that is set when the object is shut down.""" - - gc_task: typing.Optional[asyncio.Task[None]] - """The internal garbage collector task.""" - - max_rate_limit: float - """The max number of seconds to backoff for when rate limited. - - Anything greater than this will instead raise an error. - """ - def __init__(self, max_rate_limit: float) -> None: - self.routes_to_hashes = {} - self.real_hashes_to_buckets = {} - self.closed_event: asyncio.Event = asyncio.Event() - self.gc_task: typing.Optional[asyncio.Task[None]] = None - self.max_rate_limit = max_rate_limit - - def __enter__(self) -> RESTBucketManager: - return self + self._routes_to_hashes: typing.Dict[routes.Route, str] = {} + self._real_hashes_to_buckets: typing.Dict[str, RESTBucket] = {} + self._closed_event: typing.Optional[asyncio.Event] = None + self._gc_task: typing.Optional[asyncio.Task[None]] = None + self._max_rate_limit = max_rate_limit + self._global_ratelimit = rate_limits.ManualRateLimiter() - def __exit__( - self, - exc_type: typing.Optional[typing.Type[Exception]], - exc_val: typing.Optional[Exception], - exc_tb: typing.Optional[types.TracebackType], - ) -> None: - self.close() + @property + def is_alive(self) -> bool: + """Whether the component is alive.""" + return self._closed_event is not None def start(self, poll_period: float = 20.0, expire_after: float = 10.0) -> None: """Start this ratelimiter up. @@ -415,27 +407,37 @@ def start(self, poll_period: float = 20.0, expire_after: float = 10.0) -> None: result. Using `0` will make the bucket get garbage collected as soon as the rate limit has reset. Defaults to `10` seconds. """ - if not self.gc_task: - self.gc_task = asyncio.create_task(self.gc(poll_period, expire_after)) + if self._closed_event: + raise errors.ComponentStateConflictError("Cannot start an active bucket manager") + + # Assert is in running loop + asyncio.get_running_loop() + + self._closed_event = asyncio.Event() + self._gc_task = asyncio.create_task(self._gc(poll_period, expire_after)) def close(self) -> None: - """Close the garbage collector and kill any tasks waiting on ratelimits. + """Close the garbage collector and kill any tasks waiting on ratelimits.""" + if not self._closed_event: + raise errors.ComponentStateConflictError("Cannot interact with an inactive bucket manager") - Once this has been called, this object is considered to be effectively - dead. To reuse it, one should create a new instance. - """ - self.closed_event.set() - for bucket in self.real_hashes_to_buckets.values(): + assert self._closed_event is not None + + for bucket in self._real_hashes_to_buckets.values(): bucket.close() - self.real_hashes_to_buckets.clear() - self.routes_to_hashes.clear() - if self.gc_task is not None: - self.gc_task.cancel() - self.gc_task = None + self._global_ratelimit.close() + self._real_hashes_to_buckets.clear() + self._routes_to_hashes.clear() + + if self._gc_task is not None: + self._gc_task.cancel() + self._gc_task = None - # Ignore docstring not starting in an imperative mood - async def gc(self, poll_period: float, expire_after: float) -> None: + self._closed_event.set() + self._closed_event = None + + async def _gc(self, poll_period: float, expire_after: float) -> None: """Run the garbage collector loop. This is designed to run in the background and manage removing unused @@ -463,37 +465,16 @@ async def gc(self, poll_period: float, expire_after: float) -> None: # Prevent filling memory increasingly until we run out by removing dead buckets every 20s # Allocations are somewhat cheap if we only do them every so-many seconds, after all. _LOGGER.log(ux.TRACE, "rate limit garbage collector started") - while not self.closed_event.is_set(): + + assert self._closed_event is not None + while not self._closed_event.is_set(): try: - await asyncio.wait_for(self.closed_event.wait(), timeout=poll_period) + await asyncio.wait_for(self._closed_event.wait(), timeout=poll_period) except asyncio.TimeoutError: _LOGGER.log(ux.TRACE, "performing rate limit garbage collection pass") - self.do_gc_pass(expire_after) - self.gc_task = None - - def do_gc_pass(self, expire_after: float) -> None: - """Perform a single garbage collection pass. + self._purge_stale_buckets(expire_after) - This will assess any routes stored in the internal mappings of this - object and remove any that are deemed to be inactive or dead in order - to save memory. - - If the removed routes are used again in the future, they will be - re-cached automatically. - - .. warning:: - You generally have no need to invoke this directly. Use - `RESTBucketManager.start` and `RESTBucketManager.close` to control - this instead. - - Parameters - ---------- - expire_after : float - Time after which the last `reset_at` was hit for a bucket to\ - remove it. Defaults to `reset_at` + 20 seconds. Higher values will - retain unneeded ratelimit info for longer, but may produce more - effective ratelimiting logic as a result. - """ + def _purge_stale_buckets(self, expire_after: float) -> None: buckets_to_purge: typing.List[str] = [] now = time.monotonic() @@ -505,7 +486,7 @@ def do_gc_pass(self, expire_after: float) -> None: active = 0 # Discover and purge - bucket_pairs = self.real_hashes_to_buckets.items() + bucket_pairs = self._real_hashes_to_buckets.items() for full_hash, bucket in bucket_pairs: if bucket.is_empty and bucket.reset_at + expire_after < now: @@ -521,48 +502,63 @@ def do_gc_pass(self, expire_after: float) -> None: survival = total - active - dead for full_hash in buckets_to_purge: - self.real_hashes_to_buckets[full_hash].close() - del self.real_hashes_to_buckets[full_hash] + self._real_hashes_to_buckets[full_hash].close() + del self._real_hashes_to_buckets[full_hash] - _LOGGER.log(ux.TRACE, "purged %s stale buckets, %s remain in survival, %s active", dead, survival, active) + if dead: + _LOGGER.debug("purged %s stale buckets, %s remain in survival, %s active", dead, survival, active) + else: + _LOGGER.log(ux.TRACE, "no buckets purged, %s remain in survival, %s active", survival, active) - def acquire(self, compiled_route: routes.CompiledRoute) -> RESTBucket: + def acquire_bucket( + self, compiled_route: routes.CompiledRoute, authentication: typing.Optional[str] + ) -> typing.AsyncContextManager[None]: """Acquire a bucket for the given route. .. note:: - You MUST keep the context manager of the bucket acquired during the - full duration of the request. From making the request until calling - `update_rate_limits`. + You MUST keep the context manager acquired during the full duration + of the request: from making the request until calling `update_rate_limits`. Parameters ---------- compiled_route : hikari.internal.routes.CompiledRoute The route to get the bucket for. + authentication : typing.Optional[str] + The authentication that will be used in the request. Returns ------- - hikari.impl.RESTBucket - The bucket for this route. + typing.AsyncContextManager + The context manager to use during the duration of the request. """ - try: - bucket_hash = self.routes_to_hashes[compiled_route.route] - real_bucket_hash = compiled_route.create_real_bucket_hash(bucket_hash) - except KeyError: - real_bucket_hash = _create_unknown_hash(compiled_route) - - try: - bucket = self.real_hashes_to_buckets[real_bucket_hash] + if not self._closed_event: + raise errors.ComponentStateConflictError("Cannot interact with an inactive bucket manager") + + authentication_hash = _create_authentication_hash(authentication) + + if bucket_hash := self._routes_to_hashes.get(compiled_route.route): + real_bucket_hash = compiled_route.create_real_bucket_hash(bucket_hash, authentication_hash) + else: + real_bucket_hash = _create_unknown_hash(compiled_route, authentication_hash) + + if bucket := self._real_hashes_to_buckets.get(real_bucket_hash): _LOGGER.debug("%s is being mapped to existing bucket %s", compiled_route, real_bucket_hash) - except KeyError: + else: _LOGGER.debug("%s is being mapped to new bucket %s", compiled_route, real_bucket_hash) - bucket = RESTBucket(real_bucket_hash, compiled_route, self.max_rate_limit) - self.real_hashes_to_buckets[real_bucket_hash] = bucket + bucket = RESTBucket( + real_bucket_hash, + compiled_route, + self._global_ratelimit, + self._max_rate_limit, + ) + self._real_hashes_to_buckets[real_bucket_hash] = bucket return bucket def update_rate_limits( self, compiled_route: routes.CompiledRoute, + authentication: typing.Optional[str], bucket_header: str, remaining_header: int, limit_header: int, @@ -574,7 +570,9 @@ def update_rate_limits( ---------- compiled_route : hikari.internal.routes.CompiledRoute The compiled route to get the bucket for. - bucket_header : typing.Optional[str] + authentication : typing.Optional[str] + The authentication that was used in the request. + bucket_header : str The `X-RateLimit-Bucket` header that was provided in the response. remaining_header : int The `X-RateLimit-Remaining` header cast to an `int`. @@ -583,10 +581,14 @@ def update_rate_limits( reset_after : float The `X-RateLimit-Reset-After` header cast to a `float`. """ - self.routes_to_hashes[compiled_route.route] = bucket_header - real_bucket_hash = compiled_route.create_real_bucket_hash(bucket_header) + if not self._closed_event: + raise errors.ComponentStateConflictError("Cannot interact with an inactive bucket manager") + + self._routes_to_hashes[compiled_route.route] = bucket_header + authentication_hash = _create_authentication_hash(authentication) + real_bucket_hash = compiled_route.create_real_bucket_hash(bucket_header, authentication_hash) - if bucket := self.real_hashes_to_buckets.get(real_bucket_hash): + if bucket := self._real_hashes_to_buckets.get(real_bucket_hash): _LOGGER.debug( "updating %s with bucket %s [reset-after:%ss, limit:%s, remaining:%s]", compiled_route, @@ -596,9 +598,9 @@ def update_rate_limits( remaining_header, ) else: - unknown_bucket_hash = _create_unknown_hash(compiled_route) + unknown_bucket_hash = _create_unknown_hash(compiled_route, authentication_hash) - if bucket := self.real_hashes_to_buckets.pop(unknown_bucket_hash, None): + if bucket := self._real_hashes_to_buckets.pop(unknown_bucket_hash, None): bucket.resolve(real_bucket_hash) _LOGGER.debug( "remapping %s with existing bucket %s [reset-after:%ss, limit:%s, remaining:%s]", @@ -617,14 +619,25 @@ def update_rate_limits( limit_header, remaining_header, ) - bucket = RESTBucket(real_bucket_hash, compiled_route, self.max_rate_limit) - self.real_hashes_to_buckets[real_bucket_hash] = bucket + bucket = RESTBucket( + real_bucket_hash, + compiled_route, + self._global_ratelimit, + self._max_rate_limit, + ) + + self._real_hashes_to_buckets[real_bucket_hash] = bucket reset_at_monotonic = time.monotonic() + reset_after bucket.update_rate_limit(remaining_header, limit_header, reset_at_monotonic) - @property - def is_started(self) -> bool: - """Return `True` if the rate limiter GC task is started.""" - return self.gc_task is not None + def throttle(self, retry_after: float) -> None: + """Throttle the global ratelimit for the buckets. + + Parameters + ---------- + retry_after : float + How long to throttle for. + """ + self._global_ratelimit.throttle(retry_after) diff --git a/hikari/impl/rest.py b/hikari/impl/rest.py index 313cc1be90..ef54b82454 100644 --- a/hikari/impl/rest.py +++ b/hikari/impl/rest.py @@ -45,7 +45,6 @@ import urllib.parse import aiohttp -import attr from hikari import _about as about from hikari import applications @@ -74,6 +73,7 @@ from hikari.impl import rate_limits from hikari.impl import special_endpoints as special_endpoints_impl from hikari.interactions import base_interactions +from hikari.internal import aio from hikari.internal import data_binding from hikari.internal import deprecation from hikari.internal import mentions @@ -112,7 +112,7 @@ _X_RATELIMIT_LIMIT_HEADER: typing.Final[str] = sys.intern("X-RateLimit-Limit") _X_RATELIMIT_REMAINING_HEADER: typing.Final[str] = sys.intern("X-RateLimit-Remaining") _X_RATELIMIT_RESET_AFTER_HEADER: typing.Final[str] = sys.intern("X-RateLimit-Reset-After") -_RETRY_ERROR_CODES: typing.Final[typing.FrozenSet[int]] = frozenset({500, 502, 503, 504}) +_RETRY_ERROR_CODES: typing.Final[typing.FrozenSet[int]] = frozenset((500, 502, 503, 504)) _MAX_BACKOFF_DURATION: typing.Final[int] = 16 @@ -287,10 +287,11 @@ class RESTApp(traits.ExecutorAware): __slots__: typing.Sequence[str] = ( "_executor", "_http_settings", - "_max_rate_limit", "_max_retries", "_proxy_settings", "_url", + "_bucket_manager", + "_client_session", ) def __init__( @@ -298,7 +299,7 @@ def __init__( *, executor: typing.Optional[concurrent.futures.Executor] = None, http_settings: typing.Optional[config_impl.HTTPSettings] = None, - max_rate_limit: float = 300, + max_rate_limit: float = 300.0, max_retries: int = 3, proxy_settings: typing.Optional[config_impl.ProxySettings] = None, url: typing.Optional[str] = None, @@ -306,9 +307,10 @@ def __init__( self._http_settings = config_impl.HTTPSettings() if http_settings is None else http_settings self._proxy_settings = config_impl.ProxySettings() if proxy_settings is None else proxy_settings self._executor = executor - self._max_rate_limit = max_rate_limit self._max_retries = max_retries self._url = url + self._bucket_manager = buckets_impl.RESTBucketManager(max_rate_limit) + self._client_session: typing.Optional[aiohttp.ClientSession] = None @property def executor(self) -> typing.Optional[concurrent.futures.Executor]: @@ -322,6 +324,26 @@ def http_settings(self) -> config_impl.HTTPSettings: def proxy_settings(self) -> config_impl.ProxySettings: return self._proxy_settings + async def start(self) -> None: + if self._client_session: + raise errors.ComponentStateConflictError("Rest app has already been started") + + self._bucket_manager.start() + self._client_session = net.create_client_session( + connector=net.create_tcp_connector(self._http_settings), + connector_owner=True, # Ensure closing the TCP connector + http_settings=self._http_settings, + raise_for_status=False, + trust_env=self._proxy_settings.trust_env, + ) + + async def close(self) -> None: + if self._client_session is None: + raise errors.ComponentStateConflictError("Rest app is not running") + + await self._client_session.close() + self._bucket_manager.close() + @typing.overload def acquire(self, token: typing.Optional[rest_api.TokenStrategy] = None) -> RESTClientImpl: ... @@ -351,12 +373,15 @@ def acquire( .. code-block:: python rest_app = RESTApp() + await rest_app.start() # Using the returned client as a context manager to implicitly start # and stop it. async with rest_app.acquire("A token", "Bot") as client: user = await client.fetch_my_user() + await rest_app.close() + Parameters ---------- token : typing.Union[str, None, hikari.api.rest.TokenStrategy] @@ -381,6 +406,9 @@ def acquire( ValueError If `token_type` is provided when a token strategy is passed for `token`. """ + if not self._client_session: + raise errors.ComponentStateConflictError("Rest app is not running so it cannot be interacted with") + # Since we essentially mimic a fake App instance, we need to make a circular provider. # We can achieve this using a lambda. This allows the entity factory to build models that # are also REST-aware @@ -398,87 +426,48 @@ def acquire( entity_factory=entity_factory, executor=self._executor, http_settings=self._http_settings, - max_rate_limit=self._max_rate_limit, max_retries=self._max_retries, proxy_settings=self._proxy_settings, token=token, token_type=token_type, rest_url=self._url, + bucket_manager=self._bucket_manager, + bucket_manager_owner=False, + client_session=self._client_session, + client_session_owner=False, ) return rest_client -@attr.define() -class _LiveAttributes: - """Fields which are only present within `RESTClientImpl` while it's "alive". +def _stringify_http_message(headers: data_binding.Headers, body: typing.Any) -> str: + string = "\n".join( + f" {name}: {value}" if name != _AUTHORIZATION_HEADER else f" {name}: **REDACTED TOKEN**" + for name, value in headers.items() + ) - .. note:: - This must be started within an active asyncio event loop. - """ + if body is not None: + string += "\n\n " + string += body.decode("ascii") if isinstance(body, bytes) else str(body) - buckets: buckets_impl.RESTBucketManager = attr.field() - client_session: aiohttp.ClientSession = attr.field() - closed_event: asyncio.Event = attr.field() - # We've been told in DAPI that this is per token. - global_rate_limit: rate_limits.ManualRateLimiter = attr.field() - tcp_connector: aiohttp.TCPConnector = attr.field() - is_closing: bool = attr.field(default=False, init=False) - - @classmethod - def build( - cls, max_rate_limit: float, http_settings: config_impl.HTTPSettings, proxy_settings: config_impl.ProxySettings - ) -> _LiveAttributes: - """Build a live attributes object. - - .. warning:: - This can only be called when the current thread has an active - asyncio loop. - """ - # This asserts that this is called within an active event loop. - asyncio.get_running_loop() - tcp_connector = net.create_tcp_connector(http_settings) - _LOGGER.log(ux.TRACE, "acquired new tcp connector") - client_session = net.create_client_session( - connector=tcp_connector, - # No, this is correct. We manage closing the connector ourselves in this class. - # This works around some other lifespan issues. - connector_owner=False, - http_settings=http_settings, - raise_for_status=False, - trust_env=proxy_settings.trust_env, - ) - _LOGGER.log(ux.TRACE, "acquired new aiohttp client session") - buckets = buckets_impl.RESTBucketManager(max_rate_limit) - buckets.start() - return _LiveAttributes( - buckets=buckets, - client_session=client_session, - closed_event=asyncio.Event(), - global_rate_limit=rate_limits.ManualRateLimiter(), - tcp_connector=tcp_connector, - ) + return string - async def close(self) -> None: - self.is_closing = True - self.closed_event.set() - self.buckets.close() - self.global_rate_limit.close() - await self.client_session.close() - await self.tcp_connector.close() - - def still_alive(self) -> _LiveAttributes: - """Chained method used to Check if `close` has been called before using this object's resources.""" - if self.is_closing: - raise errors.ComponentStateConflictError("The REST client was closed mid-request") - return self +def _transform_emoji_to_url_format( + emoji: typing.Union[str, emojis.Emoji], + emoji_id: undefined.UndefinedOr[snowflakes.SnowflakeishOr[emojis.CustomEmoji]], + /, +) -> str: + if isinstance(emoji, emojis.Emoji): + if emoji_id is not undefined.UNDEFINED: + raise ValueError("emoji_id shouldn't be passed when an Emoji object is passed for emoji") + + return emoji.url_name + if emoji_id is not undefined.UNDEFINED: + return f"{emoji}:{snowflakes.Snowflake(emoji_id)}" -# The standard exceptions are all unsloted so slotting here would be a waste of time. -@attr.define(auto_exc=True, repr=False, slots=False) -class _RetryRequest(RuntimeError): - ... + return emoji class RESTClientImpl(rest_api.RESTClient): @@ -495,13 +484,6 @@ class RESTClientImpl(rest_api.RESTClient): executor : typing.Optional[concurrent.futures.Executor] The executor to use for blocking IO. Defaults to the `asyncio` thread pool if set to `None`. - max_rate_limit : float - Maximum number of seconds to sleep for when rate limited. If a rate - limit occurs that is longer than this value, then a - `hikari.errors.RateLimitedError` will be raised instead of waiting. - - This is provided since some endpoints may respond with non-sensible - rate limits. max_retries : typing.Optional[int] Maximum number of times a request will be retried if it fails with a `5xx` status. Defaults to 3 if set to `None`. @@ -528,17 +510,20 @@ class RESTClientImpl(rest_api.RESTClient): """ __slots__: typing.Sequence[str] = ( + "_bucket_manager", + "_bucket_manager_owner", "_cache", "_entity_factory", "_executor", "_http_settings", - "_live_attributes", - "_max_rate_limit", "_max_retries", "_proxy_settings", "_rest_url", "_token", "_token_type", + "_client_session", + "_client_session_owner", + "_close_event", ) def __init__( @@ -548,7 +533,11 @@ def __init__( entity_factory: entity_factory_.EntityFactory, executor: typing.Optional[concurrent.futures.Executor], http_settings: config_impl.HTTPSettings, - max_rate_limit: float, + bucket_manager: typing.Optional[buckets_impl.RESTBucketManager] = None, + bucket_manager_owner: bool = True, + client_session: typing.Optional[aiohttp.ClientSession] = None, + client_session_owner: bool = True, + max_rate_limit: float = 300.0, max_retries: int = 3, proxy_settings: config_impl.ProxySettings, token: typing.Union[str, None, rest_api.TokenStrategy], @@ -558,14 +547,28 @@ def __init__( if max_retries > 5: raise ValueError("'max_retries' must be below or equal to 5") + if client_session_owner is False and client_session is None: + raise ValueError( + "Cannot delegate ownership of unknown client session [client_session_owner=False, client_session=None]" + ) + if bucket_manager_owner is False and bucket_manager is None: + raise ValueError( + "Cannot delegate ownership of unknown bucket manager [bucket_manager_owner=False, bucket_manager=None]" + ) + self._cache = cache self._entity_factory = entity_factory self._executor = executor self._http_settings = http_settings - self._live_attributes: typing.Optional[_LiveAttributes] = None - self._max_rate_limit = max_rate_limit self._max_retries = max_retries self._proxy_settings = proxy_settings + self._bucket_manager = ( + buckets_impl.RESTBucketManager(max_rate_limit) if bucket_manager is None else bucket_manager + ) + self._bucket_manager_owner = bucket_manager_owner + self._client_session = client_session + self._client_session_owner = client_session_owner + self._close_event: typing.Optional[asyncio.Event] = None self._token: typing.Union[str, rest_api.TokenStrategy, None] = None self._token_type: typing.Optional[str] = None @@ -589,7 +592,7 @@ def __init__( @property def is_alive(self) -> bool: - return self._live_attributes is not None + return self._close_event is not None @property def http_settings(self) -> config_impl.HTTPSettings: @@ -607,14 +610,21 @@ def entity_factory(self) -> entity_factory_.EntityFactory: def token_type(self) -> typing.Union[str, applications.TokenType, None]: return self._token_type - @typing.final async def close(self) -> None: """Close the HTTP client and any open HTTP connections.""" - live_attributes = self._get_live_attributes() - self._live_attributes = None - await live_attributes.close() + if not self._close_event or not self._client_session: + raise errors.ComponentStateConflictError("Cannot close an inactive REST client") + + self._close_event.set() + self._close_event = None + + if self._client_session_owner: + await self._client_session.close() + self._client_session = None + + if self._bucket_manager_owner: + self._bucket_manager.close() - @typing.final def start(self) -> None: """Start the HTTP client. @@ -626,16 +636,25 @@ def start(self) -> None: RuntimeError If this is called in an environment without an active event loop. """ - if self._live_attributes: + if self._close_event: raise errors.ComponentStateConflictError("Cannot start a REST Client which is already alive") - self._live_attributes = _LiveAttributes.build(self._max_rate_limit, self._http_settings, self._proxy_settings) + # Assert is in running loop + asyncio.get_running_loop() - def _get_live_attributes(self) -> _LiveAttributes: - if self._live_attributes: - return self._live_attributes + self._close_event = asyncio.Event() - raise errors.ComponentStateConflictError("Cannot use an inactive REST client") + if self._client_session_owner: + self._client_session = net.create_client_session( + connector=net.create_tcp_connector(self._http_settings), + connector_owner=True, # Ensure closing the TCP connector + http_settings=self._http_settings, + raise_for_status=False, + trust_env=self._proxy_settings.trust_env, + ) + + if self._bucket_manager_owner: + self._bucket_manager.start() async def __aenter__(self) -> RESTClientImpl: self.start() @@ -674,159 +693,171 @@ async def _request( form_builder: typing.Optional[data_binding.URLEncodedFormBuilder] = None, json: typing.Union[data_binding.JSONObjectBuilder, data_binding.JSONArray, None] = None, reason: undefined.UndefinedOr[str] = undefined.UNDEFINED, - no_auth: bool = False, - auth: typing.Optional[str] = None, + auth: undefined.UndefinedNoneOr[str] = undefined.UNDEFINED, ) -> typing.Union[None, data_binding.JSONObject, data_binding.JSONArray]: - # Make a ratelimit-protected HTTP request to a JSON endpoint and expect some form - # of JSON response. - live_attributes = self._get_live_attributes() - headers = data_binding.StringMapBuilder() - headers.setdefault(_USER_AGENT_HEADER, _HTTP_USER_AGENT) + if not self._close_event: + raise errors.ComponentStateConflictError("Cannot use an inactive REST client") - re_authed = False - token: typing.Optional[str] = None - if auth: - headers[_AUTHORIZATION_HEADER] = auth + request_task = asyncio.create_task( + self._perform_request( + compiled_route=compiled_route, + query=query, + form_builder=form_builder, + json=json, + reason=reason, + auth=auth, + ) + ) + + await aio.first_completed(request_task, self._close_event.wait()) + + if not self._close_event.is_set(): + return request_task.result() - elif not no_auth: - if isinstance(self._token, str): - headers[_AUTHORIZATION_HEADER] = self._token + raise errors.ComponentStateConflictError("The REST client was closed mid-request") - elif self._token is not None: - token = await self._token.acquire(self) - headers[_AUTHORIZATION_HEADER] = token + @typing.final + async def _perform_request( + self, + compiled_route: routes.CompiledRoute, + *, + query: typing.Optional[data_binding.StringMapBuilder] = None, + form_builder: typing.Optional[data_binding.URLEncodedFormBuilder] = None, + json: typing.Union[data_binding.JSONObjectBuilder, data_binding.JSONArray, None] = None, + reason: undefined.UndefinedOr[str] = undefined.UNDEFINED, + auth: undefined.UndefinedNoneOr[str] = undefined.UNDEFINED, + ) -> typing.Union[None, data_binding.JSONObject, data_binding.JSONArray]: + # Make a ratelimit-protected HTTP request to a JSON endpoint and expect some form + # of JSON response. + + assert self._client_session is not None # This will never be None here + headers = data_binding.StringMapBuilder() + headers.put(_USER_AGENT_HEADER, _HTTP_USER_AGENT) # As per the docs, UTF-8 characters are only supported here if it's url-encoded. headers.put(_X_AUDIT_LOG_REASON_HEADER, reason, conversion=urllib.parse.quote) + can_re_auth = False + if auth is undefined.UNDEFINED: + if isinstance(self._token, rest_api.TokenStrategy): + auth = await self._token.acquire(self) + can_re_auth = True + + else: + auth = self._token + + if auth: + headers[_AUTHORIZATION_HEADER] = auth + url = compiled_route.create_url(self._rest_url) + stack = contextlib.AsyncExitStack() # This is initiated the first time we hit a 5xx error to save a little memory when nothing goes wrong backoff: typing.Optional[rate_limits.ExponentialBackOff] = None retry_count = 0 - - stack = contextlib.AsyncExitStack() trace_logging_enabled = _LOGGER.isEnabledFor(ux.TRACE) + while True: - try: - uuid = time.uuid() - async with stack: - form = await form_builder.build(stack) if form_builder else None - - await stack.enter_async_context(live_attributes.still_alive().buckets.acquire(compiled_route)) - # Buckets not using authentication still have a global - # rate limit, but it is different from the token one. - if not no_auth: - await live_attributes.still_alive().global_rate_limit.acquire() - - if trace_logging_enabled: - _LOGGER.log( - ux.TRACE, - "%s %s %s\n%s", - uuid, - compiled_route.method, - url, - self._stringify_http_message(headers, json), - ) - start = time.monotonic() - - # Make the request. - response = await live_attributes.still_alive().client_session.request( + async with stack: + form = await form_builder.build(stack) if form_builder else None + + if compiled_route.route.has_ratelimits: + await stack.enter_async_context(self._bucket_manager.acquire_bucket(compiled_route, auth)) + + if trace_logging_enabled: + uuid = time.uuid() + _LOGGER.log( + ux.TRACE, + "%s %s %s\n%s", + uuid, compiled_route.method, url, - headers=headers, - params=query, - json=json, - data=form, - allow_redirects=self._http_settings.max_redirects is not None, - max_redirects=self._http_settings.max_redirects, - proxy=self._proxy_settings.url, - proxy_headers=self._proxy_settings.all_headers, + _stringify_http_message(headers, json), ) + start = time.monotonic() + + # Make the request. + response = await self._client_session.request( + compiled_route.method, + url, + headers=headers, + params=query, + json=json, + data=form, + allow_redirects=self._http_settings.max_redirects is not None, + max_redirects=self._http_settings.max_redirects, + proxy=self._proxy_settings.url, + proxy_headers=self._proxy_settings.all_headers, + ) - if trace_logging_enabled: - time_taken = (time.monotonic() - start) * 1_000 # pyright: ignore[reportUnboundVariable] - _LOGGER.log( - ux.TRACE, - "%s %s %s in %sms\n%s", - uuid, - response.status, - response.reason, - time_taken, - self._stringify_http_message(response.headers, await response.read()), - ) - - # Ensure we are not rate limited, and update rate limiting headers where appropriate. - await self._parse_ratelimits(compiled_route, response, live_attributes) - - # Don't bother processing any further if we got NO CONTENT. There's not anything - # to check. - if response.status == http.HTTPStatus.NO_CONTENT: - return None - - # Handle the response when everything went good - if 200 <= response.status < 300: - if response.content_type == _APPLICATION_JSON: - # Only deserializing here stops Cloudflare shenanigans messing us around. - return data_binding.load_json(await response.read()) - - real_url = str(response.real_url) - raise errors.HTTPError(f"Expected JSON [{response.content_type=}, {real_url=}]") - - # Handling 5xx errors - if response.status in _RETRY_ERROR_CODES and retry_count < self._max_retries: - if backoff is None: - backoff = rate_limits.ExponentialBackOff(maximum=_MAX_BACKOFF_DURATION) - - sleep_time = next(backoff) - _LOGGER.warning( - "Received status %s on request, backing off for %.2fs and retrying. Retries remaining: %s", + if trace_logging_enabled: + time_taken = (time.monotonic() - start) * 1_000 # pyright: ignore[reportUnboundVariable] + _LOGGER.log( + ux.TRACE, + "%s %s %s in %sms\n%s", + uuid, # pyright: ignore[reportUnboundVariable] response.status, - sleep_time, - self._max_retries - retry_count, + response.reason, + time_taken, + _stringify_http_message(response.headers, await response.read()), ) - retry_count += 1 - await asyncio.sleep(sleep_time) - continue + # Ensure we are not rate limited, and update rate limiting headers where appropriate. + retry = await self._parse_ratelimits(compiled_route, auth, response) - # Attempt to re-auth on UNAUTHORIZED if we are using a TokenStrategy - can_re_auth = response.status == 401 and not (auth or no_auth or re_authed) - if can_re_auth and isinstance(self._token, rest_api.TokenStrategy): - self._token.invalidate(token) - token = await self._token.acquire(self) - headers[_AUTHORIZATION_HEADER] = token - re_authed = True - continue + if retry: + continue - await self._handle_error_response(response) + # Don't bother processing any further if we got NO CONTENT. There's not anything + # to check. + if response.status == http.HTTPStatus.NO_CONTENT: + return None + + # Handle the response when everything went good + if 200 <= response.status < 300: + if response.content_type == _APPLICATION_JSON: + # Only deserializing here stops Cloudflare shenanigans messing us around. + return data_binding.load_json(await response.read()) + + real_url = str(response.real_url) + raise errors.HTTPError(f"Expected JSON [{response.content_type=}, {real_url=}]") + + # Handling 5xx errors + if response.status in _RETRY_ERROR_CODES and retry_count < self._max_retries: + if not backoff: + backoff = rate_limits.ExponentialBackOff(maximum=_MAX_BACKOFF_DURATION) + + sleep_time = next(backoff) + retry_count += 1 + _LOGGER.warning( + "Received status %s on request, backing off for %.2fs and retrying. Retries remaining: %s", + response.status, + sleep_time, + self._max_retries - retry_count, + ) - except _RetryRequest: + await asyncio.sleep(sleep_time) continue - @staticmethod - @typing.final - def _stringify_http_message(headers: data_binding.Headers, body: typing.Any) -> str: - string = "\n".join( - f" {name}: {value}" if name != _AUTHORIZATION_HEADER else f" {name}: **REDACTED TOKEN**" - for name, value in headers.items() - ) + # Attempt to re-auth on UNAUTHORIZED if we are using a TokenStrategy + if can_re_auth and response.status == 401: + # can_re_auth ensures that it is a token strategy + assert isinstance(self._token, rest_api.TokenStrategy) - if body is not None: - string += "\n\n " - string += body.decode("ascii") if isinstance(body, bytes) else str(body) - - return string + self._token.invalidate(auth) + auth = headers[_AUTHORIZATION_HEADER] = await self._token.acquire(self) + can_re_auth = False + continue - @staticmethod - @typing.final - async def _handle_error_response(response: aiohttp.ClientResponse) -> typing.NoReturn: - raise await net.generate_error_response(response) + raise await net.generate_error_response(response) @typing.final async def _parse_ratelimits( - self, compiled_route: routes.CompiledRoute, response: aiohttp.ClientResponse, live_attributes: _LiveAttributes - ) -> None: + self, + compiled_route: routes.CompiledRoute, + authentication: typing.Optional[str], + response: aiohttp.ClientResponse, + ) -> bool: # Handle rate limiting. resp_headers = response.headers limit = int(resp_headers.get(_X_RATELIMIT_LIMIT_HEADER, "1")) @@ -835,8 +866,20 @@ async def _parse_ratelimits( reset_after = float(resp_headers.get(_X_RATELIMIT_RESET_AFTER_HEADER, "0")) if bucket: - live_attributes.still_alive().buckets.update_rate_limits( + if not compiled_route.route.has_ratelimits: + # This should theoretically never see the light of day, but it scares me that Discord might + # pull a funny one and this may go unnoticed, so better safe to have it! + _LOGGER.error( + "Received an unexpected bucket header for %r. " + "The route will be treated as having a ratelimit for the duration of this applications runtime. " + "If you see this, please report it to the maintainers so the route can be updated!", + compiled_route.route, + ) + compiled_route.route.has_ratelimits = True + + self._bucket_manager.update_rate_limits( compiled_route=compiled_route, + authentication=authentication, bucket_header=bucket, remaining_header=remaining, limit_header=limit, @@ -844,7 +887,7 @@ async def _parse_ratelimits( ) if response.status != http.HTTPStatus.TOO_MANY_REQUESTS: - return + return False # Discord have started applying ratelimits to operations on some endpoints # based on specific fields used in the JSON body. @@ -872,7 +915,7 @@ async def _parse_ratelimits( "rate limited on bucket %s, maybe you are running more than one bot on this token? Retrying request...", bucket, ) - raise _RetryRequest + return True if response.content_type != _APPLICATION_JSON: # We don't know exactly what this could imply. It is likely Cloudflare interfering @@ -894,15 +937,15 @@ async def _parse_ratelimits( "rate limited on the global bucket. You should consider lowering the number of requests you make or " "contacting Discord to raise this limit. Backing off and retrying request..." ) - live_attributes.still_alive().global_rate_limit.throttle(body_retry_after) - raise _RetryRequest + self._bucket_manager.throttle(body_retry_after) + return True # If the values are within 20% of each other by relativistic tolerance, it is probably # safe to retry the request, as they are likely the same value just with some # measuring difference. 20% was used as a rounded figure. if math.isclose(body_retry_after, reset_after, rel_tol=0.20): _LOGGER.error("rate limited on a sub bucket on bucket %s, but it is safe to retry", bucket) - raise _RetryRequest + return True raise errors.RateLimitedError( url=str(response.real_url), @@ -1159,8 +1202,11 @@ async def create_invite( def trigger_typing( self, channel: snowflakes.SnowflakeishOr[channels_.TextableChannel] ) -> special_endpoints.TypingIndicator: + if not self._close_event: + raise errors.ComponentStateConflictError("Cannot use an inactive REST client") + return special_endpoints_impl.TypingIndicator( - request_call=self._request, channel=channel, rest_closed_event=self._get_live_attributes().closed_event + request_call=self._request, channel=channel, rest_close_event=self._close_event ) async def fetch_pins( @@ -1555,23 +1601,6 @@ async def delete_messages( except Exception as ex: raise errors.BulkDeleteError(deleted) from ex - @staticmethod - def _transform_emoji_to_url_format( - emoji: typing.Union[str, emojis.Emoji], - emoji_id: undefined.UndefinedOr[snowflakes.SnowflakeishOr[emojis.CustomEmoji]], - /, - ) -> str: - if isinstance(emoji, emojis.Emoji): - if emoji_id is not undefined.UNDEFINED: - raise ValueError("emoji_id shouldn't be passed when an Emoji object is passed for emoji") - - return emoji.url_name - - if emoji_id is not undefined.UNDEFINED: - return f"{emoji}:{snowflakes.Snowflake(emoji_id)}" - - return emoji - async def add_reaction( self, channel: snowflakes.SnowflakeishOr[channels_.TextableChannel], @@ -1580,7 +1609,7 @@ async def add_reaction( emoji_id: undefined.UndefinedOr[snowflakes.SnowflakeishOr[emojis.CustomEmoji]] = undefined.UNDEFINED, ) -> None: route = routes.PUT_MY_REACTION.compile( - emoji=self._transform_emoji_to_url_format(emoji, emoji_id), + emoji=_transform_emoji_to_url_format(emoji, emoji_id), channel=channel, message=message, ) @@ -1594,7 +1623,7 @@ async def delete_my_reaction( emoji_id: undefined.UndefinedOr[snowflakes.SnowflakeishOr[emojis.CustomEmoji]] = undefined.UNDEFINED, ) -> None: route = routes.DELETE_MY_REACTION.compile( - emoji=self._transform_emoji_to_url_format(emoji, emoji_id), + emoji=_transform_emoji_to_url_format(emoji, emoji_id), channel=channel, message=message, ) @@ -1608,7 +1637,7 @@ async def delete_all_reactions_for_emoji( emoji_id: undefined.UndefinedOr[snowflakes.SnowflakeishOr[emojis.CustomEmoji]] = undefined.UNDEFINED, ) -> None: route = routes.DELETE_REACTION_EMOJI.compile( - emoji=self._transform_emoji_to_url_format(emoji, emoji_id), + emoji=_transform_emoji_to_url_format(emoji, emoji_id), channel=channel, message=message, ) @@ -1623,7 +1652,7 @@ async def delete_reaction( emoji_id: undefined.UndefinedOr[snowflakes.SnowflakeishOr[emojis.CustomEmoji]] = undefined.UNDEFINED, ) -> None: route = routes.DELETE_REACTION_USER.compile( - emoji=self._transform_emoji_to_url_format(emoji, emoji_id), + emoji=_transform_emoji_to_url_format(emoji, emoji_id), channel=channel, message=message, user=user, @@ -1650,7 +1679,7 @@ def fetch_reactions_for_emoji( request_call=self._request, channel=channel, message=message, - emoji=self._transform_emoji_to_url_format(emoji, emoji_id), + emoji=_transform_emoji_to_url_format(emoji, emoji_id), ) async def create_webhook( @@ -1682,12 +1711,12 @@ async def fetch_webhook( ) -> webhooks.PartialWebhook: if token is undefined.UNDEFINED: route = routes.GET_WEBHOOK.compile(webhook=webhook) - no_auth = False + auth = undefined.UNDEFINED else: route = routes.GET_WEBHOOK_WITH_TOKEN.compile(webhook=webhook, token=token) - no_auth = True + auth = None - response = await self._request(route, no_auth=no_auth) + response = await self._request(route, auth=auth) assert isinstance(response, dict) return self._entity_factory.deserialize_webhook(response) @@ -1721,10 +1750,10 @@ async def edit_webhook( ) -> webhooks.PartialWebhook: if token is undefined.UNDEFINED: route = routes.PATCH_WEBHOOK.compile(webhook=webhook) - no_auth = False + auth = undefined.UNDEFINED else: route = routes.PATCH_WEBHOOK_WITH_TOKEN.compile(webhook=webhook, token=token) - no_auth = True + auth = None body = data_binding.JSONObjectBuilder() body.put("name", name) @@ -1737,7 +1766,7 @@ async def edit_webhook( async with avatar_resource.stream(executor=self._executor) as stream: body.put("avatar", await stream.data_uri()) - response = await self._request(route, json=body, reason=reason, no_auth=no_auth) + response = await self._request(route, json=body, reason=reason, auth=auth) assert isinstance(response, dict) return self._entity_factory.deserialize_webhook(response) @@ -1749,12 +1778,12 @@ async def delete_webhook( ) -> None: if token is undefined.UNDEFINED: route = routes.DELETE_WEBHOOK.compile(webhook=webhook) - no_auth = False + auth = undefined.UNDEFINED else: route = routes.DELETE_WEBHOOK_WITH_TOKEN.compile(webhook=webhook, token=token) - no_auth = True + auth = None - await self._request(route, no_auth=no_auth) + await self._request(route, auth=auth) async def execute_webhook( self, @@ -1810,9 +1839,9 @@ async def execute_webhook( if form_builder is not None: form_builder.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON) - response = await self._request(route, form_builder=form_builder, query=query, no_auth=True) + response = await self._request(route, form_builder=form_builder, query=query, auth=None) else: - response = await self._request(route, json=body, query=query, no_auth=True) + response = await self._request(route, json=body, query=query, auth=None) assert isinstance(response, dict) return self._entity_factory.deserialize_message(response) @@ -1832,7 +1861,7 @@ async def fetch_webhook_message( route = routes.GET_WEBHOOK_MESSAGE.compile(webhook=webhook_id, token=token, message=message) query = data_binding.StringMapBuilder() query.put("thread_id", thread) - response = await self._request(route, no_auth=True, query=query) + response = await self._request(route, auth=None, query=query) assert isinstance(response, dict) return self._entity_factory.deserialize_message(response) @@ -1888,9 +1917,9 @@ async def edit_webhook_message( if form_builder is not None: form_builder.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON) - response = await self._request(route, form_builder=form_builder, query=query, no_auth=True) + response = await self._request(route, form_builder=form_builder, query=query, auth=None) else: - response = await self._request(route, json=body, query=query, no_auth=True) + response = await self._request(route, json=body, query=query, auth=None) assert isinstance(response, dict) return self._entity_factory.deserialize_message(response) @@ -1910,12 +1939,12 @@ async def delete_webhook_message( query = data_binding.StringMapBuilder() query.put("thread_id", thread) route = routes.DELETE_WEBHOOK_MESSAGE.compile(webhook=webhook_id, token=token, message=message) - await self._request(route, no_auth=True, query=query) + await self._request(route, query=query, auth=None) async def fetch_gateway_url(self) -> str: route = routes.GET_GATEWAY.compile() # This doesn't need authorization. - response = await self._request(route, no_auth=True) + response = await self._request(route, auth=None) assert isinstance(response, dict) url = response["url"] assert isinstance(url, str) @@ -2242,7 +2271,7 @@ async def delete_emoji( async def fetch_available_sticker_packs(self) -> typing.Sequence[stickers.StickerPack]: route = routes.GET_STICKER_PACKS.compile() - response = await self._request(route, no_auth=True) + response = await self._request(route, auth=None) assert isinstance(response, dict) return [ self._entity_factory.deserialize_sticker_pack(sticker_pack_payload) @@ -3851,7 +3880,7 @@ async def fetch_interaction_response( self, application: snowflakes.SnowflakeishOr[guilds.PartialApplication], token: str ) -> messages_.Message: route = routes.GET_INTERACTION_RESPONSE.compile(webhook=application, token=token) - response = await self._request(route, no_auth=True) + response = await self._request(route, auth=None) assert isinstance(response, dict) return self._entity_factory.deserialize_message(response) @@ -3902,9 +3931,9 @@ async def create_interaction_response( if form is not None: form.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON) - await self._request(route, form_builder=form, no_auth=True) + await self._request(route, form_builder=form, auth=None) else: - await self._request(route, json=body, no_auth=True) + await self._request(route, json=body, auth=None) async def edit_interaction_response( self, @@ -3950,9 +3979,9 @@ async def edit_interaction_response( if form_builder is not None: form_builder.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON) - response = await self._request(route, form_builder=form_builder, no_auth=True) + response = await self._request(route, form_builder=form_builder, auth=None) else: - response = await self._request(route, json=body, no_auth=True) + response = await self._request(route, json=body, auth=None) assert isinstance(response, dict) return self._entity_factory.deserialize_message(response) @@ -3961,7 +3990,7 @@ async def delete_interaction_response( self, application: snowflakes.SnowflakeishOr[guilds.PartialApplication], token: str ) -> None: route = routes.DELETE_INTERACTION_RESPONSE.compile(webhook=application, token=token) - await self._request(route, no_auth=True) + await self._request(route, auth=None) async def create_autocomplete_response( self, @@ -3978,7 +4007,7 @@ async def create_autocomplete_response( data.put("choices", [{"name": choice.name, "value": choice.value} for choice in choices]) body.put("data", data) - await self._request(route, json=body, no_auth=True) + await self._request(route, json=body, auth=None) async def create_modal_response( self, @@ -4009,7 +4038,7 @@ async def create_modal_response( body.put("data", data) - await self._request(route, json=body, no_auth=True) + await self._request(route, json=body, auth=None) def build_action_row(self) -> special_endpoints.MessageActionRowBuilder: """Build a message action row message component for use in message create and REST calls. diff --git a/hikari/impl/special_endpoints.py b/hikari/impl/special_endpoints.py index f3716777ea..a2023c9074 100644 --- a/hikari/impl/special_endpoints.py +++ b/hikari/impl/special_endpoints.py @@ -109,8 +109,7 @@ async def __call__( form_builder: typing.Optional[data_binding.URLEncodedFormBuilder] = None, json: typing.Union[data_binding.JSONObjectBuilder, data_binding.JSONArray, None] = None, reason: undefined.UndefinedOr[str] = undefined.UNDEFINED, - no_auth: bool = False, - auth: typing.Optional[str] = None, + auth: undefined.UndefinedNoneOr[str] = undefined.UNDEFINED, ) -> typing.Union[None, data_binding.JSONObject, data_binding.JSONArray]: ... @@ -156,13 +155,13 @@ def __init__( self, request_call: _RequestCallSig, channel: snowflakes.SnowflakeishOr[channels.TextableChannel], - rest_closed_event: asyncio.Event, + rest_close_event: asyncio.Event, ) -> None: self._route = routes.POST_CHANNEL_TYPING.compile(channel=channel) self._request_call = request_call self._task_name = f"repeatedly trigger typing in {channel}" self._task: typing.Optional[asyncio.Task[None]] = None - self._rest_close_event = rest_closed_event + self._rest_close_event = rest_close_event def __await__(self) -> typing.Generator[typing.Any, typing.Any, typing.Any]: return self._request_call(self._route).__await__() diff --git a/hikari/internal/net.py b/hikari/internal/net.py index 59243fd36c..3aaebbe2ff 100644 --- a/hikari/internal/net.py +++ b/hikari/internal/net.py @@ -90,13 +90,13 @@ def create_tcp_connector( Optional Parameters ------------------- - dns_cache: typing.Union[None, bool, int] + dns_cache : typing.Union[None, bool, int] If `True`, DNS caching is used with a default TTL of 10 seconds. If `False`, DNS caching is disabled. If an `int` is given, then DNS caching is enabled with an explicit TTL set. If `None`, the cache will be enabled and never invalidate. limit : int - Number of connections to allow in the pool at a maximum. + Number of connections to allow in the pool at any given time. Returns ------- diff --git a/hikari/internal/routes.py b/hikari/internal/routes.py index 210cd9899e..5d7851b76c 100644 --- a/hikari/internal/routes.py +++ b/hikari/internal/routes.py @@ -88,7 +88,7 @@ def create_url(self, base_url: str) -> str: """ return base_url + self.compiled_path - def create_real_bucket_hash(self, initial_bucket_hash: str) -> str: + def create_real_bucket_hash(self, initial_bucket_hash: str, authentication_hash: str) -> str: """Create a full bucket hash from a given initial hash. The result of this hash will be decided by the value of the major @@ -99,6 +99,8 @@ def create_real_bucket_hash(self, initial_bucket_hash: str) -> str: initial_bucket_hash : str The initial bucket hash provided by Discord in the HTTP headers for a given response. + authentication_hash : str + The token hash. Returns ------- @@ -106,7 +108,7 @@ def create_real_bucket_hash(self, initial_bucket_hash: str) -> str: The input hash amalgamated with a hash code produced by the major parameters in this compiled route instance. """ - return initial_bucket_hash + HASH_SEPARATOR + self.major_param_hash + return f"{initial_bucket_hash}{HASH_SEPARATOR}{authentication_hash}{HASH_SEPARATOR}{self.major_param_hash}" def __str__(self) -> str: return f"{self.method} {self.compiled_path}" @@ -135,12 +137,21 @@ class Route: path_template: str = attr.field() """The template string used for the path.""" - major_params: typing.Optional[typing.FrozenSet[str]] = attr.field(hash=False, eq=False) + major_params: typing.Optional[typing.FrozenSet[str]] = attr.field(hash=False, eq=False, repr=False) """The optional major parameter name combination for this endpoint.""" - def __init__(self, method: str, path_template: str) -> None: + has_ratelimits: bool = attr.field(hash=False, eq=False, repr=False) + """Whether this route is affected by ratelimits. + + This should be left as `True` (the default) for most routes. This + only covers specific routes where no ratelimits exist, so we can + be a bit more efficient with them. + """ + + def __init__(self, method: str, path_template: str, *, has_ratelimits: bool = True) -> None: self.method = method self.path_template = path_template + self.has_ratelimits = has_ratelimits self.major_params = None match = PARAM_REGEX.findall(path_template) @@ -528,16 +539,17 @@ def compile_to_file( # For these endpoints "webhook" is the application ID. GET_INTERACTION_RESPONSE: typing.Final[Route] = Route(GET, "/webhooks/{webhook}/{token}/messages/@original") PATCH_INTERACTION_RESPONSE: typing.Final[Route] = Route(PATCH, "/webhooks/{webhook}/{token}/messages/@original") -POST_INTERACTION_RESPONSE: typing.Final[Route] = Route(POST, "/interactions/{interaction}/{token}/callback") +POST_INTERACTION_RESPONSE: typing.Final[Route] = Route( + POST, "/interactions/{interaction}/{token}/callback", has_ratelimits=False +) DELETE_INTERACTION_RESPONSE: typing.Final[Route] = Route(DELETE, "/webhooks/{webhook}/{token}/messages/@original") # OAuth2 API GET_MY_APPLICATION: typing.Final[Route] = Route(GET, "/oauth2/applications/@me") GET_MY_AUTHORIZATION: typing.Final[Route] = Route(GET, "/oauth2/@me") -POST_AUTHORIZE: typing.Final[Route] = Route(POST, "/oauth2/authorize") -POST_TOKEN: typing.Final[Route] = Route(POST, "/oauth2/token") -POST_TOKEN_REVOKE: typing.Final[Route] = Route(POST, "/oauth2/token/revoke") +POST_TOKEN: typing.Final[Route] = Route(POST, "/oauth2/token", has_ratelimits=False) +POST_TOKEN_REVOKE: typing.Final[Route] = Route(POST, "/oauth2/token/revoke", has_ratelimits=False) # Gateway GET_GATEWAY: typing.Final[Route] = Route(GET, "/gateway") diff --git a/tests/hikari/hikari_test_helpers.py b/tests/hikari/hikari_test_helpers.py index 81a3fe2080..db5498af30 100644 --- a/tests/hikari/hikari_test_helpers.py +++ b/tests/hikari/hikari_test_helpers.py @@ -35,7 +35,7 @@ # How long to wait for before considering a test to be jammed in an unbreakable # condition, and thus acceptable to terminate the test and fail it. -REASONABLE_TIMEOUT_AFTER = 10 +REASONABLE_TIMEOUT_AFTER = 5 _T = typing.TypeVar("_T") diff --git a/tests/hikari/impl/test_buckets.py b/tests/hikari/impl/test_buckets.py index 23fab9bec4..d80824096a 100644 --- a/tests/hikari/impl/test_buckets.py +++ b/tests/hikari/impl/test_buckets.py @@ -31,7 +31,6 @@ from hikari.impl import rate_limits from hikari.internal import routes from hikari.internal import time as hikari_date -from tests.hikari import hikari_test_helpers class TestRESTBucket: @@ -47,7 +46,7 @@ def compiled_route(self, template): async def test_async_context_manager(self, compiled_route): with mock.patch.object(buckets.RESTBucket, "acquire", new=mock.AsyncMock()) as acquire: with mock.patch.object(buckets.RESTBucket, "release") as release: - async with buckets.RESTBucket("spaghetti", compiled_route, float("inf")): + async with buckets.RESTBucket("spaghetti", compiled_route, object(), float("inf")): acquire.assert_awaited_once_with() release.assert_not_called() @@ -55,11 +54,11 @@ async def test_async_context_manager(self, compiled_route): @pytest.mark.parametrize("name", ["spaghetti", buckets.UNKNOWN_HASH]) def test_is_unknown(self, name, compiled_route): - with buckets.RESTBucket(name, compiled_route, float("inf")) as rl: + with buckets.RESTBucket(name, compiled_route, object(), float("inf")) as rl: assert rl.is_unknown is (name == buckets.UNKNOWN_HASH) def test_release(self, compiled_route): - with buckets.RESTBucket(__name__, compiled_route, float("inf")) as rl: + with buckets.RESTBucket(__name__, compiled_route, object(), float("inf")) as rl: rl._lock = mock.Mock() rl.release() @@ -67,7 +66,7 @@ def test_release(self, compiled_route): rl._lock.release.assert_called_once_with() def test_update_rate_limit(self, compiled_route): - with buckets.RESTBucket(__name__, compiled_route, float("inf")) as rl: + with buckets.RESTBucket(__name__, compiled_route, object(), float("inf")) as rl: rl.remaining = 1 rl.limit = 2 rl.reset_at = 3 @@ -83,7 +82,7 @@ def test_update_rate_limit(self, compiled_route): @pytest.mark.asyncio() async def test_acquire_when_unknown_bucket(self, compiled_route): - with buckets.RESTBucket(buckets.UNKNOWN_HASH, compiled_route, float("inf")) as rl: + with buckets.RESTBucket(buckets.UNKNOWN_HASH, compiled_route, object(), float("inf")) as rl: rl._lock = mock.AsyncMock() with mock.patch.object(rate_limits.WindowedBurstRateLimiter, "acquire") as super_acquire: assert await rl.acquire() is None @@ -94,7 +93,7 @@ async def test_acquire_when_unknown_bucket(self, compiled_route): @pytest.mark.asyncio() async def test_acquire_when_too_long_ratelimit(self, compiled_route): stack = contextlib.ExitStack() - rl = stack.enter_context(buckets.RESTBucket("spaghetti", compiled_route, 60)) + rl = stack.enter_context(buckets.RESTBucket("spaghetti", compiled_route, object(), 60)) rl._lock = mock.Mock(acquire=mock.AsyncMock()) rl.reset_at = time.perf_counter() + 999999999999999999999999999 stack.enter_context(mock.patch.object(buckets.RESTBucket, "is_rate_limited", return_value=True)) @@ -108,290 +107,311 @@ async def test_acquire_when_too_long_ratelimit(self, compiled_route): @pytest.mark.asyncio() async def test_acquire(self, compiled_route): - with buckets.RESTBucket("spaghetti", compiled_route, float("inf")) as rl: + global_ratelimit = mock.AsyncMock() + + with buckets.RESTBucket("spaghetti", compiled_route, global_ratelimit, float("inf")) as rl: rl._lock = mock.AsyncMock() with mock.patch.object(rate_limits.WindowedBurstRateLimiter, "acquire") as super_acquire: await rl.acquire() - super_acquire.assert_awaited_once_with() - rl._lock.acquire.assert_awaited_once_with() + super_acquire.assert_awaited_once_with() + rl._lock.acquire.assert_awaited_once_with() + global_ratelimit.acquire.assert_awaited_once_with() def test_resolve_when_not_unknown(self, compiled_route): - with buckets.RESTBucket("spaghetti", compiled_route, float("inf")) as rl: + with buckets.RESTBucket("spaghetti", compiled_route, object(), float("inf")) as rl: with pytest.raises(RuntimeError, match=r"Cannot resolve known bucket"): rl.resolve("test") assert rl.name == "spaghetti" def test_resolve(self, compiled_route): - with buckets.RESTBucket(buckets.UNKNOWN_HASH, compiled_route, float("inf")) as rl: + with buckets.RESTBucket(buckets.UNKNOWN_HASH, compiled_route, object(), float("inf")) as rl: rl.resolve("test") assert rl.name == "test" class TestRESTBucketManager: - @pytest.mark.asyncio() - async def test_close_closes_all_buckets(self): - class MockBucket: - def __init__(self): - self.close = mock.Mock() + @pytest.fixture() + def bucket_manager(self): + manager = buckets.RESTBucketManager(max_rate_limit=float("inf")) + manager._closed_event = mock.Mock(wait=mock.AsyncMock(), is_set=mock.Mock(return_value=False)) + + return manager - buckets_array = [MockBucket() for _ in range(30)] + @pytest.mark.asyncio() + async def test_close_closes_all_buckets(self, bucket_manager): + buckets_array = [mock.Mock() for _ in range(30)] + bucket_manager._real_hashes_to_buckets = {f"blah{i}": b for i, b in enumerate(buckets_array)} - mgr = buckets.RESTBucketManager(max_rate_limit=float("inf")) - mgr.real_hashes_to_buckets = {f"blah{i}": bucket for i, bucket in enumerate(buckets_array)} + bucket_manager.close() - mgr.close() + assert bucket_manager._real_hashes_to_buckets == {} - for i, bucket in enumerate(buckets_array): - bucket.close.assert_called_once(), i + for i, b in enumerate(buckets_array): + b.close.assert_called_once(), i @pytest.mark.asyncio() - async def test_close_sets_closed_event(self): - mgr = buckets.RESTBucketManager(max_rate_limit=float("inf")) - assert not mgr.closed_event.is_set() - mgr.close() - assert mgr.closed_event.is_set() + async def test_close_sets_closed_event(self, bucket_manager): + closed_event = mock.Mock() + bucket_manager._closed_event = closed_event - @pytest.mark.asyncio() - async def test_start(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - assert mgr.gc_task is None - mgr.start() - assert mgr.gc_task is not None + bucket_manager.close() + + assert bucket_manager._closed_event is None + closed_event.set.assert_called_once() @pytest.mark.asyncio() - async def test_start_when_already_started(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - mock_task = mock.Mock() - mgr.gc_task = mock_task - mgr.start() - assert mgr.gc_task is mock_task + async def test_start(self, bucket_manager): + bucket_manager._closed_event = None + + assert bucket_manager._gc_task is None + bucket_manager.start() + assert bucket_manager._gc_task is not None + + # cancel created task + bucket_manager._gc_task.cancel() + try: + await bucket_manager._gc_task + except asyncio.CancelledError: + pass @pytest.mark.asyncio() - async def test_exit_closes(self): - with mock.patch.object(buckets.RESTBucketManager, "close") as close: - with mock.patch.object(buckets.RESTBucketManager, "gc") as gc: - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - mgr.start(0.01, 32) - gc.assert_called_once_with(0.01, 32) - close.assert_called() + async def test_start_when_already_started(self, bucket_manager): + bucket_manager._closed_event = object() + + with pytest.raises(errors.ComponentStateConflictError): + bucket_manager.start() @pytest.mark.asyncio() - async def test_gc_polls_until_closed_event_set(self, event_loop): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - # Start the gc and initial assertions - task = event_loop.create_task(mgr.gc(0.001, float("inf"))) - assert not task.done() + async def test_gc_polls_until_closed_event_set(self, event_loop, bucket_manager): + bucket_manager._closed_event = asyncio.Event() + + # Start the gc and initial assertions + task = event_loop.create_task(bucket_manager._gc(0.001, float("inf"))) + assert not task.done() - # [First poll] event not set => shouldn't complete the task - await asyncio.sleep(0.001) - assert not task.done() + # [First poll] event not set => shouldn't complete the task + await asyncio.sleep(0.001) + assert not task.done() - # [Second poll] event not set during poll => shouldn't complete the task - await asyncio.sleep(0.001) - mgr.closed_event.set() - assert not task.done() + # [Second poll] event not set during poll => shouldn't complete the task + await asyncio.sleep(0.001) + bucket_manager._closed_event.set() + assert not task.done() - # [Third poll] event set => should complete the task - await asyncio.sleep(0.001) - assert task.done() + # [Third poll] event set => should complete the task + await asyncio.sleep(0.001) + assert task.done() @pytest.mark.asyncio() - async def test_gc_calls_do_pass(self): + async def test_gc_makes_gc_pass(self, bucket_manager): class ExitError(Exception): ... - with hikari_test_helpers.mock_class_namespace(buckets.RESTBucketManager, slots_=False)( - max_rate_limit=float("inf") - ) as mgr: - mgr.do_gc_pass = mock.Mock(side_effect=ExitError) - with pytest.raises(ExitError): - await mgr.gc(0.001, 33) + bucket_manager._closed_event.wait = mock.Mock() + + with mock.patch.object(buckets.RESTBucketManager, "_purge_stale_buckets") as purge_stale_buckets: + with mock.patch.object(asyncio, "wait_for", side_effect=[asyncio.TimeoutError, ExitError]): + with pytest.raises(ExitError): + await bucket_manager._gc(0.001, 33) - mgr.do_gc_pass.assert_called_with(33) + purge_stale_buckets.assert_called_with(33) @pytest.mark.asyncio() - async def test_do_gc_pass_any_buckets_that_are_empty_but_still_rate_limited_are_kept_alive(self): - with hikari_test_helpers.mock_class_namespace(buckets.RESTBucketManager)(max_rate_limit=float("inf")) as mgr: - bucket = mock.Mock() - bucket.is_empty = True - bucket.is_unknown = False - bucket.reset_at = time.perf_counter() + 999999999999999999999999999 + async def test_purge_stale_buckets_any_buckets_that_are_empty_but_still_rate_limited_are_kept_alive( + self, bucket_manager + ): + bucket = mock.Mock() + bucket.is_empty = True + bucket.is_unknown = False + bucket.reset_at = time.perf_counter() + 999999999999999999999999999 - mgr.real_hashes_to_buckets["foobar"] = bucket + bucket_manager._real_hashes_to_buckets["foobar"] = bucket - mgr.do_gc_pass(0) + bucket_manager._purge_stale_buckets(0) - assert "foobar" in mgr.real_hashes_to_buckets - bucket.close.assert_not_called() + assert "foobar" in bucket_manager._real_hashes_to_buckets + bucket.close.assert_not_called() @pytest.mark.asyncio() - async def test_do_gc_pass_any_buckets_that_are_empty_but_not_rate_limited_and_not_expired_are_kept_alive(self): - with hikari_test_helpers.mock_class_namespace(buckets.RESTBucketManager)(max_rate_limit=float("inf")) as mgr: - bucket = mock.Mock() - bucket.is_empty = True - bucket.is_unknown = False - bucket.reset_at = time.perf_counter() + async def test_purge_stale_buckets_any_buckets_that_are_empty_but_not_rate_limited_and_not_expired_are_kept_alive( + self, bucket_manager + ): + bucket = mock.Mock() + bucket.is_empty = True + bucket.is_unknown = False + bucket.reset_at = time.perf_counter() - mgr.real_hashes_to_buckets["foobar"] = bucket + bucket_manager._real_hashes_to_buckets["foobar"] = bucket - mgr.do_gc_pass(10) + bucket_manager._purge_stale_buckets(10) - assert "foobar" in mgr.real_hashes_to_buckets - bucket.close.assert_not_called() + assert "foobar" in bucket_manager._real_hashes_to_buckets + bucket.close.assert_not_called() @pytest.mark.asyncio() - async def test_do_gc_pass_any_buckets_that_are_empty_but_not_rate_limited_and_expired_are_closed(self): - with hikari_test_helpers.mock_class_namespace(buckets.RESTBucketManager)(max_rate_limit=float("inf")) as mgr: - bucket = mock.Mock() - bucket.is_empty = True - bucket.is_unknown = False - bucket.reset_at = time.perf_counter() - 999999999999999999999999999 + async def test_purge_stale_buckets_any_buckets_that_are_empty_but_not_rate_limited_and_expired_are_closed( + self, bucket_manager + ): + bucket = mock.Mock() + bucket.is_empty = True + bucket.is_unknown = False + bucket.reset_at = time.perf_counter() - 999999999999999999999999999 - mgr.real_hashes_to_buckets["foobar"] = bucket + bucket_manager._real_hashes_to_buckets["foobar"] = bucket - mgr.do_gc_pass(0) + bucket_manager._purge_stale_buckets(0) - assert "foobar" not in mgr.real_hashes_to_buckets - bucket.close.assert_called_once() + assert "foobar" not in bucket_manager._real_hashes_to_buckets + bucket.close.assert_called_once() @pytest.mark.asyncio() - async def test_do_gc_pass_any_buckets_that_are_not_empty_are_kept_alive(self): - with hikari_test_helpers.mock_class_namespace(buckets.RESTBucketManager)(max_rate_limit=float("inf")) as mgr: - bucket = mock.Mock() - bucket.is_empty = False - bucket.is_unknown = True - bucket.reset_at = time.perf_counter() + async def test_purge_stale_buckets_any_buckets_that_are_not_empty_are_kept_alive(self, bucket_manager): + bucket = mock.Mock() + bucket.is_empty = False + bucket.is_unknown = True + bucket.reset_at = time.perf_counter() - mgr.real_hashes_to_buckets["foobar"] = bucket + bucket_manager._real_hashes_to_buckets["foobar"] = bucket - mgr.do_gc_pass(0) + bucket_manager._purge_stale_buckets(0) - assert "foobar" in mgr.real_hashes_to_buckets - bucket.close.assert_not_called() + assert "foobar" in bucket_manager._real_hashes_to_buckets + bucket.close.assert_not_called() @pytest.mark.asyncio() - async def test_acquire_route_when_not_in_routes_to_real_hashes_makes_new_bucket_using_initial_hash(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - route = mock.Mock() + async def test_acquire_route_when_not_in_routes_to_real_hashes_makes_new_bucket_using_initial_hash( + self, bucket_manager + ): + route = mock.Mock() - with mock.patch.object(buckets, "_create_unknown_hash", return_value="UNKNOWN;bobs") as create_unknown_hash: - mgr.acquire(route) + with mock.patch.object(buckets, "_create_authentication_hash", return_value="auth_hash"): + with mock.patch.object( + buckets, "_create_unknown_hash", return_value="UNKNOWN;auth_hash;bobs" + ) as create_unknown_hash: + bucket_manager.acquire_bucket(route, "auth") - assert "UNKNOWN;bobs" in mgr.real_hashes_to_buckets - assert isinstance(mgr.real_hashes_to_buckets["UNKNOWN;bobs"], buckets.RESTBucket) - create_unknown_hash.assert_called_once_with(route) + assert "UNKNOWN;auth_hash;bobs" in bucket_manager._real_hashes_to_buckets + assert isinstance(bucket_manager._real_hashes_to_buckets["UNKNOWN;auth_hash;bobs"], buckets.RESTBucket) + create_unknown_hash.assert_called_once_with(route, "auth_hash") @pytest.mark.asyncio() - async def test_acquire_route_when_not_in_routes_to_real_hashes_doesnt_cache_route(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - route = mock.Mock() - route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash: initial_hash + ";bobs") + async def test_acquire_route_when_not_in_routes_to_real_hashes_doesnt_cache_route(self, bucket_manager): + route = mock.Mock() + route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash, auth: initial_hash + ";" + auth + ";bobs") - mgr.acquire(route) + bucket_manager.acquire_bucket(route, "auth") - assert mgr.routes_to_hashes.get(route.route) is None + assert bucket_manager._routes_to_hashes.get(route.route) is None @pytest.mark.asyncio() - async def test_acquire_route_when_route_cached_already_obtains_hash_from_route_and_bucket_from_hash(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - route = mock.Mock() - route.create_real_bucket_hash = mock.Mock(return_value="eat pant;1234") - bucket = mock.Mock(reset_at=time.perf_counter() + 999999999999999999999999999) - mgr.routes_to_hashes[route.route] = "eat pant" - mgr.real_hashes_to_buckets["eat pant;1234"] = bucket + async def test_acquire_route_when_route_cached_already_obtains_hash_from_route_and_bucket_from_hash( + self, bucket_manager + ): + route = mock.Mock() + route.create_real_bucket_hash = mock.Mock(return_value="eat pant;1234") + bucket = mock.Mock(reset_at=time.perf_counter() + 999999999999999999999999999) + bucket_manager._routes_to_hashes[route.route] = "eat pant" + bucket_manager._real_hashes_to_buckets["eat pant;1234"] = bucket - assert mgr.acquire(route) is bucket + assert bucket_manager.acquire_bucket(route, "auth") is bucket @pytest.mark.asyncio() - async def test_acquire_route_returns_context_manager(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - route = mock.Mock() + async def test_acquire_route_returns_context_manager(self, bucket_manager): + route = mock.Mock() - bucket = mock.Mock(reset_at=time.perf_counter() + 999999999999999999999999999) - with mock.patch.object(buckets, "RESTBucket", return_value=bucket): - route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash: initial_hash + ";bobs") + bucket = mock.Mock(reset_at=time.perf_counter() + 999999999999999999999999999) + with mock.patch.object(buckets, "RESTBucket", return_value=bucket): + route.create_real_bucket_hash = mock.Mock( + wraps=lambda initial_hash, auth: initial_hash + ";" + auth + ";bobs" + ) - assert mgr.acquire(route) is bucket + assert bucket_manager.acquire_bucket(route, "auth") is bucket @pytest.mark.asyncio() - async def test_acquire_unknown_route_returns_context_manager_for_new_bucket(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - route = mock.Mock() - route.create_real_bucket_hash = mock.Mock(return_value="eat pant;bobs") - bucket = mock.Mock(reset_at=time.perf_counter() + 999999999999999999999999999) - mgr.routes_to_hashes[route.route] = "eat pant" - mgr.real_hashes_to_buckets["eat pant;bobs"] = bucket + async def test_acquire_unknown_route_returns_context_manager_for_new_bucket(self, bucket_manager): + route = mock.Mock() + route.create_real_bucket_hash = mock.Mock(return_value="eat pant;bobs") + bucket = mock.Mock(reset_at=time.perf_counter() + 999999999999999999999999999) + bucket_manager._routes_to_hashes[route.route] = "eat pant" + bucket_manager._real_hashes_to_buckets["eat pant;bobs"] = bucket - assert mgr.acquire(route) is bucket + assert bucket_manager.acquire_bucket(route, "auth") is bucket @pytest.mark.asyncio() - async def test_update_rate_limits_if_wrong_bucket_hash_reroutes_route(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - route = mock.Mock() - route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash: initial_hash + ";bobs") - mgr.routes_to_hashes[route.route] = "123" + async def test_update_rate_limits_if_wrong_bucket_hash_reroutes_route(self, bucket_manager): + route = mock.Mock() + route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash, auth: initial_hash + ";" + auth + ";bobs") + bucket_manager._routes_to_hashes[route.route] = "123" + with mock.patch.object(buckets, "_create_authentication_hash", return_value="auth_hash"): with mock.patch.object(hikari_date, "monotonic", return_value=27): with mock.patch.object(buckets, "RESTBucket") as bucket: - mgr.update_rate_limits(route, "blep", 22, 23, 3.56) + bucket_manager.update_rate_limits(route, "auth", "blep", 22, 23, 3.56) - assert mgr.routes_to_hashes[route.route] == "blep" - assert mgr.real_hashes_to_buckets["blep;bobs"] is bucket.return_value - bucket.return_value.update_rate_limit.assert_called_once_with(22, 23, 27 + 3.56) + assert bucket_manager._routes_to_hashes[route.route] == "blep" + assert bucket_manager._real_hashes_to_buckets["blep;auth_hash;bobs"] is bucket.return_value + bucket.return_value.update_rate_limit.assert_called_once_with(22, 23, 27 + 3.56) @pytest.mark.asyncio() - async def test_update_rate_limits_if_unknown_bucket_hash_reroutes_route(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - route = mock.Mock() - route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash: initial_hash + ";bobs") - mgr.routes_to_hashes[route.route] = "123" - bucket = mock.Mock() - mgr.real_hashes_to_buckets["UNKNOWN;bobs"] = bucket - - with mock.patch.object(buckets, "_create_unknown_hash", return_value="UNKNOWN;bobs") as create_unknown_hash: - with mock.patch.object(hikari_date, "monotonic", return_value=27): - mgr.update_rate_limits(route, "blep", 22, 23, 3.56) - - assert mgr.routes_to_hashes[route.route] == "blep" - assert mgr.real_hashes_to_buckets["blep;bobs"] is bucket - bucket.resolve.assert_called_once_with("blep;bobs") - bucket.update_rate_limit.assert_called_once_with(22, 23, 27 + 3.56) - create_unknown_hash.assert_called_once_with(route) + async def test_update_rate_limits_if_unknown_bucket_hash_reroutes_route(self, bucket_manager): + route = mock.Mock() + route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash, auth: initial_hash + ";" + auth + ";bobs") + bucket_manager._routes_to_hashes[route.route] = "123" + bucket = mock.Mock() + bucket_manager._real_hashes_to_buckets["UNKNOWN;auth_hash;bobs"] = bucket - @pytest.mark.asyncio() - async def test_update_rate_limits_if_right_bucket_hash_does_nothing_to_hash(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - route = mock.Mock() - route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash: initial_hash + ";bobs") - mgr.routes_to_hashes[route.route] = "123" - bucket = mock.Mock(reset_at=time.perf_counter() + 999999999999999999999999999) - mgr.real_hashes_to_buckets["123;bobs"] = bucket + stack = contextlib.ExitStack() + create_authentication_hash = stack.enter_context( + mock.patch.object(buckets, "_create_authentication_hash", return_value="auth_hash") + ) + create_unknown_hash = stack.enter_context( + mock.patch.object(buckets, "_create_unknown_hash", return_value="UNKNOWN;auth_hash;bobs") + ) + stack.enter_context(mock.patch.object(hikari_date, "monotonic", return_value=27)) - with mock.patch.object(hikari_date, "monotonic", return_value=27): - mgr.update_rate_limits(route, "123", 22, 23, 7.65) + with stack: + bucket_manager.update_rate_limits(route, "auth", "blep", 22, 23, 3.56) - assert mgr.routes_to_hashes[route.route] == "123" - assert mgr.real_hashes_to_buckets["123;bobs"] is bucket - bucket.update_rate_limit.assert_called_once_with(22, 23, 27 + 7.65) + assert bucket_manager._routes_to_hashes[route.route] == "blep" + assert bucket_manager._real_hashes_to_buckets["blep;auth_hash;bobs"] is bucket + bucket.resolve.assert_called_once_with("blep;auth_hash;bobs") + bucket.update_rate_limit.assert_called_once_with(22, 23, 27 + 3.56) + create_unknown_hash.assert_called_once_with(route, "auth_hash") + create_authentication_hash.assert_called_once_with("auth") @pytest.mark.asyncio() - async def test_update_rate_limits_updates_params(self): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - route = mock.Mock() - route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash: initial_hash + ";bobs") - mgr.routes_to_hashes[route.route] = "123" - bucket = mock.Mock(reset_at=time.perf_counter() + 999999999999999999999999999) - mgr.real_hashes_to_buckets["123;bobs"] = bucket + async def test_update_rate_limits_if_right_bucket_hash_does_nothing_to_hash(self, bucket_manager): + route = mock.Mock() + route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash, auth: initial_hash + ";" + auth + ";bobs") + bucket_manager._routes_to_hashes[route.route] = "123" + bucket = mock.Mock(reset_at=time.perf_counter() + 999999999999999999999999999) + bucket_manager._real_hashes_to_buckets["123;auth_hash;bobs"] = bucket + + with mock.patch.object(buckets, "_create_authentication_hash", return_value="auth_hash"): + with mock.patch.object(hikari_date, "monotonic", return_value=27): + bucket_manager.update_rate_limits(route, "auth", "123", 22, 23, 7.65) + assert bucket_manager._routes_to_hashes[route.route] == "123" + assert bucket_manager._real_hashes_to_buckets["123;auth_hash;bobs"] is bucket + bucket.update_rate_limit.assert_called_once_with(22, 23, 27 + 7.65) + + @pytest.mark.asyncio() + async def test_update_rate_limits_updates_params(self, bucket_manager): + route = mock.Mock() + route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash, auth: initial_hash + ";" + auth + ";bobs") + bucket_manager._routes_to_hashes[route.route] = "123" + bucket = mock.Mock(reset_at=time.perf_counter() + 999999999999999999999999999) + bucket_manager._real_hashes_to_buckets["123;auth_hash;bobs"] = bucket + + with mock.patch.object(buckets, "_create_authentication_hash", return_value="auth_hash"): with mock.patch.object(hikari_date, "monotonic", return_value=27): - mgr.update_rate_limits(route, "123", 22, 23, 5.32) + bucket_manager.update_rate_limits(route, "auth", "123", 22, 23, 5.32) bucket.update_rate_limit.assert_called_once_with(22, 23, 27 + 5.32) - @pytest.mark.parametrize(("gc_task", "is_started"), [(None, False), (mock.Mock(spec_set=asyncio.Task), True)]) - def test_is_started(self, gc_task, is_started): - with buckets.RESTBucketManager(max_rate_limit=float("inf")) as mgr: - mgr.gc_task = gc_task - assert mgr.is_started is is_started + @pytest.mark.parametrize(("closed_event", "is_alive"), [(None, False), ("some", True)]) + def test_is_alive(self, bucket_manager, closed_event, is_alive): + bucket_manager._closed_event = closed_event + assert bucket_manager.is_alive is is_alive diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index d992e626eb..c0fc9c5044 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -51,7 +51,6 @@ from hikari import users from hikari import webhooks from hikari.api import rest as rest_api -from hikari.impl import buckets from hikari.impl import config from hikari.impl import entity_factory from hikari.impl import rate_limits @@ -294,82 +293,6 @@ def test_invalidate_when_token_is_stored_token(self): assert strategy._token is None -################### -# _LiveAttributes # -################### - - -class Test_LiveAttributes: - def test_build(self): - stack = contextlib.ExitStack() - create_tcp_connector = stack.enter_context(mock.patch.object(net, "create_tcp_connector")) - create_client_session = stack.enter_context(mock.patch.object(net, "create_client_session")) - bucket_manager = stack.enter_context(mock.patch.object(buckets, "RESTBucketManager")) - manual_rate_limiter = stack.enter_context(mock.patch.object(rate_limits, "ManualRateLimiter")) - stack.enter_context(mock.patch.object(asyncio, "get_running_loop")) - mock_settings = object() - mock_proxy_settings = mock.Mock() - - with stack: - attributes = rest._LiveAttributes.build(123.321, mock_settings, mock_proxy_settings) - - assert isinstance(attributes, rest._LiveAttributes) - assert attributes.is_closing is False - assert attributes.buckets is bucket_manager.return_value - assert attributes.client_session is create_client_session.return_value - assert isinstance(attributes.closed_event, asyncio.Event) - assert attributes.global_rate_limit is manual_rate_limiter.return_value - assert attributes.tcp_connector is create_tcp_connector.return_value - - bucket_manager.assert_called_once_with(123.321) - bucket_manager.return_value.start.assert_called_once_with() - create_tcp_connector.assert_called_once_with(mock_settings) - create_client_session.assert_called_once_with( - connector=create_tcp_connector.return_value, - connector_owner=False, - http_settings=mock_settings, - raise_for_status=False, - trust_env=mock_proxy_settings.trust_env, - ) - manual_rate_limiter.assert_called_once_with() - - def test_build_when_no_running_loop(self): - with pytest.raises(RuntimeError): - rest._LiveAttributes.build(123.321, object(), object()) - - @pytest.mark.asyncio() - async def test_close(self): - attributes = rest._LiveAttributes( - buckets=mock.Mock(), - client_session=mock.AsyncMock(), - closed_event=mock.Mock(), - global_rate_limit=mock.Mock(), - tcp_connector=mock.AsyncMock(), - ) - - await attributes.close() - - assert attributes.is_closing is True - attributes.buckets.close.assert_called_once_with() - attributes.client_session.close.assert_awaited_once_with() - attributes.closed_event.set.assert_called_once_with() - attributes.global_rate_limit.close.assert_called_once_with() - attributes.tcp_connector.close.assert_awaited_once_with() - - def test_still_alive_when_alive(self): - attributes = hikari_test_helpers.mock_class_namespace(rest._LiveAttributes, init_=False)() - attributes.is_closing = False - - assert attributes.still_alive() is attributes - - def test_still_alive_when_closing(self): - attributes = hikari_test_helpers.mock_class_namespace(rest._LiveAttributes, init_=False)() - attributes.is_closing = True - - with pytest.raises(errors.ComponentStateConflictError): - attributes.still_alive() - - ########### # RESTApp # ########### @@ -403,6 +326,8 @@ def test_proxy_settings(self, rest_app): assert rest_app.proxy_settings is mock_proxy_settings def test_acquire(self, rest_app): + rest_app._client_session = object() + rest_app._bucket_manager = object() stack = contextlib.ExitStack() mock_entity_factory = stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl")) mock_client = stack.enter_context(mock.patch.object(rest, "RESTClientImpl")) @@ -415,12 +340,15 @@ def test_acquire(self, rest_app): entity_factory=mock_entity_factory.return_value, executor=rest_app._executor, http_settings=rest_app._http_settings, - max_rate_limit=float("inf"), max_retries=0, proxy_settings=rest_app._proxy_settings, token="token", token_type="Type", rest_url=rest_app._url, + bucket_manager=rest_app._bucket_manager, + bucket_manager_owner=False, + client_session=rest_app._client_session, + client_session_owner=False, ) rest_provider = mock_entity_factory.call_args_list[0][0][0] @@ -429,6 +357,8 @@ def test_acquire(self, rest_app): assert rest_provider.executor is rest_app._executor def test_acquire_defaults_to_bearer_for_a_string_token(self, rest_app): + rest_app._client_session = object() + rest_app._bucket_manager = object() stack = contextlib.ExitStack() mock_entity_factory = stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl")) mock_client = stack.enter_context(mock.patch.object(rest, "RESTClientImpl")) @@ -441,12 +371,15 @@ def test_acquire_defaults_to_bearer_for_a_string_token(self, rest_app): entity_factory=mock_entity_factory.return_value, executor=rest_app._executor, http_settings=rest_app._http_settings, - max_rate_limit=float("inf"), max_retries=0, proxy_settings=rest_app._proxy_settings, token="token", token_type=applications.TokenType.BEARER, rest_url=rest_app._url, + bucket_manager=rest_app._bucket_manager, + bucket_manager_owner=False, + client_session=rest_app._client_session, + client_session_owner=False, ) rest_provider = mock_entity_factory.call_args_list[0][0][0] @@ -465,24 +398,13 @@ def rest_client_class(): return hikari_test_helpers.mock_class_namespace(rest.RESTClientImpl, slots_=False) -@pytest.fixture() -def live_attributes(): - attributes = mock.Mock( - buckets=mock.Mock(acquire=mock.Mock(return_value=hikari_test_helpers.AsyncContextManagerMock())), - global_rate_limit=mock.Mock(acquire=mock.AsyncMock()), - close=mock.AsyncMock(), - ) - attributes.still_alive.return_value = attributes - return attributes - - @pytest.fixture() def mock_cache(): return mock.Mock() @pytest.fixture() -def rest_client(rest_client_class, live_attributes, mock_cache): +def rest_client(rest_client_class, mock_cache): obj = rest_client_class( cache=mock_cache, http_settings=mock.Mock(spec=config.HTTPSettings), @@ -494,8 +416,13 @@ def rest_client(rest_client_class, live_attributes, mock_cache): rest_url="https://some.where/api/v3", executor=object(), entity_factory=mock.Mock(), + bucket_manager=mock.Mock( + acquire_bucket=mock.Mock(return_value=hikari_test_helpers.AsyncContextManagerMock()), + acquire_authentication=mock.AsyncMock(), + ), + client_session=mock.Mock(request=mock.AsyncMock()), ) - obj._live_attributes = live_attributes + obj._close_event = object() return obj @@ -546,6 +473,48 @@ def __init__(self, id=0): self.id = snowflakes.Snowflake(id) +class TestStringifyHttpMessage: + def test_when_body_is_None(self, rest_client): + headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"} + expected_return = " HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**" + assert rest._stringify_http_message(headers, None) == expected_return + + @pytest.mark.parametrize(("body", "expected"), [(bytes("hello :)", "ascii"), "hello :)"), (123, "123")]) + def test_when_body_is_not_None(self, rest_client, body, expected): + headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"} + expected_return = ( + f" HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**\n\n {expected}" + ) + assert rest._stringify_http_message(headers, body) == expected_return + + +class TestTransformEmojiToUrlFormat: + @pytest.mark.parametrize( + ("emoji", "expected_return"), + [ + (emojis.CustomEmoji(id=123, name="rooYay", is_animated=False), "rooYay:123"), + ("\N{OK HAND SIGN}", "\N{OK HAND SIGN}"), + (emojis.UnicodeEmoji("\N{OK HAND SIGN}"), "\N{OK HAND SIGN}"), + ], + ) + def test_expected(self, rest_client, emoji, expected_return): + assert rest._transform_emoji_to_url_format(emoji, undefined.UNDEFINED) == expected_return + + def test_with_id(self, rest_client): + assert rest._transform_emoji_to_url_format("rooYay", 123) == "rooYay:123" + + @pytest.mark.parametrize( + "emoji", + [ + emojis.CustomEmoji(id=123, name="rooYay", is_animated=False), + emojis.UnicodeEmoji("\N{OK HAND SIGN}"), + ], + ) + def test_when_id_passed_with_emoji_object(self, rest_client, emoji): + with pytest.raises(ValueError, match="emoji_id shouldn't be passed when an Emoji object is passed for emoji"): + rest._transform_emoji_to_url_format(emoji, 123) + + class TestRESTClientImpl: def test__init__when_max_retries_over_5(self): with pytest.raises(ValueError, match="'max_retries' must be below or equal to 5"): @@ -678,7 +647,7 @@ def test___exit__(self, rest_client): @pytest.mark.parametrize(("attributes", "expected_result"), [(None, False), (object(), True)]) def test_is_alive_property(self, rest_client, attributes, expected_result): - rest_client._live_attributes = attributes + rest_client._close_event = attributes assert rest_client.is_alive is expected_result @@ -700,87 +669,78 @@ def test_token_type_property(self, rest_client): rest_client._token_type = mock_type assert rest_client.token_type is mock_type + @pytest.mark.parametrize("client_session_owner", [True, False]) + @pytest.mark.parametrize("bucket_manager_owner", [True, False]) @pytest.mark.asyncio() - async def test_close(self, rest_client): - rest_client._live_attributes = mock_live_attributes = mock.AsyncMock() + async def test_close(self, rest_client, client_session_owner, bucket_manager_owner): + rest_client._close_event = mock_close_event = mock.Mock() + rest_client._client_session.close = client_close = mock.AsyncMock() + rest_client._client_session_owner = client_session_owner + rest_client._bucket_manager_owner = bucket_manager_owner await rest_client.close() - mock_live_attributes.close.assert_awaited_once_with() - assert rest_client._live_attributes is None - - def test_start(self, rest_client): - rest_client._live_attributes = None + mock_close_event.set.assert_called_once_with() + assert rest_client._close_event is None - with mock.patch.object(rest._LiveAttributes, "build") as build: - rest_client.start() + if client_session_owner: + client_close.assert_awaited_once_with() + assert rest_client._client_session is None + else: + client_close.assert_not_called() + assert rest_client._client_session is not None - build.assert_called_once_with( - rest_client._max_rate_limit, rest_client.http_settings, rest_client.proxy_settings + if bucket_manager_owner: + rest_client._bucket_manager.close.assert_called_once_with() + else: + rest_client._bucket_manager.assert_not_called() + + @pytest.mark.parametrize("client_session_owner", [True, False]) + @pytest.mark.parametrize("bucket_manager_owner", [True, False]) + @pytest.mark.asyncio() # Function needs to be executed in a running loop + async def test_start(self, rest_client, client_session_owner, bucket_manager_owner): + rest_client._client_session = None + rest_client._close_event = None + rest_client._bucket_manager = mock.Mock() + rest_client._client_session_owner = client_session_owner + rest_client._bucket_manager_owner = bucket_manager_owner + + with mock.patch.object(net, "create_client_session") as create_client_session: + with mock.patch.object(net, "create_tcp_connector") as create_tcp_connector: + with mock.patch.object(asyncio, "Event") as event: + rest_client.start() + + assert rest_client._close_event is event.return_value + + if client_session_owner: + create_tcp_connector.assert_called_once_with(rest_client._http_settings) + create_client_session.assert_called_once_with( + connector=create_tcp_connector.return_value, + connector_owner=True, + http_settings=rest_client._http_settings, + raise_for_status=False, + trust_env=rest_client._proxy_settings.trust_env, ) - assert rest_client._live_attributes is build.return_value + assert rest_client._client_session is create_client_session.return_value + else: + assert rest_client._client_session is None + + if bucket_manager_owner: + rest_client._bucket_manager.start.assert_called_once_with() + else: + rest_client._bucket_manager.start.assert_not_called() def test_start_when_active(self, rest_client): - rest_client._live_attributes = object() + rest_client._close_event = object() with pytest.raises(errors.ComponentStateConflictError): rest_client.start() - def test__get_live_attributes_when_active(self, rest_client): - mock_attributes = rest_client._live_attributes = object() - - assert rest_client._get_live_attributes() is mock_attributes - - def test__get_live_attributes_when_inactive(self, rest_client): - rest_client._live_attributes = None - - with pytest.raises(errors.ComponentStateConflictError): - rest_client._get_live_attributes() - - @pytest.mark.parametrize( # noqa: PT014 - Duplicate test cases (false positive) - ("emoji", "expected_return"), - [ - (emojis.CustomEmoji(id=123, name="rooYay", is_animated=False), "rooYay:123"), - ("👌", "👌"), - ("\N{OK HAND SIGN}", "\N{OK HAND SIGN}"), - (emojis.UnicodeEmoji("\N{OK HAND SIGN}"), "\N{OK HAND SIGN}"), - ], - ) - def test__transform_emoji_to_url_format(self, rest_client, emoji, expected_return): - assert rest_client._transform_emoji_to_url_format(emoji, undefined.UNDEFINED) == expected_return - - def test__transform_emoji_to_url_format_with_id(self, rest_client): - assert rest_client._transform_emoji_to_url_format("rooYay", 123) == "rooYay:123" - - @pytest.mark.parametrize( - "emoji", - [ - emojis.CustomEmoji(id=123, name="rooYay", is_animated=False), - emojis.UnicodeEmoji("\N{OK HAND SIGN}"), - ], - ) - def test__transform_emoji_to_url_format_when_id_passed_with_emoji_object(self, rest_client, emoji): - with pytest.raises(ValueError, match="emoji_id shouldn't be passed when an Emoji object is passed for emoji"): - rest_client._transform_emoji_to_url_format(emoji, 123) - - def test__stringify_http_message_when_body_is_None(self, rest_client): - headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"} - expected_return = " HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**" - assert rest_client._stringify_http_message(headers, None) == expected_return - - @pytest.mark.parametrize(("body", "expected"), [(bytes("hello :)", "ascii"), "hello :)"), (123, "123")]) - def test__stringify_http_message_when_body_is_not_None(self, rest_client, body, expected): - headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"} - expected_return = ( - f" HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**\n\n {expected}" - ) - assert rest_client._stringify_http_message(headers, body) == expected_return - ####################### # Non-async endpoints # ####################### - def test_trigger_typing(self, rest_client, live_attributes): + def test_trigger_typing(self, rest_client): channel = StubModel(123) stub_iterator = mock.Mock() @@ -788,7 +748,7 @@ def test_trigger_typing(self, rest_client, live_attributes): assert rest_client.trigger_typing(channel) == stub_iterator typing_indicator.assert_called_once_with( - request_call=rest_client._request, channel=channel, rest_closed_event=live_attributes.closed_event + request_call=rest_client._request, channel=channel, rest_close_event=rest_client._close_event ) @pytest.mark.parametrize( @@ -889,10 +849,10 @@ def test_fetch_reactions_for_emoji(self, rest_client): channel = StubModel(123) message = StubModel(456) stub_iterator = mock.Mock() - rest_client._transform_emoji_to_url_format = mock.Mock(return_value="rooYay:123") with mock.patch.object(special_endpoints, "ReactorIterator", return_value=stub_iterator) as iterator: - assert rest_client.fetch_reactions_for_emoji(channel, message, "<:rooYay:123>") == stub_iterator + with mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"): + assert rest_client.fetch_reactions_for_emoji(channel, message, "<:rooYay:123>") == stub_iterator iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, @@ -1762,11 +1722,9 @@ async def test___aenter__and__aexit__(self, rest_client): rest_client.close.assert_awaited_once_with() @hikari_test_helpers.timeout() - async def test__request_builds_form_when_passed(self, rest_client, exit_exception, live_attributes): + async def test_perform_request_builds_form_when_passed(self, rest_client, exit_exception): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session + rest_client._client_session.request.side_effect = exit_exception rest_client._token = None mock_form = mock.AsyncMock() mock_stack = mock.AsyncMock() @@ -1774,46 +1732,40 @@ async def test__request_builds_form_when_passed(self, rest_client, exit_exceptio with mock.patch.object(contextlib, "AsyncExitStack", return_value=mock_stack) as exit_stack: with pytest.raises(exit_exception): - await rest_client._request(route, form_builder=mock_form) + await rest_client._perform_request(route, form_builder=mock_form) - _, kwargs = mock_session.request.call_args_list[0] + _, kwargs = rest_client._client_session.request.call_args_list[0] mock_form.build.assert_awaited_once_with(exit_stack.return_value) assert kwargs["data"] is mock_form.build.return_value - assert live_attributes.still_alive.call_count == 3 @hikari_test_helpers.timeout() - async def test__request_url_encodes_reason_header(self, rest_client, exit_exception, live_attributes): + async def test_perform_request_url_encodes_reason_header(self, rest_client, exit_exception): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session + rest_client._client_session.request.side_effect = exit_exception with pytest.raises(exit_exception): - await rest_client._request(route, reason="光のenergyが 大地に降りそそぐ") + await rest_client._perform_request(route, reason="光のenergyが 大地に降りそそぐ") - _, kwargs = mock_session.request.call_args_list[0] + _, kwargs = rest_client._client_session.request.call_args_list[0] assert kwargs["headers"][rest._X_AUDIT_LOG_REASON_HEADER] == ( "%E5%85%89%E3%81%AEenergy%E3%81%8C%E3%80%80%E5%A4%" "A7%E5%9C%B0%E3%81%AB%E9%99%8D%E3%82%8A%E3%81%9D%E3%81%9D%E3%81%90" ) @hikari_test_helpers.timeout() - async def test__request_with_strategy_token(self, rest_client, exit_exception, live_attributes): + async def test_perform_request_with_strategy_token(self, rest_client, exit_exception): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session + rest_client._client_session.request.side_effect = exit_exception rest_client._token = mock.Mock(rest_api.TokenStrategy, acquire=mock.AsyncMock(return_value="Bearer ok.ok.ok")) with pytest.raises(exit_exception): - await rest_client._request(route) + await rest_client._perform_request(route) - _, kwargs = mock_session.request.call_args_list[0] + _, kwargs = rest_client._client_session.request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" - assert live_attributes.still_alive.call_count == 3 @hikari_test_helpers.timeout() - async def test__request_retries_strategy_once(self, rest_client, exit_exception, live_attributes): + async def test_perform_request_retries_strategy_once(self, rest_client, exit_exception): class StubResponse: status = http.HTTPStatus.UNAUTHORIZED content_type = rest._APPLICATION_JSON @@ -1824,26 +1776,23 @@ async def read(self): return '{"something": null}' route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock( - request=hikari_test_helpers.CopyingAsyncMock(side_effect=[StubResponse(), exit_exception]) + rest_client._client_session.request = hikari_test_helpers.CopyingAsyncMock( + side_effect=[StubResponse(), exit_exception] ) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session rest_client._token = mock.Mock( rest_api.TokenStrategy, acquire=mock.AsyncMock(side_effect=["Bearer ok.ok.ok", "Bearer ok2.ok2.ok2"]) ) with pytest.raises(exit_exception): - await rest_client._request(route) + await rest_client._perform_request(route) - _, kwargs = mock_session.request.call_args_list[0] + _, kwargs = rest_client._client_session.request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" - _, kwargs = mock_session.request.call_args_list[1] + _, kwargs = rest_client._client_session.request.call_args_list[1] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok2.ok2.ok2" - assert live_attributes.still_alive.call_count == 6 @hikari_test_helpers.timeout() - async def test__request_raises_after_re_auth_attempt(self, rest_client, exit_exception, live_attributes): + async def test_perform_request_raises_after_re_auth_attempt(self, rest_client, exit_exception): class StubResponse: status = http.HTTPStatus.UNAUTHORIZED content_type = rest._APPLICATION_JSON @@ -1858,107 +1807,87 @@ async def json(self): return {"something": None} route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock( - request=hikari_test_helpers.CopyingAsyncMock(side_effect=[StubResponse(), StubResponse(), StubResponse()]) + rest_client._client_session.request = hikari_test_helpers.CopyingAsyncMock( + side_effect=[StubResponse(), StubResponse(), StubResponse()] ) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session rest_client._token = mock.Mock( rest_api.TokenStrategy, acquire=mock.AsyncMock(side_effect=["Bearer ok.ok.ok", "Bearer ok2.ok2.ok2"]) ) with pytest.raises(errors.UnauthorizedError): - await rest_client._request(route) + await rest_client._perform_request(route) - _, kwargs = mock_session.request.call_args_list[0] + _, kwargs = rest_client._client_session.request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" - _, kwargs = mock_session.request.call_args_list[1] + _, kwargs = rest_client._client_session.request.call_args_list[1] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok2.ok2.ok2" - assert live_attributes.still_alive.call_count == 6 @hikari_test_helpers.timeout() - async def test__request_when__token_is_None(self, rest_client, exit_exception, live_attributes): + async def test_perform_request_when__token_is_None(self, rest_client, exit_exception): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session + rest_client._client_session.request.side_effect = exit_exception rest_client._token = None with pytest.raises(exit_exception): - await rest_client._request(route) + await rest_client._perform_request(route) - _, kwargs = mock_session.request.call_args_list[0] + _, kwargs = rest_client._client_session.request.call_args_list[0] assert rest._AUTHORIZATION_HEADER not in kwargs["headers"] @hikari_test_helpers.timeout() - async def test__request_when__token_is_not_None(self, rest_client, exit_exception, live_attributes): + async def test_perform_request_when__token_is_not_None(self, rest_client, exit_exception): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session + rest_client._client_session.request.side_effect = exit_exception rest_client._token = "token" with pytest.raises(exit_exception): - await rest_client._request(route) + await rest_client._perform_request(route) - _, kwargs = mock_session.request.call_args_list[0] + _, kwargs = rest_client._client_session.request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "token" - assert live_attributes.still_alive.call_count == 3 @hikari_test_helpers.timeout() - async def test__request_when_no_auth_passed(self, rest_client, exit_exception, live_attributes): + async def test_perform_request_when_no_auth_passed(self, rest_client, exit_exception): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session + rest_client._client_session.request.side_effect = exit_exception rest_client._token = "token" with pytest.raises(exit_exception): - await rest_client._request(route, no_auth=True) + await rest_client._perform_request(route, auth=None) - _, kwargs = mock_session.request.call_args_list[0] + _, kwargs = rest_client._client_session.request.call_args_list[0] assert rest._AUTHORIZATION_HEADER not in kwargs["headers"] - live_attributes.buckets.acquire.assert_called_once_with(route) - live_attributes.buckets.acquire.return_value.assert_used_once() - live_attributes.global_rate_limit.acquire.assert_not_called() - assert live_attributes.still_alive.call_count == 2 + rest_client._bucket_manager.acquire_bucket.assert_called_once_with(route, None) + rest_client._bucket_manager.acquire_bucket.return_value.assert_used_once() @hikari_test_helpers.timeout() - async def test__request_when_auth_passed(self, rest_client, exit_exception, live_attributes): + async def test_perform_request_when_auth_passed(self, rest_client, exit_exception): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exit_exception)) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session + rest_client._client_session.request.side_effect = exit_exception rest_client._token = "token" with pytest.raises(exit_exception): - await rest_client._request(route, auth="ooga booga") + await rest_client._perform_request(route, auth="ooga booga") - _, kwargs = mock_session.request.call_args_list[0] + _, kwargs = rest_client._client_session.request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "ooga booga" - live_attributes.buckets.acquire.assert_called_once_with(route) - live_attributes.buckets.acquire.return_value.assert_used_once() - live_attributes.global_rate_limit.acquire.assert_awaited_once_with() - assert live_attributes.still_alive.call_count == 3 + rest_client._bucket_manager.acquire_bucket.assert_called_once_with(route, "ooga booga") + rest_client._bucket_manager.acquire_bucket.return_value.assert_used_once() @hikari_test_helpers.timeout() - async def test__request_when_response_is_NO_CONTENT(self, rest_client, live_attributes): + async def test_perform_request_when_response_is_NO_CONTENT(self, rest_client): class StubResponse: status = http.HTTPStatus.NO_CONTENT reason = "cause why not" route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse())) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - rest_client._parse_ratelimits = mock.AsyncMock() - - assert (await rest_client._request(route)) is None + rest_client._client_session.request.return_value = StubResponse() + rest_client._parse_ratelimits = mock.AsyncMock(return_value=False) - assert live_attributes.still_alive.call_count == 3 + assert (await rest_client._perform_request(route)) is None @hikari_test_helpers.timeout() - async def test__request_when_response_is_APPLICATION_JSON(self, rest_client, live_attributes): + async def test_perform_request_when_response_is_APPLICATION_JSON(self, rest_client): class StubResponse: status = http.HTTPStatus.OK content_type = rest._APPLICATION_JSON @@ -1969,17 +1898,13 @@ async def read(self): return '{"something": null}' route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse())) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - rest_client._parse_ratelimits = mock.AsyncMock() + rest_client._client_session.request.return_value = StubResponse() + rest_client._parse_ratelimits = mock.AsyncMock(return_value=False) - assert (await rest_client._request(route)) == {"something": None} - - assert live_attributes.still_alive.call_count == 3 + assert (await rest_client._perform_request(route)) == {"something": None} @hikari_test_helpers.timeout() - async def test__request_when_response_is_not_JSON(self, rest_client, live_attributes): + async def test_perform_request_when_response_is_not_JSON(self, rest_client): class StubResponse: status = http.HTTPStatus.IM_USED content_type = "text/html" @@ -1987,49 +1912,39 @@ class StubResponse: real_url = "https://some.url" route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse())) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - rest_client._parse_ratelimits = mock.AsyncMock() + rest_client._client_session.request.return_value = StubResponse() + rest_client._parse_ratelimits = mock.AsyncMock(return_value=False) with pytest.raises(errors.HTTPError): - await rest_client._request(route) - - assert live_attributes.still_alive.call_count == 3 + await rest_client._perform_request(route) @hikari_test_helpers.timeout() - async def test__request_when_response_unhandled_status(self, rest_client, exit_exception, live_attributes): + async def test_perform_request_when_response_unhandled_status(self, rest_client, exit_exception): class StubResponse: status = http.HTTPStatus.NOT_IMPLEMENTED content_type = "text/html" reason = "cause why not" route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse())) - live_attributes.buckets.is_started = True - rest_client._parse_ratelimits = mock.AsyncMock() - live_attributes.client_session = mock_session - rest_client._handle_error_response = mock.AsyncMock(side_effect=exit_exception) + rest_client._client_session.request.return_value = StubResponse() - with pytest.raises(exit_exception): - await rest_client._request(route) + rest_client._parse_ratelimits = mock.AsyncMock(return_value=False) - assert live_attributes.still_alive.call_count == 3 + with mock.patch.object(net, "generate_error_response", return_value=exit_exception): + with pytest.raises(exit_exception): + await rest_client._perform_request(route) @hikari_test_helpers.timeout() - async def test__request_when_status_in_retry_codes_will_retry_until_exhausted( - self, rest_client, exit_exception, live_attributes + async def test_perform_request_when_status_in_retry_codes_will_retry_until_exhausted( + self, rest_client, exit_exception ): class StubResponse: status = http.HTTPStatus.INTERNAL_SERVER_ERROR route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse())) + rest_client._client_session.request.return_value = StubResponse() rest_client._max_retries = 3 - rest_client._parse_ratelimits = mock.AsyncMock() - rest_client._handle_error_response = mock.AsyncMock(side_effect=exit_exception) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session + rest_client._parse_ratelimits = mock.AsyncMock(return_value=False) stack = contextlib.ExitStack() stack.enter_context(pytest.raises(exit_exception)) @@ -2041,32 +1956,21 @@ class StubResponse: ) ) asyncio_sleep = stack.enter_context(mock.patch.object(asyncio, "sleep")) + generate_error_response = stack.enter_context( + mock.patch.object(net, "generate_error_response", return_value=exit_exception) + ) with stack: - await rest_client._request(route) + await rest_client._perform_request(route) - assert live_attributes.still_alive.call_count == 12 assert exponential_backoff.return_value.__next__.call_count == 3 exponential_backoff.assert_called_once_with(maximum=16) asyncio_sleep.assert_has_awaits([mock.call(1), mock.call(2), mock.call(3)]) - - @hikari_test_helpers.timeout() - async def test__request_when_response__RetryRequest_gets_handled( - self, rest_client, exit_exception, live_attributes - ): - route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=[rest._RetryRequest, exit_exception])) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - - with pytest.raises(exit_exception): - await rest_client._request(route) - - assert live_attributes.still_alive.call_count == 6 + generate_error_response.assert_called_once_with(rest_client._client_session.request.return_value) @pytest.mark.parametrize("enabled", [True, False]) @hikari_test_helpers.timeout() - async def test__request_logger(self, rest_client, enabled, live_attributes): + async def test_perform_request_logger(self, rest_client, enabled): class StubResponse: status = http.HTTPStatus.NO_CONTENT headers = {} @@ -2076,30 +1980,18 @@ async def read(self): return None route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - mock_session = mock.AsyncMock(request=mock.AsyncMock(return_value=StubResponse())) - live_attributes.buckets.is_started = True - live_attributes.client_session = mock_session - rest_client._parse_ratelimits = mock.AsyncMock() + rest_client._client_session.request.return_value = StubResponse() + rest_client._parse_ratelimits = mock.AsyncMock(return_value=False) with mock.patch.object(rest, "_LOGGER", new=mock.Mock(isEnabledFor=mock.Mock(return_value=enabled))) as logger: - await rest_client._request(route) + await rest_client._perform_request(route) if enabled: assert logger.log.call_count == 2 else: assert logger.log.call_count == 0 - assert live_attributes.still_alive.call_count == 3 - - async def test__handle_error_response(self, rest_client, exit_exception): - mock_response = mock.Mock() - with mock.patch.object(net, "generate_error_response", return_value=exit_exception) as generate_error_response: - with pytest.raises(exit_exception): - await rest_client._handle_error_response(mock_response) - - generate_error_response.assert_called_once_with(mock_response) - - async def test__parse_ratelimits_when_bucket_provided_updates_rate_limits(self, rest_client, live_attributes): + async def test__parse_ratelimits_when_bucket_provided_updates_rate_limits(self, rest_client): class StubResponse: status = http.HTTPStatus.OK headers = { @@ -2112,17 +2004,18 @@ class StubResponse: response = StubResponse() route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - await rest_client._parse_ratelimits(route, response, live_attributes) + await rest_client._parse_ratelimits(route, "auth", response) - live_attributes.buckets.update_rate_limits.assert_called_once_with( + rest_client._bucket_manager.update_rate_limits.assert_called_once_with( compiled_route=route, bucket_header="bucket_header", + authentication="auth", remaining_header=987654321, limit_header=123456789, reset_after=12.2, ) - async def test__parse_ratelimits_when_not_ratelimited(self, rest_client, live_attributes): + async def test__parse_ratelimits_when_not_ratelimited(self, rest_client): class StubResponse: status = http.HTTPStatus.OK headers = {} @@ -2132,12 +2025,11 @@ class StubResponse: response = StubResponse() route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - await rest_client._parse_ratelimits(route, response, live_attributes) + await rest_client._parse_ratelimits(route, "auth", response) response.json.assert_not_called() - live_attributes.still_alive.assert_not_called() - async def test__parse_ratelimits_when_ratelimited(self, rest_client, exit_exception, live_attributes): + async def test__parse_ratelimits_when_ratelimited(self, rest_client, exit_exception): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = rest._APPLICATION_JSON @@ -2148,11 +2040,9 @@ async def json(self): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) with pytest.raises(exit_exception): - await rest_client._parse_ratelimits(route, StubResponse(), live_attributes) - - live_attributes.still_alive.assert_not_called() + await rest_client._parse_ratelimits(route, "auth", StubResponse()) - async def test__parse_ratelimits_when_unexpected_content_type(self, rest_client, live_attributes): + async def test__parse_ratelimits_when_unexpected_content_type(self, rest_client): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = "text/html" @@ -2164,11 +2054,9 @@ async def read(self): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) with pytest.raises(errors.HTTPResponseError): - await rest_client._parse_ratelimits(route, StubResponse(), live_attributes) - - live_attributes.still_alive.assert_not_called() + await rest_client._parse_ratelimits(route, "auth", StubResponse()) - async def test__parse_ratelimits_when_global_ratelimit(self, rest_client, live_attributes): + async def test__parse_ratelimits_when_global_ratelimit(self, rest_client): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = rest._APPLICATION_JSON @@ -2179,13 +2067,11 @@ async def json(self): return {"global": True, "retry_after": "2"} route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - with pytest.raises(rest._RetryRequest): - await rest_client._parse_ratelimits(route, StubResponse(), live_attributes) + assert (await rest_client._parse_ratelimits(route, "auth", StubResponse())) is True - live_attributes.global_rate_limit.throttle.assert_called_once_with(2.0) - assert live_attributes.still_alive.call_count == 1 + rest_client._bucket_manager.throttle.assert_called_once_with(2.0) - async def test__parse_ratelimits_when_remaining_header_under_or_equal_to_0(self, rest_client, live_attributes): + async def test__parse_ratelimits_when_remaining_header_under_or_equal_to_0(self, rest_client): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = rest._APPLICATION_JSON @@ -2198,12 +2084,9 @@ async def json(self): return {"retry_after": "2", "global": False} route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - with pytest.raises(rest._RetryRequest): - await rest_client._parse_ratelimits(route, StubResponse(), live_attributes) + assert await rest_client._parse_ratelimits(route, "some auth", StubResponse()) is True - live_attributes.still_alive.assert_not_called() - - async def test__parse_ratelimits_when_retry_after_is_close_enough(self, rest_client, live_attributes): + async def test__parse_ratelimits_when_retry_after_is_close_enough(self, rest_client): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = rest._APPLICATION_JSON @@ -2216,12 +2099,9 @@ async def json(self): return {"retry_after": "0.002"} route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - with pytest.raises(rest._RetryRequest): - await rest_client._parse_ratelimits(route, StubResponse(), live_attributes) - - live_attributes.still_alive.assert_not_called() + assert await rest_client._parse_ratelimits(route, "some auth", StubResponse()) is True - async def test__parse_ratelimits_when_retry_after_is_not_close_enough(self, rest_client, live_attributes): + async def test__parse_ratelimits_when_retry_after_is_not_close_enough(self, rest_client): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = rest._APPLICATION_JSON @@ -2233,9 +2113,7 @@ async def json(self): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) with pytest.raises(errors.RateLimitedError): - await rest_client._parse_ratelimits(route, StubResponse(), live_attributes) - - live_attributes.still_alive.assert_not_called() + await rest_client._parse_ratelimits(route, "auth", StubResponse()) ############# # Endpoints # @@ -3000,33 +2878,37 @@ async def test_delete_messages_with_async_iterable_and_args(self, rest_client): async def test_add_reaction(self, rest_client): expected_route = routes.PUT_MY_REACTION.compile(emoji="rooYay:123", channel=123, message=456) rest_client._request = mock.AsyncMock() - rest_client._transform_emoji_to_url_format = mock.Mock(return_value="rooYay:123") - await rest_client.add_reaction(StubModel(123), StubModel(456), "<:rooYay:123>") + with mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"): + await rest_client.add_reaction(StubModel(123), StubModel(456), "<:rooYay:123>") + rest_client._request.assert_awaited_once_with(expected_route) async def test_delete_my_reaction(self, rest_client): expected_route = routes.DELETE_MY_REACTION.compile(emoji="rooYay:123", channel=123, message=456) rest_client._request = mock.AsyncMock() - rest_client._transform_emoji_to_url_format = mock.Mock(return_value="rooYay:123") - await rest_client.delete_my_reaction(StubModel(123), StubModel(456), "<:rooYay:123>") + with mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"): + await rest_client.delete_my_reaction(StubModel(123), StubModel(456), "<:rooYay:123>") + rest_client._request.assert_awaited_once_with(expected_route) async def test_delete_all_reactions_for_emoji(self, rest_client): expected_route = routes.DELETE_REACTION_EMOJI.compile(emoji="rooYay:123", channel=123, message=456) rest_client._request = mock.AsyncMock() - rest_client._transform_emoji_to_url_format = mock.Mock(return_value="rooYay:123") - await rest_client.delete_all_reactions_for_emoji(StubModel(123), StubModel(456), "<:rooYay:123>") + with mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"): + await rest_client.delete_all_reactions_for_emoji(StubModel(123), StubModel(456), "<:rooYay:123>") + rest_client._request.assert_awaited_once_with(expected_route) async def test_delete_reaction(self, rest_client): expected_route = routes.DELETE_REACTION_USER.compile(emoji="rooYay:123", channel=123, message=456, user=789) rest_client._request = mock.AsyncMock() - rest_client._transform_emoji_to_url_format = mock.Mock(return_value="rooYay:123") - await rest_client.delete_reaction(StubModel(123), StubModel(456), StubModel(789), "<:rooYay:123>") + with mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"): + await rest_client.delete_reaction(StubModel(123), StubModel(456), StubModel(789), "<:rooYay:123>") + rest_client._request.assert_awaited_once_with(expected_route) async def test_delete_all_reactions(self, rest_client): @@ -3069,7 +2951,7 @@ async def test_fetch_webhook(self, rest_client): rest_client._entity_factory.deserialize_webhook = mock.Mock(return_value=webhook) assert await rest_client.fetch_webhook(StubModel(123), token="token") is webhook - rest_client._request.assert_awaited_once_with(expected_route, no_auth=True) + rest_client._request.assert_awaited_once_with(expected_route, auth=None) rest_client._entity_factory.deserialize_webhook.assert_called_once_with({"id": "456"}) async def test_fetch_webhook_without_token(self, rest_client): @@ -3079,7 +2961,7 @@ async def test_fetch_webhook_without_token(self, rest_client): rest_client._entity_factory.deserialize_webhook = mock.Mock(return_value=webhook) assert await rest_client.fetch_webhook(StubModel(123)) is webhook - rest_client._request.assert_awaited_once_with(expected_route, no_auth=False) + rest_client._request.assert_awaited_once_with(expected_route, auth=undefined.UNDEFINED) rest_client._entity_factory.deserialize_webhook.assert_called_once_with({"id": "456"}) async def test_fetch_channel_webhooks(self, rest_client): @@ -3160,7 +3042,7 @@ async def test_edit_webhook(self, rest_client): assert returned is webhook rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason="some smart reason to do this", no_auth=True + expected_route, json=expected_json, reason="some smart reason to do this", auth=None ) rest_client._entity_factory.deserialize_webhook.assert_called_once_with({"id": "456"}) @@ -3175,7 +3057,7 @@ async def test_edit_webhook_without_token(self, rest_client): assert returned is webhook rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason=undefined.UNDEFINED, no_auth=False + expected_route, json=expected_json, reason=undefined.UNDEFINED, auth=undefined.UNDEFINED ) rest_client._entity_factory.deserialize_webhook.assert_called_once_with({"id": "456"}) @@ -3189,7 +3071,7 @@ async def test_edit_webhook_when_avatar_is_file(self, rest_client, file_resource assert await rest_client.edit_webhook(StubModel(123), avatar="someavatar.png") is webhook rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason=undefined.UNDEFINED, no_auth=False + expected_route, json=expected_json, reason=undefined.UNDEFINED, auth=undefined.UNDEFINED ) rest_client._entity_factory.deserialize_webhook.assert_called_once_with({"id": "456"}) @@ -3198,14 +3080,14 @@ async def test_delete_webhook(self, rest_client): rest_client._request = mock.AsyncMock(return_value={"id": "456"}) await rest_client.delete_webhook(StubModel(123), token="token") - rest_client._request.assert_awaited_once_with(expected_route, no_auth=True) + rest_client._request.assert_awaited_once_with(expected_route, auth=None) async def test_delete_webhook_without_token(self, rest_client): expected_route = routes.DELETE_WEBHOOK.compile(webhook=123) rest_client._request = mock.AsyncMock(return_value={"id": "456"}) await rest_client.delete_webhook(StubModel(123)) - rest_client._request.assert_awaited_once_with(expected_route, no_auth=False) + rest_client._request.assert_awaited_once_with(expected_route, auth=undefined.UNDEFINED) @pytest.mark.parametrize( ("webhook", "avatar_url"), @@ -3271,7 +3153,7 @@ async def test_execute_webhook_when_form(self, rest_client, webhook, avatar_url) expected_route, form_builder=mock_form, query={"wait": "true"}, - no_auth=True, + auth=None, ) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) @@ -3311,7 +3193,7 @@ async def test_execute_webhook_when_form_and_thread(self, rest_client): expected_route, form_builder=mock_form, query={"wait": "true", "thread_id": "1234543123"}, - no_auth=True, + auth=None, ) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) @@ -3345,7 +3227,7 @@ async def test_execute_webhook_when_no_form(self, rest_client): expected_route, json={"testing": "ensure_in_test"}, query={"wait": "true", "thread_id": "2134312123"}, - no_auth=True, + auth=None, ) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) @@ -3400,7 +3282,7 @@ async def test_execute_webhook_when_thread_and_no_form(self, rest_client): expected_route, json={"testing": "ensure_in_test", "username": "davfsa", "avatar_url": "https://website.com/davfsa_logo"}, query={"wait": "true"}, - no_auth=True, + auth=None, ) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) @@ -3413,7 +3295,7 @@ async def test_fetch_webhook_message(self, rest_client, webhook): assert await rest_client.fetch_webhook_message(webhook, "hi, im a token", StubModel(456)) is message_obj - rest_client._request.assert_awaited_once_with(expected_route, no_auth=True, query={}) + rest_client._request.assert_awaited_once_with(expected_route, auth=None, query={}) rest_client._entity_factory.deserialize_message.assert_called_once_with({"id": "456"}) async def test_fetch_webhook_message_when_thread(self, rest_client): @@ -3427,7 +3309,7 @@ async def test_fetch_webhook_message_when_thread(self, rest_client): ) assert result is message_obj - rest_client._request.assert_awaited_once_with(expected_route, no_auth=True, query={"thread_id": "54123123"}) + rest_client._request.assert_awaited_once_with(expected_route, auth=None, query={"thread_id": "54123123"}) rest_client._entity_factory.deserialize_message.assert_called_once_with({"id": "456"}) @pytest.mark.parametrize("webhook", [mock.Mock(webhooks.ExecutableWebhook, webhook_id=432), 432]) @@ -3478,7 +3360,7 @@ async def test_edit_webhook_message_when_form(self, rest_client, webhook): mock_form.add_field.assert_called_once_with( "payload_json", '{"testing": "ensure_in_test"}', content_type="application/json" ) - rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form, query={}, no_auth=True) + rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form, query={}, auth=None) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) async def test_edit_webhook_message_when_form_and_thread(self, rest_client): @@ -3511,7 +3393,7 @@ async def test_edit_webhook_message_when_form_and_thread(self, rest_client): "payload_json", '{"testing": "ensure_in_test"}', content_type="application/json" ) rest_client._request.assert_awaited_once_with( - expected_route, form_builder=mock_form, query={"thread_id": "123543123"}, no_auth=True + expected_route, form_builder=mock_form, query={"thread_id": "123543123"}, auth=None ) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) @@ -3559,7 +3441,7 @@ async def test_edit_webhook_message_when_no_form(self, rest_client: rest_api.RES edit=True, ) rest_client._request.assert_awaited_once_with( - expected_route, json={"testing": "ensure_in_test"}, query={}, no_auth=True + expected_route, json={"testing": "ensure_in_test"}, query={}, auth=None ) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) @@ -3589,7 +3471,7 @@ async def test_edit_webhook_message_when_thread_and_no_form(self, rest_client: r edit=True, ) rest_client._request.assert_awaited_once_with( - expected_route, json={"testing": "ensure_in_test"}, query={"thread_id": "2346523432"}, no_auth=True + expected_route, json={"testing": "ensure_in_test"}, query={"thread_id": "2346523432"}, auth=None ) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) @@ -3600,7 +3482,7 @@ async def test_delete_webhook_message(self, rest_client, webhook): await rest_client.delete_webhook_message(webhook, "token", StubModel(456)) - rest_client._request.assert_awaited_once_with(expected_route, no_auth=True, query={}) + rest_client._request.assert_awaited_once_with(expected_route, auth=None, query={}) async def test_delete_webhook_message_when_thread(self, rest_client): expected_route = routes.DELETE_WEBHOOK_MESSAGE.compile(webhook=123, token="token", message=456) @@ -3608,7 +3490,7 @@ async def test_delete_webhook_message_when_thread(self, rest_client): await rest_client.delete_webhook_message(123, "token", StubModel(456), thread=StubModel(432123)) - rest_client._request.assert_awaited_once_with(expected_route, no_auth=True, query={"thread_id": "432123"}) + rest_client._request.assert_awaited_once_with(expected_route, auth=None, query={"thread_id": "432123"}) async def test_fetch_gateway_url(self, rest_client): expected_route = routes.GET_GATEWAY.compile() @@ -3616,7 +3498,7 @@ async def test_fetch_gateway_url(self, rest_client): assert await rest_client.fetch_gateway_url() == "wss://some.url" - rest_client._request.assert_awaited_once_with(expected_route, no_auth=True) + rest_client._request.assert_awaited_once_with(expected_route, auth=None) async def test_fetch_gateway_bot(self, rest_client): bot = StubModel(123) @@ -4044,7 +3926,7 @@ async def test_fetch_sticker_packs(self, rest_client): assert await rest_client.fetch_available_sticker_packs() == [pack1, pack2, pack3] - rest_client._request.assert_awaited_once_with(expected_route, no_auth=True) + rest_client._request.assert_awaited_once_with(expected_route, auth=None) rest_client._entity_factory.deserialize_sticker_pack.assert_has_calls( [mock.call({"id": "123"}), mock.call({"id": "456"}), mock.call({"id": "789"})] ) @@ -6009,7 +5891,7 @@ async def test_fetch_interaction_response(self, rest_client): assert result is rest_client._entity_factory.deserialize_message.return_value rest_client._entity_factory.deserialize_message.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with(expected_route, no_auth=True) + rest_client._request.assert_awaited_once_with(expected_route, auth=None) async def test_create_interaction_response_when_form(self, rest_client): attachment_obj = object() @@ -6060,7 +5942,7 @@ async def test_create_interaction_response_when_form(self, rest_client): mock_form.add_field.assert_called_once_with( "payload_json", '{"type": 1, "data": {"testing": "ensure_in_test"}}', content_type="application/json" ) - rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form, no_auth=True) + rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form, auth=None) async def test_create_interaction_response_when_no_form(self, rest_client): attachment_obj = object() @@ -6108,7 +5990,7 @@ async def test_create_interaction_response_when_no_form(self, rest_client): role_mentions=[1234], ) rest_client._request.assert_awaited_once_with( - expected_route, json={"type": 1, "data": {"testing": "ensure_in_test"}}, no_auth=True + expected_route, json={"type": 1, "data": {"testing": "ensure_in_test"}}, auth=None ) async def test_edit_interaction_response_when_form(self, rest_client): @@ -6157,7 +6039,7 @@ async def test_edit_interaction_response_when_form(self, rest_client): mock_form.add_field.assert_called_once_with( "payload_json", '{"testing": "ensure_in_test"}', content_type="application/json" ) - rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form, no_auth=True) + rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form, auth=None) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) async def test_edit_interaction_response_when_no_form(self, rest_client): @@ -6202,7 +6084,7 @@ async def test_edit_interaction_response_when_no_form(self, rest_client): role_mentions=[1234], edit=True, ) - rest_client._request.assert_awaited_once_with(expected_route, json={"testing": "ensure_in_test"}, no_auth=True) + rest_client._request.assert_awaited_once_with(expected_route, json={"testing": "ensure_in_test"}, auth=None) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) async def test_delete_interaction_response(self, rest_client): @@ -6211,7 +6093,7 @@ async def test_delete_interaction_response(self, rest_client): await rest_client.delete_interaction_response(StubModel(1235431), "go homo now") - rest_client._request.assert_awaited_once_with(expected_route, no_auth=True) + rest_client._request.assert_awaited_once_with(expected_route, auth=None) async def test_create_autocomplete_response(self, rest_client): expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=1235431, token="snek") @@ -6223,7 +6105,7 @@ async def test_create_autocomplete_response(self, rest_client): rest_client._request.assert_awaited_once_with( expected_route, json={"type": 8, "data": {"choices": [{"name": "a", "value": "b"}, {"name": "foo", "value": "bar"}]}}, - no_auth=True, + auth=None, ) async def test_create_modal_response(self, rest_client): @@ -6241,7 +6123,7 @@ async def test_create_modal_response(self, rest_client): "type": 9, "data": {"title": "title", "custom_id": "idd", "components": [component.build.return_value]}, }, - no_auth=True, + auth=None, ) async def test_create_modal_response_with_plural_args(self, rest_client): @@ -6259,7 +6141,7 @@ async def test_create_modal_response_with_plural_args(self, rest_client): "type": 9, "data": {"title": "title", "custom_id": "idd", "components": [component.build.return_value]}, }, - no_auth=True, + auth=None, ) async def test_create_modal_response_when_both_component_and_components_passed(self, rest_client): diff --git a/tests/hikari/internal/test_routes.py b/tests/hikari/internal/test_routes.py index d6a6610fb8..2e4245a319 100644 --- a/tests/hikari/internal/test_routes.py +++ b/tests/hikari/internal/test_routes.py @@ -41,7 +41,7 @@ def test_create_url(self, compiled_route): assert compiled_route.create_url("https://some.url/api") == "https://some.url/api/some/endpoint" def test_create_real_bucket_hash(self, compiled_route): - assert compiled_route.create_real_bucket_hash("UNKNOWN") == "UNKNOWN;abc123" + assert compiled_route.create_real_bucket_hash("UNKNOWN", "AUTH_HASH") == "UNKNOWN;AUTH_HASH;abc123" def test__str__(self, compiled_route): assert str(compiled_route) == "GET /some/endpoint"