diff --git a/chytorch/nn/losses.py b/chytorch/nn/losses.py index 417a35b..8bce6c5 100644 --- a/chytorch/nn/losses.py +++ b/chytorch/nn/losses.py @@ -23,7 +23,9 @@ from torch import float32, zeros_like, exp, Tensor from torch.nn import Parameter, MSELoss from torch.nn.modules.loss import _Loss -from torchtyping import TensorType +import torch +from jaxtyping import Bool, Float + class MultiTaskLoss(_Loss): @@ -32,7 +34,7 @@ class MultiTaskLoss(_Loss): https://arxiv.org/abs/1705.07115 """ - def __init__(self, loss_type: TensorType['loss_type', bool], *, reduction='mean'): + def __init__(self, loss_type: Bool[torch.Tensor, "loss_type"], *, reduction='mean'): """ :param loss_type: vector equal to the number of tasks losses. True for regression and False for classification. """ @@ -40,7 +42,7 @@ def __init__(self, loss_type: TensorType['loss_type', bool], *, reduction='mean' self.log = Parameter(zeros_like(loss_type, dtype=float32)) self.register_buffer('coefficient', (loss_type + 1.).to(float32)) - def forward(self, x: TensorType['loss', float]): + def forward(self, x: Float[torch.Tensor, "loss"]): """ :param x: 1d vector of losses or 2d matrix of batch X losses. """ diff --git a/chytorch/nn/molecule/encoder.py b/chytorch/nn/molecule/encoder.py index 130b147..b467a3b 100644 --- a/chytorch/nn/molecule/encoder.py +++ b/chytorch/nn/molecule/encoder.py @@ -22,7 +22,9 @@ # from itertools import repeat from torch.nn import GELU, Module, ModuleList, LayerNorm -from torchtyping import TensorType +import torch +from jaxtyping import Float + from typing import Tuple, Optional, List from warnings import warn from ._embedding import EmbeddingBag @@ -112,9 +114,9 @@ def __init__(self, max_neighbors: int = 14, max_distance: int = 10, d_model: int self._register_load_state_dict_pre_hook(_update) def forward(self, batch: MoleculeDataBatch, /, *, - cache: Optional[List[Tuple[TensorType['batch', 'atoms+conditions', 'embedding'], - TensorType['batch', 'atoms+conditions', 'embedding']]]] = None) -> \ - TensorType['batch', 'atoms', 'embedding']: + cache: Optional[List[Tuple[Float[torch.Tensor, "batch atoms+conditions embedding"], + Float[torch.Tensor, "batch atoms+conditions embedding"]]]] = None) -> \ + Float[torch.Tensor, "batch atoms embedding"]: """ Use 0 for padding. Atoms should be coded by atomic numbers + 2. diff --git a/chytorch/nn/reaction.py b/chytorch/nn/reaction.py index 9ef2c1d..afe8012 100644 --- a/chytorch/nn/reaction.py +++ b/chytorch/nn/reaction.py @@ -23,7 +23,9 @@ from math import inf from torch import zeros_like, float as t_float from torch.nn import Embedding, GELU, Module -from torchtyping import TensorType +import torch +from jaxtyping import Float + from .molecule import MoleculeEncoder from .transformer import EncoderLayer from ..utils.data import ReactionEncoderDataBatch @@ -58,7 +60,7 @@ def max_distance(self): """ return self.molecule_encoder.max_distance - def forward(self, batch: ReactionEncoderDataBatch) -> TensorType['batch', 'atoms', 'embedding']: + def forward(self, batch: ReactionEncoderDataBatch) -> Float[torch.Tensor, "batch atoms embedding"]: """ Use 0 for padding. Roles should be coded by 2 for reactants, 3 for products and 1 for special cls token. Distances - same as molecular encoder distances but batched diagonally. diff --git a/chytorch/nn/voting/binary.py b/chytorch/nn/voting/binary.py index ca56df4..1730a4d 100644 --- a/chytorch/nn/voting/binary.py +++ b/chytorch/nn/voting/binary.py @@ -24,7 +24,9 @@ from torch import sigmoid, no_grad from torch.nn import GELU from torch.nn.functional import binary_cross_entropy_with_logits -from torchtyping import TensorType +import torch +from jaxtyping import Float, Int + from typing import Union, Optional from ._kfold import k_fold_mask from .regressor import VotingRegressor @@ -42,8 +44,8 @@ def __init__(self, ensemble: int = 10, output: int = 1, hidden: int = 256, input layer_norm_eps, loss_function, norm_first) @no_grad() - def predict(self, x: TensorType['batch', 'embedding'], *, - k_fold: Optional[int] = None) -> Union[TensorType['batch', int], TensorType['batch', 'output', int]]: + def predict(self, x: Float[torch.Tensor, "batch embedding"], *, + k_fold: Optional[int] = None) -> Union[Int[torch.Tensor, "batch"], Int[torch.Tensor, "batch output"]]: """ Average class prediction @@ -53,9 +55,9 @@ def predict(self, x: TensorType['batch', 'embedding'], *, return (self.predict_proba(x, k_fold=k_fold) > .5).long() @no_grad() - def predict_proba(self, x: TensorType['batch', 'embedding'], *, - k_fold: Optional[int] = None) -> Union[TensorType['batch', float], - TensorType['batch', 'output', float]]: + def predict_proba(self, x: Float[torch.Tensor, "batch embedding"], *, + k_fold: Optional[int] = None) -> Union[Float[torch.Tensor, "batch"], + Float[torch.Tensor, "batch output"]]: """ Average probability diff --git a/chytorch/nn/voting/classifier.py b/chytorch/nn/voting/classifier.py index 28b0ef7..a630851 100644 --- a/chytorch/nn/voting/classifier.py +++ b/chytorch/nn/voting/classifier.py @@ -24,7 +24,9 @@ from torch import bmm, no_grad, Tensor from torch.nn import Dropout, GELU, LayerNorm, LazyLinear, Linear, Module from torch.nn.functional import cross_entropy, softmax -from torchtyping import TensorType +import torch +from jaxtyping import Float, Int + from typing import Optional, Union from ._kfold import k_fold_mask @@ -83,8 +85,8 @@ def forward(self, x): return x.view(-1, self._output, self._ensemble, self._n_classes) # B x O x E x C return x # B x E x C - def loss(self, x: TensorType['batch', 'embedding'], - y: Union[TensorType['batch', 1, int], TensorType['batch', 'output', int]], + def loss(self, x: Float[torch.Tensor, "batch embedding"], + y: Union[Int[torch.Tensor, "batch 1 int] TensorType[batch output"]], k_fold: Optional[int] = None, ignore_index: int = -100) -> Tensor: """ Apply loss function to ensemble of predictions. @@ -120,8 +122,8 @@ def loss(self, x: TensorType['batch', 'embedding'], return self.loss_function(p, y) @no_grad() - def predict(self, x: TensorType['batch', 'embedding'], *, - k_fold: Optional[int] = None) -> Union[TensorType['batch', int], TensorType['batch', 'output', int]]: + def predict(self, x: Float[torch.Tensor, "batch embedding"], *, + k_fold: Optional[int] = None) -> Union[Int[torch.Tensor, "batch"], Int[torch.Tensor, "batch output"]]: """ Average class prediction @@ -130,9 +132,9 @@ def predict(self, x: TensorType['batch', 'embedding'], *, return self.predict_proba(x, k_fold=k_fold).argmax(-1) # B or B x O @no_grad() - def predict_proba(self, x: TensorType['batch', 'embedding'], *, - k_fold: Optional[int] = None) -> Union[TensorType['batch', 'classes', float], - TensorType['batch', 'output', 'classes', float]]: + def predict_proba(self, x: Float[torch.Tensor, "batch embedding"], *, + k_fold: Optional[int] = None) -> Union[Float[torch.Tensor, "batch classes"], + Float[torch.Tensor, "batch output classes"]]: """ Average probability diff --git a/chytorch/nn/voting/regressor.py b/chytorch/nn/voting/regressor.py index 85512dc..0d931d0 100644 --- a/chytorch/nn/voting/regressor.py +++ b/chytorch/nn/voting/regressor.py @@ -24,7 +24,9 @@ from torch import bmm, no_grad, Tensor from torch.nn import Dropout, GELU, LayerNorm, LazyLinear, Linear, Module from torch.nn.functional import smooth_l1_loss -from torchtyping import TensorType +import torch +from jaxtyping import Float + from typing import Optional, Union from ._kfold import k_fold_mask @@ -81,8 +83,8 @@ def forward(self, x): return x.view(-1, self._output, self._ensemble) # B x O x E return x # B x E - def loss(self, x: TensorType['batch', 'embedding'], - y: Union[TensorType['batch', 1, float], TensorType['batch', 'output', float]], + def loss(self, x: Float[torch.Tensor, "batch embedding"], + y: Union[Float[torch.Tensor, "batch 1 float] TensorType[batch output"]], k_fold: Optional[int] = None) -> Tensor: """ Apply loss function to ensemble of predictions. @@ -110,9 +112,9 @@ def loss(self, x: TensorType['batch', 'embedding'], return self.loss_function(p, y) @no_grad() - def predict(self, x: TensorType['batch', 'embedding'], *, - k_fold: Optional[int] = None) -> Union[TensorType['batch', float], - TensorType['batch', 'output', float]]: + def predict(self, x: Float[torch.Tensor, "batch embedding"], *, + k_fold: Optional[int] = None) -> Union[Float[torch.Tensor, "batch"], + Float[torch.Tensor, "batch output"]]: """ Average prediction diff --git a/chytorch/utils/data/molecule/conformer.py b/chytorch/utils/data/molecule/conformer.py index abdf887..d380a65 100644 --- a/chytorch/utils/data/molecule/conformer.py +++ b/chytorch/utils/data/molecule/conformer.py @@ -28,20 +28,22 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from torch.utils.data._utils.collate import default_collate_fn_map -from torchtyping import TensorType +import torch +from jaxtyping import Int + from typing import Sequence, Tuple, Union, NamedTuple class ConformerDataPoint(NamedTuple): - atoms: TensorType['atoms', int] - hydrogens: TensorType['atoms', int] - distances: TensorType['atoms', 'atoms', int] + atoms: Int[torch.Tensor, "atoms"] + hydrogens: Int[torch.Tensor, "atoms"] + distances: Int[torch.Tensor, "atoms atoms"] class ConformerDataBatch(NamedTuple): - atoms: TensorType['batch', 'atoms', int] - hydrogens: TensorType['batch', 'atoms', int] - distances: TensorType['batch', 'atoms', 'atoms', int] + atoms: Int[torch.Tensor, "batch atoms"] + hydrogens: Int[torch.Tensor, "batch atoms"] + distances: Int[torch.Tensor, "batch atoms atoms"] def to(self, *args, **kwargs): return ConformerDataBatch(*(x.to(*args, **kwargs) for x in self)) diff --git a/chytorch/utils/data/molecule/encoder.py b/chytorch/utils/data/molecule/encoder.py index 1f48976..91f2212 100644 --- a/chytorch/utils/data/molecule/encoder.py +++ b/chytorch/utils/data/molecule/encoder.py @@ -28,21 +28,23 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from torch.utils.data._utils.collate import default_collate_fn_map -from torchtyping import TensorType +import torch +from jaxtyping import Int + from typing import Sequence, Union, NamedTuple, Tuple from zlib import decompress class MoleculeDataPoint(NamedTuple): - atoms: TensorType['atoms', int] - neighbors: TensorType['atoms', int] - distances: TensorType['atoms', 'atoms', int] + atoms: Int[torch.Tensor, "atoms"] + neighbors: Int[torch.Tensor, "atoms"] + distances: Int[torch.Tensor, "atoms atoms"] class MoleculeDataBatch(NamedTuple): - atoms: TensorType['batch', 'atoms', int] - neighbors: TensorType['batch', 'atoms', int] - distances: TensorType['batch', 'atoms', 'atoms', int] + atoms: Int[torch.Tensor, "batch atoms"] + neighbors: Int[torch.Tensor, "batch atoms"] + distances: Int[torch.Tensor, "batch atoms atoms"] def to(self, *args, **kwargs): return MoleculeDataBatch(*(x.to(*args, **kwargs) for x in self)) @@ -92,7 +94,7 @@ def collate_molecules(batch, *, padding_left: bool = False, collate_fn_map=None) class MoleculeDataset(Dataset): def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *, add_cls: bool = True, cls_token: Union[int, Tuple[int, ...], Sequence[int], Sequence[Tuple[int, ...]], - TensorType['cls', int], TensorType['dataset', 1, int], TensorType['dataset', 'cls', int]] = 1, + Int[torch.Tensor, "cls"], Int[torch.Tensor, "dataset 1 int] TensorType[dataset cls"]] = 1, max_distance: int = 10, max_neighbors: int = 14, attention_schema: str = 'bert', components_attention: bool = True, unpack: bool = False, compressed: bool = True, distance_cutoff=None): diff --git a/chytorch/utils/data/reaction/encoder.py b/chytorch/utils/data/reaction/encoder.py index 219acbd..5fb522a 100644 --- a/chytorch/utils/data/reaction/encoder.py +++ b/chytorch/utils/data/reaction/encoder.py @@ -26,23 +26,25 @@ from torch.nn.utils.rnn import pad_sequence from torch.utils.data import Dataset from torch.utils.data._utils.collate import default_collate_fn_map -from torchtyping import TensorType +import torch +from jaxtyping import Int + from typing import Sequence, Union, NamedTuple from ..molecule import MoleculeDataset class ReactionEncoderDataPoint(NamedTuple): - atoms: TensorType['atoms', int] - neighbors: TensorType['atoms', int] - distances: TensorType['atoms', 'atoms', int] - roles: TensorType['atoms', int] + atoms: Int[torch.Tensor, "atoms"] + neighbors: Int[torch.Tensor, "atoms"] + distances: Int[torch.Tensor, "atoms atoms"] + roles: Int[torch.Tensor, "atoms"] class ReactionEncoderDataBatch(NamedTuple): - atoms: TensorType['batch', 'atoms', int] - neighbors: TensorType['batch', 'atoms', int] - distances: TensorType['batch', 'atoms', 'atoms', int] - roles: TensorType['batch', 'atoms', int] + atoms: Int[torch.Tensor, "batch atoms"] + neighbors: Int[torch.Tensor, "batch atoms"] + distances: Int[torch.Tensor, "batch atoms atoms"] + roles: Int[torch.Tensor, "batch atoms"] def to(self, *args, **kwargs): return ReactionEncoderDataBatch(*(x.to(*args, **kwargs) for x in self))