Skip to content

Commit 29d2493

Browse files
committed
Convert torchtyping to jaxtyping for PyTorch>=2.4
1 parent 46c41e5 commit 29d2493

File tree

10 files changed

+71
-53
lines changed

10 files changed

+71
-53
lines changed

chytorch/nn/losses.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import torch
3+
from jaxtyping import Bool, Float
24
#
35
# Copyright 2023, 2024 Ramil Nugmanov <nougmanoff@protonmail.com>
46
#
@@ -23,7 +25,7 @@
2325
from torch import float32, zeros_like, exp, Tensor
2426
from torch.nn import Parameter, MSELoss
2527
from torch.nn.modules.loss import _Loss
26-
from torchtyping import TensorType
28+
2729

2830

2931
class MultiTaskLoss(_Loss):
@@ -32,15 +34,15 @@ class MultiTaskLoss(_Loss):
3234
3335
https://arxiv.org/abs/1705.07115
3436
"""
35-
def __init__(self, loss_type: TensorType['loss_type', bool], *, reduction='mean'):
37+
def __init__(self, loss_type: Bool[torch.Tensor, "loss_type"], *, reduction='mean'):
3638
"""
3739
:param loss_type: vector equal to the number of tasks losses. True for regression and False for classification.
3840
"""
3941
super().__init__(reduction=reduction)
4042
self.log = Parameter(zeros_like(loss_type, dtype=float32))
4143
self.register_buffer('coefficient', (loss_type + 1.).to(float32))
4244

43-
def forward(self, x: TensorType['loss', float]):
45+
def forward(self, x: Float[torch.Tensor, "loss"]):
4446
"""
4547
:param x: 1d vector of losses or 2d matrix of batch X losses.
4648
"""

chytorch/nn/molecule/encoder.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import torch
3+
from jaxtyping import Float
24
#
35
# Copyright 2021-2024 Ramil Nugmanov <nougmanoff@protonmail.com>
46
#
@@ -22,7 +24,7 @@
2224
#
2325
from itertools import repeat
2426
from torch.nn import GELU, Module, ModuleList, LayerNorm
25-
from torchtyping import TensorType
27+
2628
from typing import Tuple, Optional, List
2729
from warnings import warn
2830
from ._embedding import EmbeddingBag
@@ -118,9 +120,9 @@ def __init__(self, max_neighbors: int = 14, max_distance: int = 10, d_model: int
118120
self._register_load_state_dict_pre_hook(_update)
119121

120122
def forward(self, batch: MoleculeDataBatch, /, *,
121-
cache: Optional[List[Tuple[TensorType['batch', 'atoms+conditions', 'embedding'],
122-
TensorType['batch', 'atoms+conditions', 'embedding']]]] = None) -> \
123-
TensorType['batch', 'atoms', 'embedding']:
123+
cache: Optional[List[Tuple[Float[torch.Tensor, "batch atoms+conditions embedding"],
124+
Float[torch.Tensor, "batch atoms+conditions embedding"]]]] = None) -> \
125+
Float[torch.Tensor, "batch atoms embedding"]:
124126
"""
125127
Use 0 for padding.
126128
Atoms should be coded by atomic numbers + 2.

chytorch/nn/reaction/encoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import torch
3+
from jaxtyping import Float
24
#
35
# Copyright 2021-2023 Ramil Nugmanov <nougmanoff@protonmail.com>
46
#
@@ -23,7 +25,7 @@
2325
from math import inf
2426
from torch import zeros_like, float as t_float
2527
from torch.nn import Embedding, GELU, Module
26-
from torchtyping import TensorType
28+
2729
from ..molecule import MoleculeEncoder
2830
from ..transformer import EncoderLayer
2931
from ...utils.data import ReactionEncoderDataBatch
@@ -58,7 +60,7 @@ def max_distance(self):
5860
"""
5961
return self.molecule_encoder.max_distance
6062

61-
def forward(self, batch: ReactionEncoderDataBatch) -> TensorType['batch', 'atoms', 'embedding']:
63+
def forward(self, batch: ReactionEncoderDataBatch) -> Float[torch.Tensor, "batch atoms embedding"]:
6264
"""
6365
Use 0 for padding. Roles should be coded by 2 for reactants, 3 for products and 1 for special cls token.
6466
Distances - same as molecular encoder distances but batched diagonally.

