From b76e83902f22cf12790a311dcd9b6b86ea21fa48 Mon Sep 17 00:00:00 2001 From: thanawan-atc <106889996+thanawan-atc@users.noreply.github.com> Date: Wed, 6 Jul 2022 13:15:22 -0700 Subject: [PATCH 1/6] new fresh copy + rich logging --- __init__.py | 0 chirpy/core/logging_formatting.py | 85 ++++++-- chirpy/core/logging_rich.py | 331 ++++++++++++++++++++++++++++++ chirpy/core/logging_utils.py | 43 +++- env.list | 21 ++ 5 files changed, 449 insertions(+), 31 deletions(-) create mode 100644 __init__.py create mode 100644 chirpy/core/logging_rich.py create mode 100644 env.list diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chirpy/core/logging_formatting.py b/chirpy/core/logging_formatting.py index 087bc32..b89df7d 100644 --- a/chirpy/core/logging_formatting.py +++ b/chirpy/core/logging_formatting.py @@ -2,23 +2,52 @@ from typing import Optional from colorama import Fore, Back +from pathlib import Path + +def get_active_branch_name(): + git_dir = Path(".") / ".git" + if git_dir.is_dir(): + head_dir = git_dir / "HEAD" + with head_dir.open("r") as f: content = f.read().splitlines() + + for line in content: + if line[0:4] == "ref:": + return line.partition("refs/heads/")[2] + else: # for integ. testing, we don't copy .git/ to the instance + return [] + LINEBREAK = '' -# The key in this dict must match the 'name' given to the component in baseline_bot.py (case-sensitive) -# The path_strings are strings we'll search for (case-sensitive) in the path of the file that does the log message +# The key in this dict must match the 'name' given to the component in baseline_bot.py (case-insensitive) +# The path_strings are strings we'll search for (case-insensitive) in the path of the file that does the log message # You can comment out parts of this dict and add your own components to make it easier to only see what you're working on +# See https://rich.readthedocs.io/en/stable/appendix/colors.html for list of rich colors +# See https://github.com/willmcgugan/rich/blob/master/rich/_emoji_codes.py for emoji codes COLOR_SETTINGS = { - 'WIKI': {'color': Fore.MAGENTA, 'path_strings': ['wiki']}, - 'MOVIES': {'color': Fore.GREEN, 'path_strings': ['movies']}, - # 'NEWS': {'color': Fore.CYAN, 'path_strings': ['news']}, - 'ACKNOWLEDGMENT': {'color': Fore.CYAN, 'path_strings': ['acknowledgment']}, - # 'LAUNCH': {'color': Fore.LIGHTMAGENTA_EX, 'path_strings': ['launch']}, - 'CATEGORIES': {'color': Fore.YELLOW, 'path_strings': ['categories']}, - 'NEURAL_CHAT': {'color': Fore.LIGHTMAGENTA_EX, 'path_strings': ['neural_chat']}, - 'entity_linker': {'color': Fore.LIGHTCYAN_EX, 'path_strings': ['entity_linker']}, - 'entity_tracker': {'color': Fore.LIGHTYELLOW_EX, 'path_strings': ['entity_tracker']}, - 'experiments': {'color': Fore.LIGHTGREEN_EX, 'path_strings': ['experiments']}, - 'navigational_intent': {'color': Fore.LIGHTMAGENTA_EX, 'path_strings': ['navigational_intent']} + 'ACKNOWLEDGMENT': {'color': Fore.CYAN, 'rich_color': '#0AAB42', + 'emoji': ':white_heavy_check_mark:', 'path_strings': ['acknowledgment']}, + 'ALEXA_COMMANDS': {'emoji': ':speaking_head_in_silhouette:', 'path_strings': ['alexa_commands']}, + 'ALIENS': {'rich_color': '#1EA8B3', 'emoji': ':alien:'}, + 'CATEGORIES': {'rich_color': '#15EBCE', 'emoji': ':newspaper:', 'path_strings': ['categories']}, + 'CORONAVIRUS': {'rich_color': '#F70C6E', 'emoji': ':face_with_medical_mask:'}, + 'FOOD': {'rich_color': '#97F20F', 'emoji': ':sushi:'}, + 'LAUNCH': {'emoji': ':checkered_flag:', 'path_strings': ['launch']}, + 'MOVIES': {'rich_color': '#F0D718', 'emoji': ':movie_camera:', 'path_strings': ['movies']}, + 'MUSIC': {'rich_color': '#0586FF', 'emoji': ':musical_notes:'}, + 'NEURAL_CHAT': {'rich_color': '#0EE827', 'emoji': ':brain:', 'path_strings': ['neural_chat']}, + 'NEWS': {'rich_color': '#1C64FF', 'emoji': ':newspaper:'}, + 'OFFENSIVE_USER': {'rich_color': '#EB5215', 'emoji': ':prohibited:'}, + 'ONE_TURN_HACK': {'rich_color': '#88B0B3', 'emoji': ':hammer:'}, + 'OPINION': {'rich_color': '#D011ED', 'emoji': ':thinking_face:'}, + 'PERSONAL_ISSUES': {'rich_color': '#BC3BEB', 'emoji': ':slightly_frowning_face:'}, + 'SPORTS': {'rich_color': '#EB8715', 'emoji': ':football:'}, + 'WIKI': {'rich_color': '#42C2F5', 'emoji': ':books:'}, + 'TRANSITION': {'rich_color': '#5FD700', 'emoji': ':soon_arrow:'}, + 'REOPEN': {'rich_color': '##5F00FF', 'emoji': ':door:'}, + 'entity_linker': {'color': Fore.LIGHTCYAN_EX, 'rich_color': '#0BC3E3', 'path_strings': ['entity_linker']}, + 'entity_tracker': {'color': Fore.LIGHTYELLOW_EX, 'rich_color': '#DB960D', 'path_strings': ['entity_tracker']}, + 'experiments': {'color': Fore.LIGHTGREEN_EX, 'rich_color': '#CADB0D', 'path_strings': ['experiments']}, + 'navigational_intent': {'color': Fore.LIGHTMAGENTA_EX, 'rich_color': '#DB0D93', 'path_strings': ['navigational_intent']} } LOG_FORMAT = '[%(levelname)s] [%(asctime)s] [fn_vers: {function_version}] [session_id: {session_id}] [%(pathname)s:%(lineno)d]\n%(message)s\n' @@ -33,8 +62,14 @@ def colored(str, fore=None, back=None, include_reset=True): new_str = '{}{}{}'.format(back, new_str, Back.RESET if include_reset else '') return new_str +def get_rich_color_for_rg(rg_name): + for component_name, settings in COLOR_SETTINGS.items(): + if component_name.lower() == rg_name.lower() and settings.get('rich_color'): + color = settings['rich_color'] + return f"[{color}]{rg_name}[/{color}]" + return rg_name -def get_line_color(line): +def get_line_color(line, branch_name): """ Given a line of logging (which is one line of a multiline log message), searches for component names at the beginning of the line. If one is found, returns its color. @@ -42,7 +77,9 @@ def get_line_color(line): first_part_line = line.strip().split()[0] for component_name, settings in COLOR_SETTINGS.items(): if component_name in first_part_line: - return settings['color'] + return settings.get('color') + if any(b.lower() in first_part_line.lower() for b in branch_name): + return Fore.BLUE return None @@ -62,7 +99,7 @@ def get_line_key(idx: int): class ChirpyFormatter(logging.Formatter): """ - A custom formatter that formats linebreaks and color according to logger_settings, and the context of each message. + A color formatter that formats linebreaks and color according to logger_settings, and the context of each message. Based on this: https://stackoverflow.com/a/14859558 """ @@ -72,6 +109,10 @@ def __init__(self, allow_multiline: bool, use_color: bool, session_id: Optional[ self.use_color = use_color self.session_id = session_id self.function_version = function_version + if self.use_color: + branch_name = get_active_branch_name() + branch_name = ''.join([x if x.isalpha() else ' ' for x in branch_name]) + self.branch_name = branch_name.split() self.update_format() def update_format(self): @@ -137,15 +178,19 @@ def format_color(self, record): lines = record.msg.split('\n') for idx, line in enumerate(lines): setattr(record, get_line_key(idx), line) # e.g. record['line_5'] -> the text of the 5th line of logging - line_colors = [get_line_color(line) for line in lines] # get the color for each line + line_colors = [get_line_color(line, self.branch_name) for line in lines] # get the color for each line self._style._fmt = self.fmt.replace('%(message)s', linecolored_msg_fmt(line_colors)) # this format string has keys for line_1, line_2, etc, along with line-specific colors # If the filepath of the calling function contains a path string for a colored component, return its color else: for component, settings in COLOR_SETTINGS.items(): - for path_string in settings['path_strings']: - if path_string in record.pathname: - self._style._fmt = colored(self.fmt, fore=settings['color']) + if settings.get('path_strings'): + for path_string in settings['path_strings']: + if path_string in record.pathname: + self._style._fmt = colored(self.fmt, fore=settings['color']) + continue + if any(b in record.pathname for b in self.branch_name): + self._style._fmt = colored(self.fmt, fore=Fore.BLUE) # Use the formatter class to do the formatting (with a possibly modified format) result = logging.Formatter.format(self, record) diff --git a/chirpy/core/logging_rich.py b/chirpy/core/logging_rich.py new file mode 100644 index 0000000..4557a64 --- /dev/null +++ b/chirpy/core/logging_rich.py @@ -0,0 +1,331 @@ +import logging +from datetime import datetime +from logging import Handler, LogRecord +from pathlib import Path +from collections import Iterable +from typing import ClassVar, Iterable, List, Optional, Type, TYPE_CHECKING, Union, Callable +import os +import rich + +from rich import get_console +from rich._log_render import LogRender, FormatTimeCallable +from rich.containers import Renderables +from rich.console import Console, ConsoleRenderable, RenderableType +from rich.highlighter import Highlighter, ReprHighlighter +from rich.text import Text, TextType +from rich.traceback import Traceback +from rich.logging import RichHandler + +from rich.table import Table + +from chirpy.core.logging_formatting import COLOR_SETTINGS + +PATH_WIDTH = 25 + +LEVEL_STYLES = {"primary_info": "dim", + "error": "bold red on bright_yellow"} + +LEVEL_LINE_COLORS = {"error": "red"} + +COBOT_HOME = os.environ.get('COBOT_HOME', os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +def get_rich_color(text): + for component_name, settings in COLOR_SETTINGS.items(): + if component_name.lower() in text.lower(): + return settings.get('rich_color') + if 'path_strings' in settings: + for path_string in settings['path_strings']: + if path_string.lower() in text.lower(): + return settings.get('rich_color') + return None + +def add_emoji(text, check_text = None): + if not check_text: + check_text = text + for component_name, settings in COLOR_SETTINGS.items(): + if component_name in check_text and settings.get('emoji'): + return settings['emoji'] + ' ' + text + if 'path_strings' in settings: + for path_string in settings['path_strings']: + if path_string in check_text and settings.get('emoji'): + return settings['emoji'] + ' ' + text + return text + +class ChirpyLogRender(LogRender): + def __call__( + self, + console: "Console", + renderables: Iterable["ConsoleRenderable"], + log_time: datetime = None, + time_format: Union[str, FormatTimeCallable] = None, + level: TextType = "", + path: str = None, + line_no: int = None, + link_path: str = None, + path_color: str = None, + ) -> "Table": + output = Table.grid(padding=(0, 1)) + output.expand = True + if self.show_level: + output.add_column(width=self.level_width) + if self.show_path and path: + output.add_column(width=PATH_WIDTH) #style="dim", + output.add_column(ratio=1, style="log.message", overflow="fold") + if self.show_time: + output.add_column(style="log.time") + row: List["RenderableType"] = [] + if self.show_level: + row.append(level[:3]) + if self.show_path and path: + path_text = Text.from_markup(path) + if line_no: + if len(path_text) > PATH_WIDTH - len(str(line_no)) - 2: + path_text.truncate(PATH_WIDTH - len(str(line_no)) - 2) + path_text.append("…") + path_text.append(f":{line_no}") + path_text.stylize(path_color) + row.append(path_text) + + row.append(Renderables(renderables)) + if self.show_time: + log_time = log_time or console.get_datetime() + time_format = time_format or self.time_format + if callable(time_format): + log_time_display = time_format(log_time) + else: + log_time_display = Text(log_time.strftime(time_format)[:-4] + ']') + if log_time_display == self._last_time and self.omit_repeated_times: + row.append(Text(" " * len(log_time_display))) + else: + row.append(log_time_display) + self._last_time = log_time_display + + output.add_row(*row) + return output + + +class ChirpyHandler(RichHandler): + DICT_OPEN_TAG: str = "[dict]\n" + DICT_CLOSE_TAG: str = "\n[/dict]" + + def __init__( + self, + level: Union[int, str] = logging.NOTSET, + console: Console = None, + *, + show_time: bool = True, + omit_repeated_times: bool = True, + show_level: bool = True, + show_path: bool = True, + enable_link_path: bool = True, + highlighter: Highlighter = None, + markup: bool = False, + rich_tracebacks: bool = False, + tracebacks_width: Optional[int] = None, + tracebacks_extra_lines: int = 3, + tracebacks_theme: Optional[str] = None, + tracebacks_word_wrap: bool = True, + tracebacks_show_locals: bool = False, + locals_max_length: int = 10, + locals_max_string: int = 80, + log_time_format: Union[str, FormatTimeCallable] = "[%x %X]", + filter_by_rg: str = None, + disable_annotation: bool = False, + ) -> None: + super().__init__( + level=level, + console=console, + show_time=show_time, + omit_repeated_times=omit_repeated_times, + show_level=show_level, + show_path=show_path, + enable_link_path=enable_link_path, + highlighter=highlighter, + markup=markup, + rich_tracebacks=rich_tracebacks, + tracebacks_width=tracebacks_width, + tracebacks_extra_lines=tracebacks_extra_lines, + tracebacks_theme=tracebacks_theme, + tracebacks_word_wrap=tracebacks_word_wrap, + tracebacks_show_locals=tracebacks_show_locals, + locals_max_length=locals_max_length, + locals_max_string=locals_max_string, + log_time_format=log_time_format, + ) + self._log_render = ChirpyLogRender( + show_time=show_time, + show_level=show_level, + show_path=show_path, + time_format=log_time_format, + omit_repeated_times=omit_repeated_times, + level_width=None, + ) + if filter_by_rg: + valid_rg_filenames = [f.name.lower() for f in os.scandir(os.path.join(COBOT_HOME, "chirpy/response_generators")) if f.is_dir()] + filter_by_rg = filter_by_rg.lower() + assert filter_by_rg in valid_rg_filenames, f"{filter_by_rg} does not specify a valid RG filename (must be a folder in chirpy/response_generators)" + self.filter_by_rg = filter_by_rg + self.disable_annotation = disable_annotation + + def process_dictionary(self, dict_text: str) -> "ConsoleRenderable": + lines = dict_text.split('\n') + pairs = [line.split('\u00a0' * 5) for line in lines] + assert all(len(p) == 2 for p in pairs) + grid = Table.grid(expand=True, padding=(0, 3)) + grid.add_column(justify="left", width=25) + grid.add_column(ratio=1) + for pair in pairs: + pair = [p.strip().strip("'") for p in pair] + name, value = pair + text_color = get_rich_color(name) + name = add_emoji(name) + if text_color: + grid.add_row(Text.from_markup(name, style=text_color), value) + else: + grid.add_row(name, value) + return grid + + def render_message(self, record: LogRecord, message: str) -> List["ConsoleRenderable"]: + """Render message text in to Text. + + record (LogRecord): logging Record. + message (str): String cotaining log message. + + Returns: + ConsoleRenderable: Renderable to display log message. + """ + use_markup = ( + getattr(record, "markup") if hasattr(record, "markup") else self.markup + ) + message_texts = [] + if record.levelname.lower() in LEVEL_LINE_COLORS: + color = LEVEL_LINE_COLORS[record.levelname.lower()] + message = "[" + color + "]" + message + message = message.replace('\n', "[/" + color + "]\n", 1) + if self.DICT_OPEN_TAG in message: + start = message.find(self.DICT_OPEN_TAG) + end = message.find(self.DICT_CLOSE_TAG) + dict_text = message[start + len(self.DICT_OPEN_TAG):end] + message_one = message[:start] + message_two = message[end + len(self.DICT_CLOSE_TAG):] + text_color = None + message_texts.append(Text.from_markup(message_one) if use_markup else Text(message_one)) + message_texts.append(self.process_dictionary(dict_text)) + message_texts.append(Text.from_markup(message_two) if use_markup else Text(message_two)) + else: + message_texts.append(Text.from_markup(message) if use_markup else Text(message)) + for message_text in message_texts: + if isinstance(message_text, Text): + if self.highlighter: + message_text = self.highlighter(message_text) + if self.KEYWORDS: + message_text.highlight_words(self.KEYWORDS, "logging.keyword") + return message_texts + + def get_level_text(self, record: LogRecord) -> Text: + """Get the level name from the record. + + Args: + record (LogRecord): LogRecord instance. + + Returns: + Text: A tuple of the style and level name. + """ + level_name = record.levelname + level_text = Text.styled( + level_name[:3].ljust(8).capitalize(), f"logging.level.{level_name.lower()}" + ) + if level_name.lower() in LEVEL_STYLES: + level_text.stylize(LEVEL_STYLES[level_name.lower()]) + return level_text + + def emit(self, record: LogRecord) -> None: + """Invoked by logging.""" + message = self.format(record) + traceback = None + if ( + self.rich_tracebacks + and record.exc_info + and record.exc_info != (None, None, None) + ): + exc_type, exc_value, exc_traceback = record.exc_info + assert exc_type is not None + assert exc_value is not None + traceback = Traceback.from_exception( + exc_type, + exc_value, + exc_traceback, + width=self.tracebacks_width, + extra_lines=self.tracebacks_extra_lines, + theme=self.tracebacks_theme, + word_wrap=self.tracebacks_word_wrap, + show_locals=self.tracebacks_show_locals, + locals_max_length=self.locals_max_length, + locals_max_string=self.locals_max_string, + ) + message = record.getMessage() + if self.formatter: + record.message = record.getMessage() + formatter = self.formatter + if hasattr(formatter, "usesTime") and formatter.usesTime(): # type: ignore + record.asctime = formatter.formatTime(record, formatter.datefmt) + message = formatter.formatMessage(record) + + if self.should_show(record): + message_renderable = self.render_message(record, message) + log_renderable = self.render( + record=record, traceback=traceback, message_renderable=message_renderable + ) + self.console.print(log_renderable) + + def should_show(self, record): + if record.levelname.lower() in ['error', 'warning']: + return True + path = record.pathname + if 'response_generators' not in path.lower(): + return not self.disable_annotation + if self.filter_by_rg is None: + return True + return self.filter_by_rg in path.lower() + + def render( + self, + *, + record: LogRecord, + traceback: Optional[Traceback], + message_renderable: "ConsoleRenderable", + ) -> "ConsoleRenderable": + """Render log for display. + + Args: + record (LogRecord): logging Record. + traceback (Optional[Traceback]): Traceback instance or None for no Traceback. + message_renderable (ConsoleRenderable): Renderable (typically Text) containing log message contents. + + Returns: + ConsoleRenderable: Renderable to display log. + """ + path_color = get_rich_color(record.pathname) + path = Path(record.pathname).name + path = add_emoji(path, record.pathname) + if record.levelname.lower() in LEVEL_LINE_COLORS: + path_color = LEVEL_LINE_COLORS[record.levelname.lower()] + level = self.get_level_text(record) + time_format = None if self.formatter is None else self.formatter.datefmt + log_time = datetime.fromtimestamp(record.created) + + if traceback: + message_renderable.append(traceback) + + log_renderable = self._log_render( + self.console, + message_renderable, + log_time=log_time, + time_format=time_format, + level=level, + path=path, + line_no=record.lineno, + link_path=record.pathname if self.enable_link_path else None, + path_color=path_color, + ) + return log_renderable \ No newline at end of file diff --git a/chirpy/core/logging_utils.py b/chirpy/core/logging_utils.py index e460eb7..e8daabe 100644 --- a/chirpy/core/logging_utils.py +++ b/chirpy/core/logging_utils.py @@ -1,5 +1,6 @@ """ -This file contains functions to create and configure the chirpylogger +This file contains functions to create and configure the chirpylogger, which is a single simple logger to replace +the more complicated LoggerFactory that came with Cobot. """ import logging @@ -8,6 +9,8 @@ from dataclasses import dataclass from typing import Optional from chirpy.core.logging_formatting import ChirpyFormatter +from chirpy.core.logging_rich import ChirpyHandler +from rich.highlighter import NullHighlighter PRIMARY_INFO_NUM = logging.INFO + 5 # between INFO and WARNING @@ -22,18 +25,24 @@ class LoggerSettings: logtoscreen_allow_multiline: bool # If true, log-to-screen messages contain \n. If false, all the \n are replaced with integ_test: bool # If True, we setup the logger in a special way to work with nosetests remove_root_handlers: bool # If True, we remove all other handlers on the root logger + allow_rich_formatting: bool = True + filter_by_rg: str = None + disable_annotation: bool = False # AWS adds a LambdaLoggerHandler to the root handler, which causes duplicate logging because we have our customized # StreamHandler on the root logger too. So we set remove_root_handlers=True to remove the LambdaLoggerHandler. # See here: https://stackoverflow.com/questions/50909824/getting-logs-twice-in-aws-lambda-function -PROD_LOGGER_SETTINGS = LoggerSettings(logtoscreen_level=logging.INFO, +PROD_LOGGER_SETTINGS = LoggerSettings(logtoscreen_level=logging.DEBUG, logtoscreen_usecolor=True, logtofile_level=None, logtofile_path=None, - logtoscreen_allow_multiline=False, + logtoscreen_allow_multiline=True, integ_test=False, - remove_root_handlers=True) + remove_root_handlers=True, + allow_rich_formatting=True, + filter_by_rg=None, + disable_annotation=False) def setup_logger(logger_settings, session_id=None): @@ -85,11 +94,23 @@ def setup_logger(logger_settings, session_id=None): chirpy_logger.setLevel(logging.DEBUG) # Create the stream handler and attach it to the root logger - stream_handler = logging.StreamHandler(sys.stdout) - stream_handler.setLevel(logger_settings.logtoscreen_level) - stream_formatter = ChirpyFormatter(allow_multiline=logger_settings.logtoscreen_allow_multiline, use_color=logger_settings.logtoscreen_usecolor, session_id=session_id) - stream_handler.setFormatter(stream_formatter) - root_logger.addHandler(stream_handler) + print("allow_multiline = ", logger_settings.logtoscreen_allow_multiline ) + print("rich formatting = ", logger_settings.allow_rich_formatting) + if logger_settings.logtoscreen_allow_multiline and logger_settings.allow_rich_formatting: + root_logger.addHandler(ChirpyHandler(log_time_format="[%H:%M:%S.%f]", + level=logger_settings.logtoscreen_level, + markup=True, + highlighter=NullHighlighter(), + filter_by_rg=logger_settings.filter_by_rg, + disable_annotation=logger_settings.disable_annotation)) + else: + # Use the stream handler if no multi-line to not mess up production logs + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setLevel(logger_settings.logtoscreen_level) + stream_formatter = ChirpyFormatter(allow_multiline=logger_settings.logtoscreen_allow_multiline, use_color=logger_settings.logtoscreen_usecolor, session_id=session_id) + stream_handler.setFormatter(stream_formatter) + root_logger.addHandler(stream_handler) + #root_logger.addHandler(RichHandler(log_time_format="[%H:%M:%S]", level=logger_settings.logtoscreen_level, markup=True)) # Create the file handler and attach it to the root logger if logger_settings.logtofile_path: @@ -102,7 +123,7 @@ def setup_logger(logger_settings, session_id=None): # Mark that the root logger has the chirpy handlers attached root_logger.chirpy_handlers = True - # Add the custom PRIMARY_INFO level to chirpy logger + # Add the color PRIMARY_INFO level to chirpy logger add_new_level(chirpy_logger, 'PRIMARY_INFO', PRIMARY_INFO_NUM) return chirpy_logger @@ -158,4 +179,4 @@ def update_logger(session_id, function_version): for handler in root_logger.handlers: if isinstance(handler.formatter, ChirpyFormatter): handler.formatter.update_session_id(session_id) - handler.formatter.update_function_version(function_version) + handler.formatter.update_function_version(function_version) \ No newline at end of file diff --git a/env.list b/env.list new file mode 100644 index 0000000..c9cedf7 --- /dev/null +++ b/env.list @@ -0,0 +1,21 @@ +export PYTHONPATH=$(pwd) +export ES_USER=chirpy1 +export ES_PASSWORD=4sMoNKNxQkMeVtrlEYqsK2Nzo7kBNU@ +export ES_HOST=search-genie-search-dev-36ydzvzvwb7oyyzvdbrs63rdny.us-west-2.es.amazonaws.com +export ES_PORT=443 +export ES_SCHEME=https +export POSTGRES_HOST=localhost +export POSTGRES_USER=postgres +export POSTGRES_PASSWORD=qyhqae-4Sepzy-zecget +export corenlp_URL=4080 +export dialogact_URL=4081 +export g2p_URL=4082 +export gpt2ed_URL=4083 +export question_URL=4084 +export convpara_URL=4085 +export entitylinker_URL=4086 +export blenderbot_URL=4087 +export responseranker_URL=4088 +export stanfordnlp_URL=4089 +export infiller_URL=4090 +export postgresql_URL=5432 From 89845258c2773f06afbe8d3f61a833384c05a3c3 Mon Sep 17 00:00:00 2001 From: thanawan-atc <106889996+thanawan-atc@users.noreply.github.com> Date: Mon, 18 Jul 2022 13:56:32 -0700 Subject: [PATCH 2/6] takeover --- agents/local_agent.py | 16 +- agents/portforwarding.sh | 2 +- chirpy/core/dialog_manager.py | 29 +++- .../core/entity_linker/wiki_data_fetching.py | 2 +- chirpy/core/entity_tracker/entity_tracker.py | 37 ++++- chirpy/core/logging_utils.py | 2 +- .../response_generator/response_generator.py | 151 +++++++++++++++++- chirpy/core/response_generator/state.py | 4 + chirpy/core/response_generator/treelet.py | 3 + chirpy/core/response_generator_datatypes.py | 20 ++- chirpy/core/response_priority.py | 2 +- chirpy/core/state.py | 1 + chirpy/core/state_manager.py | 2 + ...closing_confirmation_response_generator.py | 2 +- .../food/food_response_generator.py | 7 +- .../treelets/ask_favorite_food_treelet.py | 3 +- .../comment_on_favorite_type_treelet.py | 40 ++++- .../food/treelets/factoid_treelet.py | 8 +- .../food/treelets/introductory_treelet.py | 4 +- .../open_ended_user_comment_treelet.py | 7 +- .../launch/launch_response_generator.py | 2 +- .../music/music_response_generator.py | 2 +- .../neural_chat_response_generator.py | 2 +- .../response_generators/neural_chat/state.py | 15 +- .../offensive_user_response_generator.py | 2 +- .../opinion2/opinion_response_generator.py | 2 +- chirpy/response_generators/wiki2/state.py | 1 + .../wiki2/treelets/discuss_article_treelet.py | 3 +- .../discuss_section_further_treelet.py | 3 +- .../wiki2/treelets/handback_treelet.py | 32 ++++ .../wiki2/treelets/intro_entity_treelet.py | 3 +- .../wiki2/treelets/takeover_treelet.py | 76 +++++++++ .../wiki2/wiki_response_generator.py | 19 ++- .../response_generators/wiki2/wiki_utils.py | 5 +- env.list | 1 + servers/local/shell_chat.py | 2 +- 36 files changed, 454 insertions(+), 58 deletions(-) create mode 100644 chirpy/response_generators/wiki2/treelets/handback_treelet.py create mode 100644 chirpy/response_generators/wiki2/treelets/takeover_treelet.py diff --git a/agents/local_agent.py b/agents/local_agent.py index 7a02e18..36dd2f6 100644 --- a/agents/local_agent.py +++ b/agents/local_agent.py @@ -205,17 +205,17 @@ def create_handler(self): response_generator_classes=[LaunchResponseGenerator, FallbackResponseGenerator, NeuralFallbackResponseGenerator, NeuralChatResponseGenerator, - OffensiveUserResponseGenerator, - CategoriesResponseGenerator, - ClosingConfirmationResponseGenerator, - AcknowledgmentResponseGenerator, - PersonalIssuesResponseGenerator, - OpinionResponseGenerator2, - AliensResponseGenerator, + # OffensiveUserResponseGenerator, + # CategoriesResponseGenerator, + # ClosingConfirmationResponseGenerator, + # AcknowledgmentResponseGenerator, + # PersonalIssuesResponseGenerator, + # OpinionResponseGenerator2, + # AliensResponseGenerator, TransitionResponseGenerator, FoodResponseGenerator, WikiResponseGenerator, - MusicResponseGenerator, + # MusicResponseGenerator, ], annotator_classes = [QuestionAnnotator, DialogActAnnotator, NavigationalIntentModule, StanfordnlpModule, CorenlpModule, EntityLinkerModule, BlenderBot], diff --git a/agents/portforwarding.sh b/agents/portforwarding.sh index c356855..2f83df5 100644 --- a/agents/portforwarding.sh +++ b/agents/portforwarding.sh @@ -2,7 +2,7 @@ kubectl port-forward corenlp-7fd4974bb-8mq5g 4080:5001 -n chirpy kubectl port-forward dialogact-849b4b67d8-ngzd5 4081:5001 -n chirpy & kubectl port-forward g2p-7644ff75bd-cjj57 4082:5001 -n chirpy & kubectl port-forward gpt2ed-68f849f64b-wr8zw 4083:5001 -n chirpy & -kubectl port-forward questionclassifier-668c4fd6c6-fd586 4084:5001 -n chirpy & +kubectl port-forward questionclassifier-668c4fd6c6-7nl2k 4084:5001 -n chirpy & kubectl port-forward convpara-dbdc8dcfb-csktj 4085:5001 -n chirpy & kubectl port-forward entitylinker-59b9678b8-nmwx9 4086:5001 -n chirpy & kubectl port-forward blenderbot-695c7b5896-gkz2s 4087:5001 -n chirpy & diff --git a/chirpy/core/dialog_manager.py b/chirpy/core/dialog_manager.py index 9c0c473..9491dc1 100644 --- a/chirpy/core/dialog_manager.py +++ b/chirpy/core/dialog_manager.py @@ -284,7 +284,15 @@ def update_rg_states(self, results: RankedResults, selected_rg: str): # Get the args needed for the update_state_if_not_chosen fn. That's (state, conditional_state) for all RGs except selected_rg other_rgs = [rg for rg in results.keys() if rg != selected_rg and not is_killed(results[rg])] logger.info(f"now, current states are {rg_states}") - args_list = [[rg_states[rg], results[rg].conditional_state] for rg in other_rgs] + + def rg_was_taken_over(rg): # EDIT + if self.state_manager.last_state: + logger.error(f"DEBUG RG_WAS_TAKEN_OVER: {selected_rg} // {rg}, {rg == self.state_manager.last_state.active_rg}") + return rg_states[selected_rg].rg_that_was_taken_over and rg == self.state_manager.last_state.active_rg + else: + return None + + args_list = [[rg_states[rg], results[rg].conditional_state, rg_was_taken_over(rg)] for rg in other_rgs] # EDIT # Run update_state_if_not_chosen for other RGs logger.info(f'Starting to run update_state_if_not_chosen for {other_rgs}...') @@ -331,7 +339,7 @@ def run_rgs_and_rank(self, phase: str, exclude_rgs : List[str] = []) -> RankedRe # Get the states for the RGs we'll run, which we'll use as input to the get_response/get_prompt fn logger.debug('Copying RG states to use as input...') - input_rg_states = copy.copy([rg_states[rg] for rg in rgs_list]) # list of dicts + # input_rg_states = copy.copy([rg_states[rg] for rg in rgs_list]) # list of dicts # EDIT: COMMENT OUT # import pdb; pdb.set_trace() @@ -343,10 +351,21 @@ def run_rgs_and_rank(self, phase: str, exclude_rgs : List[str] = []) -> RankedRe priority_modules = [last_state_active_rg] else: priority_modules = [] - results_dict = self.response_generators.run_multithreaded(rg_names=rgs_list, - function_name=f'get_{phase}', + + rg_was_taken_over = None # EDIT + if self.state_manager.last_state_response: # EDIT + rg_was_taken_over = self.state_manager.last_state_response.state.rg_that_was_taken_over + + def rg_to_resume(rg): # EDIT : ???? + logger.error(f"DEBUG RG_TO_RESUME: {rg_was_taken_over} // {rg}, {rg == rg_was_taken_over}") + return rg == rg_was_taken_over + + function_name = 'get_prompt_wrapper' if phase == 'prompt' else 'get_response' + args_list = copy.copy([[rg_states[rg], rg_to_resume(rg)] for rg in rgs_list]) # EDIT : ???? + results_dict = self.response_generators.run_multithreaded(rg_names=rgs_list, # EDIT : ???? + function_name=function_name, timeout=timeout, - args_list=[[state] for state in input_rg_states], + args_list=args_list, # [[state] for state in input_rg_states], priority_modules=priority_modules) # Log the initial results diff --git a/chirpy/core/entity_linker/wiki_data_fetching.py b/chirpy/core/entity_linker/wiki_data_fetching.py index d60a454..b7c7f12 100644 --- a/chirpy/core/entity_linker/wiki_data_fetching.py +++ b/chirpy/core/entity_linker/wiki_data_fetching.py @@ -20,7 +20,7 @@ ANCHORTEXT_QUERY_TIMEOUT = 3.0 # seconds ENTITYNAME_QUERY_TIMEOUT = 1.0 # seconds -ARTICLES_INDEX_NAME = 'enwiki-20220107-articles' +ARTICLES_INDEX_NAME = 'enwiki-20200920-articles' # These are the fields we DO want to fetch from ES FIELDS_FILTER = ['doc_title', 'doc_id', 'categories', 'pageview', 'linkable_span_info', 'wikidata_categories_all', 'redirects', 'plural'] diff --git a/chirpy/core/entity_tracker/entity_tracker.py b/chirpy/core/entity_tracker/entity_tracker.py index 4cc8055..0d59043 100644 --- a/chirpy/core/entity_tracker/entity_tracker.py +++ b/chirpy/core/entity_tracker/entity_tracker.py @@ -8,6 +8,8 @@ from chirpy.core.entity_linker.thresholds import SCORE_THRESHOLD_NAV_ABOUT, SCORE_THRESHOLD_NAV_NOT_ABOUT, SCORE_THRESHOLD_EXPECTEDTYPE from chirpy.core.entity_linker.entity_groups import EntityGroup +import random # EDIT + logger = logging.getLogger('chirpylogger') class TransitionType(Enum): @@ -23,6 +25,8 @@ class EntityTrackerState(object): def __init__(self): self.cur_entity = None # the current entity under discussion (can be None) + self.talked_unfinished = [] # EDIT + self.able_to_takeover_entities = [] # EDIT self.talked_rejected = [] # entities we talked about in the past, and stopped talking about because the user indicated they didn't want to talk about it any more self.talked_finished = [] # entities we talked about in the past, that aren't in talked_rejected self.talked_transitionable = [] @@ -97,7 +101,7 @@ def finish_entity(self, entity: Optional[WikiEntity], transition_is_possible=Tru logger.error(f"This is an error. This should be a WikiEntity object but {entity} is of type {type(entity)}") entity = None - if entity is not None and entity not in self.talked_finished: + if entity is not None and entity not in self.talked_finished and entity not in self.talked_unfinished: # EDIT (?) logger.info(f'Putting entity {entity} on the talked_finished list') self.talked_finished.append(entity) @@ -277,16 +281,21 @@ def condition_fn(entity_linker_result, linked_span, entity) -> bool: if nav_intent_output.neg_intent or nav_intent_output.pos_intent or last_answer_type in [AnswerType.QUESTION_SELFHANDLING, AnswerType.QUESTION_HANDOFF]: self.cur_entity = self.entity_initiated_on_turn + self.able_to_takeover_entities = [] # EDIT + for linked_span in current_state.entity_linker.high_prec: if not self.talked(linked_span.top_ent): logger.info(f'Adding {linked_span.top_ent} to user_mentioned_untalked') self.user_mentioned_untalked.append(linked_span.top_ent) + self.able_to_takeover_entities.append(linked_span.top_ent) # EDIT logger.primary_info(f'The EntityTrackerState is now: {self}') + logger.error(f'ABLE_TO_TAKEOVER_ENTITIES: {self.able_to_takeover_entities}') # Update the entity tracker history self.history[-1]['user'] = self.cur_entity + def record_untalked_high_prec_entities(self, current_state): """ Take any entities in the entity linker's high precision set for this turn, and if they haven't been discussed, @@ -313,6 +322,7 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up result: ResponseGeneratorResult, PromptResult, or UpdateEntity rg: the name of the RG that provided the new entity """ + if isinstance(result, UpdateEntity): new_entity = result.cur_entity phase = 'get_entity' @@ -325,6 +335,14 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up transition_is_possible = not getattr(result, 'no_transition', False) + if self.able_to_takeover_entities: # EDIT + self.talked_unfinished.append(self.cur_entity) + new_entity = self.able_to_takeover_entities.pop() + logger.primary_info(f'Removing {new_entity} from {self.able_to_takeover_entities}') + self.able_to_takeover_entities = [e for e in self.able_to_takeover_entities if e != new_entity] + logger.error(f'[AFTER TAKEOVER 1] TALK_UNFINISHED: {self.talked_unfinished} // ABLE_TO_TAKEOVER_ENT: {self.able_to_takeover_entities} //' + f'/ TALKED_FINISHED = {self.talked_finished}') + if new_entity == self.cur_entity: logger.primary_info(f'new_entity={new_entity} from {rg} RG {phase} is the same as cur_entity, so keeping EntityTrackerState the same') else: @@ -344,7 +362,14 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up self.user_mentioned_untalked = [e for e in self.user_mentioned_untalked if e != new_entity] logger.primary_info(f'Set cur_entity to new_entity={new_entity} from {rg} RG {phase}') - logger.primary_info(f'EntityTrackerState after updating wrt {rg} RG {phase}: {self}') + + if new_entity in self.talked_unfinished: # EDIT + archived_entity = new_entity + logger.error( + f"Removing archived_entity [{archived_entity}] from talked_unfinished [{self.talked_unfinished}]") + self.talked_unfinished.remove(archived_entity) + + logger.error(f'EntityTrackerState after updating wrt {rg} RG {phase}: {self}') # If we're updating after receiving UpdateEntity from an RG, put any undiscussed high precision entities that # the user mentioned this turn in user_mentioned_untalked @@ -360,6 +385,8 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up def __repr__(self, show_history=False): output = f" bool: if ent is None: return True return ent in entities + + self.able_to_takeover_entities = [ent for ent in self.able_to_takeover_entities if keep_entity(ent)] # EDIT self.talked_finished = [ent for ent in self.talked_finished if keep_entity(ent)] self.talked_rejected = [ent for ent in self.talked_rejected if keep_entity(ent)] self.user_mentioned_untalked = [ent for ent in self.user_mentioned_untalked if keep_entity(ent)] @@ -393,6 +422,8 @@ def reduce_size(self, max_size: int): # Make a set (no duplicates) of all the WikiEntities stored in this EntityTrackerState entity_set = set() entity_set.add(self.cur_entity) + entity_set.update(self.talked_unfinished) # EDIT + entity_set.update(self.able_to_takeover_entities) # EDIT entity_set.update(self.talked_finished) entity_set.update(self.talked_rejected) entity_set.update(self.user_mentioned_untalked) @@ -408,6 +439,8 @@ def replace_ent(ent: Optional[WikiEntity]): return None return entname2ent[ent.name] self.cur_entity = replace_ent(self.cur_entity) + self.talked_unfinished = [replace_ent(ent) for ent in self.talked_unfinished] # EDIT + self.able_to_takeover_entities = [replace_ent(ent) for ent in self.able_to_takeover_entities] # EDIT self.talked_finished = [replace_ent(ent) for ent in self.talked_finished] self.talked_rejected = [replace_ent(ent) for ent in self.talked_rejected] self.user_mentioned_untalked = [replace_ent(ent) for ent in self.user_mentioned_untalked] diff --git a/chirpy/core/logging_utils.py b/chirpy/core/logging_utils.py index e8daabe..9d4f1e8 100644 --- a/chirpy/core/logging_utils.py +++ b/chirpy/core/logging_utils.py @@ -33,7 +33,7 @@ class LoggerSettings: # AWS adds a LambdaLoggerHandler to the root handler, which causes duplicate logging because we have our customized # StreamHandler on the root logger too. So we set remove_root_handlers=True to remove the LambdaLoggerHandler. # See here: https://stackoverflow.com/questions/50909824/getting-logs-twice-in-aws-lambda-function -PROD_LOGGER_SETTINGS = LoggerSettings(logtoscreen_level=logging.DEBUG, +PROD_LOGGER_SETTINGS = LoggerSettings(logtoscreen_level=logging.INFO + 5, logtoscreen_usecolor=True, logtofile_level=None, logtofile_path=None, diff --git a/chirpy/core/response_generator/response_generator.py b/chirpy/core/response_generator/response_generator.py index 6063aa9..d34ffbc 100644 --- a/chirpy/core/response_generator/response_generator.py +++ b/chirpy/core/response_generator/response_generator.py @@ -24,6 +24,7 @@ from chirpy.response_generators.music.utils import WikiEntityInterface from concurrent import futures +import copy logger = logging.getLogger('chirpylogger') @@ -41,7 +42,7 @@ def __init__(self, disallow_start_from=None, can_give_prompts=False, state_constructor=None, - conditional_state_constructor=None + conditional_state_constructor=None, ): """Creates a new Response Generator. @@ -93,7 +94,8 @@ def update_state_if_chosen(self, state, conditional_state): if response_types is not None: state.response_types = construct_response_types_tuple(response_types) - if conditional_state is None: return state + if conditional_state is None: + return state if conditional_state: for attr in dir(conditional_state): @@ -101,15 +103,21 @@ def update_state_if_chosen(self, state, conditional_state): val = getattr(conditional_state, attr) if val != NO_UPDATE: setattr(state, attr, val) state.num_turns_in_rg += 1 + return state - def update_state_if_not_chosen(self, state, conditional_state): + def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over=False): # EDIT """ By default, this sets the prev_treelet_str and next_treelet_str to '' and resets num_turns_in_rg to 0. Response types are also saved. No other attributes are updated. All other attributes in ConditionalState are set to NO-UPDATE """ + if rg_was_taken_over: # EDIT + state.archived_state = copy.deepcopy(state) + logging.error(f"ARCHIVED_STATE: {state.archived_state}") + + response_types = self.get_cache(f'{self.name}_response_types') if response_types is not None: state.response_types = construct_response_types_tuple(response_types) @@ -285,6 +293,9 @@ def get_current_entity(self, initiated_this_turn=False): else: return self.state_manager.current_state.entity_tracker.cur_entity + def get_most_recent_able_to_takeover_entity(self): # EDIT + return self.state_manager.current_state.entity_tracker.able_to_takeover_entities[-1] + def get_entity_tracker(self): return self.state_manager.current_state.entity_tracker @@ -861,7 +872,7 @@ def get_last_rg_in_control(self) -> Optional[str]: return self.state_manager.last_state.selected_response_rg - def get_response(self, state) -> ResponseGeneratorResult: + def get_response(self, state, rg_was_taken_over=False) -> ResponseGeneratorResult: # EDIT : ???? response_types = self.identify_response_types(self.utterance) logger.primary_info(f"{self.name} identified response_types: {response_types}") self.state = state @@ -915,19 +926,38 @@ def get_response(self, state) -> ResponseGeneratorResult: if not is_continuing_conversation: # allow the first branch to divert here logger.primary_info(f"{self.name} is not currently active, so checking if it should activate") - + if self.name == 'FOOD': + logger.error(f"Self.state is {self.state}") activation_check_fns = { (lambda: self.get_last_active_rg() in self.disallow_start_from): self.get_fallback_result, (lambda: True): self.handle_direct_navigational_intent, + (lambda: (self.last_rg_willing_to_handover_control() and self.exist_able_to_takeover_entities())): self.get_takeover_response, # EDIT + # (lambda: (self.takeover_rg_willing_to_handback_control())): wrapped_partial(self.resume_response_generator, response_types), # EDIT (lambda: True): self.handle_current_entity, (lambda: True): self.get_intro_treelet_response, (lambda: True): self.handle_custom_activation_checks, } + logging.error(f"DEBUG HANDOVER {self.last_rg_willing_to_handover_control()}, {self.exist_able_to_takeover_entities()}") + logging.error(f"DEBUG HANDBACK {self.takeover_rg_willing_to_handback_control()}") + + for activation_condition, activation_check_fn in activation_check_fns.items(): if activation_condition(): response = activation_check_fn() - if response: return self.possibly_augment_with_prompt(response) + + # if response and activation_check_fn == self.get_takeover_response: # EDIT + # logger.primary_info(f"{self.name} is being taken over.") + # resumable_state = self.state_manager.last_state_response + # logger.error(f"STATE TAKEOVER: Change resumable state from {self.state.resumable_state} to {resumable_state})") + # updated_state = self.update_state_if_not_chosen(self.state, self.ConditionalState()) + # self.state = updated_state + # self.state.resumable_state = resumable_state + # logger.error(f"CURRENT STATE AFTER TAKEOVER: {self.state}") + + if response: + return self.possibly_augment_with_prompt(response) + response = self.handle_default_post_checks() if response: @@ -935,6 +965,113 @@ def get_response(self, state) -> ResponseGeneratorResult: return self.get_fallback_result() + def last_rg_willing_to_handover_control(self): # EDIT + last_active_rg_prompt = self.state_manager.last_state_response + # logging.error(f"LAST ACTIVE: {last_active_rg_prompt}") + # logging.error(f"LAST WILL: {last_active_rg_prompt.last_rg_willing_to_handover_control}") + if last_active_rg_prompt: + return last_active_rg_prompt.last_rg_willing_to_handover_control + else: + return False + + def exist_able_to_takeover_entities(self): # EDIT + return len(self.state_manager.current_state.entity_tracker.able_to_takeover_entities) != 0 + + def get_takeover_response(self): # EDIT + logging.error(f"TEST: {self.name} null get_takeover_response") + return None + + def takeover_rg_willing_to_handback_control(self): # EDIT + last_active_rg_prompt = self.state_manager.last_state_response + if last_active_rg_prompt: + return last_active_rg_prompt.takeover_rg_willing_to_handback_control + else: + return False + + # def resume_response_generator(self, response_types): # EDIT + # # logging.error(f"TEST: {self.name} null resume_conversation") + # # if self.name == 'FOOD': + # # logging.error(f"DEBUG: state is {self.state} ") + # # logging.error(f"ARCHIVED_RESUME: {archived_state}") + # archived_state = self.state.archived_state + # if archived_state: + # logging.error(f" HANDING OVER FROM: {self.name} // {self.state}") + # self.state = archived_state + # logging.error(f" HANDING OVER TO: {self.state}") + # return self.continue_conversation(response_types) + # else: + # return None + + def get_resuming_statement(self, state) -> ResponseGeneratorResult: + logging.error(f"TEST: {self.name} null get_resuming_statement") + return self.emptyPrompt() + + def augment_resuming_statement(self, resuming_statement_first_treelet): + resuming_conversation_second_treelet_str = resuming_statement_first_treelet.resuming_conversation_next_treelet + logger.error(f"DEBUG RESUMING PROMPT TREELET: {resuming_conversation_second_treelet_str}") + resuming_conversation_second_treelet = self.treelets[resuming_conversation_second_treelet_str] + resuming_prompt_second_treelet = resuming_conversation_second_treelet.get_prompt() + logger.error(f"DEBUG RESUMING PROMPT: {resuming_prompt_second_treelet}") + if resuming_prompt_second_treelet: + # for attr_copy in ['state', 'cur_entity', 'expected_type']: + # getattr(resuming_statement_first_treelet, attr_copy) = + + resuming_statement_first_treelet.text = f"{resuming_statement_first_treelet.text} {resuming_prompt_second_treelet.text}" + resuming_statement_first_treelet.state = resuming_prompt_second_treelet.state + resuming_statement_first_treelet.conditional_state.next_treelet_str = resuming_conversation_second_treelet_str + resuming_statement_first_treelet.cur_entity = resuming_prompt_second_treelet.cur_entity + resuming_statement_first_treelet.expected_type = resuming_prompt_second_treelet.expected_type + resuming_statement_first_treelet.answer_type = resuming_prompt_second_treelet.answer_type + resuming_statement_first_treelet.last_rg_willing_to_handover_control = resuming_prompt_second_treelet.last_rg_willing_to_handover_control + resuming_statement_first_treelet.rg_that_was_taken_over = resuming_prompt_second_treelet.rg_that_was_taken_over + resuming_statement_first_treelet.takeover_rg_willing_to_handback_control = resuming_prompt_second_treelet.takeover_rg_willing_to_handback_control + resuming_statement_first_treelet.resuming_conversation_next_treelet = None + logger.error(f"DEBUG RESUMING_STATEMENT_FIRST_TREELET.STATE: {resuming_statement_first_treelet}") + return resuming_statement_first_treelet + + def resume_conversation(self): + logger.error(f"DEBUG SELF WITH ARCHIVE: {self.name}") + logger.error( + f"DEBUG ARCHIVED_STATE_RESUMING_RG: {self.state_manager.current_state.response_generator_states[self.name].archived_state}") + archived_state = self.state_manager.current_state.response_generator_states[self.name].archived_state + self.state = archived_state + logger.error(f"DEBUG SELF AFTER RETREIVING ARCHIVE: {self}") + logger.error(f"DEBUG SELF.STATE AFTER RETREIVING ARCHIVE: {self.state}") + + first_treelet_str = self.state.next_treelet_str + assert first_treelet_str in self.treelets + first_treelet = self.treelets[first_treelet_str] + resuming_statement_first_treelet = first_treelet.get_resuming_statement() + logger.error(f"DEBUG RESUMING STATEMENT AFTER RETREIVING ARCHIVE: {resuming_statement_first_treelet}") + + resuming_conversation = self.augment_resuming_statement(resuming_statement_first_treelet) + logger.error(f"DEBUG WHOLE RESUMING RESPONSE: {resuming_conversation}") + + return resuming_conversation + + def get_prompt_wrapper(self, state, rg_to_resume=False): + if self.takeover_rg_willing_to_handback_control(): + if rg_to_resume: + return self.resume_conversation() + else: + return self.get_prompt(state) + + # def get_prompt_wrapper(self, state, rg_resuming_prompt=False): + # if self.takeover_rg_willing_to_handback_control(): + # rg_with_archived_state = self.state_manager.last_state_response.state.rg_that_was_taken_over + # logger.error(f"DEBUG RG_WITH_ARCHIVE {self.state_manager.last_state_response.state.rg_that_was_taken_over}") + # if self.name == rg_with_archived_state: + # logger.error(f"DEBUG SELF WITH ARCHIVE: {self.name} // {self.state_manager} // {self.state_manager.last_state_response}") + # archived_state = self.State.archived_state + # if archived_state: # TODO: Set to None later ???? + # self.state = archived_state + # archived_resuming_response = self.get_resuming_response(archived_state) + # logger.error("DEBUG ARCHIVE STATE PROMPT: {archived_resuming_response}") + # return archived_resuming_response + # else: + # return self.emptyPrompt() + # return self.get_prompt(state) + def possibly_augment_with_prompt(self, response): """ @@ -971,7 +1108,7 @@ def continue_conversation(self, response_types) -> Optional[ResponseGeneratorRes next_treelet = None response_priority = ResponsePriority.STRONG_CONTINUE - + logger.error(f"In continue_conversation, self.state is {self.state}, next_treelet_str is {next_treelet_str}") if next_treelet_str is None: return self.emptyResult() # continue from some other RG elif next_treelet_str == '': diff --git a/chirpy/core/response_generator/state.py b/chirpy/core/response_generator/state.py index 01d3080..180882e 100644 --- a/chirpy/core/response_generator/state.py +++ b/chirpy/core/response_generator/state.py @@ -22,12 +22,16 @@ class BaseState: next_treelet_str: Optional[str] = '' response_types: Tuple[str] = () num_turns_in_rg: int = 0 + archived_state: "BaseState" = None # EDIT + rg_that_was_taken_over: str = None # EDIT @dataclass class BaseConditionalState: prev_treelet_str: str = '' next_treelet_str: Optional[str] = '' response_types: Tuple[str] = NO_UPDATE + archived_state: "BaseState" = NO_UPDATE # EDIT + rg_that_was_taken_over: str = None # EDIT def construct_response_types_tuple(response_types): return tuple([str(x) for x in response_types]) diff --git a/chirpy/core/response_generator/treelet.py b/chirpy/core/response_generator/treelet.py index 805bf9b..13337da 100644 --- a/chirpy/core/response_generator/treelet.py +++ b/chirpy/core/response_generator/treelet.py @@ -50,6 +50,9 @@ def get_current_state(self): def get_current_entity(self, initiated_this_turn=False): return self.rg.get_current_entity(initiated_this_turn=initiated_this_turn) + def get_most_recent_able_to_takeover_entity(self): + return self.rg.get_most_recent_able_to_takeover_entity() + def get_sentiment(self): return self.rg.get_sentiment() diff --git a/chirpy/core/response_generator_datatypes.py b/chirpy/core/response_generator_datatypes.py index 3f12221..3394ff7 100644 --- a/chirpy/core/response_generator_datatypes.py +++ b/chirpy/core/response_generator_datatypes.py @@ -33,7 +33,11 @@ def __init__(self, smooth_handoff: Optional[SmoothHandoff] = None, conditional_state=None, tiebreak_priority=None, - no_transition=False): + no_transition=False, + last_rg_willing_to_handover_control=False, # EDIT + rg_that_was_taken_over =None, # EDIT + takeover_rg_willing_to_handback_control=False # EDIT + ): """ :param text: text of the response :param priority: priority of the response @@ -98,6 +102,9 @@ def __init__(self, self.conditional_state = conditional_state self.tiebreak_priority = tiebreak_priority self.no_transition = no_transition + self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control # EDIT + self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT + self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control # EDIT def reduce_size(self, max_size:int = None): """Gracefully degrade by removing non essential attributes. @@ -124,7 +131,12 @@ def __init__(self, cur_entity: Optional[WikiEntity], expected_type: Optional[EntityGroup] = None, conditional_state=None, - answer_type: AnswerType = AnswerType.QUESTION_SELFHANDLING): + answer_type: AnswerType = AnswerType.QUESTION_SELFHANDLING, + last_rg_willing_to_handover_control=False, # EDIT + rg_that_was_taken_over =None, # EDIT + takeover_rg_willing_to_handback_control=False, # EDIT + resuming_conversation_next_treelet=None # EDIT + ): """ :param text: text of the response :param prompt_type: the type of response being given, typically CONTEXTUAL or GENERIC @@ -163,6 +175,10 @@ def __init__(self, self.state = state self.conditional_state = conditional_state self.answer_type = answer_type + self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control # EDIT + self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT + self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control # EDIT + self.resuming_conversation_next_treelet = resuming_conversation_next_treelet # EDIT def __repr__(self): return 'PromptResult' + str(self.__dict__) diff --git a/chirpy/core/response_priority.py b/chirpy/core/response_priority.py index eb0710d..ba9232c 100644 --- a/chirpy/core/response_priority.py +++ b/chirpy/core/response_priority.py @@ -69,7 +69,7 @@ class TiebreakPriority(Enum): ACKNOWLEDGMENT = 62 EVI = 58 NEWS = 65 - WIKI = 64 + WIKI = 69 # EDIT: Change from 64 CATEGORIES = 60 MUSIC = 66 NEURAL_FALLBACK = 5 # fallback should always be lowest priority i.e. last resort diff --git a/chirpy/core/state.py b/chirpy/core/state.py index b9cb559..310c35e 100644 --- a/chirpy/core/state.py +++ b/chirpy/core/state.py @@ -78,6 +78,7 @@ def update_from_last_state(self, last_state): self.entity_tracker.init_for_new_turn() self.experiments = last_state.experiments self.turn_num = last_state.turn_num + 1 + try: self.turns_since_last_active = last_state.turns_since_last_active except AttributeError: diff --git a/chirpy/core/state_manager.py b/chirpy/core/state_manager.py index f211297..c535874 100644 --- a/chirpy/core/state_manager.py +++ b/chirpy/core/state_manager.py @@ -30,3 +30,5 @@ def last_state_response(self): if not self.last_state: return None if hasattr(self.last_state, 'prompt_results'): return self.last_state.prompt_results[self.last_state.active_rg] else: return self.last_state.response_results[self.last_state.active_rg] + + diff --git a/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py b/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py index bbca842..7a55f33 100644 --- a/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py +++ b/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py @@ -101,7 +101,7 @@ def handle_custom_continuation_checks(self): # If neither matched, allow another RG to handle return self.emptyResult() - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState]) -> BaseState: + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: # EDIT state = super().update_state_if_not_chosen(state, conditional_state) state.has_just_asked_to_exit = False return state diff --git a/chirpy/response_generators/food/food_response_generator.py b/chirpy/response_generators/food/food_response_generator.py index 540a31b..b30497c 100644 --- a/chirpy/response_generators/food/food_response_generator.py +++ b/chirpy/response_generators/food/food_response_generator.py @@ -20,6 +20,7 @@ from chirpy.core.offensive_classifier.offensive_classifier import OffensiveClassifier from chirpy.response_generators.food.food_helpers import * + logger = logging.getLogger('chirpylogger') class FoodResponseGenerator(ResponseGenerator): @@ -31,9 +32,11 @@ def __init__(self, state_manager) -> None: self.comment_on_favorite_type_treelet = CommentOnFavoriteTypeTreelet(self) self.ask_favorite_food_treelet = AskFavoriteFoodTreelet(self) self.factoid_treelet = FactoidTreelet(self) + treelets = { treelet.name: treelet for treelet in [self.introductory_treelet, self.open_ended_user_comment_treelet, - self.comment_on_favorite_type_treelet, self.factoid_treelet, self.ask_favorite_food_treelet] + self.comment_on_favorite_type_treelet, self.factoid_treelet, self.ask_favorite_food_treelet + ] } super().__init__(state_manager, treelets=treelets, intent_templates=[], can_give_prompts=True, state_constructor=State, @@ -87,4 +90,4 @@ def get_neural_response(self, prefix=None, allow_questions=False, conditions=Non def get_prompt(self, state): self.state = state self.response_types = self.get_cache(f'{self.name}_response_types') - return self.emptyPrompt() + return self.emptyPrompt() \ No newline at end of file diff --git a/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py b/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py index fa3c570..0d4f0f5 100644 --- a/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py +++ b/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py @@ -29,5 +29,6 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): conditional_state=ConditionalState( next_treelet_str="food_introductory_treelet", cur_food=None), - expected_type=ENTITY_GROUPS_FOR_EXPECTED_TYPE.food_related + expected_type=ENTITY_GROUPS_FOR_EXPECTED_TYPE.food_related, + last_rg_willing_to_handover_control=True # EDIT ) diff --git a/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py b/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py index 6d21c16..08b5421 100644 --- a/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py +++ b/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py @@ -38,7 +38,8 @@ def get_prompt(self, conditional_state=None): return None return PromptResult(text, PromptType.CONTEXTUAL, state, conditional_state=conditional_state, - cur_entity=entity, answer_type=AnswerType.QUESTION_SELFHANDLING) + cur_entity=entity, answer_type=AnswerType.QUESTION_SELFHANDLING, + ) def get_best_candidate_user_entity(self, utterance, cur_food): def condition_fn(entity_linker_result, linked_span, entity): @@ -88,10 +89,43 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): other_type = sample_from_type(cur_food) text = f"That totally makes sense! I also really enjoy {user_answer}. Personally, I really like {other_type}." - return ResponseGeneratorResult(text=text, priority=priority, + return ResponseGeneratorResult(text=text, priority=priority, # EDIT needs_prompt=False, state=state, cur_entity=entity, conditional_state=ConditionalState( prompt_treelet=self.rg.open_ended_user_comment_treelet.name, - cur_food=cur_food_entity) + cur_food=cur_food_entity), + last_rg_willing_to_handover_control=False # EDIT ) + + + # return ResponseGeneratorResult(text="TODO: RESUMING RESPONSE", priority=priority, # EDIT + # needs_prompt=False, state=state, + # cur_entity=entity, + # conditional_state=ConditionalState( + # prompt_treelet=self.rg.open_ended_user_comment_treelet.name, + # cur_food=cur_food_entity), + # last_rg_willing_to_handover_control=False # EDIT + # ) + + def get_resuming_statement(self, prompt_type=PromptType.FORCE_START, **kwargs): + logger.error(f"GET_STATEMENT_RESPONSE got triggered.") + state, utterance, response_types = self.get_state_utterance_response_types() + entity = self.rg.get_current_entity(initiated_this_turn=False) + cur_food_entity = state.cur_food + cur_food = cur_food_entity.name + cur_talkable_food = cur_food_entity.talkable_name + + if get_custom_question(cur_food) is not None: + custom_question_answer = get_custom_question_answer(cur_food) + text = f"TODO: RESUMING_STATEMENT_FIRST_TREELET_A (e.g. Personally, when it comes to {cur_talkable_food}, I really like {custom_question_answer})." + else: + other_type = sample_from_type(cur_food) + text = f"TODO: RESUMING_STATEMENT_FIRST_TREELET_B (e.g. Personally, I really like {other_type})" + + return PromptResult(text=text, prompt_type=prompt_type, state=state, + conditional_state=ConditionalState( + prompt_treelet=self.rg.open_ended_user_comment_treelet.name, + cur_food=cur_food_entity), + cur_entity=entity, + resuming_conversation_next_treelet=self.rg.open_ended_user_comment_treelet.name) \ No newline at end of file diff --git a/chirpy/response_generators/food/treelets/factoid_treelet.py b/chirpy/response_generators/food/treelets/factoid_treelet.py index 217303f..c4f4402 100644 --- a/chirpy/response_generators/food/treelets/factoid_treelet.py +++ b/chirpy/response_generators/food/treelets/factoid_treelet.py @@ -27,7 +27,8 @@ def get_prompt(self, conditional_state=None): conditional_state = ConditionalState(cur_food=cur_food) entity = self.rg.state_manager.current_state.entity_tracker.cur_entity return PromptResult(text=get_factoid(cur_food), prompt_type=PromptType.CONTEXTUAL, - state=state, cur_entity=entity, conditional_state=conditional_state, answer_type=AnswerType.QUESTION_SELFHANDLING) + state=state, cur_entity=entity, conditional_state=conditional_state, answer_type=AnswerType.QUESTION_SELFHANDLING, + ) def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): """ Returns the response. """ @@ -50,5 +51,6 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): cur_entity=None, conditional_state=ConditionalState( prev_treelet_str=self.name, - cur_food=cur_food - )) + cur_food=cur_food), + last_rg_willing_to_handover_control=True # EDIT + ) diff --git a/chirpy/response_generators/food/treelets/introductory_treelet.py b/chirpy/response_generators/food/treelets/introductory_treelet.py index d4e0448..9f2d8ba 100644 --- a/chirpy/response_generators/food/treelets/introductory_treelet.py +++ b/chirpy/response_generators/food/treelets/introductory_treelet.py @@ -73,7 +73,8 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): needs_prompt=False, state=state, cur_entity=entity, conditional_state=ConditionalState(cur_food=entity, - prompt_treelet=prompt_treelet)) + prompt_treelet=prompt_treelet), + last_rg_willing_to_handover_control=True) # EDIT def get_prompt(self, **kwargs): return None @@ -90,3 +91,4 @@ def get_prompt(self, **kwargs): # cur_treelet_str="get_other_type", # cur_food=entity.name, # response=prompt_text)) + diff --git a/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py b/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py index df132c4..b63aeba 100644 --- a/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py +++ b/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py @@ -33,7 +33,7 @@ def get_prompt(self, conditional_state=None): pronoun = infl('them', entity.is_plural) if best_attribute: text = 'What do you think?' else: text = f'What do you like best about {pronoun}?' - return PromptResult(text, PromptType.CONTEXTUAL, state=state, cur_entity=entity, conditional_state=conditional_state) + return PromptResult(text=f'TODO: RESUMING_CONV_SECOND_TREELET (e.g. {text})', prompt_type=PromptType.CONTEXTUAL, state=state, cur_entity=entity, conditional_state=conditional_state) def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): """ Returns the response. """ @@ -67,11 +67,12 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): text = f"{neural_response} {concluding_statement}" cur_entity = None - return ResponseGeneratorResult(text=text, priority=ResponsePriority.STRONG_CONTINUE, + return ResponseGeneratorResult(text=f"TODO: PICKING UP RESPONSE (eg. {text})", priority=ResponsePriority.STRONG_CONTINUE, needs_prompt=needs_prompt, state=state, cur_entity=cur_entity, conditional_state=ConditionalState( prev_treelet_str=self.name, prompt_treelet=prompt_treelet, - cur_food=None) + cur_food=None), + last_rg_willing_to_handover_control=False # EDIT ) diff --git a/chirpy/response_generators/launch/launch_response_generator.py b/chirpy/response_generators/launch/launch_response_generator.py index 8c0fe90..f2570d3 100644 --- a/chirpy/response_generators/launch/launch_response_generator.py +++ b/chirpy/response_generators/launch/launch_response_generator.py @@ -50,7 +50,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi # state.asked_name_counter = 1 return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState]) -> BaseState: + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: # EDIT state = super().update_state_if_not_chosen(state, conditional_state) state.next_treelet_str = None return state diff --git a/chirpy/response_generators/music/music_response_generator.py b/chirpy/response_generators/music/music_response_generator.py index c75fdbd..d7ee4ab 100644 --- a/chirpy/response_generators/music/music_response_generator.py +++ b/chirpy/response_generators/music/music_response_generator.py @@ -76,7 +76,7 @@ def update_state_if_chosen(self, state, conditional_state): state.discussed_entities.append(state.cur_singer_str) return state - def update_state_if_not_chosen(self, state, conditional_state): + def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over=False): # EDIT state = super().update_state_if_not_chosen(state, conditional_state) return state diff --git a/chirpy/response_generators/neural_chat/neural_chat_response_generator.py b/chirpy/response_generators/neural_chat/neural_chat_response_generator.py index f3f2056..a98258d 100644 --- a/chirpy/response_generators/neural_chat/neural_chat_response_generator.py +++ b/chirpy/response_generators/neural_chat/neural_chat_response_generator.py @@ -188,7 +188,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi state.update_if_chosen(conditional_state) return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState]) -> State: + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: # EDIT logger.primary_info(f"Neural chat state is {state}") if conditional_state is not None: state.update_if_not_chosen(conditional_state) diff --git a/chirpy/response_generators/neural_chat/state.py b/chirpy/response_generators/neural_chat/state.py index d83f9c6..920466b 100644 --- a/chirpy/response_generators/neural_chat/state.py +++ b/chirpy/response_generators/neural_chat/state.py @@ -3,6 +3,8 @@ from typing import List, Optional, Set, Tuple from chirpy.core.response_generator.state import NO_UPDATE +import copy + logger = logging.getLogger('chirpylogger') @dataclass @@ -32,7 +34,8 @@ class ConditionalState(object): def __init__(self, next_treelet: Optional[str] = None, most_recent_treelet: Optional[str] = None, user_utterance: Optional[str] = None, user_labels: List[str] = [], bot_utterance: Optional[str] = None, bot_labels: List[str] = [], - neural_responses: Optional[List[str]] = None, num_topic_shifts: int = 0): + neural_responses: Optional[List[str]] = None, num_topic_shifts: int = 0, + archived_state: "State" = None, rg_that_was_taken_over: str = None): # EDIT """ @param next_treelet: the name of the treelet we should run on the next turn if our response/prompt is chosen. None means turn off next turn. @param most_recent_treelet: the name of the treelet that handled this turn, if applicable @@ -59,6 +62,8 @@ def __init__(self, next_treelet: Optional[str] = None, most_recent_treelet: Opti self.bot_labels = bot_labels self.neural_responses = neural_responses self.num_topic_shifts = num_topic_shifts + self.archived_state = archived_state # EDIT + self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT def __repr__(self): return f"" @@ -161,7 +170,7 @@ def update_if_chosen(self, conditional_state: ConditionalState): self.update_conv_history(conditional_state) - def update_if_not_chosen(self, conditional_state: ConditionalState): + def update_if_not_chosen(self, conditional_state: ConditionalState, rg_was_taken_over=False): """If our response/prompt has not been chosen, update state""" # Set the next_treelet for the next turn to be None (off) diff --git a/chirpy/response_generators/offensive_user/offensive_user_response_generator.py b/chirpy/response_generators/offensive_user/offensive_user_response_generator.py index dda40ae..7ff75c1 100644 --- a/chirpy/response_generators/offensive_user/offensive_user_response_generator.py +++ b/chirpy/response_generators/offensive_user/offensive_user_response_generator.py @@ -78,7 +78,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi # state[key] += 1 # return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState]) -> BaseState: + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: # EDIT state = super().update_state_if_not_chosen(state, conditional_state) state.handle_response = False state.offense_type = None diff --git a/chirpy/response_generators/opinion2/opinion_response_generator.py b/chirpy/response_generators/opinion2/opinion_response_generator.py index 853d142..a3b3405 100644 --- a/chirpy/response_generators/opinion2/opinion_response_generator.py +++ b/chirpy/response_generators/opinion2/opinion_response_generator.py @@ -533,7 +533,7 @@ def update_state_if_chosen(self, state: State, conditional_state : Optional[Stat if val != NO_UPDATE: setattr(state, attr, val) return state - def update_state_if_not_chosen(self, state: State, conditional_state : Optional[State]) -> State: + def update_state_if_not_chosen(self, state: State, conditional_state : Optional[State], rg_was_taken_over=False) -> State: # EDIT new_state = state.reset_state() new_state.num_turns_since_long_policy += 1 return new_state diff --git a/chirpy/response_generators/wiki2/state.py b/chirpy/response_generators/wiki2/state.py index 7ee3a6b..9c4b547 100644 --- a/chirpy/response_generators/wiki2/state.py +++ b/chirpy/response_generators/wiki2/state.py @@ -60,6 +60,7 @@ class State(BaseState): context_used: Optional[str] = None + @dataclass class ConditionalState(BaseConditionalState): # This is only used in conditional state to update the information for each entity diff --git a/chirpy/response_generators/wiki2/treelets/discuss_article_treelet.py b/chirpy/response_generators/wiki2/treelets/discuss_article_treelet.py index caa9ea9..7ffe7a7 100644 --- a/chirpy/response_generators/wiki2/treelets/discuss_article_treelet.py +++ b/chirpy/response_generators/wiki2/treelets/discuss_article_treelet.py @@ -352,7 +352,8 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): conditional_state=ConditionalState( prev_treelet_str=self.name, next_treelet_str=None - )) + ), + last_rg_willing_to_handover_control=True) else: return ResponseGeneratorResult( text=f"{ack} {text}", diff --git a/chirpy/response_generators/wiki2/treelets/discuss_section_further_treelet.py b/chirpy/response_generators/wiki2/treelets/discuss_section_further_treelet.py index 5bbc527..ff3d9c4 100644 --- a/chirpy/response_generators/wiki2/treelets/discuss_section_further_treelet.py +++ b/chirpy/response_generators/wiki2/treelets/discuss_section_further_treelet.py @@ -207,7 +207,8 @@ def get_initial_response(self): return ResponseGeneratorResult(text=response, priority=ResponsePriority.STRONG_CONTINUE, needs_prompt=False, state=state, - cur_entity=entity, conditional_state=conditional_state) + cur_entity=entity, conditional_state=conditional_state, + last_rg_willing_to_handover_control=True) def get_followup_acknowledgement(self): state, utterance, response_types = self.get_state_utterance_response_types() diff --git a/chirpy/response_generators/wiki2/treelets/handback_treelet.py b/chirpy/response_generators/wiki2/treelets/handback_treelet.py new file mode 100644 index 0000000..a2a0564 --- /dev/null +++ b/chirpy/response_generators/wiki2/treelets/handback_treelet.py @@ -0,0 +1,32 @@ +import random + +from chirpy.core.response_generator.treelet import Treelet +from chirpy.core.response_generator_datatypes import ResponsePriority, ResponseGeneratorResult +from chirpy.response_generators.wiki2.state import ConditionalState +from chirpy.response_generators.neural_fallback.neural_helpers import get_random_fallback_neural_response +from typing import Optional +import logging +import chirpy.response_generators.wiki2.wiki_utils as wiki_utils + +logger = logging.getLogger('chirpylogger') + + +class WikiHandBackTreelet(Treelet): + name = "wiki_handback_treelet" + + def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): + state, utterance, response_types = self.get_state_utterance_response_types() + + + logger.error(f'WIKI HANDBACK') + + + return ResponseGeneratorResult( + text="TODO:HANDBACK_WIKI_TEXT (WRAP UP)", + priority=priority, + state=state, needs_prompt=True, cur_entity=self.get_current_entity(), + conditional_state=ConditionalState(prev_treelet_str=self.name, + next_treelet_str=None, + rg_that_was_taken_over=self.rg.state.rg_that_was_taken_over), + ) + diff --git a/chirpy/response_generators/wiki2/treelets/intro_entity_treelet.py b/chirpy/response_generators/wiki2/treelets/intro_entity_treelet.py index 61305cf..3e00a81 100644 --- a/chirpy/response_generators/wiki2/treelets/intro_entity_treelet.py +++ b/chirpy/response_generators/wiki2/treelets/intro_entity_treelet.py @@ -53,7 +53,8 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): priority=priority, state=state, needs_prompt=False, cur_entity=entity, conditional_state=ConditionalState(prev_treelet_str=self.name, - next_treelet_str=self.rg.discuss_article_treelet.name) + next_treelet_str=self.rg.discuss_article_treelet.name), + last_rg_willing_to_handover_control=True ) else: # no intro paragraph available neural_response = get_random_fallback_neural_response(self.get_current_state()) diff --git a/chirpy/response_generators/wiki2/treelets/takeover_treelet.py b/chirpy/response_generators/wiki2/treelets/takeover_treelet.py new file mode 100644 index 0000000..5fb3cf6 --- /dev/null +++ b/chirpy/response_generators/wiki2/treelets/takeover_treelet.py @@ -0,0 +1,76 @@ +import random + +from chirpy.core.response_generator.treelet import Treelet +from chirpy.core.response_generator_datatypes import ResponsePriority, ResponseGeneratorResult +from chirpy.response_generators.wiki2.state import ConditionalState +from chirpy.response_generators.neural_fallback.neural_helpers import get_random_fallback_neural_response +from typing import Optional +import logging +import chirpy.response_generators.wiki2.wiki_utils as wiki_utils + +logger = logging.getLogger('chirpylogger') + + +class WikiTakeOverTreelet(Treelet): + name = "wiki_takeover_treelet" + + def get_summary_takeover(self, related_wiki_section, sentseg_fn, max_words, max_sents): + summary = wiki_utils.get_summary(related_wiki_section['text'], sentseg_fn, max_words, max_sents) + logger.primary_info(f"Takeover Summary is: {summary}") + summary = wiki_utils.clean_wiki_text(summary) + logger.primary_info(f"Takeover Summary after clean is: {summary}") + if wiki_utils.contains_offensive(summary): + logger.primary_info(f"Found takeover overview to be offensive, discarding it") + return None + return summary + + def get_takeover_paragraph(self, cur_entity: str, takeover_entity: str) -> Optional[str]: + related_wiki_sections_from_cur_entity_doc = wiki_utils.search_wiki_sections(cur_entity, (takeover_entity,), (takeover_entity,)) + related_wiki_sections_from_takeover_entity_doc = wiki_utils.search_wiki_sections(takeover_entity, (cur_entity,), (cur_entity,)) + + logging.error(f"related_wiki_sections_from_cur_entity_doc: {related_wiki_sections_from_cur_entity_doc}") + logging.error(f"related_wiki_sections_from_takeover_entity_doc: {related_wiki_sections_from_takeover_entity_doc}") + + if related_wiki_sections_from_cur_entity_doc: + return self.get_summary_takeover(related_wiki_sections_from_cur_entity_doc, wiki_utils.get_sentseg_fn(self.rg), max_sents=4) + + if related_wiki_sections_from_takeover_entity_doc: + return self.get_summary_takeover(related_wiki_sections_from_takeover_entity_doc, wiki_utils.get_sentseg_fn(self.rg), max_sents=4) + + logger.info("No overview found") + return None + + def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): + state, utterance, response_types = self.get_state_utterance_response_types() + + rg_that_was_taken_over = self.rg.state_manager.last_state.active_rg + logger.error(f'RG_THAT_WAS_TAKEN_OVER: {rg_that_was_taken_over}') + + cur_entity = self.get_current_entity() + takeover_entity = self.get_most_recent_able_to_takeover_entity() + logger.error(f'WIKI TAKEOVER ENTITY: {takeover_entity}') + + takeover_text = "TODO:TAKEOVER_WIKI_TEXT" # self.get_takeover_paragraph(cur_entity.name, takeover_entity.name) + ack = random.choice([ + "Well, from what I've read,", + "Ah, to my knowledge," + ]) + + logger.error(f'WIKI TAKEOVER TEXT: {takeover_text}') + + + + if takeover_text: + return ResponseGeneratorResult( + text=f"{ack} {wiki_utils.clean_wiki_text(takeover_text)}", + priority=priority, + state=state, needs_prompt=False, cur_entity=takeover_entity, + conditional_state=ConditionalState(prev_treelet_str=self.name, + next_treelet_str=self.rg.handback_treelet.name, + rg_that_was_taken_over=rg_that_was_taken_over), + takeover_rg_willing_to_handback_control=True # EDIT + + ) + + else: + return None diff --git a/chirpy/response_generators/wiki2/wiki_response_generator.py b/chirpy/response_generators/wiki2/wiki_response_generator.py index b2c7436..21a7192 100644 --- a/chirpy/response_generators/wiki2/wiki_response_generator.py +++ b/chirpy/response_generators/wiki2/wiki_response_generator.py @@ -1,5 +1,7 @@ import os import logging +from concurrent import futures + from typing import Optional, Set, Tuple import random @@ -23,6 +25,8 @@ from chirpy.annotators.corenlp import Sentiment from chirpy.response_generators.wiki2.state import State,ConditionalState, NO_UPDATE +from chirpy.response_generators.wiki2.treelets.takeover_treelet import WikiTakeOverTreelet # EDIT +from chirpy.response_generators.wiki2.treelets.handback_treelet import WikiHandBackTreelet # EDIT logger = logging.getLogger('chirpylogger') @@ -30,11 +34,13 @@ from chirpy.annotators.responseranker import ResponseRanker use_responseranker = True except ModuleNotFoundError: - logger.warning('ResponseRanker module not found, defaulting to original DialoGPT and GPT2 Rankers') + logger.warning('ResponseRanker module not found, defaulting to original DialoGPT and GP T2 Rankers') from chirpy.annotators.dialogptranker import DialoGPTRanker from chirpy.annotators.gpt2ranker import GPT2Ranker use_responseranker = False +import threading + class WikiResponseGenerator(ResponseGenerator): name='WIKI' @@ -50,12 +56,15 @@ def __init__(self, state_manager) -> None: self.discuss_section_treelet = DiscussSectionTreelet(self) self.discuss_section_further_treelet = DiscussSectionFurtherTreelet(self) self.get_opinion_treelet = GetOpinionTreelet(self) + self.takeover_treelet = WikiTakeOverTreelet(self) # EDIT + self.handback_treelet = WikiHandBackTreelet(self) # EDIT treelets = {t.name: t for t in [self.check_user_knowledge_treelet, self.acknowledge_user_knowledge_treelet, self.factoid_treelet, self.intro_entity_treelet, self.combined_til_treelet, self.discuss_article_treelet, self.discuss_section_treelet, - self.discuss_section_further_treelet, self.get_opinion_treelet]} + self.discuss_section_further_treelet, self.get_opinion_treelet, + self.takeover_treelet, self.handback_treelet]} super().__init__(state_manager, treelets=treelets, state_constructor=State, can_give_prompts=True, conditional_state_constructor=ConditionalState, @@ -641,7 +650,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState]) -> State: + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: # EDIT state = super().update_state_if_not_chosen(state, conditional_state) state.cur_doc_title = None state.suggested_sections = [] @@ -658,3 +667,7 @@ def update_state_if_not_chosen(self, state: State, conditional_state: Optional[C state.context_used = None return state + + def get_takeover_response(self): # EDIT + logger.info("WIKI TAKEOVER") + return self.takeover_treelet.get_response(ResponsePriority.FORCE_START) \ No newline at end of file diff --git a/chirpy/response_generators/wiki2/wiki_utils.py b/chirpy/response_generators/wiki2/wiki_utils.py index 0e8ea8a..c3ede9b 100644 --- a/chirpy/response_generators/wiki2/wiki_utils.py +++ b/chirpy/response_generators/wiki2/wiki_utils.py @@ -262,12 +262,15 @@ def search_wiki_sections(doc_title: str, phrases: tuple, wiki_links:tuple) -> Li } } } + import json + logger.error(f"QUERY: {json.dumps(query, indent=2)}") sections = es.search(index='enwiki-20200920-sections', body=query) - logger.debug(f"For phrases {phrases}, in wikipedia article {doc_title}, found following sections (unfiltered) {sections}") + logger.error(f"For phrases {phrases}, in wikipedia article {doc_title}, found following sections (unfiltered) {sections}") filtered_sections = filter_highlight_sections(doc_title, sections) return filtered_sections + def get_text_for_entity(entity): results = es.search(index='enwiki-20200920-sections', body={ 'query': { diff --git a/env.list b/env.list index c9cedf7..119b5f5 100644 --- a/env.list +++ b/env.list @@ -19,3 +19,4 @@ export responseranker_URL=4088 export stanfordnlp_URL=4089 export infiller_URL=4090 export postgresql_URL=5432 +export usecolbert=false diff --git a/servers/local/shell_chat.py b/servers/local/shell_chat.py index 0dcfb7c..069b6dd 100644 --- a/servers/local/shell_chat.py +++ b/servers/local/shell_chat.py @@ -47,7 +47,7 @@ } # Logging settings LOGTOSCREEN_LEVEL = logging.INFO + 5 -LOGTOFILE_LEVEL = logging.DEBUG +LOGTOFILE_LEVEL = logging.INFO # EDIT: logging.debug def init_logger(): logger_settings = LoggerSettings(logtoscreen_level=LOGTOSCREEN_LEVEL, logtoscreen_usecolor=True, From 1899b7951366586cb632da981c3786252596a43f Mon Sep 17 00:00:00 2001 From: thanawan-atc Date: Mon, 18 Jul 2022 14:27:45 -0700 Subject: [PATCH 3/6] Takeover: Unfilled text --- .../core/response_generator/response_generator.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/chirpy/core/response_generator/response_generator.py b/chirpy/core/response_generator/response_generator.py index d34ffbc..95ac8d4 100644 --- a/chirpy/core/response_generator/response_generator.py +++ b/chirpy/core/response_generator/response_generator.py @@ -1013,18 +1013,12 @@ def augment_resuming_statement(self, resuming_statement_first_treelet): resuming_prompt_second_treelet = resuming_conversation_second_treelet.get_prompt() logger.error(f"DEBUG RESUMING PROMPT: {resuming_prompt_second_treelet}") if resuming_prompt_second_treelet: - # for attr_copy in ['state', 'cur_entity', 'expected_type']: - # getattr(resuming_statement_first_treelet, attr_copy) = - resuming_statement_first_treelet.text = f"{resuming_statement_first_treelet.text} {resuming_prompt_second_treelet.text}" - resuming_statement_first_treelet.state = resuming_prompt_second_treelet.state resuming_statement_first_treelet.conditional_state.next_treelet_str = resuming_conversation_second_treelet_str - resuming_statement_first_treelet.cur_entity = resuming_prompt_second_treelet.cur_entity - resuming_statement_first_treelet.expected_type = resuming_prompt_second_treelet.expected_type - resuming_statement_first_treelet.answer_type = resuming_prompt_second_treelet.answer_type - resuming_statement_first_treelet.last_rg_willing_to_handover_control = resuming_prompt_second_treelet.last_rg_willing_to_handover_control - resuming_statement_first_treelet.rg_that_was_taken_over = resuming_prompt_second_treelet.rg_that_was_taken_over - resuming_statement_first_treelet.takeover_rg_willing_to_handback_control = resuming_prompt_second_treelet.takeover_rg_willing_to_handback_control + for attr_to_copy in ['state', 'cur_entity', 'expected_type', 'answer_type', + 'last_rg_willing_to_handover_control', 'takeover_rg_willing_to_handback_control']: + attr_template = getattr(resuming_prompt_second_treelet, attr_to_copy) + setattr(resuming_statement_first_treelet, attr_to_copy, attr_template) resuming_statement_first_treelet.resuming_conversation_next_treelet = None logger.error(f"DEBUG RESUMING_STATEMENT_FIRST_TREELET.STATE: {resuming_statement_first_treelet}") return resuming_statement_first_treelet From f50c376b4e963fef7c3877b4170ed83bab1d7898 Mon Sep 17 00:00:00 2001 From: thanawan-atc Date: Thu, 21 Jul 2022 15:22:02 -0700 Subject: [PATCH 4/6] Transition with EDIT marked up --- chirpy/core/dialog_manager.py | 16 +-- chirpy/core/entity_tracker/entity_tracker.py | 30 +++-- .../response_generator/response_generator.py | 53 +++------ chirpy/core/response_generator/state.py | 12 +- chirpy/core/response_generator_datatypes.py | 32 +++--- chirpy/core/response_priority.py | 2 +- ...closing_confirmation_response_generator.py | 2 +- .../treelets/ask_favorite_food_treelet.py | 2 +- .../comment_on_favorite_type_treelet.py | 18 +-- .../food/treelets/factoid_treelet.py | 2 +- .../food/treelets/introductory_treelet.py | 2 +- .../open_ended_user_comment_treelet.py | 6 +- .../launch/launch_response_generator.py | 2 +- .../music/music_response_generator.py | 2 +- .../neural_chat_response_generator.py | 2 +- .../response_generators/neural_chat/state.py | 16 ++- .../offensive_user_response_generator.py | 2 +- .../opinion2/opinion_response_generator.py | 2 +- .../response_templates/response_components.py | 1 + .../wiki2/treelets/handback_treelet.py | 46 ++++++-- .../wiki2/treelets/takeover_treelet.py | 66 +++++------ .../wiki2/wiki_response_generator.py | 30 +++-- .../response_generators/wiki2/wiki_utils.py | 103 ++++++++++++++++++ servers/local/shell_chat.py | 2 +- 24 files changed, 287 insertions(+), 164 deletions(-) diff --git a/chirpy/core/dialog_manager.py b/chirpy/core/dialog_manager.py index 9491dc1..c8dcabd 100644 --- a/chirpy/core/dialog_manager.py +++ b/chirpy/core/dialog_manager.py @@ -285,14 +285,14 @@ def update_rg_states(self, results: RankedResults, selected_rg: str): other_rgs = [rg for rg in results.keys() if rg != selected_rg and not is_killed(results[rg])] logger.info(f"now, current states are {rg_states}") - def rg_was_taken_over(rg): # EDIT + def rg_was_taken_over(rg): # EDIT: TAKEOVER if self.state_manager.last_state: logger.error(f"DEBUG RG_WAS_TAKEN_OVER: {selected_rg} // {rg}, {rg == self.state_manager.last_state.active_rg}") return rg_states[selected_rg].rg_that_was_taken_over and rg == self.state_manager.last_state.active_rg else: return None - args_list = [[rg_states[rg], results[rg].conditional_state, rg_was_taken_over(rg)] for rg in other_rgs] # EDIT + args_list = [[rg_states[rg], results[rg].conditional_state, rg_was_taken_over(rg)] for rg in other_rgs] # EDIT: TAKEOVER # Run update_state_if_not_chosen for other RGs logger.info(f'Starting to run update_state_if_not_chosen for {other_rgs}...') @@ -339,7 +339,7 @@ def run_rgs_and_rank(self, phase: str, exclude_rgs : List[str] = []) -> RankedRe # Get the states for the RGs we'll run, which we'll use as input to the get_response/get_prompt fn logger.debug('Copying RG states to use as input...') - # input_rg_states = copy.copy([rg_states[rg] for rg in rgs_list]) # list of dicts # EDIT: COMMENT OUT + # input_rg_states = copy.copy([rg_states[rg] for rg in rgs_list]) # list of dicts # EDIT: TAKEOVER (COMMENT OUT) # import pdb; pdb.set_trace() @@ -352,17 +352,17 @@ def run_rgs_and_rank(self, phase: str, exclude_rgs : List[str] = []) -> RankedRe else: priority_modules = [] - rg_was_taken_over = None # EDIT - if self.state_manager.last_state_response: # EDIT + rg_was_taken_over = None # EDIT: TAKEOVER + if self.state_manager.last_state_response: # EDIT: TAKEOVER rg_was_taken_over = self.state_manager.last_state_response.state.rg_that_was_taken_over - def rg_to_resume(rg): # EDIT : ???? + def rg_to_resume(rg): # EDIT: TAKEOVER logger.error(f"DEBUG RG_TO_RESUME: {rg_was_taken_over} // {rg}, {rg == rg_was_taken_over}") return rg == rg_was_taken_over function_name = 'get_prompt_wrapper' if phase == 'prompt' else 'get_response' - args_list = copy.copy([[rg_states[rg], rg_to_resume(rg)] for rg in rgs_list]) # EDIT : ???? - results_dict = self.response_generators.run_multithreaded(rg_names=rgs_list, # EDIT : ???? + args_list = copy.copy([[rg_states[rg], rg_to_resume(rg)] for rg in rgs_list]) # EDIT: TAKEOVER + results_dict = self.response_generators.run_multithreaded(rg_names=rgs_list, # EDIT: TAKEOVER function_name=function_name, timeout=timeout, args_list=args_list, # [[state] for state in input_rg_states], diff --git a/chirpy/core/entity_tracker/entity_tracker.py b/chirpy/core/entity_tracker/entity_tracker.py index 0d59043..376cd8d 100644 --- a/chirpy/core/entity_tracker/entity_tracker.py +++ b/chirpy/core/entity_tracker/entity_tracker.py @@ -8,8 +8,6 @@ from chirpy.core.entity_linker.thresholds import SCORE_THRESHOLD_NAV_ABOUT, SCORE_THRESHOLD_NAV_NOT_ABOUT, SCORE_THRESHOLD_EXPECTEDTYPE from chirpy.core.entity_linker.entity_groups import EntityGroup -import random # EDIT - logger = logging.getLogger('chirpylogger') class TransitionType(Enum): @@ -25,8 +23,8 @@ class EntityTrackerState(object): def __init__(self): self.cur_entity = None # the current entity under discussion (can be None) - self.talked_unfinished = [] # EDIT - self.able_to_takeover_entities = [] # EDIT + self.talked_unfinished = [] # EDIT: TAKEOVER (for storing entities needed for resuming conversation) + self.able_to_takeover_entities = [] # EDIT: TAKEOVER (Will be rewritten every turn) self.talked_rejected = [] # entities we talked about in the past, and stopped talking about because the user indicated they didn't want to talk about it any more self.talked_finished = [] # entities we talked about in the past, that aren't in talked_rejected self.talked_transitionable = [] @@ -101,7 +99,7 @@ def finish_entity(self, entity: Optional[WikiEntity], transition_is_possible=Tru logger.error(f"This is an error. This should be a WikiEntity object but {entity} is of type {type(entity)}") entity = None - if entity is not None and entity not in self.talked_finished and entity not in self.talked_unfinished: # EDIT (?) + if entity is not None and entity not in self.talked_finished and entity not in self.talked_unfinished: # EDIT: TAKEOVER logger.info(f'Putting entity {entity} on the talked_finished list') self.talked_finished.append(entity) @@ -281,13 +279,13 @@ def condition_fn(entity_linker_result, linked_span, entity) -> bool: if nav_intent_output.neg_intent or nav_intent_output.pos_intent or last_answer_type in [AnswerType.QUESTION_SELFHANDLING, AnswerType.QUESTION_HANDOFF]: self.cur_entity = self.entity_initiated_on_turn - self.able_to_takeover_entities = [] # EDIT + self.able_to_takeover_entities = [] # EDIT: TAKEOVER for linked_span in current_state.entity_linker.high_prec: if not self.talked(linked_span.top_ent): logger.info(f'Adding {linked_span.top_ent} to user_mentioned_untalked') self.user_mentioned_untalked.append(linked_span.top_ent) - self.able_to_takeover_entities.append(linked_span.top_ent) # EDIT + self.able_to_takeover_entities.append(linked_span.top_ent) # EDIT: TAKEOVER logger.primary_info(f'The EntityTrackerState is now: {self}') logger.error(f'ABLE_TO_TAKEOVER_ENTITIES: {self.able_to_takeover_entities}') @@ -335,7 +333,7 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up transition_is_possible = not getattr(result, 'no_transition', False) - if self.able_to_takeover_entities: # EDIT + if self.able_to_takeover_entities: # EDIT: TAKEOVER self.talked_unfinished.append(self.cur_entity) new_entity = self.able_to_takeover_entities.pop() logger.primary_info(f'Removing {new_entity} from {self.able_to_takeover_entities}') @@ -363,7 +361,7 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up logger.primary_info(f'Set cur_entity to new_entity={new_entity} from {rg} RG {phase}') - if new_entity in self.talked_unfinished: # EDIT + if new_entity in self.talked_unfinished: # EDIT: TAKEOVER archived_entity = new_entity logger.error( f"Removing archived_entity [{archived_entity}] from talked_unfinished [{self.talked_unfinished}]") @@ -385,8 +383,8 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up def __repr__(self, show_history=False): output = f" bool: return True return ent in entities - self.able_to_takeover_entities = [ent for ent in self.able_to_takeover_entities if keep_entity(ent)] # EDIT + self.able_to_takeover_entities = [ent for ent in self.able_to_takeover_entities if keep_entity(ent)] # EDIT: TAKEOVER self.talked_finished = [ent for ent in self.talked_finished if keep_entity(ent)] self.talked_rejected = [ent for ent in self.talked_rejected if keep_entity(ent)] self.user_mentioned_untalked = [ent for ent in self.user_mentioned_untalked if keep_entity(ent)] @@ -422,8 +420,8 @@ def reduce_size(self, max_size: int): # Make a set (no duplicates) of all the WikiEntities stored in this EntityTrackerState entity_set = set() entity_set.add(self.cur_entity) - entity_set.update(self.talked_unfinished) # EDIT - entity_set.update(self.able_to_takeover_entities) # EDIT + entity_set.update(self.talked_unfinished) # EDIT: TAKEOVER + entity_set.update(self.able_to_takeover_entities) # EDIT: TAKEOVER entity_set.update(self.talked_finished) entity_set.update(self.talked_rejected) entity_set.update(self.user_mentioned_untalked) @@ -439,8 +437,8 @@ def replace_ent(ent: Optional[WikiEntity]): return None return entname2ent[ent.name] self.cur_entity = replace_ent(self.cur_entity) - self.talked_unfinished = [replace_ent(ent) for ent in self.talked_unfinished] # EDIT - self.able_to_takeover_entities = [replace_ent(ent) for ent in self.able_to_takeover_entities] # EDIT + self.talked_unfinished = [replace_ent(ent) for ent in self.talked_unfinished] # EDIT: TAKEOVER + self.able_to_takeover_entities = [replace_ent(ent) for ent in self.able_to_takeover_entities] # EDIT: TAKEOVER self.talked_finished = [replace_ent(ent) for ent in self.talked_finished] self.talked_rejected = [replace_ent(ent) for ent in self.talked_rejected] self.user_mentioned_untalked = [replace_ent(ent) for ent in self.user_mentioned_untalked] diff --git a/chirpy/core/response_generator/response_generator.py b/chirpy/core/response_generator/response_generator.py index 95ac8d4..1f12d6a 100644 --- a/chirpy/core/response_generator/response_generator.py +++ b/chirpy/core/response_generator/response_generator.py @@ -106,18 +106,17 @@ def update_state_if_chosen(self, state, conditional_state): return state - def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over=False): # EDIT + def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over=False): """ By default, this sets the prev_treelet_str and next_treelet_str to '' and resets num_turns_in_rg to 0. Response types are also saved. No other attributes are updated. All other attributes in ConditionalState are set to NO-UPDATE """ - if rg_was_taken_over: # EDIT + if rg_was_taken_over: # EDIT: TAKEOVER state.archived_state = copy.deepcopy(state) logging.error(f"ARCHIVED_STATE: {state.archived_state}") - response_types = self.get_cache(f'{self.name}_response_types') if response_types is not None: state.response_types = construct_response_types_tuple(response_types) @@ -293,7 +292,7 @@ def get_current_entity(self, initiated_this_turn=False): else: return self.state_manager.current_state.entity_tracker.cur_entity - def get_most_recent_able_to_takeover_entity(self): # EDIT + def get_most_recent_able_to_takeover_entity(self): # EDIT: TAKEOVER return self.state_manager.current_state.entity_tracker.able_to_takeover_entities[-1] def get_entity_tracker(self): @@ -872,7 +871,7 @@ def get_last_rg_in_control(self) -> Optional[str]: return self.state_manager.last_state.selected_response_rg - def get_response(self, state, rg_was_taken_over=False) -> ResponseGeneratorResult: # EDIT : ???? + def get_response(self, state, rg_was_taken_over=False) -> ResponseGeneratorResult: response_types = self.identify_response_types(self.utterance) logger.primary_info(f"{self.name} identified response_types: {response_types}") self.state = state @@ -931,8 +930,7 @@ def get_response(self, state, rg_was_taken_over=False) -> ResponseGeneratorResul activation_check_fns = { (lambda: self.get_last_active_rg() in self.disallow_start_from): self.get_fallback_result, (lambda: True): self.handle_direct_navigational_intent, - (lambda: (self.last_rg_willing_to_handover_control() and self.exist_able_to_takeover_entities())): self.get_takeover_response, # EDIT - # (lambda: (self.takeover_rg_willing_to_handback_control())): wrapped_partial(self.resume_response_generator, response_types), # EDIT + (lambda: (self.last_rg_willing_to_handover_control() and self.exist_able_to_takeover_entities())): self.get_takeover_response, # EDIT: TAKEOVER (lambda: True): self.handle_current_entity, (lambda: True): self.get_intro_treelet_response, (lambda: True): self.handle_custom_activation_checks, @@ -946,15 +944,6 @@ def get_response(self, state, rg_was_taken_over=False) -> ResponseGeneratorResul if activation_condition(): response = activation_check_fn() - # if response and activation_check_fn == self.get_takeover_response: # EDIT - # logger.primary_info(f"{self.name} is being taken over.") - # resumable_state = self.state_manager.last_state_response - # logger.error(f"STATE TAKEOVER: Change resumable state from {self.state.resumable_state} to {resumable_state})") - # updated_state = self.update_state_if_not_chosen(self.state, self.ConditionalState()) - # self.state = updated_state - # self.state.resumable_state = resumable_state - # logger.error(f"CURRENT STATE AFTER TAKEOVER: {self.state}") - if response: return self.possibly_augment_with_prompt(response) @@ -965,43 +954,27 @@ def get_response(self, state, rg_was_taken_over=False) -> ResponseGeneratorResul return self.get_fallback_result() - def last_rg_willing_to_handover_control(self): # EDIT + def last_rg_willing_to_handover_control(self): # EDIT: TAKEOVER last_active_rg_prompt = self.state_manager.last_state_response - # logging.error(f"LAST ACTIVE: {last_active_rg_prompt}") - # logging.error(f"LAST WILL: {last_active_rg_prompt.last_rg_willing_to_handover_control}") if last_active_rg_prompt: return last_active_rg_prompt.last_rg_willing_to_handover_control else: return False - def exist_able_to_takeover_entities(self): # EDIT + def exist_able_to_takeover_entities(self): # EDIT: TAKEOVER return len(self.state_manager.current_state.entity_tracker.able_to_takeover_entities) != 0 - def get_takeover_response(self): # EDIT + def get_takeover_response(self): # EDIT: TAKEOVER logging.error(f"TEST: {self.name} null get_takeover_response") return None - def takeover_rg_willing_to_handback_control(self): # EDIT + def takeover_rg_willing_to_handback_control(self): # EDIT: TAKEOVER last_active_rg_prompt = self.state_manager.last_state_response if last_active_rg_prompt: return last_active_rg_prompt.takeover_rg_willing_to_handback_control else: return False - # def resume_response_generator(self, response_types): # EDIT - # # logging.error(f"TEST: {self.name} null resume_conversation") - # # if self.name == 'FOOD': - # # logging.error(f"DEBUG: state is {self.state} ") - # # logging.error(f"ARCHIVED_RESUME: {archived_state}") - # archived_state = self.state.archived_state - # if archived_state: - # logging.error(f" HANDING OVER FROM: {self.name} // {self.state}") - # self.state = archived_state - # logging.error(f" HANDING OVER TO: {self.state}") - # return self.continue_conversation(response_types) - # else: - # return None - def get_resuming_statement(self, state) -> ResponseGeneratorResult: logging.error(f"TEST: {self.name} null get_resuming_statement") return self.emptyPrompt() @@ -1101,8 +1074,11 @@ def continue_conversation(self, response_types) -> Optional[ResponseGeneratorRes next_treelet_str = self.state.next_treelet_str next_treelet = None - response_priority = ResponsePriority.STRONG_CONTINUE - logger.error(f"In continue_conversation, self.state is {self.state}, next_treelet_str is {next_treelet_str}") + if self.last_rg_willing_to_handover_control(): # we talked last turn and decided to handover... + response_priority = ResponsePriority.WEAK_CONTINUE + else: + response_priority = ResponsePriority.STRONG_CONTINUE + logger.error(f"In continue_conversation, self.state is {self.state}, next_treelet_str is {next_treelet_str}, priority is {response_priority}") if next_treelet_str is None: return self.emptyResult() # continue from some other RG elif next_treelet_str == '': @@ -1144,7 +1120,6 @@ def continue_conversation(self, response_types) -> Optional[ResponseGeneratorRes logger.info(f"Continuing conversation from {next_treelet_str} for {self.name}") assert next_treelet_str in self.treelets next_treelet = self.treelets[next_treelet_str] - response_priority = ResponsePriority.STRONG_CONTINUE if next_treelet is not None: response = next_treelet.get_response(response_priority, ) diff --git a/chirpy/core/response_generator/state.py b/chirpy/core/response_generator/state.py index 180882e..f5e5f57 100644 --- a/chirpy/core/response_generator/state.py +++ b/chirpy/core/response_generator/state.py @@ -3,6 +3,8 @@ from chirpy.core.response_generator.response_type import ResponseType +from chirpy.core.entity_linker.entity_linker_classes import WikiEntity # EDIT: TAKEOVER + import logging logger = logging.getLogger('chirpylogger') @@ -22,16 +24,18 @@ class BaseState: next_treelet_str: Optional[str] = '' response_types: Tuple[str] = () num_turns_in_rg: int = 0 - archived_state: "BaseState" = None # EDIT - rg_that_was_taken_over: str = None # EDIT + archived_state: "BaseState" = None # EDIT: TAKEOVER + rg_that_was_taken_over: str = None # EDIT: TAKEOVER + takeover_entity: WikiEntity = None # EDIT: TAKEOVER @dataclass class BaseConditionalState: prev_treelet_str: str = '' next_treelet_str: Optional[str] = '' response_types: Tuple[str] = NO_UPDATE - archived_state: "BaseState" = NO_UPDATE # EDIT - rg_that_was_taken_over: str = None # EDIT + archived_state: "BaseState" = NO_UPDATE # EDIT: TAKEOVER + rg_that_was_taken_over: str = NO_UPDATE # EDIT: TAKEOVER + takeover_entity: WikiEntity = NO_UPDATE # EDIT: TAKEOVER def construct_response_types_tuple(response_types): return tuple([str(x) for x in response_types]) diff --git a/chirpy/core/response_generator_datatypes.py b/chirpy/core/response_generator_datatypes.py index 3394ff7..f56e9b3 100644 --- a/chirpy/core/response_generator_datatypes.py +++ b/chirpy/core/response_generator_datatypes.py @@ -34,9 +34,10 @@ def __init__(self, conditional_state=None, tiebreak_priority=None, no_transition=False, - last_rg_willing_to_handover_control=False, # EDIT - rg_that_was_taken_over =None, # EDIT - takeover_rg_willing_to_handback_control=False # EDIT + last_rg_willing_to_handover_control=False, # EDIT: TAKEOVER + rg_that_was_taken_over =None, # EDIT: TAKEOVER + takeover_entity=None, # EDIT: TAKEOVER + takeover_rg_willing_to_handback_control=False # EDIT: TAKEOVER ): """ :param text: text of the response @@ -102,9 +103,10 @@ def __init__(self, self.conditional_state = conditional_state self.tiebreak_priority = tiebreak_priority self.no_transition = no_transition - self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control # EDIT - self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT - self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control # EDIT + self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control # EDIT: TAKEOVER + self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT: TAKEOVER + self.takeover_entity = takeover_entity # EDIT: TAKEOVER + self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control # EDIT: TAKEOVER def reduce_size(self, max_size:int = None): """Gracefully degrade by removing non essential attributes. @@ -132,10 +134,11 @@ def __init__(self, expected_type: Optional[EntityGroup] = None, conditional_state=None, answer_type: AnswerType = AnswerType.QUESTION_SELFHANDLING, - last_rg_willing_to_handover_control=False, # EDIT - rg_that_was_taken_over =None, # EDIT - takeover_rg_willing_to_handback_control=False, # EDIT - resuming_conversation_next_treelet=None # EDIT + last_rg_willing_to_handover_control=False, # EDIT: TAKEOVER + rg_that_was_taken_over =None, # EDIT: TAKEOVER + takeover_entity =None, # EDIT: TAKEOVER + takeover_rg_willing_to_handback_control=False, # EDIT: TAKEOVER + resuming_conversation_next_treelet=None # EDIT: TAKEOVER ): """ :param text: text of the response @@ -175,10 +178,11 @@ def __init__(self, self.state = state self.conditional_state = conditional_state self.answer_type = answer_type - self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control # EDIT - self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT - self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control # EDIT - self.resuming_conversation_next_treelet = resuming_conversation_next_treelet # EDIT + self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control # EDIT: TAKEOVER + self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT: TAKEOVER + self.takeover_entity = takeover_entity # EDIT: TAKEOVER + self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control # EDIT: TAKEOVER + self.resuming_conversation_next_treelet = resuming_conversation_next_treelet # EDIT: TAKEOVER def __repr__(self): return 'PromptResult' + str(self.__dict__) diff --git a/chirpy/core/response_priority.py b/chirpy/core/response_priority.py index ba9232c..eb0710d 100644 --- a/chirpy/core/response_priority.py +++ b/chirpy/core/response_priority.py @@ -69,7 +69,7 @@ class TiebreakPriority(Enum): ACKNOWLEDGMENT = 62 EVI = 58 NEWS = 65 - WIKI = 69 # EDIT: Change from 64 + WIKI = 64 CATEGORIES = 60 MUSIC = 66 NEURAL_FALLBACK = 5 # fallback should always be lowest priority i.e. last resort diff --git a/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py b/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py index 7a55f33..5b9234f 100644 --- a/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py +++ b/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py @@ -101,7 +101,7 @@ def handle_custom_continuation_checks(self): # If neither matched, allow another RG to handle return self.emptyResult() - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: # EDIT + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: # EDIT: TAKEOVER state = super().update_state_if_not_chosen(state, conditional_state) state.has_just_asked_to_exit = False return state diff --git a/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py b/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py index 0d4f0f5..95d9388 100644 --- a/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py +++ b/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py @@ -30,5 +30,5 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): next_treelet_str="food_introductory_treelet", cur_food=None), expected_type=ENTITY_GROUPS_FOR_EXPECTED_TYPE.food_related, - last_rg_willing_to_handover_control=True # EDIT + last_rg_willing_to_handover_control=True # EDIT: TAKEOVER ) diff --git a/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py b/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py index 08b5421..20814f0 100644 --- a/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py +++ b/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py @@ -89,25 +89,15 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): other_type = sample_from_type(cur_food) text = f"That totally makes sense! I also really enjoy {user_answer}. Personally, I really like {other_type}." - return ResponseGeneratorResult(text=text, priority=priority, # EDIT + return ResponseGeneratorResult(text=text, priority=priority, # EDIT: TAKEOVER needs_prompt=False, state=state, cur_entity=entity, conditional_state=ConditionalState( prompt_treelet=self.rg.open_ended_user_comment_treelet.name, cur_food=cur_food_entity), - last_rg_willing_to_handover_control=False # EDIT + last_rg_willing_to_handover_control=False # EDIT: TAKEOVER ) - - # return ResponseGeneratorResult(text="TODO: RESUMING RESPONSE", priority=priority, # EDIT - # needs_prompt=False, state=state, - # cur_entity=entity, - # conditional_state=ConditionalState( - # prompt_treelet=self.rg.open_ended_user_comment_treelet.name, - # cur_food=cur_food_entity), - # last_rg_willing_to_handover_control=False # EDIT - # ) - def get_resuming_statement(self, prompt_type=PromptType.FORCE_START, **kwargs): logger.error(f"GET_STATEMENT_RESPONSE got triggered.") state, utterance, response_types = self.get_state_utterance_response_types() @@ -118,10 +108,10 @@ def get_resuming_statement(self, prompt_type=PromptType.FORCE_START, **kwargs): if get_custom_question(cur_food) is not None: custom_question_answer = get_custom_question_answer(cur_food) - text = f"TODO: RESUMING_STATEMENT_FIRST_TREELET_A (e.g. Personally, when it comes to {cur_talkable_food}, I really like {custom_question_answer})." + text = f"Anyway, personally, when it comes to {cur_talkable_food}, I really like {custom_question_answer}." else: other_type = sample_from_type(cur_food) - text = f"TODO: RESUMING_STATEMENT_FIRST_TREELET_B (e.g. Personally, I really like {other_type})" + text = f"Anyway, personally, I really like {other_type}" return PromptResult(text=text, prompt_type=prompt_type, state=state, conditional_state=ConditionalState( diff --git a/chirpy/response_generators/food/treelets/factoid_treelet.py b/chirpy/response_generators/food/treelets/factoid_treelet.py index c4f4402..4a48041 100644 --- a/chirpy/response_generators/food/treelets/factoid_treelet.py +++ b/chirpy/response_generators/food/treelets/factoid_treelet.py @@ -52,5 +52,5 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): conditional_state=ConditionalState( prev_treelet_str=self.name, cur_food=cur_food), - last_rg_willing_to_handover_control=True # EDIT + last_rg_willing_to_handover_control=True # EDIT: TAKEOVER ) diff --git a/chirpy/response_generators/food/treelets/introductory_treelet.py b/chirpy/response_generators/food/treelets/introductory_treelet.py index 9f2d8ba..04919ab 100644 --- a/chirpy/response_generators/food/treelets/introductory_treelet.py +++ b/chirpy/response_generators/food/treelets/introductory_treelet.py @@ -74,7 +74,7 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): cur_entity=entity, conditional_state=ConditionalState(cur_food=entity, prompt_treelet=prompt_treelet), - last_rg_willing_to_handover_control=True) # EDIT + last_rg_willing_to_handover_control=True) # EDIT: TAKEOVER def get_prompt(self, **kwargs): return None diff --git a/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py b/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py index b63aeba..9fece4c 100644 --- a/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py +++ b/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py @@ -33,7 +33,7 @@ def get_prompt(self, conditional_state=None): pronoun = infl('them', entity.is_plural) if best_attribute: text = 'What do you think?' else: text = f'What do you like best about {pronoun}?' - return PromptResult(text=f'TODO: RESUMING_CONV_SECOND_TREELET (e.g. {text})', prompt_type=PromptType.CONTEXTUAL, state=state, cur_entity=entity, conditional_state=conditional_state) + return PromptResult(text=text, prompt_type=PromptType.CONTEXTUAL, state=state, cur_entity=entity, conditional_state=conditional_state) def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): """ Returns the response. """ @@ -67,12 +67,12 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): text = f"{neural_response} {concluding_statement}" cur_entity = None - return ResponseGeneratorResult(text=f"TODO: PICKING UP RESPONSE (eg. {text})", priority=ResponsePriority.STRONG_CONTINUE, + return ResponseGeneratorResult(text=text, priority=ResponsePriority.STRONG_CONTINUE, needs_prompt=needs_prompt, state=state, cur_entity=cur_entity, conditional_state=ConditionalState( prev_treelet_str=self.name, prompt_treelet=prompt_treelet, cur_food=None), - last_rg_willing_to_handover_control=False # EDIT + last_rg_willing_to_handover_control=False # EDIT: TAKEOVER ) diff --git a/chirpy/response_generators/launch/launch_response_generator.py b/chirpy/response_generators/launch/launch_response_generator.py index f2570d3..86ecfcc 100644 --- a/chirpy/response_generators/launch/launch_response_generator.py +++ b/chirpy/response_generators/launch/launch_response_generator.py @@ -50,7 +50,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi # state.asked_name_counter = 1 return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: # EDIT + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: # EDIT: TAKEOVER state = super().update_state_if_not_chosen(state, conditional_state) state.next_treelet_str = None return state diff --git a/chirpy/response_generators/music/music_response_generator.py b/chirpy/response_generators/music/music_response_generator.py index d7ee4ab..984d8a0 100644 --- a/chirpy/response_generators/music/music_response_generator.py +++ b/chirpy/response_generators/music/music_response_generator.py @@ -76,7 +76,7 @@ def update_state_if_chosen(self, state, conditional_state): state.discussed_entities.append(state.cur_singer_str) return state - def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over=False): # EDIT + def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over=False): # EDIT: TAKEOVER state = super().update_state_if_not_chosen(state, conditional_state) return state diff --git a/chirpy/response_generators/neural_chat/neural_chat_response_generator.py b/chirpy/response_generators/neural_chat/neural_chat_response_generator.py index a98258d..5a73abc 100644 --- a/chirpy/response_generators/neural_chat/neural_chat_response_generator.py +++ b/chirpy/response_generators/neural_chat/neural_chat_response_generator.py @@ -188,7 +188,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi state.update_if_chosen(conditional_state) return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: # EDIT + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: # EDIT: TAKEOVER logger.primary_info(f"Neural chat state is {state}") if conditional_state is not None: state.update_if_not_chosen(conditional_state) diff --git a/chirpy/response_generators/neural_chat/state.py b/chirpy/response_generators/neural_chat/state.py index 920466b..563fa85 100644 --- a/chirpy/response_generators/neural_chat/state.py +++ b/chirpy/response_generators/neural_chat/state.py @@ -3,6 +3,8 @@ from typing import List, Optional, Set, Tuple from chirpy.core.response_generator.state import NO_UPDATE +from chirpy.core.entity_linker.entity_linker_classes import WikiEntity # EDIT: TAKEOVER + import copy logger = logging.getLogger('chirpylogger') @@ -35,7 +37,7 @@ def __init__(self, next_treelet: Optional[str] = None, most_recent_treelet: Opti user_utterance: Optional[str] = None, user_labels: List[str] = [], bot_utterance: Optional[str] = None, bot_labels: List[str] = [], neural_responses: Optional[List[str]] = None, num_topic_shifts: int = 0, - archived_state: "State" = None, rg_that_was_taken_over: str = None): # EDIT + archived_state: "State" = None, rg_that_was_taken_over: str = None, takeover_entity: WikiEntity = None): # EDIT: TAKEOVER """ @param next_treelet: the name of the treelet we should run on the next turn if our response/prompt is chosen. None means turn off next turn. @param most_recent_treelet: the name of the treelet that handled this turn, if applicable @@ -62,8 +64,9 @@ def __init__(self, next_treelet: Optional[str] = None, most_recent_treelet: Opti self.bot_labels = bot_labels self.neural_responses = neural_responses self.num_topic_shifts = num_topic_shifts - self.archived_state = archived_state # EDIT - self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT + self.archived_state = archived_state # EDIT: TAKEOVER + self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT: TAKEOVER + self.takeover_entity = takeover_entity # EDIT: TAKEOVER def __repr__(self): return f" BaseState: # EDIT + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: # EDIT: TAKEOVER state = super().update_state_if_not_chosen(state, conditional_state) state.handle_response = False state.offense_type = None diff --git a/chirpy/response_generators/opinion2/opinion_response_generator.py b/chirpy/response_generators/opinion2/opinion_response_generator.py index a3b3405..7d4dc73 100644 --- a/chirpy/response_generators/opinion2/opinion_response_generator.py +++ b/chirpy/response_generators/opinion2/opinion_response_generator.py @@ -533,7 +533,7 @@ def update_state_if_chosen(self, state: State, conditional_state : Optional[Stat if val != NO_UPDATE: setattr(state, attr, val) return state - def update_state_if_not_chosen(self, state: State, conditional_state : Optional[State], rg_was_taken_over=False) -> State: # EDIT + def update_state_if_not_chosen(self, state: State, conditional_state : Optional[State], rg_was_taken_over=False) -> State: # EDIT: TAKEOVER new_state = state.reset_state() new_state.num_turns_since_long_policy += 1 return new_state diff --git a/chirpy/response_generators/wiki2/response_templates/response_components.py b/chirpy/response_generators/wiki2/response_templates/response_components.py index 912a20c..771ba73 100644 --- a/chirpy/response_generators/wiki2/response_templates/response_components.py +++ b/chirpy/response_generators/wiki2/response_templates/response_components.py @@ -9,6 +9,7 @@ 'cool', 'super', 'i didn\'t know', + 'interesting', ] GENERAL_BOT_ACKNOWLEDGEMENTS = [ diff --git a/chirpy/response_generators/wiki2/treelets/handback_treelet.py b/chirpy/response_generators/wiki2/treelets/handback_treelet.py index a2a0564..97231d7 100644 --- a/chirpy/response_generators/wiki2/treelets/handback_treelet.py +++ b/chirpy/response_generators/wiki2/treelets/handback_treelet.py @@ -6,7 +6,11 @@ from chirpy.response_generators.neural_fallback.neural_helpers import get_random_fallback_neural_response from typing import Optional import logging + import chirpy.response_generators.wiki2.wiki_utils as wiki_utils +from chirpy.response_generators.wiki2.wiki_helpers import ResponseType +from chirpy.core.regex.response_lists import * +from chirpy.response_generators.wiki2.response_templates.response_components import * logger = logging.getLogger('chirpylogger') @@ -14,19 +18,47 @@ class WikiHandBackTreelet(Treelet): name = "wiki_handback_treelet" - def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): + def get_acknowledgement(self): state, utterance, response_types = self.get_state_utterance_response_types() + if ResponseType.CONFUSED in response_types: return random.choice(ERROR_ADMISSION) + + prefix = '' + if ResponseType.AGREEMENT in response_types: + return random.choice(RESPONSES_TO_USER_AGREEMENT) + if ResponseType.POS_SENTIMENT in response_types: + if ResponseType.OPINION in response_types: + prefix = random.choice(POS_OPINION_RESPONSES) + elif ResponseType.APPRECIATIVE in response_types: + return random.choice(APPRECIATION_DEFAULT_ACKNOWLEDGEMENTS) + elif ResponseType.NEG_SENTIMENT in response_types: + if ResponseType.OPINION in response_types: # negative opinion + prefix = "That's an interesting take," + else: # expression of sadness + return random.choice(COMMISERATION_ACKNOWLEDGEMENTS) + elif ResponseType.NEUTRAL_SENTIMENT in response_types: + if ResponseType.OPINION in response_types or ResponseType.PERSONAL_DISCLOSURE in response_types: + return random.choice(NEUTRAL_OPINION_SHARING_RESPONSES) + elif ResponseType.KNOW_MORE: + return "Yeah," + if prefix is not None: + return prefix + return random.choice(POST_SHARING_ACK) + def get_response(self, priority=ResponsePriority.FORCE_START, **kwargs): + state, utterance, response_types = self.get_state_utterance_response_types() + takeover_entity = state.takeover_entity logger.error(f'WIKI HANDBACK') + wrap_up_text = self.get_acknowledgement() return ResponseGeneratorResult( - text="TODO:HANDBACK_WIKI_TEXT (WRAP UP)", - priority=priority, - state=state, needs_prompt=True, cur_entity=self.get_current_entity(), - conditional_state=ConditionalState(prev_treelet_str=self.name, - next_treelet_str=None, - rg_that_was_taken_over=self.rg.state.rg_that_was_taken_over), + text=wrap_up_text, + priority=priority, + state=state, needs_prompt=True, cur_entity=self.get_current_entity(), + conditional_state=ConditionalState(prev_treelet_str=self.name, + next_treelet_str=None, + rg_that_was_taken_over=self.rg.state.rg_that_was_taken_over, + takeover_entity=takeover_entity), ) diff --git a/chirpy/response_generators/wiki2/treelets/takeover_treelet.py b/chirpy/response_generators/wiki2/treelets/takeover_treelet.py index 5fb3cf6..cb848b4 100644 --- a/chirpy/response_generators/wiki2/treelets/takeover_treelet.py +++ b/chirpy/response_generators/wiki2/treelets/takeover_treelet.py @@ -8,10 +8,13 @@ import logging import chirpy.response_generators.wiki2.wiki_utils as wiki_utils + +from chirpy.annotators.blenderbot import BlenderBot + logger = logging.getLogger('chirpylogger') -class WikiTakeOverTreelet(Treelet): +class WikiTakeOverTreelet(Treelet): # EDIT: TAKEOVER name = "wiki_takeover_treelet" def get_summary_takeover(self, related_wiki_section, sentseg_fn, max_words, max_sents): @@ -24,23 +27,7 @@ def get_summary_takeover(self, related_wiki_section, sentseg_fn, max_words, max_ return None return summary - def get_takeover_paragraph(self, cur_entity: str, takeover_entity: str) -> Optional[str]: - related_wiki_sections_from_cur_entity_doc = wiki_utils.search_wiki_sections(cur_entity, (takeover_entity,), (takeover_entity,)) - related_wiki_sections_from_takeover_entity_doc = wiki_utils.search_wiki_sections(takeover_entity, (cur_entity,), (cur_entity,)) - - logging.error(f"related_wiki_sections_from_cur_entity_doc: {related_wiki_sections_from_cur_entity_doc}") - logging.error(f"related_wiki_sections_from_takeover_entity_doc: {related_wiki_sections_from_takeover_entity_doc}") - - if related_wiki_sections_from_cur_entity_doc: - return self.get_summary_takeover(related_wiki_sections_from_cur_entity_doc, wiki_utils.get_sentseg_fn(self.rg), max_sents=4) - - if related_wiki_sections_from_takeover_entity_doc: - return self.get_summary_takeover(related_wiki_sections_from_takeover_entity_doc, wiki_utils.get_sentseg_fn(self.rg), max_sents=4) - - logger.info("No overview found") - return None - - def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): + def get_response(self, priority=ResponsePriority.FORCE_START, **kwargs): state, utterance, response_types = self.get_state_utterance_response_types() rg_that_was_taken_over = self.rg.state_manager.last_state.active_rg @@ -48,29 +35,44 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): cur_entity = self.get_current_entity() takeover_entity = self.get_most_recent_able_to_takeover_entity() - logger.error(f'WIKI TAKEOVER ENTITY: {takeover_entity}') - - takeover_text = "TODO:TAKEOVER_WIKI_TEXT" # self.get_takeover_paragraph(cur_entity.name, takeover_entity.name) - ack = random.choice([ - "Well, from what I've read,", - "Ah, to my knowledge," - ]) - - logger.error(f'WIKI TAKEOVER TEXT: {takeover_text}') + takeover_text = wiki_utils.get_takeover_text(self.rg, cur_entity, takeover_entity) + logger.error(f"TAKEOVER TEXT: {takeover_text}") if takeover_text: + intro_intersect_text = wiki_utils.get_random_intro_intersect_text(cur_entity.talkable_name, takeover_entity.talkable_name) + starter_text = wiki_utils.get_random_starter_text() return ResponseGeneratorResult( - text=f"{ack} {wiki_utils.clean_wiki_text(takeover_text)}", + text=intro_intersect_text + starter_text + takeover_text, priority=priority, state=state, needs_prompt=False, cur_entity=takeover_entity, conditional_state=ConditionalState(prev_treelet_str=self.name, next_treelet_str=self.rg.handback_treelet.name, - rg_that_was_taken_over=rg_that_was_taken_over), - takeover_rg_willing_to_handback_control=True # EDIT - + rg_that_was_taken_over=rg_that_was_taken_over, + takeover_entity=takeover_entity), + takeover_rg_willing_to_handback_control=True ) else: - return None + neural_prefix = f'Speaking of {takeover_entity.talkable_name} and {cur_entity.talkable_name},' + takeover_neural_response = self.rg.get_neural_response(prefix=neural_prefix) + takeover_neural_response = takeover_neural_response.split('.')[0] + generated_response = takeover_neural_response[len(neural_prefix):] + logger.error(f"TAKEOVER_NEURAL_RESPONSE: {takeover_neural_response}") + if takeover_entity.talkable_name in generated_response and cur_entity.talkable_name in generated_response : + intro_intersect_text = wiki_utils.get_random_intro_intersect_text(cur_entity.talkable_name, + takeover_entity.talkable_name) + starter_text = wiki_utils.get_random_starter_text() + return ResponseGeneratorResult( + text=intro_intersect_text + starter_text + generated_response, + priority=priority, + state=state, needs_prompt=False, cur_entity=takeover_entity, + conditional_state=ConditionalState(prev_treelet_str=self.name, + next_treelet_str=self.rg.handback_treelet.name, + rg_that_was_taken_over=rg_that_was_taken_over, + takeover_entity=takeover_entity), + takeover_rg_willing_to_handback_control=True + ) + else: + return None \ No newline at end of file diff --git a/chirpy/response_generators/wiki2/wiki_response_generator.py b/chirpy/response_generators/wiki2/wiki_response_generator.py index 21a7192..7abd88e 100644 --- a/chirpy/response_generators/wiki2/wiki_response_generator.py +++ b/chirpy/response_generators/wiki2/wiki_response_generator.py @@ -25,8 +25,10 @@ from chirpy.annotators.corenlp import Sentiment from chirpy.response_generators.wiki2.state import State,ConditionalState, NO_UPDATE -from chirpy.response_generators.wiki2.treelets.takeover_treelet import WikiTakeOverTreelet # EDIT -from chirpy.response_generators.wiki2.treelets.handback_treelet import WikiHandBackTreelet # EDIT +from chirpy.response_generators.wiki2.treelets.takeover_treelet import WikiTakeOverTreelet # EDIT: TAKEOVER +from chirpy.response_generators.wiki2.treelets.handback_treelet import WikiHandBackTreelet # EDIT: TAKEOVER + +from chirpy.core.offensive_classifier.offensive_classifier import OffensiveClassifier # EDIT: TAKEOVER logger = logging.getLogger('chirpylogger') @@ -44,7 +46,7 @@ class WikiResponseGenerator(ResponseGenerator): name='WIKI' - killable = True + killable = False def __init__(self, state_manager) -> None: self.check_user_knowledge_treelet = CheckUserKnowledgeTreelet(self) self.acknowledge_user_knowledge_treelet = AcknowledgeUserKnowledgeTreelet(self) @@ -56,15 +58,15 @@ def __init__(self, state_manager) -> None: self.discuss_section_treelet = DiscussSectionTreelet(self) self.discuss_section_further_treelet = DiscussSectionFurtherTreelet(self) self.get_opinion_treelet = GetOpinionTreelet(self) - self.takeover_treelet = WikiTakeOverTreelet(self) # EDIT - self.handback_treelet = WikiHandBackTreelet(self) # EDIT + self.takeover_treelet = WikiTakeOverTreelet(self) # EDIT: TAKEOVER + self.handback_treelet = WikiHandBackTreelet(self) # EDIT: TAKEOVER treelets = {t.name: t for t in [self.check_user_knowledge_treelet, self.acknowledge_user_knowledge_treelet, self.factoid_treelet, self.intro_entity_treelet, self.combined_til_treelet, self.discuss_article_treelet, self.discuss_section_treelet, self.discuss_section_further_treelet, self.get_opinion_treelet, - self.takeover_treelet, self.handback_treelet]} + self.takeover_treelet, self.handback_treelet]} # EDIT: TAKEOVER super().__init__(state_manager, treelets=treelets, state_constructor=State, can_give_prompts=True, conditional_state_constructor=ConditionalState, @@ -650,7 +652,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: # EDIT + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: # EDIT: TAKEOVER state = super().update_state_if_not_chosen(state, conditional_state) state.cur_doc_title = None state.suggested_sections = [] @@ -668,6 +670,14 @@ def update_state_if_not_chosen(self, state: State, conditional_state: Optional[C return state - def get_takeover_response(self): # EDIT - logger.info("WIKI TAKEOVER") - return self.takeover_treelet.get_response(ResponsePriority.FORCE_START) \ No newline at end of file + def get_takeover_response(self): # EDIT: TAKEOVER + logger.error("WIKI TAKEOVER") + return self.takeover_treelet.get_response(ResponsePriority.FORCE_START) + + def get_neural_response(self, prefix=None, allow_questions=False, conditions=None) -> Optional[str]: + if conditions is None: conditions = [] + offensive_classifier = OffensiveClassifier() + conditions = [lambda response: not offensive_classifier.contains_offensive(response)] + conditions + response = super().get_neural_response(prefix, allow_questions, conditions) + if response is None: return "That's great to hear." + return response \ No newline at end of file diff --git a/chirpy/response_generators/wiki2/wiki_utils.py b/chirpy/response_generators/wiki2/wiki_utils.py index c3ede9b..3ab6aae 100644 --- a/chirpy/response_generators/wiki2/wiki_utils.py +++ b/chirpy/response_generators/wiki2/wiki_utils.py @@ -23,6 +23,9 @@ import random import math +from chirpy.core.entity_linker.entity_linker_classes import WikiEntity # EDIT: TAKEOVER +from chirpy.annotators.convpara import ConvPara # EDIT: TAKEOVER + lucene_stopwords = {'a', 'an', 'and', 'are', 'as', 'at', 'be', 'but', 'by', 'for', 'if', 'in', 'into', 'is', 'it', 'no', 'not', 'of', 'on', 'or', 'such', 'that', 'the', 'their', 'then', 'there', 'these', 'they', 'this', 'to', 'was', @@ -135,6 +138,8 @@ def filter_highlight_sections(title, es_sections): wiki_sections.append(wiki_section) filtered_sections = wiki_sections + logger.error(f"SECTIONS: {filtered_sections}") + # Filter sections and log why the were filtered filtered_sections = filter_and_log(lambda section: not contains_offensive(section.title), filtered_sections, 'Wiki Highlights', reason_for_filtering='section title contains offensive phrases') @@ -162,6 +167,7 @@ def filter_highlight_sections(title, es_sections): filtered_sections = filter_and_log(lambda section: not contains_offensive(section.highlight), filtered_sections, 'Wiki Highlights', reason_for_filtering='section highlight contains offensive phrases') + logger.error(f"SECTIONS2: {filtered_sections}") return filtered_sections def filter_sections(title, es_sections): wiki_sections = [] @@ -270,6 +276,103 @@ def search_wiki_sections(doc_title: str, phrases: tuple, wiki_links:tuple) -> Li return filtered_sections +def prune_section(section): + return section['text'][0] in {'†', '+', '*'} + +def clean_takeover_wiki_text(text: str) -> str: + modified_text = clean_wiki_text(text) + index_caption = modified_text.find(']]') + if index_caption != -1: + modified_text = modified_text[index_caption + 2:] + return modified_text + +def summarize_takeover_candidate(rg, text: str, span_to_keep: str, max_sents: int = 3) -> str: + logger.debug(f'Summarizing takeover text: {text}') + + local_sentseg_fn = lambda text: re.split('[.\n]', text) + sentseg_fn = NLTKSentenceSegmenter( + rg.state_manager).execute if rg.state_manager else local_sentseg_fn + sentences = sentseg_fn(text) + + summary = '' + num_sentences = 0 + found = False + for sentence in sentences: + if sentence == '': + continue + if "|" in sentence or "[" in sentence or "]" in sentence or "{" in sentence or "}" in sentence: + continue + if span_to_keep in sentence: + found = True + summary += sentence + ('.' if sentence[-1] not in {'.', '!', '?'} else ' ') + num_sentences += 1 + if num_sentences > max_sents and found: + break + return summary + + +def search_wiki_intersect_sections(rg, doc_title: str, search_entity: WikiEntity) -> List[str]: + query = {'query': {'bool': {'filter': [ + {'term': {'doc_title': doc_title}}]}}} + sections = es.search(index='enwiki-20200920-sections', body=query, size=100) + top_spans = list(search_entity.anchortext_counts.keys())[:3] + logger.error(f"SPAN: {search_entity.anchortext_counts}") + candidate_texts = [] + # logger.error(f"TEXTS: {sections['hits']['hits']}") + for section in sections['hits']['hits']: + source = section['_source'] + if not prune_section(source): + source_texts = list(filter(None, re.split('\n', source['text']))) + for text in source_texts: + cleaned_text = clean_takeover_wiki_text(text) + if not contains_offensive(cleaned_text) and not rg.has_overlap_with_history(cleaned_text, threshold=0.8): + for s in top_spans: + if s in cleaned_text: + modified_text = re.sub(r"\([^()]*\)", "", text) + summary_text = summarize_takeover_candidate(rg, modified_text, span_to_keep=s) + candidate_texts.append(summary_text) + break + logger.error(f"CANDIDATE_TEXTS (from doc_title {doc_title}): {candidate_texts}") + return candidate_texts + +def get_paraphrase(rg, text: str, entity: str) -> str: + conv_para = ConvPara(rg.state_manager) + paraphrases = conv_para.get_paraphrases(background=text, entity=entity) + logger.error(f"PARAPHRASES: {paraphrases}") + return paraphrases + +def get_takeover_text(rg, cur_entity: WikiEntity, takeover_entity: WikiEntity) -> Optional[str]: + related_wiki_texts_from_cur_entity_doc = search_wiki_intersect_sections(rg, cur_entity.talkable_name, takeover_entity) + related_wiki_texts_from_takeover_entity_doc = search_wiki_intersect_sections(rg, takeover_entity.talkable_name, cur_entity) + + intersect_wiki_texts = related_wiki_texts_from_cur_entity_doc + related_wiki_texts_from_takeover_entity_doc + + logger.error(f"INTERSECT_WIKI_TEXTS: {intersect_wiki_texts}") + if related_wiki_texts_from_cur_entity_doc: + takeover_text = random.choice(related_wiki_texts_from_cur_entity_doc) + return takeover_text + + elif related_wiki_texts_from_takeover_entity_doc: + takeover_text = random.choice(related_wiki_texts_from_takeover_entity_doc) + return get_paraphrase(rg, takeover_text, cur_entity.name) + else: + return None + +INTRO_INTERSECT_TEXT = ["Speaking of {} and {}, ", + "Relating to {} and {}, ", + "Since you mentioned {} and {}, "] + +def get_random_intro_intersect_text(cur_entity: str, takeover_entity: str) -> str: + return random.choice(INTRO_INTERSECT_TEXT).format(cur_entity, takeover_entity) + +STARTER_TEXTS = ["did you know that ", + "I recently learned that ", + "I was reading recently and found out that ", + "did you know that ", + "I was interested to learn that "] + +def get_random_starter_text(): + return random.choice(STARTER_TEXTS) def get_text_for_entity(entity): results = es.search(index='enwiki-20200920-sections', body={ diff --git a/servers/local/shell_chat.py b/servers/local/shell_chat.py index 069b6dd..0dcfb7c 100644 --- a/servers/local/shell_chat.py +++ b/servers/local/shell_chat.py @@ -47,7 +47,7 @@ } # Logging settings LOGTOSCREEN_LEVEL = logging.INFO + 5 -LOGTOFILE_LEVEL = logging.INFO # EDIT: logging.debug +LOGTOFILE_LEVEL = logging.DEBUG def init_logger(): logger_settings = LoggerSettings(logtoscreen_level=LOGTOSCREEN_LEVEL, logtoscreen_usecolor=True, From 292d641181c3901da121cc54b84694c8fb695afe Mon Sep 17 00:00:00 2001 From: thanawan-atc Date: Mon, 25 Jul 2022 13:23:35 -0700 Subject: [PATCH 5/6] Added takeover feature to Chirpy --- chirpy/core/dialog_manager.py | 21 +++---- chirpy/core/entity_tracker/entity_tracker.py | 42 ++++++------- .../response_generator/response_generator.py | 60 ++++++------------- chirpy/core/response_generator/state.py | 14 ++--- chirpy/core/response_generator_datatypes.py | 36 +++++------ ...closing_confirmation_response_generator.py | 2 +- .../treelets/ask_favorite_food_treelet.py | 2 +- .../comment_on_favorite_type_treelet.py | 4 +- .../food/treelets/factoid_treelet.py | 2 +- .../food/treelets/introductory_treelet.py | 2 +- .../open_ended_user_comment_treelet.py | 2 +- .../launch/launch_response_generator.py | 2 +- .../music/music_response_generator.py | 2 +- .../neural_chat_response_generator.py | 2 +- .../response_generators/neural_chat/state.py | 18 +++--- .../offensive_user_response_generator.py | 2 +- .../opinion2/opinion_response_generator.py | 2 +- .../wiki2/treelets/handback_treelet.py | 2 +- .../wiki2/treelets/takeover_treelet.py | 24 +++----- .../wiki2/wiki_response_generator.py | 16 ++--- .../response_generators/wiki2/wiki_utils.py | 47 +++++++-------- 21 files changed, 138 insertions(+), 166 deletions(-) diff --git a/chirpy/core/dialog_manager.py b/chirpy/core/dialog_manager.py index c8dcabd..52f6bb4 100644 --- a/chirpy/core/dialog_manager.py +++ b/chirpy/core/dialog_manager.py @@ -285,14 +285,15 @@ def update_rg_states(self, results: RankedResults, selected_rg: str): other_rgs = [rg for rg in results.keys() if rg != selected_rg and not is_killed(results[rg])] logger.info(f"now, current states are {rg_states}") - def rg_was_taken_over(rg): # EDIT: TAKEOVER + def rg_was_taken_over(rg): if self.state_manager.last_state: - logger.error(f"DEBUG RG_WAS_TAKEN_OVER: {selected_rg} // {rg}, {rg == self.state_manager.last_state.active_rg}") + logger.debug(f"Rg that is selected is {selected_rg}. Currently evaluated rg is {rg}. " + f"rg == self.state_manager.last_state.active_rg is {rg == self.state_manager.last_state.active_rg}") return rg_states[selected_rg].rg_that_was_taken_over and rg == self.state_manager.last_state.active_rg else: return None - args_list = [[rg_states[rg], results[rg].conditional_state, rg_was_taken_over(rg)] for rg in other_rgs] # EDIT: TAKEOVER + args_list = [[rg_states[rg], results[rg].conditional_state, rg_was_taken_over(rg)] for rg in other_rgs] # Run update_state_if_not_chosen for other RGs logger.info(f'Starting to run update_state_if_not_chosen for {other_rgs}...') @@ -339,7 +340,6 @@ def run_rgs_and_rank(self, phase: str, exclude_rgs : List[str] = []) -> RankedRe # Get the states for the RGs we'll run, which we'll use as input to the get_response/get_prompt fn logger.debug('Copying RG states to use as input...') - # input_rg_states = copy.copy([rg_states[rg] for rg in rgs_list]) # list of dicts # EDIT: TAKEOVER (COMMENT OUT) # import pdb; pdb.set_trace() @@ -352,17 +352,18 @@ def run_rgs_and_rank(self, phase: str, exclude_rgs : List[str] = []) -> RankedRe else: priority_modules = [] - rg_was_taken_over = None # EDIT: TAKEOVER - if self.state_manager.last_state_response: # EDIT: TAKEOVER + rg_was_taken_over = None + if self.state_manager.last_state_response: rg_was_taken_over = self.state_manager.last_state_response.state.rg_that_was_taken_over - def rg_to_resume(rg): # EDIT: TAKEOVER - logger.error(f"DEBUG RG_TO_RESUME: {rg_was_taken_over} // {rg}, {rg == rg_was_taken_over}") + def rg_to_resume(rg): + logger.debug(f"rg that was taken over is {rg_was_taken_over}. Currently evaluated rg is {rg}. " + f"rg == rg_was_taken_over is {rg == rg_was_taken_over}.") return rg == rg_was_taken_over function_name = 'get_prompt_wrapper' if phase == 'prompt' else 'get_response' - args_list = copy.copy([[rg_states[rg], rg_to_resume(rg)] for rg in rgs_list]) # EDIT: TAKEOVER - results_dict = self.response_generators.run_multithreaded(rg_names=rgs_list, # EDIT: TAKEOVER + args_list = copy.copy([[rg_states[rg], rg_to_resume(rg)] for rg in rgs_list]) + results_dict = self.response_generators.run_multithreaded(rg_names=rgs_list, function_name=function_name, timeout=timeout, args_list=args_list, # [[state] for state in input_rg_states], diff --git a/chirpy/core/entity_tracker/entity_tracker.py b/chirpy/core/entity_tracker/entity_tracker.py index 376cd8d..b6c87f1 100644 --- a/chirpy/core/entity_tracker/entity_tracker.py +++ b/chirpy/core/entity_tracker/entity_tracker.py @@ -23,8 +23,8 @@ class EntityTrackerState(object): def __init__(self): self.cur_entity = None # the current entity under discussion (can be None) - self.talked_unfinished = [] # EDIT: TAKEOVER (for storing entities needed for resuming conversation) - self.able_to_takeover_entities = [] # EDIT: TAKEOVER (Will be rewritten every turn) + self.talked_unfinished = [] # entities that we have not finished talking about, but the rg is taken over + self.able_to_takeover_entities = [] # entities that are found in the response in that turn and can be used for wiki rg to takeover self.talked_rejected = [] # entities we talked about in the past, and stopped talking about because the user indicated they didn't want to talk about it any more self.talked_finished = [] # entities we talked about in the past, that aren't in talked_rejected self.talked_transitionable = [] @@ -99,7 +99,7 @@ def finish_entity(self, entity: Optional[WikiEntity], transition_is_possible=Tru logger.error(f"This is an error. This should be a WikiEntity object but {entity} is of type {type(entity)}") entity = None - if entity is not None and entity not in self.talked_finished and entity not in self.talked_unfinished: # EDIT: TAKEOVER + if entity is not None and entity not in self.talked_finished and entity not in self.talked_unfinished: logger.info(f'Putting entity {entity} on the talked_finished list') self.talked_finished.append(entity) @@ -279,16 +279,18 @@ def condition_fn(entity_linker_result, linked_span, entity) -> bool: if nav_intent_output.neg_intent or nav_intent_output.pos_intent or last_answer_type in [AnswerType.QUESTION_SELFHANDLING, AnswerType.QUESTION_HANDOFF]: self.cur_entity = self.entity_initiated_on_turn - self.able_to_takeover_entities = [] # EDIT: TAKEOVER + logger.info(f'Resetting able_to_takeover_entities to empty list') + self.able_to_takeover_entities = [] for linked_span in current_state.entity_linker.high_prec: if not self.talked(linked_span.top_ent): logger.info(f'Adding {linked_span.top_ent} to user_mentioned_untalked') self.user_mentioned_untalked.append(linked_span.top_ent) - self.able_to_takeover_entities.append(linked_span.top_ent) # EDIT: TAKEOVER + logger.info(f'Adding {linked_span.top_ent} to able_to_takeover_entities') + self.able_to_takeover_entities.append(linked_span.top_ent) logger.primary_info(f'The EntityTrackerState is now: {self}') - logger.error(f'ABLE_TO_TAKEOVER_ENTITIES: {self.able_to_takeover_entities}') + # logger.error(f'ABLE_TO_TAKEOVER_ENTITIES: {self.able_to_takeover_entities}') # Update the entity tracker history self.history[-1]['user'] = self.cur_entity @@ -333,13 +335,13 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up transition_is_possible = not getattr(result, 'no_transition', False) - if self.able_to_takeover_entities: # EDIT: TAKEOVER + if self.able_to_takeover_entities and result.state.takeover_entity: self.talked_unfinished.append(self.cur_entity) new_entity = self.able_to_takeover_entities.pop() logger.primary_info(f'Removing {new_entity} from {self.able_to_takeover_entities}') self.able_to_takeover_entities = [e for e in self.able_to_takeover_entities if e != new_entity] - logger.error(f'[AFTER TAKEOVER 1] TALK_UNFINISHED: {self.talked_unfinished} // ABLE_TO_TAKEOVER_ENT: {self.able_to_takeover_entities} //' - f'/ TALKED_FINISHED = {self.talked_finished}') + logger.info(f'After takeover, self.talk_unfinished is {self.talked_unfinished}, self.able_to_takeover_entities is {self.able_to_takeover_entities}' + f' and self.talked_unfinished is {self.talked_finished}.') if new_entity == self.cur_entity: logger.primary_info(f'new_entity={new_entity} from {rg} RG {phase} is the same as cur_entity, so keeping EntityTrackerState the same') @@ -356,18 +358,18 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up self.cur_entity = new_entity # Remove new_entity from user_mentioned_untalked if new_entity in self.user_mentioned_untalked: - logger.primary_info(f'Removing {new_entity} from {self.user_mentioned_untalked}') + logger.primary_info(f'Removing {new_entity} from {self.user_mentioned_untalked} after conversation is resumed.') self.user_mentioned_untalked = [e for e in self.user_mentioned_untalked if e != new_entity] logger.primary_info(f'Set cur_entity to new_entity={new_entity} from {rg} RG {phase}') - if new_entity in self.talked_unfinished: # EDIT: TAKEOVER + if new_entity in self.talked_unfinished: archived_entity = new_entity - logger.error( + logger.info( f"Removing archived_entity [{archived_entity}] from talked_unfinished [{self.talked_unfinished}]") self.talked_unfinished.remove(archived_entity) - logger.error(f'EntityTrackerState after updating wrt {rg} RG {phase}: {self}') + logger.info(f'EntityTrackerState after updating wrt {rg} RG {phase}: {self}') # If we're updating after receiving UpdateEntity from an RG, put any undiscussed high precision entities that # the user mentioned this turn in user_mentioned_untalked @@ -383,8 +385,8 @@ def update_from_rg(self, result: Union[ResponseGeneratorResult, PromptResult, Up def __repr__(self, show_history=False): output = f" bool: return True return ent in entities - self.able_to_takeover_entities = [ent for ent in self.able_to_takeover_entities if keep_entity(ent)] # EDIT: TAKEOVER + self.able_to_takeover_entities = [ent for ent in self.able_to_takeover_entities if keep_entity(ent)] self.talked_finished = [ent for ent in self.talked_finished if keep_entity(ent)] self.talked_rejected = [ent for ent in self.talked_rejected if keep_entity(ent)] self.user_mentioned_untalked = [ent for ent in self.user_mentioned_untalked if keep_entity(ent)] @@ -420,8 +422,8 @@ def reduce_size(self, max_size: int): # Make a set (no duplicates) of all the WikiEntities stored in this EntityTrackerState entity_set = set() entity_set.add(self.cur_entity) - entity_set.update(self.talked_unfinished) # EDIT: TAKEOVER - entity_set.update(self.able_to_takeover_entities) # EDIT: TAKEOVER + entity_set.update(self.talked_unfinished) + entity_set.update(self.able_to_takeover_entities) entity_set.update(self.talked_finished) entity_set.update(self.talked_rejected) entity_set.update(self.user_mentioned_untalked) @@ -437,8 +439,8 @@ def replace_ent(ent: Optional[WikiEntity]): return None return entname2ent[ent.name] self.cur_entity = replace_ent(self.cur_entity) - self.talked_unfinished = [replace_ent(ent) for ent in self.talked_unfinished] # EDIT: TAKEOVER - self.able_to_takeover_entities = [replace_ent(ent) for ent in self.able_to_takeover_entities] # EDIT: TAKEOVER + self.talked_unfinished = [replace_ent(ent) for ent in self.talked_unfinished] + self.able_to_takeover_entities = [replace_ent(ent) for ent in self.able_to_takeover_entities] self.talked_finished = [replace_ent(ent) for ent in self.talked_finished] self.talked_rejected = [replace_ent(ent) for ent in self.talked_rejected] self.user_mentioned_untalked = [replace_ent(ent) for ent in self.user_mentioned_untalked] diff --git a/chirpy/core/response_generator/response_generator.py b/chirpy/core/response_generator/response_generator.py index 1f12d6a..8c8c47c 100644 --- a/chirpy/core/response_generator/response_generator.py +++ b/chirpy/core/response_generator/response_generator.py @@ -113,9 +113,9 @@ def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over No other attributes are updated. All other attributes in ConditionalState are set to NO-UPDATE """ - if rg_was_taken_over: # EDIT: TAKEOVER + if rg_was_taken_over: state.archived_state = copy.deepcopy(state) - logging.error(f"ARCHIVED_STATE: {state.archived_state}") + logging.info(f"Save current state as archived_state for conversation to be resumed: {state.archived_state}") response_types = self.get_cache(f'{self.name}_response_types') if response_types is not None: @@ -292,7 +292,7 @@ def get_current_entity(self, initiated_this_turn=False): else: return self.state_manager.current_state.entity_tracker.cur_entity - def get_most_recent_able_to_takeover_entity(self): # EDIT: TAKEOVER + def get_most_recent_able_to_takeover_entity(self): return self.state_manager.current_state.entity_tracker.able_to_takeover_entities[-1] def get_entity_tracker(self): @@ -925,20 +925,16 @@ def get_response(self, state, rg_was_taken_over=False) -> ResponseGeneratorResul if not is_continuing_conversation: # allow the first branch to divert here logger.primary_info(f"{self.name} is not currently active, so checking if it should activate") - if self.name == 'FOOD': - logger.error(f"Self.state is {self.state}") activation_check_fns = { (lambda: self.get_last_active_rg() in self.disallow_start_from): self.get_fallback_result, (lambda: True): self.handle_direct_navigational_intent, - (lambda: (self.last_rg_willing_to_handover_control() and self.exist_able_to_takeover_entities())): self.get_takeover_response, # EDIT: TAKEOVER + (lambda: (self.last_rg_willing_to_handover_control() and self.exist_able_to_takeover_entities())): self.get_takeover_response, (lambda: True): self.handle_current_entity, (lambda: True): self.get_intro_treelet_response, (lambda: True): self.handle_custom_activation_checks, } - logging.error(f"DEBUG HANDOVER {self.last_rg_willing_to_handover_control()}, {self.exist_able_to_takeover_entities()}") - logging.error(f"DEBUG HANDBACK {self.takeover_rg_willing_to_handback_control()}") - + logging.debug(f"DEBUG HANDOVER {self.last_rg_willing_to_handover_control()}, {self.exist_able_to_takeover_entities()}") for activation_condition, activation_check_fn in activation_check_fns.items(): if activation_condition(): @@ -954,21 +950,21 @@ def get_response(self, state, rg_was_taken_over=False) -> ResponseGeneratorResul return self.get_fallback_result() - def last_rg_willing_to_handover_control(self): # EDIT: TAKEOVER + def last_rg_willing_to_handover_control(self): last_active_rg_prompt = self.state_manager.last_state_response if last_active_rg_prompt: return last_active_rg_prompt.last_rg_willing_to_handover_control else: return False - def exist_able_to_takeover_entities(self): # EDIT: TAKEOVER + def exist_able_to_takeover_entities(self): return len(self.state_manager.current_state.entity_tracker.able_to_takeover_entities) != 0 - def get_takeover_response(self): # EDIT: TAKEOVER - logging.error(f"TEST: {self.name} null get_takeover_response") + def get_takeover_response(self): + logging.info(f"{self.name} null get_takeover_response") return None - def takeover_rg_willing_to_handback_control(self): # EDIT: TAKEOVER + def takeover_rg_willing_to_handback_control(self): last_active_rg_prompt = self.state_manager.last_state_response if last_active_rg_prompt: return last_active_rg_prompt.takeover_rg_willing_to_handback_control @@ -976,15 +972,15 @@ def takeover_rg_willing_to_handback_control(self): # EDIT: TAKEOVER return False def get_resuming_statement(self, state) -> ResponseGeneratorResult: - logging.error(f"TEST: {self.name} null get_resuming_statement") + logging.info(f"{self.name} null get_resuming_statement") return self.emptyPrompt() def augment_resuming_statement(self, resuming_statement_first_treelet): resuming_conversation_second_treelet_str = resuming_statement_first_treelet.resuming_conversation_next_treelet - logger.error(f"DEBUG RESUMING PROMPT TREELET: {resuming_conversation_second_treelet_str}") + logger.debug(f"The prompt treelet for resuming conversation is {resuming_conversation_second_treelet_str}") resuming_conversation_second_treelet = self.treelets[resuming_conversation_second_treelet_str] resuming_prompt_second_treelet = resuming_conversation_second_treelet.get_prompt() - logger.error(f"DEBUG RESUMING PROMPT: {resuming_prompt_second_treelet}") + logger.debug(f"The prompt for resuming conversation is {resuming_prompt_second_treelet}") if resuming_prompt_second_treelet: resuming_statement_first_treelet.text = f"{resuming_statement_first_treelet.text} {resuming_prompt_second_treelet.text}" resuming_statement_first_treelet.conditional_state.next_treelet_str = resuming_conversation_second_treelet_str @@ -993,29 +989,27 @@ def augment_resuming_statement(self, resuming_statement_first_treelet): attr_template = getattr(resuming_prompt_second_treelet, attr_to_copy) setattr(resuming_statement_first_treelet, attr_to_copy, attr_template) resuming_statement_first_treelet.resuming_conversation_next_treelet = None - logger.error(f"DEBUG RESUMING_STATEMENT_FIRST_TREELET.STATE: {resuming_statement_first_treelet}") return resuming_statement_first_treelet def resume_conversation(self): - logger.error(f"DEBUG SELF WITH ARCHIVE: {self.name}") - logger.error( - f"DEBUG ARCHIVED_STATE_RESUMING_RG: {self.state_manager.current_state.response_generator_states[self.name].archived_state}") + logger.debug( + f"The archived_state for resuming conversation is {self.state_manager.current_state.response_generator_states[self.name].archived_state} in {self.name}") archived_state = self.state_manager.current_state.response_generator_states[self.name].archived_state self.state = archived_state - logger.error(f"DEBUG SELF AFTER RETREIVING ARCHIVE: {self}") - logger.error(f"DEBUG SELF.STATE AFTER RETREIVING ARCHIVE: {self.state}") + logger.error(f"The state of {self.name} after retrieving archived_state is: {self.state}") first_treelet_str = self.state.next_treelet_str assert first_treelet_str in self.treelets first_treelet = self.treelets[first_treelet_str] resuming_statement_first_treelet = first_treelet.get_resuming_statement() - logger.error(f"DEBUG RESUMING STATEMENT AFTER RETREIVING ARCHIVE: {resuming_statement_first_treelet}") + logger.info(f"The resuming statement generated from the current treelet is {resuming_statement_first_treelet}") resuming_conversation = self.augment_resuming_statement(resuming_statement_first_treelet) - logger.error(f"DEBUG WHOLE RESUMING RESPONSE: {resuming_conversation}") + logger.info(f"The resuming statement after augmented with a prompt is {resuming_conversation}") return resuming_conversation + def get_prompt_wrapper(self, state, rg_to_resume=False): if self.takeover_rg_willing_to_handback_control(): if rg_to_resume: @@ -1023,22 +1017,6 @@ def get_prompt_wrapper(self, state, rg_to_resume=False): else: return self.get_prompt(state) - # def get_prompt_wrapper(self, state, rg_resuming_prompt=False): - # if self.takeover_rg_willing_to_handback_control(): - # rg_with_archived_state = self.state_manager.last_state_response.state.rg_that_was_taken_over - # logger.error(f"DEBUG RG_WITH_ARCHIVE {self.state_manager.last_state_response.state.rg_that_was_taken_over}") - # if self.name == rg_with_archived_state: - # logger.error(f"DEBUG SELF WITH ARCHIVE: {self.name} // {self.state_manager} // {self.state_manager.last_state_response}") - # archived_state = self.State.archived_state - # if archived_state: # TODO: Set to None later ???? - # self.state = archived_state - # archived_resuming_response = self.get_resuming_response(archived_state) - # logger.error("DEBUG ARCHIVE STATE PROMPT: {archived_resuming_response}") - # return archived_resuming_response - # else: - # return self.emptyPrompt() - # return self.get_prompt(state) - def possibly_augment_with_prompt(self, response): """ diff --git a/chirpy/core/response_generator/state.py b/chirpy/core/response_generator/state.py index f5e5f57..6c451cf 100644 --- a/chirpy/core/response_generator/state.py +++ b/chirpy/core/response_generator/state.py @@ -3,7 +3,7 @@ from chirpy.core.response_generator.response_type import ResponseType -from chirpy.core.entity_linker.entity_linker_classes import WikiEntity # EDIT: TAKEOVER +from chirpy.core.entity_linker.entity_linker_classes import WikiEntity import logging logger = logging.getLogger('chirpylogger') @@ -24,18 +24,18 @@ class BaseState: next_treelet_str: Optional[str] = '' response_types: Tuple[str] = () num_turns_in_rg: int = 0 - archived_state: "BaseState" = None # EDIT: TAKEOVER - rg_that_was_taken_over: str = None # EDIT: TAKEOVER - takeover_entity: WikiEntity = None # EDIT: TAKEOVER + archived_state: "BaseState" = None + rg_that_was_taken_over: str = None + takeover_entity: WikiEntity = None @dataclass class BaseConditionalState: prev_treelet_str: str = '' next_treelet_str: Optional[str] = '' response_types: Tuple[str] = NO_UPDATE - archived_state: "BaseState" = NO_UPDATE # EDIT: TAKEOVER - rg_that_was_taken_over: str = NO_UPDATE # EDIT: TAKEOVER - takeover_entity: WikiEntity = NO_UPDATE # EDIT: TAKEOVER + archived_state: "BaseState" = NO_UPDATE + rg_that_was_taken_over: str = NO_UPDATE + takeover_entity: WikiEntity = NO_UPDATE def construct_response_types_tuple(response_types): return tuple([str(x) for x in response_types]) diff --git a/chirpy/core/response_generator_datatypes.py b/chirpy/core/response_generator_datatypes.py index f56e9b3..510daa0 100644 --- a/chirpy/core/response_generator_datatypes.py +++ b/chirpy/core/response_generator_datatypes.py @@ -34,10 +34,10 @@ def __init__(self, conditional_state=None, tiebreak_priority=None, no_transition=False, - last_rg_willing_to_handover_control=False, # EDIT: TAKEOVER - rg_that_was_taken_over =None, # EDIT: TAKEOVER - takeover_entity=None, # EDIT: TAKEOVER - takeover_rg_willing_to_handback_control=False # EDIT: TAKEOVER + last_rg_willing_to_handover_control=False, + rg_that_was_taken_over =None, + takeover_entity=None, + takeover_rg_willing_to_handback_control=False ): """ :param text: text of the response @@ -103,10 +103,10 @@ def __init__(self, self.conditional_state = conditional_state self.tiebreak_priority = tiebreak_priority self.no_transition = no_transition - self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control # EDIT: TAKEOVER - self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT: TAKEOVER - self.takeover_entity = takeover_entity # EDIT: TAKEOVER - self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control # EDIT: TAKEOVER + self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control + self.rg_that_was_taken_over = rg_that_was_taken_over + self.takeover_entity = takeover_entity + self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control def reduce_size(self, max_size:int = None): """Gracefully degrade by removing non essential attributes. @@ -134,11 +134,11 @@ def __init__(self, expected_type: Optional[EntityGroup] = None, conditional_state=None, answer_type: AnswerType = AnswerType.QUESTION_SELFHANDLING, - last_rg_willing_to_handover_control=False, # EDIT: TAKEOVER - rg_that_was_taken_over =None, # EDIT: TAKEOVER - takeover_entity =None, # EDIT: TAKEOVER - takeover_rg_willing_to_handback_control=False, # EDIT: TAKEOVER - resuming_conversation_next_treelet=None # EDIT: TAKEOVER + last_rg_willing_to_handover_control=False, + rg_that_was_taken_over =None, + takeover_entity =None, + takeover_rg_willing_to_handback_control=False, + resuming_conversation_next_treelet=None ): """ :param text: text of the response @@ -178,11 +178,11 @@ def __init__(self, self.state = state self.conditional_state = conditional_state self.answer_type = answer_type - self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control # EDIT: TAKEOVER - self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT: TAKEOVER - self.takeover_entity = takeover_entity # EDIT: TAKEOVER - self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control # EDIT: TAKEOVER - self.resuming_conversation_next_treelet = resuming_conversation_next_treelet # EDIT: TAKEOVER + self.last_rg_willing_to_handover_control = last_rg_willing_to_handover_control + self.rg_that_was_taken_over = rg_that_was_taken_over + self.takeover_entity = takeover_entity + self.takeover_rg_willing_to_handback_control = takeover_rg_willing_to_handback_control + self.resuming_conversation_next_treelet = resuming_conversation_next_treelet def __repr__(self): return 'PromptResult' + str(self.__dict__) diff --git a/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py b/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py index 5b9234f..e88f417 100644 --- a/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py +++ b/chirpy/response_generators/closing_confirmation/closing_confirmation_response_generator.py @@ -101,7 +101,7 @@ def handle_custom_continuation_checks(self): # If neither matched, allow another RG to handle return self.emptyResult() - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: # EDIT: TAKEOVER + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: state = super().update_state_if_not_chosen(state, conditional_state) state.has_just_asked_to_exit = False return state diff --git a/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py b/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py index 95d9388..c5b6a1a 100644 --- a/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py +++ b/chirpy/response_generators/food/treelets/ask_favorite_food_treelet.py @@ -30,5 +30,5 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): next_treelet_str="food_introductory_treelet", cur_food=None), expected_type=ENTITY_GROUPS_FOR_EXPECTED_TYPE.food_related, - last_rg_willing_to_handover_control=True # EDIT: TAKEOVER + last_rg_willing_to_handover_control=False ) diff --git a/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py b/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py index 20814f0..af3a02f 100644 --- a/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py +++ b/chirpy/response_generators/food/treelets/comment_on_favorite_type_treelet.py @@ -89,13 +89,13 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): other_type = sample_from_type(cur_food) text = f"That totally makes sense! I also really enjoy {user_answer}. Personally, I really like {other_type}." - return ResponseGeneratorResult(text=text, priority=priority, # EDIT: TAKEOVER + return ResponseGeneratorResult(text=text, priority=priority, needs_prompt=False, state=state, cur_entity=entity, conditional_state=ConditionalState( prompt_treelet=self.rg.open_ended_user_comment_treelet.name, cur_food=cur_food_entity), - last_rg_willing_to_handover_control=False # EDIT: TAKEOVER + last_rg_willing_to_handover_control=False ) def get_resuming_statement(self, prompt_type=PromptType.FORCE_START, **kwargs): diff --git a/chirpy/response_generators/food/treelets/factoid_treelet.py b/chirpy/response_generators/food/treelets/factoid_treelet.py index 4a48041..3c39b60 100644 --- a/chirpy/response_generators/food/treelets/factoid_treelet.py +++ b/chirpy/response_generators/food/treelets/factoid_treelet.py @@ -52,5 +52,5 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): conditional_state=ConditionalState( prev_treelet_str=self.name, cur_food=cur_food), - last_rg_willing_to_handover_control=True # EDIT: TAKEOVER + last_rg_willing_to_handover_control=True ) diff --git a/chirpy/response_generators/food/treelets/introductory_treelet.py b/chirpy/response_generators/food/treelets/introductory_treelet.py index 04919ab..20c8155 100644 --- a/chirpy/response_generators/food/treelets/introductory_treelet.py +++ b/chirpy/response_generators/food/treelets/introductory_treelet.py @@ -74,7 +74,7 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): cur_entity=entity, conditional_state=ConditionalState(cur_food=entity, prompt_treelet=prompt_treelet), - last_rg_willing_to_handover_control=True) # EDIT: TAKEOVER + last_rg_willing_to_handover_control=True) def get_prompt(self, **kwargs): return None diff --git a/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py b/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py index 9fece4c..ab660c5 100644 --- a/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py +++ b/chirpy/response_generators/food/treelets/open_ended_user_comment_treelet.py @@ -74,5 +74,5 @@ def get_response(self, priority=ResponsePriority.STRONG_CONTINUE, **kwargs): prev_treelet_str=self.name, prompt_treelet=prompt_treelet, cur_food=None), - last_rg_willing_to_handover_control=False # EDIT: TAKEOVER + last_rg_willing_to_handover_control=False ) diff --git a/chirpy/response_generators/launch/launch_response_generator.py b/chirpy/response_generators/launch/launch_response_generator.py index 86ecfcc..c8a0bd4 100644 --- a/chirpy/response_generators/launch/launch_response_generator.py +++ b/chirpy/response_generators/launch/launch_response_generator.py @@ -50,7 +50,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi # state.asked_name_counter = 1 return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: # EDIT: TAKEOVER + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: state = super().update_state_if_not_chosen(state, conditional_state) state.next_treelet_str = None return state diff --git a/chirpy/response_generators/music/music_response_generator.py b/chirpy/response_generators/music/music_response_generator.py index 984d8a0..68bbe46 100644 --- a/chirpy/response_generators/music/music_response_generator.py +++ b/chirpy/response_generators/music/music_response_generator.py @@ -76,7 +76,7 @@ def update_state_if_chosen(self, state, conditional_state): state.discussed_entities.append(state.cur_singer_str) return state - def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over=False): # EDIT: TAKEOVER + def update_state_if_not_chosen(self, state, conditional_state, rg_was_taken_over=False): state = super().update_state_if_not_chosen(state, conditional_state) return state diff --git a/chirpy/response_generators/neural_chat/neural_chat_response_generator.py b/chirpy/response_generators/neural_chat/neural_chat_response_generator.py index 5a73abc..8ac8b18 100644 --- a/chirpy/response_generators/neural_chat/neural_chat_response_generator.py +++ b/chirpy/response_generators/neural_chat/neural_chat_response_generator.py @@ -188,7 +188,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi state.update_if_chosen(conditional_state) return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: # EDIT: TAKEOVER + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: logger.primary_info(f"Neural chat state is {state}") if conditional_state is not None: state.update_if_not_chosen(conditional_state) diff --git a/chirpy/response_generators/neural_chat/state.py b/chirpy/response_generators/neural_chat/state.py index 563fa85..e5045b8 100644 --- a/chirpy/response_generators/neural_chat/state.py +++ b/chirpy/response_generators/neural_chat/state.py @@ -3,7 +3,7 @@ from typing import List, Optional, Set, Tuple from chirpy.core.response_generator.state import NO_UPDATE -from chirpy.core.entity_linker.entity_linker_classes import WikiEntity # EDIT: TAKEOVER +from chirpy.core.entity_linker.entity_linker_classes import WikiEntity import copy @@ -37,7 +37,7 @@ def __init__(self, next_treelet: Optional[str] = None, most_recent_treelet: Opti user_utterance: Optional[str] = None, user_labels: List[str] = [], bot_utterance: Optional[str] = None, bot_labels: List[str] = [], neural_responses: Optional[List[str]] = None, num_topic_shifts: int = 0, - archived_state: "State" = None, rg_that_was_taken_over: str = None, takeover_entity: WikiEntity = None): # EDIT: TAKEOVER + archived_state: "State" = None, rg_that_was_taken_over: str = None, takeover_entity: WikiEntity = None): """ @param next_treelet: the name of the treelet we should run on the next turn if our response/prompt is chosen. None means turn off next turn. @param most_recent_treelet: the name of the treelet that handled this turn, if applicable @@ -64,9 +64,9 @@ def __init__(self, next_treelet: Optional[str] = None, most_recent_treelet: Opti self.bot_labels = bot_labels self.neural_responses = neural_responses self.num_topic_shifts = num_topic_shifts - self.archived_state = archived_state # EDIT: TAKEOVER - self.rg_that_was_taken_over = rg_that_was_taken_over # EDIT: TAKEOVER - self.takeover_entity = takeover_entity # EDIT: TAKEOVER + self.archived_state = archived_state + self.rg_that_was_taken_over = rg_that_was_taken_over + self.takeover_entity = takeover_entity def __repr__(self): return f" BaseState: # EDIT: TAKEOVER + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> BaseState: state = super().update_state_if_not_chosen(state, conditional_state) state.handle_response = False state.offense_type = None diff --git a/chirpy/response_generators/opinion2/opinion_response_generator.py b/chirpy/response_generators/opinion2/opinion_response_generator.py index 7d4dc73..6805d0f 100644 --- a/chirpy/response_generators/opinion2/opinion_response_generator.py +++ b/chirpy/response_generators/opinion2/opinion_response_generator.py @@ -533,7 +533,7 @@ def update_state_if_chosen(self, state: State, conditional_state : Optional[Stat if val != NO_UPDATE: setattr(state, attr, val) return state - def update_state_if_not_chosen(self, state: State, conditional_state : Optional[State], rg_was_taken_over=False) -> State: # EDIT: TAKEOVER + def update_state_if_not_chosen(self, state: State, conditional_state : Optional[State], rg_was_taken_over=False) -> State: new_state = state.reset_state() new_state.num_turns_since_long_policy += 1 return new_state diff --git a/chirpy/response_generators/wiki2/treelets/handback_treelet.py b/chirpy/response_generators/wiki2/treelets/handback_treelet.py index 97231d7..5d0acae 100644 --- a/chirpy/response_generators/wiki2/treelets/handback_treelet.py +++ b/chirpy/response_generators/wiki2/treelets/handback_treelet.py @@ -48,7 +48,7 @@ def get_response(self, priority=ResponsePriority.FORCE_START, **kwargs): state, utterance, response_types = self.get_state_utterance_response_types() takeover_entity = state.takeover_entity - logger.error(f'WIKI HANDBACK') + logger.debug(f'WIKI handback_treelet is triggered.') wrap_up_text = self.get_acknowledgement() diff --git a/chirpy/response_generators/wiki2/treelets/takeover_treelet.py b/chirpy/response_generators/wiki2/treelets/takeover_treelet.py index cb848b4..040a7c2 100644 --- a/chirpy/response_generators/wiki2/treelets/takeover_treelet.py +++ b/chirpy/response_generators/wiki2/treelets/takeover_treelet.py @@ -14,31 +14,21 @@ logger = logging.getLogger('chirpylogger') -class WikiTakeOverTreelet(Treelet): # EDIT: TAKEOVER +class WikiTakeOverTreelet(Treelet): name = "wiki_takeover_treelet" - def get_summary_takeover(self, related_wiki_section, sentseg_fn, max_words, max_sents): - summary = wiki_utils.get_summary(related_wiki_section['text'], sentseg_fn, max_words, max_sents) - logger.primary_info(f"Takeover Summary is: {summary}") - summary = wiki_utils.clean_wiki_text(summary) - logger.primary_info(f"Takeover Summary after clean is: {summary}") - if wiki_utils.contains_offensive(summary): - logger.primary_info(f"Found takeover overview to be offensive, discarding it") - return None - return summary - def get_response(self, priority=ResponsePriority.FORCE_START, **kwargs): state, utterance, response_types = self.get_state_utterance_response_types() rg_that_was_taken_over = self.rg.state_manager.last_state.active_rg - logger.error(f'RG_THAT_WAS_TAKEN_OVER: {rg_that_was_taken_over}') + logger.debug(f'rg that was taken over is {rg_that_was_taken_over}.') cur_entity = self.get_current_entity() takeover_entity = self.get_most_recent_able_to_takeover_entity() takeover_text = wiki_utils.get_takeover_text(self.rg, cur_entity, takeover_entity) - logger.error(f"TAKEOVER TEXT: {takeover_text}") + logger.info(f"takenover_text is {takeover_text}") if takeover_text: intro_intersect_text = wiki_utils.get_random_intro_intersect_text(cur_entity.talkable_name, takeover_entity.talkable_name) @@ -59,10 +49,11 @@ def get_response(self, priority=ResponsePriority.FORCE_START, **kwargs): takeover_neural_response = self.rg.get_neural_response(prefix=neural_prefix) takeover_neural_response = takeover_neural_response.split('.')[0] generated_response = takeover_neural_response[len(neural_prefix):] - logger.error(f"TAKEOVER_NEURAL_RESPONSE: {takeover_neural_response}") - if takeover_entity.talkable_name in generated_response and cur_entity.talkable_name in generated_response : + logger.info(f"takenover_neural_response is {takeover_neural_response}") + if takeover_entity.talkable_name in generated_response and cur_entity.talkable_name in generated_response: intro_intersect_text = wiki_utils.get_random_intro_intersect_text(cur_entity.talkable_name, takeover_entity.talkable_name) + logger.info("takenover_neural_response is used.") starter_text = wiki_utils.get_random_starter_text() return ResponseGeneratorResult( text=intro_intersect_text + starter_text + generated_response, @@ -75,4 +66,7 @@ def get_response(self, priority=ResponsePriority.FORCE_START, **kwargs): takeover_rg_willing_to_handback_control=True ) else: + logger.info("takenover_neural_response is not used because it does not contain takeover_entity and cur_entity in it.") + logger.info( + "WIKI fails to takeover.") return None \ No newline at end of file diff --git a/chirpy/response_generators/wiki2/wiki_response_generator.py b/chirpy/response_generators/wiki2/wiki_response_generator.py index 7abd88e..df30c9d 100644 --- a/chirpy/response_generators/wiki2/wiki_response_generator.py +++ b/chirpy/response_generators/wiki2/wiki_response_generator.py @@ -25,10 +25,10 @@ from chirpy.annotators.corenlp import Sentiment from chirpy.response_generators.wiki2.state import State,ConditionalState, NO_UPDATE -from chirpy.response_generators.wiki2.treelets.takeover_treelet import WikiTakeOverTreelet # EDIT: TAKEOVER -from chirpy.response_generators.wiki2.treelets.handback_treelet import WikiHandBackTreelet # EDIT: TAKEOVER +from chirpy.response_generators.wiki2.treelets.takeover_treelet import WikiTakeOverTreelet +from chirpy.response_generators.wiki2.treelets.handback_treelet import WikiHandBackTreelet -from chirpy.core.offensive_classifier.offensive_classifier import OffensiveClassifier # EDIT: TAKEOVER +from chirpy.core.offensive_classifier.offensive_classifier import OffensiveClassifier logger = logging.getLogger('chirpylogger') @@ -58,15 +58,15 @@ def __init__(self, state_manager) -> None: self.discuss_section_treelet = DiscussSectionTreelet(self) self.discuss_section_further_treelet = DiscussSectionFurtherTreelet(self) self.get_opinion_treelet = GetOpinionTreelet(self) - self.takeover_treelet = WikiTakeOverTreelet(self) # EDIT: TAKEOVER - self.handback_treelet = WikiHandBackTreelet(self) # EDIT: TAKEOVER + self.takeover_treelet = WikiTakeOverTreelet(self) + self.handback_treelet = WikiHandBackTreelet(self) treelets = {t.name: t for t in [self.check_user_knowledge_treelet, self.acknowledge_user_knowledge_treelet, self.factoid_treelet, self.intro_entity_treelet, self.combined_til_treelet, self.discuss_article_treelet, self.discuss_section_treelet, self.discuss_section_further_treelet, self.get_opinion_treelet, - self.takeover_treelet, self.handback_treelet]} # EDIT: TAKEOVER + self.takeover_treelet, self.handback_treelet]} super().__init__(state_manager, treelets=treelets, state_constructor=State, can_give_prompts=True, conditional_state_constructor=ConditionalState, @@ -652,7 +652,7 @@ def update_state_if_chosen(self, state: State, conditional_state: Optional[Condi return state - def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: # EDIT: TAKEOVER + def update_state_if_not_chosen(self, state: State, conditional_state: Optional[ConditionalState], rg_was_taken_over=False) -> State: state = super().update_state_if_not_chosen(state, conditional_state) state.cur_doc_title = None state.suggested_sections = [] @@ -670,7 +670,7 @@ def update_state_if_not_chosen(self, state: State, conditional_state: Optional[C return state - def get_takeover_response(self): # EDIT: TAKEOVER + def get_takeover_response(self): logger.error("WIKI TAKEOVER") return self.takeover_treelet.get_response(ResponsePriority.FORCE_START) diff --git a/chirpy/response_generators/wiki2/wiki_utils.py b/chirpy/response_generators/wiki2/wiki_utils.py index 3ab6aae..7a6ede8 100644 --- a/chirpy/response_generators/wiki2/wiki_utils.py +++ b/chirpy/response_generators/wiki2/wiki_utils.py @@ -23,8 +23,7 @@ import random import math -from chirpy.core.entity_linker.entity_linker_classes import WikiEntity # EDIT: TAKEOVER -from chirpy.annotators.convpara import ConvPara # EDIT: TAKEOVER +from chirpy.core.entity_linker.entity_linker_classes import WikiEntity lucene_stopwords = {'a', 'an', 'and', 'are', 'as', 'at', 'be', 'but', 'by', 'for', 'if', 'in', 'into', 'is', 'it', @@ -169,6 +168,8 @@ def filter_highlight_sections(title, es_sections): logger.error(f"SECTIONS2: {filtered_sections}") return filtered_sections + + def filter_sections(title, es_sections): wiki_sections = [] for section in es_sections['hits']['hits']: @@ -279,6 +280,7 @@ def search_wiki_sections(doc_title: str, phrases: tuple, wiki_links:tuple) -> Li def prune_section(section): return section['text'][0] in {'†', '+', '*'} + def clean_takeover_wiki_text(text: str) -> str: modified_text = clean_wiki_text(text) index_caption = modified_text.find(']]') @@ -286,8 +288,9 @@ def clean_takeover_wiki_text(text: str) -> str: modified_text = modified_text[index_caption + 2:] return modified_text -def summarize_takeover_candidate(rg, text: str, span_to_keep: str, max_sents: int = 3) -> str: - logger.debug(f'Summarizing takeover text: {text}') + +def summarize_takeover_candidate_text(rg, text: str, span_to_keep: str, max_words: int = 50, max_sents: int = 3) -> str: + logger.info(f'Summarizing takeover text: {text}') local_sentseg_fn = lambda text: re.split('[.\n]', text) sentseg_fn = NLTKSentenceSegmenter( @@ -306,7 +309,7 @@ def summarize_takeover_candidate(rg, text: str, span_to_keep: str, max_sents: in found = True summary += sentence + ('.' if sentence[-1] not in {'.', '!', '?'} else ' ') num_sentences += 1 - if num_sentences > max_sents and found: + if found and (num_sentences > max_sents or len(summary.split(' ')) < max_words): break return summary @@ -328,52 +331,45 @@ def search_wiki_intersect_sections(rg, doc_title: str, search_entity: WikiEntity if not contains_offensive(cleaned_text) and not rg.has_overlap_with_history(cleaned_text, threshold=0.8): for s in top_spans: if s in cleaned_text: - modified_text = re.sub(r"\([^()]*\)", "", text) - summary_text = summarize_takeover_candidate(rg, modified_text, span_to_keep=s) - candidate_texts.append(summary_text) + summarized_text = summarize_takeover_candidate_text(rg, cleaned_text, span_to_keep=s) + candidate_texts.append(summarized_text) break - logger.error(f"CANDIDATE_TEXTS (from doc_title {doc_title}): {candidate_texts}") + logger.info(f"candidate_texts from doc_title {doc_title} is {candidate_texts}") return candidate_texts -def get_paraphrase(rg, text: str, entity: str) -> str: - conv_para = ConvPara(rg.state_manager) - paraphrases = conv_para.get_paraphrases(background=text, entity=entity) - logger.error(f"PARAPHRASES: {paraphrases}") - return paraphrases def get_takeover_text(rg, cur_entity: WikiEntity, takeover_entity: WikiEntity) -> Optional[str]: related_wiki_texts_from_cur_entity_doc = search_wiki_intersect_sections(rg, cur_entity.talkable_name, takeover_entity) + if related_wiki_texts_from_cur_entity_doc: + return random.choice(related_wiki_texts_from_cur_entity_doc) + related_wiki_texts_from_takeover_entity_doc = search_wiki_intersect_sections(rg, takeover_entity.talkable_name, cur_entity) + if related_wiki_texts_from_takeover_entity_doc: + return random.choice(related_wiki_texts_from_takeover_entity_doc) - intersect_wiki_texts = related_wiki_texts_from_cur_entity_doc + related_wiki_texts_from_takeover_entity_doc + return None - logger.error(f"INTERSECT_WIKI_TEXTS: {intersect_wiki_texts}") - if related_wiki_texts_from_cur_entity_doc: - takeover_text = random.choice(related_wiki_texts_from_cur_entity_doc) - return takeover_text - - elif related_wiki_texts_from_takeover_entity_doc: - takeover_text = random.choice(related_wiki_texts_from_takeover_entity_doc) - return get_paraphrase(rg, takeover_text, cur_entity.name) - else: - return None INTRO_INTERSECT_TEXT = ["Speaking of {} and {}, ", "Relating to {} and {}, ", "Since you mentioned {} and {}, "] + def get_random_intro_intersect_text(cur_entity: str, takeover_entity: str) -> str: return random.choice(INTRO_INTERSECT_TEXT).format(cur_entity, takeover_entity) + STARTER_TEXTS = ["did you know that ", "I recently learned that ", "I was reading recently and found out that ", "did you know that ", "I was interested to learn that "] + def get_random_starter_text(): return random.choice(STARTER_TEXTS) + def get_text_for_entity(entity): results = es.search(index='enwiki-20200920-sections', body={ 'query': { @@ -397,6 +393,7 @@ def replaceByLength(matchobj): sections = sorted(sections, key=(lambda x: -len(x[1]))) return sections + def check_section_summary(rg, section_summary, selected_section, allow_history_overlap=False): """ Check that the section summary is present, non-offensive, and does not overlap with history. From 8db9262808e995d57f29199498c238e1721cc472 Mon Sep 17 00:00:00 2001 From: thanawan-atc Date: Mon, 25 Jul 2022 13:40:27 -0700 Subject: [PATCH 6/6] Added WIKI takeover feature --- agents/local_agent.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/agents/local_agent.py b/agents/local_agent.py index 36dd2f6..7a02e18 100644 --- a/agents/local_agent.py +++ b/agents/local_agent.py @@ -205,17 +205,17 @@ def create_handler(self): response_generator_classes=[LaunchResponseGenerator, FallbackResponseGenerator, NeuralFallbackResponseGenerator, NeuralChatResponseGenerator, - # OffensiveUserResponseGenerator, - # CategoriesResponseGenerator, - # ClosingConfirmationResponseGenerator, - # AcknowledgmentResponseGenerator, - # PersonalIssuesResponseGenerator, - # OpinionResponseGenerator2, - # AliensResponseGenerator, + OffensiveUserResponseGenerator, + CategoriesResponseGenerator, + ClosingConfirmationResponseGenerator, + AcknowledgmentResponseGenerator, + PersonalIssuesResponseGenerator, + OpinionResponseGenerator2, + AliensResponseGenerator, TransitionResponseGenerator, FoodResponseGenerator, WikiResponseGenerator, - # MusicResponseGenerator, + MusicResponseGenerator, ], annotator_classes = [QuestionAnnotator, DialogActAnnotator, NavigationalIntentModule, StanfordnlpModule, CorenlpModule, EntityLinkerModule, BlenderBot],