diff --git a/flair/datasets.py b/flair/datasets.py index a56ec0baaf..cb0cef67d8 100644 --- a/flair/datasets.py +++ b/flair/datasets.py @@ -681,6 +681,22 @@ def __init__( for row in csv_reader: + # test if format is OK + wrong_format = False + for text_column in self.text_columns: + if text_column >= len(row): + wrong_format = True + + # test if at least one label given + has_label = False + for column in self.column_name_map: + if self.column_name_map[column].startswith("label") and row[column]: + has_label = True + break + + if wrong_format or not has_label: + continue + if self.in_memory: text = " || ".join( diff --git a/flair/models/language_model.py b/flair/models/language_model.py index 9084d3f7a0..5f6fdecf84 100644 --- a/flair/models/language_model.py +++ b/flair/models/language_model.py @@ -130,7 +130,7 @@ def get_representation(self, strings: List[str], chars_per_chunk: int = 512): prediction, rnn_output, hidden = self.forward(batch, hidden) rnn_output = rnn_output.detach() - output_parts.append(rnn_output) + output_parts.append(rnn_output.to("cpu")) # concatenate all chunks to make final output output = torch.cat(output_parts)