-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Creating instances of jaxtyped
dataclasses is slow
#232
Comments
I do not know if this helps, but quick profiling revealed that most of the time is spend in the |
Hmm, is this not just the overhead from doing the actual type checking itself? For what it's worth I don't think we currently respect |
I am not very used to the actual |
Did a bit of line profiling on the above example, see below.
Edit: hold on: isn't the actual type checking call just 3.6% of time?... (the last line) In which case, can't we simply cache this Timer unit: 1e-06 s
Total time: 0.201161 s
Function: _check_dataclass_annotations at line 1
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1 def _check_dataclass_annotations(self, typechecker):
2 """Creates and calls a function that checks the attributes of `self`
3
4 `self` should be a dataclass instance. `typechecker` should be e.g.
5 `beartype.beartype` or `typeguard.typechecked`.
6 """
7 1000 1748.0 1.7 0.9 parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)]
8 1000 113.0 0.1 0.1 values = {}
9 2000 1397.0 0.7 0.7 for field in dataclasses.fields(self):
10 1000 124.0 0.1 0.1 annotation = field.type
11 1000 161.0 0.2 0.1 if isinstance(annotation, str):
12 # Don't check stringified annotations. These are basically impossible to
13 # resolve correctly, so just skip them.
14 continue
15 1000 742.0 0.7 0.4 if get_origin(annotation) is type:
16 args = get_args(annotation)
17 if len(args) == 1 and isinstance(args[0], str):
18 # We also special-case this one kind of partially-stringified type
19 # annotation, so as to support Equinox <v0.11.1.
20 # This was fixed in Equinox in
21 # https://github.com/patrick-kidger/equinox/pull/543
22 continue
23 1000 75.0 0.1 0.0 try:
24 1000 537.0 0.5 0.3 value = getattr(self, field.name) # noqa: F841
25 1000 117.0 0.1 0.1 except AttributeError:
26 1000 109.0 0.1 0.1 continue # allow uninitialised fields, which are allowed on dataclasses
27
28 parameters.append(
29 inspect.Parameter(
30 field.name,
31 inspect.Parameter.POSITIONAL_OR_KEYWORD,
32 annotation=field.type,
33 )
34 )
35 values[field.name] = value
36
37 1000 1853.0 1.9 0.9 signature = inspect.Signature(parameters)
38 2000 22255.0 11.1 11.1 f = _make_fn_with_signature(
39 1000 152.0 0.2 0.1 self.__class__.__name__,
40 1000 137.0 0.1 0.1 self.__class__.__qualname__,
41 1000 137.0 0.1 0.1 self.__class__.__module__,
42 1000 73.0 0.1 0.0 signature,
43 1000 70.0 0.1 0.0 output=False,
44 )
45 1000 164024.0 164.0 81.5 f = jaxtyped(f, typechecker=typechecker)
46 1000 7337.0 7.3 3.6 f(self, **values) |
Oh interesting! Thank you for profiling this -- I agree, caching sounds reasonable. I'd be happy to take a PR on this! |
Annotating a
dataclass
with@jaxtyped
makes creating instances of that class ~1000x slower.This is especially problematic in cases where the entire package is jaxtyped with
install_import_hook()
, because it is not possible to exclude a frequently used dataclass from being jaxtyped.Here is a small benchmark:
Output:
The text was updated successfully, but these errors were encountered: