Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add orjson as an optional speedup and allow to pass custom json.dumps and json.loads functions to all components. #1486

Merged
merged 22 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ other internal settings in the interpreter.

If you have a C compiler (Microsoft VC++ Redistributable 14.0 or newer, or a modern copy of GCC/G++, Clang, etc), it is
recommended you install Hikari using `pip install -U hikari[speedups]`. This will install `aiohttp` with its available
speedups, and `ciso8601` which will provide you with a small performance boost.
speedups, `ciso8601` and `orjson` which will provide you with a substantial performance boost.

### `uvloop`

Expand Down
1 change: 1 addition & 0 deletions changes/1486.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `orjson` as an optional speedup and allow to pass custom `json.dumps` and `json.loads` functions to all components.
4 changes: 3 additions & 1 deletion hikari/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
)

import http
import json
import typing

import attr
Expand Down Expand Up @@ -317,7 +318,8 @@ def __str__(self) -> str:
try:
value += _dump_errors(self.errors).strip("\n")
except KeyError:
value += data_binding.dump_json(self.errors, indent=2)
# Use the stdlib json.dumps here to be able to indent
value += json.dumps(self.errors, indent=2)
davfsa marked this conversation as resolved.
Show resolved Hide resolved

self._cached_str = value
return value
Expand Down
15 changes: 15 additions & 0 deletions hikari/impl/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from hikari.impl import shard as shard_impl
from hikari.impl import voice as voice_impl
from hikari.internal import aio
from hikari.internal import data_binding
from hikari.internal import signals
from hikari.internal import time
from hikari.internal import ux
Expand Down Expand Up @@ -237,6 +238,10 @@ class GatewayBot(traits.GatewayBotAware):
proxy_settings : typing.Optional[hikari.impl.config.ProxySettings]
Custom proxy settings to use with network-layer logic
in your application to get through an HTTP-proxy.
dumps : hikari.internal.data_binding.JSONEncoder
The JSON encoder this application should use. Defaults to `hikari.internal.data_binding.default_json_dumps`.
loads : hikari.internal.data_binding.JSONDecoder
The JSON decoder this application should use. Defaults to `hikari.internal.data_binding.default_json_loads`.
rest_url : typing.Optional[str]
Defaults to the Discord REST API URL if `None`. Can be
overridden if you are attempting to point to an unofficial endpoint, or
Expand Down Expand Up @@ -291,6 +296,8 @@ class GatewayBot(traits.GatewayBotAware):
"_shards",
"_token",
"_voice",
"_loads",
"_dumps",
"shards",
)

