Skip to content

Commit

Permalink
#79 done, #71 done
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Sep 21, 2023
1 parent 9e7db0e commit bc65f82
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 185 deletions.
51 changes: 8 additions & 43 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,61 +31,26 @@ pip install git+https://github.com/nicolay-r/arelight@v0.24.0
python standalone.py
```

## Inference

> **Supported Languages**: Russian
## Usage

Infer sentiment attitudes from a mass-media document(s).

Using the `BERT` fine-tuned model version:
```bash
python3 -m arelight.run.infer.py --from-files data/texts-inosmi-rus/e1.txt \
python3 -m arelight.run.infer --from-files data/texts-inosmi-rus/e1.txt \
--ner-model-name "ner_ontonotes_bert_mult" \
--ner-types "ORG|PERSON|LOC|GPE" \
--synonyms data/synonyms.txt \
--labels-count 3 \
--terms-per-context 50 \
--sentence-parser "ru" \
--text-b-type "nli_m" \
--tokens-per-context 128 \
--text-b-type nli_m \
--sentence-parser ru \
--pretrained-bert "bert-base-uncased" \
-o output/brat_inference_output
```
From `CSV` file (you need to have `text` column; sentence parser could be disabled):
```bash
python3 -m arelight.run.infer.py \
--from-dataframe data/examples.csv \
--entities-parser bert-ontonotes \
--terms-per-context 50 \
--sentence-parser ru \
-o output/data
```
<p align="center">
<img src="docs/inference-bert-e1.png"/>
</p>

## Serialization

> **Supported Languages**: Russian/English
From list of files
```bash
python3 -m arelight.run.serialize --from-files data/texts-inosmi-rus/e1.txt \
--entities-parser bert-ontonotes \
--terms-per-context 50 \
--sentence-parser ru \
-o output/e1
```
From `CSV` file (you need to have `text` column; sentence parser could be disabled):
```bash
python3 -m arelight.run.serialize \
--from-dataframe data/examples.csv \
--entities-parser bert-ontonotes \
--terms-per-context 50 \
--sentence-parser ru \
-o output/data
```

<p align="center">
<img src="docs/samples-bert.png">
</p>

## Reference

* [Nicolay Rusnachenko: Language Models Application in Sentiment Attitude Extraction Task (2021) [RUS]](https://nicolay-r.github.io/website/data/rusnachenko2021language.pdf)
Expand Down
29 changes: 24 additions & 5 deletions arelight/pipelines/demo/infer_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from arekit.contrib.utils.data.storages.row_cache import RowCacheStorage
from arekit.contrib.utils.data.writers.csv_native import NativeCsvWriter
from arekit.contrib.utils.io_utils.samples import SamplesIO
from arekit.contrib.utils.pipelines.items.sampling.base import BaseSerializerPipelineItem
from arekit.contrib.utils.pipelines.items.sampling.bert import BertExperimentInputSerializerPipelineItem

from arelight.pipelines.demo.labels.base import PositiveLabel, NegativeLabel
from arelight.pipelines.items.backend_brat_html import BratHtmlEmbeddingPipelineItem
from arelight.pipelines.items.backend_brat_json import BratBackendContentsPipelineItem
from arelight.pipelines.items.inference_bert import BertInferencePipelineItem
from arelight.predict_writer_csv import TsvPredictWriter
Expand All @@ -23,8 +25,10 @@ def demo_infer_texts_bert_pipeline(pretrained_bert,
labels_scaler,
bert_config_path=None,
bert_vocab_path=None,
brat_backend=False,
text_b_type=SampleFormattersService.name_to_type("nli_m"),
max_seq_length=128):
assert(isinstance(pretrained_bert, str) or pretrained_bert is None)
assert(isinstance(samples_output_dir, str))
assert(isinstance(samples_prefix, str))
assert(isinstance(labels_scaler, BaseLabelScaler))
Expand All @@ -34,9 +38,9 @@ def demo_infer_texts_bert_pipeline(pretrained_bert,
prefix=samples_prefix,
writer=NativeCsvWriter(delimiter=','))

pipeline = BasePipeline(pipeline=[

BertExperimentInputSerializerPipelineItem(
# Serialization by default in the pipeline.
pipeline = [
BaseSerializerPipelineItem(
rows_provider=create_bert_sample_provider(
provider_type=text_b_type,
label_scaler=labels_scaler,
Expand All @@ -46,8 +50,14 @@ def demo_infer_texts_bert_pipeline(pretrained_bert,
# These additional columns required for BRAT visualization.
const.ENTITIES, const.ENTITY_VALUES, const.ENTITY_TYPES, const.SENT_IND
]),
save_labels_func=lambda data_type: data_type != DataType.Test),
save_labels_func=lambda data_type: data_type != DataType.Test)
]

if pretrained_bert is None:
return pipeline

# Add BERT processing pipeline.
pipeline += [
BertInferencePipelineItem(
pretrained_bert=pretrained_bert,
data_type=DataType.Test,
Expand All @@ -57,14 +67,23 @@ def demo_infer_texts_bert_pipeline(pretrained_bert,
vocab_filepath=bert_vocab_path,
max_seq_length=max_seq_length,
labels_count=labels_scaler.LabelsCount),
]

if not brat_backend:
return pipeline

pipeline += [
BratBackendContentsPipelineItem(label_to_rel={
str(labels_scaler.label_to_uint(PositiveLabel())): "POS",
str(labels_scaler.label_to_uint(NegativeLabel())): "NEG"
},
obj_color_types={"ORG": '#7fa2ff', "GPE": "#7fa200", "PERSON": "#7f00ff", "Frame": "#00a2ff"},
rel_color_types={"POS": "GREEN", "NEG": "RED"},
)
])
]

pipeline += [
BratHtmlEmbeddingPipelineItem(brat_url="http://localhost:8001/")
]

return pipeline
5 changes: 5 additions & 0 deletions arelight/predict_writer_csv.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import gzip
import logging

from arekit.common.utils import progress_bar_iter, create_dir_if_not_exists

from arelight.predict_writer import BasePredictWriter

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class TsvPredictWriter(BasePredictWriter):

Expand Down Expand Up @@ -34,6 +38,7 @@ def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
logger.info(f"Saved: {self._target}")
self.__f.close()

# endregion
31 changes: 18 additions & 13 deletions arelight/run/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from arekit.common.docs.entities_grouping import EntitiesGroupingPipelineItem
from arekit.common.experiment.data_type import DataType
from arekit.common.pipeline.base import BasePipeline
from arekit.common.synonyms.grouping import SynonymsCollectionValuesGroupingProviders
from arekit.common.text.parser import BaseTextParser
from arekit.contrib.utils.pipelines.items.text.terms_splitter import TermsSplitterParser
Expand All @@ -11,7 +12,6 @@
from arelight.doc_provider import InMemoryDocProvider
from arelight.pipelines.data.annot_pairs_nolabel import create_neutral_annotation_pipeline
from arelight.pipelines.demo.infer_bert import demo_infer_texts_bert_pipeline
from arelight.pipelines.items.backend_brat_html import BratHtmlEmbeddingPipelineItem
from arelight.pipelines.items.id_assigner import IdAssigner
from arelight.pipelines.items.utils import input_to_docs
from arelight.run.args import common, const
Expand All @@ -36,7 +36,7 @@
common.PredictOutputFilepathArg.add_argument(parser, default=const.OUTPUT_TEMPLATE)
common.NERModelNameArg.add_argument(parser, default="ner_ontonotes_bert_mult")
common.NERObjectTypes.add_argument(parser, default="ORG|PERSON|LOC|GPE")
common.PretrainedBERTArg.add_argument(parser, default="bert-base-uncased")
common.PretrainedBERTArg.add_argument(parser, default=None)
common.SentenceParserArg.add_argument(parser)
common.BertConfigFilepathArg.add_argument(parser, default=None)
common.BertVocabFilepathArg.add_argument(parser, default=None)
Expand All @@ -58,6 +58,7 @@
backend_template = common.PredictOutputFilepathArg.read_argument(args)
pretrained_bert = common.PretrainedBERTArg.read_argument(args)

# Setup main pipeline.
pipeline = demo_infer_texts_bert_pipeline(
pretrained_bert=pretrained_bert,
samples_output_dir=dirname(backend_template),
Expand All @@ -66,12 +67,16 @@
labels_scaler=create_labels_scaler(common.LabelsCountArg.read_argument(args)),
bert_config_path=common.BertConfigFilepathArg.read_argument(args),
bert_vocab_path=common.BertVocabFilepathArg.read_argument(args),
max_seq_length=common.TokensPerContextArg.read_argument(args))
max_seq_length=common.TokensPerContextArg.read_argument(args),
brat_backend=True)

pipeline = BasePipeline(pipeline)

synonyms_collection_path = common.SynonymsCollectionFilepathArg.read_argument(args)
synonyms = read_synonyms_collection(synonyms_collection_path) if synonyms_collection_path is not None else \
SimpleSynonymCollection(iter_group_values_lists=[], is_read_only=False)

# Setup text parser.
text_parser = BaseTextParser(pipeline=[
TermsSplitterParser(),
create_entity_parser(ner_model_name=ner_model_name,
Expand All @@ -82,21 +87,21 @@
synonyms=synonyms, value=value))
])

# Setup data annotation pipeline.
data_pipeline = create_neutral_annotation_pipeline(
synonyms=synonyms,
dist_in_terms_bound=terms_per_context,
doc_ops=InMemoryDocProvider(docs=input_to_docs(actual_content, sentence_parser=sentence_parser)),
terms_per_context=terms_per_context,
text_parser=text_parser)

pipeline.append(
BratHtmlEmbeddingPipelineItem(brat_url="http://localhost:8001/")
)

pipeline.run(None, {
"template_filepath": join(const.DATA_DIR, "brat_template.html"),
"predict_fp": "{}.tsv.gz".format(backend_template) if backend_template is not None else None,
"brat_vis_fp": "{}.html".format(backend_template) if backend_template is not None else None,
"data_type_pipelines": {DataType.Test: data_pipeline},
"doc_ids": list(range(len(actual_content))),
# Launch application.
pipeline.run(
input_data=None,
params_dict={
"template_filepath": join(const.DATA_DIR, "brat_template.html"),
"predict_fp": "{}.tsv.gz".format(backend_template) if backend_template is not None else None,
"brat_vis_fp": "{}.html".format(backend_template) if backend_template is not None else None,
"data_type_pipelines": {DataType.Test: data_pipeline},
"doc_ids": list(range(len(actual_content)))
})
120 changes: 0 additions & 120 deletions arelight/run/serialize.py

This file was deleted.

Loading

0 comments on commit bc65f82

Please sign in to comment.