-
Notifications
You must be signed in to change notification settings - Fork 36
/
file_loader.py
804 lines (701 loc) · 30.6 KB
/
file_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
"""File loaders."""
from __future__ import annotations
import copy
import csv
import dataclasses
from dataclasses import dataclass
from io import StringIO
import itertools
import json
from typing import (
Any,
cast,
ClassVar,
final,
Iterable,
Optional,
Sized,
Type,
TypeVar,
Union,
)
from datalabs import DatasetDict, IterableDatasetDict, load_dataset
from datalabs.features.features import ClassLabel, Sequence
from explainaboard.analysis.analyses import Analysis
from explainaboard.analysis.feature import FeatureType
from explainaboard.constants import Source
from explainaboard.serialization.serializers import PrimitiveSerializer
from explainaboard.utils.load_resources import get_customized_features
from explainaboard.utils.preprocessor import Preprocessor
from explainaboard.utils.typing_utils import narrow
DType = Union[Type[int], Type[float], Type[str], Type[dict], Type[list]]
T = TypeVar('T')
@dataclass
class FileLoaderField:
"""A filed in a file loader.
Attributes:
src_name: field name in the source file. use int for tsv column indices,
str for dict keys, or tuple for hierarchical dict keys
target_name: field name expected in the loaded data
dtype: data type of the field in the loaded data. It is only intended for
simple type conversion so it only supports int, float and str. Pass in None
to turn off type conversion.
strip_before_parsing: call strip() on strings before casting to either str,
int or float. It is only intended to be used with these three data types.
It defaults to True for str. For all other types, it defaults to False
parser: a custom parser for the field. When called,
`data_points[idx][src_name]` is passed in as input, it is expected to return
the parsed result. If parser is not None, `strip_before_parsing` and dtype
will not have any effect.
"""
src_name: int | str | Iterable[str]
target_name: int | str
dtype: Optional[DType] = None
strip_before_parsing: Optional[bool] = None
parser: Optional[Preprocessor] = None
# Special constants used in field mapping
SOURCE_LANGUAGE: ClassVar[str] = '__SOURCE_LANGUAGE__'
TARGET_LANGUAGE: ClassVar[str] = '__TARGET_LANGUAGE__'
def __post_init__(self):
"""Validate data and set defaults."""
if self.strip_before_parsing is None:
self.strip_before_parsing = self.dtype == str
if self.dtype is None and self.strip_before_parsing:
raise ValueError(
"strip_before_parsing only works with int, float and str types"
)
if self.dtype not in (str, int, float, dict, list, None):
raise ValueError(
"dtype must be one of str, int, float, dict, list, and None"
)
@dataclass
class FileLoaderMetadata:
"""Metadata populated in the process of loading the dataset or output files.
Attributes:
source_language: The language of the input
target_language: The language of the output
supported_languages: All languages supported by the dataset at all
task: The specific task to be analyzed
supported_tasks: The task or tasks that *can* be handled (e.g. by a dataset)
"""
system_name: str | None = None
dataset_name: str | None = None
sub_dataset_name: str | None = None
split: str | None = None
source_language: str | None = None
target_language: str | None = None
supported_languages: list[str] | None = None
task_name: str | None = None
supported_tasks: list[str] | None = None
custom_features: dict[str, dict[str, FeatureType]] | None = None
# analysis level name -> list of analyses dictionary
custom_analyses: list[Analysis] | None = None
def merge(self, other: FileLoaderMetadata) -> None:
"""Merge the information from the two pieces of metadata.
In the case that the two conflict, the passed-in metadata get preference.
"""
# TODO(gneubig): This should be changed into a for loop
self.system_name = other.system_name or self.system_name
self.dataset_name = other.dataset_name or self.dataset_name
self.sub_dataset_name = other.sub_dataset_name or self.sub_dataset_name
self.split = other.split or self.split
self.source_language = other.source_language or self.source_language
self.target_language = other.target_language or self.target_language
self.supported_languages = other.supported_languages or self.supported_languages
self.task_name = other.task_name or self.task_name
self.supported_tasks = other.supported_tasks or self.supported_tasks
self.custom_features = other.custom_features or self.custom_features
self.custom_analyses = other.custom_analyses or self.custom_analyses
@classmethod
def from_dict(cls, data: dict) -> FileLoaderMetadata:
"""Deserialize a dictionary into File Loader Metadata.
Args:
data: The data to serialize.
"""
# TODO(gneubig): A better way to do this might be through a library such as
# pydantic or dacite
source_language = data.get('source_language')
target_language = data.get('target_language')
if data.get('language'):
if source_language or target_language:
raise ValueError(
'can not set both "language" and "source_language"/'
'"target_language"'
)
source_language = target_language = data.get('language')
custom_features: dict[str, dict[str, FeatureType]] | None = None
custom_analyses: list[Analysis] | None = None
if 'custom_features' in data:
ft_serializer = PrimitiveSerializer()
custom_features = {
k1: {
# See https://github.com/python/mypy/issues/4717
k2: narrow(
FeatureType, ft_serializer.deserialize(v2) # type: ignore
)
for k2, v2 in v1.items()
}
for k1, v1 in data['custom_features'].items()
}
if 'custom_analyses' in data:
custom_analyses = data['custom_analyses']
return FileLoaderMetadata(
system_name=data.get('system_name'),
dataset_name=data.get('dataset_name'),
sub_dataset_name=data.get('sub_dataset_name'),
split=data.get('split'),
source_language=source_language,
target_language=target_language,
supported_languages=data.get('supported_languages'),
task_name=data.get('task_name'),
supported_tasks=data.get('supported_tasks'),
custom_features=custom_features,
custom_analyses=custom_analyses,
)
@classmethod
def from_file(cls, file_name: str) -> FileLoaderMetadata:
"""Load meta data from a file.
Args:
file_name: The name of the file to load from.
Returns:
A file loader metadata class.
"""
with open(file_name, 'r') as file_in:
my_data = json.load(file_in)
if not isinstance(my_data, dict) or 'metadata' not in my_data:
raise ValueError(f'Could not find metadata in {file_name}')
else:
return FileLoaderMetadata.from_file(my_data['metadata'])
@dataclass
class FileLoaderReturn(Sized):
"""Data returned by a FileLoader.
Attributes:
samples: A list of samples from the dataset
metadata: Metadata regarding the samples or the dataset
"""
samples: list
metadata: FileLoaderMetadata = dataclasses.field(
default_factory=lambda: FileLoaderMetadata()
)
def __len__(self) -> int:
"""Return the length of the samples."""
return len(self.samples)
def __getitem__(self, item: int) -> Any:
"""Get a certain sample from the dataset."""
return self.samples[item]
class FileLoader:
"""A class that loads raw data from a file."""
def __init__(
self,
fields: list[FileLoaderField] = None,
use_idx_as_id: bool = True,
id_field_name: Optional[str] = None,
) -> None:
"""Loader that loads data according to fields.
Args:
fields: A specification of the fields to be read in from the file.
use_idx_as_id: whether to use sample indices as IDs. Generated IDs are
str even though it represents an index.
(This is to make sure all sample IDs are str.)
id_field_name: The name of the field to be used as an ID
"""
self._fields = fields or []
self._use_idx_as_id = use_idx_as_id
self._id_field_name = id_field_name
self.validate()
def validate(self) -> None:
"""Validates the setting of the fields.
Raises:
ValueError: if the fields are not valid
"""
if self._use_idx_as_id and self._id_field_name:
raise ValueError("id_field_name must be None when use_idx_as_id is True")
target_names = [field.target_name for field in self._fields]
if len(target_names) != len(set(target_names)):
raise ValueError("target_name must be unique")
@final
def add_fields(self, fields: list[FileLoaderField]) -> None:
"""Add more more fields to the FileLoader.
Args:
fields: The fields to add.
Raises:
ValueError: if the fields are wrong.
"""
self._fields.extend(fields)
self.validate()
@staticmethod
def parse_data(data: Any, field: FileLoaderField) -> Any:
"""Parse data loaded in from a file to the required data type.
Args:
data: The data loaded in from the file.
field: Information about the field.
Returns:
The parsed data.
"""
if field.parser:
return field.parser(data)
if field.strip_before_parsing:
data = (
data.strip() if isinstance(data, str) else data
) # some time data could be a nested json object
dtype = field.dtype
if dtype == int:
return int(data)
elif dtype == float:
return float(data)
elif dtype == str:
return str(data)
elif dtype == list or dtype == dict:
return data # TODO(Pengfei): I add the `dict` type for temporal use,
# but wonder if we need to generalize the current type mechanism,
elif dtype is None:
return data
raise NotImplementedError(f"dtype {dtype} is not supported")
def generate_id(self, parsed_data_point: dict, sample_idx: int) -> None:
"""Generates an id attribute for each data point in place.
Args:
parsed_data_point: The data point parsed into a dict.
sample_idx: The ID of the sample.
"""
if self._use_idx_as_id:
parsed_data_point["id"] = str(sample_idx)
elif self._id_field_name:
if self._id_field_name not in parsed_data_point:
raise ValueError(
f"The {sample_idx} data point in system outputs file does not have "
f"field {self._id_field_name}"
)
parsed_data_point["id"] = str(parsed_data_point[self._id_field_name])
def load_raw(
self, data: str | DatalabLoaderOption, source: Source
) -> FileLoaderReturn:
"""Load data from source and return an iterable of data points.
It does not use fields information to parse the data points.
Args:
data: if str, it's either base64 encoded system output or a path
source: source of data
Returns:
The data loaded from the file.
"""
raise NotImplementedError(
"load_raw() is not implemented for the base FileLoader"
)
def _map_fields(self, fields: list, field_mapping: dict[str, str] | None = None):
new_fields = copy.deepcopy(fields)
if field_mapping is not None:
for field in new_fields:
if isinstance(field.src_name, str):
field.src_name = field_mapping.get(field.src_name, field.src_name)
elif isinstance(field.src_name, Iterable):
field.src_name = [field_mapping.get(x, x) for x in field.src_name]
else:
field.src_name = field.src_name
return new_fields
@classmethod
def find_field(
cls,
data_point: dict,
field: FileLoaderField,
field_mapping: dict[str, str] | None = None,
):
"""In a structured dictionary, find a specified field.
This can be specified by an
* int index to a list (data_point[field])
* str index to a dictionary (datapoint[field])
* Iterable[str] index to a dictionary (datapoint[field[0]][field[1]]...)
Args:
data_point: The data to search
field: The file loader field corresponding to the dict
field_mapping: A mapping between field names. If a str in a field name
exists as a key in the mapping, the value will be used to search instead
Returns:
the required data
"""
if field_mapping is None:
field_mapping = {}
if isinstance(data_point, list):
int_idx = int(field.src_name)
if int_idx >= len(data_point):
raise ValueError(
f'{cls.__name__}: Could not find '
f'field "{field.src_name}" in datapoint {data_point}'
)
return data_point[int_idx]
elif isinstance(data_point, dict):
if isinstance(field.src_name, int):
raise ValueError(f'unexpected int index for dict data_point in {field}')
# Parse a string or tuple identifier
field_list = (
[field.src_name] if isinstance(field.src_name, str) else field.src_name
)
ret_dict = data_point
for sub_field in field_list:
sub_field = field_mapping.get(sub_field, sub_field)
if sub_field not in ret_dict:
raise ValueError(
f'{cls.__name__}: Could not find '
f'field "{field.src_name}" in datapoint {data_point}'
)
ret_dict = ret_dict[sub_field]
return ret_dict
def load(
self,
data: str | DatalabLoaderOption,
source: Source,
field_mapping: dict[str, str] | None = None,
) -> FileLoaderReturn:
"""Load data from source, parse data points with fields information.
Args:
data: An indication of the data to be loading
source: The source from which it should be loaded
field_mapping: A mapping from field name in the loader spec to field name
in the actual input
Returns:
an iterable of data points.
"""
raw_data = self.load_raw(data, source)
parsed_data_points: list[dict] = []
# Get language information from meta-data if it doesn't exist already
actual_mapping = field_mapping or {}
for lang, meta in [
(FileLoaderField.SOURCE_LANGUAGE, raw_data.metadata.source_language),
(FileLoaderField.TARGET_LANGUAGE, raw_data.metadata.target_language),
]:
temp = actual_mapping.get(lang) or meta
if temp is not None:
actual_mapping[lang] = temp
# map the field names
before_fields = copy.deepcopy(self._fields)
fields = self._map_fields(self._fields, actual_mapping)
if raw_data.metadata.custom_features is not None:
for level_name, feats in raw_data.metadata.custom_features.items():
if level_name == 'example':
for feat in feats:
fields.append(FileLoaderField(feat, feat, None))
else:
raise ValueError(
'cannot currently load custom features other '
f'than on the example level (got {level_name})'
)
assert [x.src_name for x in before_fields] == [x.src_name for x in self._fields]
# process the actual data
for idx, data_point in enumerate(raw_data.samples):
parsed_data_point = {}
for field in fields: # parse data point according to fields
parsed_data_point[field.target_name] = self.parse_data(
self.find_field(data_point, field, field_mapping), field
)
self.generate_id(parsed_data_point, idx)
parsed_data_points.append(parsed_data_point)
return FileLoaderReturn(parsed_data_points, raw_data.metadata)
class TSVFileLoader(FileLoader):
"""A class for loading from TSV files."""
def validate(self) -> None:
"""See FileLoader.validate."""
super().validate()
for field in self._fields:
if not isinstance(field.src_name, int):
raise ValueError("field src_name for TSVFileLoader must be an int")
def load_raw(
self, data: str | DatalabLoaderOption, source: Source
) -> FileLoaderReturn:
"""See FileLoader.load_raw."""
data = narrow(str, data)
if source == Source.in_memory:
file = StringIO(data)
lines = list(csv.reader(file, delimiter='\t', quoting=csv.QUOTE_NONE))
elif source == Source.local_filesystem:
with open(data, "r", encoding="utf8") as fin:
lines = list(csv.reader(fin, delimiter='\t', quoting=csv.QUOTE_NONE))
else:
raise NotImplementedError
return FileLoaderReturn(
list(filter(lambda line: line, lines))
) # remove empty lines
class CoNLLFileLoader(FileLoader):
"""A loader from CoNLL-formatted files."""
def __init__(self, fields: list[FileLoaderField] = None) -> None:
"""Constructor.
Args:
fields: A list of fields to read from each column.
"""
super().__init__(fields, False)
def validate(self) -> None:
"""See FileLoader.validate."""
super().validate()
if len(self._fields) not in [1, 2]:
raise ValueError(
"CoNLL file loader expects 1 or 2 fields "
+ f"({len(self._fields)} given)"
)
def load_raw(
self, data: str | DatalabLoaderOption, source: Source
) -> FileLoaderReturn:
"""See FileLoader.load_raw."""
data = narrow(str, data)
if source == Source.in_memory:
return FileLoaderReturn(data.splitlines())
elif source == Source.local_filesystem:
with open(data, "r", encoding="utf8") as fin:
return FileLoaderReturn([line.strip() for line in fin])
raise NotImplementedError
def load(
self,
data: str | DatalabLoaderOption,
source: Source,
field_mapping: dict[str, str] | None = None,
) -> FileLoaderReturn:
"""See FileLoader.load."""
raw_data = self.load_raw(data, source)
parsed_samples: list[dict] = []
guid = 0
curr_sentence_fields: dict[str | int | Iterable[str], list[str]] = {
field.src_name: [] for field in self._fields
}
def add_sample():
nonlocal guid, curr_sentence_fields
# uses the first field to check if data is empty
if curr_sentence_fields.get(self._fields[0].src_name):
new_sample: dict = {}
for field in self._fields: # parse data point according to fields
new_sample[field.target_name] = curr_sentence_fields[field.src_name]
new_sample["id"] = str(guid)
parsed_samples.append(new_sample)
guid += 1
curr_sentence_fields = {
field.src_name: [] for field in self._fields
} # reset
max_field: int = max([narrow(int, x.src_name) for x in self._fields])
for line in raw_data.samples:
# at sentence boundary
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
add_sample()
else:
splits = line.split("\t")
if len(splits) <= max_field: # not separated by tabs
splits = line.split(" ")
if len(splits) <= max_field: # not separated by tabs or spaces
raise ValueError(
f"not enough fields for {line} (sentence index: {guid})"
)
for field in self._fields:
curr_sentence_fields[field.src_name].append(
self.parse_data(splits[narrow(int, field.src_name)], field)
)
add_sample() # add last example
return FileLoaderReturn(parsed_samples, metadata=raw_data.metadata)
class JSONFileLoader(FileLoader):
"""A loader from JSON files."""
def load_raw(
self, data: str | DatalabLoaderOption, source: Source
) -> FileLoaderReturn:
"""See FileLoader.load_raw."""
data = narrow(str, data)
if source == Source.in_memory:
loaded = json.loads(data)
elif source == Source.local_filesystem:
with open(data, 'r', encoding="utf8") as json_file:
loaded = json.load(json_file)
else:
raise NotImplementedError
if isinstance(loaded, list):
return FileLoaderReturn(loaded)
else:
if 'examples' not in loaded or not isinstance(loaded['examples'], list):
raise ValueError(
f'Error loading {data}. Input JSON files in dict '
'format must have a list "examples"'
)
raw_data = loaded.pop('examples')
if len(loaded) > 1 or (len(loaded) == 1 and 'metadata' not in loaded):
raise ValueError(
f'Error loading {data}. Input JSON files in dict '
'format must have "examples" and optionally '
'"metadata", nothing else'
)
metadata = FileLoaderMetadata.from_dict(loaded.get('metadata', {}))
return FileLoaderReturn(raw_data, metadata=metadata)
@dataclass
class DatalabLoaderOption:
"""A class representing the options when using DataLabLoader.
Attributes:
dataset: The name of the dataset
subdataset: The name of the sub dataset (optional)
split: The name of the split
custom_features: Custom features that are added to the retrieved file
custom_analyses: Custom analyses that are added
"""
dataset: str
subdataset: str | None = None
split: str = "test"
custom_features: dict[str, dict[str, FeatureType]] | None = None
custom_analyses: list[Analysis] | None = None
class DatalabFileLoader(FileLoader):
"""A file loader for loading from DataLab."""
@classmethod
def _replace_one(cls, names: list[str], lab: int):
return names[lab] if lab != -1 else '_NULL_'
@classmethod
def _replace_labels(cls, features: dict, example: dict) -> dict:
new_example = {}
for examp_k, examp_v in example.items():
examp_f = features[examp_k]
# Label feature
if isinstance(examp_f, ClassLabel):
names = cast(ClassLabel, examp_f).names
new_example[examp_k] = cls._replace_one(names, examp_v)
# Sequence feature
elif isinstance(examp_f, Sequence):
examp_seq = cast(Sequence, examp_f)
# Sequence of labels
if isinstance(examp_seq.feature, ClassLabel):
names = cast(ClassLabel, examp_seq.feature).names
new_example[examp_k] = [cls._replace_one(names, x) for x in examp_v]
# Sequence of anything else
else:
new_example[examp_k] = examp_v
# Anything else
else:
new_example[examp_k] = examp_v
return new_example
def load_raw(
self, data: str | DatalabLoaderOption, source: Source
) -> FileLoaderReturn:
"""See FileLoader.load_raw."""
config = narrow(DatalabLoaderOption, data)
ft_serializer = PrimitiveSerializer()
# load customized features from global config files
customized_features_from_config = get_customized_features()
if config.dataset in customized_features_from_config:
ds_feats = customized_features_from_config[config.dataset]
if config.custom_features is None:
config.custom_features = {}
for level_name, level_feats in ds_feats['custom_features'].items():
if level_name != 'example':
raise NotImplementedError(
'currently custom features are only '
'supported on the example level, but got '
f'{level_name}'
)
parsed_level_feats = {
# See https://github.com/python/mypy/issues/4717
k: narrow(FeatureType, ft_serializer.deserialize(v)) # type: ignore
for k, v in level_feats.items()
}
new_features = config.custom_features.get(level_name, {})
new_features.update(parsed_level_feats)
config.custom_features[level_name] = new_features
dataset = load_dataset(
config.dataset, config.subdataset, split=config.split, streaming=False
)
# TODO(gneubig): patch for an inconsistency in datalab, where DatasetDict
# doesn't have info
if isinstance(dataset, DatasetDict) or isinstance(dataset, IterableDatasetDict):
raise ValueError('Cannot handle DatasetDict returns')
info = dataset.info
# update src_name based on task schema (e.g., text1_column => sentence1)
ignore_tasks = ["machine-translation", "code-generation"]
if info.task_templates[0].task not in ignore_tasks:
for idx in range(len(self._fields)):
src_name = cast(str, self._fields[idx].src_name)
self._fields[idx].src_name = getattr(info.task_templates[0], src_name)
# Infer metadata from the dataset
metadata = FileLoaderMetadata(
custom_features=config.custom_features,
custom_analyses=config.custom_analyses,
)
# load customized features from global config files
if config.dataset in customized_features_from_config:
if metadata.custom_features is None:
metadata.custom_features = {}
metadata.custom_features.update(
customized_features_from_config[config.dataset]['custom_features']
)
if metadata.custom_analyses is None:
metadata.custom_analyses = []
metadata.custom_analyses.extend(
customized_features_from_config[config.dataset]['custom_analyses']
)
if info.languages is not None:
metadata.supported_languages = info.languages
# Infer languages:
# If only one language is supported, set both source and target to that
# language. If two are supported set source to lang[0], target to lang[1].
# Otherwise, do not infer the language at all.
if (
metadata.supported_languages
and len(metadata.supported_languages) > 0
and len(metadata.supported_languages) < 3
):
metadata.source_language = metadata.supported_languages[0]
metadata.target_language = metadata.supported_languages[
0 if len(metadata.supported_languages) == 1 else 1
]
if info.task_templates is not None:
tt = info.task_templates
metadata.supported_tasks = list(
itertools.chain.from_iterable(
[[x.task] + x.task_categories for x in tt]
)
)
# Return
return FileLoaderReturn(
[self._replace_labels(info.features, x) for x in dataset],
metadata=metadata,
)
class TextFileLoader(FileLoader):
"""Loads a text file where each line is a different sample.
Only one field is allowed. It is often used for predicted outputs of text generation
models.
"""
def __init__(
self,
target_name: str = "output",
dtype: DType = str,
strip_before_parsing: Optional[bool] = None,
) -> None:
"""Constructor.
src_name is not used for this file loader, it overrides the base load method.
Args:
target_name: The name of the target field.
dtype: The type of the field.
strip_before_parsing: Whether to strip white space before parsing.
"""
super().__init__(
[FileLoaderField("_", target_name, dtype, strip_before_parsing)],
use_idx_as_id=True,
)
@classmethod
def load_raw(
cls, data: str | DatalabLoaderOption, source: Source
) -> FileLoaderReturn:
"""See FileLoader.load_raw."""
data = narrow(str, data)
if source == Source.in_memory:
return FileLoaderReturn(data.splitlines())
elif source == Source.local_filesystem:
with open(data, "r", encoding="utf8") as f:
return FileLoaderReturn(f.readlines())
raise NotImplementedError
def validate(self) -> None:
"""See FileLoader.validate."""
super().validate()
if len(self._fields) != 1:
raise ValueError("Text File Loader only takes one field")
def load(
self,
data: str | DatalabLoaderOption,
source: Source,
field_mapping: dict[str, str] | None = None,
) -> FileLoaderReturn:
"""See FileLoader.load."""
raw_data = self.load_raw(data, source)
data_list: list[str] = raw_data.samples
parsed_data_points: list[dict] = []
for idx, data_point in enumerate(data_list):
parsed_data_point = {}
field = self._fields[0]
parsed_data_point[field.target_name] = self.parse_data(data_point, field)
self.generate_id(parsed_data_point, idx)
parsed_data_points.append(parsed_data_point)
return FileLoaderReturn(parsed_data_points)