11# -*- coding: utf-8 -*-
2+ import torch
3+ from jaxtyping import Float , Int
24#
35# Copyright 2022, 2023 Ramil Nugmanov <nougmanoff@protonmail.com>
46#
2426from torch import bmm , no_grad , Tensor
2527from torch .nn import Dropout , GELU , LayerNorm , LazyLinear , Linear , Module
2628from torch .nn .functional import cross_entropy , softmax
27- from torchtyping import TensorType
29+
2830from typing import Optional , Union
2931from ._kfold import k_fold_mask
3032
@@ -83,8 +85,8 @@ def forward(self, x):
8385 return x .view (- 1 , self ._output , self ._ensemble , self ._n_classes ) # B x O x E x C
8486 return x # B x E x C
8587
86- def loss (self , x : TensorType [ 'batch' , ' embedding' ],
87- y : Union [TensorType [ 'batch' , 1 , int ], TensorType [' batch' , ' output' , int ]],
88+ def loss (self , x : Float [ torch . Tensor , "batch embedding" ],
89+ y : Union [Int [ torch . Tensor , "batch 1 int] TensorType[batch output" ]],
8890 k_fold : Optional [int ] = None , ignore_index : int = - 100 ) -> Tensor :
8991 """
9092 Apply loss function to ensemble of predictions.
@@ -120,8 +122,8 @@ def loss(self, x: TensorType['batch', 'embedding'],
120122 return self .loss_function (p , y )
121123
122124 @no_grad ()
123- def predict (self , x : TensorType [ 'batch' , ' embedding' ], * ,
124- k_fold : Optional [int ] = None ) -> Union [TensorType [ 'batch' , int ], TensorType [ 'batch' , 'output' , int ]]:
125+ def predict (self , x : Float [ torch . Tensor , "batch embedding" ], * ,
126+ k_fold : Optional [int ] = None ) -> Union [Int [ torch . Tensor , "batch" ], Int [ torch . Tensor , "batch output" ]]:
125127 """
126128 Average class prediction
127129
@@ -130,9 +132,9 @@ def predict(self, x: TensorType['batch', 'embedding'], *,
130132 return self .predict_proba (x , k_fold = k_fold ).argmax (- 1 ) # B or B x O
131133
132134 @no_grad ()
133- def predict_proba (self , x : TensorType [ 'batch' , ' embedding' ], * ,
134- k_fold : Optional [int ] = None ) -> Union [TensorType [ 'batch' , 'classes' , float ],
135- TensorType [ 'batch' , ' output' , ' classes' , float ]]:
135+ def predict_proba (self , x : Float [ torch . Tensor , "batch embedding" ], * ,
136+ k_fold : Optional [int ] = None ) -> Union [Float [ torch . Tensor , "batch classes" ],
137+ Float [ torch . Tensor , "batch output classes" ]]:
136138 """
137139 Average probability
138140
0 commit comments