Skip to content

Commit

Permalink
Add warmup for DQN and fix minor bugs (#150)
Browse files Browse the repository at this point in the history
* Initial commit

* Update README.md

* sync with commit aa1af0ee81ba591d1cf3c222c9d71963ed1dca98

* add gitignore

* update tutorial

* update mdrg, not use dbPointer

* update mdrg, not use dbPointer

* update mdrg, download before use dbPointer

* update analyzer

* update README

* update tutorial

* Fix dbquery when matching name

* move change to dev branch

* move dbquery change from master to dev branch

* disable travis for now

* do nothing in travis

* do nothing in travis for now

* not deploy now

* add docs

* update .travis.yml

* update .travis.yml

* update .travis.yml

* update rst files

* add alias center for centre in dbquery

* fix the policy training bug

* fix the bug of nan gradient

* add cross-lingual dst data

* Update README.md

* Update README.md

* Add CrossWoz Web support and some minor bug fix (#19)

* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>

* modify xdst data name

* Translation train on MultiWOZ (Chinese) nad CrossWOZ (English) of SUMBT (#17)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en

* updata xdst baseline

* Update README.md

* fix allennlp==0.9.0

* Update README.md

* modify build message function for goal generation

* Fix goal generator and dbquery for multiwoz (#32)

* move dbquery change from master to dev branch

* add alias center for centre in dbquery

* replace attraction type 'mutliple sports' to 'multiple sports', involving only one entity

* add depart and destination constraints for searching db (ignore=False), modify goal generator to draw the values of these two slots from database

* fix bug (#35)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en

* fix bug #34

* revert changes

* update demo video link

* Update README.md

* some changes in #36 (#37)

* multiwoz_zh

* crosswoz_en

* translation train

* test translation train

* update evaluation code

* update evaluation code for crosswoz

* evaluate human val set

* update readme

* evaluate machine val

* extract all ontology, bad result

* update evalutate

* update evalutation result on crosswoz-en

* fix bug #34

* revert changes

* revert changes

* some changes of #36

* fix analyzer example.py

* dst/evaluate.py: Use utf-8 encoding

* use transformers library to automate model caching

* Update README.md

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* fix nlu max len

* update travis

* Update run_agent.py

* Create README.md

* Update README.md

* modify human_eval README

* fix sclstm crosswoz import issues

* update travis.yml

* try to fix deploy

* Update README.md

* Update README.md

* Improve agenda policy (#52)

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy

* Update README.md

#53

* Update README.md (#57)

* Improve agenda policy (#60)

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy

* fix self.cur_domain=None when system offer book

* Improve agenda policy (#62)

* cut sentences that exceed 512 tokens in jointBERT

* Notice: The results are for commits before bdc9dba (inclusive). We will update the results after improving user policy.

* improve agenda police #31, the order of NLG could be more detailed in TemplateNLG:sorted_dialog_act

* improve goal sample strategy

* fix self.cur_domain=None when system offer book

* fix agenda for 0 choice

* fix sequicityy

* fix sequicityy

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* update setup.py:add tokenizers requirement

* fix typo

* update user nlg template

* Update README.md

* remove fail book in multiwoz goal generator

* fix taxi dontcare problem

* can manually set user goal in agenda now

* test goal overlap between generator and trainset

* change default taxi depart and destination from address to name/'the hotel/restaurant'

* change initiative from 4 to randint(2,4)

* agenda pop more da when only answer dontcare

* add 'the same area/pricerange/people/day' in agenda with 0.3 probability

* remove unnecessary thank you

* add domain for postcode and Phone in user templateNLG

* add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy

* add template for interent-no, parking-no in templatenlg

* update Evaluator: check whether final goal satisfies constraints

* update evaluator: check booked entity

* output goal analysis to file

* update goal analysis

* update

* Update analyzer.py

* Fix simulator (#83)

* remove fail book in multiwoz goal generator

* fix taxi dontcare problem

* can manually set user goal in agenda now

* test goal overlap between generator and trainset

* change default taxi depart and destination from address to name/'the hotel/restaurant'

* change initiative from 4 to randint(2,4)

* agenda pop more da when only answer dontcare

* add 'the same area/pricerange/people/day' in agenda with 0.3 probability

* remove unnecessary thank you

* add domain for postcode and Phone in user templateNLG

* add **kwargs in init_session for self-defined goal; remove request for nooffer-slot in rule-sys-policy

* add template for interent-no, parking-no in templatenlg

* remove police and hospital domain in goal generator

* update multiwoz evaluator: adding 'internet/parking-none, 24:** to valid value

* fix nlg template (#88)

* add new_goal_model without police and hospital domain (#89)

* Normalize string comparisons in multiwoz template nlg to be case insensitive (#87)

* normalize template nlg keys to be lower case

* fix slot comparison in multiwoz nlg to be case insensitive

* use value_lower instead of calling .lower() on each comparison

* Add police n hospital (#95)

* add back police and hospital goal

* update police db:add postcode; update hospital db:add address and postcode; update dbquery: query hospital with department, deepcopy query result

* update dbquery and session (#99)

* update dbquery: ? matches all; fix bug in init_session

* update multiwoz_eval, check Ref of booked

* filter domain in final_goal_analyze

Co-authored-by: newRuntieException <wdz15@mails.tsinghua.edu.cn>

* Add dockerfile (#98)

* fix nlg template

* add dockerfile

* include missing packages at setup.py (#102)

* multiwoz dbquery doesnt require mutable constraints (#106)

* Add police n hospital (#107)

* add back police and hospital goal

* update police db:add postcode; update hospital db:add address and postcode; update dbquery: query hospital with department, deepcopy query result

* update user templatenlg

* add test set example for dstc9 (multiwoz_zh, crosswoz_en) (#108)

* Add dockerfile (#110)

* fix nlg template

* add dockerfile

* add package for dockerfile

* update versions

* Update README.md

* Update versions in setup (#111)

* move dbquery change from master to dev branch

* add alias center for centre in dbquery

* fix sequicityy

* update versions

Co-authored-by: zqwerty <zhuq96@hotmail.com>
Co-authored-by: zhuqi <zqwerty@users.noreply.github.com>

* Update README.md

* Update README.md

* Update README.md

* fix system nlg template bug (#117)

* add 'book' in DST evaluation. (#85)

* Maintenance (#119)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* dstc9 eval

* dstc9 xldst evaluation

* Nlg template fix (#121)

* fix nlg template

* fix user nlg template issue

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* dstc9 xldst evaluation (#122)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* revise xldst evaluation (#124)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* Update dst.py

* update precision, recall, f1 calculation

* minor change

* fix policy evaluation

* Nlg template fix (#127)

* fix nlg template

* fix user nlg template issue

* fix system NLG template

* nlu update and bugfix (#118)

* jointBERT_new avaliable && fix milu dataset_reader && fix jointBERT/tag2id

* remove jointBERT_new

* update milu/multiwoz/nlu.py model_file path

* add metrics in XLDST evaluation (#126)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* add input reqt vals in human eval (#128)

* Maintenance (#129)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* Human (#131)

* change task config

* add final goal logging

* encapsule PipelineAgent internal state interface for return and
replacement

* Maintenance (#132)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue #130

* remove debug output

* fix a database typo

* Maintenance (#134)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue #130

* remove debug output

* fix goal generator for police domain message

* fix a minor typo in crosswoz database (#133)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* remove low performance baselines (#136)

* Human2 (#137)

* change task config

* add final goal logging

* encapsule PipelineAgent internal state interface for return and
replacement

* fix bug associted with the issue of strange user input

* Fix a bug in TRADE CrossWOZ training (#138)

* add 'book' in DST evaluation.

* Fix TRADE crosswoz training evaluation bug

Co-authored-by: zheng <zheng@zhangzheng-PC.lan>

* Maintenance (#140)

* add test set example for dstc9 (multiwoz_zh, crosswoz_en)

* update new_goal_model.pkl

* update crosswoz auto_sys_template_nlg

* add postcode as special case for NLU tokenization

* fix lower case for int value in nlg.py

* fix empty user utterance problem in multiwoz simulator, issue #130

* remove debug output

* fix goal generator for police domain message

* update template NLG

* Add note for deploy web service (#139)

* add 'book' in DST evaluation.

* Fix TRADE crosswoz training evaluation bug

* Add note on deploy

Co-authored-by: zheng <zheng@zhangzheng-PC.lan>

* add value unification

* fix XLDST evaluation (#141)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* add value unification

* fix user Nlg template (#142)

* fix system nlg template bug

* fix user nlg issue

* fix white character issue #144

* deal with white charater in XLDST evaluation (#145)

* update sumbt translation train result with evaluation mode set

* update extract values

* automatically download sumbt model

* dstc9 eval

* dstc9 xldst evaluation

* modify example

* add .gitignore

* remove precision, recall, f1

* release 250 test data

* revise evaluation

* fix file submission example

* update precision, recall, f1 calculation

* minor change

* fix a database typo

* use selectedResults for missing name

* add value unification

* fix white character issue #144

* DQN (#113)

* implemented script to extract all the statistics for all dialogue_act in data

* changed script for actions be compatible to sys_da_voc.txt actions

* multiwoz vector now supports composite actions

* implemented ReplayMemory and EpsilongGreedyPolicy

* implemented a basic version of dqn

* included some comments

* Add DQN Test and Change file structure  (#146)

* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

* Update server.py

* add test for DQN

* change server

Co-authored-by: Carrey Wang <cwhongru@cuc.edu.cn>
Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
Co-authored-by: zimozhou <47972969+zimozhou@users.noreply.github.com>
Co-authored-by: MR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk>

* update eval

* dump dst eval results

* make value lower

* add progress bar

* fix bug in last commit

* Update policy_agenda_multiwoz.py

* remove unnecessary mapping (#147)

* udpate dstc9 eval

* make value lower

* add warm up for dqn and fix bugs

* rm unrelated files

Co-authored-by: zhuqi <zqwerty@users.noreply.github.com>
Co-authored-by: zqwerty <zhuq96@hotmail.com>
Co-authored-by: Ryuichi Takanobu <truthless11@gmail.com>
Co-authored-by: newRuntieException <wdz15@mails.tsinghua.edu.cn>
Co-authored-by: liangrz <liangrz15@mails.tsinghua.edu.cn>
Co-authored-by: Carrey Wang <cwhongru@cuc.edu.cn>
Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
Co-authored-by: 罗崚骁 <function2@qq.com>
Co-authored-by: mehrad <mehrad@stanford.edu>
Co-authored-by: pengbaolin <39398162+pengbaolin@users.noreply.github.com>
Co-authored-by: Jinchao Li <38700695+jincli@users.noreply.github.com>
Co-authored-by: Shahin Shayandeh <shahins@microsoft.com>
Co-authored-by: aaa123git <43716234+aaa123git@users.noreply.github.com>
Co-authored-by: Bruno Eidi Nishimoto <bruno_nishimoto@hotmail.com>
Co-authored-by: Vojtěch Hudeček <vojta.hudecek@gmail.com>
Co-authored-by: zhangzthu <zhangz.goal@gmail.com>
Co-authored-by: xw <48146603+xwwwwww@users.noreply.github.com>
Co-authored-by: zheng <zheng@zhangzheng-PC.lan>
Co-authored-by: zimozhou <47972969+zimozhou@users.noreply.github.com>
Co-authored-by: MR. WANG <hrwang@kfsrv03.se.cuhk.edu.hk>
  • Loading branch information
22 people authored Oct 22, 2020
1 parent 3811af8 commit e368dee
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 10 deletions.
33 changes: 26 additions & 7 deletions convlab2/policy/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from convlab2.policy.rlmodule import EpsilonGreedyPolicy, MemoryReplay
from convlab2.util.train_util import init_logging_handler
from convlab2.policy.vector.vector_multiwoz import MultiWozVector
from convlab2.policy.rule.multiwoz.rule_based_multiwoz_bot import RuleBasedMultiwozBot
from convlab2.util.file_util import cached_path
import zipfile
import sys
Expand All @@ -32,6 +33,8 @@ def __init__(self, is_train=False, dataset='Multiwoz'):
self.training_iter = cfg['training_iter']
self.training_batch_iter = cfg['training_batch_iter']
self.batch_size = cfg['batch_size']
self.epsilon = cfg['epsilon_spec']['start']
self.rule_bot = RuleBasedMultiwozBot()
self.gamma = cfg['gamma']
self.is_train = is_train
if is_train:
Expand All @@ -58,22 +61,38 @@ def __init__(self, is_train=False, dataset='Multiwoz'):
self.loss_fn = nn.MSELoss()

def update_memory(self, sample):
self.memory.reset()
self.memory.append(sample)

def predict(self, state):
def predict(self, state, warm_up=False):
"""
Predict an system action given state.
Args:
state (dict): Dialog state. Please refer to util/state.py
Returns:
action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, ...})
"""
s_vec = torch.Tensor(self.vector.state_vectorize(state))
a = self.net.select_action(s_vec.to(device=DEVICE))

action = self.vector.action_devectorize(a.numpy())

state['system_action'] = action
if warm_up:
action = self.rule_action(state)
state['system_action'] = action
else:
s_vec = torch.Tensor(self.vector.state_vectorize(state))
a = self.net.select_action(s_vec.to(device=DEVICE), is_train=self.is_train)
action = self.vector.action_devectorize(a.numpy())
state['system_action'] = action
return action

def rule_action(self, state):
if self.epsilon > np.random.rand():
a = torch.randint(self.vector.da_dim, (1, ))
# transforms action index to a vector action (one-hot encoding)
a_vec = torch.zeros(self.vector.da_dim)
a_vec[a] = 1.
action = self.vector.action_devectorize(a_vec.numpy())
else:
# rule-based warm up
action = self.rule_bot.predict(state)

return action

def init_session(self):
Expand Down
78 changes: 76 additions & 2 deletions convlab2/policy/dqn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,71 @@ def sampler(pid, queue, evt, env, policy, batchsz):
queue.put([pid, buff])
evt.wait()

def warmupsampler(pid, queue, evt, env, policy, batchsz):
"""
This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple
processes.
:param pid: process id
:param queue: multiprocessing.Queue, to collect sampled data
:param evt: multiprocessing.Event, to keep the process alive
:param env: environment instance
:param policy: policy network, to generate action from current policy
:param batchsz: total sampled items
:return:
"""
buff = Memory()

# we need to sample batchsz of (state, action, next_state, reward, mask)
# each trajectory contains `trajectory_len` num of items, so we only need to sample
# `batchsz//trajectory_len` num of trajectory totally
# the final sampled number may be larger than batchsz.

sampled_num = 0
sampled_traj_num = 0
traj_len = 50
real_traj_len = 0

while sampled_num < batchsz:
# for each trajectory, we reset the env and get initial state
s = env.reset()

for t in range(traj_len):

# [s_dim] => [a_dim]
s_vec = torch.Tensor(policy.vector.state_vectorize(s))
a = policy.predict(s, warm_up=True)

# interact with env
next_s, r, done = env.step(a)

# a flag indicates ending or not
mask = 0 if done else 1

# get reward compared to demostrations
next_s_vec = torch.Tensor(policy.vector.state_vectorize(next_s))

# save to queue
buff.push(s_vec.numpy(), policy.vector.action_vectorize(a), r, next_s_vec.numpy(), mask)

# update per step
s = next_s
real_traj_len = t

if done:
break

def sample(env, policy, batchsz, process_num):
# this is end of one trajectory
sampled_num += real_traj_len
sampled_traj_num += 1
# t indicates the valid trajectory length

# this is end of sampling all batchsz of items.
# when sampling is over, push all buff data into queue
queue.put([pid, buff])
evt.wait()


def sample(env, policy, batchsz, process_num, warm_up=False):
"""
Given batchsz number of task, the batchsz will be splited equally to each processes
and when processes return, it merge all data and return
Expand Down Expand Up @@ -119,7 +182,10 @@ def sample(env, policy, batchsz, process_num):
processes = []
for i in range(process_num):
process_args = (i, queue, evt, env, policy, process_batchsz)
processes.append(mp.Process(target=sampler, args=process_args))
if warm_up:
processes.append(mp.Process(target=warmupsampler, args=process_args))
else:
processes.append(mp.Process(target=sampler, args=process_args))
for p in processes:
# set the process as daemon, and it will be killed once the main process is stoped.
p.daemon = True
Expand All @@ -146,6 +212,13 @@ def update(env, policy, batchsz, epoch, process_num):
policy.update(epoch)


def warm_start(env, policy, batchsz, epoch, process_num):
# sample data asynchronously
buff = sample(env, policy, batchsz, process_num, warm_up=True)
policy.update_memory(buff)
policy.update(epoch)


if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--load_path", type=str, default="", help="path of model to load")
Expand All @@ -170,6 +243,7 @@ def update(env, policy, batchsz, epoch, process_num):
evaluator = MultiWozEvaluator()
env = Environment(None, simulator, None, dst_sys, evaluator)

warm_start(env, policy_sys, args.batchsz, 0, args.process_num)

for i in range(args.epoch):
update(env, policy_sys, args.batchsz, i, args.process_num)
2 changes: 1 addition & 1 deletion convlab2/policy/hdsa/multiwoz/transformer/Beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def advance(self, word_prob):

# bestScoresId is flattened as a (beam x word) array,
# so we need to calculate which word and beam each score came from
prev_k = best_scores_id / num_words
prev_k = best_scores_id // num_words
self.prev_ks.append(prev_k)
self.next_ys.append(best_scores_id - prev_k * num_words)

Expand Down

0 comments on commit e368dee

Please sign in to comment.