Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

cache api calls and responses on setup #4565

Merged
merged 4 commits into from
Jun 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions parlai/tasks/multiwoz_v22/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,19 @@ def _get_round(self, dialogue_id, raw_episode, turn_id):
)
if not valid:
continue

call = maybe_call
resp = self._get_find_api_response(
intent, frame["state"]["slot_values"], sys_dialog_act
)
call_key = str(call)
if call_key not in self.call_response_cache:
resp = self._get_find_api_response(
intent,
frame["state"]["slot_values"],
sys_dialog_act,
)
self.call_response_cache[call_key] = resp
else:
resp = self.call_response_cache[call_key]

elif "book" in intent:
for key in sys_dialog_act:
if "Book" in key: # and "Inform" not in key:
Expand Down Expand Up @@ -300,8 +309,17 @@ def setup_episodes(self, fold):
"""
Parses into TodStructuredEpisode.
"""
self.dbs = self.load_dbs()
self.schemas = self.load_schemas()
cache_path = os.path.join(self.dpath, f"{fold}_call_response_cache.json")

if PathManager.exists(cache_path):
with PathManager.open(cache_path, 'r') as f:
self.call_response_cache = json.load(f)
self.dbs = None
else:
self.call_response_cache = {}
self.dbs = self.load_dbs()

with PathManager.open(os.path.join(self.dpath, "dialog_acts.json")) as f:
self.dialog_acts = json.load(f)

Expand Down Expand Up @@ -345,6 +363,10 @@ def setup_episodes(self, fold):
rounds=rounds,
)
episodes.append(episode)

with PathManager.open(cache_path, 'w') as f:
json.dump(self.call_response_cache, f)

return episodes

def get_id_task_prefix(self):
Expand Down