diff --git a/.gitignore b/.gitignore index bdb4d680..2823cc38 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ convlab2/nlu/jointBERT/**/output/ convlab2/dst/sumbt/multiwoz/output/ convlab2/nlg/sclstm/**/generated_sens_sys.json convlab2/nlg/template/**/generated_sens_sys.json +convlab2/nlu/jointBERT/crosswoz/**/data # test script *_test.py diff --git a/convlab2/nlu/jointBERT/crosswoz/nlu.py b/convlab2/nlu/jointBERT/crosswoz/nlu.py index d9f7e4ca..594f7285 100755 --- a/convlab2/nlu/jointBERT/crosswoz/nlu.py +++ b/convlab2/nlu/jointBERT/crosswoz/nlu.py @@ -17,7 +17,8 @@ def __init__(self, mode='all', config_file='crosswoz_all_context.json', assert mode == 'usr' or mode == 'sys' or mode == 'all' config_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'configs/{}'.format(config_file)) config = json.load(open(config_file)) - DEVICE = config['DEVICE'] + # DEVICE = config['DEVICE'] + DEVICE = 'cpu' if not torch.cuda.is_available() else config['DEVICE'] root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) data_dir = os.path.join(root_dir, config['data_dir']) output_dir = os.path.join(root_dir, config['output_dir']) diff --git a/convlab2/policy/mle/crosswoz/mle.py b/convlab2/policy/mle/crosswoz/mle.py index abb12493..ef65fd72 100755 --- a/convlab2/policy/mle/crosswoz/mle.py +++ b/convlab2/policy/mle/crosswoz/mle.py @@ -41,4 +41,4 @@ def __init__(self, if not os.path.exists(os.path.join(model_dir, 'best_mle.pol.mdl')): archive = zipfile.ZipFile(archive_file, 'r') archive.extractall(model_dir) - self.load(archive_file, model_file, cfg['load']) + self.load_from_pretrained(archive_file, model_file, cfg['load']) diff --git a/convlab2/policy/mle/mle.py b/convlab2/policy/mle/mle.py index 21d6436c..54e2f026 100755 --- a/convlab2/policy/mle/mle.py +++ b/convlab2/policy/mle/mle.py @@ -25,7 +25,7 @@ def predict(self, state): """ s_vec = torch.Tensor(self.vector.state_vectorize(state)) a = self.policy.select_action(s_vec.to(device=DEVICE), False).cpu() - action = self.vector.action_devectorize(a.numpy()) + action = self.vector.action_devectorize(a.detach().numpy()) state['system_action'] = action return action diff --git a/convlab2/policy/pg/pg.py b/convlab2/policy/pg/pg.py index 82df69a8..c605c808 100755 --- a/convlab2/policy/pg/pg.py +++ b/convlab2/policy/pg/pg.py @@ -126,7 +126,6 @@ def update(self, epoch, batchsz, s, a, r, mask): # backprop surrogate.backward() - for p in self.policy.parameters(): p.grad[p.grad != p.grad] = 0.0 # gradient clipping, for stability diff --git a/deploy/dep_config.json b/deploy/dep_config.json index b9c478ba..4afd8884 100755 --- a/deploy/dep_config.json +++ b/deploy/dep_config.json @@ -27,6 +27,19 @@ "preload": false, "enable": true }, + "bert-cro": { + "class_path": "convlab2.nlu.jointBERT.crosswoz.nlu.BERTNLU", + "data_set": "crosswoz", + "ini_params": { + "mode": "all", + "config_file": "crosswoz_all.json", + "model_file": "https://convlab.blob.core.windows.net/convlab-2/bert_crosswoz_all.zip" + }, + "model_name": "bert-cro", + "max_core": 1, + "preload": false, + "enable": true + }, "bert-mul": { "class_path": "convlab2.nlu.jointBERT.multiwoz.nlu.BERTNLU", "data_set": "multiwoz", @@ -60,6 +73,15 @@ "preload": true, "enable": true }, + "rule-cro": { + "class_path": "convlab2.dst.rule.crosswoz.dst.RuleDST", + "data_set": "crosswoz", + "ini_params": {}, + "model_name": "rule-cro", + "max_core": 1, + "preload": true, + "enable": true + }, "trade-mul": { "class_path": "convlab2.dst.trade.multiwoz.trade.MultiWOZTRADE", "data_set": "multiwoz", @@ -106,6 +128,15 @@ "max_core": 1, "preload": true, "enable": true + }, + "mle-cro": { + "class_path": "convlab2.policy.mle.crosswoz.mle.MLE", + "data_set": "crosswoz", + "ini_params": {}, + "model_name": "mle-cro", + "max_core": 1, + "preload": false, + "enable": true } }, "nlg": { @@ -143,6 +174,18 @@ "max_core": 1, "preload": true, "enable": true + }, + "tmp-auto_manual-cro": { + "class_path": "convlab2.nlg.template.crosswoz.nlg.TemplateNLG", + "data_set": "crosswoz", + "ini_params": { + "is_user": false, + "mode": "auto_manual" + }, + "model_name": "tmp-auto_manual-cro", + "max_core": 1, + "preload": true, + "enable": true } } } \ No newline at end of file diff --git a/deploy/templates/dialog.html b/deploy/templates/dialog.html index d5218094..ac3e15c9 100755 --- a/deploy/templates/dialog.html +++ b/deploy/templates/dialog.html @@ -279,7 +279,7 @@ data: { dataset: 'MultiWoz', dataset_short: 'mul', - dataset_list: ['MultiWoz'], + dataset_list: ['MultiWoz', 'CrossWoz'], nlu: 'BERTNLU', nlu_list: [], nlu_output: {}, @@ -317,6 +317,8 @@ dataset: function() { if (this.dataset === 'MultiWoz') { this.dataset_short = 'mul' + } else if (this.dataset == 'CrossWoz') { + this.dataset_short = 'cro' } else { this.dataset_short = 'cam' }