Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions chytorch/nn/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from torch import float32, zeros_like, exp, Tensor
from torch.nn import Parameter, MSELoss
from torch.nn.modules.loss import _Loss
from torchtyping import TensorType
import torch
from jaxtyping import Bool, Float



class MultiTaskLoss(_Loss):
Expand All @@ -32,15 +34,15 @@ class MultiTaskLoss(_Loss):

https://arxiv.org/abs/1705.07115
"""
def __init__(self, loss_type: TensorType['loss_type', bool], *, reduction='mean'):
def __init__(self, loss_type: Bool[torch.Tensor, "loss_type"], *, reduction='mean'):
"""
:param loss_type: vector equal to the number of tasks losses. True for regression and False for classification.
"""
super().__init__(reduction=reduction)
self.log = Parameter(zeros_like(loss_type, dtype=float32))
self.register_buffer('coefficient', (loss_type + 1.).to(float32))

def forward(self, x: TensorType['loss', float]):
def forward(self, x: Float[torch.Tensor, "loss"]):
"""
:param x: 1d vector of losses or 2d matrix of batch X losses.
"""
Expand Down
10 changes: 6 additions & 4 deletions chytorch/nn/molecule/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
#
from itertools import repeat
from torch.nn import GELU, Module, ModuleList, LayerNorm
from torchtyping import TensorType
import torch
from jaxtyping import Float

from typing import Tuple, Optional, List
from warnings import warn
from ._embedding import EmbeddingBag
Expand Down Expand Up @@ -112,9 +114,9 @@ def __init__(self, max_neighbors: int = 14, max_distance: int = 10, d_model: int
self._register_load_state_dict_pre_hook(_update)

def forward(self, batch: MoleculeDataBatch, /, *,
cache: Optional[List[Tuple[TensorType['batch', 'atoms+conditions', 'embedding'],
TensorType['batch', 'atoms+conditions', 'embedding']]]] = None) -> \
TensorType['batch', 'atoms', 'embedding']:
cache: Optional[List[Tuple[Float[torch.Tensor, "batch atoms+conditions embedding"],
Float[torch.Tensor, "batch atoms+conditions embedding"]]]] = None) -> \
Float[torch.Tensor, "batch atoms embedding"]:
"""
Use 0 for padding.
Atoms should be coded by atomic numbers + 2.
Expand Down
6 changes: 4 additions & 2 deletions chytorch/nn/reaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from math import inf
from torch import zeros_like, float as t_float
from torch.nn import Embedding, GELU, Module
from torchtyping import TensorType
import torch
from jaxtyping import Float

from .molecule import MoleculeEncoder
from .transformer import EncoderLayer
from ..utils.data import ReactionEncoderDataBatch
Expand Down Expand Up @@ -58,7 +60,7 @@ def max_distance(self):
"""
return self.molecule_encoder.max_distance

def forward(self, batch: ReactionEncoderDataBatch) -> TensorType['batch', 'atoms', 'embedding']:
def forward(self, batch: ReactionEncoderDataBatch) -> Float[torch.Tensor, "batch atoms embedding"]:
"""
Use 0 for padding. Roles should be coded by 2 for reactants, 3 for products and 1 for special cls token.
Distances - same as molecular encoder distances but batched diagonally.
Expand Down
14 changes: 8 additions & 6 deletions chytorch/nn/voting/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from torch import sigmoid, no_grad
from torch.nn import GELU
from torch.nn.functional import binary_cross_entropy_with_logits
from torchtyping import TensorType
import torch
from jaxtyping import Float, Int

from typing import Union, Optional
from ._kfold import k_fold_mask
from .regressor import VotingRegressor
Expand All @@ -42,8 +44,8 @@ def __init__(self, ensemble: int = 10, output: int = 1, hidden: int = 256, input
layer_norm_eps, loss_function, norm_first)

@no_grad()
def predict(self, x: TensorType['batch', 'embedding'], *,
k_fold: Optional[int] = None) -> Union[TensorType['batch', int], TensorType['batch', 'output', int]]:
def predict(self, x: Float[torch.Tensor, "batch embedding"], *,
k_fold: Optional[int] = None) -> Union[Int[torch.Tensor, "batch"], Int[torch.Tensor, "batch output"]]:
"""
Average class prediction

Expand All @@ -53,9 +55,9 @@ def predict(self, x: TensorType['batch', 'embedding'], *,
return (self.predict_proba(x, k_fold=k_fold) > .5).long()

@no_grad()
def predict_proba(self, x: TensorType['batch', 'embedding'], *,
k_fold: Optional[int] = None) -> Union[TensorType['batch', float],
TensorType['batch', 'output', float]]:
def predict_proba(self, x: Float[torch.Tensor, "batch embedding"], *,
k_fold: Optional[int] = None) -> Union[Float[torch.Tensor, "batch"],
Float[torch.Tensor, "batch output"]]:
"""
Average probability

