Skip to content

Commit

Permalink
#125 switched to cropping mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Nov 23, 2023
1 parent e787892 commit 03d61f6
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 16 deletions.
3 changes: 2 additions & 1 deletion arelight/run/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def setup_collection_name(value):
provider_type=SampleFormattersService.name_to_type(args.text_b_type) if args.text_b_type is not None else None,
# We annotate everything with NoLabel.
label_scaler=SingleLabelScaler(NoLabel()),
entity_formatter=SharpPrefixedEntitiesSimpleFormatter()),
entity_formatter=SharpPrefixedEntitiesSimpleFormatter(),
crop_window=terms_per_context),
"samples_io": SamplesIO(target_dir=output_dir,
prefix=collection_name,
reader=JsonlReader(),
Expand Down
16 changes: 6 additions & 10 deletions arelight/samplers/bert.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from enum import Enum

from arekit.common.data.input.providers.label.multiple import MultipleLabelProvider
from arekit.common.data.input.providers.rows.samples import BaseSampleRowProvider
from arekit.common.data.input.providers.text.single import BaseSingleTextProvider
from arekit.common.entities.str_fmt import StringEntitiesFormatter
from arekit.common.labels.scaler.base import BaseLabelScaler
from arekit.contrib.bert.input.providers.text_pair import PairTextProvider
from arekit.contrib.bert.input.providers.cropped_sample import CroppedBertSampleRowProvider
from arekit.contrib.bert.terms.mapper import BertDefaultStringTextTermsMapper

from arelight.samplers.types import BertSampleProviderTypes
Expand All @@ -28,7 +25,7 @@ class BertTextBRussianPrompts(Enum):
QA = 'Что вы думаете по поводу отношения {subject} к {object} в контексте : << {context} >> ?'


def create_bert_sample_provider(provider_type, label_scaler, entity_formatter):
def create_bert_sample_provider(provider_type, label_scaler, entity_formatter, crop_window):
""" This is a factory method, which allows to instantiate any of the
supported bert_sample_encoders
"""
Expand All @@ -44,8 +41,7 @@ def create_bert_sample_provider(provider_type, label_scaler, entity_formatter):
if provider_type == BertSampleProviderTypes.QA_M:
text_b_prompt = BertTextBRussianPrompts.QA.value

text_provider = PairTextProvider(text_b_prompt=text_b_prompt, text_terms_mapper=text_terms_mapper)\
if text_b_prompt is not None else BaseSingleTextProvider(text_terms_mapper)

return BaseSampleRowProvider(label_provider=MultipleLabelProvider(label_scaler),
text_provider=text_provider)
return CroppedBertSampleRowProvider(crop_window_size=crop_window,
text_b_template=text_b_prompt,
text_terms_mapper=text_terms_mapper,
label_scaler=label_scaler)
3 changes: 2 additions & 1 deletion test/test_pipeline_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def create_sampling_params(self):
"rows_provider": create_bert_sample_provider(
label_scaler=SingleLabelScaler(NoLabel()),
provider_type=BertSampleProviderTypes.NLI_M,
entity_formatter=SharpPrefixedEntitiesSimpleFormatter()),
entity_formatter=SharpPrefixedEntitiesSimpleFormatter(),
crop_window=50),
"save_labels_func": lambda _: False,
"samples_io": SamplesIO(target_dir=utils.TEST_OUT_DIR,
reader=JsonlReader(),
Expand Down
8 changes: 5 additions & 3 deletions test/test_pipeline_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ def test(self):
doc_provider = utils.InMemoryDocProvider(docs=BertTestSerialization.input_to_docs(texts))
pipeline = BasePipeline([AREkitSerializerPipelineItem(
rows_provider=create_bert_sample_provider(
label_scaler=SingleLabelScaler(NoLabel()),
provider_type=BertSampleProviderTypes.NLI_M,
entity_formatter=SharpPrefixedEntitiesSimpleFormatter()),
label_scaler=SingleLabelScaler(NoLabel()),
provider_type=BertSampleProviderTypes.NLI_M,
entity_formatter=SharpPrefixedEntitiesSimpleFormatter(),
crop_window=50,
),
save_labels_func=lambda _: False,
samples_io=SamplesIO(target_dir=utils.TEST_OUT_DIR, writer=NativeCsvWriter(delimiter=',')),
storage=RowCacheStorage(force_collect_columns=[
Expand Down
3 changes: 2 additions & 1 deletion test/utils_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def test_ner(texts, ner_ppl_items, prefix):
rows_provider = create_bert_sample_provider(
label_scaler=single_label_scaler,
provider_type=BertSampleProviderTypes.NLI_M,
entity_formatter=SharpPrefixedEntitiesSimpleFormatter())
entity_formatter=SharpPrefixedEntitiesSimpleFormatter(),
crop_window=50)

pipeline = BasePipeline([
BaseSerializerPipelineItem(
Expand Down

0 comments on commit 03d61f6

Please sign in to comment.