diff --git a/src/pytorch_ie/taskmodules/taskmodule.py b/src/pytorch_ie/taskmodules/taskmodule.py index 695d53f1..b3ef0217 100644 --- a/src/pytorch_ie/taskmodules/taskmodule.py +++ b/src/pytorch_ie/taskmodules/taskmodule.py @@ -101,7 +101,9 @@ def encode( @abstractmethod def encode_input( - self, documents: List[Document] + self, + documents: List[Document], + is_training: bool = False, ) -> Tuple[List[InputEncoding], List[Metadata], Optional[List[Document]]]: raise NotImplementedError() diff --git a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py index 7043fb3c..f361810c 100644 --- a/src/pytorch_ie/taskmodules/transformer_re_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_re_text_classification.py @@ -266,7 +266,9 @@ def _encode_text( return encoding def encode_input( - self, documents: List[Document] + self, + documents: List[Document], + is_training: bool = False, ) -> Tuple[ List[TransformerReTextClassificationInputEncoding], List[Metadata], diff --git a/src/pytorch_ie/taskmodules/transformer_seq2seq.py b/src/pytorch_ie/taskmodules/transformer_seq2seq.py index 6ac12143..0af81dc0 100644 --- a/src/pytorch_ie/taskmodules/transformer_seq2seq.py +++ b/src/pytorch_ie/taskmodules/transformer_seq2seq.py @@ -81,7 +81,9 @@ def encode_input_strings(self, inputs: List[str]) -> List[TransformerSeq2SeqInpu ] def encode_input( - self, documents: List[Document] + self, + documents: List[Document], + is_training: bool = False, ) -> Tuple[List[TransformerSeq2SeqInputEncoding], List[Metadata], Optional[List[Document]]]: input_strings = [self.document_to_input_string(document) for document in documents] return ( diff --git a/src/pytorch_ie/taskmodules/transformer_span_classification.py b/src/pytorch_ie/taskmodules/transformer_span_classification.py index d2a367ed..6c844c1d 100644 --- a/src/pytorch_ie/taskmodules/transformer_span_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_span_classification.py @@ -99,7 +99,9 @@ def prepare(self, documents: List[Document]) -> None: self.id_to_label = {v: k for k, v in self.label_to_id.items()} def encode_input( - self, documents: List[Document] + self, + documents: List[Document], + is_training: bool = False, ) -> Tuple[ List[TransformerSpanClassificationInputEncoding], List[Metadata], diff --git a/src/pytorch_ie/taskmodules/transformer_text_classification.py b/src/pytorch_ie/taskmodules/transformer_text_classification.py index 53018c42..11221e1b 100644 --- a/src/pytorch_ie/taskmodules/transformer_text_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_text_classification.py @@ -119,7 +119,9 @@ def prepare(self, documents: List[Document]) -> None: self.id_to_label = {v: k for k, v in self.label_to_id.items()} def encode_input( - self, documents: List[Document] + self, + documents: List[Document], + is_training: bool = False, ) -> Tuple[ List[TransformerTextClassificationInputEncoding], List[Metadata], diff --git a/src/pytorch_ie/taskmodules/transformer_token_classification.py b/src/pytorch_ie/taskmodules/transformer_token_classification.py index 2c0f9906..9904b7f6 100644 --- a/src/pytorch_ie/taskmodules/transformer_token_classification.py +++ b/src/pytorch_ie/taskmodules/transformer_token_classification.py @@ -127,7 +127,9 @@ def encode_text( ) def encode_input( - self, documents: List[Document] + self, + documents: List[Document], + is_training: bool = False, ) -> Tuple[ List[TransformerTokenClassificationInputEncoding], List[Metadata],