From 55ad63958a3eff92a04a233c198fbc25ea29502a Mon Sep 17 00:00:00 2001 From: macota Date: Thu, 2 Jun 2022 16:03:20 -0400 Subject: [PATCH] cache api calls and responses on setup (#4565) * cache api calls and responses on setup * use PathManager instead of os and json instead of pickle * fix json.load and file open * fix typo Co-authored-by: Martin Corredor --- parlai/tasks/multiwoz_v22/agents.py | 30 +++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/parlai/tasks/multiwoz_v22/agents.py b/parlai/tasks/multiwoz_v22/agents.py index 063661f1a3b..fe2cf3ce572 100644 --- a/parlai/tasks/multiwoz_v22/agents.py +++ b/parlai/tasks/multiwoz_v22/agents.py @@ -262,10 +262,19 @@ def _get_round(self, dialogue_id, raw_episode, turn_id): ) if not valid: continue + call = maybe_call - resp = self._get_find_api_response( - intent, frame["state"]["slot_values"], sys_dialog_act - ) + call_key = str(call) + if call_key not in self.call_response_cache: + resp = self._get_find_api_response( + intent, + frame["state"]["slot_values"], + sys_dialog_act, + ) + self.call_response_cache[call_key] = resp + else: + resp = self.call_response_cache[call_key] + elif "book" in intent: for key in sys_dialog_act: if "Book" in key: # and "Inform" not in key: @@ -300,8 +309,17 @@ def setup_episodes(self, fold): """ Parses into TodStructuredEpisode. """ - self.dbs = self.load_dbs() self.schemas = self.load_schemas() + cache_path = os.path.join(self.dpath, f"{fold}_call_response_cache.json") + + if PathManager.exists(cache_path): + with PathManager.open(cache_path, 'r') as f: + self.call_response_cache = json.load(f) + self.dbs = None + else: + self.call_response_cache = {} + self.dbs = self.load_dbs() + with PathManager.open(os.path.join(self.dpath, "dialog_acts.json")) as f: self.dialog_acts = json.load(f) @@ -345,6 +363,10 @@ def setup_episodes(self, fold): rounds=rounds, ) episodes.append(episode) + + with PathManager.open(cache_path, 'w') as f: + json.dump(self.call_response_cache, f) + return episodes def get_id_task_prefix(self):