Skip to content

Commit

Permalink
Maintenance (thu-coai#132)
Browse files Browse the repository at this point in the history
* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue thu-coai#130

* remove debug output
  • Loading branch information
zqwerty authored Sep 27, 2020
1 parent 424a38b commit afd3528
Showing 1 changed file with 123 additions and 90 deletions.
213 changes: 123 additions & 90 deletions convlab2/policy/rule/multiwoz/policy_agenda_multiwoz.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,20 @@ def predict(self, sys_dialog_act):
self.agenda.close_session()
else:
sys_action = self._transform_sysact_in(sys_action)
# print('sys action before update agenda', sys_action)
self.agenda.update(sys_action, self.goal)
if self.goal.task_complete():
self.agenda.close_session()

# A -> A' + user_action
# action = self.agenda.get_action(random.randint(2, self.max_initiative))
action = self.agenda.get_action(self.max_initiative)
action = {}
while len(action) == 0:
# A -> A' + user_action
# action = self.agenda.get_action(random.randint(2, self.max_initiative))
action = self.agenda.get_action(self.max_initiative)

# transform to DA
action = self._transform_usract_out(action)
# transform to DA
action = self._transform_usract_out(action)
# print(action)

tuples = []
for domain_intent, svs in action.items():
Expand Down Expand Up @@ -169,6 +173,8 @@ def _transform_usract_out(cls, action):
new_action[new_act].append(['NotBook', 'none'])
elif slot is not None:
new_action[new_act].append([slot, pairs[1]])
if len(new_action[new_act]) == 0:
new_action.pop(new_act)
# new_action[new_act] = [[REF_USR_DA_M[dom.capitalize()].get(pairs[0], pairs[0]), pairs[1]] for pairs in action[act]]
else:
new_action[act] = action[act]
Expand Down Expand Up @@ -848,7 +854,6 @@ def __pop(self, initiative=1):
diaacts = []
slots = []
values = []

p_diaact, p_slot = self.__check_next_diaact_slot()
if p_diaact.split('-')[1] == 'inform' and p_slot in BOOK_SLOT:
for _ in range(10 if self.__cur_push_num == 0 else self.__cur_push_num):
Expand Down Expand Up @@ -914,38 +919,38 @@ def __str__(self):
from convlab2.dst.rule.multiwoz.dst import RuleDST
from convlab2.nlu.jointBERT.multiwoz.nlu import BERTNLU

seed = 50
seed = 41
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
#
# sys_nlu = BERTNLU()
# sys_dst = RuleDST()
# sys_policy = RulePolicy()
# sys_nlg = TemplateNLG(is_user=False)
# sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')

sys_nlu = BERTNLU()
sys_dst = RuleDST()
sys_policy = RulePolicy()
sys_nlg = TemplateNLG(is_user=False)
sys_agent = PipelineAgent(sys_nlu, sys_dst, sys_policy, sys_nlg, name='sys')

user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
user_dst = None
user_policy = RulePolicy(character='usr')
user_nlg = TemplateNLG(is_user=True)
user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')
# user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
# model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
# user_dst = None
# user_policy = RulePolicy(character='usr')
# user_nlg = TemplateNLG(is_user=True)
# user_agent = PipelineAgent(user_nlu, user_dst, user_policy, user_nlg, name='user')

# evaluator = MultiWozEvaluator()
# sess = BiSession(sys_agent=sys_agent, user_agent=user_agent, kb_query=None, evaluator=evaluator)




# user_policy = UserPolicyAgendaMultiWoz()
user_policy = UserPolicyAgendaMultiWoz()
#
# sys_policy = RuleBasedMultiwozBot()
sys_policy = RulePolicy(character='sys')
#
# user_nlg = TemplateNLG(is_user=True, mode='manual')
# sys_nlg = TemplateNLG(is_user=False, mode='manual')
user_nlg = TemplateNLG(is_user=True, mode='manual')
sys_nlg = TemplateNLG(is_user=False, mode='manual')
#
# dst = RuleDST()
dst = RuleDST()
#
# user_nlu = BERTNLU(mode='sys', config_file='multiwoz_sys_context.json',
# model_file='https://convlab.blob.core.windows.net/convlab-2/bert_multiwoz_sys_context.zip')
Expand All @@ -956,12 +961,24 @@ def __str__(self):
# if 'restaurant' in goal['domain_ordering'] and 'hotel' in goal['domain_ordering']:
# break
# # pprint(goal)
user_goal = {'domain_ordering': ('restaurant', 'hotel', 'taxi'),
'hotel': {'book': {'day': 'sunday', 'people': '6', 'stay': '4'},
'info': {'internet': 'no',
'parking': 'no',
'pricerange': 'moderate',
'area': 'centre'}},
user_goal = {'domain_ordering': ('hotel', 'attraction'),
'train': {
'info': {'arriveBy': '16:00',
'day': 'monday',
'departure': 'cambridge',
'destination': 'stansted airport'},
'book': {'people': 2}, 'booked': '?'
},
'attraction': {
'info': {'type': 'museum'},
'reqt': ['phone']
},
'hotel': {
'info': {'internet': 'yes',
'parking': 'yes',
'stars': '4',
'type': 'hotel'},
'reqt': ['postcode']},
'restaurant': {'info': {'area': 'centre',
'food': 'portuguese',
'pricerange': 'cheap'},
Expand All @@ -986,7 +1003,7 @@ def __str__(self):
user_policy.init_session(ini_goal=goal)
print('init goal:')
# pprint(user_policy.get_goal())
pprint(user_agent.policy.get_goal())
# pprint(user_agent.policy.get_goal())
# pprint(sess.evaluator.goal)
# print('-' * 50)
# for i in range(20):
Expand All @@ -1005,85 +1022,101 @@ def __str__(self):
# print('=' * 100)

history = []
user_utt = user_agent.response('')
print(user_utt)
user_utt = 'I need a restaurant . It just needs to be expensive . I am also in the market for a new restaurant . Is there something in the centre of town ? Do you have portuguese food ?'
# history.append(['user', user_utt])
sys_agent.dst.state['belief_state']['restaurant']['semi']['food'] = 'portuguese'
sys_utt = sys_agent.response(user_utt)
pprint(sys_agent.dst.state)
print(sys_utt)
sys_utt = "I have n't found any in the centre. I am unable to find any portuguese restaurants in town ."
# history.append(['user', user_utt])

user_utt = user_agent.response(sys_utt)
print(user_utt)
user_utt = "It just needs to be cheap ."
sys_utt = sys_agent.response(user_utt)
print(sys_utt)
sys_utt = "It is in the centre area . They serve portuguese . Would you like to try nandos city centre ? They are in the cheap price range . I will book it for you and get a reference number ?"

user_utt = user_agent.response(sys_utt)
print(user_utt)
sys_utt = sys_agent.response(user_utt)
print(sys_utt)

user_utt = user_agent.response(sys_utt)
print(user_utt)
sys_utt = sys_agent.response(user_utt)
print(sys_utt)

user_utt = user_agent.response(sys_utt)
print(user_utt)
sys_utt = sys_agent.response(user_utt)
print(sys_utt)

# user_utt = user_agent.response('')
# print(user_utt)
# user_utt = 'I need a restaurant . It just needs to be expensive . I am also in the market for a new restaurant . Is there something in the centre of town ? Do you have portuguese food ?'
# # history.append(['user', user_utt])
# sys_agent.dst.state['belief_state']['restaurant']['semi']['food'] = 'portuguese'
# sys_utt = sys_agent.response(user_utt)
# pprint(sys_agent.dst.state)
# print(sys_utt)
# sys_utt = "I have n't found any in the centre. I am unable to find any portuguese restaurants in town ."
# # history.append(['user', user_utt])
#
# print(user_policy.agenda)
# user_act = user_policy.predict([])
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# user_utt = user_agent.response(sys_utt)
# print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# sys_utt = sys_nlg.generate(sys_act)
# # sys_act.append(["Request", "Restaurant", "Price", "?"])
# # sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']]
# print(sys_act)
# user_utt = "It just needs to be cheap ."
# sys_utt = sys_agent.response(user_utt)
# print(sys_utt)
# sys_utt = "It is in the centre area . They serve portuguese . Would you like to try nandos city centre ? They are in the cheap price range . I will book it for you and get a reference number ?"
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# user_utt = user_agent.response(sys_utt)
# print(user_utt)
# sys_utt = sys_agent.response(user_utt)
# print(sys_utt)
#
# user_utt = user_agent.response(sys_utt)
# print(user_utt)
# sys_utt = sys_agent.response(user_utt)
# print(sys_utt)
#
# user_utt = user_agent.response(sys_utt)
# print(user_utt)
# sys_utt = sys_agent.response(user_utt)
# print(sys_utt)

#
print(user_policy.agenda)
user_act = user_policy.predict([])
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
history.append(['user', user_utt])
state = dst.state
state['user_action'] = user_act
dst.update(user_act)
# pprint(state)
sys_act = sys_policy.predict(state)
sys_utt = sys_nlg.generate(sys_act)
# sys_act.append(["Request", "Restaurant", "Price", "?"])
# sys_act = [['Request', 'Hotel', 'Area', '?'], ['Request', 'Hotel', 'Stars', '?']]
sys_act = [['Inform', 'Hotel', 'Post', 'pe296fl']]
print(sys_act)
history.append(['sys', user_utt])

# sys_utt = sys_agent.response(user_utt)
# print(sys_utt)
#
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
history.append(['user', user_utt])
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [['Inform', 'Hotel', 'Choice', '3']]
# print(sys_act)
sys_act = [
['Inform', 'Hotel', 'Post', 'pe296fl']
]
print(sys_act)
# sys_utt = sys_agent.response(user_utt)
# print(sys_utt)
# sys_utt = 'The arrive time is 15:08 . The train will be departing from cambridge . The booking is for arriving in stansted airport . TR6936 will be your perfect fit . How about 14:40 will that work for you ?'
# history.append(['sys', user_utt])
#
#
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
# sys_act = user_nlu.predict(sys_utt, history)
# print(sys_act)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
# # pprint(state)
# sys_act = sys_policy.predict(state)
# # sys_act = [["Book", "Booking", "Ref", "7GAWK763"]]
# print(sys_act)
sys_act = [['Request', 'Hotel', 'Price', '?'], ['Request', 'Attraction', 'Price', '?']]
print(sys_act)
# #
# user_act = user_policy.predict(sys_act)
# print(user_act)
# user_utt = user_nlg.generate(user_act)
# print(user_utt)
user_act = user_policy.predict(sys_act)
print(user_act)
user_utt = user_nlg.generate(user_act)
print(user_utt)
# state = dst.state
# state['user_action'] = user_act
# dst.update(user_act)
Expand Down

0 comments on commit afd3528

Please sign in to comment.