-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_tutorial_pipeline_sampling_bert.py
122 lines (97 loc) · 5.3 KB
/
test_tutorial_pipeline_sampling_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import unittest
from collections import OrderedDict
from arekit.common.data.input.providers.label.multiple import MultipleLabelProvider
from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvider
from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.data.input.writers.tsv import TsvWriter
from arekit.common.entities.base import Entity
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.entities.types import OpinionEntityType
from arekit.common.experiment.data_type import DataType
from arekit.common.folding.nofold import NoFolding
from arekit.common.labels.base import NoLabel, Label
from arekit.common.labels.scaler.base import BaseLabelScaler
from arekit.common.labels.str_fmt import StringLabelsFormatter
from arekit.common.pipeline.base import BasePipeline
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.bert.input.providers.text_pair import PairTextProvider
from arekit.contrib.bert.terms.mapper import BertDefaultStringTextTermsMapper
from arekit.contrib.source.brat.entities.parser import BratTextEntitiesParser
from arekit.contrib.utils.io_utils.samples import SamplesIO
from arekit.contrib.utils.pipelines.items.sampling.bert import BertExperimentInputSerializerPipelineItem
from arekit.contrib.utils.pipelines.items.text.tokenizer import DefaultTextTokenizer
from arekit.contrib.utils.pipelines.text_opinion.annot.predefined import PredefinedTextOpinionAnnotator
from arekit.contrib.utils.pipelines.text_opinion.extraction import text_opinion_extraction_pipeline
from arekit.contrib.utils.pipelines.text_opinion.filters.distance_based import DistanceLimitedTextOpinionFilter
from tests.tutorials.test_tutorial_pipeline_text_opinion_annotation import FooDocumentOperations
class Positive(Label):
pass
class Negative(Label):
pass
class SentimentLabelScaler(BaseLabelScaler):
def __init__(self):
int_to_label = OrderedDict([(NoLabel(), 0), (Positive(), 1), (Negative(), -1)])
uint_to_label = OrderedDict([(NoLabel(), 0), (Positive(), 1), (Negative(), 2)])
super(SentimentLabelScaler, self).__init__(int_to_label, uint_to_label)
class CustomLabelsFormatter(StringLabelsFormatter):
def __init__(self, pos_label_type, neg_label_type):
stol = {"POSITIVE_TO": neg_label_type, "NEGATIVE_TO": pos_label_type}
super(CustomLabelsFormatter, self).__init__(stol=stol)
class CustomEntitiesFormatter(StringEntitiesFormatter):
def __init__(self, subject_fmt="[subject]", object_fmt="[object]"):
self.__subj_fmt = subject_fmt
self.__obj_fmt = object_fmt
def to_string(self, original_value, entity_type):
assert(isinstance(original_value, Entity))
if entity_type == OpinionEntityType.Other:
return original_value.Value
elif entity_type == OpinionEntityType.Object or entity_type == OpinionEntityType.SynonymObject:
return self.__obj_fmt
elif entity_type == OpinionEntityType.Subject or entity_type == OpinionEntityType.SynonymSubject:
return self.__subj_fmt
return None
class TestBertSerialization(unittest.TestCase):
def test(self):
text_b_template = '{subject} к {object} в контексте : << {context} >>'
terms_mapper = BertDefaultStringTextTermsMapper(
entity_formatter=CustomEntitiesFormatter(subject_fmt="#S", object_fmt="#O"))
text_provider = BaseSingleTextProvider(terms_mapper) \
if text_b_template is None else \
PairTextProvider(text_b_template, terms_mapper)
sample_rows_provider = BaseSampleRowProvider(
label_provider=MultipleLabelProvider(SentimentLabelScaler()),
text_provider=text_provider)
writer = TsvWriter(write_header=True)
samples_io = SamplesIO("out/", writer, target_extension=".tsv.gz")
pipeline_item = BertExperimentInputSerializerPipelineItem(
sample_rows_provider=sample_rows_provider,
samples_io=samples_io,
save_labels_func=lambda data_type: True,
balance_func=lambda data_type: data_type == DataType.Train)
pipeline = BasePipeline([
pipeline_item
])
#####
# Declaring pipeline related context parameters.
#####
no_folding = NoFolding(doc_ids=[0, 1], supported_data_type=DataType.Train)
doc_ops = FooDocumentOperations()
text_parser = BaseTextParser(pipeline=[BratTextEntitiesParser(), DefaultTextTokenizer(keep_tokens=True)])
train_pipeline = text_opinion_extraction_pipeline(
annotators=[
PredefinedTextOpinionAnnotator(
doc_ops,
label_formatter=CustomLabelsFormatter(pos_label_type=Positive,
neg_label_type=Negative))
],
text_opinion_filters=[
DistanceLimitedTextOpinionFilter(terms_per_context=50)
],
get_doc_func=lambda doc_id: doc_ops.get_doc(doc_id),
text_parser=text_parser)
#####
pipeline.run(input_data=None,
params_dict={
"data_folding": no_folding,
"data_type_pipelines": {DataType.Train: train_pipeline}
})