chytorch/nn/voting/binary.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import torch
3+
from jaxtyping import Float, Int
24
#
35
# Copyright 2022, 2023 Ramil Nugmanov <nougmanoff@protonmail.com>
46
#
@@ -24,7 +26,7 @@
2426
from torch import sigmoid, no_grad
2527
from torch.nn import GELU
2628
from torch.nn.functional import binary_cross_entropy_with_logits
27-
from torchtyping import TensorType
29+
2830
from typing import Union, Optional
2931
from ._kfold import k_fold_mask
3032
from .regressor import VotingRegressor
@@ -42,8 +44,8 @@ def __init__(self, ensemble: int = 10, output: int = 1, hidden: int = 256, input
4244
layer_norm_eps, loss_function, norm_first)
4345

4446
@no_grad()
45-
def predict(self, x: TensorType['batch', 'embedding'], *,
46-
k_fold: Optional[int] = None) -> Union[TensorType['batch', int], TensorType['batch', 'output', int]]:
47+
def predict(self, x: Float[torch.Tensor, "batch embedding"], *,
48+
k_fold: Optional[int] = None) -> Union[Int[torch.Tensor, "batch"], Int[torch.Tensor, "batch output"]]:
4749
"""
4850
Average class prediction
4951
@@ -53,9 +55,9 @@ def predict(self, x: TensorType['batch', 'embedding'], *,
5355
return (self.predict_proba(x, k_fold=k_fold) > .5).long()
5456

5557
@no_grad()
56-
def predict_proba(self, x: TensorType['batch', 'embedding'], *,
57-
k_fold: Optional[int] = None) -> Union[TensorType['batch', float],
58-
TensorType['batch', 'output', float]]:
58+
def predict_proba(self, x: Float[torch.Tensor, "batch embedding"], *,
59+
k_fold: Optional[int] = None) -> Union[Float[torch.Tensor, "batch"],
60+
Float[torch.Tensor, "batch output"]]:
5961
"""
6062
Average probability
6163

chytorch/nn/voting/classifier.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import torch
3+
from jaxtyping import Float, Int
24
#
35
# Copyright 2022, 2023 Ramil Nugmanov <nougmanoff@protonmail.com>
46
#
@@ -24,7 +26,7 @@
2426
from torch import bmm, no_grad, Tensor
2527
from torch.nn import Dropout, GELU, LayerNorm, LazyLinear, Linear, Module
2628
from torch.nn.functional import cross_entropy, softmax
27-
from torchtyping import TensorType
29+
2830
from typing import Optional, Union
2931
from ._kfold import k_fold_mask
3032

@@ -83,8 +85,8 @@ def forward(self, x):
8385
return x.view(-1, self._output, self._ensemble, self._n_classes) # B x O x E x C
8486
return x # B x E x C
8587

