Skip to content

Commit

Permalink
#70 implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolay-r committed Sep 2, 2023
1 parent 90488c2 commit 702b41b
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 31 deletions.
117 changes: 95 additions & 22 deletions arekit_ss/text_parser/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,82 @@ def __init__(self, src, dest, attempts=10, timeout_for_connection_lost_sec=1.0):
self.__attempts = attempts
self.__timeout_for_connection_lost = timeout_for_connection_lost_sec

def apply_core(self, input_data, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
assert(isinstance(input_data, list))

def __optionally_register(prts_to_join):
if len(prts_to_join) > 0:
content.append(" ".join(prts_to_join))
def fast_most_accurate_approach(self, input_data, entity_placeholder_template="<entityTag={}/>"):
""" This approach assumes that the translation won't corrupt the original
meta-annotation for entities and objects mentioned in text.
"""

def __optionally_register(prts):
if len(prts) > 0:
content.append(" ".join(prts))
parts_to_join.clear()

# Check the pipeline state whether is an idle mode or not.
parent_ctx = pipeline_ctx.provide(PARENT_CTX)
idle_mode = parent_ctx.provide(IDLE_MODE)
content = []
origin_entities = []
parts_to_join = []

# When pipeline utilized only for the assessing the expected amount
# of rows (common case of idle_mode), there is no need to perform
# translation.
if idle_mode:
return
for part in input_data:
if isinstance(part, str) and part.strip():
parts_to_join.append(part)
elif isinstance(part, Entity):
entity_index = len(origin_entities)
parts_to_join.append(entity_placeholder_template.format(entity_index))
# Register entities information for further restoration.
origin_entities.append(part)

# Register original text with masked named entities.
__optionally_register(parts_to_join)
# Register all named entities in order of their appearance in text.
content.extend([e.Value for e in origin_entities])

# Due to the potential opportunity of connection lost,
# we wrap everything in a loop with multiple attempts.
for attempt_index in range(self.__attempts):
try:
# Compose text parts.
translated_parts = [
part.text for part in self.translator.translate(content, dest=self.__dest, src=self.__src)
]

# Take the original text.
text = translated_parts[0]
for entity_index in range(len(origin_entities)):
if entity_placeholder_template.format(entity_index) not in text:
return []

# Enumerate entities.
from_ind = 0
text_parts = []
for entity_index, translated_value in enumerate(translated_parts[1:]):
entity_placeholder_instance = entity_placeholder_template.format(entity_index)
# Cropping text part.
to_ind = text.index(entity_placeholder_instance)
origin_entities[entity_index].set_display_value(translated_value.strip())
# Register entities.
text_parts.append(text[from_ind:to_ind])
text_parts.append(origin_entities[entity_index])
# Update from index.
from_ind = to_ind + len(entity_placeholder_instance)

# Consider the remaining part.
text_parts.append(text[from_ind:])

return text_parts
except:
if attempt_index > 0:
logger.info("Unable to perform translation. Try {} out of {}.".format(attempt_index, self.__attempts))
time.sleep(self.__timeout_for_connection_lost)
return []

def default_pre_part_splitting_approach(self, input_data):
""" This is the original strategy, based on the manually cropped named entities
before the actual translation call.
"""

def __optionally_register(prts):
if len(prts) > 0:
content.append(" ".join(prts))
parts_to_join.clear()

content = []
origin_entities = []
Expand All @@ -68,8 +126,6 @@ def __optionally_register(prts_to_join):

__optionally_register(parts_to_join)

translated_parts = []

# Due to the potential opportunity of connection lost, we wrap everything in a loop with multiple attempts.
for attempt_index in range(self.__attempts):
try:
Expand All @@ -78,12 +134,29 @@ def __optionally_register(prts_to_join):
self.translator.translate(content, dest=self.__dest, src=self.__src)]
for entity_ind, entity_part_ind in enumerate(origin_entity_ind):
entity = origin_entities[entity_ind]
entity.set_display_value(translated_parts[entity_part_ind])
entity.set_display_value(translated_parts[entity_part_ind].strip())
translated_parts[entity_part_ind] = entity
break
return translated_parts
except:
logger.info("Unable to perform translation. Try {} out of {}.".format(attempt_index, self.__attempts))
if attempt_index > 0:
logger.info("Unable to perform translation. Try {} out of {}.".format(attempt_index, self.__attempts))
time.sleep(self.__timeout_for_connection_lost)
translated_parts = []
return []

def apply_core(self, input_data, pipeline_ctx):
assert(isinstance(pipeline_ctx, PipelineContext))
assert(isinstance(input_data, list))

# Check the pipeline state whether is an idle mode or not.
parent_ctx = pipeline_ctx.provide(PARENT_CTX)
idle_mode = parent_ctx.provide(IDLE_MODE)

# When pipeline utilized only for the assessing the expected amount
# of rows (common case of idle_mode), there is no need to perform
# translation.
if idle_mode:
return

return translated_parts
fast_accurate = self.fast_most_accurate_approach(input_data)
return self.default_pre_part_splitting_approach(input_data) \
if len(fast_accurate) == 0 else fast_accurate
18 changes: 9 additions & 9 deletions test/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import unittest

from arekit.common.data.input.providers.const import IDLE_MODE
from arekit.common.docs.base import Document
from arekit.common.docs.parser import DocumentParser
from arekit.common.docs.sentence import BaseDocumentSentence
from arekit.common.entities.base import Entity
from arekit.common.context.token import Token
from arekit.common.news.sentence import BaseNewsSentence
from arekit.common.news.base import News
from arekit.common.news.parser import NewsParser
from arekit.common.pipeline.context import PipelineContext
from arekit.contrib.utils.pipelines.items.text.tokenizer import DefaultTextTokenizer
from arekit.contrib.utils.pipelines.items.text.entities_default import TextEntitiesParser
Expand All @@ -19,10 +19,10 @@
logging.basicConfig(level=logging.DEBUG)


class TestTestParser(unittest.TestCase):
class TestTextParser(unittest.TestCase):

def test(self):
text = "А контроль над этими провинциями — [США] , которая не пытается ввести санкции против."
text = "А контроль над этими провинциями — [США] , которая не пытается ввести санкции против. [ВКC] "

# Adopting translate pipeline item, based on google translator.
text_parser = BaseTextParser(pipeline=[
Expand All @@ -31,10 +31,10 @@ def test(self):
DefaultTextTokenizer(keep_tokens=True),
])

news = News(doc_id=0, sentences=[BaseNewsSentence(text.split())])
parsed_news = NewsParser.parse(news=news, text_parser=text_parser,
parent_ppl_ctx=PipelineContext({IDLE_MODE: False}))
self.debug_show_terms(parsed_news.iter_terms())
doc = Document(doc_id=0, sentences=[BaseDocumentSentence(text.split())])
parsed_doc = DocumentParser.parse(doc=doc, text_parser=text_parser,
parent_ppl_ctx=PipelineContext({IDLE_MODE: False}))
self.debug_show_terms(parsed_doc.iter_terms())

@staticmethod
def debug_show_terms(terms):
Expand Down

0 comments on commit 702b41b

Please sign in to comment.