diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index 604fc601b..b09f8fb93 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -8,7 +8,7 @@ import torch from torch import Tensor -from torch.nn import Linear +from torch.nn import Linear, Identity from torch_geometric.data import Data if TYPE_CHECKING: @@ -235,6 +235,7 @@ def __init__( self, hidden_size: int, loss_function: "LossFunction", + disable_affine: bool = False, **task_kwargs: Any, ): """Construct `LearnedTask`. @@ -244,18 +245,23 @@ def __init__( the last latent layer of `Model` using this Task. Available through `Model.nb_outputs` loss_function: Loss function appropriate to the task. + disable_affine: Disable linear layer mapping from hidden layer size + to number of inputs. """ # Base class constructor super().__init__(**task_kwargs) # Mapping from last hidden layer to required size of input self._loss_function = loss_function - self._affine = Linear(hidden_size, self.nb_inputs) + self._disable_affine = disable_affine + + if self._disable_affine: + self._affine = Identity() + else: + self._affine = Linear(hidden_size, self.nb_inputs) @abstractmethod - def _forward( # type: ignore - self, x: Union[Tensor, Data] - ) -> Union[Tensor, Data]: + def _forward(self, x: Union[Tensor, Data]) -> Union[Tensor, Data]: """Syntax like `.forward`, for implentation in inheriting classes.""" raise NotImplementedError @@ -272,9 +278,7 @@ def nb_inputs(self) -> int: """Return number of inputs assumed by task.""" @final - def forward( # type: ignore - self, x: Union[Tensor, Data] - ) -> Union[Tensor, Data]: + def forward(self, x: Union[Tensor, Data]) -> Union[Tensor, Data]: """Forward call for `LearnedTask`. The learned embedding transforms last latent layer of Model to meet