-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'remotes/origin/dev'
- Loading branch information
Showing
53 changed files
with
3,312 additions
and
769 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[flake8] | ||
max-line-length=120 | ||
ignore=D100,D101,D102,D103,D107,F403,F405 | ||
exclude=.git,__pycache__,build,dist,env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,4 +107,7 @@ venv.bak/ | |
.idea/ | ||
|
||
#GIT | ||
.git/ | ||
.git/ | ||
|
||
#vscode | ||
.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
VERSION = "0.11.1" | ||
from core.log import init_logger | ||
|
||
STATE_API_VERSION = "0.12.1" | ||
|
||
init_logger() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,149 @@ | ||
from datetime import datetime | ||
from typing import Sequence, Hashable, Any, Callable, List, Dict | ||
from itertools import compress | ||
import operator | ||
import asyncio | ||
|
||
from core.state_manager import StateManager, get_state | ||
from core.skill_manager import SkillManager | ||
from models.hardcode_utterances import TG_START_UTT | ||
from core.state_schema import Dialog, Human | ||
from collections import defaultdict | ||
from time import time | ||
from typing import Any, Optional, Callable, Hashable | ||
|
||
Profile = Dict[str, Any] | ||
from core.pipeline import Pipeline | ||
from core.state_manager import StateManager | ||
from core.state_schema import Dialog | ||
from models.hardcode_utterances import TG_START_UTT | ||
|
||
|
||
class Agent: | ||
def __init__(self, state_manager: StateManager, preprocessors: List[Callable], | ||
postprocessor: Callable, | ||
skill_manager: SkillManager) -> None: | ||
def __init__(self, pipeline: Pipeline, state_manager: StateManager, | ||
process_logger_callable: Optional[Callable] = None, | ||
response_logger_callable: Optional[Callable] = None): | ||
self.workflow = dict() | ||
self.pipeline = pipeline | ||
self.state_manager = state_manager | ||
self.preprocessors = preprocessors | ||
self.postprocessor = postprocessor | ||
self.skill_manager = skill_manager | ||
|
||
def __call__(self, utterances: Sequence[str], user_telegram_ids: Sequence[Hashable], | ||
user_device_types: Sequence[Any], | ||
date_times: Sequence[datetime], locations=Sequence[Any], | ||
channel_types=Sequence[str]): | ||
should_reset = [utterance == TG_START_UTT for utterance in utterances] | ||
# here and further me stands for mongoengine | ||
me_users = self.state_manager.get_or_create_users(user_telegram_ids, user_device_types) | ||
me_dialogs = self.state_manager.get_or_create_dialogs(me_users, locations, channel_types, | ||
should_reset) | ||
self.state_manager.add_human_utterances(me_dialogs, utterances, date_times) | ||
informative_dialogs = list(compress(me_dialogs, map(operator.not_, should_reset))) | ||
self._update_annotations(informative_dialogs) | ||
|
||
selected_skills = self.skill_manager.get_skill_responses(me_dialogs) | ||
self._update_utterances(me_dialogs, selected_skills, key='selected_skills') | ||
|
||
skill_names, responses, confidences, profiles = self.skill_manager(me_dialogs) | ||
self._update_profiles(me_users, profiles) | ||
|
||
self.state_manager.add_bot_utterances(me_dialogs, responses, responses, | ||
[datetime.utcnow()] * len(me_dialogs), | ||
skill_names, confidences) | ||
|
||
sent_responses = self.postprocessor(me_dialogs) | ||
self._update_utterances(me_dialogs, sent_responses, key='text') | ||
self._update_annotations(me_dialogs) | ||
|
||
return sent_responses # return text only to the users | ||
|
||
def _update_annotations(self, me_dialogs: Sequence[Dialog]) -> None: | ||
for prep in self.preprocessors: | ||
annotations = prep(get_state(me_dialogs)) | ||
utterances = [dialog.utterances[-1] for dialog in me_dialogs] | ||
self.state_manager.add_annotations(utterances, annotations) | ||
|
||
def _update_profiles(self, me_users: Sequence[Human], profiles: List[Profile]) -> None: | ||
for me_user, profile in zip(me_users, profiles): | ||
if any(profile.values()): | ||
self.state_manager.update_user_profile(me_user, profile) | ||
|
||
def _update_utterances(self, me_dialogs: Sequence[Dialog], values: Sequence[Any], | ||
key: str) -> None: | ||
if values: | ||
utterances = [dialog.utterances[-1] for dialog in me_dialogs] | ||
for utt, val in zip(utterances, values): | ||
self.state_manager.update_me_object(utt, {key: val}) | ||
self.process_logger_callable = process_logger_callable | ||
self.response_logger_callable = response_logger_callable | ||
|
||
def add_workflow_record(self, dialog: Dialog, deadline_timestamp: Optional[float] = None, **kwargs): | ||
if str(dialog.id) in self.workflow.keys(): | ||
raise ValueError(f'dialog with id {dialog.id} is already in workflow') | ||
workflow_record = {'dialog_object': dialog, 'dialog': dialog.to_dict(), 'services': defaultdict(dict)} | ||
if deadline_timestamp: | ||
workflow_record['deadline_timestamp'] = deadline_timestamp | ||
if 'dialog_object' in kwargs: | ||
raise ValueError("'dialog_object' is system reserved workflow record field") | ||
workflow_record.update(kwargs) | ||
self.workflow[str(dialog.id)] = workflow_record | ||
|
||
def get_workflow_record(self, dialog_id): | ||
record = self.workflow.get(dialog_id, None) | ||
if not record: | ||
raise ValueError(f'dialog with id {dialog_id} is not exist in workflow') | ||
return record | ||
|
||
def flush_record(self, dialog_id: str): | ||
if dialog_id not in self.workflow.keys(): | ||
raise ValueError(f'dialog with id {dialog_id} is not exist in workflow') | ||
if self.response_logger_callable: | ||
self.response_logger_callable(self.workflow[dialog_id]) | ||
return self.workflow.pop(dialog_id) | ||
|
||
def register_service_request(self, dialog_id: str, service_name): | ||
if dialog_id not in self.workflow.keys(): | ||
raise ValueError(f'dialog with id {dialog_id} is not exist in workflow') | ||
self.workflow[dialog_id]['services'][service_name] = {'send': True, 'done': False, 'agent_send_time': time(), | ||
'agent_done_time': None} | ||
|
||
def get_services_status(self, dialog_id: str): | ||
if dialog_id not in self.workflow.keys(): | ||
raise ValueError(f'dialog with id {dialog_id} is not exist in workflow') | ||
done, waiting = set(), set() | ||
for key, value in self.workflow[dialog_id]['services'].items(): | ||
if value['done']: | ||
done.add(key) | ||
else: | ||
waiting.add(key) | ||
|
||
return done, waiting | ||
|
||
def process_service_response(self, dialog_id: str, service_name: str = None, response: Any = None, | ||
**kwargs): | ||
workflow_record = self.get_workflow_record(dialog_id) | ||
|
||
# Updating workflow with service response | ||
service = self.pipeline.get_service_by_name(service_name) | ||
if service: | ||
service_data = self.workflow[dialog_id]['services'][service_name] | ||
service_data['done'] = True | ||
service_data['agent_done_time'] = time() | ||
if response and service.state_processor_method: | ||
service.state_processor_method(dialog=workflow_record['dialog'], | ||
dialog_object=workflow_record['dialog_object'], | ||
payload=response, | ||
message_attrs=kwargs.pop('message_attrs', {})) | ||
|
||
# passing kwargs to services record | ||
if not set(service_data.keys()).intersection(set(kwargs.keys())): | ||
service_data.update(kwargs) | ||
|
||
# Flush record and return zero next services if service is is_responder | ||
if service.is_responder(): | ||
if not workflow_record.get('hold_flush'): | ||
self.flush_record(dialog_id) | ||
return [] | ||
|
||
# Calculating next steps | ||
done, waiting = self.get_services_status(dialog_id) | ||
next_services = self.pipeline.get_next_services(done, waiting) | ||
|
||
# Processing the case, when service is a skill selector | ||
if service and service.is_sselector(): | ||
selected_services = list(response.values())[0] | ||
result = [] | ||
for service in next_services: | ||
if service.name not in selected_services: | ||
self.workflow[dialog_id]['services'][service.name] = {'done': True, 'send': False, | ||
'agent_send_time': None, | ||
'agent_done_time': None} | ||
else: | ||
result.append(service) | ||
next_services = result | ||
# send dialog workflow record to further logging operations: | ||
if self.process_logger_callable: | ||
self.process_logger_callable(self.workflow['dialog_id']) | ||
|
||
return next_services | ||
|
||
async def register_msg(self, utterance: str, user_telegram_id: Hashable, | ||
user_device_type: Any, location: Any, | ||
channel_type: str, deadline_timestamp=None, | ||
require_response=False, **kwargs): | ||
user = self.state_manager.get_or_create_user(user_telegram_id, user_device_type) | ||
should_reset = True if utterance == TG_START_UTT else False | ||
dialog = self.state_manager.get_or_create_dialog(user, location, channel_type, should_reset=should_reset) | ||
dialog_id = str(dialog.id) | ||
service_name = 'input' | ||
message_attrs = kwargs.pop('message_attrs', {}) | ||
|
||
if require_response: | ||
event = asyncio.Event() | ||
kwargs['event'] = event | ||
self.add_workflow_record(dialog=dialog, deadline_timestamp=deadline_timestamp, hold_flush=True, **kwargs) | ||
self.register_service_request(dialog_id, service_name) | ||
await self.process(dialog_id, service_name, response=utterance, message_attrs=message_attrs) | ||
await event.wait() | ||
return self.flush_record(dialog_id) | ||
|
||
self.add_workflow_record(dialog=dialog, deadline_timestamp=deadline_timestamp, **kwargs) | ||
self.register_service_request(dialog_id, service_name) | ||
await self.process(dialog_id, service_name, response=utterance, message_attrs=message_attrs) | ||
|
||
async def process(self, dialog_id, service_name=None, response: Any = None, **kwargs): | ||
workflow_record = self.get_workflow_record(dialog_id) | ||
next_services = self.process_service_response(dialog_id, service_name, response, **kwargs) | ||
|
||
service_requests = [] | ||
for service in next_services: | ||
self.register_service_request(dialog_id, service.name) | ||
payload = service.apply_workflow_formatter(workflow_record) | ||
service_requests.append( | ||
service.connector_func(payload=payload, callback=self.process) | ||
) | ||
|
||
await asyncio.gather(*service_requests) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.