86-
def loss(self, x: TensorType['batch', 'embedding'],
87-
y: Union[TensorType['batch', 1, int], TensorType['batch', 'output', int]],
88+
def loss(self, x: Float[torch.Tensor, "batch embedding"],
89+
y: Union[Int[torch.Tensor, "batch 1 int] TensorType[batch output"]],
8890
k_fold: Optional[int] = None, ignore_index: int = -100) -> Tensor:
8991
"""
9092
Apply loss function to ensemble of predictions.
@@ -120,8 +122,8 @@ def loss(self, x: TensorType['batch', 'embedding'],
120122
return self.loss_function(p, y)
121123

122124
@no_grad()
123-
def predict(self, x: TensorType['batch', 'embedding'], *,
124-
k_fold: Optional[int] = None) -> Union[TensorType['batch', int], TensorType['batch', 'output', int]]:
125+
def predict(self, x: Float[torch.Tensor, "batch embedding"], *,
126+
k_fold: Optional[int] = None) -> Union[Int[torch.Tensor, "batch"], Int[torch.Tensor, "batch output"]]:
125127
"""
126128
Average class prediction
127129
@@ -130,9 +132,9 @@ def predict(self, x: TensorType['batch', 'embedding'], *,
130132
return self.predict_proba(x, k_fold=k_fold).argmax(-1) # B or B x O
131133

132134
@no_grad()
133-
def predict_proba(self, x: TensorType['batch', 'embedding'], *,
134-
k_fold: Optional[int] = None) -> Union[TensorType['batch', 'classes', float],
135-
TensorType['batch', 'output', 'classes', float]]:
135+
def predict_proba(self, x: Float[torch.Tensor, "batch embedding"], *,
136+
k_fold: Optional[int] = None) -> Union[Float[torch.Tensor, "batch classes"],
137+
Float[torch.Tensor, "batch output classes"]]:
136138
"""
137139
Average probability
138140

chytorch/nn/voting/regressor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import torch
3+
from jaxtyping import Float
24
#
35
# Copyright 2022, 2023 Ramil Nugmanov <nougmanoff@protonmail.com>
46
#
@@ -24,7 +26,7 @@
2426
from torch import bmm, no_grad, Tensor
2527
from torch.nn import Dropout, GELU, LayerNorm, LazyLinear, Linear, Module
2628
from torch.nn.functional import smooth_l1_loss
27-
from torchtyping import TensorType
29+
2830
from typing import Optional, Union
2931
from ._kfold import k_fold_mask
3032

@@ -81,8 +83,8 @@ def forward(self, x):
8183
return x.view(-1, self._output, self._ensemble) # B x O x E
8284
return x # B x E
8385

84-
def loss(self, x: TensorType['batch', 'embedding'],
85-
y: Union[TensorType['batch', 1, float], TensorType['batch', 'output', float]],
86+
def loss(self, x: Float[torch.Tensor, "batch embedding"],
87+
y: Union[Float[torch.Tensor, "batch 1 float] TensorType[batch output"]],
8688
k_fold: Optional[int] = None) -> Tensor:
8789
"""
8890
Apply loss function to ensemble of predictions.
@@ -110,9 +112,9 @@ def loss(self, x: TensorType['batch', 'embedding'],
110112
return self.loss_function(p, y)
111113

112114
@no_grad()
113-
def predict(self, x: TensorType['batch', 'embedding'], *,
114-
k_fold: Optional[int] = None) -> Union[TensorType['batch', float],
115-
TensorType['batch', 'output', float]]:
115+
def predict(self, x: Float[torch.Tensor, "batch embedding"], *,
116+
k_fold: Optional[int] = None) -> Union[Float[torch.Tensor, "batch"],
117+
Float[torch.Tensor, "batch output"]]:
116118
"""
117119
Average prediction
118120

chytorch/utils/data/molecule/conformer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import torch
3+
from jaxtyping import Int
24
#
35
# Copyright 2022-2024 Ramil Nugmanov <nougmanoff@protonmail.com>
46
#
@@ -27,21 +29,21 @@
2729
from torch import IntTensor, Size, zeros, ones as t_ones, int32 as t_int32, eye
2830
from torch.nn.utils.rnn import pad_sequence
2931
from torch.utils.data import Dataset
30-
from torchtyping import TensorType
32+
3133
from typing import Sequence, Tuple, Union, NamedTuple
3234
from .._abc import default_collate_fn_map
3335

3436

3537
class ConformerDataPoint(NamedTuple):
36-
atoms: TensorType['atoms', int]
37-
hydrogens: TensorType['atoms', int]
38-
distances: TensorType['atoms', 'atoms', int]
38+
atoms: Int[torch.Tensor, "atoms"]
39+
hydrogens: Int[torch.Tensor, "atoms"]
40+
distances: Int[torch.Tensor, "atoms atoms"]
3941

4042

4143
class ConformerDataBatch(NamedTuple):
42-
atoms: TensorType['batch', 'atoms', int]
43-
hydrogens: TensorType['batch', 'atoms', int]
44-
distances: TensorType['batch', 'atoms', 'atoms', int]
44+
atoms: Int[torch.Tensor, "batch atoms"]
45+
hydrogens: Int[torch.Tensor, "batch atoms"]
46+
distances: Int[torch.Tensor, "batch atoms atoms"]
4547

4648
def to(self, *args, **kwargs):
4749
return ConformerDataBatch(*(x.to(*args, **kwargs) for x in self))

chytorch/utils/data/molecule/encoder.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import torch
3+
from jaxtyping import Int
24
#
35
# Copyright 2021-2024 Ramil Nugmanov <nougmanoff@protonmail.com>
46
#
@@ -26,22 +28,22 @@
2628
from torch import IntTensor, Size, int32, ones, zeros, eye, empty, full
2729
from torch.nn.utils.rnn import pad_sequence
2830
from torch.utils.data import Dataset
29-
from torchtyping import TensorType
31+
3032
from typing import Sequence, Union, NamedTuple, Optional, Tuple
3133
from zlib import decompress
3234
from .._abc import default_collate_fn_map
3335

3436

3537
class MoleculeDataPoint(NamedTuple):
36-
atoms: TensorType['atoms', int]
37-
neighbors: TensorType['atoms', int]
38-
distances: TensorType['atoms', 'atoms', int]
38+
atoms: Int[torch.Tensor, "atoms"]
39+
neighbors: Int[torch.Tensor, "atoms"]
40+
distances: Int[torch.Tensor, "atoms atoms"]
3941

4042

4143
class MoleculeDataBatch(NamedTuple):
42-
atoms: TensorType['batch', 'atoms', int]
43-
neighbors: TensorType['batch', 'atoms', int]
44-
distances: TensorType['batch', 'atoms', 'atoms', int]
44+
atoms: Int[torch.Tensor, "batch atoms"]
45+
neighbors: Int[torch.Tensor, "batch atoms"]
46+
distances: Int[torch.Tensor, "batch atoms atoms"]
4547

4648
def to(self, *args, **kwargs):
4749
return MoleculeDataBatch(*(x.to(*args, **kwargs) for x in self))

chytorch/utils/data/reaction/encoder.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import torch
3+
from jaxtyping import Int
24
#
35
# Copyright 2021-2023 Ramil Nugmanov <nougmanoff@protonmail.com>
46
#
@@ -25,24 +27,24 @@
2527
from torch import IntTensor, cat, zeros, int32, Size, eye
2628
from torch.nn.utils.rnn import pad_sequence
2729
from torch.utils.data import Dataset
28-
from torchtyping import TensorType
30+
2931
from typing import Sequence, Union, NamedTuple
3032
from ..molecule import MoleculeDataset
3133
from .._abc import default_collate_fn_map
3234

3335

3436
class ReactionEncoderDataPoint(NamedTuple):
35-
atoms: TensorType['atoms', int]
36-
neighbors: TensorType['atoms', int]
37-
distances: TensorType['atoms', 'atoms', int]
38-
roles: TensorType['atoms', int]
37+
atoms: Int[torch.Tensor, "atoms"]
38+
neighbors: Int[torch.Tensor, "atoms"]
39+
distances: Int[torch.Tensor, "atoms atoms"]
40+
roles: Int[torch.Tensor, "atoms"]
3941

4042

4143
class ReactionEncoderDataBatch(NamedTuple):
42-
atoms: TensorType['batch', 'atoms', int]
43-
neighbors: TensorType['batch', 'atoms', int]
44-
distances: TensorType['batch', 'atoms', 'atoms', int]
45-
roles: TensorType['batch', 'atoms', int]
44+
atoms: Int[torch.Tensor, "batch atoms"]
45+
neighbors: Int[torch.Tensor, "batch atoms"]
46+
distances: Int[torch.Tensor, "batch atoms atoms"]
47+
roles: Int[torch.Tensor, "batch atoms"]
4648

4749
def to(self, *args, **kwargs):
4850
return ReactionEncoderDataBatch(*(x.to(*args, **kwargs) for x in self))

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ include = [
2929

3030
[tool.poetry.dependencies]
3131
python = '>=3.8,<3.12'
32-
torchtyping = '^0.1.4'
32+
jaxtyping = '^0.3.2'
3333
chython = '^1.70'
3434
scipy = '^1.10'
3535
torch = '>=1.8'

0 commit comments

Comments
 (0)