Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
cache api calls and responses on setup (#4565)
Browse files Browse the repository at this point in the history
* cache api calls and responses on setup

* use PathManager instead of os and json instead of pickle

* fix json.load and file open

* fix typo

Co-authored-by: Martin Corredor <mcorredor@devfair0237.h2.fair>
  • Loading branch information
2 people authored and kushalarora committed Jun 15, 2022
1 parent ad842f3 commit 55ad639
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions parlai/tasks/multiwoz_v22/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,19 @@ def _get_round(self, dialogue_id, raw_episode, turn_id):
)
if not valid:
continue

call = maybe_call
resp = self._get_find_api_response(
intent, frame["state"]["slot_values"], sys_dialog_act
)
call_key = str(call)
if call_key not in self.call_response_cache:
resp = self._get_find_api_response(
intent,
frame["state"]["slot_values"],
sys_dialog_act,
)
self.call_response_cache[call_key] = resp
else:
resp = self.call_response_cache[call_key]

elif "book" in intent:
for key in sys_dialog_act:
if "Book" in key: # and "Inform" not in key:
Expand Down Expand Up @@ -300,8 +309,17 @@ def setup_episodes(self, fold):
"""
Parses into TodStructuredEpisode.
"""
self.dbs = self.load_dbs()
self.schemas = self.load_schemas()
cache_path = os.path.join(self.dpath, f"{fold}_call_response_cache.json")

if PathManager.exists(cache_path):
with PathManager.open(cache_path, 'r') as f:
self.call_response_cache = json.load(f)
self.dbs = None
else:
self.call_response_cache = {}
self.dbs = self.load_dbs()

with PathManager.open(os.path.join(self.dpath, "dialog_acts.json")) as f:
self.dialog_acts = json.load(f)

Expand Down Expand Up @@ -345,6 +363,10 @@ def setup_episodes(self, fold):
rounds=rounds,
)
episodes.append(episode)

with PathManager.open(cache_path, 'w') as f:
json.dump(self.call_response_cache, f)

return episodes

def get_id_task_prefix(self):
Expand Down

0 comments on commit 55ad639

Please sign in to comment.