Expand Down
18 changes: 10 additions & 8 deletions chytorch/nn/voting/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from torch import bmm, no_grad, Tensor
from torch.nn import Dropout, GELU, LayerNorm, LazyLinear, Linear, Module
from torch.nn.functional import cross_entropy, softmax
from torchtyping import TensorType
import torch
from jaxtyping import Float, Int

from typing import Optional, Union
from ._kfold import k_fold_mask

Expand Down Expand Up @@ -83,8 +85,8 @@ def forward(self, x):
return x.view(-1, self._output, self._ensemble, self._n_classes) # B x O x E x C
return x # B x E x C

def loss(self, x: TensorType['batch', 'embedding'],
y: Union[TensorType['batch', 1, int], TensorType['batch', 'output', int]],
def loss(self, x: Float[torch.Tensor, "batch embedding"],
y: Union[Int[torch.Tensor, "batch 1 int] TensorType[batch output"]],
k_fold: Optional[int] = None, ignore_index: int = -100) -> Tensor:
"""
Apply loss function to ensemble of predictions.
Expand Down Expand Up @@ -120,8 +122,8 @@ def loss(self, x: TensorType['batch', 'embedding'],
return self.loss_function(p, y)

@no_grad()
def predict(self, x: TensorType['batch', 'embedding'], *,
k_fold: Optional[int] = None) -> Union[TensorType['batch', int], TensorType['batch', 'output', int]]:
def predict(self, x: Float[torch.Tensor, "batch embedding"], *,
k_fold: Optional[int] = None) -> Union[Int[torch.Tensor, "batch"], Int[torch.Tensor, "batch output"]]:
"""
Average class prediction

Expand All @@ -130,9 +132,9 @@ def predict(self, x: TensorType['batch', 'embedding'], *,
return self.predict_proba(x, k_fold=k_fold).argmax(-1) # B or B x O

@no_grad()
def predict_proba(self, x: TensorType['batch', 'embedding'], *,
k_fold: Optional[int] = None) -> Union[TensorType['batch', 'classes', float],
TensorType['batch', 'output', 'classes', float]]:
def predict_proba(self, x: Float[torch.Tensor, "batch embedding"], *,
k_fold: Optional[int] = None) -> Union[Float[torch.Tensor, "batch classes"],
Float[torch.Tensor, "batch output classes"]]:
"""
Average probability

Expand Down
14 changes: 8 additions & 6 deletions chytorch/nn/voting/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from torch import bmm, no_grad, Tensor
from torch.nn import Dropout, GELU, LayerNorm, LazyLinear, Linear, Module
from torch.nn.functional import smooth_l1_loss
from torchtyping import TensorType
import torch
from jaxtyping import Float

from typing import Optional, Union
from ._kfold import k_fold_mask

Expand Down Expand Up @@ -81,8 +83,8 @@ def forward(self, x):
return x.view(-1, self._output, self._ensemble) # B x O x E
return x # B x E

def loss(self, x: TensorType['batch', 'embedding'],
y: Union[TensorType['batch', 1, float], TensorType['batch', 'output', float]],
def loss(self, x: Float[torch.Tensor, "batch embedding"],
y: Union[Float[torch.Tensor, "batch 1 float] TensorType[batch output"]],
k_fold: Optional[int] = None) -> Tensor:
"""
Apply loss function to ensemble of predictions.
Expand Down Expand Up @@ -110,9 +112,9 @@ def loss(self, x: TensorType['batch', 'embedding'],
return self.loss_function(p, y)

@no_grad()
def predict(self, x: TensorType['batch', 'embedding'], *,
k_fold: Optional[int] = None) -> Union[TensorType['batch', float],
TensorType['batch', 'output', float]]:
def predict(self, x: Float[torch.Tensor, "batch embedding"], *,
k_fold: Optional[int] = None) -> Union[Float[torch.Tensor, "batch"],
Float[torch.Tensor, "batch output"]]:
"""
Average prediction

Expand Down
16 changes: 9 additions & 7 deletions chytorch/utils/data/molecule/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,22 @@
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import default_collate_fn_map
from torchtyping import TensorType
import torch
from jaxtyping import Int

from typing import Sequence, Tuple, Union, NamedTuple


class ConformerDataPoint(NamedTuple):
atoms: TensorType['atoms', int]
hydrogens: TensorType['atoms', int]
distances: TensorType['atoms', 'atoms', int]
atoms: Int[torch.Tensor, "atoms"]
hydrogens: Int[torch.Tensor, "atoms"]
distances: Int[torch.Tensor, "atoms atoms"]


class ConformerDataBatch(NamedTuple):
atoms: TensorType['batch', 'atoms', int]
hydrogens: TensorType['batch', 'atoms', int]
distances: TensorType['batch', 'atoms', 'atoms', int]
atoms: Int[torch.Tensor, "batch atoms"]
hydrogens: Int[torch.Tensor, "batch atoms"]
distances: Int[torch.Tensor, "batch atoms atoms"]

def to(self, *args, **kwargs):
return ConformerDataBatch(*(x.to(*args, **kwargs) for x in self))
Expand Down
18 changes: 10 additions & 8 deletions chytorch/utils/data/molecule/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,23 @@
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import default_collate_fn_map
from torchtyping import TensorType
import torch
from jaxtyping import Int

from typing import Sequence, Union, NamedTuple, Tuple
from zlib import decompress


class MoleculeDataPoint(NamedTuple):
atoms: TensorType['atoms', int]
neighbors: TensorType['atoms', int]
distances: TensorType['atoms', 'atoms', int]
atoms: Int[torch.Tensor, "atoms"]
neighbors: Int[torch.Tensor, "atoms"]
distances: Int[torch.Tensor, "atoms atoms"]


class MoleculeDataBatch(NamedTuple):
atoms: TensorType['batch', 'atoms', int]
neighbors: TensorType['batch', 'atoms', int]
distances: TensorType['batch', 'atoms', 'atoms', int]
atoms: Int[torch.Tensor, "batch atoms"]
neighbors: Int[torch.Tensor, "batch atoms"]
distances: Int[torch.Tensor, "batch atoms atoms"]

def to(self, *args, **kwargs):
return MoleculeDataBatch(*(x.to(*args, **kwargs) for x in self))
Expand Down Expand Up @@ -92,7 +94,7 @@ def collate_molecules(batch, *, padding_left: bool = False, collate_fn_map=None)
class MoleculeDataset(Dataset):
def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
add_cls: bool = True, cls_token: Union[int, Tuple[int, ...], Sequence[int], Sequence[Tuple[int, ...]],
TensorType['cls', int], TensorType['dataset', 1, int], TensorType['dataset', 'cls', int]] = 1,
Int[torch.Tensor, "cls"], Int[torch.Tensor, "dataset 1 int] TensorType[dataset cls"]] = 1,
max_distance: int = 10, max_neighbors: int = 14,
attention_schema: str = 'bert', components_attention: bool = True,
unpack: bool = False, compressed: bool = True, distance_cutoff=None):
Expand Down
20 changes: 11 additions & 9 deletions chytorch/utils/data/reaction/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,25 @@
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from torch.utils.data._utils.collate import default_collate_fn_map
from torchtyping import TensorType
import torch
from jaxtyping import Int

from typing import Sequence, Union, NamedTuple
from ..molecule import MoleculeDataset


class ReactionEncoderDataPoint(NamedTuple):
atoms: TensorType['atoms', int]
neighbors: TensorType['atoms', int]
distances: TensorType['atoms', 'atoms', int]
roles: TensorType['atoms', int]
atoms: Int[torch.Tensor, "atoms"]
neighbors: Int[torch.Tensor, "atoms"]
distances: Int[torch.Tensor, "atoms atoms"]
roles: Int[torch.Tensor, "atoms"]


class ReactionEncoderDataBatch(NamedTuple):
atoms: TensorType['batch', 'atoms', int]
neighbors: TensorType['batch', 'atoms', int]
distances: TensorType['batch', 'atoms', 'atoms', int]
roles: TensorType['batch', 'atoms', int]
atoms: Int[torch.Tensor, "batch atoms"]
neighbors: Int[torch.Tensor, "batch atoms"]
distances: Int[torch.Tensor, "batch atoms atoms"]
roles: Int[torch.Tensor, "batch atoms"]

def to(self, *args, **kwargs):
return ReactionEncoderDataBatch(*(x.to(*args, **kwargs) for x in self))
Expand Down