Expand All @@ -305,6 +312,8 @@ def __init__(
force_color: bool = False,
cache_settings: typing.Optional[config_impl.CacheSettings] = None,
http_settings: typing.Optional[config_impl.HTTPSettings] = None,
dumps: data_binding.JSONEncoder = data_binding.default_json_dumps,
loads: data_binding.JSONDecoder = data_binding.default_json_loads,
intents: intents_.Intents = intents_.Intents.ALL_UNPRIVILEGED,
auto_chunk_members: bool = True,
logs: typing.Union[None, int, str, typing.Dict[str, typing.Any]] = "INFO",
Expand All @@ -326,6 +335,8 @@ def __init__(
self._intents = intents
self._proxy_settings = proxy_settings if proxy_settings is not None else config_impl.ProxySettings()
self._token = token.strip()
self._dumps = dumps
self._loads = loads

# Caching
cache_settings = cache_settings if cache_settings is not None else config_impl.CacheSettings()
Expand Down Expand Up @@ -357,6 +368,8 @@ def __init__(
http_settings=self._http_settings,
max_rate_limit=max_rate_limit,
proxy_settings=self._proxy_settings,
dumps=dumps,
loads=loads,
rest_url=rest_url,
max_retries=max_retries,
token=token,
Expand Down Expand Up @@ -1252,6 +1265,8 @@ async def _start_one_shard(
event_manager=self._event_manager,
event_factory=self._event_factory,
intents=self._intents,
dumps=self._dumps,
loads=self._loads,
initial_activity=activity,
initial_is_afk=afk,
initial_idle_since=idle_since,
Expand Down
21 changes: 11 additions & 10 deletions hikari/impl/interaction_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def status_code(self) -> int:

# Constant response
_PONG_RESPONSE: typing.Final[_Response] = _Response(
_OK_STATUS, data_binding.dump_json({"type": _PONG_RESPONSE_TYPE}).encode(), content_type=_JSON_CONTENT_TYPE
_OK_STATUS, data_binding.default_json_dumps({"type": _PONG_RESPONSE_TYPE}), content_type=_JSON_CONTENT_TYPE
)


Expand Down Expand Up @@ -193,10 +193,10 @@ class InteractionServer(interaction_server.InteractionServer):

Other Parameters
----------------
dumps : aiohttp.typedefs.JSONEncoder
The JSON encoder this server should use. Defaults to `json.dumps`.
loads : aiohttp.typedefs.JSONDecoder
The JSON decoder this server should use. Defaults to `json.loads`.
dumps : hikari.internal.data_binding.JSONEncoder
The JSON encoder this server should use. Defaults to `hikari.internal.data_binding.default_json_dumps`.
loads : hikari.internal.data_binding.JSONDecoder
The JSON decoder this server should use. Defaults to `hikari.internal.data_binding.default_json_loads`.
public_key : bytes
The public key this server should use for verifying request payloads from
Discord. If left as `None` then the client will try to work this
Expand Down Expand Up @@ -224,10 +224,10 @@ class InteractionServer(interaction_server.InteractionServer):
def __init__(
self,
*,
dumps: aiohttp.typedefs.JSONEncoder = data_binding.dump_json,
dumps: data_binding.JSONEncoder = data_binding.default_json_dumps,
entity_factory: entity_factory_api.EntityFactory,
executor: typing.Optional[concurrent.futures.Executor] = None,
loads: aiohttp.typedefs.JSONDecoder = data_binding.load_json,
loads: data_binding.JSONDecoder = data_binding.default_json_loads,
rest_client: rest_api.RESTClient,
public_key: typing.Optional[bytes] = None,
) -> None:
Expand Down Expand Up @@ -433,10 +433,11 @@ async def on_interaction(self, body: bytes, signature: bytes, timestamp: bytes)
return _Response(_BAD_REQUEST_STATUS, b"Invalid request signature")

try:
payload = self._loads(body.decode("utf-8"))
payload = self._loads(body)
assert isinstance(payload, dict)
interaction_type = int(payload["type"])

except (data_binding.JSONDecodeError, ValueError, TypeError) as exc:
except (ValueError, TypeError) as exc:
_LOGGER.error("Received a request with an invalid JSON body", exc_info=exc)
return _Response(_BAD_REQUEST_STATUS, b"Invalid JSON body")

Expand Down Expand Up @@ -484,7 +485,7 @@ async def on_interaction(self, body: bytes, signature: bytes, timestamp: bytes)
)
return _Response(_INTERNAL_SERVER_ERROR_STATUS, b"Exception occurred during interaction dispatch")

return _Response(_OK_STATUS, payload.encode(), files=files, content_type=_JSON_CONTENT_TYPE)
return _Response(_OK_STATUS, payload, files=files, content_type=_JSON_CONTENT_TYPE)

_LOGGER.debug(
"Ignoring interaction %s of type %s without registered listener", interaction.id, interaction.type
Expand Down
60 changes: 45 additions & 15 deletions hikari/impl/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ class RESTApp(traits.ExecutorAware):
http_settings : typing.Optional[hikari.impl.config.HTTPSettings]
HTTP settings to use. Sane defaults are used if this is
`None`.
dumps : hikari.internal.data_binding.JSONEncoder
The JSON encoder this application should use. Defaults to `hikari.internal.data_binding.default_json_dumps`.
loads : hikari.internal.data_binding.JSONDecoder
The JSON decoder this application should use. Defaults to `hikari.internal.data_binding.default_json_loads`.
max_rate_limit : float
Maximum number of seconds to sleep for when rate limited. If a rate
limit occurs that is longer than this value, then a
Expand Down Expand Up @@ -292,20 +296,26 @@ class RESTApp(traits.ExecutorAware):
"_url",
"_bucket_manager",
"_client_session",
"_loads",
"_dumps",
)

def __init__(
self,
*,
executor: typing.Optional[concurrent.futures.Executor] = None,
http_settings: typing.Optional[config_impl.HTTPSettings] = None,
dumps: data_binding.JSONEncoder = data_binding.default_json_dumps,
loads: data_binding.JSONDecoder = data_binding.default_json_loads,
max_rate_limit: float = 300.0,
max_retries: int = 3,
proxy_settings: typing.Optional[config_impl.ProxySettings] = None,
url: typing.Optional[str] = None,
) -> None:
self._http_settings = config_impl.HTTPSettings() if http_settings is None else http_settings
self._proxy_settings = config_impl.ProxySettings() if proxy_settings is None else proxy_settings
self._loads = loads
self._dumps = dumps
self._executor = executor
self._max_retries = max_retries
self._url = url
Expand Down Expand Up @@ -428,6 +438,8 @@ def acquire(
http_settings=self._http_settings,
max_retries=self._max_retries,
proxy_settings=self._proxy_settings,
loads=self._loads,
dumps=self._dumps,
token=token,
token_type=token_type,
rest_url=self._url,
Expand Down Expand Up @@ -487,6 +499,10 @@ class RESTClientImpl(rest_api.RESTClient):
max_retries : typing.Optional[int]
Maximum number of times a request will be retried if
it fails with a `5xx` status. Defaults to 3 if set to `None`.
dumps : hikari.internal.data_binding.JSONEncoder
The JSON encoder this application should use. Defaults to `hikari.internal.data_binding.default_json_dumps`.
loads : hikari.internal.data_binding.JSONDecoder
The JSON decoder this application should use. Defaults to `hikari.internal.data_binding.default_json_loads`.
token : typing.Union[str, None, hikari.api.rest.TokenStrategy]
The bot or bearer token. If no token is to be used,
this can be undefined.
Expand Down Expand Up @@ -518,6 +534,8 @@ class RESTClientImpl(rest_api.RESTClient):
"_http_settings",
"_max_retries",
"_proxy_settings",
"_dumps",
"_loads",
"_rest_url",
"_token",
"_token_type",
Expand All @@ -540,6 +558,8 @@ def __init__(
max_rate_limit: float = 300.0,
max_retries: int = 3,
proxy_settings: config_impl.ProxySettings,
dumps: data_binding.JSONEncoder = data_binding.default_json_dumps,
loads: data_binding.JSONDecoder = data_binding.default_json_loads,
token: typing.Union[str, None, rest_api.TokenStrategy],
token_type: typing.Union[applications.TokenType, str, None],
rest_url: typing.Optional[str],
Expand All @@ -562,6 +582,8 @@ def __init__(
self._http_settings = http_settings
self._max_retries = max_retries
self._proxy_settings = proxy_settings
self._dumps = dumps
self._loads = loads
self._bucket_manager = (
buckets_impl.RESTBucketManager(max_rate_limit) if bucket_manager is None else bucket_manager
)
Expand Down Expand Up @@ -723,7 +745,7 @@ async def _perform_request(
*,
query: typing.Optional[data_binding.StringMapBuilder] = None,
form_builder: typing.Optional[data_binding.URLEncodedFormBuilder] = None,
json: typing.Union[data_binding.JSONObjectBuilder, data_binding.JSONArray, None] = None,
json: typing.Union[data_binding.JSONObject, data_binding.JSONArray, None] = None,
reason: undefined.UndefinedOr[str] = undefined.UNDEFINED,
auth: undefined.UndefinedNoneOr[str] = undefined.UNDEFINED,
) -> typing.Union[None, data_binding.JSONObject, data_binding.JSONArray]:
Expand All @@ -749,6 +771,13 @@ async def _perform_request(
if auth:
headers[_AUTHORIZATION_HEADER] = auth

data: typing.Union[None, aiohttp.BytesPayload, aiohttp.FormData] = None
if json is not None:
if form_builder:
raise ValueError("Can only provide one of 'json' or 'form_builder', not both")
davfsa marked this conversation as resolved.
Show resolved Hide resolved

data = data_binding.JSONPayload(json, json_dumps=self._dumps)

url = compiled_route.create_url(self._rest_url)

stack = contextlib.AsyncExitStack()
Expand All @@ -759,7 +788,8 @@ async def _perform_request(

while True:
async with stack:
form = await form_builder.build(stack) if form_builder else None
if form_builder:
data = await form_builder.build(stack, executor=self._executor)

if compiled_route.route.has_ratelimits:
await stack.enter_async_context(self._bucket_manager.acquire_bucket(compiled_route, auth))
Expand All @@ -782,8 +812,7 @@ async def _perform_request(
url,
headers=headers,
params=query,
json=json,
data=form,
data=data,
allow_redirects=self._http_settings.max_redirects is not None,
max_redirects=self._http_settings.max_redirects,
proxy=self._proxy_settings.url,
Expand Down Expand Up @@ -817,7 +846,7 @@ async def _perform_request(
if 200 <= response.status < 300:
if response.content_type == _APPLICATION_JSON:
# Only deserializing here stops Cloudflare shenanigans messing us around.
return data_binding.load_json(await response.read())
return self._loads(await response.read())

real_url = str(response.real_url)
raise errors.HTTPError(f"Expected JSON [{response.content_type=}, {real_url=}]")
Expand Down Expand Up @@ -929,7 +958,8 @@ async def _parse_ratelimits(
f"received rate limited response with unexpected response type {response.content_type}",
)

body = await response.json()
body = self._loads(await response.read())
assert isinstance(body, dict)
body_retry_after = float(body["retry_after"])

if body.get("global", False) is True:
Expand Down Expand Up @@ -1401,7 +1431,7 @@ def _build_message_payload( # noqa: C901- Function too complex
continue

if not form_builder:
form_builder = data_binding.URLEncodedFormBuilder(executor=self._executor)
form_builder = data_binding.URLEncodedFormBuilder()

resource = files.ensure_resource(f)
attachments_payload.append({"id": attachment_id, "filename": resource.filename})
Expand Down Expand Up @@ -1464,7 +1494,7 @@ async def create_message(
body.put("message_reference", message_reference)

if form_builder is not None:
form_builder.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON)
form_builder.add_field("payload_json", self._dumps(body), content_type=_APPLICATION_JSON)
response = await self._request(route, form_builder=form_builder)
else:
response = await self._request(route, json=body)
Expand Down Expand Up @@ -1530,7 +1560,7 @@ async def edit_message(
)

if form_builder is not None:
form_builder.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON)
form_builder.add_field("payload_json", self._dumps(body), content_type=_APPLICATION_JSON)
response = await self._request(route, form_builder=form_builder)
else:
response = await self._request(route, json=body)
Expand Down Expand Up @@ -1845,7 +1875,7 @@ async def execute_webhook(
body.put("avatar_url", avatar_url, conversion=str)

if form_builder is not None:
form_builder.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON)
form_builder.add_field("payload_json", self._dumps(body), content_type=_APPLICATION_JSON)
response = await self._request(route, form_builder=form_builder, query=query, auth=None)
else:
response = await self._request(route, json=body, query=query, auth=None)
Expand Down Expand Up @@ -1923,7 +1953,7 @@ async def edit_webhook_message(
)

if form_builder is not None:
form_builder.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON)
form_builder.add_field("payload_json", self._dumps(body), content_type=_APPLICATION_JSON)
response = await self._request(route, form_builder=form_builder, query=query, auth=None)
else:
response = await self._request(route, json=body, query=query, auth=None)
Expand Down Expand Up @@ -2329,7 +2359,7 @@ async def create_sticker(
reason: undefined.UndefinedOr[str] = undefined.UNDEFINED,
) -> stickers.GuildSticker:
route = routes.POST_GUILD_STICKERS.compile(guild=guild)
form = data_binding.URLEncodedFormBuilder(executor=self._executor)
form = data_binding.URLEncodedFormBuilder()
form.add_field("name", name)
form.add_field("tags", tag)
form.add_field("description", description or "")
Expand Down Expand Up @@ -2862,7 +2892,7 @@ async def create_forum_post(
body.put("message", message_body)

if form_builder is not None:
form_builder.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON)
form_builder.add_field("payload_json", self._dumps(body), content_type=_APPLICATION_JSON)
response = await self._request(route, form_builder=form_builder, reason=reason)
else:
response = await self._request(route, json=body, reason=reason)
Expand Down Expand Up @@ -3937,7 +3967,7 @@ async def create_interaction_response(
body.put("data", data)

if form is not None:
form.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON)
form.add_field("payload_json", self._dumps(body), content_type=_APPLICATION_JSON)
await self._request(route, form_builder=form, auth=None)
else:
await self._request(route, json=body, auth=None)
Expand Down Expand Up @@ -3985,7 +4015,7 @@ async def edit_interaction_response(
)

if form_builder is not None:
form_builder.add_field("payload_json", data_binding.dump_json(body), content_type=_APPLICATION_JSON)
form_builder.add_field("payload_json", self._dumps(body), content_type=_APPLICATION_JSON)
response = await self._request(route, form_builder=form_builder, auth=None)
else:
response = await self._request(route, json=body, auth=None)
Expand Down
Loading