Skip to content

Commit

Permalink
Add CrossWoz Web support and some minor bug fix (thu-coai#19)
Browse files Browse the repository at this point in the history
* Initial commit

* first commit

* add build

* add build

* add build

* add recommend

* add crosswoz config in deploy

* add crosswoz at html

* debug chinese vision

* fix system bug according to convlab2

* master change

* modify .gitignore

* delete svm_camrest_usr.pickle

Co-authored-by: kflab_2018 <kflab_2018@kflab-2018s-MacBook-Air.local>
Co-authored-by: CarreyWong <carreywong@CarreyWongs-MacBook-Pro.local>
  • Loading branch information
3 people authored Jun 15, 2020
1 parent 9f8d8ec commit 1f1b919
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ convlab2/nlu/jointBERT/**/output/
convlab2/dst/sumbt/multiwoz/output/
convlab2/nlg/sclstm/**/generated_sens_sys.json
convlab2/nlg/template/**/generated_sens_sys.json
convlab2/nlu/jointBERT/crosswoz/**/data
# test script
*_test.py

Expand Down
3 changes: 2 additions & 1 deletion convlab2/nlu/jointBERT/crosswoz/nlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def __init__(self, mode='all', config_file='crosswoz_all_context.json',
assert mode == 'usr' or mode == 'sys' or mode == 'all'
config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs/{}'.format(config_file))
config = json.load(open(config_file))
DEVICE = config['DEVICE']
# DEVICE = config['DEVICE']
DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE']
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
data_dir = os.path.join(root_dir, config['data_dir'])
output_dir = os.path.join(root_dir, config['output_dir'])
Expand Down
2 changes: 1 addition & 1 deletion convlab2/policy/mle/crosswoz/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def __init__(self,
if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')):
archive = zipfile.ZipFile(archive_file, 'r')
archive.extractall(model_dir)
self.load(archive_file, model_file, cfg['load'])
self.load_from_pretrained(archive_file, model_file, cfg['load'])
2 changes: 1 addition & 1 deletion convlab2/policy/mle/mle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def predict(self, state):
"""
s_vec = torch.Tensor(self.vector.state_vectorize(state))
a = self.policy.select_action(s_vec.to(device=DEVICE), False).cpu()
action = self.vector.action_devectorize(a.numpy())
action = self.vector.action_devectorize(a.detach().numpy())
state['system_action'] = action
return action

Expand Down
1 change: 0 additions & 1 deletion convlab2/policy/pg/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def update(self, epoch, batchsz, s, a, r, mask):

# backprop
surrogate.backward()

for p in self.policy.parameters():
p.grad[p.grad != p.grad] = 0.0
# gradient clipping, for stability
Expand Down
43 changes: 43 additions & 0 deletions deploy/dep_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@
"preload": false,
"enable": true
},
"bert-cro": {
"class_path": "convlab2.nlu.jointBERT.crosswoz.nlu.BERTNLU",
"data_set": "crosswoz",
"ini_params": {
"mode": "all",
"config_file": "crosswoz_all.json",
"model_file": "https://convlab.blob.core.windows.net/convlab-2/bert_crosswoz_all.zip"
},
"model_name": "bert-cro",
"max_core": 1,
"preload": false,
"enable": true
},
"bert-mul": {
"class_path": "convlab2.nlu.jointBERT.multiwoz.nlu.BERTNLU",
"data_set": "multiwoz",
Expand Down Expand Up @@ -60,6 +73,15 @@
"preload": true,
"enable": true
},
"rule-cro": {
"class_path": "convlab2.dst.rule.crosswoz.dst.RuleDST",
"data_set": "crosswoz",
"ini_params": {},
"model_name": "rule-cro",
"max_core": 1,
"preload": true,
"enable": true
},
"trade-mul": {
"class_path": "convlab2.dst.trade.multiwoz.trade.MultiWOZTRADE",
"data_set": "multiwoz",
Expand Down Expand Up @@ -106,6 +128,15 @@
"max_core": 1,
"preload": true,
"enable": true
},
"mle-cro": {
"class_path": "convlab2.policy.mle.crosswoz.mle.MLE",
"data_set": "crosswoz",
"ini_params": {},
"model_name": "mle-cro",
"max_core": 1,
"preload": false,
"enable": true
}
},
"nlg": {
Expand Down Expand Up @@ -143,6 +174,18 @@
"max_core": 1,
"preload": true,
"enable": true
},
"tmp-auto_manual-cro": {
"class_path": "convlab2.nlg.template.crosswoz.nlg.TemplateNLG",
"data_set": "crosswoz",
"ini_params": {
"is_user": false,
"mode": "auto_manual"
},
"model_name": "tmp-auto_manual-cro",
"max_core": 1,
"preload": true,
"enable": true
}
}
}
4 changes: 3 additions & 1 deletion deploy/templates/dialog.html
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@
data: {
dataset: 'MultiWoz',
dataset_short: 'mul',
dataset_list: ['MultiWoz'],
dataset_list: ['MultiWoz', 'CrossWoz'],
nlu: 'BERTNLU',
nlu_list: [],
nlu_output: {},
Expand Down Expand Up @@ -317,6 +317,8 @@
dataset: function() {
if (this.dataset === 'MultiWoz') {
this.dataset_short = 'mul'
} else if (this.dataset == 'CrossWoz') {
this.dataset_short = 'cro'
} else {
this.dataset_short = 'cam'
}
Expand Down

0 comments on commit 1f1b919

Please sign in to comment.